1use crate::{
2 Layer, Service,
3 error::BoxError,
4 extensions::{Extensions, ExtensionsMut, ExtensionsRef},
5 http::client::proxy::layer::{
6 HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7 },
8 net::{
9 Protocol,
10 address::ProxyAddress,
11 client::{ConnectorService, EstablishedClientConnection},
12 transport::TryRefIntoTransportContext,
13 },
14 proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15 stream::Stream,
16 telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use rama_core::error::{ErrorContext as _, ErrorExt as _};
20use std::{
21 fmt::Debug,
22 pin::Pin,
23 task::{self, Poll},
24};
25use tokio::io::{AsyncRead, AsyncWrite};
26
27#[derive(Debug, Clone)]
32pub struct ProxyConnector<S> {
33 inner: S,
34 socks: Socks5ProxyConnector<S>,
35 http: HttpProxyConnector<S>,
36 required: bool,
37}
38
39impl<S: Clone> ProxyConnector<S> {
40 fn new(
42 inner: S,
43 socks_proxy_layer: Socks5ProxyConnectorLayer,
44 http_proxy_layer: HttpProxyConnectorLayer,
45 required: bool,
46 ) -> Self {
47 Self {
48 socks: socks_proxy_layer.into_layer(inner.clone()),
49 http: http_proxy_layer.into_layer(inner.clone()),
50 inner,
51 required,
52 }
53 }
54
55 #[inline]
56 pub fn required(
60 inner: S,
61 socks_proxy_layer: Socks5ProxyConnectorLayer,
62 http_proxy_layer: HttpProxyConnectorLayer,
63 ) -> Self {
64 Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
65 }
66
67 #[inline]
68 pub fn optional(
72 inner: S,
73 socks_proxy_layer: Socks5ProxyConnectorLayer,
74 http_proxy_layer: HttpProxyConnectorLayer,
75 ) -> Self {
76 Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
77 }
78}
79
80impl<Input, S> Service<Input> for ProxyConnector<S>
81where
82 S: ConnectorService<Input, Connection: Stream + Unpin>,
83 Input: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
84 + Send
85 + ExtensionsMut
86 + 'static,
87{
88 type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
89 type Error = BoxError;
90
91 async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
92 let proxy = input.extensions().get::<ProxyAddress>();
93
94 match proxy {
95 None => {
96 if self.required {
97 return Err("proxy required but none is defined".into());
98 }
99 tracing::trace!("no proxy detected in ctx, using inner connector");
100 let EstablishedClientConnection { input, conn } =
101 self.inner.connect(input).await.into_box_error()?;
102
103 let conn = MaybeProxiedConnection::direct(conn);
104 Ok(EstablishedClientConnection { input, conn })
105 }
106 Some(proxy) => {
107 let protocol = proxy.protocol.as_ref();
108 tracing::trace!(?protocol, "proxy detected in ctx");
109
110 let protocol = protocol.unwrap_or_else(|| {
111 tracing::trace!("no protocol detected, using http as protocol");
112 &Protocol::HTTP
113 });
114
115 if protocol.is_socks5() {
116 tracing::trace!("using socks proxy connector");
117 let EstablishedClientConnection { input, conn } =
118 self.socks.connect(input).await?;
119
120 let conn = MaybeProxiedConnection::socks(conn);
121 Ok(EstablishedClientConnection { input, conn })
122 } else if protocol.is_http() {
123 tracing::trace!("using http proxy connector");
124 let EstablishedClientConnection { input, conn } =
125 self.http.connect(input).await?;
126
127 let conn = MaybeProxiedConnection::http(conn);
128 Ok(EstablishedClientConnection { input, conn })
129 } else {
130 Err(BoxError::from("received unsupport proxy protocol")
131 .with_context_debug_field("protocol", || protocol.clone()))
132 }
133 }
134 }
135 }
136}
137
138pin_project! {
139 pub struct MaybeProxiedConnection<S> {
141 #[pin]
142 inner: Connection<S>,
143 }
144}
145
146impl<S: ExtensionsMut> MaybeProxiedConnection<S> {
147 pub fn direct(conn: S) -> Self {
148 Self {
149 inner: Connection::Direct { conn },
150 }
151 }
152
153 pub fn socks(conn: S) -> Self {
154 Self {
155 inner: Connection::Socks { conn },
156 }
157 }
158
159 pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
160 Self {
161 inner: Connection::Http { conn },
162 }
163 }
164}
165
166impl<S: Debug> Debug for MaybeProxiedConnection<S> {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.debug_struct("MaybeProxiedConnection")
169 .field("inner", &self.inner)
170 .finish()
171 }
172}
173
174impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
175 fn extensions(&self) -> &Extensions {
176 match &self.inner {
177 Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
178 Connection::Http { conn } => conn.extensions(),
179 }
180 }
181}
182
183impl<S: ExtensionsMut> ExtensionsMut for MaybeProxiedConnection<S> {
184 fn extensions_mut(&mut self) -> &mut Extensions {
185 match &mut self.inner {
186 Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions_mut(),
187 Connection::Http { conn } => conn.extensions_mut(),
188 }
189 }
190}
191
192pin_project! {
193 #[project = ConnectionProj]
194 enum Connection<S> {
195 Direct{ #[pin] conn: S },
196 Socks{ #[pin] conn: S },
197 Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
198
199 }
200}
201
202impl<S: Debug> Debug for Connection<S> {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 match self {
205 Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
206 Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
207 Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
208 }
209 }
210}
211
212#[warn(clippy::missing_trait_methods)]
213impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
214 fn poll_write(
215 self: Pin<&mut Self>,
216 cx: &mut task::Context<'_>,
217 buf: &[u8],
218 ) -> Poll<Result<usize, std::io::Error>> {
219 match self.project().inner.project() {
220 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
221 conn.poll_write(cx, buf)
222 }
223 ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
224 }
225 }
226
227 fn poll_flush(
228 self: Pin<&mut Self>,
229 cx: &mut task::Context<'_>,
230 ) -> Poll<Result<(), std::io::Error>> {
231 match self.project().inner.project() {
232 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
233 ConnectionProj::Http { conn } => conn.poll_flush(cx),
234 }
235 }
236
237 fn poll_shutdown(
238 self: Pin<&mut Self>,
239 cx: &mut task::Context<'_>,
240 ) -> Poll<Result<(), std::io::Error>> {
241 match self.project().inner.project() {
242 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
243 conn.poll_shutdown(cx)
244 }
245 ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
246 }
247 }
248
249 fn is_write_vectored(&self) -> bool {
250 match &self.inner {
251 Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
252 Connection::Http { conn } => conn.is_write_vectored(),
253 }
254 }
255
256 fn poll_write_vectored(
257 self: Pin<&mut Self>,
258 cx: &mut task::Context<'_>,
259 bufs: &[std::io::IoSlice<'_>],
260 ) -> Poll<Result<usize, std::io::Error>> {
261 match self.project().inner.project() {
262 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
263 conn.poll_write_vectored(cx, bufs)
264 }
265 ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
266 }
267 }
268}
269
270#[warn(clippy::missing_trait_methods)]
271impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
272 fn poll_read(
273 self: Pin<&mut Self>,
274 cx: &mut task::Context<'_>,
275 buf: &mut tokio::io::ReadBuf<'_>,
276 ) -> Poll<std::io::Result<()>> {
277 match self.project().inner.project() {
278 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
279 conn.poll_read(cx, buf)
280 }
281 ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
282 }
283 }
284}
285
286pub struct ProxyConnectorLayer {
291 socks_layer: Socks5ProxyConnectorLayer,
292 http_layer: HttpProxyConnectorLayer,
293 required: bool,
294}
295
296impl ProxyConnectorLayer {
297 #[must_use]
298 pub fn required(
302 socks_proxy_layer: Socks5ProxyConnectorLayer,
303 http_proxy_layer: HttpProxyConnectorLayer,
304 ) -> Self {
305 Self {
306 socks_layer: socks_proxy_layer,
307 http_layer: http_proxy_layer,
308 required: true,
309 }
310 }
311
312 #[must_use]
313 pub fn optional(
317 socks_proxy_layer: Socks5ProxyConnectorLayer,
318 http_proxy_layer: HttpProxyConnectorLayer,
319 ) -> Self {
320 Self {
321 socks_layer: socks_proxy_layer,
322 http_layer: http_proxy_layer,
323 required: false,
324 }
325 }
326}
327
328impl<S: Clone> Layer<S> for ProxyConnectorLayer {
329 type Service = ProxyConnector<S>;
330
331 fn layer(&self, inner: S) -> Self::Service {
332 ProxyConnector::new(
333 inner,
334 self.socks_layer.clone(),
335 self.http_layer.clone(),
336 self.required,
337 )
338 }
339
340 fn into_layer(self, inner: S) -> Self::Service {
341 ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
342 }
343}