1use crate::{
2 Layer, Service,
3 error::{BoxError, OpaqueError},
4 extensions::{Extensions, ExtensionsMut, ExtensionsRef},
5 http::client::proxy::layer::{
6 HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7 },
8 net::{
9 Protocol,
10 address::ProxyAddress,
11 client::{ConnectorService, EstablishedClientConnection},
12 transport::TryRefIntoTransportContext,
13 },
14 proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15 stream::Stream,
16 telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use std::{
20 fmt::Debug,
21 pin::Pin,
22 sync::Arc,
23 task::{self, Poll},
24};
25use tokio::io::{AsyncRead, AsyncWrite};
26
27pub struct ProxyConnector<S> {
32 inner: S,
33 socks: Socks5ProxyConnector<S>,
34 http: HttpProxyConnector<S>,
35 required: bool,
36}
37
38impl<S> ProxyConnector<S> {
39 fn new(
41 inner: S,
42 socks_proxy_layer: Socks5ProxyConnectorLayer,
43 http_proxy_layer: HttpProxyConnectorLayer,
44 required: bool,
45 ) -> ProxyConnector<Arc<S>> {
46 let inner = Arc::new(inner);
47 ProxyConnector {
48 socks: socks_proxy_layer.into_layer(inner.clone()),
49 http: http_proxy_layer.into_layer(inner.clone()),
50 inner,
51 required,
52 }
53 }
54
55 #[inline]
56 pub fn required(
60 inner: S,
61 socks_proxy_layer: Socks5ProxyConnectorLayer,
62 http_proxy_layer: HttpProxyConnectorLayer,
63 ) -> ProxyConnector<Arc<S>> {
64 Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
65 }
66
67 #[inline]
68 pub fn optional(
72 inner: S,
73 socks_proxy_layer: Socks5ProxyConnectorLayer,
74 http_proxy_layer: HttpProxyConnectorLayer,
75 ) -> ProxyConnector<Arc<S>> {
76 Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
77 }
78}
79
80impl<Request, S> Service<Request> for ProxyConnector<S>
81where
82 S: ConnectorService<Request, Connection: Stream + Unpin>,
83 Request: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
84 + Send
85 + ExtensionsMut
86 + 'static,
87{
88 type Response = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Request>;
89
90 type Error = BoxError;
91
92 async fn serve(&self, req: Request) -> Result<Self::Response, Self::Error> {
93 let proxy = req.extensions().get::<ProxyAddress>();
94
95 match proxy {
96 None => {
97 if self.required {
98 return Err("proxy required but none is defined".into());
99 }
100 tracing::trace!("no proxy detected in ctx, using inner connector");
101 let EstablishedClientConnection { req, conn } =
102 self.inner.connect(req).await.map_err(Into::into)?;
103
104 let conn = MaybeProxiedConnection::direct(conn);
105 Ok(EstablishedClientConnection { req, conn })
106 }
107 Some(proxy) => {
108 let protocol = proxy.protocol.as_ref();
109 tracing::trace!(?protocol, "proxy detected in ctx");
110
111 let protocol = protocol.unwrap_or_else(|| {
112 tracing::trace!("no protocol detected, using http as protocol");
113 &Protocol::HTTP
114 });
115
116 if protocol.is_socks5() {
117 tracing::trace!("using socks proxy connector");
118 let EstablishedClientConnection { req, conn } = self.socks.connect(req).await?;
119
120 let conn = MaybeProxiedConnection::socks(conn);
121 Ok(EstablishedClientConnection { req, conn })
122 } else if protocol.is_http() {
123 tracing::trace!("using http proxy connector");
124 let EstablishedClientConnection { req, conn } = self.http.connect(req).await?;
125
126 let conn = MaybeProxiedConnection::http(conn);
127 Ok(EstablishedClientConnection { req, conn })
128 } else {
129 Err(OpaqueError::from_display(format!(
130 "received unsupport proxy protocol {protocol:?}"
131 ))
132 .into_boxed())
133 }
134 }
135 }
136 }
137}
138
139pin_project! {
140 pub struct MaybeProxiedConnection<S> {
142 #[pin]
143 inner: Connection<S>,
144 }
145}
146
147impl<S: ExtensionsMut> MaybeProxiedConnection<S> {
148 pub fn direct(conn: S) -> Self {
149 Self {
150 inner: Connection::Direct { conn },
151 }
152 }
153
154 pub fn socks(conn: S) -> Self {
155 Self {
156 inner: Connection::Socks { conn },
157 }
158 }
159
160 pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
161 Self {
162 inner: Connection::Http { conn },
163 }
164 }
165}
166
167impl<S: Debug> Debug for MaybeProxiedConnection<S> {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.debug_struct("MaybeProxiedConnection")
170 .field("inner", &self.inner)
171 .finish()
172 }
173}
174
175impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
176 fn extensions(&self) -> &Extensions {
177 match &self.inner {
178 Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
179 Connection::Http { conn } => conn.extensions(),
180 }
181 }
182}
183
184impl<S: ExtensionsMut> ExtensionsMut for MaybeProxiedConnection<S> {
185 fn extensions_mut(&mut self) -> &mut Extensions {
186 match &mut self.inner {
187 Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions_mut(),
188 Connection::Http { conn } => conn.extensions_mut(),
189 }
190 }
191}
192
193pin_project! {
194 #[project = ConnectionProj]
195 enum Connection<S> {
196 Direct{ #[pin] conn: S },
197 Socks{ #[pin] conn: S },
198 Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
199
200 }
201}
202
203impl<S: Debug> Debug for Connection<S> {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 match self {
206 Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
207 Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
208 Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
209 }
210 }
211}
212
213#[warn(clippy::missing_trait_methods)]
214impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
215 fn poll_write(
216 self: Pin<&mut Self>,
217 cx: &mut task::Context<'_>,
218 buf: &[u8],
219 ) -> Poll<Result<usize, std::io::Error>> {
220 match self.project().inner.project() {
221 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
222 conn.poll_write(cx, buf)
223 }
224 ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
225 }
226 }
227
228 fn poll_flush(
229 self: Pin<&mut Self>,
230 cx: &mut task::Context<'_>,
231 ) -> Poll<Result<(), std::io::Error>> {
232 match self.project().inner.project() {
233 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
234 ConnectionProj::Http { conn } => conn.poll_flush(cx),
235 }
236 }
237
238 fn poll_shutdown(
239 self: Pin<&mut Self>,
240 cx: &mut task::Context<'_>,
241 ) -> Poll<Result<(), std::io::Error>> {
242 match self.project().inner.project() {
243 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
244 conn.poll_shutdown(cx)
245 }
246 ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
247 }
248 }
249
250 fn is_write_vectored(&self) -> bool {
251 match &self.inner {
252 Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
253 Connection::Http { conn } => conn.is_write_vectored(),
254 }
255 }
256
257 fn poll_write_vectored(
258 self: Pin<&mut Self>,
259 cx: &mut task::Context<'_>,
260 bufs: &[std::io::IoSlice<'_>],
261 ) -> Poll<Result<usize, std::io::Error>> {
262 match self.project().inner.project() {
263 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
264 conn.poll_write_vectored(cx, bufs)
265 }
266 ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
267 }
268 }
269}
270
271#[warn(clippy::missing_trait_methods)]
272impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
273 fn poll_read(
274 self: Pin<&mut Self>,
275 cx: &mut task::Context<'_>,
276 buf: &mut tokio::io::ReadBuf<'_>,
277 ) -> Poll<std::io::Result<()>> {
278 match self.project().inner.project() {
279 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
280 conn.poll_read(cx, buf)
281 }
282 ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
283 }
284 }
285}
286
287pub struct ProxyConnectorLayer {
292 socks_layer: Socks5ProxyConnectorLayer,
293 http_layer: HttpProxyConnectorLayer,
294 required: bool,
295}
296
297impl ProxyConnectorLayer {
298 #[must_use]
299 pub fn required(
303 socks_proxy_layer: Socks5ProxyConnectorLayer,
304 http_proxy_layer: HttpProxyConnectorLayer,
305 ) -> Self {
306 Self {
307 socks_layer: socks_proxy_layer,
308 http_layer: http_proxy_layer,
309 required: true,
310 }
311 }
312
313 #[must_use]
314 pub fn optional(
318 socks_proxy_layer: Socks5ProxyConnectorLayer,
319 http_proxy_layer: HttpProxyConnectorLayer,
320 ) -> Self {
321 Self {
322 socks_layer: socks_proxy_layer,
323 http_layer: http_proxy_layer,
324 required: false,
325 }
326 }
327}
328
329impl<S> Layer<S> for ProxyConnectorLayer {
330 type Service = ProxyConnector<Arc<S>>;
331
332 fn layer(&self, inner: S) -> Self::Service {
333 ProxyConnector::new(
334 inner,
335 self.socks_layer.clone(),
336 self.http_layer.clone(),
337 self.required,
338 )
339 }
340
341 fn into_layer(self, inner: S) -> Self::Service {
342 ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
343 }
344}