Skip to main content

rama/http/client/
proxy_connector.rs

1use crate::{
2    Layer, Service,
3    error::BoxError,
4    extensions::{Extensions, ExtensionsMut, ExtensionsRef},
5    http::client::proxy::layer::{
6        HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7    },
8    net::{
9        Protocol,
10        address::ProxyAddress,
11        client::{ConnectorService, EstablishedClientConnection},
12        transport::TryRefIntoTransportContext,
13    },
14    proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15    stream::Stream,
16    telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use rama_core::error::{ErrorContext as _, ErrorExt as _};
20use std::{
21    fmt::Debug,
22    pin::Pin,
23    task::{self, Poll},
24};
25use tokio::io::{AsyncRead, AsyncWrite};
26
27/// Proxy connector which supports http(s) and socks5(h) proxy address
28///
29/// Connector will look at [`ProxyAddress`] to determine which proxy
30/// connector to use if one is configured
31#[derive(Debug, Clone)]
32pub struct ProxyConnector<S> {
33    inner: S,
34    socks: Socks5ProxyConnector<S>,
35    http: HttpProxyConnector<S>,
36    required: bool,
37}
38
39impl<S: Clone> ProxyConnector<S> {
40    /// Creates a new [`ProxyConnector`].
41    fn new(
42        inner: S,
43        socks_proxy_layer: Socks5ProxyConnectorLayer,
44        http_proxy_layer: HttpProxyConnectorLayer,
45        required: bool,
46    ) -> Self {
47        Self {
48            socks: socks_proxy_layer.into_layer(inner.clone()),
49            http: http_proxy_layer.into_layer(inner.clone()),
50            inner,
51            required,
52        }
53    }
54
55    #[inline]
56    /// Creates a new required [`ProxyConnector`].
57    ///
58    /// This connector will fail if no [`ProxyAddress`] is configured
59    pub fn required(
60        inner: S,
61        socks_proxy_layer: Socks5ProxyConnectorLayer,
62        http_proxy_layer: HttpProxyConnectorLayer,
63    ) -> Self {
64        Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
65    }
66
67    #[inline]
68    /// Creates a new optional [`ProxyConnector`].
69    ///
70    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
71    pub fn optional(
72        inner: S,
73        socks_proxy_layer: Socks5ProxyConnectorLayer,
74        http_proxy_layer: HttpProxyConnectorLayer,
75    ) -> Self {
76        Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
77    }
78}
79
80impl<Input, S> Service<Input> for ProxyConnector<S>
81where
82    S: ConnectorService<Input, Connection: Stream + Unpin>,
83    Input: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
84        + Send
85        + ExtensionsMut
86        + 'static,
87{
88    type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
89    type Error = BoxError;
90
91    async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
92        let proxy = input.extensions().get::<ProxyAddress>();
93
94        match proxy {
95            None => {
96                if self.required {
97                    return Err("proxy required but none is defined".into());
98                }
99                tracing::trace!("no proxy detected in ctx, using inner connector");
100                let EstablishedClientConnection { input, conn } =
101                    self.inner.connect(input).await.into_box_error()?;
102
103                let conn = MaybeProxiedConnection::direct(conn);
104                Ok(EstablishedClientConnection { input, conn })
105            }
106            Some(proxy) => {
107                let protocol = proxy.protocol.as_ref();
108                tracing::trace!(?protocol, "proxy detected in ctx");
109
110                let protocol = protocol.unwrap_or_else(|| {
111                    tracing::trace!("no protocol detected, using http as protocol");
112                    &Protocol::HTTP
113                });
114
115                if protocol.is_socks5() {
116                    tracing::trace!("using socks proxy connector");
117                    let EstablishedClientConnection { input, conn } =
118                        self.socks.connect(input).await?;
119
120                    let conn = MaybeProxiedConnection::socks(conn);
121                    Ok(EstablishedClientConnection { input, conn })
122                } else if protocol.is_http() {
123                    tracing::trace!("using http proxy connector");
124                    let EstablishedClientConnection { input, conn } =
125                        self.http.connect(input).await?;
126
127                    let conn = MaybeProxiedConnection::http(conn);
128                    Ok(EstablishedClientConnection { input, conn })
129                } else {
130                    Err(BoxError::from("received unsupport proxy protocol")
131                        .with_context_debug_field("protocol", || protocol.clone()))
132                }
133            }
134        }
135    }
136}
137
138pin_project! {
139    /// A connection which will be proxied if a [`ProxyAddress`] was configured
140    pub struct MaybeProxiedConnection<S> {
141        #[pin]
142        inner: Connection<S>,
143    }
144}
145
146impl<S: ExtensionsMut> MaybeProxiedConnection<S> {
147    pub fn direct(conn: S) -> Self {
148        Self {
149            inner: Connection::Direct { conn },
150        }
151    }
152
153    pub fn socks(conn: S) -> Self {
154        Self {
155            inner: Connection::Socks { conn },
156        }
157    }
158
159    pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
160        Self {
161            inner: Connection::Http { conn },
162        }
163    }
164}
165
166impl<S: Debug> Debug for MaybeProxiedConnection<S> {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.debug_struct("MaybeProxiedConnection")
169            .field("inner", &self.inner)
170            .finish()
171    }
172}
173
174impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
175    fn extensions(&self) -> &Extensions {
176        match &self.inner {
177            Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
178            Connection::Http { conn } => conn.extensions(),
179        }
180    }
181}
182
183impl<S: ExtensionsMut> ExtensionsMut for MaybeProxiedConnection<S> {
184    fn extensions_mut(&mut self) -> &mut Extensions {
185        match &mut self.inner {
186            Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions_mut(),
187            Connection::Http { conn } => conn.extensions_mut(),
188        }
189    }
190}
191
192pin_project! {
193    #[project = ConnectionProj]
194    enum Connection<S> {
195        Direct{ #[pin] conn: S },
196        Socks{ #[pin] conn: S },
197        Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
198
199    }
200}
201
202impl<S: Debug> Debug for Connection<S> {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        match self {
205            Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
206            Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
207            Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
208        }
209    }
210}
211
212#[warn(clippy::missing_trait_methods)]
213impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
214    fn poll_write(
215        self: Pin<&mut Self>,
216        cx: &mut task::Context<'_>,
217        buf: &[u8],
218    ) -> Poll<Result<usize, std::io::Error>> {
219        match self.project().inner.project() {
220            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
221                conn.poll_write(cx, buf)
222            }
223            ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
224        }
225    }
226
227    fn poll_flush(
228        self: Pin<&mut Self>,
229        cx: &mut task::Context<'_>,
230    ) -> Poll<Result<(), std::io::Error>> {
231        match self.project().inner.project() {
232            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
233            ConnectionProj::Http { conn } => conn.poll_flush(cx),
234        }
235    }
236
237    fn poll_shutdown(
238        self: Pin<&mut Self>,
239        cx: &mut task::Context<'_>,
240    ) -> Poll<Result<(), std::io::Error>> {
241        match self.project().inner.project() {
242            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
243                conn.poll_shutdown(cx)
244            }
245            ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
246        }
247    }
248
249    fn is_write_vectored(&self) -> bool {
250        match &self.inner {
251            Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
252            Connection::Http { conn } => conn.is_write_vectored(),
253        }
254    }
255
256    fn poll_write_vectored(
257        self: Pin<&mut Self>,
258        cx: &mut task::Context<'_>,
259        bufs: &[std::io::IoSlice<'_>],
260    ) -> Poll<Result<usize, std::io::Error>> {
261        match self.project().inner.project() {
262            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
263                conn.poll_write_vectored(cx, bufs)
264            }
265            ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
266        }
267    }
268}
269
270#[warn(clippy::missing_trait_methods)]
271impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
272    fn poll_read(
273        self: Pin<&mut Self>,
274        cx: &mut task::Context<'_>,
275        buf: &mut tokio::io::ReadBuf<'_>,
276    ) -> Poll<std::io::Result<()>> {
277        match self.project().inner.project() {
278            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
279                conn.poll_read(cx, buf)
280            }
281            ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
282        }
283    }
284}
285
286/// Proxy connector layer which supports http(s) and socks5(h) proxy address
287///
288/// Connector will look at [`ProxyAddress`] to determine which proxy
289/// connector to use if one is configured
290pub struct ProxyConnectorLayer {
291    socks_layer: Socks5ProxyConnectorLayer,
292    http_layer: HttpProxyConnectorLayer,
293    required: bool,
294}
295
296impl ProxyConnectorLayer {
297    #[must_use]
298    /// Creates a new required [`ProxyConnectorLayer`].
299    ///
300    /// This connector will fail if no [`ProxyAddress`] is configured
301    pub fn required(
302        socks_proxy_layer: Socks5ProxyConnectorLayer,
303        http_proxy_layer: HttpProxyConnectorLayer,
304    ) -> Self {
305        Self {
306            socks_layer: socks_proxy_layer,
307            http_layer: http_proxy_layer,
308            required: true,
309        }
310    }
311
312    #[must_use]
313    /// Creates a new optional [`ProxyConnectorLayer`].
314    ///
315    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
316    pub fn optional(
317        socks_proxy_layer: Socks5ProxyConnectorLayer,
318        http_proxy_layer: HttpProxyConnectorLayer,
319    ) -> Self {
320        Self {
321            socks_layer: socks_proxy_layer,
322            http_layer: http_proxy_layer,
323            required: false,
324        }
325    }
326}
327
328impl<S: Clone> Layer<S> for ProxyConnectorLayer {
329    type Service = ProxyConnector<S>;
330
331    fn layer(&self, inner: S) -> Self::Service {
332        ProxyConnector::new(
333            inner,
334            self.socks_layer.clone(),
335            self.http_layer.clone(),
336            self.required,
337        )
338    }
339
340    fn into_layer(self, inner: S) -> Self::Service {
341        ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
342    }
343}