1use rama_core::error::ErrorExt as _;
4
5use crate::{
6 Layer, Service,
7 cli::ForwardKind,
8 combinators::Either,
9 combinators::{Either3, Either7},
10 error::BoxError,
11 http::{
12 Request, Response, Version,
13 headers::exotic::XClacksOverhead,
14 headers::forwarded::{CFConnectingIp, ClientIp, TrueClientIp, XClientIp, XRealIp},
15 layer::set_header::SetResponseHeaderLayer,
16 layer::{
17 forwarded::GetForwardedHeaderLayer, required_header::AddRequiredResponseHeadersLayer,
18 trace::TraceLayer,
19 },
20 server::HttpServer,
21 service::{
22 fs::{DirectoryServeMode, ServeDir, ServeFile},
23 web::StaticService,
24 web::response::{Html, IntoResponse},
25 },
26 },
27 layer::limit::policy::UnlimitedPolicy,
28 layer::{ConsumeErrLayer, LimitLayer, TimeoutLayer, limit::policy::ConcurrentPolicy},
29 net::stream::layer::http::BodyLimitLayer,
30 proxy::haproxy::server::HaProxyLayer,
31 rt::Executor,
32 tcp::TcpStream,
33 telemetry::tracing,
34 ua::layer::classifier::UserAgentClassifierLayer,
35};
36
37use std::{convert::Infallible, path::PathBuf, sync::Arc, time::Duration};
38
39#[cfg(feature = "boring")]
40use crate::{
41 net::tls::server::ServerConfig,
42 tls::boring::server::{TlsAcceptorData, TlsAcceptorLayer},
43};
44
45#[cfg(all(feature = "rustls", not(feature = "boring")))]
46use crate::tls::rustls::server::{TlsAcceptorData, TlsAcceptorLayer};
47
48#[cfg(feature = "boring")]
49type TlsConfig = ServerConfig;
50
51#[cfg(all(feature = "rustls", not(feature = "boring")))]
52type TlsConfig = TlsAcceptorData;
53
54#[derive(Debug, Clone)]
55pub struct FsServiceBuilder<H> {
58 concurrent_limit: usize,
59 body_limit: usize,
60 timeout: Duration,
61 forward: Option<ForwardKind>,
62
63 #[cfg(any(feature = "rustls", feature = "boring"))]
64 tls_server_config: Option<TlsConfig>,
65
66 http_version: Option<Version>,
67
68 http_service_builder: H,
69
70 content_path: Option<PathBuf>,
71 dir_serve_mode: DirectoryServeMode,
72}
73
74impl Default for FsServiceBuilder<()> {
75 fn default() -> Self {
76 Self {
77 concurrent_limit: 0,
78 body_limit: 1024 * 1024,
79 timeout: Duration::ZERO,
80 forward: None,
81
82 #[cfg(any(feature = "rustls", feature = "boring"))]
83 tls_server_config: None,
84
85 http_version: None,
86
87 http_service_builder: (),
88
89 content_path: None,
90 dir_serve_mode: DirectoryServeMode::HtmlFileList,
91 }
92 }
93}
94
95impl FsServiceBuilder<()> {
96 #[must_use]
98 pub fn new() -> Self {
99 Self::default()
100 }
101}
102
103impl<H> FsServiceBuilder<H> {
104 rama_utils::macros::generate_set_and_with! {
105 pub fn concurrent(mut self, limit: usize) -> Self {
109 self.concurrent_limit = limit;
110 self
111 }
112 }
113
114 rama_utils::macros::generate_set_and_with! {
115 pub fn body_limit(mut self, limit: usize) -> Self {
117 self.body_limit = limit;
118 self
119 }
120 }
121
122 rama_utils::macros::generate_set_and_with! {
123 pub fn timeout(mut self, timeout: Duration) -> Self {
127 self.timeout = timeout;
128 self
129 }
130 }
131
132 rama_utils::macros::generate_set_and_with! {
133 pub fn forward(mut self, kind: Option<ForwardKind>) -> Self {
145 self.forward = kind;
146 self
147 }
148 }
149
150 #[cfg(any(feature = "rustls", feature = "boring"))]
151 rama_utils::macros::generate_set_and_with! {
152 pub fn tls_server_config(mut self, cfg: Option<TlsConfig>) -> Self {
155 self.tls_server_config = cfg;
156 self
157 }
158 }
159
160 rama_utils::macros::generate_set_and_with! {
161 pub fn http_version(mut self, version: Option<Version>) -> Self {
163 self.http_version = version;
164 self
165 }
166 }
167
168 #[must_use]
170 pub fn with_http_layer<H2>(self, layer: H2) -> FsServiceBuilder<(H, H2)> {
171 FsServiceBuilder {
172 concurrent_limit: self.concurrent_limit,
173 body_limit: self.body_limit,
174 timeout: self.timeout,
175 forward: self.forward,
176
177 #[cfg(any(feature = "rustls", feature = "boring"))]
178 tls_server_config: self.tls_server_config,
179
180 http_version: self.http_version,
181
182 http_service_builder: (self.http_service_builder, layer),
183
184 content_path: self.content_path,
185 dir_serve_mode: self.dir_serve_mode,
186 }
187 }
188
189 rama_utils::macros::generate_set_and_with! {
190 pub fn content_path(mut self, path: impl Into<PathBuf>) -> Self {
192 self.content_path = Some(path.into());
193 self
194 }
195 }
196
197 #[must_use]
199 pub fn maybe_with_content_path<P: Into<PathBuf>>(mut self, path: Option<P>) -> Self {
200 self.content_path = path.map(Into::into);
201 self
202 }
203
204 pub fn maybe_set_content_path<P: Into<PathBuf>>(&mut self, path: Option<P>) -> &mut Self {
206 self.content_path = path.map(Into::into);
207 self
208 }
209
210 rama_utils::macros::generate_set_and_with! {
211 pub fn directory_serve_mode(mut self, mode: DirectoryServeMode) -> Self {
219 self.dir_serve_mode = mode;
220 self
221 }
222 }
223}
224
225impl<H> FsServiceBuilder<H>
226where
227 H: Layer<ServeService, Service: Service<Request, Output = Response, Error: Into<BoxError>>>,
228{
229 pub fn build(
231 self,
232 executor: Executor,
233 ) -> Result<impl Service<TcpStream, Output = (), Error = Infallible>, BoxError> {
234 let tcp_forwarded_layer = match &self.forward {
235 Some(ForwardKind::HaProxy) => Some(HaProxyLayer::default()),
236 _ => None,
237 };
238
239 let http_service = Arc::new(self.build_http()?);
240
241 #[cfg(all(feature = "rustls", not(feature = "boring")))]
242 let tls_cfg = self.tls_server_config;
243
244 #[cfg(feature = "boring")]
245 let tls_cfg: Option<TlsAcceptorData> = match self.tls_server_config {
246 Some(cfg) => Some(cfg.try_into()?),
247 None => None,
248 };
249
250 let tcp_service_builder = (
251 ConsumeErrLayer::trace_as(tracing::Level::DEBUG),
252 LimitLayer::new(if self.concurrent_limit > 0 {
253 Either::A(ConcurrentPolicy::max(self.concurrent_limit))
254 } else {
255 Either::B(UnlimitedPolicy::new())
256 }),
257 if !self.timeout.is_zero() {
258 TimeoutLayer::new(self.timeout)
259 } else {
260 TimeoutLayer::never()
261 },
262 tcp_forwarded_layer,
263 BodyLimitLayer::request_only(self.body_limit),
264 #[cfg(any(feature = "rustls", feature = "boring"))]
265 tls_cfg.map(|cfg| {
266 #[cfg(feature = "boring")]
267 return TlsAcceptorLayer::new(cfg).with_store_client_hello(true);
268 #[cfg(all(feature = "rustls", not(feature = "boring")))]
269 TlsAcceptorLayer::new(cfg).with_store_client_hello(true)
270 }),
271 );
272
273 let http_transport_service = match self.http_version {
274 Some(Version::HTTP_2) => Either3::A(HttpServer::h2(executor).service(http_service)),
275 Some(Version::HTTP_11 | Version::HTTP_10 | Version::HTTP_09) => {
276 Either3::B(HttpServer::http1(executor).service(http_service))
277 }
278 Some(version) => {
279 return Err(BoxError::from("unsupported http version")
280 .context_debug_field("version", version));
281 }
282 None => Either3::C(HttpServer::auto(executor).service(http_service)),
283 };
284
285 Ok(tcp_service_builder.into_layer(http_transport_service))
286 }
287
288 pub fn build_http(
290 &self,
291 ) -> Result<impl Service<Request, Output: IntoResponse, Error = Infallible> + use<H>, BoxError>
292 {
293 let http_forwarded_layer = match &self.forward {
294 None | Some(ForwardKind::HaProxy) => None,
295 Some(ForwardKind::Forwarded) => Some(Either7::A(GetForwardedHeaderLayer::forwarded())),
296 Some(ForwardKind::XForwardedFor) => {
297 Some(Either7::B(GetForwardedHeaderLayer::x_forwarded_for()))
298 }
299 Some(ForwardKind::XClientIp) => {
300 Some(Either7::C(GetForwardedHeaderLayer::<XClientIp>::new()))
301 }
302 Some(ForwardKind::ClientIp) => {
303 Some(Either7::D(GetForwardedHeaderLayer::<ClientIp>::new()))
304 }
305 Some(ForwardKind::XRealIp) => {
306 Some(Either7::E(GetForwardedHeaderLayer::<XRealIp>::new()))
307 }
308 Some(ForwardKind::CFConnectingIp) => {
309 Some(Either7::F(GetForwardedHeaderLayer::<CFConnectingIp>::new()))
310 }
311 Some(ForwardKind::TrueClientIp) => {
312 Some(Either7::G(GetForwardedHeaderLayer::<TrueClientIp>::new()))
313 }
314 };
315
316 let serve_service = match &self.content_path {
317 None => Either3::A(StaticService::new(Html(include_str!(
318 "../../../docs/index.html"
319 )))),
320 Some(path) if path.is_file() => Either3::B(ServeFile::new(path.clone())),
321 Some(path) if path.is_dir() => {
322 Either3::C(ServeDir::new(path).with_directory_serve_mode(self.dir_serve_mode))
323 }
324 Some(path) => {
325 return Err(BoxError::from("invalid path: no such file or directory")
326 .with_context_debug_field("path", || path.clone()));
327 }
328 };
329
330 let http_service = (
331 TraceLayer::new_for_http(),
332 SetResponseHeaderLayer::<XClacksOverhead>::if_not_present_default_typed(),
333 AddRequiredResponseHeadersLayer::default(),
334 UserAgentClassifierLayer::new(),
335 ConsumeErrLayer::default(),
336 http_forwarded_layer,
337 )
338 .into_layer(self.http_service_builder.layer(serve_service));
339
340 Ok(http_service)
341 }
342}
343
344type ServeStaticHtml = StaticService<Html<&'static str>>;
345type ServeService = Either3<ServeStaticHtml, ServeFile, ServeDir>;