1use crate::{
2 Layer, Service,
3 error::{BoxError, OpaqueError},
4 extensions::{Extensions, ExtensionsMut, ExtensionsRef},
5 http::client::proxy::layer::{
6 HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7 },
8 net::{
9 Protocol,
10 address::ProxyAddress,
11 client::{ConnectorService, EstablishedClientConnection},
12 transport::TryRefIntoTransportContext,
13 },
14 proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15 stream::Stream,
16 telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use std::{
20 fmt::Debug,
21 pin::Pin,
22 sync::Arc,
23 task::{self, Poll},
24};
25use tokio::io::{AsyncRead, AsyncWrite};
26
27pub struct ProxyConnector<S> {
32 inner: S,
33 socks: Socks5ProxyConnector<S>,
34 http: HttpProxyConnector<S>,
35 required: bool,
36}
37
38impl<S> ProxyConnector<S> {
39 fn new(
41 inner: S,
42 socks_proxy_layer: Socks5ProxyConnectorLayer,
43 http_proxy_layer: HttpProxyConnectorLayer,
44 required: bool,
45 ) -> ProxyConnector<Arc<S>> {
46 let inner = Arc::new(inner);
47 ProxyConnector {
48 socks: socks_proxy_layer.into_layer(inner.clone()),
49 http: http_proxy_layer.into_layer(inner.clone()),
50 inner,
51 required,
52 }
53 }
54
55 #[inline]
56 pub fn required(
60 inner: S,
61 socks_proxy_layer: Socks5ProxyConnectorLayer,
62 http_proxy_layer: HttpProxyConnectorLayer,
63 ) -> ProxyConnector<Arc<S>> {
64 Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
65 }
66
67 #[inline]
68 pub fn optional(
72 inner: S,
73 socks_proxy_layer: Socks5ProxyConnectorLayer,
74 http_proxy_layer: HttpProxyConnectorLayer,
75 ) -> ProxyConnector<Arc<S>> {
76 Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
77 }
78}
79
80impl<Input, S> Service<Input> for ProxyConnector<S>
81where
82 S: ConnectorService<Input, Connection: Stream + Unpin>,
83 Input: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
84 + Send
85 + ExtensionsMut
86 + 'static,
87{
88 type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
89 type Error = BoxError;
90
91 async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
92 let proxy = input.extensions().get::<ProxyAddress>();
93
94 match proxy {
95 None => {
96 if self.required {
97 return Err("proxy required but none is defined".into());
98 }
99 tracing::trace!("no proxy detected in ctx, using inner connector");
100 let EstablishedClientConnection { input, conn } =
101 self.inner.connect(input).await.map_err(Into::into)?;
102
103 let conn = MaybeProxiedConnection::direct(conn);
104 Ok(EstablishedClientConnection { input, conn })
105 }
106 Some(proxy) => {
107 let protocol = proxy.protocol.as_ref();
108 tracing::trace!(?protocol, "proxy detected in ctx");
109
110 let protocol = protocol.unwrap_or_else(|| {
111 tracing::trace!("no protocol detected, using http as protocol");
112 &Protocol::HTTP
113 });
114
115 if protocol.is_socks5() {
116 tracing::trace!("using socks proxy connector");
117 let EstablishedClientConnection { input, conn } =
118 self.socks.connect(input).await?;
119
120 let conn = MaybeProxiedConnection::socks(conn);
121 Ok(EstablishedClientConnection { input, conn })
122 } else if protocol.is_http() {
123 tracing::trace!("using http proxy connector");
124 let EstablishedClientConnection { input, conn } =
125 self.http.connect(input).await?;
126
127 let conn = MaybeProxiedConnection::http(conn);
128 Ok(EstablishedClientConnection { input, conn })
129 } else {
130 Err(OpaqueError::from_display(format!(
131 "received unsupport proxy protocol {protocol:?}"
132 ))
133 .into_boxed())
134 }
135 }
136 }
137 }
138}
139
140pin_project! {
141 pub struct MaybeProxiedConnection<S> {
143 #[pin]
144 inner: Connection<S>,
145 }
146}
147
148impl<S: ExtensionsMut> MaybeProxiedConnection<S> {
149 pub fn direct(conn: S) -> Self {
150 Self {
151 inner: Connection::Direct { conn },
152 }
153 }
154
155 pub fn socks(conn: S) -> Self {
156 Self {
157 inner: Connection::Socks { conn },
158 }
159 }
160
161 pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
162 Self {
163 inner: Connection::Http { conn },
164 }
165 }
166}
167
168impl<S: Debug> Debug for MaybeProxiedConnection<S> {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 f.debug_struct("MaybeProxiedConnection")
171 .field("inner", &self.inner)
172 .finish()
173 }
174}
175
176impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
177 fn extensions(&self) -> &Extensions {
178 match &self.inner {
179 Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
180 Connection::Http { conn } => conn.extensions(),
181 }
182 }
183}
184
185impl<S: ExtensionsMut> ExtensionsMut for MaybeProxiedConnection<S> {
186 fn extensions_mut(&mut self) -> &mut Extensions {
187 match &mut self.inner {
188 Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions_mut(),
189 Connection::Http { conn } => conn.extensions_mut(),
190 }
191 }
192}
193
194pin_project! {
195 #[project = ConnectionProj]
196 enum Connection<S> {
197 Direct{ #[pin] conn: S },
198 Socks{ #[pin] conn: S },
199 Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
200
201 }
202}
203
204impl<S: Debug> Debug for Connection<S> {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 match self {
207 Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
208 Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
209 Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
210 }
211 }
212}
213
214#[warn(clippy::missing_trait_methods)]
215impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
216 fn poll_write(
217 self: Pin<&mut Self>,
218 cx: &mut task::Context<'_>,
219 buf: &[u8],
220 ) -> Poll<Result<usize, std::io::Error>> {
221 match self.project().inner.project() {
222 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
223 conn.poll_write(cx, buf)
224 }
225 ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
226 }
227 }
228
229 fn poll_flush(
230 self: Pin<&mut Self>,
231 cx: &mut task::Context<'_>,
232 ) -> Poll<Result<(), std::io::Error>> {
233 match self.project().inner.project() {
234 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
235 ConnectionProj::Http { conn } => conn.poll_flush(cx),
236 }
237 }
238
239 fn poll_shutdown(
240 self: Pin<&mut Self>,
241 cx: &mut task::Context<'_>,
242 ) -> Poll<Result<(), std::io::Error>> {
243 match self.project().inner.project() {
244 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
245 conn.poll_shutdown(cx)
246 }
247 ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
248 }
249 }
250
251 fn is_write_vectored(&self) -> bool {
252 match &self.inner {
253 Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
254 Connection::Http { conn } => conn.is_write_vectored(),
255 }
256 }
257
258 fn poll_write_vectored(
259 self: Pin<&mut Self>,
260 cx: &mut task::Context<'_>,
261 bufs: &[std::io::IoSlice<'_>],
262 ) -> Poll<Result<usize, std::io::Error>> {
263 match self.project().inner.project() {
264 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
265 conn.poll_write_vectored(cx, bufs)
266 }
267 ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
268 }
269 }
270}
271
272#[warn(clippy::missing_trait_methods)]
273impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
274 fn poll_read(
275 self: Pin<&mut Self>,
276 cx: &mut task::Context<'_>,
277 buf: &mut tokio::io::ReadBuf<'_>,
278 ) -> Poll<std::io::Result<()>> {
279 match self.project().inner.project() {
280 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
281 conn.poll_read(cx, buf)
282 }
283 ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
284 }
285 }
286}
287
288pub struct ProxyConnectorLayer {
293 socks_layer: Socks5ProxyConnectorLayer,
294 http_layer: HttpProxyConnectorLayer,
295 required: bool,
296}
297
298impl ProxyConnectorLayer {
299 #[must_use]
300 pub fn required(
304 socks_proxy_layer: Socks5ProxyConnectorLayer,
305 http_proxy_layer: HttpProxyConnectorLayer,
306 ) -> Self {
307 Self {
308 socks_layer: socks_proxy_layer,
309 http_layer: http_proxy_layer,
310 required: true,
311 }
312 }
313
314 #[must_use]
315 pub fn optional(
319 socks_proxy_layer: Socks5ProxyConnectorLayer,
320 http_proxy_layer: HttpProxyConnectorLayer,
321 ) -> Self {
322 Self {
323 socks_layer: socks_proxy_layer,
324 http_layer: http_proxy_layer,
325 required: false,
326 }
327 }
328}
329
330impl<S> Layer<S> for ProxyConnectorLayer {
331 type Service = ProxyConnector<Arc<S>>;
332
333 fn layer(&self, inner: S) -> Self::Service {
334 ProxyConnector::new(
335 inner,
336 self.socks_layer.clone(),
337 self.http_layer.clone(),
338 self.required,
339 )
340 }
341
342 fn into_layer(self, inner: S) -> Self::Service {
343 ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
344 }
345}