rama/http/client/
proxy_connector.rs

1use crate::{
2    Layer, Service,
3    error::{BoxError, OpaqueError},
4    http::client::proxy::layer::{
5        HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
6    },
7    net::{
8        Protocol,
9        address::ProxyAddress,
10        client::{ConnectorService, EstablishedClientConnection},
11        stream::Stream,
12        transport::TryRefIntoTransportContext,
13    },
14    proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15    telemetry::tracing,
16};
17use pin_project_lite::pin_project;
18use std::{
19    fmt::Debug,
20    pin::Pin,
21    sync::Arc,
22    task::{self, Poll},
23};
24use tokio::io::{AsyncRead, AsyncWrite};
25
26/// Proxy connector which supports http(s) and socks5(h) proxy address
27///
28/// Connector will look at [`ProxyAddress`] to determine which proxy
29/// connector to use if one is configured
30pub struct ProxyConnector<S> {
31    inner: S,
32    socks: Socks5ProxyConnector<S>,
33    http: HttpProxyConnector<S>,
34    required: bool,
35}
36
37impl<S> ProxyConnector<S> {
38    /// Creates a new [`ProxyConnector`].
39    fn new(
40        inner: S,
41        socks_proxy_layer: Socks5ProxyConnectorLayer,
42        http_proxy_layer: HttpProxyConnectorLayer,
43        required: bool,
44    ) -> ProxyConnector<Arc<S>> {
45        let inner = Arc::new(inner);
46        ProxyConnector {
47            socks: socks_proxy_layer.into_layer(inner.clone()),
48            http: http_proxy_layer.into_layer(inner.clone()),
49            inner,
50            required,
51        }
52    }
53
54    #[inline]
55    /// Creates a new required [`ProxyConnector`].
56    ///
57    /// This connector will fail if no [`ProxyAddress`] is configured
58    pub fn required(
59        inner: S,
60        socks_proxy_layer: Socks5ProxyConnectorLayer,
61        http_proxy_layer: HttpProxyConnectorLayer,
62    ) -> ProxyConnector<Arc<S>> {
63        Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
64    }
65
66    #[inline]
67    /// Creates a new optional [`ProxyConnector`].
68    ///
69    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
70    pub fn optional(
71        inner: S,
72        socks_proxy_layer: Socks5ProxyConnectorLayer,
73        http_proxy_layer: HttpProxyConnectorLayer,
74    ) -> ProxyConnector<Arc<S>> {
75        Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
76    }
77}
78
79impl<Request, S> Service<Request> for ProxyConnector<S>
80where
81    S: ConnectorService<Request, Connection: Stream + Unpin, Error: Into<BoxError>>,
82    Request: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static> + Send + 'static,
83{
84    type Response = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Request>;
85
86    type Error = BoxError;
87
88    async fn serve(
89        &self,
90        ctx: rama_core::Context,
91        req: Request,
92    ) -> Result<Self::Response, Self::Error> {
93        let proxy = ctx.get::<ProxyAddress>();
94
95        match proxy {
96            None => {
97                if self.required {
98                    return Err("proxy required but none is defined".into());
99                }
100                tracing::trace!("no proxy detected in ctx, using inner connector");
101                let EstablishedClientConnection { ctx, req, conn } =
102                    self.inner.connect(ctx, req).await.map_err(Into::into)?;
103                Ok(EstablishedClientConnection {
104                    ctx,
105                    req,
106                    conn: MaybeProxiedConnection {
107                        inner: Connection::Direct { conn },
108                    },
109                })
110            }
111            Some(proxy) => {
112                let protocol = proxy.protocol.as_ref();
113                tracing::trace!(?protocol, "proxy detected in ctx");
114
115                let protocol = protocol.unwrap_or_else(|| {
116                    tracing::trace!("no protocol detected, using http as protocol");
117                    &Protocol::HTTP
118                });
119
120                if protocol.is_socks5() {
121                    tracing::trace!("using socks proxy connector");
122                    let EstablishedClientConnection { ctx, req, conn } =
123                        self.socks.connect(ctx, req).await?;
124                    Ok(EstablishedClientConnection {
125                        ctx,
126                        req,
127                        conn: MaybeProxiedConnection {
128                            inner: Connection::Socks { conn },
129                        },
130                    })
131                } else if protocol.is_http() {
132                    tracing::trace!("using http proxy connector");
133                    let EstablishedClientConnection { ctx, req, conn } =
134                        self.http.connect(ctx, req).await?;
135                    Ok(EstablishedClientConnection {
136                        ctx,
137                        req,
138                        conn: MaybeProxiedConnection {
139                            inner: Connection::Http { conn },
140                        },
141                    })
142                } else {
143                    Err(OpaqueError::from_display(format!(
144                        "received unsupport proxy protocol {protocol:?}"
145                    ))
146                    .into_boxed())
147                }
148            }
149        }
150    }
151}
152
153pin_project! {
154    /// A connection which will be proxied if a [`ProxyAddress`] was configured
155    pub struct MaybeProxiedConnection<S> {
156        #[pin]
157        inner: Connection<S>,
158    }
159}
160
161impl<S: Debug> Debug for MaybeProxiedConnection<S> {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        f.debug_struct("MaybeProxiedConnection")
164            .field("inner", &self.inner)
165            .finish()
166    }
167}
168
169pin_project! {
170    #[project = ConnectionProj]
171    enum Connection<S> {
172        Direct{ #[pin] conn: S },
173        Socks{ #[pin] conn: S },
174        Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
175
176    }
177}
178
179impl<S: Debug> Debug for Connection<S> {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        match self {
182            Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
183            Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
184            Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
185        }
186    }
187}
188
189impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
190    fn poll_write(
191        self: Pin<&mut Self>,
192        cx: &mut task::Context<'_>,
193        buf: &[u8],
194    ) -> Poll<Result<usize, std::io::Error>> {
195        match self.project().inner.project() {
196            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
197                conn.poll_write(cx, buf)
198            }
199            ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
200        }
201    }
202
203    fn poll_flush(
204        self: Pin<&mut Self>,
205        cx: &mut task::Context<'_>,
206    ) -> Poll<Result<(), std::io::Error>> {
207        match self.project().inner.project() {
208            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
209            ConnectionProj::Http { conn } => conn.poll_flush(cx),
210        }
211    }
212
213    fn poll_shutdown(
214        self: Pin<&mut Self>,
215        cx: &mut task::Context<'_>,
216    ) -> Poll<Result<(), std::io::Error>> {
217        match self.project().inner.project() {
218            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
219                conn.poll_shutdown(cx)
220            }
221            ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
222        }
223    }
224
225    fn is_write_vectored(&self) -> bool {
226        match &self.inner {
227            Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
228            Connection::Http { conn } => conn.is_write_vectored(),
229        }
230    }
231
232    fn poll_write_vectored(
233        self: Pin<&mut Self>,
234        cx: &mut task::Context<'_>,
235        bufs: &[std::io::IoSlice<'_>],
236    ) -> Poll<Result<usize, std::io::Error>> {
237        match self.project().inner.project() {
238            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
239                conn.poll_write_vectored(cx, bufs)
240            }
241            ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
242        }
243    }
244}
245
246impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
247    fn poll_read(
248        self: Pin<&mut Self>,
249        cx: &mut task::Context<'_>,
250        buf: &mut tokio::io::ReadBuf<'_>,
251    ) -> Poll<std::io::Result<()>> {
252        match self.project().inner.project() {
253            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
254                conn.poll_read(cx, buf)
255            }
256            ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
257        }
258    }
259}
260
261/// Proxy connector layer which supports http(s) and socks5(h) proxy address
262///
263/// Connector will look at [`ProxyAddress`] to determine which proxy
264/// connector to use if one is configured
265pub struct ProxyConnectorLayer {
266    socks_layer: Socks5ProxyConnectorLayer,
267    http_layer: HttpProxyConnectorLayer,
268    required: bool,
269}
270
271impl ProxyConnectorLayer {
272    #[must_use]
273    /// Creates a new required [`ProxyConnectorLayer`].
274    ///
275    /// This connector will fail if no [`ProxyAddress`] is configured
276    pub fn required(
277        socks_proxy_layer: Socks5ProxyConnectorLayer,
278        http_proxy_layer: HttpProxyConnectorLayer,
279    ) -> Self {
280        Self {
281            socks_layer: socks_proxy_layer,
282            http_layer: http_proxy_layer,
283            required: true,
284        }
285    }
286
287    #[must_use]
288    /// Creates a new optional [`ProxyConnectorLayer`].
289    ///
290    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
291    pub fn optional(
292        socks_proxy_layer: Socks5ProxyConnectorLayer,
293        http_proxy_layer: HttpProxyConnectorLayer,
294    ) -> Self {
295        Self {
296            socks_layer: socks_proxy_layer,
297            http_layer: http_proxy_layer,
298            required: false,
299        }
300    }
301}
302
303impl<S> Layer<S> for ProxyConnectorLayer {
304    type Service = ProxyConnector<Arc<S>>;
305
306    fn layer(&self, inner: S) -> Self::Service {
307        ProxyConnector::new(
308            inner,
309            self.socks_layer.clone(),
310            self.http_layer.clone(),
311            self.required,
312        )
313    }
314
315    fn into_layer(self, inner: S) -> Self::Service {
316        ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
317    }
318}