1use crate::{
4 error::{ErrorContext, OpaqueError},
5 http::{
6 Body, Method, Request, Uri,
7 header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE, Entry, HeaderValue},
8 },
9};
10use rama_http::proto::h1::{Http1HeaderName, headers::original::OriginalHttp1Headers};
11use rama_utils::macros::match_ignore_ascii_case_str;
12use serde_json::Value;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
16pub struct RequestArgsBuilder {
18 state: BuilderState,
19}
20
21impl Default for RequestArgsBuilder {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl RequestArgsBuilder {
28 #[must_use]
30 pub const fn new() -> Self {
31 Self {
32 state: BuilderState::MethodOrUrl { content_type: None },
33 }
34 }
35
36 #[must_use]
38 pub const fn new_json() -> Self {
39 Self {
40 state: BuilderState::MethodOrUrl {
41 content_type: Some(ContentType::Json),
42 },
43 }
44 }
45
46 #[must_use]
48 pub const fn new_form() -> Self {
49 Self {
50 state: BuilderState::MethodOrUrl {
51 content_type: Some(ContentType::Form),
52 },
53 }
54 }
55
56 pub fn parse_arg(&mut self, arg: String) {
59 let new_state = match &mut self.state {
60 BuilderState::MethodOrUrl { content_type } => {
61 if let Some(method) = parse_arg_as_method(&arg) {
62 Some(BuilderState::Url {
63 content_type: *content_type,
64 method: Some(method),
65 })
66 } else {
67 Some(BuilderState::Data {
68 content_type: *content_type,
69 method: None,
70 url: arg,
71 query: HashMap::new(),
72 headers: Vec::new(),
73 body: HashMap::new(),
74 })
75 }
76 }
77 BuilderState::Url {
78 content_type,
79 method,
80 } => Some(BuilderState::Data {
81 content_type: *content_type,
82 method: method.clone(),
83 url: arg,
84 query: HashMap::new(),
85 headers: Vec::new(),
86 body: HashMap::new(),
87 }),
88 BuilderState::Data {
89 query,
90 headers,
91 body,
92 ..
93 } => match parse_arg_as_data(&arg, query, headers, body) {
94 Ok(_) => None,
95 Err(msg) => Some(BuilderState::Error {
96 message: msg,
97 ignored: vec![],
98 }),
99 },
100 BuilderState::Error { ignored, .. } => {
101 ignored.push(arg);
102 None
103 }
104 };
105 if let Some(new_state) = new_state {
106 self.state = new_state;
107 }
108 }
109
110 pub fn build(self) -> Result<Request, OpaqueError> {
112 match self.state {
113 BuilderState::MethodOrUrl { .. } | BuilderState::Url { .. } => {
114 Err(OpaqueError::from_display("no url defined"))
115 }
116 BuilderState::Error { message, ignored } => {
117 Err(OpaqueError::from_display(if ignored.is_empty() {
118 format!("request arg parser failed: {message}")
119 } else {
120 format!("request arg parser failed: {message} (ignored: {ignored:?})")
121 }))
122 }
123 BuilderState::Data {
124 content_type,
125 method,
126 url,
127 query,
128 headers,
129 body,
130 } => {
131 let mut req = Request::builder();
132
133 let url = expand_url(url);
134
135 let uri: Uri = url
136 .parse()
137 .map_err(OpaqueError::from_std)
138 .context("parse base uri")?;
139
140 if query.is_empty() {
141 req = req.uri(url);
142 } else {
143 let mut uri_parts = uri.into_parts();
144 uri_parts.path_and_query = Some(if let Some(pq) = uri_parts.path_and_query {
145 if let Some(q) = pq.query() {
146 let mut existing_query: HashMap<String, Vec<String>> =
147 serde_html_form::from_str(q)
148 .map_err(OpaqueError::from_std)
149 .context("parse existing query")?;
150 for (k, v) in query {
151 existing_query.entry(k).or_default().extend(v);
152 }
153 let query = serde_html_form::to_string(&existing_query)
154 .map_err(OpaqueError::from_std)
155 .context("serialize extended query")?;
156 format!("{}?{}", pq.path(), query)
157 .parse()
158 .map_err(OpaqueError::from_std)
159 .context("create new path+query from extended query")?
160 } else {
161 let query = serde_html_form::to_string(&query)
162 .map_err(OpaqueError::from_std)
163 .context("serialize new and only query params")?;
164 format!("{}?{}", pq.path(), query)
165 .parse()
166 .map_err(OpaqueError::from_std)
167 .context("create path+query from given query params")?
168 }
169 } else {
170 let query =
171 serde_html_form::to_string(&query).map_err(OpaqueError::from_std)?;
172 format!("/?{query}")
173 .parse()
174 .map_err(OpaqueError::from_std)?
175 });
176 req = req.uri(Uri::from_parts(uri_parts).map_err(OpaqueError::from_std)?);
177 }
178
179 match method {
180 Some(method) => req = req.method(method),
181 None => {
182 if body.is_empty() {
183 req = req.method(Method::GET);
184 } else {
185 req = req.method(Method::POST);
186 }
187 }
188 }
189
190 let mut header_order = OriginalHttp1Headers::with_capacity(headers.len());
191 for (name, value) in headers {
192 let header_name = Http1HeaderName::try_copy_from_str(name.as_str())
193 .context("convert string into Http1HeaderName")?;
194 req = req.header(header_name.clone(), value);
195 header_order.push(header_name);
196 }
197
198 if body.is_empty() {
199 let mut req = req
200 .body(Body::empty())
201 .map_err(OpaqueError::from_std)
202 .context("create request without body")?;
203
204 req.extensions_mut().insert(header_order);
205
206 return Ok(req);
207 }
208
209 let ct = content_type.unwrap_or_else(|| {
210 match req
211 .headers_ref()
212 .and_then(|h| h.get(CONTENT_TYPE))
213 .and_then(|h| h.to_str().ok())
214 {
215 Some(cv) if cv.contains("application/x-www-form-urlencoded") => {
216 ContentType::Form
217 }
218 _ => ContentType::Json,
219 }
220 });
221
222 let req = if req.headers_ref().is_none() {
223 let req = req.header(CONTENT_TYPE, ct.header_value());
224 header_order.push(CONTENT_TYPE.into());
225 if ct == ContentType::Json {
226 header_order.push(ACCEPT.into());
227 req.header(ACCEPT, ct.header_value())
228 } else {
229 req
230 }
231 } else {
232 let headers = req.headers_mut().unwrap();
233
234 if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) {
235 header_order.push(CONTENT_TYPE.into());
236 entry.insert(ct.header_value());
237 }
238
239 if ct == ContentType::Json
240 && let Entry::Vacant(entry) = headers.entry(ACCEPT)
241 {
242 header_order.push(ACCEPT.into());
243 entry.insert(ct.header_value());
244 }
245
246 req
247 };
248
249 let mut req = match ct {
250 ContentType::Json => {
251 let body = serde_json::to_string(&body)
252 .map_err(OpaqueError::from_std)
253 .context("serialize form body")?;
254 header_order.push(CONTENT_LENGTH.into());
255 req.header(CONTENT_LENGTH, body.len().to_string())
256 .body(Body::from(body))
257 }
258 ContentType::Form => {
259 let body = serde_html_form::to_string(&body)
260 .map_err(OpaqueError::from_std)
261 .context("serialize json body")?;
262 header_order.push(CONTENT_LENGTH.into());
263 req.header(CONTENT_LENGTH, body.len().to_string())
264 .body(Body::from(body))
265 }
266 }
267 .map_err(OpaqueError::from_std)
268 .context("create request with body")?;
269
270 req.extensions_mut().insert(header_order);
271
272 Ok(req)
273 }
274 }
275 }
276}
277
278fn parse_arg_as_data(
279 arg: &str,
280 query: &mut HashMap<String, Vec<String>>,
281 headers: &mut Vec<(String, String)>,
282 body: &mut HashMap<String, Value>,
283) -> Result<(), String> {
284 let mut state = DataParseArgState::None;
285 for (i, c) in arg.chars().enumerate() {
286 match state {
287 DataParseArgState::None => match c {
288 '\\' => state = DataParseArgState::Escaped,
289 '=' => state = DataParseArgState::Equal,
290 ':' => state = DataParseArgState::Colon,
291 _ => (),
292 },
293 DataParseArgState::Escaped => {
294 state = DataParseArgState::None;
296 }
297 DataParseArgState::Equal => {
298 let (name, value) = arg.split_at(i - 1);
299 if c == '=' {
300 let value = &value[2..];
302 query
303 .entry(name.to_owned())
304 .or_default()
305 .push(value.to_owned());
306 } else {
307 let value = &value[1..];
309 body.insert(name.to_owned(), Value::String(value.to_owned()));
310 }
311 break;
312 }
313 DataParseArgState::Colon => {
314 let (name, value) = arg.split_at(i - 1);
315 if c == '=' {
316 let value = &value[2..];
318 let value: Value =
319 serde_json::from_str(value).map_err(|err| err.to_string())?;
320 body.insert(name.to_owned(), value);
321 } else {
322 let value = &value[1..];
324 headers.push((name.to_owned(), value.to_owned()));
325 }
326 break;
327 }
328 }
329 }
330 Ok(())
331}
332
333fn parse_arg_as_method(arg: impl AsRef<str>) -> Option<Method> {
334 match_ignore_ascii_case_str! {
335 match (arg.as_ref()) {
336 "GET" => Some(Method::GET),
337 "POST" => Some(Method::POST),
338 "PUT" => Some(Method::PUT),
339 "DELETE" => Some(Method::DELETE),
340 "PATCH" => Some(Method::PATCH),
341 "HEAD" => Some(Method::HEAD),
342 "OPTIONS" => Some(Method::OPTIONS),
343 "CONNECT" => Some(Method::CONNECT),
344 "TRACE" => Some(Method::TRACE),
345 _ => None,
346
347 }
348 }
349}
350
351fn expand_url(url: String) -> String {
354 if url.is_empty() {
355 "http://localhost".to_owned()
356 } else if let Some(stripped_url) = url.strip_prefix(':') {
357 if stripped_url.is_empty() {
358 "http://localhost".to_owned()
359 } else if stripped_url
360 .chars()
361 .next()
362 .map(|c| c.is_ascii_digit())
363 .unwrap_or_default()
364 {
365 format!("http://localhost{url}")
366 } else {
367 format!("http://localhost{stripped_url}")
368 }
369 } else if !url.contains("://") {
370 format!("http://{url}")
371 } else {
372 url
373 }
374}
375
376enum DataParseArgState {
377 None,
378 Escaped,
379 Equal,
380 Colon,
381}
382
383#[derive(Debug, Clone, Copy, PartialEq, Hash)]
384enum ContentType {
385 Json,
386 Form,
387}
388
389impl ContentType {
390 fn header_value(self) -> HeaderValue {
391 HeaderValue::from_static(match self {
392 Self::Json => "application/json",
393 Self::Form => "application/x-www-form-urlencoded",
394 })
395 }
396}
397
398#[derive(Debug, Clone)]
399enum BuilderState {
400 MethodOrUrl {
401 content_type: Option<ContentType>,
402 },
403 Url {
404 content_type: Option<ContentType>,
405 method: Option<Method>,
406 },
407 Data {
408 content_type: Option<ContentType>,
409 method: Option<Method>,
410 url: String,
411 query: HashMap<String, Vec<String>>,
412 headers: Vec<(String, String)>,
413 body: HashMap<String, Value>,
414 },
415 Error {
416 message: String,
417 ignored: Vec<String>,
418 },
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::http::io::write_http_request;
425
426 #[test]
427 fn test_parse_arg_as_method() {
428 for (arg, expected) in [
429 ("GET", Some(Method::GET)),
430 ("POST", Some(Method::POST)),
431 ("PUT", Some(Method::PUT)),
432 ("DELETE", Some(Method::DELETE)),
433 ("PATCH", Some(Method::PATCH)),
434 ("HEAD", Some(Method::HEAD)),
435 ("OPTIONS", Some(Method::OPTIONS)),
436 ("CONNECT", Some(Method::CONNECT)),
437 ("TRACE", Some(Method::TRACE)),
438 ("get", Some(Method::GET)),
439 ("post", Some(Method::POST)),
440 ("put", Some(Method::PUT)),
441 ("delete", Some(Method::DELETE)),
442 ("patch", Some(Method::PATCH)),
443 ("head", Some(Method::HEAD)),
444 ("options", Some(Method::OPTIONS)),
445 ("connect", Some(Method::CONNECT)),
446 ("trace", Some(Method::TRACE)),
447 ("invalid", None),
448 ("", None),
449 ] {
450 assert_eq!(parse_arg_as_method(arg), expected);
451 }
452 }
453
454 #[test]
455 fn test_expand_url() {
456 for (url, expected) in [
457 ("example.com", "http://example.com"),
458 ("http://example.com", "http://example.com"),
459 ("https://example.com", "https://example.com"),
460 ("example.com:8080", "http://example.com:8080"),
461 (":8080/foo", "http://localhost:8080/foo"),
462 (":8080", "http://localhost:8080"),
463 ("", "http://localhost"),
464 ] {
465 assert_eq!(expand_url(url.to_owned()), expected);
466 }
467 }
468
469 #[tokio::test]
470 async fn test_request_args_builder_happy() {
471 for (args, expected_request_str) in [
472 (vec![":8080"], "GET / HTTP/1.1\r\n\r\n"),
473 (vec!["HeAD", ":8000/foo"], "HEAD /foo HTTP/1.1\r\n\r\n"),
474 (
475 vec!["example.com/bar", "FOO:bar", "AnSweR:42"],
476 "GET /bar HTTP/1.1\r\nFOO: bar\r\nAnSweR: 42\r\n\r\n",
477 ),
478 (
479 vec![
480 "example.com/foo",
481 "c=d",
482 "Content-Type:application/x-www-form-urlencoded",
483 ],
484 "POST /foo HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
485 ),
486 (
487 vec!["example.com/foo", "a=b", "Content-Type:application/json"],
488 "POST /foo HTTP/1.1\r\nContent-Type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
489 ),
490 (
491 vec!["example.com/foo", "a=b"],
492 "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
493 ),
494 (
495 vec!["example.com/foo", "x-a:1", "a=b"],
496 "POST /foo HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
497 ),
498 (
499 vec!["put", "example.com/foo?a=2", "x-a:1", "a:=42", "a==3"],
500 "PUT /foo?a=2&a=3 HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 8\r\n\r\n{\"a\":42}",
501 ),
502 (
503 vec![":3000", "Cookie:foo=bar"],
504 "GET / HTTP/1.1\r\nCookie: foo=bar\r\n\r\n",
505 ),
506 (
507 vec![":/foo", "search==rama"],
508 "GET /foo?search=rama HTTP/1.1\r\n\r\n",
509 ),
510 (
511 vec!["example.com", "description='CLI HTTP client'"],
512 "POST / HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 35\r\n\r\n{\"description\":\"'CLI HTTP client'\"}",
513 ),
514 (
515 vec!["example.com", "x-cfg:a=1&foo=bar&foo=baz"],
516 "GET / HTTP/1.1\r\nx-cfg: a=1&foo=bar&foo=baz\r\n\r\n",
517 ),
518 ] {
519 let mut builder = RequestArgsBuilder::new();
520 for arg in args {
521 builder.parse_arg(arg.to_owned());
522 }
523 let request = builder.build().unwrap();
524 let mut w = Vec::new();
525 let _ = write_http_request(&mut w, request, true, true)
526 .await
527 .unwrap();
528 assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
529 }
530 }
531
532 #[tokio::test]
533 async fn test_request_args_builder_form_happy() {
534 for (args, expected_request_str) in [(
535 vec!["example.com/foo", "c=d"],
536 "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
537 )] {
538 let mut builder = RequestArgsBuilder::new_form();
539 for arg in args {
540 builder.parse_arg(arg.to_owned());
541 }
542 let request = builder.build().unwrap();
543 let mut w = Vec::new();
544 let _ = write_http_request(&mut w, request, true, true)
545 .await
546 .unwrap();
547 assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
548 }
549 }
550
551 #[tokio::test]
552 async fn test_request_args_builder_json_happy() {
553 for (args, expected_request_str) in [(
554 vec!["example.com/foo", "a=b"],
555 "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
556 )] {
557 let mut builder = RequestArgsBuilder::new();
558 for arg in args {
559 builder.parse_arg(arg.to_owned());
560 }
561 let request = builder.build().unwrap();
562 let mut w = Vec::new();
563 let _ = write_http_request(&mut w, request, true, true)
564 .await
565 .unwrap();
566 assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
567 }
568 }
569
570 #[tokio::test]
571 async fn test_request_args_builder_error() {
572 for test in [
573 vec![],
574 vec!["invalid url"],
575 vec!["get"],
576 vec!["get", "invalid url"],
577 ] {
578 let mut builder = RequestArgsBuilder::new();
579 for arg in test {
580 builder.parse_arg(arg.to_owned());
581 }
582 let request = builder.build();
583 assert!(request.is_err());
584 }
585 }
586}