1use crate::{
2 Layer, Service,
3 error::{BoxError, BoxErrorExt, ErrorContext as _, ErrorExt as _},
4 extensions::{Extensions, ExtensionsRef},
5 http::client::proxy::layer::{
6 HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7 },
8 io::Io,
9 net::{
10 Protocol,
11 address::ProxyAddress,
12 client::{ConnectorService, EstablishedClientConnection},
13 transport::TryRefIntoTransportContext,
14 },
15 proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
16 telemetry::tracing,
17};
18use pin_project_lite::pin_project;
19use std::{
20 fmt::Debug,
21 pin::Pin,
22 task::{self, Poll},
23};
24use tokio::io::{AsyncRead, AsyncWrite};
25
26#[derive(Debug, Clone)]
31pub struct ProxyConnector<S> {
32 inner: S,
33 socks: Socks5ProxyConnector<S>,
34 http: HttpProxyConnector<S>,
35 required: bool,
36}
37
38impl<S: Clone> ProxyConnector<S> {
39 fn new(
41 inner: S,
42 socks_proxy_layer: Socks5ProxyConnectorLayer,
43 http_proxy_layer: HttpProxyConnectorLayer,
44 required: bool,
45 ) -> Self {
46 Self {
47 socks: socks_proxy_layer.into_layer(inner.clone()),
48 http: http_proxy_layer.into_layer(inner.clone()),
49 inner,
50 required,
51 }
52 }
53
54 #[inline]
55 pub fn required(
59 inner: S,
60 socks_proxy_layer: Socks5ProxyConnectorLayer,
61 http_proxy_layer: HttpProxyConnectorLayer,
62 ) -> Self {
63 Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
64 }
65
66 #[inline]
67 pub fn optional(
71 inner: S,
72 socks_proxy_layer: Socks5ProxyConnectorLayer,
73 http_proxy_layer: HttpProxyConnectorLayer,
74 ) -> Self {
75 Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
76 }
77}
78
79impl<Input, S> Service<Input> for ProxyConnector<S>
80where
81 S: ConnectorService<Input, Connection: Io + Unpin>,
82 Input: TryRefIntoTransportContext<Error: Into<BoxError> + Send + 'static>
83 + Send
84 + ExtensionsRef
85 + 'static,
86{
87 type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
88 type Error = BoxError;
89
90 async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
91 let proxy = input.extensions().get_ref::<ProxyAddress>();
92
93 match proxy {
94 None => {
95 if self.required {
96 return Err(BoxError::from_static_str(
97 "proxy required but none is defined",
98 ));
99 }
100 tracing::trace!("no proxy detected in ctx, using inner connector");
101 let EstablishedClientConnection { input, conn } =
102 self.inner.connect(input).await.into_box_error()?;
103
104 let conn = MaybeProxiedConnection::direct(conn);
105 Ok(EstablishedClientConnection { input, conn })
106 }
107 Some(proxy) => {
108 let protocol = proxy.protocol.as_ref();
109 tracing::trace!(?protocol, "proxy detected in ctx");
110
111 let protocol = protocol.unwrap_or_else(|| {
112 tracing::trace!("no protocol detected, using http as protocol");
113 &Protocol::HTTP
114 });
115
116 if protocol.is_socks5() {
117 tracing::trace!("using socks proxy connector");
118 let EstablishedClientConnection { input, conn } =
119 self.socks.connect(input).await?;
120
121 let conn = MaybeProxiedConnection::socks(conn);
122 Ok(EstablishedClientConnection { input, conn })
123 } else if protocol.is_http() {
124 tracing::trace!("using http proxy connector");
125 let EstablishedClientConnection { input, conn } =
126 self.http.connect(input).await?;
127
128 let conn = MaybeProxiedConnection::http(conn);
129 Ok(EstablishedClientConnection { input, conn })
130 } else {
131 Err(
132 BoxError::from_static_str("received unsupport proxy protocol")
133 .with_context_debug_field("protocol", || protocol.clone()),
134 )
135 }
136 }
137 }
138 }
139}
140
141pin_project! {
142 pub struct MaybeProxiedConnection<S> {
144 #[pin]
145 inner: Connection<S>,
146 }
147}
148
149impl<S: ExtensionsRef> MaybeProxiedConnection<S> {
150 pub fn direct(conn: S) -> Self {
151 Self {
152 inner: Connection::Direct { conn },
153 }
154 }
155
156 pub fn socks(conn: S) -> Self {
157 Self {
158 inner: Connection::Socks { conn },
159 }
160 }
161
162 pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
163 Self {
164 inner: Connection::Http { conn },
165 }
166 }
167}
168
169impl<S: Debug> Debug for MaybeProxiedConnection<S> {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_struct("MaybeProxiedConnection")
172 .field("inner", &self.inner)
173 .finish()
174 }
175}
176
177impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
178 fn extensions(&self) -> &Extensions {
179 match &self.inner {
180 Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
181 Connection::Http { conn } => conn.extensions(),
182 }
183 }
184}
185
186pin_project! {
187 #[project = ConnectionProj]
188 enum Connection<S> {
189 Direct{ #[pin] conn: S },
190 Socks{ #[pin] conn: S },
191 Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
192
193 }
194}
195
196impl<S: Debug> Debug for Connection<S> {
197 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
198 match self {
199 Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
200 Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
201 Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
202 }
203 }
204}
205
206#[warn(clippy::missing_trait_methods)]
207impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
208 fn poll_write(
209 self: Pin<&mut Self>,
210 cx: &mut task::Context<'_>,
211 buf: &[u8],
212 ) -> Poll<Result<usize, std::io::Error>> {
213 match self.project().inner.project() {
214 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
215 conn.poll_write(cx, buf)
216 }
217 ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
218 }
219 }
220
221 fn poll_flush(
222 self: Pin<&mut Self>,
223 cx: &mut task::Context<'_>,
224 ) -> Poll<Result<(), std::io::Error>> {
225 match self.project().inner.project() {
226 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
227 ConnectionProj::Http { conn } => conn.poll_flush(cx),
228 }
229 }
230
231 fn poll_shutdown(
232 self: Pin<&mut Self>,
233 cx: &mut task::Context<'_>,
234 ) -> Poll<Result<(), std::io::Error>> {
235 match self.project().inner.project() {
236 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
237 conn.poll_shutdown(cx)
238 }
239 ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
240 }
241 }
242
243 fn is_write_vectored(&self) -> bool {
244 match &self.inner {
245 Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
246 Connection::Http { conn } => conn.is_write_vectored(),
247 }
248 }
249
250 fn poll_write_vectored(
251 self: Pin<&mut Self>,
252 cx: &mut task::Context<'_>,
253 bufs: &[std::io::IoSlice<'_>],
254 ) -> Poll<Result<usize, std::io::Error>> {
255 match self.project().inner.project() {
256 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
257 conn.poll_write_vectored(cx, bufs)
258 }
259 ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
260 }
261 }
262}
263
264#[warn(clippy::missing_trait_methods)]
265impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
266 fn poll_read(
267 self: Pin<&mut Self>,
268 cx: &mut task::Context<'_>,
269 buf: &mut tokio::io::ReadBuf<'_>,
270 ) -> Poll<std::io::Result<()>> {
271 match self.project().inner.project() {
272 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
273 conn.poll_read(cx, buf)
274 }
275 ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
276 }
277 }
278}
279
280pub struct ProxyConnectorLayer {
285 socks_layer: Socks5ProxyConnectorLayer,
286 http_layer: HttpProxyConnectorLayer,
287 required: bool,
288}
289
290impl ProxyConnectorLayer {
291 #[must_use]
292 pub fn required(
296 socks_proxy_layer: Socks5ProxyConnectorLayer,
297 http_proxy_layer: HttpProxyConnectorLayer,
298 ) -> Self {
299 Self {
300 socks_layer: socks_proxy_layer,
301 http_layer: http_proxy_layer,
302 required: true,
303 }
304 }
305
306 #[must_use]
307 pub fn optional(
311 socks_proxy_layer: Socks5ProxyConnectorLayer,
312 http_proxy_layer: HttpProxyConnectorLayer,
313 ) -> Self {
314 Self {
315 socks_layer: socks_proxy_layer,
316 http_layer: http_proxy_layer,
317 required: false,
318 }
319 }
320}
321
322impl<S: Clone> Layer<S> for ProxyConnectorLayer {
323 type Service = ProxyConnector<S>;
324
325 fn layer(&self, inner: S) -> Self::Service {
326 ProxyConnector::new(
327 inner,
328 self.socks_layer.clone(),
329 self.http_layer.clone(),
330 self.required,
331 )
332 }
333
334 fn into_layer(self, inner: S) -> Self::Service {
335 ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
336 }
337}