Skip to main content

rama/http/client/
proxy_connector.rs

1use crate::{
2    Layer, Service,
3    error::{BoxError, BoxErrorExt, ErrorContext as _, ErrorExt as _},
4    extensions::{Extensions, ExtensionsRef},
5    http::client::proxy::layer::{
6        HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7    },
8    io::Io,
9    net::{
10        Protocol,
11        address::ProxyAddress,
12        client::{ConnectorService, EstablishedClientConnection},
13        transport::TryRefIntoTransportContext,
14    },
15    proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
16    telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use std::{
20    fmt::Debug,
21    pin::Pin,
22    task::{self, Poll},
23};
24use tokio::io::{AsyncRead, AsyncWrite};
25
26/// Proxy connector which supports http(s) and socks5(h) proxy address
27///
28/// Connector will look at [`ProxyAddress`] to determine which proxy
29/// connector to use if one is configured
30#[derive(Debug, Clone)]
31pub struct ProxyConnector<S> {
32    inner: S,
33    socks: Socks5ProxyConnector<S>,
34    http: HttpProxyConnector<S>,
35    required: bool,
36}
37
38impl<S: Clone> ProxyConnector<S> {
39    /// Creates a new [`ProxyConnector`].
40    fn new(
41        inner: S,
42        socks_proxy_layer: Socks5ProxyConnectorLayer,
43        http_proxy_layer: HttpProxyConnectorLayer,
44        required: bool,
45    ) -> Self {
46        Self {
47            socks: socks_proxy_layer.into_layer(inner.clone()),
48            http: http_proxy_layer.into_layer(inner.clone()),
49            inner,
50            required,
51        }
52    }
53
54    #[inline]
55    /// Creates a new required [`ProxyConnector`].
56    ///
57    /// This connector will fail if no [`ProxyAddress`] is configured
58    pub fn required(
59        inner: S,
60        socks_proxy_layer: Socks5ProxyConnectorLayer,
61        http_proxy_layer: HttpProxyConnectorLayer,
62    ) -> Self {
63        Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
64    }
65
66    #[inline]
67    /// Creates a new optional [`ProxyConnector`].
68    ///
69    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
70    pub fn optional(
71        inner: S,
72        socks_proxy_layer: Socks5ProxyConnectorLayer,
73        http_proxy_layer: HttpProxyConnectorLayer,
74    ) -> Self {
75        Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
76    }
77}
78
79impl<Input, S> Service<Input> for ProxyConnector<S>
80where
81    S: ConnectorService<Input, Connection: Io + Unpin>,
82    Input: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
83        + Send
84        + ExtensionsRef
85        + 'static,
86{
87    type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
88    type Error = BoxError;
89
90    async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
91        let proxy = input.extensions().get_ref::<ProxyAddress>();
92
93        match proxy {
94            None => {
95                if self.required {
96                    return Err(BoxError::from_static_str(
97                        "proxy required but none is defined",
98                    ));
99                }
100                tracing::trace!("no proxy detected in ctx, using inner connector");
101                let EstablishedClientConnection { input, conn } =
102                    self.inner.connect(input).await.into_box_error()?;
103
104                let conn = MaybeProxiedConnection::direct(conn);
105                Ok(EstablishedClientConnection { input, conn })
106            }
107            Some(proxy) => {
108                let protocol = proxy.protocol.as_ref();
109                tracing::trace!(?protocol, "proxy detected in ctx");
110
111                let protocol = protocol.unwrap_or_else(|| {
112                    tracing::trace!("no protocol detected, using http as protocol");
113                    &Protocol::HTTP
114                });
115
116                if protocol.is_socks5() {
117                    tracing::trace!("using socks proxy connector");
118                    let EstablishedClientConnection { input, conn } =
119                        self.socks.connect(input).await?;
120
121                    let conn = MaybeProxiedConnection::socks(conn);
122                    Ok(EstablishedClientConnection { input, conn })
123                } else if protocol.is_http() {
124                    tracing::trace!("using http proxy connector");
125                    let EstablishedClientConnection { input, conn } =
126                        self.http.connect(input).await?;
127
128                    let conn = MaybeProxiedConnection::http(conn);
129                    Ok(EstablishedClientConnection { input, conn })
130                } else {
131                    Err(
132                        BoxError::from_static_str("received unsupport proxy protocol")
133                            .with_context_debug_field("protocol", || protocol.clone()),
134                    )
135                }
136            }
137        }
138    }
139}
140
141pin_project! {
142    /// A connection which will be proxied if a [`ProxyAddress`] was configured
143    pub struct MaybeProxiedConnection<S> {
144        #[pin]
145        inner: Connection<S>,
146    }
147}
148
149impl<S: ExtensionsRef> MaybeProxiedConnection<S> {
150    pub fn direct(conn: S) -> Self {
151        Self {
152            inner: Connection::Direct { conn },
153        }
154    }
155
156    pub fn socks(conn: S) -> Self {
157        Self {
158            inner: Connection::Socks { conn },
159        }
160    }
161
162    pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
163        Self {
164            inner: Connection::Http { conn },
165        }
166    }
167}
168
169impl<S: Debug> Debug for MaybeProxiedConnection<S> {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        f.debug_struct("MaybeProxiedConnection")
172            .field("inner", &self.inner)
173            .finish()
174    }
175}
176
177impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
178    fn extensions(&self) -> &Extensions {
179        match &self.inner {
180            Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
181            Connection::Http { conn } => conn.extensions(),
182        }
183    }
184}
185
186pin_project! {
187    #[project = ConnectionProj]
188    enum Connection<S> {
189        Direct{ #[pin] conn: S },
190        Socks{ #[pin] conn: S },
191        Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
192
193    }
194}
195
196impl<S: Debug> Debug for Connection<S> {
197    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198        match self {
199            Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
200            Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
201            Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
202        }
203    }
204}
205
206#[warn(clippy::missing_trait_methods)]
207impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
208    fn poll_write(
209        self: Pin<&mut Self>,
210        cx: &mut task::Context<'_>,
211        buf: &[u8],
212    ) -> Poll<Result<usize, std::io::Error>> {
213        match self.project().inner.project() {
214            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
215                conn.poll_write(cx, buf)
216            }
217            ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
218        }
219    }
220
221    fn poll_flush(
222        self: Pin<&mut Self>,
223        cx: &mut task::Context<'_>,
224    ) -> Poll<Result<(), std::io::Error>> {
225        match self.project().inner.project() {
226            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
227            ConnectionProj::Http { conn } => conn.poll_flush(cx),
228        }
229    }
230
231    fn poll_shutdown(
232        self: Pin<&mut Self>,
233        cx: &mut task::Context<'_>,
234    ) -> Poll<Result<(), std::io::Error>> {
235        match self.project().inner.project() {
236            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
237                conn.poll_shutdown(cx)
238            }
239            ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
240        }
241    }
242
243    fn is_write_vectored(&self) -> bool {
244        match &self.inner {
245            Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
246            Connection::Http { conn } => conn.is_write_vectored(),
247        }
248    }
249
250    fn poll_write_vectored(
251        self: Pin<&mut Self>,
252        cx: &mut task::Context<'_>,
253        bufs: &[std::io::IoSlice<'_>],
254    ) -> Poll<Result<usize, std::io::Error>> {
255        match self.project().inner.project() {
256            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
257                conn.poll_write_vectored(cx, bufs)
258            }
259            ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
260        }
261    }
262}
263
264#[warn(clippy::missing_trait_methods)]
265impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
266    fn poll_read(
267        self: Pin<&mut Self>,
268        cx: &mut task::Context<'_>,
269        buf: &mut tokio::io::ReadBuf<'_>,
270    ) -> Poll<std::io::Result<()>> {
271        match self.project().inner.project() {
272            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
273                conn.poll_read(cx, buf)
274            }
275            ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
276        }
277    }
278}
279
280/// Proxy connector layer which supports http(s) and socks5(h) proxy address
281///
282/// Connector will look at [`ProxyAddress`] to determine which proxy
283/// connector to use if one is configured
284pub struct ProxyConnectorLayer {
285    socks_layer: Socks5ProxyConnectorLayer,
286    http_layer: HttpProxyConnectorLayer,
287    required: bool,
288}
289
290impl ProxyConnectorLayer {
291    #[must_use]
292    /// Creates a new required [`ProxyConnectorLayer`].
293    ///
294    /// This connector will fail if no [`ProxyAddress`] is configured
295    pub fn required(
296        socks_proxy_layer: Socks5ProxyConnectorLayer,
297        http_proxy_layer: HttpProxyConnectorLayer,
298    ) -> Self {
299        Self {
300            socks_layer: socks_proxy_layer,
301            http_layer: http_proxy_layer,
302            required: true,
303        }
304    }
305
306    #[must_use]
307    /// Creates a new optional [`ProxyConnectorLayer`].
308    ///
309    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
310    pub fn optional(
311        socks_proxy_layer: Socks5ProxyConnectorLayer,
312        http_proxy_layer: HttpProxyConnectorLayer,
313    ) -> Self {
314        Self {
315            socks_layer: socks_proxy_layer,
316            http_layer: http_proxy_layer,
317            required: false,
318        }
319    }
320}
321
322impl<S: Clone> Layer<S> for ProxyConnectorLayer {
323    type Service = ProxyConnector<S>;
324
325    fn layer(&self, inner: S) -> Self::Service {
326        ProxyConnector::new(
327            inner,
328            self.socks_layer.clone(),
329            self.http_layer.clone(),
330            self.required,
331        )
332    }
333
334    fn into_layer(self, inner: S) -> Self::Service {
335        ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
336    }
337}