1use crate::{
2 Layer, Service,
3 error::{BoxError, OpaqueError},
4 http::client::proxy::layer::{
5 HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
6 },
7 net::{
8 Protocol,
9 address::ProxyAddress,
10 client::{ConnectorService, EstablishedClientConnection},
11 stream::Stream,
12 transport::TryRefIntoTransportContext,
13 },
14 proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15 telemetry::tracing,
16};
17use pin_project_lite::pin_project;
18use std::{
19 fmt::Debug,
20 pin::Pin,
21 sync::Arc,
22 task::{self, Poll},
23};
24use tokio::io::{AsyncRead, AsyncWrite};
25
26pub struct ProxyConnector<S> {
31 inner: S,
32 socks: Socks5ProxyConnector<S>,
33 http: HttpProxyConnector<S>,
34 required: bool,
35}
36
37impl<S> ProxyConnector<S> {
38 fn new(
40 inner: S,
41 socks_proxy_layer: Socks5ProxyConnectorLayer,
42 http_proxy_layer: HttpProxyConnectorLayer,
43 required: bool,
44 ) -> ProxyConnector<Arc<S>> {
45 let inner = Arc::new(inner);
46 ProxyConnector {
47 socks: socks_proxy_layer.into_layer(inner.clone()),
48 http: http_proxy_layer.into_layer(inner.clone()),
49 inner,
50 required,
51 }
52 }
53
54 #[inline]
55 pub fn required(
59 inner: S,
60 socks_proxy_layer: Socks5ProxyConnectorLayer,
61 http_proxy_layer: HttpProxyConnectorLayer,
62 ) -> ProxyConnector<Arc<S>> {
63 Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
64 }
65
66 #[inline]
67 pub fn optional(
71 inner: S,
72 socks_proxy_layer: Socks5ProxyConnectorLayer,
73 http_proxy_layer: HttpProxyConnectorLayer,
74 ) -> ProxyConnector<Arc<S>> {
75 Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
76 }
77}
78
79impl<Request, S> Service<Request> for ProxyConnector<S>
80where
81 S: ConnectorService<Request, Connection: Stream + Unpin, Error: Into<BoxError>>,
82 Request: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static> + Send + 'static,
83{
84 type Response = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Request>;
85
86 type Error = BoxError;
87
88 async fn serve(
89 &self,
90 ctx: rama_core::Context,
91 req: Request,
92 ) -> Result<Self::Response, Self::Error> {
93 let proxy = ctx.get::<ProxyAddress>();
94
95 match proxy {
96 None => {
97 if self.required {
98 return Err("proxy required but none is defined".into());
99 }
100 tracing::trace!("no proxy detected in ctx, using inner connector");
101 let EstablishedClientConnection { ctx, req, conn } =
102 self.inner.connect(ctx, req).await.map_err(Into::into)?;
103 Ok(EstablishedClientConnection {
104 ctx,
105 req,
106 conn: MaybeProxiedConnection {
107 inner: Connection::Direct { conn },
108 },
109 })
110 }
111 Some(proxy) => {
112 let protocol = proxy.protocol.as_ref();
113 tracing::trace!(?protocol, "proxy detected in ctx");
114
115 let protocol = protocol.unwrap_or_else(|| {
116 tracing::trace!("no protocol detected, using http as protocol");
117 &Protocol::HTTP
118 });
119
120 if protocol.is_socks5() {
121 tracing::trace!("using socks proxy connector");
122 let EstablishedClientConnection { ctx, req, conn } =
123 self.socks.connect(ctx, req).await?;
124 Ok(EstablishedClientConnection {
125 ctx,
126 req,
127 conn: MaybeProxiedConnection {
128 inner: Connection::Socks { conn },
129 },
130 })
131 } else if protocol.is_http() {
132 tracing::trace!("using http proxy connector");
133 let EstablishedClientConnection { ctx, req, conn } =
134 self.http.connect(ctx, req).await?;
135 Ok(EstablishedClientConnection {
136 ctx,
137 req,
138 conn: MaybeProxiedConnection {
139 inner: Connection::Http { conn },
140 },
141 })
142 } else {
143 Err(OpaqueError::from_display(format!(
144 "received unsupport proxy protocol {protocol:?}"
145 ))
146 .into_boxed())
147 }
148 }
149 }
150 }
151}
152
153pin_project! {
154 pub struct MaybeProxiedConnection<S> {
156 #[pin]
157 inner: Connection<S>,
158 }
159}
160
161impl<S: Debug> Debug for MaybeProxiedConnection<S> {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("MaybeProxiedConnection")
164 .field("inner", &self.inner)
165 .finish()
166 }
167}
168
169pin_project! {
170 #[project = ConnectionProj]
171 enum Connection<S> {
172 Direct{ #[pin] conn: S },
173 Socks{ #[pin] conn: S },
174 Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
175
176 }
177}
178
179impl<S: Debug> Debug for Connection<S> {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 match self {
182 Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
183 Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
184 Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
185 }
186 }
187}
188
189impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
190 fn poll_write(
191 self: Pin<&mut Self>,
192 cx: &mut task::Context<'_>,
193 buf: &[u8],
194 ) -> Poll<Result<usize, std::io::Error>> {
195 match self.project().inner.project() {
196 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
197 conn.poll_write(cx, buf)
198 }
199 ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
200 }
201 }
202
203 fn poll_flush(
204 self: Pin<&mut Self>,
205 cx: &mut task::Context<'_>,
206 ) -> Poll<Result<(), std::io::Error>> {
207 match self.project().inner.project() {
208 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
209 ConnectionProj::Http { conn } => conn.poll_flush(cx),
210 }
211 }
212
213 fn poll_shutdown(
214 self: Pin<&mut Self>,
215 cx: &mut task::Context<'_>,
216 ) -> Poll<Result<(), std::io::Error>> {
217 match self.project().inner.project() {
218 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
219 conn.poll_shutdown(cx)
220 }
221 ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
222 }
223 }
224
225 fn is_write_vectored(&self) -> bool {
226 match &self.inner {
227 Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
228 Connection::Http { conn } => conn.is_write_vectored(),
229 }
230 }
231
232 fn poll_write_vectored(
233 self: Pin<&mut Self>,
234 cx: &mut task::Context<'_>,
235 bufs: &[std::io::IoSlice<'_>],
236 ) -> Poll<Result<usize, std::io::Error>> {
237 match self.project().inner.project() {
238 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
239 conn.poll_write_vectored(cx, bufs)
240 }
241 ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
242 }
243 }
244}
245
246impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
247 fn poll_read(
248 self: Pin<&mut Self>,
249 cx: &mut task::Context<'_>,
250 buf: &mut tokio::io::ReadBuf<'_>,
251 ) -> Poll<std::io::Result<()>> {
252 match self.project().inner.project() {
253 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
254 conn.poll_read(cx, buf)
255 }
256 ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
257 }
258 }
259}
260
261pub struct ProxyConnectorLayer {
266 socks_layer: Socks5ProxyConnectorLayer,
267 http_layer: HttpProxyConnectorLayer,
268 required: bool,
269}
270
271impl ProxyConnectorLayer {
272 #[must_use]
273 pub fn required(
277 socks_proxy_layer: Socks5ProxyConnectorLayer,
278 http_proxy_layer: HttpProxyConnectorLayer,
279 ) -> Self {
280 Self {
281 socks_layer: socks_proxy_layer,
282 http_layer: http_proxy_layer,
283 required: true,
284 }
285 }
286
287 #[must_use]
288 pub fn optional(
292 socks_proxy_layer: Socks5ProxyConnectorLayer,
293 http_proxy_layer: HttpProxyConnectorLayer,
294 ) -> Self {
295 Self {
296 socks_layer: socks_proxy_layer,
297 http_layer: http_proxy_layer,
298 required: false,
299 }
300 }
301}
302
303impl<S> Layer<S> for ProxyConnectorLayer {
304 type Service = ProxyConnector<Arc<S>>;
305
306 fn layer(&self, inner: S) -> Self::Service {
307 ProxyConnector::new(
308 inner,
309 self.socks_layer.clone(),
310 self.http_layer.clone(),
311 self.required,
312 )
313 }
314
315 fn into_layer(self, inner: S) -> Self::Service {
316 ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
317 }
318}