1use crate::{
2 Layer, Service,
3 error::{BoxError, BoxErrorExt, ErrorContext as _, ErrorExt as _},
4 extensions::{Extensions, ExtensionsRef},
5 http::client::proxy::layer::{
6 HttpProxyConnector, HttpProxyConnectorLayer, MaybeHttpProxiedConnection,
7 },
8 io::Io,
9 net::{
10 AuthorityInputExt, Protocol, ProtocolInputExt,
11 address::ProxyAddress,
12 client::{ConnectorService, EstablishedClientConnection},
13 },
14 proxy::socks5::{Socks5ProxyConnector, Socks5ProxyConnectorLayer},
15 telemetry::tracing,
16};
17use pin_project_lite::pin_project;
18use std::{
19 fmt::Debug,
20 pin::Pin,
21 task::{self, Poll},
22};
23use tokio::io::{AsyncRead, AsyncWrite};
24
25#[derive(Debug, Clone)]
30pub struct ProxyConnector<S> {
31 inner: S,
32 socks: Socks5ProxyConnector<S>,
33 http: HttpProxyConnector<S>,
34 required: bool,
35}
36
37impl<S: Clone> ProxyConnector<S> {
38 fn new(
40 inner: S,
41 socks_proxy_layer: Socks5ProxyConnectorLayer,
42 http_proxy_layer: HttpProxyConnectorLayer,
43 required: bool,
44 ) -> Self {
45 Self {
46 socks: socks_proxy_layer.into_layer(inner.clone()),
47 http: http_proxy_layer.into_layer(inner.clone()),
48 inner,
49 required,
50 }
51 }
52
53 #[inline]
54 pub fn required(
58 inner: S,
59 socks_proxy_layer: Socks5ProxyConnectorLayer,
60 http_proxy_layer: HttpProxyConnectorLayer,
61 ) -> Self {
62 Self::new(inner, socks_proxy_layer, http_proxy_layer, true)
63 }
64
65 #[inline]
66 pub fn optional(
70 inner: S,
71 socks_proxy_layer: Socks5ProxyConnectorLayer,
72 http_proxy_layer: HttpProxyConnectorLayer,
73 ) -> Self {
74 Self::new(inner, socks_proxy_layer, http_proxy_layer, false)
75 }
76}
77
78impl<Input, S> Service<Input> for ProxyConnector<S>
79where
80 S: ConnectorService<Input, Connection: Io + Unpin>,
81 Input: AuthorityInputExt + ProtocolInputExt + Send + ExtensionsRef + 'static,
82{
83 type Output = EstablishedClientConnection<MaybeProxiedConnection<S::Connection>, Input>;
84 type Error = BoxError;
85
86 async fn serve(&self, input: Input) -> Result<Self::Output, Self::Error> {
87 let proxy = input.extensions().get_ref::<ProxyAddress>();
88
89 match proxy {
90 None => {
91 if self.required {
92 return Err(BoxError::from_static_str(
93 "proxy required but none is defined",
94 ));
95 }
96 tracing::trace!("no proxy detected in ctx, using inner connector");
97 let EstablishedClientConnection { input, conn } =
98 self.inner.connect(input).await.into_box_error()?;
99
100 let conn = MaybeProxiedConnection::direct(conn);
101 Ok(EstablishedClientConnection { input, conn })
102 }
103 Some(proxy) => {
104 let protocol = proxy.protocol.as_ref();
105 tracing::trace!(?protocol, "proxy detected in ctx");
106
107 let protocol = protocol.unwrap_or_else(|| {
108 tracing::trace!("no protocol detected, using http as protocol");
109 &Protocol::HTTP
110 });
111
112 if protocol.is_socks5() {
113 tracing::trace!("using socks proxy connector");
114 let EstablishedClientConnection { input, conn } =
115 self.socks.connect(input).await?;
116
117 let conn = MaybeProxiedConnection::socks(conn);
118 Ok(EstablishedClientConnection { input, conn })
119 } else if protocol.is_http() {
120 tracing::trace!("using http proxy connector");
121 let EstablishedClientConnection { input, conn } =
122 self.http.connect(input).await?;
123
124 let conn = MaybeProxiedConnection::http(conn);
125 Ok(EstablishedClientConnection { input, conn })
126 } else {
127 Err(
128 BoxError::from_static_str("received unsupport proxy protocol")
129 .with_context_debug_field("protocol", || protocol.clone()),
130 )
131 }
132 }
133 }
134 }
135}
136
137pin_project! {
138 pub struct MaybeProxiedConnection<S> {
140 #[pin]
141 inner: Connection<S>,
142 }
143}
144
145impl<S: ExtensionsRef> MaybeProxiedConnection<S> {
146 pub fn direct(conn: S) -> Self {
147 Self {
148 inner: Connection::Direct { conn },
149 }
150 }
151
152 pub fn socks(conn: S) -> Self {
153 Self {
154 inner: Connection::Socks { conn },
155 }
156 }
157
158 pub fn http(conn: MaybeHttpProxiedConnection<S>) -> Self {
159 Self {
160 inner: Connection::Http { conn },
161 }
162 }
163}
164
165impl<S: Debug> Debug for MaybeProxiedConnection<S> {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 f.debug_struct("MaybeProxiedConnection")
168 .field("inner", &self.inner)
169 .finish()
170 }
171}
172
173impl<S: ExtensionsRef> ExtensionsRef for MaybeProxiedConnection<S> {
174 fn extensions(&self) -> &Extensions {
175 match &self.inner {
176 Connection::Direct { conn } | Connection::Socks { conn } => conn.extensions(),
177 Connection::Http { conn } => conn.extensions(),
178 }
179 }
180}
181
182pin_project! {
183 #[project = ConnectionProj]
184 enum Connection<S> {
185 Direct{ #[pin] conn: S },
186 Socks{ #[pin] conn: S },
187 Http{ #[pin] conn: MaybeHttpProxiedConnection<S> },
188
189 }
190}
191
192impl<S: Debug> Debug for Connection<S> {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 match self {
195 Self::Direct { conn } => f.debug_struct("Direct").field("conn", conn).finish(),
196 Self::Socks { conn } => f.debug_struct("Socks").field("conn", conn).finish(),
197 Self::Http { conn } => f.debug_struct("Http").field("conn", conn).finish(),
198 }
199 }
200}
201
202#[warn(clippy::missing_trait_methods)]
203impl<Conn: AsyncWrite> AsyncWrite for MaybeProxiedConnection<Conn> {
204 fn poll_write(
205 self: Pin<&mut Self>,
206 cx: &mut task::Context<'_>,
207 buf: &[u8],
208 ) -> Poll<Result<usize, std::io::Error>> {
209 match self.project().inner.project() {
210 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
211 conn.poll_write(cx, buf)
212 }
213 ConnectionProj::Http { conn } => conn.poll_write(cx, buf),
214 }
215 }
216
217 fn poll_flush(
218 self: Pin<&mut Self>,
219 cx: &mut task::Context<'_>,
220 ) -> Poll<Result<(), std::io::Error>> {
221 match self.project().inner.project() {
222 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => conn.poll_flush(cx),
223 ConnectionProj::Http { conn } => conn.poll_flush(cx),
224 }
225 }
226
227 fn poll_shutdown(
228 self: Pin<&mut Self>,
229 cx: &mut task::Context<'_>,
230 ) -> Poll<Result<(), std::io::Error>> {
231 match self.project().inner.project() {
232 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
233 conn.poll_shutdown(cx)
234 }
235 ConnectionProj::Http { conn } => conn.poll_shutdown(cx),
236 }
237 }
238
239 fn is_write_vectored(&self) -> bool {
240 match &self.inner {
241 Connection::Direct { conn } | Connection::Socks { conn } => conn.is_write_vectored(),
242 Connection::Http { conn } => conn.is_write_vectored(),
243 }
244 }
245
246 fn poll_write_vectored(
247 self: Pin<&mut Self>,
248 cx: &mut task::Context<'_>,
249 bufs: &[std::io::IoSlice<'_>],
250 ) -> Poll<Result<usize, std::io::Error>> {
251 match self.project().inner.project() {
252 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
253 conn.poll_write_vectored(cx, bufs)
254 }
255 ConnectionProj::Http { conn } => conn.poll_write_vectored(cx, bufs),
256 }
257 }
258}
259
260#[warn(clippy::missing_trait_methods)]
261impl<Conn: AsyncRead> AsyncRead for MaybeProxiedConnection<Conn> {
262 fn poll_read(
263 self: Pin<&mut Self>,
264 cx: &mut task::Context<'_>,
265 buf: &mut tokio::io::ReadBuf<'_>,
266 ) -> Poll<std::io::Result<()>> {
267 match self.project().inner.project() {
268 ConnectionProj::Direct { conn } | ConnectionProj::Socks { conn } => {
269 conn.poll_read(cx, buf)
270 }
271 ConnectionProj::Http { conn } => conn.poll_read(cx, buf),
272 }
273 }
274}
275
276pub struct ProxyConnectorLayer {
281 socks_layer: Socks5ProxyConnectorLayer,
282 http_layer: HttpProxyConnectorLayer,
283 required: bool,
284}
285
286impl ProxyConnectorLayer {
287 #[must_use]
288 pub fn required(
292 socks_proxy_layer: Socks5ProxyConnectorLayer,
293 http_proxy_layer: HttpProxyConnectorLayer,
294 ) -> Self {
295 Self {
296 socks_layer: socks_proxy_layer,
297 http_layer: http_proxy_layer,
298 required: true,
299 }
300 }
301
302 #[must_use]
303 pub fn optional(
307 socks_proxy_layer: Socks5ProxyConnectorLayer,
308 http_proxy_layer: HttpProxyConnectorLayer,
309 ) -> Self {
310 Self {
311 socks_layer: socks_proxy_layer,
312 http_layer: http_proxy_layer,
313 required: false,
314 }
315 }
316}
317
318impl<S: Clone> Layer<S> for ProxyConnectorLayer {
319 type Service = ProxyConnector<S>;
320
321 fn layer(&self, inner: S) -> Self::Service {
322 ProxyConnector::new(
323 inner,
324 self.socks_layer.clone(),
325 self.http_layer.clone(),
326 self.required,
327 )
328 }
329
330 fn into_layer(self, inner: S) -> Self::Service {
331 ProxyConnector::new(inner, self.socks_layer, self.http_layer, self.required)
332 }
333}