1use crate::{
2 Layer, Service,
3 error::BoxError,
4 extensions::{Extensions, ExtensionsRef},
5 http::client::proxy::layer::{
6 HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7 },
8 io::Io,
9 net::{
10 Protocol,
11 address::ProxyAddress,
12 client::{ConnectorService, EstablishedClientConnection},
13 transport::TryRefIntoTransportContext,
14 },
15 proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
16 telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use rama_core::error::{ErrorContext as _, ErrorExt as _, extra::OpaqueError};
20use std::{
21 fmt::Debug,
22 pin::Pin,
23 task::{self, Poll},
24};
25use tokio::io::{AsyncRead, AsyncWrite};
26
27#[derive(Debug, Clone)]
32pub struct ProxyConnector<S> {
33 inner: S,
34 socks: Socks5ProxyConnector<S>,
35 http: HttpProxyConnector<S>,
36 required: bool,
37}
38
39impl<S: Clone> ProxyConnector<S> {
40 fn new(
42 inner: S,
43 socks_proxy_layer: Socks5ProxyConnectorLayer,
44 http_proxy_layer: HttpProxyConnectorLayer,
45 required: bool,
46 ) -> Self {
47 Self {
48 socks: socks_proxy_layer.into_layer(inner.clone()),
49 http: http_proxy_layer.into_layer(inner.clone()),
50 inner,
51 required,
52 }
53 }
54
55 #[inline]
56 pub fn required(
60 inner: S,
61 socks_proxy_layer: Socks5ProxyConnectorLayer,
62 http_proxy_layer: HttpProxyConnectorLayer,
63 ) -> Self {
64 Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
65 }
66
67 #[inline]
68 pub fn optional(
72 inner: S,
73 socks_proxy_layer: Socks5ProxyConnectorLayer,
74 http_proxy_layer: HttpProxyConnectorLayer,
75 ) -> Self {
76 Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
77 }
78}
79
80impl<Input, S> Service<Input> for ProxyConnector<S>
81where
82 S: ConnectorService<Input, Connection: Io + Unpin>,
83 Input: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
84 + Send
85 + ExtensionsRef
86 + 'static,
87{
88 type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
89 type Error = BoxError;
90
91 async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
92 let proxy = input.extensions().get_ref::<ProxyAddress>();
93
94 match proxy {
95 None => {
96 if self.required {
97 return Err(
98 OpaqueError::from_static_str("proxy required but none is defined")
99 .into_box_error(),
100 );
101 }
102 tracing::trace!("no proxy detected in ctx, using inner connector");
103 let EstablishedClientConnection { input, conn } =
104 self.inner.connect(input).await.into_box_error()?;
105
106 let conn = MaybeProxiedConnection::direct(conn);
107 Ok(EstablishedClientConnection { input, conn })
108 }
109 Some(proxy) => {
110 let protocol = proxy.protocol.as_ref();
111 tracing::trace!(?protocol, "proxy detected in ctx");
112
113 let protocol = protocol.unwrap_or_else(|| {
114 tracing::trace!("no protocol detected, using http as protocol");
115 &Protocol::HTTP
116 });
117
118 if protocol.is_socks5() {
119 tracing::trace!("using socks proxy connector");
120 let EstablishedClientConnection { input, conn } =
121 self.socks.connect(input).await?;
122
123 let conn = MaybeProxiedConnection::socks(conn);
124 Ok(EstablishedClientConnection { input, conn })
125 } else if protocol.is_http() {
126 tracing::trace!("using http proxy connector");
127 let EstablishedClientConnection { input, conn } =
128 self.http.connect(input).await?;
129
130 let conn = MaybeProxiedConnection::http(conn);
131 Ok(EstablishedClientConnection { input, conn })
132 } else {
133 Err(
134 OpaqueError::from_static_str("received unsupport proxy protocol")
135 .with_context_debug_field("protocol", || protocol.clone()),
136 )
137 }
138 }
139 }
140 }
141}
142
143pin_project! {
144 pub struct MaybeProxiedConnection<S> {
146 #[pin]
147 inner: Connection<S>,
148 }
149}
150
151impl<S: ExtensionsRef> MaybeProxiedConnection<S> {
152 pub fn direct(conn: S) -> Self {
153 Self {
154 inner: Connection::Direct { conn },
155 }
156 }
157
158 pub fn socks(conn: S) -> Self {
159 Self {
160 inner: Connection::Socks { conn },
161 }
162 }
163
164 pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
165 Self {
166 inner: Connection::Http { conn },
167 }
168 }
169}
170
171impl<S: Debug> Debug for MaybeProxiedConnection<S> {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("MaybeProxiedConnection")
174 .field("inner", &self.inner)
175 .finish()
176 }
177}
178
179impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
180 fn extensions(&self) -> &Extensions {
181 match &self.inner {
182 Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
183 Connection::Http { conn } => conn.extensions(),
184 }
185 }
186}
187
188pin_project! {
189 #[project = ConnectionProj]
190 enum Connection<S> {
191 Direct{ #[pin] conn: S },
192 Socks{ #[pin] conn: S },
193 Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
194
195 }
196}
197
198impl<S: Debug> Debug for Connection<S> {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 match self {
201 Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
202 Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
203 Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
204 }
205 }
206}
207
208#[warn(clippy::missing_trait_methods)]
209impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
210 fn poll_write(
211 self: Pin<&mut Self>,
212 cx: &mut task::Context<'_>,
213 buf: &[u8],
214 ) -> Poll<Result<usize, std::io::Error>> {
215 match self.project().inner.project() {
216 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
217 conn.poll_write(cx, buf)
218 }
219 ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
220 }
221 }
222
223 fn poll_flush(
224 self: Pin<&mut Self>,
225 cx: &mut task::Context<'_>,
226 ) -> Poll<Result<(), std::io::Error>> {
227 match self.project().inner.project() {
228 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
229 ConnectionProj::Http { conn } => conn.poll_flush(cx),
230 }
231 }
232
233 fn poll_shutdown(
234 self: Pin<&mut Self>,
235 cx: &mut task::Context<'_>,
236 ) -> Poll<Result<(), std::io::Error>> {
237 match self.project().inner.project() {
238 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
239 conn.poll_shutdown(cx)
240 }
241 ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
242 }
243 }
244
245 fn is_write_vectored(&self) -> bool {
246 match &self.inner {
247 Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
248 Connection::Http { conn } => conn.is_write_vectored(),
249 }
250 }
251
252 fn poll_write_vectored(
253 self: Pin<&mut Self>,
254 cx: &mut task::Context<'_>,
255 bufs: &[std::io::IoSlice<'_>],
256 ) -> Poll<Result<usize, std::io::Error>> {
257 match self.project().inner.project() {
258 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
259 conn.poll_write_vectored(cx, bufs)
260 }
261 ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
262 }
263 }
264}
265
266#[warn(clippy::missing_trait_methods)]
267impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
268 fn poll_read(
269 self: Pin<&mut Self>,
270 cx: &mut task::Context<'_>,
271 buf: &mut tokio::io::ReadBuf<'_>,
272 ) -> Poll<std::io::Result<()>> {
273 match self.project().inner.project() {
274 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
275 conn.poll_read(cx, buf)
276 }
277 ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
278 }
279 }
280}
281
282pub struct ProxyConnectorLayer {
287 socks_layer: Socks5ProxyConnectorLayer,
288 http_layer: HttpProxyConnectorLayer,
289 required: bool,
290}
291
292impl ProxyConnectorLayer {
293 #[must_use]
294 pub fn required(
298 socks_proxy_layer: Socks5ProxyConnectorLayer,
299 http_proxy_layer: HttpProxyConnectorLayer,
300 ) -> Self {
301 Self {
302 socks_layer: socks_proxy_layer,
303 http_layer: http_proxy_layer,
304 required: true,
305 }
306 }
307
308 #[must_use]
309 pub fn optional(
313 socks_proxy_layer: Socks5ProxyConnectorLayer,
314 http_proxy_layer: HttpProxyConnectorLayer,
315 ) -> Self {
316 Self {
317 socks_layer: socks_proxy_layer,
318 http_layer: http_proxy_layer,
319 required: false,
320 }
321 }
322}
323
324impl<S: Clone> Layer<S> for ProxyConnectorLayer {
325 type Service = ProxyConnector<S>;
326
327 fn layer(&self, inner: S) -> Self::Service {
328 ProxyConnector::new(
329 inner,
330 self.socks_layer.clone(),
331 self.http_layer.clone(),
332 self.required,
333 )
334 }
335
336 fn into_layer(self, inner: S) -> Self::Service {
337 ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
338 }
339}