rama/cli/
args.rs

1//! build requests from command line arguments
2
3use crate::{
4    error::{ErrorContext, OpaqueError},
5    http::{
6        Body, Method, Request, Uri,
7        header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE, Entry, HeaderValue},
8    },
9};
10use rama_http::proto::h1::{Http1HeaderName, headers::original::OriginalHttp1Headers};
11use rama_utils::macros::match_ignore_ascii_case_str;
12use serde_json::Value;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
16/// A builder to create a request from command line arguments.
17pub struct RequestArgsBuilder {
18    state: BuilderState,
19}
20
21impl Default for RequestArgsBuilder {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl RequestArgsBuilder {
28    /// Create a new [`RequestArgsBuilder`], which auto-detects the content type.
29    pub const fn new() -> Self {
30        Self {
31            state: BuilderState::MethodOrUrl { content_type: None },
32        }
33    }
34
35    /// Create a new [`RequestArgsBuilder`], which expects JSON data.
36    pub const fn new_json() -> RequestArgsBuilder {
37        RequestArgsBuilder {
38            state: BuilderState::MethodOrUrl {
39                content_type: Some(ContentType::Json),
40            },
41        }
42    }
43
44    /// Create a new [`RequestArgsBuilder`], which expects Form data.
45    pub const fn new_form() -> RequestArgsBuilder {
46        RequestArgsBuilder {
47            state: BuilderState::MethodOrUrl {
48                content_type: Some(ContentType::Form),
49            },
50        }
51    }
52
53    /// parse a command line argument, the possible meaning
54    /// depend on the current state of the builder, driven by the position of the argument.
55    pub fn parse_arg(&mut self, arg: String) {
56        let new_state = match &mut self.state {
57            BuilderState::MethodOrUrl { content_type } => {
58                if let Some(method) = parse_arg_as_method(&arg) {
59                    Some(BuilderState::Url {
60                        content_type: *content_type,
61                        method: Some(method),
62                    })
63                } else {
64                    Some(BuilderState::Data {
65                        content_type: *content_type,
66                        method: None,
67                        url: arg,
68                        query: HashMap::new(),
69                        headers: Vec::new(),
70                        body: HashMap::new(),
71                    })
72                }
73            }
74            BuilderState::Url {
75                content_type,
76                method,
77            } => Some(BuilderState::Data {
78                content_type: *content_type,
79                method: method.clone(),
80                url: arg,
81                query: HashMap::new(),
82                headers: Vec::new(),
83                body: HashMap::new(),
84            }),
85            BuilderState::Data {
86                query,
87                headers,
88                body,
89                ..
90            } => match parse_arg_as_data(arg, query, headers, body) {
91                Ok(_) => None,
92                Err(msg) => Some(BuilderState::Error {
93                    message: msg,
94                    ignored: vec![],
95                }),
96            },
97            BuilderState::Error { ignored, .. } => {
98                ignored.push(arg);
99                None
100            }
101        };
102        if let Some(new_state) = new_state {
103            self.state = new_state;
104        }
105    }
106
107    /// Build the request from the parsed arguments.
108    pub fn build(self) -> Result<Request, OpaqueError> {
109        match self.state {
110            BuilderState::MethodOrUrl { .. } | BuilderState::Url { .. } => {
111                Err(OpaqueError::from_display("no url defined"))
112            }
113            BuilderState::Error { message, ignored } => {
114                Err(OpaqueError::from_display(if ignored.is_empty() {
115                    format!("request arg parser failed: {}", message)
116                } else {
117                    format!(
118                        "request arg parser failed: {} (ignored: {:?})",
119                        message, ignored
120                    )
121                }))
122            }
123            BuilderState::Data {
124                content_type,
125                method,
126                url,
127                query,
128                headers,
129                body,
130            } => {
131                let mut req = Request::builder();
132
133                let url = expand_url(url);
134
135                let uri: Uri = url
136                    .parse()
137                    .map_err(OpaqueError::from_std)
138                    .context("parse base uri")?;
139
140                if query.is_empty() {
141                    req = req.uri(url);
142                } else {
143                    let mut uri_parts = uri.into_parts();
144                    uri_parts.path_and_query = Some(match uri_parts.path_and_query {
145                        Some(pq) => match pq.query() {
146                            Some(q) => {
147                                let mut existing_query: HashMap<String, Vec<String>> =
148                                    serde_html_form::from_str(q)
149                                        .map_err(OpaqueError::from_std)
150                                        .context("parse existing query")?;
151                                for (k, v) in query {
152                                    existing_query.entry(k).or_default().extend(v);
153                                }
154                                let query = serde_html_form::to_string(&existing_query)
155                                    .map_err(OpaqueError::from_std)
156                                    .context("serialize extended query")?;
157                                format!("{}?{}", pq.path(), query)
158                                    .parse()
159                                    .map_err(OpaqueError::from_std)
160                                    .context("create new path+query from extended query")?
161                            }
162                            None => {
163                                let query = serde_html_form::to_string(&query)
164                                    .map_err(OpaqueError::from_std)
165                                    .context("serialize new and only query params")?;
166                                format!("{}?{}", pq.path(), query)
167                                    .parse()
168                                    .map_err(OpaqueError::from_std)
169                                    .context("create path+query from given query params")?
170                            }
171                        },
172                        None => {
173                            let query = serde_html_form::to_string(&query)
174                                .map_err(OpaqueError::from_std)?;
175                            format!("/?{}", query)
176                                .parse()
177                                .map_err(OpaqueError::from_std)?
178                        }
179                    });
180                    req = req.uri(Uri::from_parts(uri_parts).map_err(OpaqueError::from_std)?);
181                }
182
183                match method {
184                    Some(method) => req = req.method(method),
185                    None => {
186                        if body.is_empty() {
187                            req = req.method(Method::GET);
188                        } else {
189                            req = req.method(Method::POST);
190                        }
191                    }
192                }
193
194                let mut header_order = OriginalHttp1Headers::with_capacity(headers.len());
195                for (name, value) in headers {
196                    let header_name = Http1HeaderName::try_copy_from_str(name.as_str())
197                        .context("convert string into Http1HeaderName")?;
198                    req = req.header(header_name.clone(), value);
199                    header_order.push(header_name);
200                }
201
202                if body.is_empty() {
203                    let mut req = req
204                        .body(Body::empty())
205                        .map_err(OpaqueError::from_std)
206                        .context("create request without body")?;
207
208                    req.extensions_mut().insert(header_order);
209
210                    return Ok(req);
211                }
212
213                let ct = content_type.unwrap_or_else(|| {
214                    match req
215                        .headers_ref()
216                        .and_then(|h| h.get(CONTENT_TYPE))
217                        .and_then(|h| h.to_str().ok())
218                    {
219                        Some(cv) if cv.contains("application/x-www-form-urlencoded") => {
220                            ContentType::Form
221                        }
222                        _ => ContentType::Json,
223                    }
224                });
225
226                let req = if req.headers_ref().is_none() {
227                    let req = req.header(CONTENT_TYPE, ct.header_value());
228                    header_order.push(CONTENT_TYPE.into());
229                    if ct == ContentType::Json {
230                        header_order.push(ACCEPT.into());
231                        req.header(ACCEPT, ct.header_value())
232                    } else {
233                        req
234                    }
235                } else {
236                    let headers = req.headers_mut().unwrap();
237
238                    if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) {
239                        header_order.push(CONTENT_TYPE.into());
240                        entry.insert(ct.header_value());
241                    }
242
243                    if ct == ContentType::Json {
244                        if let Entry::Vacant(entry) = headers.entry(ACCEPT) {
245                            header_order.push(ACCEPT.into());
246                            entry.insert(ct.header_value());
247                        }
248                    }
249
250                    req
251                };
252
253                let mut req = match ct {
254                    ContentType::Json => {
255                        let body = serde_json::to_string(&body)
256                            .map_err(OpaqueError::from_std)
257                            .context("serialize form body")?;
258                        header_order.push(CONTENT_LENGTH.into());
259                        req.header(CONTENT_LENGTH, body.len().to_string())
260                            .body(Body::from(body))
261                    }
262                    ContentType::Form => {
263                        let body = serde_html_form::to_string(&body)
264                            .map_err(OpaqueError::from_std)
265                            .context("serialize json body")?;
266                        header_order.push(CONTENT_LENGTH.into());
267                        req.header(CONTENT_LENGTH, body.len().to_string())
268                            .body(Body::from(body))
269                    }
270                }
271                .map_err(OpaqueError::from_std)
272                .context("create request with body")?;
273
274                req.extensions_mut().insert(header_order);
275
276                Ok(req)
277            }
278        }
279    }
280}
281
282fn parse_arg_as_data(
283    arg: String,
284    query: &mut HashMap<String, Vec<String>>,
285    headers: &mut Vec<(String, String)>,
286    body: &mut HashMap<String, Value>,
287) -> Result<(), String> {
288    let mut state = DataParseArgState::None;
289    for (i, c) in arg.chars().enumerate() {
290        match state {
291            DataParseArgState::None => match c {
292                '\\' => state = DataParseArgState::Escaped,
293                '=' => state = DataParseArgState::Equal,
294                ':' => state = DataParseArgState::Colon,
295                _ => (),
296            },
297            DataParseArgState::Escaped => {
298                // \*
299                state = DataParseArgState::None;
300            }
301            DataParseArgState::Equal => {
302                let (name, value) = arg.split_at(i - 1);
303                if c == '=' {
304                    // ==
305                    let value = &value[2..];
306                    query
307                        .entry(name.to_owned())
308                        .or_default()
309                        .push(value.to_owned());
310                } else {
311                    // =
312                    let value = &value[1..];
313                    body.insert(name.to_owned(), Value::String(value.to_owned()));
314                }
315                break;
316            }
317            DataParseArgState::Colon => {
318                let (name, value) = arg.split_at(i - 1);
319                if c == '=' {
320                    // :=
321                    let value = &value[2..];
322                    let value: Value =
323                        serde_json::from_str(value).map_err(|err| err.to_string())?;
324                    body.insert(name.to_owned(), value);
325                } else {
326                    // :
327                    let value = &value[1..];
328                    headers.push((name.to_owned(), value.to_owned()));
329                }
330                break;
331            }
332        }
333    }
334    Ok(())
335}
336
337fn parse_arg_as_method(arg: impl AsRef<str>) -> Option<Method> {
338    match_ignore_ascii_case_str! {
339        match (arg.as_ref()) {
340            "GET" => Some(Method::GET),
341            "POST" => Some(Method::POST),
342            "PUT" => Some(Method::PUT),
343            "DELETE" => Some(Method::DELETE),
344            "PATCH" => Some(Method::PATCH),
345            "HEAD" => Some(Method::HEAD),
346            "OPTIONS" => Some(Method::OPTIONS),
347            "CONNECT" => Some(Method::CONNECT),
348            "TRACE" => Some(Method::TRACE),
349            _ => None,
350
351        }
352    }
353}
354
355/// Expand a URL string to a full URL,
356/// e.g. `example.com` -> `http://example.com`
357fn expand_url(url: String) -> String {
358    if url.is_empty() {
359        "http://localhost".to_owned()
360    } else if let Some(stripped_url) = url.strip_prefix(':') {
361        if stripped_url.is_empty() {
362            "http://localhost".to_owned()
363        } else if stripped_url
364            .chars()
365            .next()
366            .map(|c| c.is_ascii_digit())
367            .unwrap_or_default()
368        {
369            format!("http://localhost{}", url)
370        } else {
371            format!("http://localhost{}", stripped_url)
372        }
373    } else if !url.contains("://") {
374        format!("http://{}", url)
375    } else {
376        url.to_string()
377    }
378}
379
380enum DataParseArgState {
381    None,
382    Escaped,
383    Equal,
384    Colon,
385}
386
387#[derive(Debug, Clone, Copy, PartialEq, Hash)]
388enum ContentType {
389    Json,
390    Form,
391}
392
393impl ContentType {
394    fn header_value(&self) -> HeaderValue {
395        HeaderValue::from_static(match self {
396            ContentType::Json => "application/json",
397            ContentType::Form => "application/x-www-form-urlencoded",
398        })
399    }
400}
401
402#[derive(Debug, Clone)]
403enum BuilderState {
404    MethodOrUrl {
405        content_type: Option<ContentType>,
406    },
407    Url {
408        content_type: Option<ContentType>,
409        method: Option<Method>,
410    },
411    Data {
412        content_type: Option<ContentType>,
413        method: Option<Method>,
414        url: String,
415        query: HashMap<String, Vec<String>>,
416        headers: Vec<(String, String)>,
417        body: HashMap<String, Value>,
418    },
419    Error {
420        message: String,
421        ignored: Vec<String>,
422    },
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428    use crate::http::io::write_http_request;
429
430    #[test]
431    fn test_parse_arg_as_method() {
432        for (arg, expected) in [
433            ("GET", Some(Method::GET)),
434            ("POST", Some(Method::POST)),
435            ("PUT", Some(Method::PUT)),
436            ("DELETE", Some(Method::DELETE)),
437            ("PATCH", Some(Method::PATCH)),
438            ("HEAD", Some(Method::HEAD)),
439            ("OPTIONS", Some(Method::OPTIONS)),
440            ("CONNECT", Some(Method::CONNECT)),
441            ("TRACE", Some(Method::TRACE)),
442            ("get", Some(Method::GET)),
443            ("post", Some(Method::POST)),
444            ("put", Some(Method::PUT)),
445            ("delete", Some(Method::DELETE)),
446            ("patch", Some(Method::PATCH)),
447            ("head", Some(Method::HEAD)),
448            ("options", Some(Method::OPTIONS)),
449            ("connect", Some(Method::CONNECT)),
450            ("trace", Some(Method::TRACE)),
451            ("invalid", None),
452            ("", None),
453        ] {
454            assert_eq!(parse_arg_as_method(arg), expected);
455        }
456    }
457
458    #[test]
459    fn test_expand_url() {
460        for (url, expected) in [
461            ("example.com", "http://example.com"),
462            ("http://example.com", "http://example.com"),
463            ("https://example.com", "https://example.com"),
464            ("example.com:8080", "http://example.com:8080"),
465            (":8080/foo", "http://localhost:8080/foo"),
466            (":8080", "http://localhost:8080"),
467            ("", "http://localhost"),
468        ] {
469            assert_eq!(expand_url(url.to_owned()), expected);
470        }
471    }
472
473    #[tokio::test]
474    async fn test_request_args_builder_happy() {
475        for (args, expected_request_str) in [
476            (vec![":8080"], "GET / HTTP/1.1\r\n\r\n"),
477            (vec!["HeAD", ":8000/foo"], "HEAD /foo HTTP/1.1\r\n\r\n"),
478            (
479                vec!["example.com/bar", "FOO:bar", "AnSweR:42"],
480                "GET /bar HTTP/1.1\r\nFOO: bar\r\nAnSweR: 42\r\n\r\n",
481            ),
482            (
483                vec![
484                    "example.com/foo",
485                    "c=d",
486                    "Content-Type:application/x-www-form-urlencoded",
487                ],
488                "POST /foo HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
489            ),
490            (
491                vec!["example.com/foo", "a=b", "Content-Type:application/json"],
492                "POST /foo HTTP/1.1\r\nContent-Type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
493            ),
494            (
495                vec!["example.com/foo", "a=b"],
496                "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
497            ),
498            (
499                vec!["example.com/foo", "x-a:1", "a=b"],
500                "POST /foo HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
501            ),
502            (
503                vec!["put", "example.com/foo?a=2", "x-a:1", "a:=42", "a==3"],
504                "PUT /foo?a=2&a=3 HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 8\r\n\r\n{\"a\":42}",
505            ),
506            (
507                vec![":3000", "Cookie:foo=bar"],
508                "GET / HTTP/1.1\r\nCookie: foo=bar\r\n\r\n",
509            ),
510            (
511                vec![":/foo", "search==rama"],
512                "GET /foo?search=rama HTTP/1.1\r\n\r\n",
513            ),
514            (
515                vec!["example.com", "description='CLI HTTP client'"],
516                "POST / HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 35\r\n\r\n{\"description\":\"'CLI HTTP client'\"}",
517            ),
518            (
519                vec!["example.com", "x-cfg:a=1&foo=bar&foo=baz"],
520                "GET / HTTP/1.1\r\nx-cfg: a=1&foo=bar&foo=baz\r\n\r\n",
521            ),
522        ] {
523            let mut builder = RequestArgsBuilder::new();
524            for arg in args {
525                builder.parse_arg(arg.to_owned());
526            }
527            let request = builder.build().unwrap();
528            let mut w = Vec::new();
529            let _ = write_http_request(&mut w, request, true, true)
530                .await
531                .unwrap();
532            assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
533        }
534    }
535
536    #[tokio::test]
537    async fn test_request_args_builder_form_happy() {
538        for (args, expected_request_str) in [(
539            vec!["example.com/foo", "c=d"],
540            "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
541        )] {
542            let mut builder = RequestArgsBuilder::new_form();
543            for arg in args {
544                builder.parse_arg(arg.to_owned());
545            }
546            let request = builder.build().unwrap();
547            let mut w = Vec::new();
548            let _ = write_http_request(&mut w, request, true, true)
549                .await
550                .unwrap();
551            assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
552        }
553    }
554
555    #[tokio::test]
556    async fn test_request_args_builder_json_happy() {
557        for (args, expected_request_str) in [(
558            vec!["example.com/foo", "a=b"],
559            "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
560        )] {
561            let mut builder = RequestArgsBuilder::new();
562            for arg in args {
563                builder.parse_arg(arg.to_owned());
564            }
565            let request = builder.build().unwrap();
566            let mut w = Vec::new();
567            let _ = write_http_request(&mut w, request, true, true)
568                .await
569                .unwrap();
570            assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
571        }
572    }
573
574    #[tokio::test]
575    async fn test_request_args_builder_error() {
576        for test in [
577            vec![],
578            vec!["invalid url"],
579            vec!["get"],
580            vec!["get", "invalid url"],
581        ] {
582            let mut builder = RequestArgsBuilder::new();
583            for arg in test {
584                builder.parse_arg(arg.to_owned());
585            }
586            let request = builder.build();
587            assert!(request.is_err());
588        }
589    }
590}