Skip to main content

rama/cli/service/
ip.rs

1//! IP '[`Service`] that echos the client IP either over http or directly over tcp.
2//!
3//! [`Service`]: crate::Service
4
5use crate::{
6    Layer, Service,
7    cli::ForwardKind,
8    combinators::Either,
9    combinators::Either7,
10    error::BoxError,
11    extensions::{ExtensionsMut, ExtensionsRef},
12    http::{
13        Request, Response, StatusCode,
14        headers::exotic::XClacksOverhead,
15        headers::forwarded::{CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp},
16        headers::{Accept, HeaderMapExt},
17        layer::{
18            forwarded::GetForwardedHeaderLayer, required_header::AddRequiredResponseHeadersLayer,
19            set_header::SetResponseHeaderLayer, trace::TraceLayer,
20        },
21        mime,
22        server::HttpServer,
23        service::web::response::{Html, IntoResponse, Json, Redirect},
24    },
25    layer::limit::policy::UnlimitedPolicy,
26    layer::{ConsumeErrLayer, LimitLayer, TimeoutLayer, limit::policy::ConcurrentPolicy},
27    net::forwarded::Forwarded,
28    net::stream::{SocketInfo, layer::http::BodyLimitLayer},
29    proxy::haproxy::server::HaProxyLayer,
30    rt::Executor,
31    stream::Stream,
32    tcp::TcpStream,
33    telemetry::tracing,
34};
35
36#[cfg(all(feature = "rustls", not(feature = "boring")))]
37use crate::tls::rustls::server::{TlsAcceptorData, TlsAcceptorLayer};
38
39#[cfg(any(feature = "rustls", feature = "boring"))]
40use crate::http::headers::StrictTransportSecurity;
41
42#[cfg(feature = "boring")]
43use crate::{
44    net::tls::server::ServerConfig,
45    tls::boring::server::{TlsAcceptorData, TlsAcceptorLayer},
46};
47
48#[cfg(feature = "boring")]
49type TlsConfig = ServerConfig;
50
51#[cfg(all(feature = "rustls", not(feature = "boring")))]
52type TlsConfig = TlsAcceptorData;
53
54use rama_core::error::ErrorExt as _;
55use std::{convert::Infallible, marker::PhantomData, net::IpAddr, time::Duration};
56use tokio::io::AsyncWriteExt;
57
58#[derive(Debug, Clone)]
59/// Builder that can be used to run your own ip [`Service`],
60/// echo'ing back the client IP over http or tcp.
61pub struct IpServiceBuilder<M> {
62    #[cfg(any(feature = "rustls", feature = "boring"))]
63    tls_server_config: Option<TlsConfig>,
64    concurrent_limit: usize,
65    timeout: Duration,
66    forward: Option<ForwardKind>,
67    _mode: PhantomData<fn(M)>,
68}
69
70impl IpServiceBuilder<mode::Http> {
71    /// Create a new [`IpServiceBuilder`], echoing the IP back over L4.
72    #[must_use]
73    pub fn http() -> Self {
74        Self {
75            #[cfg(any(feature = "rustls", feature = "boring"))]
76            tls_server_config: None,
77            concurrent_limit: 0,
78            timeout: Duration::ZERO,
79            forward: None,
80            _mode: PhantomData,
81        }
82    }
83}
84
85impl IpServiceBuilder<mode::Transport> {
86    /// Create a new [`IpServiceBuilder`], echoing the IP back over L4.
87    #[must_use]
88    pub fn tcp() -> Self {
89        Self {
90            #[cfg(any(feature = "rustls", feature = "boring"))]
91            tls_server_config: None,
92            concurrent_limit: 0,
93            timeout: Duration::ZERO,
94            forward: None,
95            _mode: PhantomData,
96        }
97    }
98}
99
100impl<M> IpServiceBuilder<M> {
101    crate::utils::macros::generate_set_and_with! {
102        /// set the number of concurrent connections to allow
103        #[must_use]
104        pub fn concurrent(mut self, limit: usize) -> Self {
105            self.concurrent_limit = limit;
106            self
107        }
108    }
109
110    crate::utils::macros::generate_set_and_with! {
111        /// set the timeout in seconds for each connection
112        #[must_use]
113        pub fn timeout(mut self, timeout: Duration) -> Self {
114            self.timeout = timeout;
115            self
116        }
117    }
118
119    crate::utils::macros::generate_set_and_with! {
120        /// maybe enable support for one of the following "forward" headers or protocols
121        ///
122        /// Supported headers:
123        ///
124        /// Forwarded ("for="), X-Forwarded-For
125        ///
126        /// X-Client-IP Client-IP, X-Real-IP
127        ///
128        /// CF-Connecting-IP, True-Client-IP
129        ///
130        /// Or using HaProxy protocol.
131        #[must_use]
132        pub fn forward(mut self, maybe_kind: Option<ForwardKind>) -> Self {
133            self.forward = maybe_kind;
134            self
135        }
136    }
137
138    crate::utils::macros::generate_set_and_with! {
139        #[cfg(any(feature = "rustls", feature = "boring"))]
140        /// define a tls server cert config to be used for tls terminaton
141        /// by the IP service.
142        pub fn tls_server_config(mut self, cfg: Option<TlsConfig>) -> Self {
143            self.tls_server_config = cfg;
144            self
145        }
146    }
147}
148
149impl IpServiceBuilder<mode::Http> {
150    #[allow(unused_mut)]
151    #[inline]
152    /// build a tcp service ready to echo the client IP back
153    pub fn build(
154        mut self,
155        executor: Executor,
156    ) -> Result<impl Service<TcpStream, Output = (), Error = Infallible>, BoxError> {
157        #[cfg(all(feature = "rustls", not(feature = "boring")))]
158        let tls_cfg = self.tls_server_config.take();
159
160        #[cfg(feature = "boring")]
161        let tls_cfg: Option<TlsAcceptorData> = match self.tls_server_config.take() {
162            Some(cfg) => Some(cfg.try_into()?),
163            None => None,
164        };
165
166        #[cfg(any(feature = "rustls", feature = "boring"))]
167        {
168            let maybe_tls_acceptor_layer = tls_cfg.map(TlsAcceptorLayer::new);
169            self.build_http(executor, maybe_tls_acceptor_layer)
170        }
171
172        #[cfg(not(any(feature = "rustls", feature = "boring")))]
173        self.build_http(executor)
174    }
175}
176
177#[derive(Debug, Clone)]
178#[non_exhaustive]
179/// The inner http ip-service used by the [`IpServiceBuilder`].
180struct HttpIpService;
181
182impl Service<Request> for HttpIpService {
183    type Output = Response;
184    type Error = BoxError;
185
186    async fn serve(&self, req: Request) -> Result<Self::Output, Self::Error> {
187        let norm_req_path = req.uri().path().trim_matches('/');
188        if !norm_req_path.is_empty() {
189            tracing::debug!("unexpected request path '{norm_req_path}', redirect to root");
190            return Ok(Redirect::permanent("/").into_response());
191        }
192
193        let peer_ip = req
194            .extensions()
195            .get::<Forwarded>()
196            .and_then(|f| f.client_ip())
197            .or_else(|| {
198                req.extensions()
199                    .get::<SocketInfo>()
200                    .map(|s| s.peer_addr().ip_addr)
201            });
202
203        Ok(match peer_ip {
204            Some(ip) => match HttpBodyContentFormat::derive_from_req(&req) {
205                HttpBodyContentFormat::Txt => ip.to_string().into_response(),
206                HttpBodyContentFormat::Html => format_html_page(ip).into_response(),
207                HttpBodyContentFormat::Json => Json(serde_json::json!({
208                    "ip": ip,
209                }))
210                .into_response(),
211            },
212            None => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
213        })
214    }
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
218enum HttpBodyContentFormat {
219    #[default]
220    Txt,
221    Html,
222    Json,
223}
224
225impl HttpBodyContentFormat {
226    fn derive_from_req(req: &Request) -> Self {
227        let Some(accept) = req.headers().typed_get::<Accept>() else {
228            return Self::default();
229        };
230        accept
231            .0
232            .iter()
233            .find_map(|qv| {
234                let r#type = qv.value.subtype();
235                if r#type == mime::JSON {
236                    Some(Self::Json)
237                } else if r#type == mime::HTML {
238                    Some(Self::Html)
239                } else if r#type == mime::TEXT {
240                    Some(Self::Txt)
241                } else {
242                    None
243                }
244            })
245            .unwrap_or_default()
246    }
247}
248
249#[derive(Debug, Clone)]
250#[non_exhaustive]
251/// The inner tcp echo-service used by the [`IpServiceBuilder`].
252struct TcpIpService;
253
254impl<Input> Service<Input> for TcpIpService
255where
256    Input: Stream + Unpin + ExtensionsRef,
257{
258    type Output = ();
259    type Error = BoxError;
260
261    async fn serve(&self, stream: Input) -> Result<Self::Output, Self::Error> {
262        tracing::info!("connection received");
263        let peer_ip = stream
264            .extensions()
265            .get::<Forwarded>()
266            .and_then(|f| f.client_ip())
267            .or_else(|| {
268                stream
269                    .extensions()
270                    .get::<SocketInfo>()
271                    .map(|s| s.peer_addr().ip_addr)
272            });
273        let Some(peer_ip) = peer_ip else {
274            tracing::error!("missing peer information");
275            return Ok(());
276        };
277
278        let mut stream = std::pin::pin!(stream);
279
280        match peer_ip {
281            std::net::IpAddr::V4(ip) => {
282                if let Err(err) = stream.write_all(&ip.octets()).await {
283                    tracing::error!("error writing IPv4 of peer to peer: {}", err);
284                }
285            }
286            std::net::IpAddr::V6(ip) => {
287                if let Err(err) = stream.write_all(&ip.octets()).await {
288                    tracing::error!("error writing IPv6 of peer to peer: {}", err);
289                }
290            }
291        };
292
293        Ok(())
294    }
295}
296
297impl IpServiceBuilder<mode::Transport> {
298    #[allow(unused_mut)]
299    #[inline]
300    /// build a tcp service ready to echo client IP back
301    pub fn build(
302        mut self,
303    ) -> Result<impl Service<TcpStream, Output = (), Error = Infallible>, BoxError> {
304        #[cfg(all(feature = "rustls", not(feature = "boring")))]
305        let tls_cfg = self.tls_server_config.take();
306
307        #[cfg(feature = "boring")]
308        let tls_cfg: Option<TlsAcceptorData> = match self.tls_server_config.take() {
309            Some(cfg) => Some(cfg.try_into()?),
310            None => None,
311        };
312
313        #[cfg(any(feature = "rustls", feature = "boring"))]
314        {
315            let maybe_tls_acceptor_layer = tls_cfg.map(TlsAcceptorLayer::new);
316            self.build_tcp(maybe_tls_acceptor_layer)
317        }
318
319        #[cfg(not(any(feature = "rustls", feature = "boring")))]
320        self.build_tcp()
321    }
322}
323
324impl<M> IpServiceBuilder<M> {
325    fn build_tcp<S: Stream + ExtensionsMut + Unpin + Send + Sync + 'static>(
326        self,
327        #[cfg(any(feature = "rustls", feature = "boring"))] maybe_tls_accept_layer: Option<
328            TlsAcceptorLayer,
329        >,
330    ) -> Result<impl Service<S, Output = (), Error = Infallible>, BoxError> {
331        let tcp_forwarded_layer = match &self.forward {
332            None => None,
333            Some(ForwardKind::HaProxy) => Some(HaProxyLayer::default()),
334            Some(other) => {
335                return Err(BoxError::from("invalid forward kind for Transport mode")
336                    .with_context_debug_field("kind", || other.clone()));
337            }
338        };
339
340        let tcp_service_builder = (
341            ConsumeErrLayer::trace_as(tracing::Level::DEBUG),
342            LimitLayer::new(if self.concurrent_limit > 0 {
343                Either::A(ConcurrentPolicy::max(self.concurrent_limit))
344            } else {
345                Either::B(UnlimitedPolicy::new())
346            }),
347            if !self.timeout.is_zero() {
348                TimeoutLayer::new(self.timeout)
349            } else {
350                TimeoutLayer::never()
351            },
352            tcp_forwarded_layer,
353            #[cfg(any(feature = "rustls", feature = "boring"))]
354            maybe_tls_accept_layer,
355        );
356
357        Ok(tcp_service_builder.into_layer(TcpIpService))
358    }
359
360    fn build_http<S: Stream + Unpin + Send + Sync + ExtensionsMut + 'static>(
361        self,
362        executor: Executor,
363        #[cfg(any(feature = "rustls", feature = "boring"))] maybe_tls_accept_layer: Option<
364            TlsAcceptorLayer,
365        >,
366    ) -> Result<impl Service<S, Output = (), Error = Infallible>, BoxError> {
367        let (tcp_forwarded_layer, http_forwarded_layer) = match &self.forward {
368            None => (None, None),
369            Some(ForwardKind::Forwarded) => {
370                (None, Some(Either7::A(GetForwardedHeaderLayer::forwarded())))
371            }
372            Some(ForwardKind::XForwardedFor) => (
373                None,
374                Some(Either7::B(GetForwardedHeaderLayer::x_forwarded_for())),
375            ),
376            Some(ForwardKind::XClientIp) => (
377                None,
378                Some(Either7::C(GetForwardedHeaderLayer::<XClientIp>::new())),
379            ),
380            Some(ForwardKind::ClientIp) => (
381                None,
382                Some(Either7::D(GetForwardedHeaderLayer::<ClientIp>::new())),
383            ),
384            Some(ForwardKind::XRealIp) => (
385                None,
386                Some(Either7::E(GetForwardedHeaderLayer::<XRealIp>::new())),
387            ),
388            Some(ForwardKind::CFConnectingIp) => (
389                None,
390                Some(Either7::F(GetForwardedHeaderLayer::<CFConnectingIp>::new())),
391            ),
392            Some(ForwardKind::TrueClientIp) => (
393                None,
394                Some(Either7::G(GetForwardedHeaderLayer::<TrueClientIp>::new())),
395            ),
396            Some(ForwardKind::HaProxy) => (Some(HaProxyLayer::default()), None),
397        };
398
399        #[cfg(any(feature = "rustls", feature = "boring"))]
400        let hsts_layer = maybe_tls_accept_layer.is_some().then(|| {
401            SetResponseHeaderLayer::if_not_present_typed(
402                StrictTransportSecurity::excluding_subdomains_for_max_seconds(31536000),
403            )
404        });
405
406        let tcp_service_builder = (
407            ConsumeErrLayer::trace_as(tracing::Level::DEBUG),
408            (self.concurrent_limit > 0)
409                .then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
410            (!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
411            tcp_forwarded_layer,
412            // Limit the body size to 1MB for requests
413            BodyLimitLayer::request_only(1024 * 1024),
414            #[cfg(any(feature = "rustls", feature = "boring"))]
415            maybe_tls_accept_layer,
416        );
417
418        let http_service = (
419            TraceLayer::new_for_http(),
420            SetResponseHeaderLayer::<XClacksOverhead>::if_not_present_default_typed(),
421            AddRequiredResponseHeadersLayer::default(),
422            ConsumeErrLayer::default(),
423            #[cfg(any(feature = "rustls", feature = "boring"))]
424            hsts_layer,
425            http_forwarded_layer,
426        )
427            .into_layer(HttpIpService);
428
429        Ok(tcp_service_builder.into_layer(HttpServer::auto(executor).service(http_service)))
430    }
431}
432
433pub mod mode {
434    //! operation modes of the ip service
435
436    #[derive(Debug, Clone)]
437    #[non_exhaustive]
438    /// Default mode of the Ip service, echo'ng the info back over http
439    pub struct Http;
440
441    #[derive(Debug, Clone)]
442    #[non_exhaustive]
443    /// Alternative mode of the Ip service, echo'ng the ip info over tcp
444    pub struct Transport;
445}
446
447fn format_html_page(ip: IpAddr) -> Html<String> {
448    Html(format!(
449        r##"<!doctype html> <html lang="en"> <head> <meta charset="utf-8" /> <meta name="viewport" content="width=device-width,initial-scale=1" /> <link rel="icon" href="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 100 100'><text y='0.9em' font-size='90'>🦙</text></svg>" /> <title>Rama IP</title> <style> *, *::before, *::after {{ box-sizing: border-box; }} :root{{ --bg:#000; --panel:#0f0f0f; --green:#45d23a; --muted:#bfbfbf; }} html,body{{height:100%;margin:0;font-family:system-ui,-apple-system,Segoe UI,Roboto,"Helvetica Neue",Arial;}} body{{ background:var(--bg); color:var(--muted); display:flex; align-items:center; justify-content:center; padding:2.8rem; }} .card{{ text-align:center; }} .logo{{ display:flex; align-items:center; justify-content:center; gap:0.8rem; margin-bottom:1.1rem; }} .logo, .logo a, .logo a:hover {{ color:var(--green); font-weight:700; font-size:2rem; letter-spacing:0.4rem; }} .logo a {{ text-decoration: none; }} .logo a:hover {{ text-decoration: underline; }} .subtitle{{ font-size:1.1rem; margin:0.3rem 0 2rem 0; color:var(--muted); }} .panel{{ background:linear-gradient(180deg,#0b0b0b 0%, #111 100%); border-radius:0.8rem; padding:2rem; box-shadow:0 0.3rem 2rem rgba(0,0,0,0.7), inset 0 0.05rem 0 rgba(255,255,255,0.02); border:0.1rem solid rgba(69,210,58,0.06); }} .ip{{ background:transparent; border-radius:0.6rem; padding:1rem 1.1rem; font-family: ui-monospace,SFMono-Regular,Menlo,monospace; font-size:1.1rem; color:#fff139; margin:0.6rem auto 1.1rem auto; word-break:break-all; border:0.05rem solid rgba(69,210,58,0.12); }} .muted{{ color:var(--muted); font-size:1rem; margin-bottom:0.9rem; }} .controls{{display:flex;gap:0.8rem;justify-content:center;flex-wrap:wrap;}} button{{ background:transparent; color:var(--green); padding:0.8rem 1.1rem; border-radius:0.6rem; font-weight:700; border:0.1rem solid rgba(69,210,58,0.9); cursor:pointer; }} button.primary{{ background:var(--green); color:#032; box-shadow:0 0.4rem 1.2rem rgba(69,210,58,0.08); }} .note{{font-size:0.95rem;color:#9aa; margin-top:1rem;}} .small{{font-size:0.9rem;color:#808080;margin-top:0.7rem}} </style> </head> <body> <div class="card"> <div class="logo"> <div>🦙</div> <div><a href="https://ramaproxy.org">ラマ</a></div> </div> <div class="panel" role="region" aria-label="ip panel"> <div class="muted">Your public ip</div><div id="ip" class="ip"> <code>{ip}</code> </div> <div class="controls"> <button id="copyBtn" class="primary" title="Copy ip to clipboard">📋 Copy IP</button></div> </div> <script> (async function(){{ const ipEl = document.getElementById('ip'); const copyBtn = document.getElementById('copyBtn'); copyBtn.addEventListener('click', async ()=>{{ const txt = ipEl.textContent.trim(); try{{ await navigator.clipboard.writeText(txt); copyBtn.textContent = 'Copied'; setTimeout(()=> copyBtn.textContent = 'Copy IP', 1400); }}catch(e){{ const ta = document.createElement('textarea'); ta.value = txt; document.body.appendChild(ta); ta.select(); try{{ document.execCommand('copy'); copyBtn.textContent = 'Copied'; }} catch(e){{ alert('Copy failed. Select and copy manually.'); }} ta.remove(); setTimeout(()=> copyBtn.textContent = 'Copy IP', 1400); }} }}); }})(); </script> </body> </html>"##,
450    ))
451}