Skip to main content

rama/http/client/
proxy_connector.rs

1use crate::{
2    Layer, Service,
3    error::{BoxError, BoxErrorExt, ErrorContext as _, ErrorExt as _},
4    extensions::{Extensions, ExtensionsRef},
5    http::client::proxy::layer::{
6        HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7    },
8    io::Io,
9    net::{
10        AuthorityInputExt, Protocol, ProtocolInputExt,
11        address::ProxyAddress,
12        client::{ConnectorService, EstablishedClientConnection},
13    },
14    proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15    telemetry::tracing,
16};
17use pin_project_lite::pin_project;
18use std::{
19    fmt::Debug,
20    pin::Pin,
21    task::{self, Poll},
22};
23use tokio::io::{AsyncRead, AsyncWrite};
24
25/// Proxy connector which supports http(s) and socks5(h) proxy address
26///
27/// Connector will look at [`ProxyAddress`] to determine which proxy
28/// connector to use if one is configured
29#[derive(Debug, Clone)]
30pub struct ProxyConnector<S> {
31    inner: S,
32    socks: Socks5ProxyConnector<S>,
33    http: HttpProxyConnector<S>,
34    required: bool,
35}
36
37impl<S: Clone> ProxyConnector<S> {
38    /// Creates a new [`ProxyConnector`].
39    fn new(
40        inner: S,
41        socks_proxy_layer: Socks5ProxyConnectorLayer,
42        http_proxy_layer: HttpProxyConnectorLayer,
43        required: bool,
44    ) -> Self {
45        Self {
46            socks: socks_proxy_layer.into_layer(inner.clone()),
47            http: http_proxy_layer.into_layer(inner.clone()),
48            inner,
49            required,
50        }
51    }
52
53    #[inline]
54    /// Creates a new required [`ProxyConnector`].
55    ///
56    /// This connector will fail if no [`ProxyAddress`] is configured
57    pub fn required(
58        inner: S,
59        socks_proxy_layer: Socks5ProxyConnectorLayer,
60        http_proxy_layer: HttpProxyConnectorLayer,
61    ) -> Self {
62        Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
63    }
64
65    #[inline]
66    /// Creates a new optional [`ProxyConnector`].
67    ///
68    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
69    pub fn optional(
70        inner: S,
71        socks_proxy_layer: Socks5ProxyConnectorLayer,
72        http_proxy_layer: HttpProxyConnectorLayer,
73    ) -> Self {
74        Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
75    }
76}
77
78impl<Input, S> Service<Input> for ProxyConnector<S>
79where
80    S: ConnectorService<Input, Connection: Io + Unpin>,
81    Input: AuthorityInputExt + ProtocolInputExt + Send + ExtensionsRef + 'static,
82{
83    type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
84    type Error = BoxError;
85
86    async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
87        let proxy = input.extensions().get_ref::<ProxyAddress>();
88
89        match proxy {
90            None => {
91                if self.required {
92                    return Err(BoxError::from_static_str(
93                        "proxy required but none is defined",
94                    ));
95                }
96                tracing::trace!("no proxy detected in ctx, using inner connector");
97                let EstablishedClientConnection { input, conn } =
98                    self.inner.connect(input).await.into_box_error()?;
99
100                let conn = MaybeProxiedConnection::direct(conn);
101                Ok(EstablishedClientConnection { input, conn })
102            }
103            Some(proxy) => {
104                let protocol = proxy.protocol.as_ref();
105                tracing::trace!(?protocol, "proxy detected in ctx");
106
107                let protocol = protocol.unwrap_or_else(|| {
108                    tracing::trace!("no protocol detected, using http as protocol");
109                    &Protocol::HTTP
110                });
111
112                if protocol.is_socks5() {
113                    tracing::trace!("using socks proxy connector");
114                    let EstablishedClientConnection { input, conn } =
115                        self.socks.connect(input).await?;
116
117                    let conn = MaybeProxiedConnection::socks(conn);
118                    Ok(EstablishedClientConnection { input, conn })
119                } else if protocol.is_http() {
120                    tracing::trace!("using http proxy connector");
121                    let EstablishedClientConnection { input, conn } =
122                        self.http.connect(input).await?;
123
124                    let conn = MaybeProxiedConnection::http(conn);
125                    Ok(EstablishedClientConnection { input, conn })
126                } else {
127                    Err(
128                        BoxError::from_static_str("received unsupport proxy protocol")
129                            .with_context_debug_field("protocol", || protocol.clone()),
130                    )
131                }
132            }
133        }
134    }
135}
136
137pin_project! {
138    /// A connection which will be proxied if a [`ProxyAddress`] was configured
139    pub struct MaybeProxiedConnection<S> {
140        #[pin]
141        inner: Connection<S>,
142    }
143}
144
145impl<S: ExtensionsRef> MaybeProxiedConnection<S> {
146    pub fn direct(conn: S) -> Self {
147        Self {
148            inner: Connection::Direct { conn },
149        }
150    }
151
152    pub fn socks(conn: S) -> Self {
153        Self {
154            inner: Connection::Socks { conn },
155        }
156    }
157
158    pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
159        Self {
160            inner: Connection::Http { conn },
161        }
162    }
163}
164
165impl<S: Debug> Debug for MaybeProxiedConnection<S> {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        f.debug_struct("MaybeProxiedConnection")
168            .field("inner", &self.inner)
169            .finish()
170    }
171}
172
173impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
174    fn extensions(&self) -> &Extensions {
175        match &self.inner {
176            Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
177            Connection::Http { conn } => conn.extensions(),
178        }
179    }
180}
181
182pin_project! {
183    #[project = ConnectionProj]
184    enum Connection<S> {
185        Direct{ #[pin] conn: S },
186        Socks{ #[pin] conn: S },
187        Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
188
189    }
190}
191
192impl<S: Debug> Debug for Connection<S> {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        match self {
195            Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
196            Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
197            Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
198        }
199    }
200}
201
202#[warn(clippy::missing_trait_methods)]
203impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
204    fn poll_write(
205        self: Pin<&mut Self>,
206        cx: &mut task::Context<'_>,
207        buf: &[u8],
208    ) -> Poll<Result<usize, std::io::Error>> {
209        match self.project().inner.project() {
210            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
211                conn.poll_write(cx, buf)
212            }
213            ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
214        }
215    }
216
217    fn poll_flush(
218        self: Pin<&mut Self>,
219        cx: &mut task::Context<'_>,
220    ) -> Poll<Result<(), std::io::Error>> {
221        match self.project().inner.project() {
222            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
223            ConnectionProj::Http { conn } => conn.poll_flush(cx),
224        }
225    }
226
227    fn poll_shutdown(
228        self: Pin<&mut Self>,
229        cx: &mut task::Context<'_>,
230    ) -> Poll<Result<(), std::io::Error>> {
231        match self.project().inner.project() {
232            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
233                conn.poll_shutdown(cx)
234            }
235            ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
236        }
237    }
238
239    fn is_write_vectored(&self) -> bool {
240        match &self.inner {
241            Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
242            Connection::Http { conn } => conn.is_write_vectored(),
243        }
244    }
245
246    fn poll_write_vectored(
247        self: Pin<&mut Self>,
248        cx: &mut task::Context<'_>,
249        bufs: &[std::io::IoSlice<'_>],
250    ) -> Poll<Result<usize, std::io::Error>> {
251        match self.project().inner.project() {
252            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
253                conn.poll_write_vectored(cx, bufs)
254            }
255            ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
256        }
257    }
258}
259
260#[warn(clippy::missing_trait_methods)]
261impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
262    fn poll_read(
263        self: Pin<&mut Self>,
264        cx: &mut task::Context<'_>,
265        buf: &mut tokio::io::ReadBuf<'_>,
266    ) -> Poll<std::io::Result<()>> {
267        match self.project().inner.project() {
268            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
269                conn.poll_read(cx, buf)
270            }
271            ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
272        }
273    }
274}
275
276/// Proxy connector layer which supports http(s) and socks5(h) proxy address
277///
278/// Connector will look at [`ProxyAddress`] to determine which proxy
279/// connector to use if one is configured
280pub struct ProxyConnectorLayer {
281    socks_layer: Socks5ProxyConnectorLayer,
282    http_layer: HttpProxyConnectorLayer,
283    required: bool,
284}
285
286impl ProxyConnectorLayer {
287    #[must_use]
288    /// Creates a new required [`ProxyConnectorLayer`].
289    ///
290    /// This connector will fail if no [`ProxyAddress`] is configured
291    pub fn required(
292        socks_proxy_layer: Socks5ProxyConnectorLayer,
293        http_proxy_layer: HttpProxyConnectorLayer,
294    ) -> Self {
295        Self {
296            socks_layer: socks_proxy_layer,
297            http_layer: http_proxy_layer,
298            required: true,
299        }
300    }
301
302    #[must_use]
303    /// Creates a new optional [`ProxyConnectorLayer`].
304    ///
305    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
306    pub fn optional(
307        socks_proxy_layer: Socks5ProxyConnectorLayer,
308        http_proxy_layer: HttpProxyConnectorLayer,
309    ) -> Self {
310        Self {
311            socks_layer: socks_proxy_layer,
312            http_layer: http_proxy_layer,
313            required: false,
314        }
315    }
316}
317
318impl<S: Clone> Layer<S> for ProxyConnectorLayer {
319    type Service = ProxyConnector<S>;
320
321    fn layer(&self, inner: S) -> Self::Service {
322        ProxyConnector::new(
323            inner,
324            self.socks_layer.clone(),
325            self.http_layer.clone(),
326            self.required,
327        )
328    }
329
330    fn into_layer(self, inner: S) -> Self::Service {
331        ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
332    }
333}