rama/http/client/
proxy_connector.rs

1use crate::{
2    Layer, Service,
3    error::{BoxError, OpaqueError},
4    extensions::{Extensions, ExtensionsMut, ExtensionsRef},
5    http::client::proxy::layer::{
6        HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7    },
8    net::{
9        Protocol,
10        address::ProxyAddress,
11        client::{ConnectorService, EstablishedClientConnection},
12        transport::TryRefIntoTransportContext,
13    },
14    proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15    stream::Stream,
16    telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use std::{
20    fmt::Debug,
21    pin::Pin,
22    sync::Arc,
23    task::{self, Poll},
24};
25use tokio::io::{AsyncRead, AsyncWrite};
26
27/// Proxy connector which supports http(s) and socks5(h) proxy address
28///
29/// Connector will look at [`ProxyAddress`] to determine which proxy
30/// connector to use if one is configured
31pub struct ProxyConnector<S> {
32    inner: S,
33    socks: Socks5ProxyConnector<S>,
34    http: HttpProxyConnector<S>,
35    required: bool,
36}
37
38impl<S> ProxyConnector<S> {
39    /// Creates a new [`ProxyConnector`].
40    fn new(
41        inner: S,
42        socks_proxy_layer: Socks5ProxyConnectorLayer,
43        http_proxy_layer: HttpProxyConnectorLayer,
44        required: bool,
45    ) -> ProxyConnector<Arc<S>> {
46        let inner = Arc::new(inner);
47        ProxyConnector {
48            socks: socks_proxy_layer.into_layer(inner.clone()),
49            http: http_proxy_layer.into_layer(inner.clone()),
50            inner,
51            required,
52        }
53    }
54
55    #[inline]
56    /// Creates a new required [`ProxyConnector`].
57    ///
58    /// This connector will fail if no [`ProxyAddress`] is configured
59    pub fn required(
60        inner: S,
61        socks_proxy_layer: Socks5ProxyConnectorLayer,
62        http_proxy_layer: HttpProxyConnectorLayer,
63    ) -> ProxyConnector<Arc<S>> {
64        Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
65    }
66
67    #[inline]
68    /// Creates a new optional [`ProxyConnector`].
69    ///
70    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
71    pub fn optional(
72        inner: S,
73        socks_proxy_layer: Socks5ProxyConnectorLayer,
74        http_proxy_layer: HttpProxyConnectorLayer,
75    ) -> ProxyConnector<Arc<S>> {
76        Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
77    }
78}
79
80impl<Request, S> Service<Request> for ProxyConnector<S>
81where
82    S: ConnectorService<Request, Connection: Stream + Unpin>,
83    Request: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
84        + Send
85        + ExtensionsMut
86        + 'static,
87{
88    type Response = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Request>;
89
90    type Error = BoxError;
91
92    async fn serve(&self, req: Request) -> Result<Self::Response, Self::Error> {
93        let proxy = req.extensions().get::<ProxyAddress>();
94
95        match proxy {
96            None => {
97                if self.required {
98                    return Err("proxy required but none is defined".into());
99                }
100                tracing::trace!("no proxy detected in ctx, using inner connector");
101                let EstablishedClientConnection { req, conn } =
102                    self.inner.connect(req).await.map_err(Into::into)?;
103
104                let conn = MaybeProxiedConnection::direct(conn);
105                Ok(EstablishedClientConnection { req, conn })
106            }
107            Some(proxy) => {
108                let protocol = proxy.protocol.as_ref();
109                tracing::trace!(?protocol, "proxy detected in ctx");
110
111                let protocol = protocol.unwrap_or_else(|| {
112                    tracing::trace!("no protocol detected, using http as protocol");
113                    &Protocol::HTTP
114                });
115
116                if protocol.is_socks5() {
117                    tracing::trace!("using socks proxy connector");
118                    let EstablishedClientConnection { req, conn } = self.socks.connect(req).await?;
119
120                    let conn = MaybeProxiedConnection::socks(conn);
121                    Ok(EstablishedClientConnection { req, conn })
122                } else if protocol.is_http() {
123                    tracing::trace!("using http proxy connector");
124                    let EstablishedClientConnection { req, conn } = self.http.connect(req).await?;
125
126                    let conn = MaybeProxiedConnection::http(conn);
127                    Ok(EstablishedClientConnection { req, conn })
128                } else {
129                    Err(OpaqueError::from_display(format!(
130                        "received unsupport proxy protocol {protocol:?}"
131                    ))
132                    .into_boxed())
133                }
134            }
135        }
136    }
137}
138
139pin_project! {
140    /// A connection which will be proxied if a [`ProxyAddress`] was configured
141    pub struct MaybeProxiedConnection<S> {
142        #[pin]
143        inner: Connection<S>,
144    }
145}
146
147impl<S: ExtensionsMut> MaybeProxiedConnection<S> {
148    pub fn direct(conn: S) -> Self {
149        Self {
150            inner: Connection::Direct { conn },
151        }
152    }
153
154    pub fn socks(conn: S) -> Self {
155        Self {
156            inner: Connection::Socks { conn },
157        }
158    }
159
160    pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
161        Self {
162            inner: Connection::Http { conn },
163        }
164    }
165}
166
167impl<S: Debug> Debug for MaybeProxiedConnection<S> {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct("MaybeProxiedConnection")
170            .field("inner", &self.inner)
171            .finish()
172    }
173}
174
175impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
176    fn extensions(&self) -> &Extensions {
177        match &self.inner {
178            Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
179            Connection::Http { conn } => conn.extensions(),
180        }
181    }
182}
183
184impl<S: ExtensionsMut> ExtensionsMut for MaybeProxiedConnection<S> {
185    fn extensions_mut(&mut self) -> &mut Extensions {
186        match &mut self.inner {
187            Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions_mut(),
188            Connection::Http { conn } => conn.extensions_mut(),
189        }
190    }
191}
192
193pin_project! {
194    #[project = ConnectionProj]
195    enum Connection<S> {
196        Direct{ #[pin] conn: S },
197        Socks{ #[pin] conn: S },
198        Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
199
200    }
201}
202
203impl<S: Debug> Debug for Connection<S> {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        match self {
206            Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
207            Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
208            Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
209        }
210    }
211}
212
213#[warn(clippy::missing_trait_methods)]
214impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
215    fn poll_write(
216        self: Pin<&mut Self>,
217        cx: &mut task::Context<'_>,
218        buf: &[u8],
219    ) -> Poll<Result<usize, std::io::Error>> {
220        match self.project().inner.project() {
221            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
222                conn.poll_write(cx, buf)
223            }
224            ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
225        }
226    }
227
228    fn poll_flush(
229        self: Pin<&mut Self>,
230        cx: &mut task::Context<'_>,
231    ) -> Poll<Result<(), std::io::Error>> {
232        match self.project().inner.project() {
233            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
234            ConnectionProj::Http { conn } => conn.poll_flush(cx),
235        }
236    }
237
238    fn poll_shutdown(
239        self: Pin<&mut Self>,
240        cx: &mut task::Context<'_>,
241    ) -> Poll<Result<(), std::io::Error>> {
242        match self.project().inner.project() {
243            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
244                conn.poll_shutdown(cx)
245            }
246            ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
247        }
248    }
249
250    fn is_write_vectored(&self) -> bool {
251        match &self.inner {
252            Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
253            Connection::Http { conn } => conn.is_write_vectored(),
254        }
255    }
256
257    fn poll_write_vectored(
258        self: Pin<&mut Self>,
259        cx: &mut task::Context<'_>,
260        bufs: &[std::io::IoSlice<'_>],
261    ) -> Poll<Result<usize, std::io::Error>> {
262        match self.project().inner.project() {
263            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
264                conn.poll_write_vectored(cx, bufs)
265            }
266            ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
267        }
268    }
269}
270
271#[warn(clippy::missing_trait_methods)]
272impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
273    fn poll_read(
274        self: Pin<&mut Self>,
275        cx: &mut task::Context<'_>,
276        buf: &mut tokio::io::ReadBuf<'_>,
277    ) -> Poll<std::io::Result<()>> {
278        match self.project().inner.project() {
279            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
280                conn.poll_read(cx, buf)
281            }
282            ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
283        }
284    }
285}
286
287/// Proxy connector layer which supports http(s) and socks5(h) proxy address
288///
289/// Connector will look at [`ProxyAddress`] to determine which proxy
290/// connector to use if one is configured
291pub struct ProxyConnectorLayer {
292    socks_layer: Socks5ProxyConnectorLayer,
293    http_layer: HttpProxyConnectorLayer,
294    required: bool,
295}
296
297impl ProxyConnectorLayer {
298    #[must_use]
299    /// Creates a new required [`ProxyConnectorLayer`].
300    ///
301    /// This connector will fail if no [`ProxyAddress`] is configured
302    pub fn required(
303        socks_proxy_layer: Socks5ProxyConnectorLayer,
304        http_proxy_layer: HttpProxyConnectorLayer,
305    ) -> Self {
306        Self {
307            socks_layer: socks_proxy_layer,
308            http_layer: http_proxy_layer,
309            required: true,
310        }
311    }
312
313    #[must_use]
314    /// Creates a new optional [`ProxyConnectorLayer`].
315    ///
316    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
317    pub fn optional(
318        socks_proxy_layer: Socks5ProxyConnectorLayer,
319        http_proxy_layer: HttpProxyConnectorLayer,
320    ) -> Self {
321        Self {
322            socks_layer: socks_proxy_layer,
323            http_layer: http_proxy_layer,
324            required: false,
325        }
326    }
327}
328
329impl<S> Layer<S> for ProxyConnectorLayer {
330    type Service = ProxyConnector<Arc<S>>;
331
332    fn layer(&self, inner: S) -> Self::Service {
333        ProxyConnector::new(
334            inner,
335            self.socks_layer.clone(),
336            self.http_layer.clone(),
337            self.required,
338        )
339    }
340
341    fn into_layer(self, inner: S) -> Self::Service {
342        ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
343    }
344}