rama/cli/
args.rs

1//! build requests from command line arguments
2
3use crate::{
4    error::{ErrorContext, OpaqueError},
5    http::{
6        Body, Method, Request, Uri,
7        header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE, Entry, HeaderValue},
8    },
9};
10use ahash::{HashMap, HashMapExt as _};
11use rama_core::extensions::ExtensionsMut;
12use rama_http::proto::h1::{Http1HeaderName, headers::original::OriginalHttp1Headers};
13use rama_utils::macros::match_ignore_ascii_case_str;
14use serde_json::Value;
15
16#[derive(Debug, Clone)]
17/// A builder to create a request from command line arguments.
18pub struct RequestArgsBuilder {
19    state: BuilderState,
20}
21
22impl Default for RequestArgsBuilder {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl RequestArgsBuilder {
29    /// Create a new [`RequestArgsBuilder`], which auto-detects the content type.
30    #[must_use]
31    pub const fn new() -> Self {
32        Self {
33            state: BuilderState::MethodOrUrl { content_type: None },
34        }
35    }
36
37    /// Create a new [`RequestArgsBuilder`], which expects JSON data.
38    #[must_use]
39    pub const fn new_json() -> Self {
40        Self {
41            state: BuilderState::MethodOrUrl {
42                content_type: Some(ContentType::Json),
43            },
44        }
45    }
46
47    /// Create a new [`RequestArgsBuilder`], which expects Form data.
48    #[must_use]
49    pub const fn new_form() -> Self {
50        Self {
51            state: BuilderState::MethodOrUrl {
52                content_type: Some(ContentType::Form),
53            },
54        }
55    }
56
57    /// parse a command line argument, the possible meaning
58    /// depend on the current state of the builder, driven by the position of the argument.
59    pub fn parse_arg(&mut self, arg: String) {
60        let new_state = match &mut self.state {
61            BuilderState::MethodOrUrl { content_type } => {
62                if let Some(method) = parse_arg_as_method(&arg) {
63                    Some(BuilderState::Url {
64                        content_type: *content_type,
65                        method: Some(method),
66                    })
67                } else {
68                    Some(BuilderState::Data {
69                        content_type: *content_type,
70                        method: None,
71                        url: arg,
72                        query: HashMap::new(),
73                        headers: Vec::new(),
74                        body: HashMap::new(),
75                    })
76                }
77            }
78            BuilderState::Url {
79                content_type,
80                method,
81            } => Some(BuilderState::Data {
82                content_type: *content_type,
83                method: method.clone(),
84                url: arg,
85                query: HashMap::new(),
86                headers: Vec::new(),
87                body: HashMap::new(),
88            }),
89            BuilderState::Data {
90                query,
91                headers,
92                body,
93                ..
94            } => match parse_arg_as_data(&arg, query, headers, body) {
95                Ok(_) => None,
96                Err(msg) => Some(BuilderState::Error {
97                    message: msg,
98                    ignored: vec![],
99                }),
100            },
101            BuilderState::Error { ignored, .. } => {
102                ignored.push(arg);
103                None
104            }
105        };
106        if let Some(new_state) = new_state {
107            self.state = new_state;
108        }
109    }
110
111    /// Build the request from the parsed arguments.
112    pub fn build(self) -> Result<Request, OpaqueError> {
113        match self.state {
114            BuilderState::MethodOrUrl { .. } | BuilderState::Url { .. } => {
115                Err(OpaqueError::from_display("no url defined"))
116            }
117            BuilderState::Error { message, ignored } => {
118                Err(OpaqueError::from_display(if ignored.is_empty() {
119                    format!("request arg parser failed: {message}")
120                } else {
121                    format!("request arg parser failed: {message} (ignored: {ignored:?})")
122                }))
123            }
124            BuilderState::Data {
125                content_type,
126                method,
127                url,
128                query,
129                headers,
130                body,
131            } => {
132                let mut req = Request::builder();
133
134                let url = expand_url(url);
135
136                let uri: Uri = url
137                    .parse()
138                    .map_err(OpaqueError::from_std)
139                    .context("parse base uri")?;
140
141                if query.is_empty() {
142                    req = req.uri(url);
143                } else {
144                    let mut uri_parts = uri.into_parts();
145                    uri_parts.path_and_query = Some(if let Some(pq) = uri_parts.path_and_query {
146                        if let Some(q) = pq.query() {
147                            let mut existing_query: HashMap<String, Vec<String>> =
148                                serde_html_form::from_str(q)
149                                    .map_err(OpaqueError::from_std)
150                                    .context("parse existing query")?;
151                            for (k, v) in query {
152                                existing_query.entry(k).or_default().extend(v);
153                            }
154                            let query = serde_html_form::to_string(&existing_query)
155                                .map_err(OpaqueError::from_std)
156                                .context("serialize extended query")?;
157                            format!("{}?{}", pq.path(), query)
158                                .parse()
159                                .map_err(OpaqueError::from_std)
160                                .context("create new path+query from extended query")?
161                        } else {
162                            let query = serde_html_form::to_string(&query)
163                                .map_err(OpaqueError::from_std)
164                                .context("serialize new and only query params")?;
165                            format!("{}?{}", pq.path(), query)
166                                .parse()
167                                .map_err(OpaqueError::from_std)
168                                .context("create path+query from given query params")?
169                        }
170                    } else {
171                        let query =
172                            serde_html_form::to_string(&query).map_err(OpaqueError::from_std)?;
173                        format!("/?{query}")
174                            .parse()
175                            .map_err(OpaqueError::from_std)?
176                    });
177                    req = req.uri(Uri::from_parts(uri_parts).map_err(OpaqueError::from_std)?);
178                }
179
180                match method {
181                    Some(method) => req = req.method(method),
182                    None => {
183                        if body.is_empty() {
184                            req = req.method(Method::GET);
185                        } else {
186                            req = req.method(Method::POST);
187                        }
188                    }
189                }
190
191                let mut header_order = OriginalHttp1Headers::with_capacity(headers.len());
192                for (name, value) in headers {
193                    let header_name = Http1HeaderName::try_copy_from_str(name.as_str())
194                        .context("convert string into Http1HeaderName")?;
195                    req = req.header(header_name.clone(), value);
196                    header_order.push(header_name);
197                }
198
199                if body.is_empty() {
200                    let mut req = req
201                        .body(Body::empty())
202                        .map_err(OpaqueError::from_std)
203                        .context("create request without body")?;
204
205                    req.extensions_mut().insert(header_order);
206
207                    return Ok(req);
208                }
209
210                let ct = content_type.unwrap_or_else(|| {
211                    match req
212                        .headers_ref()
213                        .and_then(|h| h.get(CONTENT_TYPE))
214                        .and_then(|h| h.to_str().ok())
215                    {
216                        Some(cv) if cv.contains("application/x-www-form-urlencoded") => {
217                            ContentType::Form
218                        }
219                        _ => ContentType::Json,
220                    }
221                });
222
223                let req = if req.headers_ref().is_none() {
224                    let req = req.header(CONTENT_TYPE, ct.header_value());
225                    header_order.push(CONTENT_TYPE.into());
226                    if ct == ContentType::Json {
227                        header_order.push(ACCEPT.into());
228                        req.header(ACCEPT, ct.header_value())
229                    } else {
230                        req
231                    }
232                } else {
233                    let headers = req.headers_mut().unwrap();
234
235                    if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) {
236                        header_order.push(CONTENT_TYPE.into());
237                        entry.insert(ct.header_value());
238                    }
239
240                    if ct == ContentType::Json
241                        && let Entry::Vacant(entry) = headers.entry(ACCEPT)
242                    {
243                        header_order.push(ACCEPT.into());
244                        entry.insert(ct.header_value());
245                    }
246
247                    req
248                };
249
250                let mut req = match ct {
251                    ContentType::Json => {
252                        let body = serde_json::to_string(&body)
253                            .map_err(OpaqueError::from_std)
254                            .context("serialize form body")?;
255                        header_order.push(CONTENT_LENGTH.into());
256                        req.header(CONTENT_LENGTH, body.len().to_string())
257                            .body(Body::from(body))
258                    }
259                    ContentType::Form => {
260                        let body = serde_html_form::to_string(&body)
261                            .map_err(OpaqueError::from_std)
262                            .context("serialize json body")?;
263                        header_order.push(CONTENT_LENGTH.into());
264                        req.header(CONTENT_LENGTH, body.len().to_string())
265                            .body(Body::from(body))
266                    }
267                }
268                .map_err(OpaqueError::from_std)
269                .context("create request with body")?;
270
271                req.extensions_mut().insert(header_order);
272
273                Ok(req)
274            }
275        }
276    }
277}
278
279fn parse_arg_as_data(
280    arg: &str,
281    query: &mut HashMap<String, Vec<String>>,
282    headers: &mut Vec<(String, String)>,
283    body: &mut HashMap<String, Value>,
284) -> Result<(), String> {
285    let mut state = DataParseArgState::None;
286    for (i, c) in arg.chars().enumerate() {
287        match state {
288            DataParseArgState::None => match c {
289                '\\' => state = DataParseArgState::Escaped,
290                '=' => state = DataParseArgState::Equal,
291                ':' => state = DataParseArgState::Colon,
292                _ => (),
293            },
294            DataParseArgState::Escaped => {
295                // \*
296                state = DataParseArgState::None;
297            }
298            DataParseArgState::Equal => {
299                let (name, value) = arg.split_at(i - 1);
300                if c == '=' {
301                    // ==
302                    let value = &value[2..];
303                    query
304                        .entry(name.to_owned())
305                        .or_default()
306                        .push(value.to_owned());
307                } else {
308                    // =
309                    let value = &value[1..];
310                    body.insert(name.to_owned(), Value::String(value.to_owned()));
311                }
312                break;
313            }
314            DataParseArgState::Colon => {
315                let (name, value) = arg.split_at(i - 1);
316                if c == '=' {
317                    // :=
318                    let value = &value[2..];
319                    let value: Value =
320                        serde_json::from_str(value).map_err(|err| err.to_string())?;
321                    body.insert(name.to_owned(), value);
322                } else {
323                    // :
324                    let value = &value[1..];
325                    headers.push((name.to_owned(), value.to_owned()));
326                }
327                break;
328            }
329        }
330    }
331    Ok(())
332}
333
334fn parse_arg_as_method(arg: impl AsRef<str>) -> Option<Method> {
335    match_ignore_ascii_case_str! {
336        match (arg.as_ref()) {
337            "GET" => Some(Method::GET),
338            "POST" => Some(Method::POST),
339            "PUT" => Some(Method::PUT),
340            "DELETE" => Some(Method::DELETE),
341            "PATCH" => Some(Method::PATCH),
342            "HEAD" => Some(Method::HEAD),
343            "OPTIONS" => Some(Method::OPTIONS),
344            "CONNECT" => Some(Method::CONNECT),
345            "TRACE" => Some(Method::TRACE),
346            _ => None,
347
348        }
349    }
350}
351
352/// Expand a URL string to a full URL,
353/// e.g. `example.com` -> `http://example.com`
354fn expand_url(url: String) -> String {
355    if url.is_empty() {
356        "http://localhost".to_owned()
357    } else if let Some(stripped_url) = url.strip_prefix(':') {
358        if stripped_url.is_empty() {
359            "http://localhost".to_owned()
360        } else if stripped_url
361            .chars()
362            .next()
363            .map(|c| c.is_ascii_digit())
364            .unwrap_or_default()
365        {
366            format!("http://localhost{url}")
367        } else {
368            format!("http://localhost{stripped_url}")
369        }
370    } else if !url.contains("://") {
371        format!("http://{url}")
372    } else {
373        url
374    }
375}
376
377enum DataParseArgState {
378    None,
379    Escaped,
380    Equal,
381    Colon,
382}
383
384#[derive(Debug, Clone, Copy, PartialEq, Hash)]
385enum ContentType {
386    Json,
387    Form,
388}
389
390impl ContentType {
391    fn header_value(self) -> HeaderValue {
392        HeaderValue::from_static(match self {
393            Self::Json => "application/json",
394            Self::Form => "application/x-www-form-urlencoded",
395        })
396    }
397}
398
399#[derive(Debug, Clone)]
400enum BuilderState {
401    MethodOrUrl {
402        content_type: Option<ContentType>,
403    },
404    Url {
405        content_type: Option<ContentType>,
406        method: Option<Method>,
407    },
408    Data {
409        content_type: Option<ContentType>,
410        method: Option<Method>,
411        url: String,
412        query: HashMap<String, Vec<String>>,
413        headers: Vec<(String, String)>,
414        body: HashMap<String, Value>,
415    },
416    Error {
417        message: String,
418        ignored: Vec<String>,
419    },
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use crate::http::io::write_http_request;
426
427    #[test]
428    fn test_parse_arg_as_method() {
429        for (arg, expected) in [
430            ("GET", Some(Method::GET)),
431            ("POST", Some(Method::POST)),
432            ("PUT", Some(Method::PUT)),
433            ("DELETE", Some(Method::DELETE)),
434            ("PATCH", Some(Method::PATCH)),
435            ("HEAD", Some(Method::HEAD)),
436            ("OPTIONS", Some(Method::OPTIONS)),
437            ("CONNECT", Some(Method::CONNECT)),
438            ("TRACE", Some(Method::TRACE)),
439            ("get", Some(Method::GET)),
440            ("post", Some(Method::POST)),
441            ("put", Some(Method::PUT)),
442            ("delete", Some(Method::DELETE)),
443            ("patch", Some(Method::PATCH)),
444            ("head", Some(Method::HEAD)),
445            ("options", Some(Method::OPTIONS)),
446            ("connect", Some(Method::CONNECT)),
447            ("trace", Some(Method::TRACE)),
448            ("invalid", None),
449            ("", None),
450        ] {
451            assert_eq!(parse_arg_as_method(arg), expected);
452        }
453    }
454
455    #[test]
456    fn test_expand_url() {
457        for (url, expected) in [
458            ("example.com", "http://example.com"),
459            ("http://example.com", "http://example.com"),
460            ("https://example.com", "https://example.com"),
461            ("example.com:8080", "http://example.com:8080"),
462            (":8080/foo", "http://localhost:8080/foo"),
463            (":8080", "http://localhost:8080"),
464            ("", "http://localhost"),
465        ] {
466            assert_eq!(expand_url(url.to_owned()), expected);
467        }
468    }
469
470    #[tokio::test]
471    async fn test_request_args_builder_happy() {
472        for (args, expected_request_str) in [
473            (vec![":8080"], "GET / HTTP/1.1\r\n\r\n"),
474            (vec!["HeAD", ":8000/foo"], "HEAD /foo HTTP/1.1\r\n\r\n"),
475            (
476                vec!["example.com/bar", "FOO:bar", "AnSweR:42"],
477                "GET /bar HTTP/1.1\r\nFOO: bar\r\nAnSweR: 42\r\n\r\n",
478            ),
479            (
480                vec![
481                    "example.com/foo",
482                    "c=d",
483                    "Content-Type:application/x-www-form-urlencoded",
484                ],
485                "POST /foo HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
486            ),
487            (
488                vec!["example.com/foo", "a=b", "Content-Type:application/json"],
489                "POST /foo HTTP/1.1\r\nContent-Type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
490            ),
491            (
492                vec!["example.com/foo", "a=b"],
493                "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
494            ),
495            (
496                vec!["example.com/foo", "x-a:1", "a=b"],
497                "POST /foo HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
498            ),
499            (
500                vec!["put", "example.com/foo?a=2", "x-a:1", "a:=42", "a==3"],
501                "PUT /foo?a=2&a=3 HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 8\r\n\r\n{\"a\":42}",
502            ),
503            (
504                vec![":3000", "Cookie:foo=bar"],
505                "GET / HTTP/1.1\r\nCookie: foo=bar\r\n\r\n",
506            ),
507            (
508                vec![":/foo", "search==rama"],
509                "GET /foo?search=rama HTTP/1.1\r\n\r\n",
510            ),
511            (
512                vec!["example.com", "description='CLI HTTP client'"],
513                "POST / HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 35\r\n\r\n{\"description\":\"'CLI HTTP client'\"}",
514            ),
515            (
516                vec!["example.com", "x-cfg:a=1&foo=bar&foo=baz"],
517                "GET / HTTP/1.1\r\nx-cfg: a=1&foo=bar&foo=baz\r\n\r\n",
518            ),
519        ] {
520            let mut builder = RequestArgsBuilder::new();
521            for arg in args {
522                builder.parse_arg(arg.to_owned());
523            }
524            let request = builder.build().unwrap();
525            let mut w = Vec::new();
526            let _ = write_http_request(&mut w, request, true, true)
527                .await
528                .unwrap();
529            assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
530        }
531    }
532
533    #[tokio::test]
534    async fn test_request_args_builder_form_happy() {
535        for (args, expected_request_str) in [(
536            vec!["example.com/foo", "c=d"],
537            "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
538        )] {
539            let mut builder = RequestArgsBuilder::new_form();
540            for arg in args {
541                builder.parse_arg(arg.to_owned());
542            }
543            let request = builder.build().unwrap();
544            let mut w = Vec::new();
545            let _ = write_http_request(&mut w, request, true, true)
546                .await
547                .unwrap();
548            assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
549        }
550    }
551
552    #[tokio::test]
553    async fn test_request_args_builder_json_happy() {
554        for (args, expected_request_str) in [(
555            vec!["example.com/foo", "a=b"],
556            "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
557        )] {
558            let mut builder = RequestArgsBuilder::new();
559            for arg in args {
560                builder.parse_arg(arg.to_owned());
561            }
562            let request = builder.build().unwrap();
563            let mut w = Vec::new();
564            let _ = write_http_request(&mut w, request, true, true)
565                .await
566                .unwrap();
567            assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
568        }
569    }
570
571    #[tokio::test]
572    async fn test_request_args_builder_error() {
573        for test in [
574            vec![],
575            vec!["invalid url"],
576            vec!["get"],
577            vec!["get", "invalid url"],
578        ] {
579            let mut builder = RequestArgsBuilder::new();
580            for arg in test {
581                builder.parse_arg(arg.to_owned());
582            }
583            let request = builder.build();
584            assert!(request.is_err());
585        }
586    }
587}