1use crate::{
6 Context, Layer, Service,
7 cli::ForwardKind,
8 combinators::Either7,
9 error::{BoxError, OpaqueError},
10 http::{
11 IntoResponse, Request, Response, StatusCode,
12 headers::{CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp},
13 layer::{
14 forwarded::GetForwardedHeadersLayer, required_header::AddRequiredResponseHeadersLayer,
15 trace::TraceLayer, ua::UserAgentClassifierLayer,
16 },
17 server::HttpServer,
18 },
19 layer::{ConsumeErrLayer, LimitLayer, TimeoutLayer, limit::policy::ConcurrentPolicy},
20 net::forwarded::Forwarded,
21 net::stream::{SocketInfo, Stream, layer::http::BodyLimitLayer},
22 proxy::haproxy::server::HaProxyLayer,
23 rt::Executor,
24};
25use std::{convert::Infallible, marker::PhantomData, time::Duration};
26use tokio::{io::AsyncWriteExt, net::TcpStream};
27
28#[derive(Debug, Clone)]
29pub struct IpServiceBuilder<M> {
32 concurrent_limit: usize,
33 timeout: Duration,
34 forward: Option<ForwardKind>,
35 _mode: PhantomData<fn(M)>,
36}
37
38impl Default for IpServiceBuilder<mode::Http> {
39 fn default() -> Self {
40 Self {
41 concurrent_limit: 0,
42 timeout: Duration::ZERO,
43 forward: None,
44 _mode: PhantomData,
45 }
46 }
47}
48
49impl IpServiceBuilder<mode::Http> {
50 pub fn http() -> Self {
52 Self::default()
53 }
54}
55
56impl IpServiceBuilder<mode::Transport> {
57 pub fn tcp() -> Self {
59 Self {
60 concurrent_limit: 0,
61 timeout: Duration::ZERO,
62 forward: None,
63 _mode: PhantomData,
64 }
65 }
66}
67
68impl<M> IpServiceBuilder<M> {
69 pub fn concurrent(mut self, limit: usize) -> Self {
73 self.concurrent_limit = limit;
74 self
75 }
76
77 pub fn set_concurrent(&mut self, limit: usize) -> &mut Self {
81 self.concurrent_limit = limit;
82 self
83 }
84
85 pub fn timeout(mut self, timeout: Duration) -> Self {
89 self.timeout = timeout;
90 self
91 }
92
93 pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
97 self.timeout = timeout;
98 self
99 }
100
101 pub fn forward(self, kind: ForwardKind) -> Self {
113 self.maybe_forward(Some(kind))
114 }
115
116 pub fn set_forward(&mut self, kind: ForwardKind) -> &mut Self {
120 self.forward = Some(kind);
121 self
122 }
123
124 pub fn maybe_forward(mut self, maybe_kind: Option<ForwardKind>) -> Self {
128 self.forward = maybe_kind;
129 self
130 }
131}
132
133impl IpServiceBuilder<mode::Http> {
134 pub fn build(
136 self,
137 executor: Executor,
138 ) -> Result<impl Service<(), TcpStream, Response = (), Error = Infallible>, BoxError> {
139 let (tcp_forwarded_layer, http_forwarded_layer) = match &self.forward {
140 None => (None, None),
141 Some(ForwardKind::Forwarded) => (
142 None,
143 Some(Either7::A(GetForwardedHeadersLayer::forwarded())),
144 ),
145 Some(ForwardKind::XForwardedFor) => (
146 None,
147 Some(Either7::B(GetForwardedHeadersLayer::x_forwarded_for())),
148 ),
149 Some(ForwardKind::XClientIp) => (
150 None,
151 Some(Either7::C(GetForwardedHeadersLayer::<XClientIp>::new())),
152 ),
153 Some(ForwardKind::ClientIp) => (
154 None,
155 Some(Either7::D(GetForwardedHeadersLayer::<ClientIp>::new())),
156 ),
157 Some(ForwardKind::XRealIp) => (
158 None,
159 Some(Either7::E(GetForwardedHeadersLayer::<XRealIp>::new())),
160 ),
161 Some(ForwardKind::CFConnectingIp) => (
162 None,
163 Some(Either7::F(GetForwardedHeadersLayer::<CFConnectingIp>::new())),
164 ),
165 Some(ForwardKind::TrueClientIp) => (
166 None,
167 Some(Either7::G(GetForwardedHeadersLayer::<TrueClientIp>::new())),
168 ),
169 Some(ForwardKind::HaProxy) => (Some(HaProxyLayer::default()), None),
170 };
171
172 let tcp_service_builder = (
173 ConsumeErrLayer::trace(tracing::Level::DEBUG),
174 (self.concurrent_limit > 0)
175 .then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
176 (!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
177 tcp_forwarded_layer,
178 BodyLimitLayer::request_only(1024 * 1024),
180 );
181
182 let http_service = (
185 TraceLayer::new_for_http(),
186 AddRequiredResponseHeadersLayer::default(),
187 UserAgentClassifierLayer::new(),
188 ConsumeErrLayer::default(),
189 http_forwarded_layer,
190 )
191 .into_layer(HttpEchoService);
192
193 Ok(tcp_service_builder.into_layer(HttpServer::auto(executor).service(http_service)))
194 }
195}
196
197#[derive(Debug, Clone)]
198#[non_exhaustive]
199pub struct HttpEchoService;
201
202impl Service<(), Request> for HttpEchoService {
203 type Response = Response;
204 type Error = BoxError;
205
206 async fn serve(&self, ctx: Context<()>, _req: Request) -> Result<Self::Response, Self::Error> {
207 let peer_ip = ctx
208 .get::<Forwarded>()
209 .and_then(|f| f.client_ip())
210 .or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
211
212 Ok(match peer_ip {
213 Some(ip) => ip.to_string().into_response(),
214 None => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
215 })
216 }
217}
218
219#[derive(Debug, Clone)]
220#[non_exhaustive]
221pub struct TcpEchoService;
223
224impl<Input> Service<(), Input> for TcpEchoService
225where
226 Input: Stream + Unpin,
227{
228 type Response = ();
229 type Error = BoxError;
230
231 async fn serve(&self, ctx: Context<()>, stream: Input) -> Result<Self::Response, Self::Error> {
232 let peer_ip = ctx
233 .get::<Forwarded>()
234 .and_then(|f| f.client_ip())
235 .or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
236 let peer_ip = match peer_ip {
237 Some(peer_ip) => peer_ip,
238 None => {
239 tracing::error!("missing peer information");
240 return Ok(());
241 }
242 };
243
244 let mut stream = std::pin::pin!(stream);
245
246 match peer_ip {
247 std::net::IpAddr::V4(ip) => {
248 if let Err(err) = stream.write_all(&ip.octets()).await {
249 tracing::error!("error writing IPv4 of peer to peer: {}", err);
250 }
251 }
252 std::net::IpAddr::V6(ip) => {
253 if let Err(err) = stream.write_all(&ip.octets()).await {
254 tracing::error!("error writing IPv6 of peer to peer: {}", err);
255 }
256 }
257 };
258
259 Ok(())
260 }
261}
262
263impl IpServiceBuilder<mode::Transport> {
264 pub fn build(
266 self,
267 ) -> Result<impl Service<(), TcpStream, Response = (), Error = Infallible>, BoxError> {
268 let tcp_forwarded_layer = match &self.forward {
269 None => None,
270 Some(ForwardKind::HaProxy) => Some(HaProxyLayer::default()),
271 Some(other) => {
272 return Err(OpaqueError::from_display(format!(
273 "invalid forward kind for Transport mode: {other:?}"
274 ))
275 .into());
276 }
277 };
278
279 let tcp_service_builder = (
280 ConsumeErrLayer::trace(tracing::Level::DEBUG),
281 (self.concurrent_limit > 0)
282 .then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
283 (!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
284 tcp_forwarded_layer,
285 );
286
287 Ok(tcp_service_builder.into_layer(TcpEchoService))
288 }
289}
290
291pub mod mode {
292 #[derive(Debug, Clone)]
295 #[non_exhaustive]
296 pub struct Http;
298
299 #[derive(Debug, Clone)]
300 #[non_exhaustive]
301 pub struct Transport;
303}