1use crate::{
6 Context, Layer, Service,
7 cli::ForwardKind,
8 combinators::Either7,
9 error::{BoxError, OpaqueError},
10 http::{
11 Request, Response, StatusCode,
12 headers::forwarded::{CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp},
13 layer::{
14 forwarded::GetForwardedHeaderLayer, required_header::AddRequiredResponseHeadersLayer,
15 trace::TraceLayer, ua::UserAgentClassifierLayer,
16 },
17 server::HttpServer,
18 },
19 layer::{ConsumeErrLayer, LimitLayer, TimeoutLayer, limit::policy::ConcurrentPolicy},
20 net::forwarded::Forwarded,
21 net::stream::{SocketInfo, Stream, layer::http::BodyLimitLayer},
22 proxy::haproxy::server::HaProxyLayer,
23 rt::Executor,
24 telemetry::tracing,
25};
26use rama_http::service::web::response::IntoResponse;
27use std::{convert::Infallible, marker::PhantomData, time::Duration};
28use tokio::{io::AsyncWriteExt, net::TcpStream};
29
30#[derive(Debug, Clone)]
31pub struct IpServiceBuilder<M> {
34 concurrent_limit: usize,
35 timeout: Duration,
36 forward: Option<ForwardKind>,
37 _mode: PhantomData<fn(M)>,
38}
39
40impl Default for IpServiceBuilder<mode::Http> {
41 fn default() -> Self {
42 Self {
43 concurrent_limit: 0,
44 timeout: Duration::ZERO,
45 forward: None,
46 _mode: PhantomData,
47 }
48 }
49}
50
51impl IpServiceBuilder<mode::Http> {
52 #[must_use]
54 pub fn http() -> Self {
55 Self::default()
56 }
57}
58
59impl IpServiceBuilder<mode::Transport> {
60 #[must_use]
62 pub fn tcp() -> Self {
63 Self {
64 concurrent_limit: 0,
65 timeout: Duration::ZERO,
66 forward: None,
67 _mode: PhantomData,
68 }
69 }
70}
71
72impl<M> IpServiceBuilder<M> {
73 #[must_use]
77 pub fn concurrent(mut self, limit: usize) -> Self {
78 self.concurrent_limit = limit;
79 self
80 }
81
82 pub fn set_concurrent(&mut self, limit: usize) -> &mut Self {
86 self.concurrent_limit = limit;
87 self
88 }
89
90 #[must_use]
94 pub fn timeout(mut self, timeout: Duration) -> Self {
95 self.timeout = timeout;
96 self
97 }
98
99 pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
103 self.timeout = timeout;
104 self
105 }
106
107 #[must_use]
119 pub fn forward(self, kind: ForwardKind) -> Self {
120 self.maybe_forward(Some(kind))
121 }
122
123 pub fn set_forward(&mut self, kind: ForwardKind) -> &mut Self {
127 self.forward = Some(kind);
128 self
129 }
130
131 #[must_use]
135 pub fn maybe_forward(mut self, maybe_kind: Option<ForwardKind>) -> Self {
136 self.forward = maybe_kind;
137 self
138 }
139}
140
141impl IpServiceBuilder<mode::Http> {
142 pub fn build(
144 self,
145 executor: Executor,
146 ) -> Result<impl Service<TcpStream, Response = (), Error = Infallible>, BoxError> {
147 let (tcp_forwarded_layer, http_forwarded_layer) = match &self.forward {
148 None => (None, None),
149 Some(ForwardKind::Forwarded) => {
150 (None, Some(Either7::A(GetForwardedHeaderLayer::forwarded())))
151 }
152 Some(ForwardKind::XForwardedFor) => (
153 None,
154 Some(Either7::B(GetForwardedHeaderLayer::x_forwarded_for())),
155 ),
156 Some(ForwardKind::XClientIp) => (
157 None,
158 Some(Either7::C(GetForwardedHeaderLayer::<XClientIp>::new())),
159 ),
160 Some(ForwardKind::ClientIp) => (
161 None,
162 Some(Either7::D(GetForwardedHeaderLayer::<ClientIp>::new())),
163 ),
164 Some(ForwardKind::XRealIp) => (
165 None,
166 Some(Either7::E(GetForwardedHeaderLayer::<XRealIp>::new())),
167 ),
168 Some(ForwardKind::CFConnectingIp) => (
169 None,
170 Some(Either7::F(GetForwardedHeaderLayer::<CFConnectingIp>::new())),
171 ),
172 Some(ForwardKind::TrueClientIp) => (
173 None,
174 Some(Either7::G(GetForwardedHeaderLayer::<TrueClientIp>::new())),
175 ),
176 Some(ForwardKind::HaProxy) => (Some(HaProxyLayer::default()), None),
177 };
178
179 let tcp_service_builder = (
180 ConsumeErrLayer::trace(tracing::Level::DEBUG),
181 (self.concurrent_limit > 0)
182 .then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
183 (!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
184 tcp_forwarded_layer,
185 BodyLimitLayer::request_only(1024 * 1024),
187 );
188
189 let http_service = (
192 TraceLayer::new_for_http(),
193 AddRequiredResponseHeadersLayer::default(),
194 UserAgentClassifierLayer::new(),
195 ConsumeErrLayer::default(),
196 http_forwarded_layer,
197 )
198 .into_layer(HttpEchoService);
199
200 Ok(tcp_service_builder.into_layer(HttpServer::auto(executor).service(http_service)))
201 }
202}
203
204#[derive(Debug, Clone)]
205#[non_exhaustive]
206pub struct HttpEchoService;
208
209impl Service<Request> for HttpEchoService {
210 type Response = Response;
211 type Error = BoxError;
212
213 async fn serve(&self, ctx: Context, _req: Request) -> Result<Self::Response, Self::Error> {
214 let peer_ip = ctx
215 .get::<Forwarded>()
216 .and_then(|f| f.client_ip())
217 .or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
218
219 Ok(match peer_ip {
220 Some(ip) => ip.to_string().into_response(),
221 None => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
222 })
223 }
224}
225
226#[derive(Debug, Clone)]
227#[non_exhaustive]
228pub struct TcpEchoService;
230
231impl<Input> Service<Input> for TcpEchoService
232where
233 Input: Stream + Unpin,
234{
235 type Response = ();
236 type Error = BoxError;
237
238 async fn serve(&self, ctx: Context, stream: Input) -> Result<Self::Response, Self::Error> {
239 let peer_ip = ctx
240 .get::<Forwarded>()
241 .and_then(|f| f.client_ip())
242 .or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
243 let Some(peer_ip) = peer_ip else {
244 tracing::error!("missing peer information");
245 return Ok(());
246 };
247
248 let mut stream = std::pin::pin!(stream);
249
250 match peer_ip {
251 std::net::IpAddr::V4(ip) => {
252 if let Err(err) = stream.write_all(&ip.octets()).await {
253 tracing::error!("error writing IPv4 of peer to peer: {}", err);
254 }
255 }
256 std::net::IpAddr::V6(ip) => {
257 if let Err(err) = stream.write_all(&ip.octets()).await {
258 tracing::error!("error writing IPv6 of peer to peer: {}", err);
259 }
260 }
261 };
262
263 Ok(())
264 }
265}
266
267impl IpServiceBuilder<mode::Transport> {
268 pub fn build(
270 self,
271 ) -> Result<impl Service<TcpStream, Response = (), Error = Infallible>, BoxError> {
272 let tcp_forwarded_layer = match &self.forward {
273 None => None,
274 Some(ForwardKind::HaProxy) => Some(HaProxyLayer::default()),
275 Some(other) => {
276 return Err(OpaqueError::from_display(format!(
277 "invalid forward kind for Transport mode: {other:?}"
278 ))
279 .into());
280 }
281 };
282
283 let tcp_service_builder = (
284 ConsumeErrLayer::trace(tracing::Level::DEBUG),
285 (self.concurrent_limit > 0)
286 .then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
287 (!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
288 tcp_forwarded_layer,
289 );
290
291 Ok(tcp_service_builder.into_layer(TcpEchoService))
292 }
293}
294
295pub mod mode {
296 #[derive(Debug, Clone)]
299 #[non_exhaustive]
300 pub struct Http;
302
303 #[derive(Debug, Clone)]
304 #[non_exhaustive]
305 pub struct Transport;
307}