rama/cli/service/
ip.rs

1//! IP '[`Service`] that echos the client IP either over http or directly over tcp.
2//!
3//! [`Service`]: crate::Service
4
5use crate::{
6    Context, Layer, Service,
7    cli::ForwardKind,
8    combinators::Either7,
9    error::{BoxError, OpaqueError},
10    http::{
11        Request, Response, StatusCode,
12        headers::forwarded::{CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp},
13        layer::{
14            forwarded::GetForwardedHeaderLayer, required_header::AddRequiredResponseHeadersLayer,
15            trace::TraceLayer, ua::UserAgentClassifierLayer,
16        },
17        server::HttpServer,
18    },
19    layer::{ConsumeErrLayer, LimitLayer, TimeoutLayer, limit::policy::ConcurrentPolicy},
20    net::forwarded::Forwarded,
21    net::stream::{SocketInfo, Stream, layer::http::BodyLimitLayer},
22    proxy::haproxy::server::HaProxyLayer,
23    rt::Executor,
24    telemetry::tracing,
25};
26use rama_http::service::web::response::IntoResponse;
27use std::{convert::Infallible, marker::PhantomData, time::Duration};
28use tokio::{io::AsyncWriteExt, net::TcpStream};
29
30#[derive(Debug, Clone)]
31/// Builder that can be used to run your own ip [`Service`],
32/// echo'ing back the client IP over http or tcp.
33pub struct IpServiceBuilder<M> {
34    concurrent_limit: usize,
35    timeout: Duration,
36    forward: Option<ForwardKind>,
37    _mode: PhantomData<fn(M)>,
38}
39
40impl Default for IpServiceBuilder<mode::Http> {
41    fn default() -> Self {
42        Self {
43            concurrent_limit: 0,
44            timeout: Duration::ZERO,
45            forward: None,
46            _mode: PhantomData,
47        }
48    }
49}
50
51impl IpServiceBuilder<mode::Http> {
52    /// Create a new [`IpServiceBuilder`], echoing the IP back over HTTP.
53    #[must_use]
54    pub fn http() -> Self {
55        Self::default()
56    }
57}
58
59impl IpServiceBuilder<mode::Transport> {
60    /// Create a new [`IpServiceBuilder`], echoing the IP back over L4.
61    #[must_use]
62    pub fn tcp() -> Self {
63        Self {
64            concurrent_limit: 0,
65            timeout: Duration::ZERO,
66            forward: None,
67            _mode: PhantomData,
68        }
69    }
70}
71
72impl<M> IpServiceBuilder<M> {
73    /// set the number of concurrent connections to allow
74    ///
75    /// (0 = no limit)
76    #[must_use]
77    pub fn concurrent(mut self, limit: usize) -> Self {
78        self.concurrent_limit = limit;
79        self
80    }
81
82    /// set the number of concurrent connections to allow
83    ///
84    /// (0 = no limit)
85    pub fn set_concurrent(&mut self, limit: usize) -> &mut Self {
86        self.concurrent_limit = limit;
87        self
88    }
89
90    /// set the timeout in seconds for each connection
91    ///
92    /// (0 = no timeout)
93    #[must_use]
94    pub fn timeout(mut self, timeout: Duration) -> Self {
95        self.timeout = timeout;
96        self
97    }
98
99    /// set the timeout in seconds for each connection
100    ///
101    /// (0 = no timeout)
102    pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
103        self.timeout = timeout;
104        self
105    }
106
107    /// enable support for one of the following "forward" headers or protocols
108    ///
109    /// Supported headers:
110    ///
111    /// Forwarded ("for="), X-Forwarded-For
112    ///
113    /// X-Client-IP Client-IP, X-Real-IP
114    ///
115    /// CF-Connecting-IP, True-Client-IP
116    ///
117    /// Or using HaProxy protocol.
118    #[must_use]
119    pub fn forward(self, kind: ForwardKind) -> Self {
120        self.maybe_forward(Some(kind))
121    }
122
123    /// enable support for one of the following "forward" headers or protocols
124    ///
125    /// Same as [`Self::forward`] but without consuming `self`.
126    pub fn set_forward(&mut self, kind: ForwardKind) -> &mut Self {
127        self.forward = Some(kind);
128        self
129    }
130
131    /// maybe enable support for one of the following "forward" headers or protocols.
132    ///
133    /// See [`Self::forward`] for more information.
134    #[must_use]
135    pub fn maybe_forward(mut self, maybe_kind: Option<ForwardKind>) -> Self {
136        self.forward = maybe_kind;
137        self
138    }
139}
140
141impl IpServiceBuilder<mode::Http> {
142    /// build a tcp service ready to echo http traffic back
143    pub fn build(
144        self,
145        executor: Executor,
146    ) -> Result<impl Service<TcpStream, Response = (), Error = Infallible>, BoxError> {
147        let (tcp_forwarded_layer, http_forwarded_layer) = match &self.forward {
148            None => (None, None),
149            Some(ForwardKind::Forwarded) => {
150                (None, Some(Either7::A(GetForwardedHeaderLayer::forwarded())))
151            }
152            Some(ForwardKind::XForwardedFor) => (
153                None,
154                Some(Either7::B(GetForwardedHeaderLayer::x_forwarded_for())),
155            ),
156            Some(ForwardKind::XClientIp) => (
157                None,
158                Some(Either7::C(GetForwardedHeaderLayer::<XClientIp>::new())),
159            ),
160            Some(ForwardKind::ClientIp) => (
161                None,
162                Some(Either7::D(GetForwardedHeaderLayer::<ClientIp>::new())),
163            ),
164            Some(ForwardKind::XRealIp) => (
165                None,
166                Some(Either7::E(GetForwardedHeaderLayer::<XRealIp>::new())),
167            ),
168            Some(ForwardKind::CFConnectingIp) => (
169                None,
170                Some(Either7::F(GetForwardedHeaderLayer::<CFConnectingIp>::new())),
171            ),
172            Some(ForwardKind::TrueClientIp) => (
173                None,
174                Some(Either7::G(GetForwardedHeaderLayer::<TrueClientIp>::new())),
175            ),
176            Some(ForwardKind::HaProxy) => (Some(HaProxyLayer::default()), None),
177        };
178
179        let tcp_service_builder = (
180            ConsumeErrLayer::trace(tracing::Level::DEBUG),
181            (self.concurrent_limit > 0)
182                .then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
183            (!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
184            tcp_forwarded_layer,
185            // Limit the body size to 1MB for requests
186            BodyLimitLayer::request_only(1024 * 1024),
187        );
188
189        // TODO: support opt-in TLS)
190
191        let http_service = (
192            TraceLayer::new_for_http(),
193            AddRequiredResponseHeadersLayer::default(),
194            UserAgentClassifierLayer::new(),
195            ConsumeErrLayer::default(),
196            http_forwarded_layer,
197        )
198            .into_layer(HttpEchoService);
199
200        Ok(tcp_service_builder.into_layer(HttpServer::auto(executor).service(http_service)))
201    }
202}
203
204#[derive(Debug, Clone)]
205#[non_exhaustive]
206/// The inner http echo-service used by the [`IpServiceBuilder`].
207pub struct HttpEchoService;
208
209impl Service<Request> for HttpEchoService {
210    type Response = Response;
211    type Error = BoxError;
212
213    async fn serve(&self, ctx: Context, _req: Request) -> Result<Self::Response, Self::Error> {
214        let peer_ip = ctx
215            .get::<Forwarded>()
216            .and_then(|f| f.client_ip())
217            .or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
218
219        Ok(match peer_ip {
220            Some(ip) => ip.to_string().into_response(),
221            None => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
222        })
223    }
224}
225
226#[derive(Debug, Clone)]
227#[non_exhaustive]
228/// The inner tcp echo-service used by the [`IpServiceBuilder`].
229pub struct TcpEchoService;
230
231impl<Input> Service<Input> for TcpEchoService
232where
233    Input: Stream + Unpin,
234{
235    type Response = ();
236    type Error = BoxError;
237
238    async fn serve(&self, ctx: Context, stream: Input) -> Result<Self::Response, Self::Error> {
239        let peer_ip = ctx
240            .get::<Forwarded>()
241            .and_then(|f| f.client_ip())
242            .or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
243        let Some(peer_ip) = peer_ip else {
244            tracing::error!("missing peer information");
245            return Ok(());
246        };
247
248        let mut stream = std::pin::pin!(stream);
249
250        match peer_ip {
251            std::net::IpAddr::V4(ip) => {
252                if let Err(err) = stream.write_all(&ip.octets()).await {
253                    tracing::error!("error writing IPv4 of peer to peer: {}", err);
254                }
255            }
256            std::net::IpAddr::V6(ip) => {
257                if let Err(err) = stream.write_all(&ip.octets()).await {
258                    tracing::error!("error writing IPv6 of peer to peer: {}", err);
259                }
260            }
261        };
262
263        Ok(())
264    }
265}
266
267impl IpServiceBuilder<mode::Transport> {
268    /// build a tcp service ready to echo http traffic back
269    pub fn build(
270        self,
271    ) -> Result<impl Service<TcpStream, Response = (), Error = Infallible>, BoxError> {
272        let tcp_forwarded_layer = match &self.forward {
273            None => None,
274            Some(ForwardKind::HaProxy) => Some(HaProxyLayer::default()),
275            Some(other) => {
276                return Err(OpaqueError::from_display(format!(
277                    "invalid forward kind for Transport mode: {other:?}"
278                ))
279                .into());
280            }
281        };
282
283        let tcp_service_builder = (
284            ConsumeErrLayer::trace(tracing::Level::DEBUG),
285            (self.concurrent_limit > 0)
286                .then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
287            (!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
288            tcp_forwarded_layer,
289        );
290
291        Ok(tcp_service_builder.into_layer(TcpEchoService))
292    }
293}
294
295pub mod mode {
296    //! operation modes of the ip service
297
298    #[derive(Debug, Clone)]
299    #[non_exhaustive]
300    /// Default mode of the Ip service, echo'ng the info back over http
301    pub struct Http;
302
303    #[derive(Debug, Clone)]
304    #[non_exhaustive]
305    /// Alternative mode of the Ip service, echo'ng the ip info over tcp
306    pub struct Transport;
307}