mod_http/
lib.rs

1use config::{any_err, from_lua_value, get_or_create_sub_module, SerdeWrappedValue};
2use futures_util::StreamExt;
3use mlua::prelude::LuaUserData;
4use mlua::{Lua, LuaSerdeExt, MetaMethod, UserDataMethods, Value};
5use reqwest::header::HeaderMap;
6use reqwest::{Body, Client, ClientBuilder, RequestBuilder, Response, StatusCode, Url};
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use tokio::sync::Mutex as TokioMutex;
12use tokio::time::Instant;
13use tokio_tungstenite::tungstenite::Message;
14
15// Client ----
16
17#[derive(Deserialize, Debug, Clone)]
18struct ClientOptions {
19    #[serde(default)]
20    user_agent: Option<String>,
21    #[serde(default)]
22    connection_verbose: Option<bool>,
23    #[serde(default, with = "duration_serde")]
24    pool_idle_timeout: Option<Duration>,
25    #[serde(default, with = "duration_serde")]
26    timeout: Option<Duration>,
27    #[serde(default)]
28    accept_invalid_certs: Option<bool>,
29}
30
31#[derive(Clone)]
32struct ClientWrapper {
33    client: Arc<Mutex<Option<Arc<Client>>>>,
34}
35
36impl ClientWrapper {
37    fn get_client(&self) -> mlua::Result<Arc<Client>> {
38        let inner = self.client.lock().unwrap();
39        inner
40            .as_ref()
41            .map(Arc::clone)
42            .ok_or_else(|| mlua::Error::external("client was closed"))
43    }
44}
45
46impl LuaUserData for ClientWrapper {
47    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
48        methods.add_method("get", |_, this, url: String| {
49            let builder = this.get_client()?.get(url);
50            Ok(RequestWrapper::new(builder))
51        });
52        methods.add_method("post", |_, this, url: String| {
53            let builder = this.get_client()?.post(url);
54            Ok(RequestWrapper::new(builder))
55        });
56        methods.add_method("put", |_, this, url: String| {
57            let builder = this.get_client()?.put(url);
58            Ok(RequestWrapper::new(builder))
59        });
60        methods.add_method("close", |_, this, _: ()| {
61            this.client.lock().unwrap().take();
62            Ok(())
63        });
64    }
65}
66
67// Request ----
68
69#[derive(Clone)]
70struct RequestWrapper {
71    builder: Arc<Mutex<Option<RequestBuilder>>>,
72}
73
74impl RequestWrapper {
75    fn new(builder: RequestBuilder) -> Self {
76        Self {
77            builder: Arc::new(Mutex::new(Some(builder))),
78        }
79    }
80
81    fn apply<F>(&self, func: F) -> mlua::Result<()>
82    where
83        F: FnOnce(RequestBuilder) -> anyhow::Result<RequestBuilder>,
84    {
85        let b = self
86            .builder
87            .lock()
88            .unwrap()
89            .take()
90            .ok_or_else(|| mlua::Error::external("broken request builder"))?;
91
92        let b = (func)(b).map_err(any_err)?;
93
94        self.builder.lock().unwrap().replace(b);
95        Ok(())
96    }
97
98    async fn send(&self) -> mlua::Result<Response> {
99        let b = self
100            .builder
101            .lock()
102            .unwrap()
103            .take()
104            .ok_or_else(|| mlua::Error::external("broken request builder"))?;
105
106        b.send().await.map_err(any_err)
107    }
108}
109
110#[derive(Deserialize, Clone, Hash, PartialEq, Eq, Debug)]
111pub struct FilePart {
112    data: String,
113    file_name: String,
114}
115
116impl LuaUserData for RequestWrapper {
117    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
118        methods.add_method("header", |_, this, (key, value): (String, String)| {
119            this.apply(|b| Ok(b.header(key, value)))?;
120            Ok(this.clone())
121        });
122
123        methods.add_method("headers", |_, this, headers: HashMap<String, String>| {
124            for (key, value) in headers {
125                this.apply(|b| Ok(b.header(key, value)))?;
126            }
127            Ok(this.clone())
128        });
129
130        methods.add_method("timeout", |_, this, duration: Value| {
131            let duration = match duration {
132                Value::Number(n) => std::time::Duration::from_secs_f64(n),
133                Value::String(s) => {
134                    let s = s.to_str()?;
135                    humantime::parse_duration(&s).map_err(any_err)?
136                }
137                _ => {
138                    return Err(mlua::Error::external("invalid timeout duration"));
139                }
140            };
141            this.apply(|b| Ok(b.timeout(duration)))?;
142            Ok(this.clone())
143        });
144
145        methods.add_method(
146            "basic_auth",
147            |_, this, (username, password): (String, Option<String>)| {
148                this.apply(|b| Ok(b.basic_auth(username, password)))?;
149                Ok(this.clone())
150            },
151        );
152
153        methods.add_method("bearer_auth", |_, this, token: String| {
154            this.apply(|b| Ok(b.bearer_auth(token)))?;
155            Ok(this.clone())
156        });
157
158        methods.add_method("body", |_, this, body: String| {
159            this.apply(|b| Ok(b.body(body)))?;
160            Ok(this.clone())
161        });
162
163        methods.add_method(
164            "form_url_encoded",
165            |_, this, params: HashMap<String, String>| {
166                this.apply(|b| Ok(b.form(&params)))?;
167                Ok(this.clone())
168            },
169        );
170
171        methods.add_method(
172            "form_multipart_data",
173            |lua, this, params: HashMap<String, mlua::Value>| {
174                // Generate a MIME body from the provided parameters
175                use mail_builder::headers::text::Text;
176                use mail_builder::headers::HeaderType;
177                use mail_builder::mime::MimePart;
178                use mailparse::MailHeaderMap;
179                use std::borrow::Cow;
180
181                let mut data = MimePart::new_multipart("multipart/form-data", vec![]);
182
183                for (k, v) in params {
184                    match v {
185                        mlua::Value::String(s) => {
186                            let part = if let Ok(s) = s.to_str() {
187                                MimePart::new_text(Cow::Owned(s.to_string()))
188                            } else {
189                                MimePart::new_binary(
190                                    "application/octet-stream",
191                                    Cow::Owned(s.as_bytes().to_vec()),
192                                )
193                            };
194                            data.add_part(part.header(
195                                "Content-Disposition",
196                                HeaderType::Text(Text::new(format!("form-data; name=\"{k}\""))),
197                            ));
198                        }
199                        _ => {
200                            let file: FilePart = lua.from_value(v.clone())?;
201
202                            let part = MimePart::new_binary(
203                                "application/octet-stream",
204                                file.data.into_bytes(),
205                            );
206                            data.add_part(part.header(
207                                "Content-Disposition",
208                                HeaderType::Text(Text::new(format!(
209                                    "form-data; name=\"{k}\"; filename=\"{}\"",
210                                    file.file_name
211                                ))),
212                            ));
213                        }
214                    }
215                }
216                let builder = mail_builder::MessageBuilder::new();
217                let builder = builder.body(data);
218                let body = builder.write_to_vec().map_err(any_err)?;
219
220                // Now, parse out the Content-Type header so that we can set that in
221                // the request, and get the generated body with its generated boundary
222                // string into a separate variable so that we can assign it as the body
223                // of the HTTP request.
224
225                let (headers, body_offset) = mailparse::parse_headers(&body).map_err(any_err)?;
226
227                let content_type = headers
228                    .get_first_value("Content-Type")
229                    .ok_or_else(|| mlua::Error::external("missing Content-Type!?".to_string()))?;
230
231                let body = &body[body_offset..];
232
233                this.apply(|b| Ok(b.header("Content-Type", content_type).body(body.to_vec())))?;
234
235                Ok(this.clone())
236            },
237        );
238
239        methods.add_async_method("send", |_, this, _: ()| async move {
240            let response = this.send().await?;
241            let status = response.status();
242            Ok(ResponseWrapper {
243                status,
244                response: Arc::new(Mutex::new(Some(response))),
245            })
246        });
247
248        // Convenience helper to apply AWS SigV4 signing to this request
249        methods.add_async_method(
250            "aws_sign_v4",
251            |_lua, this, params: SerdeWrappedValue<mod_aws_sigv4::SigV4Request>| async move {
252                let mut signer_params = params.0;
253
254                // Clone the current RequestBuilder so we can inspect method, URI, and query params
255                let req_builder = this
256                    .builder
257                    .lock()
258                    .unwrap()
259                    .as_ref()
260                    .ok_or_else(|| mlua::Error::external("broken request builder"))?
261                    .try_clone()
262                    .ok_or_else(|| mlua::Error::external("failed to clone request builder"))?;
263
264                let req = req_builder.build().map_err(any_err)?;
265
266                // Populate method, uri and query parameters from the request
267                signer_params.method = req.method().as_str().to_string();
268                signer_params.uri = req.url().path().to_string();
269                signer_params.query_params = req
270                    .url()
271                    .query_pairs()
272                    .map(|(k, v)| (k.to_string(), v.to_string()))
273                    .collect();
274
275                // Ensure the Host header is present
276                if !signer_params.headers.contains_key("host") {
277                    if let Some(host) = req.url().host_str() {
278                        signer_params
279                            .headers
280                            .insert("host".to_string(), host.to_string());
281                    }
282                }
283
284                let sig = mod_aws_sigv4::sign_request(signer_params)
285                    .await
286                    .map_err(any_err)?;
287
288                // Apply the computed signature headers to the underlying request builder
289                this.apply(|b| {
290                    Ok(b.header("Authorization", sig.authorization)
291                        .header("X-Amz-Date", sig.timestamp))
292                })?;
293
294                Ok(this.clone())
295            },
296        );
297    }
298}
299
300// Response ----
301
302#[derive(Clone)]
303struct ResponseWrapper {
304    status: StatusCode,
305    response: Arc<Mutex<Option<Response>>>,
306}
307
308impl ResponseWrapper {
309    fn with<F, T>(&self, func: F) -> mlua::Result<T>
310    where
311        F: FnOnce(&Response) -> anyhow::Result<T>,
312    {
313        let locked = self.response.lock().unwrap();
314        let response = locked
315            .as_ref()
316            .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
317
318        (func)(response).map_err(any_err)
319    }
320
321    async fn text(&self) -> mlua::Result<String> {
322        let r = self
323            .response
324            .lock()
325            .unwrap()
326            .take()
327            .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
328
329        r.text().await.map_err(any_err)
330    }
331
332    async fn bytes(&self, lua: &Lua) -> mlua::Result<mlua::String> {
333        let r = self
334            .response
335            .lock()
336            .unwrap()
337            .take()
338            .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
339
340        let bytes = r.bytes().await.map_err(any_err)?;
341
342        lua.create_string(bytes.as_ref())
343    }
344}
345
346impl LuaUserData for ResponseWrapper {
347    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
348        methods.add_method("status_code", |_, this, _: ()| Ok(this.status.as_u16()));
349        methods.add_method("status_reason", |_, this, _: ()| {
350            Ok(this.status.canonical_reason())
351        });
352        methods.add_method("status_is_informational", |_, this, _: ()| {
353            Ok(this.status.is_informational())
354        });
355        methods.add_method("status_is_success", |_, this, _: ()| {
356            Ok(this.status.is_success())
357        });
358        methods.add_method("status_is_redirection", |_, this, _: ()| {
359            Ok(this.status.is_redirection())
360        });
361        methods.add_method("status_is_client_error", |_, this, _: ()| {
362            Ok(this.status.is_client_error())
363        });
364        methods.add_method("status_is_server_error", |_, this, _: ()| {
365            Ok(this.status.is_server_error())
366        });
367        methods.add_method("headers", |_, this, _: ()| {
368            this.with(|response| Ok(HeaderMapWrapper(response.headers().clone())))
369        });
370        methods.add_method("content_length", |_, this, _: ()| {
371            this.with(|response| Ok(response.content_length()))
372        });
373
374        methods.add_async_method("text", |_, this, _: ()| async move { this.text().await });
375
376        methods.add_async_method(
377            "bytes",
378            |lua, this, _: ()| async move { this.bytes(&lua).await },
379        );
380    }
381}
382
383// Headermap ---
384
385#[derive(Clone, mlua::FromLua)]
386struct HeaderMapWrapper(HeaderMap);
387
388impl LuaUserData for HeaderMapWrapper {
389    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
390        methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
391            if let Some(value) = this.0.get(&key) {
392                let s = lua.create_string(value.as_bytes())?;
393                return Ok(Value::String(s));
394            }
395            Ok(Value::Nil)
396        });
397
398        methods.add_meta_method(MetaMethod::Pairs, |lua, this, ()| {
399            let stateless_iter =
400                lua.create_function(|lua, (this, key): (HeaderMapWrapper, Option<String>)| {
401                    let iter = this.0.iter();
402
403                    let mut this_is_key = false;
404
405                    if key.is_none() {
406                        this_is_key = true;
407                    }
408
409                    for (this_key, value) in iter {
410                        if this_is_key {
411                            let key = lua.create_string(this_key.as_str().as_bytes())?;
412                            let value = lua.create_string(value.as_bytes())?;
413
414                            return Ok(mlua::MultiValue::from_vec(vec![
415                                Value::String(key),
416                                Value::String(value),
417                            ]));
418                        }
419                        if Some(this_key.as_str()) == key.as_deref() {
420                            this_is_key = true;
421                        }
422                    }
423                    Ok(mlua::MultiValue::new())
424                })?;
425            Ok((stateless_iter, this.clone(), Value::Nil))
426        });
427    }
428}
429
430#[derive(Clone)]
431struct WebSocketStream {
432    stream: Arc<
433        TokioMutex<
434            tokio_tungstenite::WebSocketStream<
435                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
436            >,
437        >,
438    >,
439}
440
441impl LuaUserData for WebSocketStream {
442    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
443        methods.add_async_method("recv", |lua, this, ()| async move {
444            let maybe_msg = {
445                let mut stream = this.stream.lock().await;
446                stream.next().await
447            };
448            let msg = match maybe_msg {
449                Some(msg) => msg.map_err(any_err)?,
450                None => return Ok(None),
451            };
452            Ok(match msg {
453                Message::Text(s) => Some(lua.create_string(&s)?),
454                Message::Close(_close_frame) => {
455                    return Ok(None);
456                }
457                Message::Pong(s) | Message::Binary(s) => Some(lua.create_string(&s)?),
458                Message::Ping(_) | Message::Frame(_) => {
459                    unreachable!()
460                }
461            })
462        });
463
464        methods.add_async_method("recv_batch", |lua, this, duration| async move {
465            let duration = match duration {
466                Value::Number(n) => std::time::Duration::from_secs_f64(n),
467                Value::String(s) => {
468                    let s = s.to_str()?;
469                    humantime::parse_duration(&s).map_err(any_err)?
470                }
471                _ => {
472                    return Err(mlua::Error::external("invalid timeout duration"));
473                }
474            };
475            let deadline = Instant::now() + duration;
476            let mut messages = vec![];
477            while let Ok(maybe_msg) = tokio::time::timeout_at(deadline, async {
478                let mut stream = this.stream.lock().await;
479                stream.next().await
480            })
481            .await
482            {
483                let msg = match maybe_msg {
484                    Some(msg) => msg.map_err(any_err)?,
485                    None => {
486                        if messages.is_empty() {
487                            return Err(mlua::Error::external("websocket closed"));
488                        }
489                        break;
490                    }
491                };
492                match msg {
493                    Message::Text(s) => messages.push(lua.create_string(&s)?),
494                    Message::Close(_close_frame) => {
495                        if messages.is_empty() {
496                            return Err(mlua::Error::external("websocket closed"));
497                        }
498                        break;
499                    }
500                    Message::Pong(s) | Message::Binary(s) => messages.push(lua.create_string(&s)?),
501                    Message::Ping(_) | Message::Frame(_) => {
502                        unreachable!()
503                    }
504                }
505            }
506
507            Ok(messages)
508        });
509    }
510}
511
512pub fn register(lua: &Lua) -> anyhow::Result<()> {
513    let http_mod = get_or_create_sub_module(lua, "http")?;
514
515    http_mod.set(
516        "build_url",
517        lua.create_function(|_lua, (url, params): (String, HashMap<String, String>)| {
518            let url = Url::parse_with_params(&url, params.into_iter()).map_err(any_err)?;
519            let url: String = url.into();
520            Ok(url)
521        })?,
522    )?;
523
524    http_mod.set(
525        "build_client",
526        lua.create_function(|lua, options: Value| {
527            let options: ClientOptions = from_lua_value(lua, options)?;
528            let mut builder = ClientBuilder::new().timeout(
529                options
530                    .timeout
531                    .unwrap_or_else(|| std::time::Duration::from_secs(60)),
532            );
533
534            if let Some(verbose) = options.connection_verbose {
535                builder = builder.connection_verbose(verbose);
536            }
537
538            if let Some(idle) = options.pool_idle_timeout {
539                builder = builder.pool_idle_timeout(idle);
540            }
541
542            if let Some(user_agent) = options.user_agent {
543                builder = builder.user_agent(user_agent);
544            }
545
546            if let Some(accept_invalid_certs) = options.accept_invalid_certs {
547                builder = builder.danger_accept_invalid_certs(accept_invalid_certs)
548            }
549
550            let client = builder.build().map_err(any_err)?;
551            Ok(ClientWrapper {
552                client: Arc::new(Mutex::new(Some(Arc::new(client)))),
553            })
554        })?,
555    )?;
556
557    http_mod.set(
558        "connect_websocket",
559        lua.create_async_function(|_, url: String| async move {
560            let (stream, response) = tokio_tungstenite::connect_async(url)
561                .await
562                .map_err(any_err)?;
563            let stream = WebSocketStream {
564                stream: Arc::new(TokioMutex::new(stream)),
565            };
566
567            // Adapt the retured http::response into a reqwest::Response
568            // so that we can use our existing ResponseWrapper type with it
569            let status = response.status();
570            let (parts, body) = response.into_parts();
571            let body = Body::from(body.unwrap_or_else(std::vec::Vec::new));
572            let response = tokio_tungstenite::tungstenite::http::Response::from_parts(parts, body);
573
574            let response = ResponseWrapper {
575                status,
576                response: Arc::new(Mutex::new(Some(Response::from(response)))),
577            };
578
579            Ok((stream, response))
580        })?,
581    )?;
582
583    Ok(())
584}