use crate::{
cli::ForwardKind,
combinators::Either7,
error::{BoxError, OpaqueError},
http::{
headers::{CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp},
layer::{
forwarded::GetForwardedHeadersLayer, required_header::AddRequiredResponseHeadersLayer,
trace::TraceLayer, ua::UserAgentClassifierLayer,
},
server::HttpServer,
IntoResponse, Request, Response, StatusCode,
},
layer::{limit::policy::ConcurrentPolicy, ConsumeErrLayer, LimitLayer, TimeoutLayer},
net::forwarded::Forwarded,
net::stream::{layer::http::BodyLimitLayer, SocketInfo, Stream},
proxy::haproxy::server::HaProxyLayer,
rt::Executor,
Context, Layer, Service,
};
use std::{convert::Infallible, marker::PhantomData, time::Duration};
use tokio::{io::AsyncWriteExt, net::TcpStream};
#[derive(Debug, Clone)]
pub struct IpServiceBuilder<M> {
concurrent_limit: usize,
timeout: Duration,
forward: Option<ForwardKind>,
_mode: PhantomData<fn(M)>,
}
impl Default for IpServiceBuilder<mode::Http> {
fn default() -> Self {
Self {
concurrent_limit: 0,
timeout: Duration::ZERO,
forward: None,
_mode: PhantomData,
}
}
}
impl IpServiceBuilder<mode::Http> {
pub fn http() -> Self {
Self::default()
}
}
impl IpServiceBuilder<mode::Transport> {
pub fn tcp() -> Self {
Self {
concurrent_limit: 0,
timeout: Duration::ZERO,
forward: None,
_mode: PhantomData,
}
}
}
impl<M> IpServiceBuilder<M> {
pub fn concurrent(mut self, limit: usize) -> Self {
self.concurrent_limit = limit;
self
}
pub fn set_concurrent(&mut self, limit: usize) -> &mut Self {
self.concurrent_limit = limit;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn set_timeout(&mut self, timeout: Duration) -> &mut Self {
self.timeout = timeout;
self
}
pub fn forward(self, kind: ForwardKind) -> Self {
self.maybe_forward(Some(kind))
}
pub fn set_forward(&mut self, kind: ForwardKind) -> &mut Self {
self.forward = Some(kind);
self
}
pub fn maybe_forward(mut self, maybe_kind: Option<ForwardKind>) -> Self {
self.forward = maybe_kind;
self
}
}
impl IpServiceBuilder<mode::Http> {
pub fn build(
self,
executor: Executor,
) -> Result<impl Service<(), TcpStream, Response = (), Error = Infallible>, BoxError> {
let (tcp_forwarded_layer, http_forwarded_layer) = match &self.forward {
None => (None, None),
Some(ForwardKind::Forwarded) => (
None,
Some(Either7::A(GetForwardedHeadersLayer::forwarded())),
),
Some(ForwardKind::XForwardedFor) => (
None,
Some(Either7::B(GetForwardedHeadersLayer::x_forwarded_for())),
),
Some(ForwardKind::XClientIp) => (
None,
Some(Either7::C(GetForwardedHeadersLayer::<XClientIp>::new())),
),
Some(ForwardKind::ClientIp) => (
None,
Some(Either7::D(GetForwardedHeadersLayer::<ClientIp>::new())),
),
Some(ForwardKind::XRealIp) => (
None,
Some(Either7::E(GetForwardedHeadersLayer::<XRealIp>::new())),
),
Some(ForwardKind::CFConnectingIp) => (
None,
Some(Either7::F(GetForwardedHeadersLayer::<CFConnectingIp>::new())),
),
Some(ForwardKind::TrueClientIp) => (
None,
Some(Either7::G(GetForwardedHeadersLayer::<TrueClientIp>::new())),
),
Some(ForwardKind::HaProxy) => (Some(HaProxyLayer::default()), None),
};
let tcp_service_builder = (
ConsumeErrLayer::trace(tracing::Level::DEBUG),
(self.concurrent_limit > 0)
.then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
(!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
tcp_forwarded_layer,
BodyLimitLayer::request_only(1024 * 1024),
);
let http_service = (
TraceLayer::new_for_http(),
AddRequiredResponseHeadersLayer::default(),
UserAgentClassifierLayer::new(),
ConsumeErrLayer::default(),
http_forwarded_layer,
)
.layer(HttpEchoService);
Ok(tcp_service_builder.layer(HttpServer::auto(executor).service(http_service)))
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct HttpEchoService;
impl Service<(), Request> for HttpEchoService {
type Response = Response;
type Error = BoxError;
async fn serve(&self, ctx: Context<()>, _req: Request) -> Result<Self::Response, Self::Error> {
let peer_ip = ctx
.get::<Forwarded>()
.and_then(|f| f.client_ip())
.or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
Ok(match peer_ip {
Some(ip) => ip.to_string().into_response(),
None => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
})
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TcpEchoService;
impl<Input> Service<(), Input> for TcpEchoService
where
Input: Stream + Unpin,
{
type Response = ();
type Error = BoxError;
async fn serve(&self, ctx: Context<()>, stream: Input) -> Result<Self::Response, Self::Error> {
let peer_ip = ctx
.get::<Forwarded>()
.and_then(|f| f.client_ip())
.or_else(|| ctx.get::<SocketInfo>().map(|s| s.peer_addr().ip()));
let peer_ip = match peer_ip {
Some(peer_ip) => peer_ip,
None => {
tracing::error!("missing peer information");
return Ok(());
}
};
let mut stream = std::pin::pin!(stream);
match peer_ip {
std::net::IpAddr::V4(ip) => {
if let Err(err) = stream.write_all(&ip.octets()).await {
tracing::error!("error writing IPv4 of peer to peer: {}", err);
}
}
std::net::IpAddr::V6(ip) => {
if let Err(err) = stream.write_all(&ip.octets()).await {
tracing::error!("error writing IPv6 of peer to peer: {}", err);
}
}
};
Ok(())
}
}
impl IpServiceBuilder<mode::Transport> {
pub fn build(
self,
) -> Result<impl Service<(), TcpStream, Response = (), Error = Infallible>, BoxError> {
let tcp_forwarded_layer = match &self.forward {
None => None,
Some(ForwardKind::HaProxy) => Some(HaProxyLayer::default()),
Some(other) => {
return Err(OpaqueError::from_display(format!(
"invalid forward kind for Transport mode: {other:?}"
))
.into())
}
};
let tcp_service_builder = (
ConsumeErrLayer::trace(tracing::Level::DEBUG),
(self.concurrent_limit > 0)
.then(|| LimitLayer::new(ConcurrentPolicy::max(self.concurrent_limit))),
(!self.timeout.is_zero()).then(|| TimeoutLayer::new(self.timeout)),
tcp_forwarded_layer,
);
Ok(tcp_service_builder.layer(TcpEchoService))
}
}
pub mod mode {
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct Http;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct Transport;
}