Skip to main content

rama/http/client/
proxy_connector.rs

1use crate::{
2    Layer, Service,
3    error::BoxError,
4    extensions::{Extensions, ExtensionsRef},
5    http::client::proxy::layer::{
6        HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7    },
8    io::Io,
9    net::{
10        Protocol,
11        address::ProxyAddress,
12        client::{ConnectorService, EstablishedClientConnection},
13        transport::TryRefIntoTransportContext,
14    },
15    proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
16    telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use rama_core::error::{ErrorContext as _, ErrorExt as _, extra::OpaqueError};
20use std::{
21    fmt::Debug,
22    pin::Pin,
23    task::{self, Poll},
24};
25use tokio::io::{AsyncRead, AsyncWrite};
26
27/// Proxy connector which supports http(s) and socks5(h) proxy address
28///
29/// Connector will look at [`ProxyAddress`] to determine which proxy
30/// connector to use if one is configured
31#[derive(Debug, Clone)]
32pub struct ProxyConnector<S> {
33    inner: S,
34    socks: Socks5ProxyConnector<S>,
35    http: HttpProxyConnector<S>,
36    required: bool,
37}
38
39impl<S: Clone> ProxyConnector<S> {
40    /// Creates a new [`ProxyConnector`].
41    fn new(
42        inner: S,
43        socks_proxy_layer: Socks5ProxyConnectorLayer,
44        http_proxy_layer: HttpProxyConnectorLayer,
45        required: bool,
46    ) -> Self {
47        Self {
48            socks: socks_proxy_layer.into_layer(inner.clone()),
49            http: http_proxy_layer.into_layer(inner.clone()),
50            inner,
51            required,
52        }
53    }
54
55    #[inline]
56    /// Creates a new required [`ProxyConnector`].
57    ///
58    /// This connector will fail if no [`ProxyAddress`] is configured
59    pub fn required(
60        inner: S,
61        socks_proxy_layer: Socks5ProxyConnectorLayer,
62        http_proxy_layer: HttpProxyConnectorLayer,
63    ) -> Self {
64        Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
65    }
66
67    #[inline]
68    /// Creates a new optional [`ProxyConnector`].
69    ///
70    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
71    pub fn optional(
72        inner: S,
73        socks_proxy_layer: Socks5ProxyConnectorLayer,
74        http_proxy_layer: HttpProxyConnectorLayer,
75    ) -> Self {
76        Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
77    }
78}
79
80impl<Input, S> Service<Input> for ProxyConnector<S>
81where
82    S: ConnectorService<Input, Connection: Io + Unpin>,
83    Input: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
84        + Send
85        + ExtensionsRef
86        + 'static,
87{
88    type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
89    type Error = BoxError;
90
91    async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
92        let proxy = input.extensions().get_ref::<ProxyAddress>();
93
94        match proxy {
95            None => {
96                if self.required {
97                    return Err(
98                        OpaqueError::from_static_str("proxy required but none is defined")
99                            .into_box_error(),
100                    );
101                }
102                tracing::trace!("no proxy detected in ctx, using inner connector");
103                let EstablishedClientConnection { input, conn } =
104                    self.inner.connect(input).await.into_box_error()?;
105
106                let conn = MaybeProxiedConnection::direct(conn);
107                Ok(EstablishedClientConnection { input, conn })
108            }
109            Some(proxy) => {
110                let protocol = proxy.protocol.as_ref();
111                tracing::trace!(?protocol, "proxy detected in ctx");
112
113                let protocol = protocol.unwrap_or_else(|| {
114                    tracing::trace!("no protocol detected, using http as protocol");
115                    &Protocol::HTTP
116                });
117
118                if protocol.is_socks5() {
119                    tracing::trace!("using socks proxy connector");
120                    let EstablishedClientConnection { input, conn } =
121                        self.socks.connect(input).await?;
122
123                    let conn = MaybeProxiedConnection::socks(conn);
124                    Ok(EstablishedClientConnection { input, conn })
125                } else if protocol.is_http() {
126                    tracing::trace!("using http proxy connector");
127                    let EstablishedClientConnection { input, conn } =
128                        self.http.connect(input).await?;
129
130                    let conn = MaybeProxiedConnection::http(conn);
131                    Ok(EstablishedClientConnection { input, conn })
132                } else {
133                    Err(
134                        OpaqueError::from_static_str("received unsupport proxy protocol")
135                            .with_context_debug_field("protocol", || protocol.clone()),
136                    )
137                }
138            }
139        }
140    }
141}
142
143pin_project! {
144    /// A connection which will be proxied if a [`ProxyAddress`] was configured
145    pub struct MaybeProxiedConnection<S> {
146        #[pin]
147        inner: Connection<S>,
148    }
149}
150
151impl<S: ExtensionsRef> MaybeProxiedConnection<S> {
152    pub fn direct(conn: S) -> Self {
153        Self {
154            inner: Connection::Direct { conn },
155        }
156    }
157
158    pub fn socks(conn: S) -> Self {
159        Self {
160            inner: Connection::Socks { conn },
161        }
162    }
163
164    pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
165        Self {
166            inner: Connection::Http { conn },
167        }
168    }
169}
170
171impl<S: Debug> Debug for MaybeProxiedConnection<S> {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        f.debug_struct("MaybeProxiedConnection")
174            .field("inner", &self.inner)
175            .finish()
176    }
177}
178
179impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
180    fn extensions(&self) -> &Extensions {
181        match &self.inner {
182            Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
183            Connection::Http { conn } => conn.extensions(),
184        }
185    }
186}
187
188pin_project! {
189    #[project = ConnectionProj]
190    enum Connection<S> {
191        Direct{ #[pin] conn: S },
192        Socks{ #[pin] conn: S },
193        Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
194
195    }
196}
197
198impl<S: Debug> Debug for Connection<S> {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        match self {
201            Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
202            Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
203            Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
204        }
205    }
206}
207
208#[warn(clippy::missing_trait_methods)]
209impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
210    fn poll_write(
211        self: Pin<&mut Self>,
212        cx: &mut task::Context<'_>,
213        buf: &[u8],
214    ) -> Poll<Result<usize, std::io::Error>> {
215        match self.project().inner.project() {
216            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
217                conn.poll_write(cx, buf)
218            }
219            ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
220        }
221    }
222
223    fn poll_flush(
224        self: Pin<&mut Self>,
225        cx: &mut task::Context<'_>,
226    ) -> Poll<Result<(), std::io::Error>> {
227        match self.project().inner.project() {
228            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
229            ConnectionProj::Http { conn } => conn.poll_flush(cx),
230        }
231    }
232
233    fn poll_shutdown(
234        self: Pin<&mut Self>,
235        cx: &mut task::Context<'_>,
236    ) -> Poll<Result<(), std::io::Error>> {
237        match self.project().inner.project() {
238            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
239                conn.poll_shutdown(cx)
240            }
241            ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
242        }
243    }
244
245    fn is_write_vectored(&self) -> bool {
246        match &self.inner {
247            Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
248            Connection::Http { conn } => conn.is_write_vectored(),
249        }
250    }
251
252    fn poll_write_vectored(
253        self: Pin<&mut Self>,
254        cx: &mut task::Context<'_>,
255        bufs: &[std::io::IoSlice<'_>],
256    ) -> Poll<Result<usize, std::io::Error>> {
257        match self.project().inner.project() {
258            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
259                conn.poll_write_vectored(cx, bufs)
260            }
261            ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
262        }
263    }
264}
265
266#[warn(clippy::missing_trait_methods)]
267impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
268    fn poll_read(
269        self: Pin<&mut Self>,
270        cx: &mut task::Context<'_>,
271        buf: &mut tokio::io::ReadBuf<'_>,
272    ) -> Poll<std::io::Result<()>> {
273        match self.project().inner.project() {
274            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
275                conn.poll_read(cx, buf)
276            }
277            ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
278        }
279    }
280}
281
282/// Proxy connector layer which supports http(s) and socks5(h) proxy address
283///
284/// Connector will look at [`ProxyAddress`] to determine which proxy
285/// connector to use if one is configured
286pub struct ProxyConnectorLayer {
287    socks_layer: Socks5ProxyConnectorLayer,
288    http_layer: HttpProxyConnectorLayer,
289    required: bool,
290}
291
292impl ProxyConnectorLayer {
293    #[must_use]
294    /// Creates a new required [`ProxyConnectorLayer`].
295    ///
296    /// This connector will fail if no [`ProxyAddress`] is configured
297    pub fn required(
298        socks_proxy_layer: Socks5ProxyConnectorLayer,
299        http_proxy_layer: HttpProxyConnectorLayer,
300    ) -> Self {
301        Self {
302            socks_layer: socks_proxy_layer,
303            http_layer: http_proxy_layer,
304            required: true,
305        }
306    }
307
308    #[must_use]
309    /// Creates a new optional [`ProxyConnectorLayer`].
310    ///
311    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
312    pub fn optional(
313        socks_proxy_layer: Socks5ProxyConnectorLayer,
314        http_proxy_layer: HttpProxyConnectorLayer,
315    ) -> Self {
316        Self {
317            socks_layer: socks_proxy_layer,
318            http_layer: http_proxy_layer,
319            required: false,
320        }
321    }
322}
323
324impl<S: Clone> Layer<S> for ProxyConnectorLayer {
325    type Service = ProxyConnector<S>;
326
327    fn layer(&self, inner: S) -> Self::Service {
328        ProxyConnector::new(
329            inner,
330            self.socks_layer.clone(),
331            self.http_layer.clone(),
332            self.required,
333        )
334    }
335
336    fn into_layer(self, inner: S) -> Self::Service {
337        ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
338    }
339}