1use crate::{
4 error::{ErrorContext, OpaqueError},
5 http::{
6 Body, Method, Request, Uri,
7 header::{ACCEPT, CONTENT_LENGTH, CONTENT_TYPE, Entry, HeaderValue},
8 },
9};
10use rama_http::proto::h1::{Http1HeaderName, headers::original::OriginalHttp1Headers};
11use rama_utils::macros::match_ignore_ascii_case_str;
12use serde_json::Value;
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
16pub struct RequestArgsBuilder {
18 state: BuilderState,
19}
20
21impl Default for RequestArgsBuilder {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl RequestArgsBuilder {
28 pub const fn new() -> Self {
30 Self {
31 state: BuilderState::MethodOrUrl { content_type: None },
32 }
33 }
34
35 pub const fn new_json() -> RequestArgsBuilder {
37 RequestArgsBuilder {
38 state: BuilderState::MethodOrUrl {
39 content_type: Some(ContentType::Json),
40 },
41 }
42 }
43
44 pub const fn new_form() -> RequestArgsBuilder {
46 RequestArgsBuilder {
47 state: BuilderState::MethodOrUrl {
48 content_type: Some(ContentType::Form),
49 },
50 }
51 }
52
53 pub fn parse_arg(&mut self, arg: String) {
56 let new_state = match &mut self.state {
57 BuilderState::MethodOrUrl { content_type } => {
58 if let Some(method) = parse_arg_as_method(&arg) {
59 Some(BuilderState::Url {
60 content_type: *content_type,
61 method: Some(method),
62 })
63 } else {
64 Some(BuilderState::Data {
65 content_type: *content_type,
66 method: None,
67 url: arg,
68 query: HashMap::new(),
69 headers: Vec::new(),
70 body: HashMap::new(),
71 })
72 }
73 }
74 BuilderState::Url {
75 content_type,
76 method,
77 } => Some(BuilderState::Data {
78 content_type: *content_type,
79 method: method.clone(),
80 url: arg,
81 query: HashMap::new(),
82 headers: Vec::new(),
83 body: HashMap::new(),
84 }),
85 BuilderState::Data {
86 query,
87 headers,
88 body,
89 ..
90 } => match parse_arg_as_data(arg, query, headers, body) {
91 Ok(_) => None,
92 Err(msg) => Some(BuilderState::Error {
93 message: msg,
94 ignored: vec![],
95 }),
96 },
97 BuilderState::Error { ignored, .. } => {
98 ignored.push(arg);
99 None
100 }
101 };
102 if let Some(new_state) = new_state {
103 self.state = new_state;
104 }
105 }
106
107 pub fn build(self) -> Result<Request, OpaqueError> {
109 match self.state {
110 BuilderState::MethodOrUrl { .. } | BuilderState::Url { .. } => {
111 Err(OpaqueError::from_display("no url defined"))
112 }
113 BuilderState::Error { message, ignored } => {
114 Err(OpaqueError::from_display(if ignored.is_empty() {
115 format!("request arg parser failed: {}", message)
116 } else {
117 format!(
118 "request arg parser failed: {} (ignored: {:?})",
119 message, ignored
120 )
121 }))
122 }
123 BuilderState::Data {
124 content_type,
125 method,
126 url,
127 query,
128 headers,
129 body,
130 } => {
131 let mut req = Request::builder();
132
133 let url = expand_url(url);
134
135 let uri: Uri = url
136 .parse()
137 .map_err(OpaqueError::from_std)
138 .context("parse base uri")?;
139
140 if query.is_empty() {
141 req = req.uri(url);
142 } else {
143 let mut uri_parts = uri.into_parts();
144 uri_parts.path_and_query = Some(match uri_parts.path_and_query {
145 Some(pq) => match pq.query() {
146 Some(q) => {
147 let mut existing_query: HashMap<String, Vec<String>> =
148 serde_html_form::from_str(q)
149 .map_err(OpaqueError::from_std)
150 .context("parse existing query")?;
151 for (k, v) in query {
152 existing_query.entry(k).or_default().extend(v);
153 }
154 let query = serde_html_form::to_string(&existing_query)
155 .map_err(OpaqueError::from_std)
156 .context("serialize extended query")?;
157 format!("{}?{}", pq.path(), query)
158 .parse()
159 .map_err(OpaqueError::from_std)
160 .context("create new path+query from extended query")?
161 }
162 None => {
163 let query = serde_html_form::to_string(&query)
164 .map_err(OpaqueError::from_std)
165 .context("serialize new and only query params")?;
166 format!("{}?{}", pq.path(), query)
167 .parse()
168 .map_err(OpaqueError::from_std)
169 .context("create path+query from given query params")?
170 }
171 },
172 None => {
173 let query = serde_html_form::to_string(&query)
174 .map_err(OpaqueError::from_std)?;
175 format!("/?{}", query)
176 .parse()
177 .map_err(OpaqueError::from_std)?
178 }
179 });
180 req = req.uri(Uri::from_parts(uri_parts).map_err(OpaqueError::from_std)?);
181 }
182
183 match method {
184 Some(method) => req = req.method(method),
185 None => {
186 if body.is_empty() {
187 req = req.method(Method::GET);
188 } else {
189 req = req.method(Method::POST);
190 }
191 }
192 }
193
194 let mut header_order = OriginalHttp1Headers::with_capacity(headers.len());
195 for (name, value) in headers {
196 let header_name = Http1HeaderName::try_copy_from_str(name.as_str())
197 .context("convert string into Http1HeaderName")?;
198 req = req.header(header_name.clone(), value);
199 header_order.push(header_name);
200 }
201
202 if body.is_empty() {
203 let mut req = req
204 .body(Body::empty())
205 .map_err(OpaqueError::from_std)
206 .context("create request without body")?;
207
208 req.extensions_mut().insert(header_order);
209
210 return Ok(req);
211 }
212
213 let ct = content_type.unwrap_or_else(|| {
214 match req
215 .headers_ref()
216 .and_then(|h| h.get(CONTENT_TYPE))
217 .and_then(|h| h.to_str().ok())
218 {
219 Some(cv) if cv.contains("application/x-www-form-urlencoded") => {
220 ContentType::Form
221 }
222 _ => ContentType::Json,
223 }
224 });
225
226 let req = if req.headers_ref().is_none() {
227 let req = req.header(CONTENT_TYPE, ct.header_value());
228 header_order.push(CONTENT_TYPE.into());
229 if ct == ContentType::Json {
230 header_order.push(ACCEPT.into());
231 req.header(ACCEPT, ct.header_value())
232 } else {
233 req
234 }
235 } else {
236 let headers = req.headers_mut().unwrap();
237
238 if let Entry::Vacant(entry) = headers.entry(CONTENT_TYPE) {
239 header_order.push(CONTENT_TYPE.into());
240 entry.insert(ct.header_value());
241 }
242
243 if ct == ContentType::Json {
244 if let Entry::Vacant(entry) = headers.entry(ACCEPT) {
245 header_order.push(ACCEPT.into());
246 entry.insert(ct.header_value());
247 }
248 }
249
250 req
251 };
252
253 let mut req = match ct {
254 ContentType::Json => {
255 let body = serde_json::to_string(&body)
256 .map_err(OpaqueError::from_std)
257 .context("serialize form body")?;
258 header_order.push(CONTENT_LENGTH.into());
259 req.header(CONTENT_LENGTH, body.len().to_string())
260 .body(Body::from(body))
261 }
262 ContentType::Form => {
263 let body = serde_html_form::to_string(&body)
264 .map_err(OpaqueError::from_std)
265 .context("serialize json body")?;
266 header_order.push(CONTENT_LENGTH.into());
267 req.header(CONTENT_LENGTH, body.len().to_string())
268 .body(Body::from(body))
269 }
270 }
271 .map_err(OpaqueError::from_std)
272 .context("create request with body")?;
273
274 req.extensions_mut().insert(header_order);
275
276 Ok(req)
277 }
278 }
279 }
280}
281
282fn parse_arg_as_data(
283 arg: String,
284 query: &mut HashMap<String, Vec<String>>,
285 headers: &mut Vec<(String, String)>,
286 body: &mut HashMap<String, Value>,
287) -> Result<(), String> {
288 let mut state = DataParseArgState::None;
289 for (i, c) in arg.chars().enumerate() {
290 match state {
291 DataParseArgState::None => match c {
292 '\\' => state = DataParseArgState::Escaped,
293 '=' => state = DataParseArgState::Equal,
294 ':' => state = DataParseArgState::Colon,
295 _ => (),
296 },
297 DataParseArgState::Escaped => {
298 state = DataParseArgState::None;
300 }
301 DataParseArgState::Equal => {
302 let (name, value) = arg.split_at(i - 1);
303 if c == '=' {
304 let value = &value[2..];
306 query
307 .entry(name.to_owned())
308 .or_default()
309 .push(value.to_owned());
310 } else {
311 let value = &value[1..];
313 body.insert(name.to_owned(), Value::String(value.to_owned()));
314 }
315 break;
316 }
317 DataParseArgState::Colon => {
318 let (name, value) = arg.split_at(i - 1);
319 if c == '=' {
320 let value = &value[2..];
322 let value: Value =
323 serde_json::from_str(value).map_err(|err| err.to_string())?;
324 body.insert(name.to_owned(), value);
325 } else {
326 let value = &value[1..];
328 headers.push((name.to_owned(), value.to_owned()));
329 }
330 break;
331 }
332 }
333 }
334 Ok(())
335}
336
337fn parse_arg_as_method(arg: impl AsRef<str>) -> Option<Method> {
338 match_ignore_ascii_case_str! {
339 match (arg.as_ref()) {
340 "GET" => Some(Method::GET),
341 "POST" => Some(Method::POST),
342 "PUT" => Some(Method::PUT),
343 "DELETE" => Some(Method::DELETE),
344 "PATCH" => Some(Method::PATCH),
345 "HEAD" => Some(Method::HEAD),
346 "OPTIONS" => Some(Method::OPTIONS),
347 "CONNECT" => Some(Method::CONNECT),
348 "TRACE" => Some(Method::TRACE),
349 _ => None,
350
351 }
352 }
353}
354
355fn expand_url(url: String) -> String {
358 if url.is_empty() {
359 "http://localhost".to_owned()
360 } else if let Some(stripped_url) = url.strip_prefix(':') {
361 if stripped_url.is_empty() {
362 "http://localhost".to_owned()
363 } else if stripped_url
364 .chars()
365 .next()
366 .map(|c| c.is_ascii_digit())
367 .unwrap_or_default()
368 {
369 format!("http://localhost{}", url)
370 } else {
371 format!("http://localhost{}", stripped_url)
372 }
373 } else if !url.contains("://") {
374 format!("http://{}", url)
375 } else {
376 url.to_string()
377 }
378}
379
380enum DataParseArgState {
381 None,
382 Escaped,
383 Equal,
384 Colon,
385}
386
387#[derive(Debug, Clone, Copy, PartialEq, Hash)]
388enum ContentType {
389 Json,
390 Form,
391}
392
393impl ContentType {
394 fn header_value(&self) -> HeaderValue {
395 HeaderValue::from_static(match self {
396 ContentType::Json => "application/json",
397 ContentType::Form => "application/x-www-form-urlencoded",
398 })
399 }
400}
401
402#[derive(Debug, Clone)]
403enum BuilderState {
404 MethodOrUrl {
405 content_type: Option<ContentType>,
406 },
407 Url {
408 content_type: Option<ContentType>,
409 method: Option<Method>,
410 },
411 Data {
412 content_type: Option<ContentType>,
413 method: Option<Method>,
414 url: String,
415 query: HashMap<String, Vec<String>>,
416 headers: Vec<(String, String)>,
417 body: HashMap<String, Value>,
418 },
419 Error {
420 message: String,
421 ignored: Vec<String>,
422 },
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use crate::http::io::write_http_request;
429
430 #[test]
431 fn test_parse_arg_as_method() {
432 for (arg, expected) in [
433 ("GET", Some(Method::GET)),
434 ("POST", Some(Method::POST)),
435 ("PUT", Some(Method::PUT)),
436 ("DELETE", Some(Method::DELETE)),
437 ("PATCH", Some(Method::PATCH)),
438 ("HEAD", Some(Method::HEAD)),
439 ("OPTIONS", Some(Method::OPTIONS)),
440 ("CONNECT", Some(Method::CONNECT)),
441 ("TRACE", Some(Method::TRACE)),
442 ("get", Some(Method::GET)),
443 ("post", Some(Method::POST)),
444 ("put", Some(Method::PUT)),
445 ("delete", Some(Method::DELETE)),
446 ("patch", Some(Method::PATCH)),
447 ("head", Some(Method::HEAD)),
448 ("options", Some(Method::OPTIONS)),
449 ("connect", Some(Method::CONNECT)),
450 ("trace", Some(Method::TRACE)),
451 ("invalid", None),
452 ("", None),
453 ] {
454 assert_eq!(parse_arg_as_method(arg), expected);
455 }
456 }
457
458 #[test]
459 fn test_expand_url() {
460 for (url, expected) in [
461 ("example.com", "http://example.com"),
462 ("http://example.com", "http://example.com"),
463 ("https://example.com", "https://example.com"),
464 ("example.com:8080", "http://example.com:8080"),
465 (":8080/foo", "http://localhost:8080/foo"),
466 (":8080", "http://localhost:8080"),
467 ("", "http://localhost"),
468 ] {
469 assert_eq!(expand_url(url.to_owned()), expected);
470 }
471 }
472
473 #[tokio::test]
474 async fn test_request_args_builder_happy() {
475 for (args, expected_request_str) in [
476 (vec![":8080"], "GET / HTTP/1.1\r\n\r\n"),
477 (vec!["HeAD", ":8000/foo"], "HEAD /foo HTTP/1.1\r\n\r\n"),
478 (
479 vec!["example.com/bar", "FOO:bar", "AnSweR:42"],
480 "GET /bar HTTP/1.1\r\nFOO: bar\r\nAnSweR: 42\r\n\r\n",
481 ),
482 (
483 vec![
484 "example.com/foo",
485 "c=d",
486 "Content-Type:application/x-www-form-urlencoded",
487 ],
488 "POST /foo HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
489 ),
490 (
491 vec!["example.com/foo", "a=b", "Content-Type:application/json"],
492 "POST /foo HTTP/1.1\r\nContent-Type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
493 ),
494 (
495 vec!["example.com/foo", "a=b"],
496 "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
497 ),
498 (
499 vec!["example.com/foo", "x-a:1", "a=b"],
500 "POST /foo HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
501 ),
502 (
503 vec!["put", "example.com/foo?a=2", "x-a:1", "a:=42", "a==3"],
504 "PUT /foo?a=2&a=3 HTTP/1.1\r\nx-a: 1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 8\r\n\r\n{\"a\":42}",
505 ),
506 (
507 vec![":3000", "Cookie:foo=bar"],
508 "GET / HTTP/1.1\r\nCookie: foo=bar\r\n\r\n",
509 ),
510 (
511 vec![":/foo", "search==rama"],
512 "GET /foo?search=rama HTTP/1.1\r\n\r\n",
513 ),
514 (
515 vec!["example.com", "description='CLI HTTP client'"],
516 "POST / HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 35\r\n\r\n{\"description\":\"'CLI HTTP client'\"}",
517 ),
518 (
519 vec!["example.com", "x-cfg:a=1&foo=bar&foo=baz"],
520 "GET / HTTP/1.1\r\nx-cfg: a=1&foo=bar&foo=baz\r\n\r\n",
521 ),
522 ] {
523 let mut builder = RequestArgsBuilder::new();
524 for arg in args {
525 builder.parse_arg(arg.to_owned());
526 }
527 let request = builder.build().unwrap();
528 let mut w = Vec::new();
529 let _ = write_http_request(&mut w, request, true, true)
530 .await
531 .unwrap();
532 assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
533 }
534 }
535
536 #[tokio::test]
537 async fn test_request_args_builder_form_happy() {
538 for (args, expected_request_str) in [(
539 vec!["example.com/foo", "c=d"],
540 "POST /foo HTTP/1.1\r\ncontent-type: application/x-www-form-urlencoded\r\ncontent-length: 3\r\n\r\nc=d",
541 )] {
542 let mut builder = RequestArgsBuilder::new_form();
543 for arg in args {
544 builder.parse_arg(arg.to_owned());
545 }
546 let request = builder.build().unwrap();
547 let mut w = Vec::new();
548 let _ = write_http_request(&mut w, request, true, true)
549 .await
550 .unwrap();
551 assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
552 }
553 }
554
555 #[tokio::test]
556 async fn test_request_args_builder_json_happy() {
557 for (args, expected_request_str) in [(
558 vec!["example.com/foo", "a=b"],
559 "POST /foo HTTP/1.1\r\ncontent-type: application/json\r\naccept: application/json\r\ncontent-length: 9\r\n\r\n{\"a\":\"b\"}",
560 )] {
561 let mut builder = RequestArgsBuilder::new();
562 for arg in args {
563 builder.parse_arg(arg.to_owned());
564 }
565 let request = builder.build().unwrap();
566 let mut w = Vec::new();
567 let _ = write_http_request(&mut w, request, true, true)
568 .await
569 .unwrap();
570 assert_eq!(String::from_utf8(w).unwrap(), expected_request_str);
571 }
572 }
573
574 #[tokio::test]
575 async fn test_request_args_builder_error() {
576 for test in [
577 vec![],
578 vec!["invalid url"],
579 vec!["get"],
580 vec!["get", "invalid url"],
581 ] {
582 let mut builder = RequestArgsBuilder::new();
583 for arg in test {
584 builder.parse_arg(arg.to_owned());
585 }
586 let request = builder.build();
587 assert!(request.is_err());
588 }
589 }
590}