rama/http/client/
proxy_connector.rs

1use crate::{
2    Layer, Service,
3    error::{BoxError, OpaqueError},
4    extensions::{Extensions, ExtensionsMut, ExtensionsRef},
5    http::client::proxy::layer::{
6        HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7    },
8    net::{
9        Protocol,
10        address::ProxyAddress,
11        client::{ConnectorService, EstablishedClientConnection},
12        transport::TryRefIntoTransportContext,
13    },
14    proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15    stream::Stream,
16    telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use std::{
20    fmt::Debug,
21    pin::Pin,
22    sync::Arc,
23    task::{self, Poll},
24};
25use tokio::io::{AsyncRead, AsyncWrite};
26
27/// Proxy connector which supports http(s) and socks5(h) proxy address
28///
29/// Connector will look at [`ProxyAddress`] to determine which proxy
30/// connector to use if one is configured
31pub struct ProxyConnector<S> {
32    inner: S,
33    socks: Socks5ProxyConnector<S>,
34    http: HttpProxyConnector<S>,
35    required: bool,
36}
37
38impl<S> ProxyConnector<S> {
39    /// Creates a new [`ProxyConnector`].
40    fn new(
41        inner: S,
42        socks_proxy_layer: Socks5ProxyConnectorLayer,
43        http_proxy_layer: HttpProxyConnectorLayer,
44        required: bool,
45    ) -> ProxyConnector<Arc<S>> {
46        let inner = Arc::new(inner);
47        ProxyConnector {
48            socks: socks_proxy_layer.into_layer(inner.clone()),
49            http: http_proxy_layer.into_layer(inner.clone()),
50            inner,
51            required,
52        }
53    }
54
55    #[inline]
56    /// Creates a new required [`ProxyConnector`].
57    ///
58    /// This connector will fail if no [`ProxyAddress`] is configured
59    pub fn required(
60        inner: S,
61        socks_proxy_layer: Socks5ProxyConnectorLayer,
62        http_proxy_layer: HttpProxyConnectorLayer,
63    ) -> ProxyConnector<Arc<S>> {
64        Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
65    }
66
67    #[inline]
68    /// Creates a new optional [`ProxyConnector`].
69    ///
70    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
71    pub fn optional(
72        inner: S,
73        socks_proxy_layer: Socks5ProxyConnectorLayer,
74        http_proxy_layer: HttpProxyConnectorLayer,
75    ) -> ProxyConnector<Arc<S>> {
76        Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
77    }
78}
79
80impl<Input, S> Service<Input> for ProxyConnector<S>
81where
82    S: ConnectorService<Input, Connection: Stream + Unpin>,
83    Input: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
84        + Send
85        + ExtensionsMut
86        + 'static,
87{
88    type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
89    type Error = BoxError;
90
91    async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
92        let proxy = input.extensions().get::<ProxyAddress>();
93
94        match proxy {
95            None => {
96                if self.required {
97                    return Err("proxy required but none is defined".into());
98                }
99                tracing::trace!("no proxy detected in ctx, using inner connector");
100                let EstablishedClientConnection { input, conn } =
101                    self.inner.connect(input).await.map_err(Into::into)?;
102
103                let conn = MaybeProxiedConnection::direct(conn);
104                Ok(EstablishedClientConnection { input, conn })
105            }
106            Some(proxy) => {
107                let protocol = proxy.protocol.as_ref();
108                tracing::trace!(?protocol, "proxy detected in ctx");
109
110                let protocol = protocol.unwrap_or_else(|| {
111                    tracing::trace!("no protocol detected, using http as protocol");
112                    &Protocol::HTTP
113                });
114
115                if protocol.is_socks5() {
116                    tracing::trace!("using socks proxy connector");
117                    let EstablishedClientConnection { input, conn } =
118                        self.socks.connect(input).await?;
119
120                    let conn = MaybeProxiedConnection::socks(conn);
121                    Ok(EstablishedClientConnection { input, conn })
122                } else if protocol.is_http() {
123                    tracing::trace!("using http proxy connector");
124                    let EstablishedClientConnection { input, conn } =
125                        self.http.connect(input).await?;
126
127                    let conn = MaybeProxiedConnection::http(conn);
128                    Ok(EstablishedClientConnection { input, conn })
129                } else {
130                    Err(OpaqueError::from_display(format!(
131                        "received unsupport proxy protocol {protocol:?}"
132                    ))
133                    .into_boxed())
134                }
135            }
136        }
137    }
138}
139
140pin_project! {
141    /// A connection which will be proxied if a [`ProxyAddress`] was configured
142    pub struct MaybeProxiedConnection<S> {
143        #[pin]
144        inner: Connection<S>,
145    }
146}
147
148impl<S: ExtensionsMut> MaybeProxiedConnection<S> {
149    pub fn direct(conn: S) -> Self {
150        Self {
151            inner: Connection::Direct { conn },
152        }
153    }
154
155    pub fn socks(conn: S) -> Self {
156        Self {
157            inner: Connection::Socks { conn },
158        }
159    }
160
161    pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
162        Self {
163            inner: Connection::Http { conn },
164        }
165    }
166}
167
168impl<S: Debug> Debug for MaybeProxiedConnection<S> {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        f.debug_struct("MaybeProxiedConnection")
171            .field("inner", &self.inner)
172            .finish()
173    }
174}
175
176impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
177    fn extensions(&self) -> &Extensions {
178        match &self.inner {
179            Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
180            Connection::Http { conn } => conn.extensions(),
181        }
182    }
183}
184
185impl<S: ExtensionsMut> ExtensionsMut for MaybeProxiedConnection<S> {
186    fn extensions_mut(&mut self) -> &mut Extensions {
187        match &mut self.inner {
188            Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions_mut(),
189            Connection::Http { conn } => conn.extensions_mut(),
190        }
191    }
192}
193
194pin_project! {
195    #[project = ConnectionProj]
196    enum Connection<S> {
197        Direct{ #[pin] conn: S },
198        Socks{ #[pin] conn: S },
199        Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
200
201    }
202}
203
204impl<S: Debug> Debug for Connection<S> {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        match self {
207            Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
208            Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
209            Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
210        }
211    }
212}
213
214#[warn(clippy::missing_trait_methods)]
215impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
216    fn poll_write(
217        self: Pin<&mut Self>,
218        cx: &mut task::Context<'_>,
219        buf: &[u8],
220    ) -> Poll<Result<usize, std::io::Error>> {
221        match self.project().inner.project() {
222            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
223                conn.poll_write(cx, buf)
224            }
225            ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
226        }
227    }
228
229    fn poll_flush(
230        self: Pin<&mut Self>,
231        cx: &mut task::Context<'_>,
232    ) -> Poll<Result<(), std::io::Error>> {
233        match self.project().inner.project() {
234            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
235            ConnectionProj::Http { conn } => conn.poll_flush(cx),
236        }
237    }
238
239    fn poll_shutdown(
240        self: Pin<&mut Self>,
241        cx: &mut task::Context<'_>,
242    ) -> Poll<Result<(), std::io::Error>> {
243        match self.project().inner.project() {
244            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
245                conn.poll_shutdown(cx)
246            }
247            ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
248        }
249    }
250
251    fn is_write_vectored(&self) -> bool {
252        match &self.inner {
253            Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
254            Connection::Http { conn } => conn.is_write_vectored(),
255        }
256    }
257
258    fn poll_write_vectored(
259        self: Pin<&mut Self>,
260        cx: &mut task::Context<'_>,
261        bufs: &[std::io::IoSlice<'_>],
262    ) -> Poll<Result<usize, std::io::Error>> {
263        match self.project().inner.project() {
264            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
265                conn.poll_write_vectored(cx, bufs)
266            }
267            ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
268        }
269    }
270}
271
272#[warn(clippy::missing_trait_methods)]
273impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
274    fn poll_read(
275        self: Pin<&mut Self>,
276        cx: &mut task::Context<'_>,
277        buf: &mut tokio::io::ReadBuf<'_>,
278    ) -> Poll<std::io::Result<()>> {
279        match self.project().inner.project() {
280            ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
281                conn.poll_read(cx, buf)
282            }
283            ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
284        }
285    }
286}
287
288/// Proxy connector layer which supports http(s) and socks5(h) proxy address
289///
290/// Connector will look at [`ProxyAddress`] to determine which proxy
291/// connector to use if one is configured
292pub struct ProxyConnectorLayer {
293    socks_layer: Socks5ProxyConnectorLayer,
294    http_layer: HttpProxyConnectorLayer,
295    required: bool,
296}
297
298impl ProxyConnectorLayer {
299    #[must_use]
300    /// Creates a new required [`ProxyConnectorLayer`].
301    ///
302    /// This connector will fail if no [`ProxyAddress`] is configured
303    pub fn required(
304        socks_proxy_layer: Socks5ProxyConnectorLayer,
305        http_proxy_layer: HttpProxyConnectorLayer,
306    ) -> Self {
307        Self {
308            socks_layer: socks_proxy_layer,
309            http_layer: http_proxy_layer,
310            required: true,
311        }
312    }
313
314    #[must_use]
315    /// Creates a new optional [`ProxyConnectorLayer`].
316    ///
317    /// This connector will forward to the inner connector if no [`ProxyAddress`] is configured
318    pub fn optional(
319        socks_proxy_layer: Socks5ProxyConnectorLayer,
320        http_proxy_layer: HttpProxyConnectorLayer,
321    ) -> Self {
322        Self {
323            socks_layer: socks_proxy_layer,
324            http_layer: http_proxy_layer,
325            required: false,
326        }
327    }
328}
329
330impl<S> Layer<S> for ProxyConnectorLayer {
331    type Service = ProxyConnector<Arc<S>>;
332
333    fn layer(&self, inner: S) -> Self::Service {
334        ProxyConnector::new(
335            inner,
336            self.socks_layer.clone(),
337            self.http_layer.clone(),
338            self.required,
339        )
340    }
341
342    fn into_layer(self, inner: S) -> Self::Service {
343        ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
344    }
345}