rama/cli/
args.rs

1//! build requests from command line arguments
2
3use crate::{
4    error::{ErrorContext, OpaqueError},
5    http::{
6        Body, Method, Request, Uri,
7        header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE, Entry, HeaderValue},
8    },
9};
10use rama_http::proto::h1::{Http1HeaderName, headers::original::OriginalHttp1Headers};
11use rama_utils::macros::match_ignore_ascii_case_str;
12use serde_json::Value;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
16/// A builder to create a request from command line arguments.
17pub struct RequestArgsBuilder {
18    state: BuilderState,
19}
20
21impl Default for RequestArgsBuilder {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl RequestArgsBuilder {
28    /// Create a new [`RequestArgsBuilder`], which auto-detects the content type.
29    #[must_use]
30    pub const fn new() -> Self {
31        Self {
32            state: BuilderState::MethodOrUrl { content_type: None },
33        }
34    }
35
36    /// Create a new [`RequestArgsBuilder`], which expects JSON data.
37    #[must_use]
38    pub const fn new_json() -> Self {
39        Self {
40            state: BuilderState::MethodOrUrl {
41                content_type: Some(ContentType::Json),
42            },
43        }
44    }
45
46    /// Create a new [`RequestArgsBuilder`], which expects Form data.
47    #[must_use]
48    pub const fn new_form() -> Self {
49        Self {
50            state: BuilderState::MethodOrUrl {
51                content_type: Some(ContentType::Form),
52            },
53        }
54    }
55
56    /// parse a command line argument, the possible meaning
57    /// depend on the current state of the builder, driven by the position of the argument.
58    pub fn parse_arg(&mut self, arg: String) {
59        let new_state = match &mut self.state {
60            BuilderState::MethodOrUrl { content_type } => {
61                if let Some(method) = parse_arg_as_method(&arg) {
62                    Some(BuilderState::Url {
63                        content_type: *content_type,
64                        method: Some(method),
65                    })
66                } else {
67                    Some(BuilderState::Data {
68                        content_type: *content_type,
69                        method: None,
70                        url: arg,
71                        query: HashMap::new(),
72                        headers: Vec::new(),
73                        body: HashMap::new(),
74                    })
75                }
76            }
77            BuilderState::Url {
78                content_type,
79                method,
80            } => Some(BuilderState::Data {
81                content_type: *content_type,
82                method: method.clone(),
83                url: arg,
84                query: HashMap::new(),
85                headers: Vec::new(),
86                body: HashMap::new(),
87            }),
88            BuilderState::Data {
89                query,
90                headers,
91                body,
92                ..
93            } => match parse_arg_as_data(&arg, query, headers, body) {
94                Ok(_) => None,
95                Err(msg) => Some(BuilderState::Error {
96                    message: msg,
97                    ignored: vec![],
98                }),
99            },
100            BuilderState::Error { ignored, .. } => {
101                ignored.push(arg);
102                None
103            }
104        };
105        if let Some(new_state) = new_state {
106            self.state = new_state;
107        }
108    }
109
110    /// Build the request from the parsed arguments.
111    pub fn build(self) -> Result<Request, OpaqueError> {
112        match self.state {
113            BuilderState::MethodOrUrl { .. } | BuilderState::Url { .. } => {
114                Err(OpaqueError::from_display("no url defined"))
115            }
116            BuilderState::Error { message, ignored } => {
117                Err(OpaqueError::from_display(if ignored.is_empty() {
118                    format!("request arg parser failed: {message}")
119                } else {
120                    format!("request arg parser failed: {message} (ignored: {ignored:?})")
121                }))
122            }
123            BuilderState::Data {
124                content_type,
125                method,
126                url,
127                query,
128                headers,
129                body,
130            } => {
131                let mut req = Request::builder();
132
133                let url = expand_url(url);
134
135                let uri: Uri = url
136                    .parse()
137                    .map_err(OpaqueError::from_std)
138                    .context("parse base uri")?;
139
140                if query.is_empty() {
141                    req = req.uri(url);
142                } else {
143                    let mut uri_parts = uri.into_parts();
144                    uri_parts.path_and_query = Some(if let Some(pq) = uri_parts.path_and_query {
145                        if let Some(q) = pq.query() {
146                            let mut existing_query: HashMap<String, Vec<String>> =
147                                serde_html_form::from_str(q)
148                                    .map_err(OpaqueError::from_std)
149                                    .context("parse existing query")?;
150                            for (k, v) in query {
151                                existing_query.entry(k).or_default().extend(v);
152                            }
153                            let query = serde_html_form::to_string(&existing_query)
154                                .map_err(OpaqueError::from_std)
155                                .context("serialize extended query")?;
156                            format!("{}?{}", pq.path(), query)
157                                .parse()
158                                .map_err(OpaqueError::from_std)
159                                .context("create new path+query from extended query")?
160                        } else {
161                            let query = serde_html_form::to_string(&query)
162                                .map_err(OpaqueError::from_std)
163                                .context("serialize new and only query params")?;
164                            format!("{}?{}", pq.path(), query)
165                                .parse()
166                                .map_err(OpaqueError::from_std)
167                                .context("create path+query from given query params")?
168                        }
169                    } else {
170                        let query =
171                            serde_html_form::to_string(&query).map_err(OpaqueError::from_std)?;
172                        format!("/?{query}")
173                            .parse()
174                            .map_err(OpaqueError::from_std)?
175                    });
176                    req = req.uri(Uri::from_parts(uri_parts).map_err(OpaqueError::from_std)?);
177                }
178
179                match method {
180                    Some(method) => req = req.method(method),
181                    None => {
182                        if body.is_empty() {
183                            req = req.method(Method::GET);
184                        } else {
185                            req = req.method(Method::POST);
186                        }
187                    }
188                }
189
190                let mut header_order = OriginalHttp1Headers::with_capacity(headers.len());
191                for (name, value) in headers {
192                    let header_name = Http1HeaderName::try_copy_from_str(name.as_str())
193                        .context("convert string into Http1HeaderName")?;
194                    req = req.header(header_name.clone(), value);
195                    header_order.push(header_name);
196                }
197
198                if body.is_empty() {
199                    let mut req = req
200                        .body(Body::empty())
201                        .map_err(OpaqueError::from_std)
202                        .context("create request without body")?;
203
204                    req.extensions_mut().insert(header_order);
205
206                    return Ok(req);
207                }
208
209                let ct = content_type.unwrap_or_else(|| {
210                    match req
211                        .headers_ref()
212                        .and_then(|h| h.get(CONTENT_TYPE))
213                        .and_then(|h| h.to_str().ok())
214                    {
215                        Some(cv) if cv.contains("application/x-www-form-urlencoded") => {
216                            ContentType::Form
217                        }
218                        _ => ContentType::Json,
219                    }
220                });
221
222                let req = if req.headers_ref().is_none() {
223                    let req = req.header(CONTENT_TYPE, ct.header_value());
224                    header_order.push(CONTENT_TYPE.into());
225                    if ct == ContentType::Json {
226                        header_order.push(ACCEPT.into());
227                        req.header(ACCEPT, ct.header_value())
228                    } else {
229                        req
230                    }
231                } else {
232                    let headers = req.headers_mut().unwrap();
233
234                    if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) {
235                        header_order.push(CONTENT_TYPE.into());
236                        entry.insert(ct.header_value());
237                    }
238
239                    if ct == ContentType::Json
240                        && let Entry::Vacant(entry) = headers.entry(ACCEPT)
241                    {
242                        header_order.push(ACCEPT.into());
243                        entry.insert(ct.header_value());
244                    }
245
246                    req
247                };
248
249                let mut req = match ct {
250                    ContentType::Json => {
251                        let body = serde_json::to_string(&body)
252                            .map_err(OpaqueError::from_std)
253                            .context("serialize form body")?;
254                        header_order.push(CONTENT_LENGTH.into());
255                        req.header(CONTENT_LENGTH, body.len().to_string())
256                            .body(Body::from(body))
257                    }
258                    ContentType::Form => {
259                        let body = serde_html_form::to_string(&body)
260                            .map_err(OpaqueError::from_std)
261                            .context("serialize json body")?;
262                        header_order.push(CONTENT_LENGTH.into());
263                        req.header(CONTENT_LENGTH, body.len().to_string())
264                            .body(Body::from(body))
265                    }
266                }
267                .map_err(OpaqueError::from_std)
268                .context("create request with body")?;
269
270                req.extensions_mut().insert(header_order);
271
272                Ok(req)
273            }
274        }
275    }
276}
277
278fn parse_arg_as_data(
279    arg: &str,
280    query: &mut HashMap<String, Vec<String>>,
281    headers: &mut Vec<(String, String)>,
282    body: &mut HashMap<String, Value>,
283) -> Result<(), String> {
284    let mut state = DataParseArgState::None;
285    for (i, c) in arg.chars().enumerate() {
286        match state {
287            DataParseArgState::None => match c {
288                '\\' => state = DataParseArgState::Escaped,
289                '=' => state = DataParseArgState::Equal,
290                ':' => state = DataParseArgState::Colon,
291                _ => (),
292            },
293            DataParseArgState::Escaped => {
294                // \*
295                state = DataParseArgState::None;
296            }
297            DataParseArgState::Equal => {
298                let (name, value) = arg.split_at(i - 1);
299                if c == '=' {
300                    // ==
301                    let value = &value[2..];
302                    query
303                        .entry(name.to_owned())
304                        .or_default()
305                        .push(value.to_owned());
306                } else {
307                    // =
308                    let value = &value[1..];
309                    body.insert(name.to_owned(), Value::String(value.to_owned()));
310                }
311                break;
312            }
313            DataParseArgState::Colon => {
314                let (name, value) = arg.split_at(i - 1);
315                if c == '=' {
316                    // :=
317                    let value = &value[2..];
318                    let value: Value =
319                        serde_json::from_str(value).map_err(|err| err.to_string())?;
320                    body.insert(name.to_owned(), value);
321                } else {
322                    // :
323                    let value = &value[1..];
324                    headers.push((name.to_owned(), value.to_owned()));
325                }
326                break;
327            }
328        }
329    }
330    Ok(())
331}
332
333fn parse_arg_as_method(arg: impl AsRef<str>) -> Option<Method> {
334    match_ignore_ascii_case_str! {
335        match (arg.as_ref()) {
336            "GET" => Some(Method::GET),
337            "POST" => Some(Method::POST),
338            "PUT" => Some(Method::PUT),
339            "DELETE" => Some(Method::DELETE),
340            "PATCH" => Some(Method::PATCH),
341            "HEAD" => Some(Method::HEAD),
342            "OPTIONS" => Some(Method::OPTIONS),
343            "CONNECT" => Some(Method::CONNECT),
344            "TRACE" => Some(Method::TRACE),
345            _ => None,
346
347        }
348    }
349}
350
351/// Expand a URL string to a full URL,
352/// e.g. `example.com` -> `http://example.com`
353fn expand_url(url: String) -> String {
354    if url.is_empty() {
355        "http://localhost".to_owned()
356    } else if let Some(stripped_url) = url.strip_prefix(':') {
357        if stripped_url.is_empty() {
358            "http://localhost".to_owned()
359        } else if stripped_url
360            .chars()
361            .next()
362            .map(|c| c.is_ascii_digit())
363            .unwrap_or_default()
364        {
365            format!("http://localhost{url}")
366        } else {
367            format!("http://localhost{stripped_url}")
368        }
369    } else if !url.contains("://") {
370        format!("http://{url}")
371    } else {
372        url
373    }
374}
375
376enum DataParseArgState {
377    None,
378    Escaped,
379    Equal,
380    Colon,
381}
382
383#[derive(Debug, Clone, Copy, PartialEq, Hash)]
384enum ContentType {
385    Json,
386    Form,
387}
388
389impl ContentType {
390    fn header_value(self) -> HeaderValue {
391        HeaderValue::from_static(match self {
392            Self::Json => "application/json",
393            Self::Form => "application/x-www-form-urlencoded",
394        })
395    }
396}
397
398#[derive(Debug, Clone)]
399enum BuilderState {
400    MethodOrUrl {
401        content_type: Option<ContentType>,
402    },
403    Url {
404        content_type: Option<ContentType>,
405        method: Option<Method>,
406    },
407    Data {
408        content_type: Option<ContentType>,
409        method: Option<Method>,
410        url: String,
411        query: HashMap<String, Vec<String>>,
412        headers: Vec<(String, String)>,
413        body: HashMap<String, Value>,
414    },
415    Error {
416        message: String,
417        ignored: Vec<String>,
418    },
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::http::io::write_http_request;
425
426    #[test]
427    fn test_parse_arg_as_method() {
428        for (arg, expected) in [
429            ("GET", Some(Method::GET)),
430            ("POST", Some(Method::POST)),
431            ("PUT", Some(Method::PUT)),
432            ("DELETE", Some(Method::DELETE)),
433            ("PATCH", Some(Method::PATCH)),
434            ("HEAD", Some(Method::HEAD)),
435            ("OPTIONS", Some(Method::OPTIONS)),
436            ("CONNECT", Some(Method::CONNECT)),
437            ("TRACE", Some(Method::TRACE)),
438            ("get", Some(Method::GET)),
439            ("post", Some(Method::POST)),
440            ("put", Some(Method::PUT)),
441            ("delete", Some(Method::DELETE)),
442            ("patch", Some(Method::PATCH)),
443            ("head", Some(Method::HEAD)),
444            ("options", Some(Method::OPTIONS)),
445            ("connect", Some(Method::CONNECT)),
446            ("trace", Some(Method::TRACE)),
447            ("invalid", None),
448            ("", None),
449        ] {
450            assert_eq!(parse_arg_as_method(arg), expected);
451        }
452    }
453
454    #[test]
455    fn test_expand_url() {
456        for (url, expected) in [
457            ("example.com", "http://example.com"),
458            ("http://example.com", "http://example.com"),
459            ("https://example.com", "https://example.com"),
460            ("example.com:8080", "http://example.com:8080"),
461            (":8080/foo", "http://localhost:8080/foo"),
462            (":8080", "http://localhost:8080"),
463            ("", "http://localhost"),
464        ] {
465            assert_eq!(expand_url(url.to_owned()), expected);
466        }
467    }
468
469    #[tokio::test]
470    async fn test_request_args_builder_happy() {
471        for (args, expected_request_str) in [
472            (vec![":8080"], "GET / HTTP/1.1\r\n\r\n"),
473            (vec!["HeAD", ":8000/foo"], "HEAD /foo HTTP/1.1\r\n\r\n"),
474            (
475                vec!["example.com/bar", "FOO:bar", "AnSweR:42"],
476                "GET /bar HTTP/1.1\r\nFOO: bar\r\nAnSweR: 42\r\n\r\n",
477            ),
478            (
479                vec![
480                    "example.com/foo",
481                    "c=d",
482                    "Content-Type:application/x-www-form-urlencoded",
483                ],
484                "POST /foo HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
485            ),
486            (
487                vec!["example.com/foo", "a=b", "Content-Type:application/json"],
488                "POST /foo HTTP/1.1\r\nContent-Type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
489            ),
490            (
491                vec!["example.com/foo", "a=b"],
492                "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
493            ),
494            (
495                vec!["example.com/foo", "x-a:1", "a=b"],
496                "POST /foo HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
497            ),
498            (
499                vec!["put", "example.com/foo?a=2", "x-a:1", "a:=42", "a==3"],
500                "PUT /foo?a=2&a=3 HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 8\r\n\r\n{\"a\":42}",
501            ),
502            (
503                vec![":3000", "Cookie:foo=bar"],
504                "GET / HTTP/1.1\r\nCookie: foo=bar\r\n\r\n",
505            ),
506            (
507                vec![":/foo", "search==rama"],
508                "GET /foo?search=rama HTTP/1.1\r\n\r\n",
509            ),
510            (
511                vec!["example.com", "description='CLI HTTP client'"],
512                "POST / HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 35\r\n\r\n{\"description\":\"'CLI HTTP client'\"}",
513            ),
514            (
515                vec!["example.com", "x-cfg:a=1&foo=bar&foo=baz"],
516                "GET / HTTP/1.1\r\nx-cfg: a=1&foo=bar&foo=baz\r\n\r\n",
517            ),
518        ] {
519            let mut builder = RequestArgsBuilder::new();
520            for arg in args {
521                builder.parse_arg(arg.to_owned());
522            }
523            let request = builder.build().unwrap();
524            let mut w = Vec::new();
525            let _ = write_http_request(&mut w, request, true, true)
526                .await
527                .unwrap();
528            assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
529        }
530    }
531
532    #[tokio::test]
533    async fn test_request_args_builder_form_happy() {
534        for (args, expected_request_str) in [(
535            vec!["example.com/foo", "c=d"],
536            "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
537        )] {
538            let mut builder = RequestArgsBuilder::new_form();
539            for arg in args {
540                builder.parse_arg(arg.to_owned());
541            }
542            let request = builder.build().unwrap();
543            let mut w = Vec::new();
544            let _ = write_http_request(&mut w, request, true, true)
545                .await
546                .unwrap();
547            assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
548        }
549    }
550
551    #[tokio::test]
552    async fn test_request_args_builder_json_happy() {
553        for (args, expected_request_str) in [(
554            vec!["example.com/foo", "a=b"],
555            "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
556        )] {
557            let mut builder = RequestArgsBuilder::new();
558            for arg in args {
559                builder.parse_arg(arg.to_owned());
560            }
561            let request = builder.build().unwrap();
562            let mut w = Vec::new();
563            let _ = write_http_request(&mut w, request, true, true)
564                .await
565                .unwrap();
566            assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
567        }
568    }
569
570    #[tokio::test]
571    async fn test_request_args_builder_error() {
572        for test in [
573            vec![],
574            vec!["invalid url"],
575            vec!["get"],
576            vec!["get", "invalid url"],
577        ] {
578            let mut builder = RequestArgsBuilder::new();
579            for arg in test {
580                builder.parse_arg(arg.to_owned());
581            }
582            let request = builder.build();
583            assert!(request.is_err());
584        }
585    }
586}