rama/cli/service/
ip.rs

1//! IP '[`Service`] that echos the client IP either over http or directly over tcp.
2//!
3//! [`Service`]: crate::Service
4
5use crate::{
6    Context, Layer, Service,
7    cli::ForwardKind,
8    combinators::Either7,
9    error::{BoxError, OpaqueError},
10    http::{
11        IntoResponse, Request, Response, StatusCode,
12        headers::{CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp},
13        layer::{
14            forwarded::GetForwardedHeadersLayer, required_header::AddRequiredResponseHeadersLayer,
15            trace::TraceLayer, ua::UserAgentClassifierLayer,
16        },
17        server::HttpServer,
18    },
19    layer::{ConsumeErrLayer, LimitLayer, TimeoutLayer, limit::policy::ConcurrentPolicy},
20    net::forwarded::Forwarded,
21    net::stream::{SocketInfo, Stream, layer::http::BodyLimitLayer},
22    proxy::haproxy::server::HaProxyLayer,
23    rt::Executor,
24};
25use std::{convert::Infallible, marker::PhantomData, time::Duration};
26use tokio::{io::AsyncWriteExt, net::TcpStream};
27
28#[derive(Debug, Clone)]
29/// Builder that can be used to run your own ip [`Service`],
30/// echo'ing back the client IP over http or tcp.
31pub struct IpServiceBuilder<M> {
32    concurrent_limit: usize,
33    timeout: Duration,
34    forward: Option<ForwardKind>,
35    _mode: PhantomData<fn(M)>,
36}
37
38impl Default for IpServiceBuilder<mode::Http> {
39    fn default() -> Self {
40        Self {
41            concurrent_limit: 0,
42            timeout: Duration::ZERO,
43            forward: None,
44            _mode: PhantomData,
45        }
46    }
47}
48
49impl IpServiceBuilder<mode::Http> {
50    /// Create a new [`IpServiceBuilder`], echoing the IP back over HTTP.
51    pub fn http() -> Self {
52        Self::default()
53    }
54}
55
56impl IpServiceBuilder<mode::Transport> {
57    /// Create a new [`IpServiceBuilder`], echoing the IP back over L4.
58    pub fn tcp() -> Self {
59        Self {
60            concurrent_limit: 0,
61            timeout: Duration::ZERO,
62            forward: None,
63            _mode: PhantomData,
64        }
65    }
66}
67
68impl<M> IpServiceBuilder<M> {
69    /// set the number of concurrent connections to allow
70    ///
71    /// (0 = no limit)
72    pub fn concurrent(mut self, limit: usize) -> Self {
73        self.concurrent_limit = limit;
74        self
75    }
76
77    /// set the number of concurrent connections to allow
78    ///
79    /// (0 = no limit)
80    pub fn set_concurrent(&mut self, limit: usize) -> &mut Self {
81        self.concurrent_limit = limit;
82        self
83    }
84
85    /// set the timeout in seconds for each connection
86    ///
87    /// (0 = no timeout)
88    pub fn timeout(mut self, timeout: Duration) -> Self {
89        self.timeout = timeout;
90        self
91    }
92
93    /// set the timeout in seconds for each connection
94    ///
95    /// (0 = no timeout)
96    pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
97        self.timeout = timeout;
98        self
99    }
100
101    /// enable support for one of the following "forward" headers or protocols
102    ///
103    /// Supported headers:
104    ///
105    /// Forwarded ("for="), X-Forwarded-For
106    ///
107    /// X-Client-IP Client-IP, X-Real-IP
108    ///
109    /// CF-Connecting-IP, True-Client-IP
110    ///
111    /// Or using HaProxy protocol.
112    pub fn forward(self, kind: ForwardKind) -> Self {
113        self.maybe_forward(Some(kind))
114    }
115
116    /// enable support for one of the following "forward" headers or protocols
117    ///
118    /// Same as [`Self::forward`] but without consuming `self`.
119    pub fn set_forward(&mut self, kind: ForwardKind) -> &mut Self {
120        self.forward = Some(kind);
121        self
122    }
123
124    /// maybe enable support for one of the following "forward" headers or protocols.
125    ///
126    /// See [`Self::forward`] for more information.
127    pub fn maybe_forward(mut self, maybe_kind: Option<ForwardKind>) -> Self {
128        self.forward = maybe_kind;
129        self
130    }
131}
132
133impl IpServiceBuilder<mode::Http> {
134    /// build a tcp service ready to echo http traffic back
135    pub fn build(
136        self,
137        executor: Executor,
138    ) -> Result<impl Service<(), TcpStream, Response = (), Error = Infallible>, BoxError> {
139        let (tcp_forwarded_layer, http_forwarded_layer) = match &self.forward {
140            None => (None, None),
141            Some(ForwardKind::Forwarded) => (
142                None,
143                Some(Either7::A(GetForwardedHeadersLayer::forwarded())),
144            ),
145            Some(ForwardKind::XForwardedFor) => (
146                None,
147                Some(Either7::B(GetForwardedHeadersLayer::x_forwarded_for())),
148            ),
149            Some(ForwardKind::XClientIp) => (
150                None,
151                Some(Either7::C(GetForwardedHeadersLayer::<XClientIp>::new())),
152            ),
153            Some(ForwardKind::ClientIp) => (
154                None,
155                Some(Either7::D(GetForwardedHeadersLayer::<ClientIp>::new())),
156            ),
157            Some(ForwardKind::XRealIp) => (
158                None,
159                Some(Either7::E(GetForwardedHeadersLayer::<XRealIp>::new())),
160            ),
161            Some(ForwardKind::CFConnectingIp) => (
162                None,
163                Some(Either7::F(GetForwardedHeadersLayer::<CFConnectingIp>::new())),
164            ),
165            Some(ForwardKind::TrueClientIp) => (
166                None,
167                Some(Either7::G(GetForwardedHeadersLayer::<TrueClientIp>::new())),
168            ),
169            Some(ForwardKind::HaProxy) => (Some(HaProxyLayer::default()), None),
170        };
171
172        let tcp_service_builder = (
173            ConsumeErrLayer::trace(tracing::Level::DEBUG),
174            (self.concurrent_limit > 0)
175                .then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
176            (!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
177            tcp_forwarded_layer,
178            // Limit the body size to 1MB for requests
179            BodyLimitLayer::request_only(1024 * 1024),
180        );
181
182        // TODO: support opt-in TLS)
183
184        let http_service = (
185            TraceLayer::new_for_http(),
186            AddRequiredResponseHeadersLayer::default(),
187            UserAgentClassifierLayer::new(),
188            ConsumeErrLayer::default(),
189            http_forwarded_layer,
190        )
191            .into_layer(HttpEchoService);
192
193        Ok(tcp_service_builder.into_layer(HttpServer::auto(executor).service(http_service)))
194    }
195}
196
197#[derive(Debug, Clone)]
198#[non_exhaustive]
199/// The inner http echo-service used by the [`IpServiceBuilder`].
200pub struct HttpEchoService;
201
202impl Service<(), Request> for HttpEchoService {
203    type Response = Response;
204    type Error = BoxError;
205
206    async fn serve(&self, ctx: Context<()>, _req: Request) -> Result<Self::Response, Self::Error> {
207        let peer_ip = ctx
208            .get::<Forwarded>()
209            .and_then(|f| f.client_ip())
210            .or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
211
212        Ok(match peer_ip {
213            Some(ip) => ip.to_string().into_response(),
214            None => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
215        })
216    }
217}
218
219#[derive(Debug, Clone)]
220#[non_exhaustive]
221/// The inner tcp echo-service used by the [`IpServiceBuilder`].
222pub struct TcpEchoService;
223
224impl<Input> Service<(), Input> for TcpEchoService
225where
226    Input: Stream + Unpin,
227{
228    type Response = ();
229    type Error = BoxError;
230
231    async fn serve(&self, ctx: Context<()>, stream: Input) -> Result<Self::Response, Self::Error> {
232        let peer_ip = ctx
233            .get::<Forwarded>()
234            .and_then(|f| f.client_ip())
235            .or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
236        let peer_ip = match peer_ip {
237            Some(peer_ip) => peer_ip,
238            None => {
239                tracing::error!("missing peer information");
240                return Ok(());
241            }
242        };
243
244        let mut stream = std::pin::pin!(stream);
245
246        match peer_ip {
247            std::net::IpAddr::V4(ip) => {
248                if let Err(err) = stream.write_all(&ip.octets()).await {
249                    tracing::error!("error writing IPv4 of peer to peer: {}", err);
250                }
251            }
252            std::net::IpAddr::V6(ip) => {
253                if let Err(err) = stream.write_all(&ip.octets()).await {
254                    tracing::error!("error writing IPv6 of peer to peer: {}", err);
255                }
256            }
257        };
258
259        Ok(())
260    }
261}
262
263impl IpServiceBuilder<mode::Transport> {
264    /// build a tcp service ready to echo http traffic back
265    pub fn build(
266        self,
267    ) -> Result<impl Service<(), TcpStream, Response = (), Error = Infallible>, BoxError> {
268        let tcp_forwarded_layer = match &self.forward {
269            None => None,
270            Some(ForwardKind::HaProxy) => Some(HaProxyLayer::default()),
271            Some(other) => {
272                return Err(OpaqueError::from_display(format!(
273                    "invalid forward kind for Transport mode: {other:?}"
274                ))
275                .into());
276            }
277        };
278
279        let tcp_service_builder = (
280            ConsumeErrLayer::trace(tracing::Level::DEBUG),
281            (self.concurrent_limit > 0)
282                .then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
283            (!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
284            tcp_forwarded_layer,
285        );
286
287        Ok(tcp_service_builder.into_layer(TcpEchoService))
288    }
289}
290
291pub mod mode {
292    //! operation modes of the ip service
293
294    #[derive(Debug, Clone)]
295    #[non_exhaustive]
296    /// Default mode of the Ip service, echo'ng the info back over http
297    pub struct Http;
298
299    #[derive(Debug, Clone)]
300    #[non_exhaustive]
301    /// Alternative mode of the Ip service, echo'ng the ip info over tcp
302    pub struct Transport;
303}