1use crate::{
4 error::{ErrorContext, OpaqueError},
5 http::{
6 Body, Method, Request, Uri,
7 header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE, Entry, HeaderValue},
8 },
9};
10use ahash::{HashMap, HashMapExt as _};
11use rama_core::extensions::ExtensionsMut;
12use rama_http::proto::h1::{Http1HeaderName, headers::original::OriginalHttp1Headers};
13use rama_utils::macros::match_ignore_ascii_case_str;
14use serde_json::Value;
15
16#[derive(Debug, Clone)]
17pub struct RequestArgsBuilder {
19 state: BuilderState,
20}
21
22impl Default for RequestArgsBuilder {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl RequestArgsBuilder {
29 #[must_use]
31 pub const fn new() -> Self {
32 Self {
33 state: BuilderState::MethodOrUrl { content_type: None },
34 }
35 }
36
37 #[must_use]
39 pub const fn new_json() -> Self {
40 Self {
41 state: BuilderState::MethodOrUrl {
42 content_type: Some(ContentType::Json),
43 },
44 }
45 }
46
47 #[must_use]
49 pub const fn new_form() -> Self {
50 Self {
51 state: BuilderState::MethodOrUrl {
52 content_type: Some(ContentType::Form),
53 },
54 }
55 }
56
57 pub fn parse_arg(&mut self, arg: String) {
60 let new_state = match &mut self.state {
61 BuilderState::MethodOrUrl { content_type } => {
62 if let Some(method) = parse_arg_as_method(&arg) {
63 Some(BuilderState::Url {
64 content_type: *content_type,
65 method: Some(method),
66 })
67 } else {
68 Some(BuilderState::Data {
69 content_type: *content_type,
70 method: None,
71 url: arg,
72 query: HashMap::new(),
73 headers: Vec::new(),
74 body: HashMap::new(),
75 })
76 }
77 }
78 BuilderState::Url {
79 content_type,
80 method,
81 } => Some(BuilderState::Data {
82 content_type: *content_type,
83 method: method.clone(),
84 url: arg,
85 query: HashMap::new(),
86 headers: Vec::new(),
87 body: HashMap::new(),
88 }),
89 BuilderState::Data {
90 query,
91 headers,
92 body,
93 ..
94 } => match parse_arg_as_data(&arg, query, headers, body) {
95 Ok(_) => None,
96 Err(msg) => Some(BuilderState::Error {
97 message: msg,
98 ignored: vec![],
99 }),
100 },
101 BuilderState::Error { ignored, .. } => {
102 ignored.push(arg);
103 None
104 }
105 };
106 if let Some(new_state) = new_state {
107 self.state = new_state;
108 }
109 }
110
111 pub fn build(self) -> Result<Request, OpaqueError> {
113 match self.state {
114 BuilderState::MethodOrUrl { .. } | BuilderState::Url { .. } => {
115 Err(OpaqueError::from_display("no url defined"))
116 }
117 BuilderState::Error { message, ignored } => {
118 Err(OpaqueError::from_display(if ignored.is_empty() {
119 format!("request arg parser failed: {message}")
120 } else {
121 format!("request arg parser failed: {message} (ignored: {ignored:?})")
122 }))
123 }
124 BuilderState::Data {
125 content_type,
126 method,
127 url,
128 query,
129 headers,
130 body,
131 } => {
132 let mut req = Request::builder();
133
134 let url = expand_url(url);
135
136 let uri: Uri = url
137 .parse()
138 .map_err(OpaqueError::from_std)
139 .context("parse base uri")?;
140
141 if query.is_empty() {
142 req = req.uri(url);
143 } else {
144 let mut uri_parts = uri.into_parts();
145 uri_parts.path_and_query = Some(if let Some(pq) = uri_parts.path_and_query {
146 if let Some(q) = pq.query() {
147 let mut existing_query: HashMap<String, Vec<String>> =
148 serde_html_form::from_str(q)
149 .map_err(OpaqueError::from_std)
150 .context("parse existing query")?;
151 for (k, v) in query {
152 existing_query.entry(k).or_default().extend(v);
153 }
154 let query = serde_html_form::to_string(&existing_query)
155 .map_err(OpaqueError::from_std)
156 .context("serialize extended query")?;
157 format!("{}?{}", pq.path(), query)
158 .parse()
159 .map_err(OpaqueError::from_std)
160 .context("create new path+query from extended query")?
161 } else {
162 let query = serde_html_form::to_string(&query)
163 .map_err(OpaqueError::from_std)
164 .context("serialize new and only query params")?;
165 format!("{}?{}", pq.path(), query)
166 .parse()
167 .map_err(OpaqueError::from_std)
168 .context("create path+query from given query params")?
169 }
170 } else {
171 let query =
172 serde_html_form::to_string(&query).map_err(OpaqueError::from_std)?;
173 format!("/?{query}")
174 .parse()
175 .map_err(OpaqueError::from_std)?
176 });
177 req = req.uri(Uri::from_parts(uri_parts).map_err(OpaqueError::from_std)?);
178 }
179
180 match method {
181 Some(method) => req = req.method(method),
182 None => {
183 if body.is_empty() {
184 req = req.method(Method::GET);
185 } else {
186 req = req.method(Method::POST);
187 }
188 }
189 }
190
191 let mut header_order = OriginalHttp1Headers::with_capacity(headers.len());
192 for (name, value) in headers {
193 let header_name = Http1HeaderName::try_copy_from_str(name.as_str())
194 .context("convert string into Http1HeaderName")?;
195 req = req.header(header_name.clone(), value);
196 header_order.push(header_name);
197 }
198
199 if body.is_empty() {
200 let mut req = req
201 .body(Body::empty())
202 .map_err(OpaqueError::from_std)
203 .context("create request without body")?;
204
205 req.extensions_mut().insert(header_order);
206
207 return Ok(req);
208 }
209
210 let ct = content_type.unwrap_or_else(|| {
211 match req
212 .headers_ref()
213 .and_then(|h| h.get(CONTENT_TYPE))
214 .and_then(|h| h.to_str().ok())
215 {
216 Some(cv) if cv.contains("application/x-www-form-urlencoded") => {
217 ContentType::Form
218 }
219 _ => ContentType::Json,
220 }
221 });
222
223 let req = if req.headers_ref().is_none() {
224 let req = req.header(CONTENT_TYPE, ct.header_value());
225 header_order.push(CONTENT_TYPE.into());
226 if ct == ContentType::Json {
227 header_order.push(ACCEPT.into());
228 req.header(ACCEPT, ct.header_value())
229 } else {
230 req
231 }
232 } else {
233 let headers = req.headers_mut().unwrap();
234
235 if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) {
236 header_order.push(CONTENT_TYPE.into());
237 entry.insert(ct.header_value());
238 }
239
240 if ct == ContentType::Json
241 && let Entry::Vacant(entry) = headers.entry(ACCEPT)
242 {
243 header_order.push(ACCEPT.into());
244 entry.insert(ct.header_value());
245 }
246
247 req
248 };
249
250 let mut req = match ct {
251 ContentType::Json => {
252 let body = serde_json::to_string(&body)
253 .map_err(OpaqueError::from_std)
254 .context("serialize form body")?;
255 header_order.push(CONTENT_LENGTH.into());
256 req.header(CONTENT_LENGTH, body.len().to_string())
257 .body(Body::from(body))
258 }
259 ContentType::Form => {
260 let body = serde_html_form::to_string(&body)
261 .map_err(OpaqueError::from_std)
262 .context("serialize json body")?;
263 header_order.push(CONTENT_LENGTH.into());
264 req.header(CONTENT_LENGTH, body.len().to_string())
265 .body(Body::from(body))
266 }
267 }
268 .map_err(OpaqueError::from_std)
269 .context("create request with body")?;
270
271 req.extensions_mut().insert(header_order);
272
273 Ok(req)
274 }
275 }
276 }
277}
278
279fn parse_arg_as_data(
280 arg: &str,
281 query: &mut HashMap<String, Vec<String>>,
282 headers: &mut Vec<(String, String)>,
283 body: &mut HashMap<String, Value>,
284) -> Result<(), String> {
285 let mut state = DataParseArgState::None;
286 for (i, c) in arg.chars().enumerate() {
287 match state {
288 DataParseArgState::None => match c {
289 '\\' => state = DataParseArgState::Escaped,
290 '=' => state = DataParseArgState::Equal,
291 ':' => state = DataParseArgState::Colon,
292 _ => (),
293 },
294 DataParseArgState::Escaped => {
295 state = DataParseArgState::None;
297 }
298 DataParseArgState::Equal => {
299 let (name, value) = arg.split_at(i - 1);
300 if c == '=' {
301 let value = &value[2..];
303 query
304 .entry(name.to_owned())
305 .or_default()
306 .push(value.to_owned());
307 } else {
308 let value = &value[1..];
310 body.insert(name.to_owned(), Value::String(value.to_owned()));
311 }
312 break;
313 }
314 DataParseArgState::Colon => {
315 let (name, value) = arg.split_at(i - 1);
316 if c == '=' {
317 let value = &value[2..];
319 let value: Value =
320 serde_json::from_str(value).map_err(|err| err.to_string())?;
321 body.insert(name.to_owned(), value);
322 } else {
323 let value = &value[1..];
325 headers.push((name.to_owned(), value.to_owned()));
326 }
327 break;
328 }
329 }
330 }
331 Ok(())
332}
333
334fn parse_arg_as_method(arg: impl AsRef<str>) -> Option<Method> {
335 match_ignore_ascii_case_str! {
336 match (arg.as_ref()) {
337 "GET" => Some(Method::GET),
338 "POST" => Some(Method::POST),
339 "PUT" => Some(Method::PUT),
340 "DELETE" => Some(Method::DELETE),
341 "PATCH" => Some(Method::PATCH),
342 "HEAD" => Some(Method::HEAD),
343 "OPTIONS" => Some(Method::OPTIONS),
344 "CONNECT" => Some(Method::CONNECT),
345 "TRACE" => Some(Method::TRACE),
346 _ => None,
347
348 }
349 }
350}
351
352fn expand_url(url: String) -> String {
355 if url.is_empty() {
356 "http://localhost".to_owned()
357 } else if let Some(stripped_url) = url.strip_prefix(':') {
358 if stripped_url.is_empty() {
359 "http://localhost".to_owned()
360 } else if stripped_url
361 .chars()
362 .next()
363 .map(|c| c.is_ascii_digit())
364 .unwrap_or_default()
365 {
366 format!("http://localhost{url}")
367 } else {
368 format!("http://localhost{stripped_url}")
369 }
370 } else if !url.contains("://") {
371 format!("http://{url}")
372 } else {
373 url
374 }
375}
376
377enum DataParseArgState {
378 None,
379 Escaped,
380 Equal,
381 Colon,
382}
383
384#[derive(Debug, Clone, Copy, PartialEq, Hash)]
385enum ContentType {
386 Json,
387 Form,
388}
389
390impl ContentType {
391 fn header_value(self) -> HeaderValue {
392 HeaderValue::from_static(match self {
393 Self::Json => "application/json",
394 Self::Form => "application/x-www-form-urlencoded",
395 })
396 }
397}
398
399#[derive(Debug, Clone)]
400enum BuilderState {
401 MethodOrUrl {
402 content_type: Option<ContentType>,
403 },
404 Url {
405 content_type: Option<ContentType>,
406 method: Option<Method>,
407 },
408 Data {
409 content_type: Option<ContentType>,
410 method: Option<Method>,
411 url: String,
412 query: HashMap<String, Vec<String>>,
413 headers: Vec<(String, String)>,
414 body: HashMap<String, Value>,
415 },
416 Error {
417 message: String,
418 ignored: Vec<String>,
419 },
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::http::io::write_http_request;
426
427 #[test]
428 fn test_parse_arg_as_method() {
429 for (arg, expected) in [
430 ("GET", Some(Method::GET)),
431 ("POST", Some(Method::POST)),
432 ("PUT", Some(Method::PUT)),
433 ("DELETE", Some(Method::DELETE)),
434 ("PATCH", Some(Method::PATCH)),
435 ("HEAD", Some(Method::HEAD)),
436 ("OPTIONS", Some(Method::OPTIONS)),
437 ("CONNECT", Some(Method::CONNECT)),
438 ("TRACE", Some(Method::TRACE)),
439 ("get", Some(Method::GET)),
440 ("post", Some(Method::POST)),
441 ("put", Some(Method::PUT)),
442 ("delete", Some(Method::DELETE)),
443 ("patch", Some(Method::PATCH)),
444 ("head", Some(Method::HEAD)),
445 ("options", Some(Method::OPTIONS)),
446 ("connect", Some(Method::CONNECT)),
447 ("trace", Some(Method::TRACE)),
448 ("invalid", None),
449 ("", None),
450 ] {
451 assert_eq!(parse_arg_as_method(arg), expected);
452 }
453 }
454
455 #[test]
456 fn test_expand_url() {
457 for (url, expected) in [
458 ("example.com", "http://example.com"),
459 ("http://example.com", "http://example.com"),
460 ("https://example.com", "https://example.com"),
461 ("example.com:8080", "http://example.com:8080"),
462 (":8080/foo", "http://localhost:8080/foo"),
463 (":8080", "http://localhost:8080"),
464 ("", "http://localhost"),
465 ] {
466 assert_eq!(expand_url(url.to_owned()), expected);
467 }
468 }
469
470 #[tokio::test]
471 async fn test_request_args_builder_happy() {
472 for (args, expected_request_str) in [
473 (vec![":8080"], "GET / HTTP/1.1\r\n\r\n"),
474 (vec!["HeAD", ":8000/foo"], "HEAD /foo HTTP/1.1\r\n\r\n"),
475 (
476 vec!["example.com/bar", "FOO:bar", "AnSweR:42"],
477 "GET /bar HTTP/1.1\r\nFOO: bar\r\nAnSweR: 42\r\n\r\n",
478 ),
479 (
480 vec![
481 "example.com/foo",
482 "c=d",
483 "Content-Type:application/x-www-form-urlencoded",
484 ],
485 "POST /foo HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
486 ),
487 (
488 vec!["example.com/foo", "a=b", "Content-Type:application/json"],
489 "POST /foo HTTP/1.1\r\nContent-Type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
490 ),
491 (
492 vec!["example.com/foo", "a=b"],
493 "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
494 ),
495 (
496 vec!["example.com/foo", "x-a:1", "a=b"],
497 "POST /foo HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
498 ),
499 (
500 vec!["put", "example.com/foo?a=2", "x-a:1", "a:=42", "a==3"],
501 "PUT /foo?a=2&a=3 HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 8\r\n\r\n{\"a\":42}",
502 ),
503 (
504 vec![":3000", "Cookie:foo=bar"],
505 "GET / HTTP/1.1\r\nCookie: foo=bar\r\n\r\n",
506 ),
507 (
508 vec![":/foo", "search==rama"],
509 "GET /foo?search=rama HTTP/1.1\r\n\r\n",
510 ),
511 (
512 vec!["example.com", "description='CLI HTTP client'"],
513 "POST / HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 35\r\n\r\n{\"description\":\"'CLI HTTP client'\"}",
514 ),
515 (
516 vec!["example.com", "x-cfg:a=1&foo=bar&foo=baz"],
517 "GET / HTTP/1.1\r\nx-cfg: a=1&foo=bar&foo=baz\r\n\r\n",
518 ),
519 ] {
520 let mut builder = RequestArgsBuilder::new();
521 for arg in args {
522 builder.parse_arg(arg.to_owned());
523 }
524 let request = builder.build().unwrap();
525 let mut w = Vec::new();
526 let _ = write_http_request(&mut w, request, true, true)
527 .await
528 .unwrap();
529 assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
530 }
531 }
532
533 #[tokio::test]
534 async fn test_request_args_builder_form_happy() {
535 for (args, expected_request_str) in [(
536 vec!["example.com/foo", "c=d"],
537 "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
538 )] {
539 let mut builder = RequestArgsBuilder::new_form();
540 for arg in args {
541 builder.parse_arg(arg.to_owned());
542 }
543 let request = builder.build().unwrap();
544 let mut w = Vec::new();
545 let _ = write_http_request(&mut w, request, true, true)
546 .await
547 .unwrap();
548 assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
549 }
550 }
551
552 #[tokio::test]
553 async fn test_request_args_builder_json_happy() {
554 for (args, expected_request_str) in [(
555 vec!["example.com/foo", "a=b"],
556 "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
557 )] {
558 let mut builder = RequestArgsBuilder::new();
559 for arg in args {
560 builder.parse_arg(arg.to_owned());
561 }
562 let request = builder.build().unwrap();
563 let mut w = Vec::new();
564 let _ = write_http_request(&mut w, request, true, true)
565 .await
566 .unwrap();
567 assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
568 }
569 }
570
571 #[tokio::test]
572 async fn test_request_args_builder_error() {
573 for test in [
574 vec![],
575 vec!["invalid url"],
576 vec!["get"],
577 vec!["get", "invalid url"],
578 ] {
579 let mut builder = RequestArgsBuilder::new();
580 for arg in test {
581 builder.parse_arg(arg.to_owned());
582 }
583 let request = builder.build();
584 assert!(request.is_err());
585 }
586 }
587}