mod_http/
lib.rs

1use config::{any_err, from_lua_value, get_or_create_sub_module, SerdeWrappedValue};
2use futures_util::StreamExt;
3use mlua::prelude::LuaUserData;
4use mlua::{Lua, LuaSerdeExt, MetaMethod, UserDataMethods, Value};
5use reqwest::header::HeaderMap;
6use reqwest::{Body, Client, ClientBuilder, RequestBuilder, Response, StatusCode, Url};
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use tokio::sync::Mutex as TokioMutex;
12use tokio::time::Instant;
13use tokio_tungstenite::tungstenite::Message;
14
15// Client ----
16
17#[derive(Deserialize, Debug, Clone)]
18struct ClientOptions {
19    #[serde(default)]
20    user_agent: Option<String>,
21    #[serde(default)]
22    connection_verbose: Option<bool>,
23    #[serde(default, with = "duration_serde")]
24    pool_idle_timeout: Option<Duration>,
25    #[serde(default, with = "duration_serde")]
26    timeout: Option<Duration>,
27}
28
29#[derive(Clone)]
30struct ClientWrapper {
31    client: Arc<Mutex<Option<Arc<Client>>>>,
32}
33
34impl ClientWrapper {
35    fn get_client(&self) -> mlua::Result<Arc<Client>> {
36        let inner = self.client.lock().unwrap();
37        inner
38            .as_ref()
39            .map(Arc::clone)
40            .ok_or_else(|| mlua::Error::external("client was closed"))
41    }
42}
43
44impl LuaUserData for ClientWrapper {
45    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
46        methods.add_method("get", |_, this, url: String| {
47            let builder = this.get_client()?.get(url);
48            Ok(RequestWrapper::new(builder))
49        });
50        methods.add_method("post", |_, this, url: String| {
51            let builder = this.get_client()?.post(url);
52            Ok(RequestWrapper::new(builder))
53        });
54        methods.add_method("put", |_, this, url: String| {
55            let builder = this.get_client()?.put(url);
56            Ok(RequestWrapper::new(builder))
57        });
58        methods.add_method("close", |_, this, _: ()| {
59            this.client.lock().unwrap().take();
60            Ok(())
61        });
62    }
63}
64
65// Request ----
66
67#[derive(Clone)]
68struct RequestWrapper {
69    builder: Arc<Mutex<Option<RequestBuilder>>>,
70}
71
72impl RequestWrapper {
73    fn new(builder: RequestBuilder) -> Self {
74        Self {
75            builder: Arc::new(Mutex::new(Some(builder))),
76        }
77    }
78
79    fn apply<F>(&self, func: F) -> mlua::Result<()>
80    where
81        F: FnOnce(RequestBuilder) -> anyhow::Result<RequestBuilder>,
82    {
83        let b = self
84            .builder
85            .lock()
86            .unwrap()
87            .take()
88            .ok_or_else(|| mlua::Error::external("broken request builder"))?;
89
90        let b = (func)(b).map_err(any_err)?;
91
92        self.builder.lock().unwrap().replace(b);
93        Ok(())
94    }
95
96    async fn send(&self) -> mlua::Result<Response> {
97        let b = self
98            .builder
99            .lock()
100            .unwrap()
101            .take()
102            .ok_or_else(|| mlua::Error::external("broken request builder"))?;
103
104        b.send().await.map_err(any_err)
105    }
106}
107
108#[derive(Deserialize, Clone, Hash, PartialEq, Eq, Debug)]
109pub struct FilePart {
110    data: String,
111    file_name: String,
112}
113
114impl LuaUserData for RequestWrapper {
115    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
116        methods.add_method("header", |_, this, (key, value): (String, String)| {
117            this.apply(|b| Ok(b.header(key, value)))?;
118            Ok(this.clone())
119        });
120
121        methods.add_method("headers", |_, this, headers: HashMap<String, String>| {
122            for (key, value) in headers {
123                this.apply(|b| Ok(b.header(key, value)))?;
124            }
125            Ok(this.clone())
126        });
127
128        methods.add_method("timeout", |_, this, duration: Value| {
129            let duration = match duration {
130                Value::Number(n) => std::time::Duration::from_secs_f64(n),
131                Value::String(s) => {
132                    let s = s.to_str()?;
133                    humantime::parse_duration(&s).map_err(any_err)?
134                }
135                _ => {
136                    return Err(mlua::Error::external("invalid timeout duration"));
137                }
138            };
139            this.apply(|b| Ok(b.timeout(duration)))?;
140            Ok(this.clone())
141        });
142
143        methods.add_method(
144            "basic_auth",
145            |_, this, (username, password): (String, Option<String>)| {
146                this.apply(|b| Ok(b.basic_auth(username, password)))?;
147                Ok(this.clone())
148            },
149        );
150
151        methods.add_method("bearer_auth", |_, this, token: String| {
152            this.apply(|b| Ok(b.bearer_auth(token)))?;
153            Ok(this.clone())
154        });
155
156        methods.add_method("body", |_, this, body: String| {
157            this.apply(|b| Ok(b.body(body)))?;
158            Ok(this.clone())
159        });
160
161        methods.add_method(
162            "form_url_encoded",
163            |_, this, params: HashMap<String, String>| {
164                this.apply(|b| Ok(b.form(&params)))?;
165                Ok(this.clone())
166            },
167        );
168
169        methods.add_method(
170            "form_multipart_data",
171            |lua, this, params: HashMap<String, mlua::Value>| {
172                // Generate a MIME body from the provided parameters
173                use mail_builder::headers::text::Text;
174                use mail_builder::headers::HeaderType;
175                use mail_builder::mime::MimePart;
176                use mailparse::MailHeaderMap;
177                use std::borrow::Cow;
178
179                let mut data = MimePart::new_multipart("multipart/form-data", vec![]);
180
181                for (k, v) in params {
182                    match v {
183                        mlua::Value::String(s) => {
184                            let part = if let Ok(s) = s.to_str() {
185                                MimePart::new_text(Cow::Owned(s.to_string()))
186                            } else {
187                                MimePart::new_binary(
188                                    "application/octet-stream",
189                                    Cow::Owned(s.as_bytes().to_vec()),
190                                )
191                            };
192                            data.add_part(part.header(
193                                "Content-Disposition",
194                                HeaderType::Text(Text::new(format!("form-data; name=\"{k}\""))),
195                            ));
196                        }
197                        _ => {
198                            let file: FilePart = lua.from_value(v.clone())?;
199
200                            let part = MimePart::new_binary(
201                                "application/octet-stream",
202                                file.data.into_bytes(),
203                            );
204                            data.add_part(part.header(
205                                "Content-Disposition",
206                                HeaderType::Text(Text::new(format!(
207                                    "form-data; name=\"{k}\"; filename=\"{}\"",
208                                    file.file_name
209                                ))),
210                            ));
211                        }
212                    }
213                }
214                let builder = mail_builder::MessageBuilder::new();
215                let builder = builder.body(data);
216                let body = builder.write_to_vec().map_err(any_err)?;
217
218                // Now, parse out the Content-Type header so that we can set that in
219                // the request, and get the generated body with its generated boundary
220                // string into a separate variable so that we can assign it as the body
221                // of the HTTP request.
222
223                let (headers, body_offset) = mailparse::parse_headers(&body).map_err(any_err)?;
224
225                let content_type = headers
226                    .get_first_value("Content-Type")
227                    .ok_or_else(|| mlua::Error::external("missing Content-Type!?".to_string()))?;
228
229                let body = &body[body_offset..];
230
231                this.apply(|b| Ok(b.header("Content-Type", content_type).body(body.to_vec())))?;
232
233                Ok(this.clone())
234            },
235        );
236
237        methods.add_async_method("send", |_, this, _: ()| async move {
238            let response = this.send().await?;
239            let status = response.status();
240            Ok(ResponseWrapper {
241                status,
242                response: Arc::new(Mutex::new(Some(response))),
243            })
244        });
245
246        // Convenience helper to apply AWS SigV4 signing to this request
247        methods.add_async_method(
248            "aws_sign_v4",
249            |_lua, this, params: SerdeWrappedValue<mod_aws_sigv4::SigV4Request>| async move {
250                let mut signer_params = params.0;
251
252                // Clone the current RequestBuilder so we can inspect method, URI, and query params
253                let req_builder = this
254                    .builder
255                    .lock()
256                    .unwrap()
257                    .as_ref()
258                    .ok_or_else(|| mlua::Error::external("broken request builder"))?
259                    .try_clone()
260                    .ok_or_else(|| mlua::Error::external("failed to clone request builder"))?;
261
262                let req = req_builder.build().map_err(any_err)?;
263
264                // Populate method, uri and query parameters from the request
265                signer_params.method = req.method().as_str().to_string();
266                signer_params.uri = req.url().path().to_string();
267                signer_params.query_params = req
268                    .url()
269                    .query_pairs()
270                    .map(|(k, v)| (k.to_string(), v.to_string()))
271                    .collect();
272
273                // Ensure the Host header is present
274                if !signer_params.headers.contains_key("host") {
275                    if let Some(host) = req.url().host_str() {
276                        signer_params
277                            .headers
278                            .insert("host".to_string(), host.to_string());
279                    }
280                }
281
282                let sig = mod_aws_sigv4::sign_request(signer_params)
283                    .await
284                    .map_err(any_err)?;
285
286                // Apply the computed signature headers to the underlying request builder
287                this.apply(|b| {
288                    Ok(b.header("Authorization", sig.authorization)
289                        .header("X-Amz-Date", sig.timestamp))
290                })?;
291
292                Ok(this.clone())
293            },
294        );
295    }
296}
297
298// Response ----
299
300#[derive(Clone)]
301struct ResponseWrapper {
302    status: StatusCode,
303    response: Arc<Mutex<Option<Response>>>,
304}
305
306impl ResponseWrapper {
307    fn with<F, T>(&self, func: F) -> mlua::Result<T>
308    where
309        F: FnOnce(&Response) -> anyhow::Result<T>,
310    {
311        let locked = self.response.lock().unwrap();
312        let response = locked
313            .as_ref()
314            .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
315
316        (func)(response).map_err(any_err)
317    }
318
319    async fn text(&self) -> mlua::Result<String> {
320        let r = self
321            .response
322            .lock()
323            .unwrap()
324            .take()
325            .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
326
327        r.text().await.map_err(any_err)
328    }
329
330    async fn bytes(&self, lua: &Lua) -> mlua::Result<mlua::String> {
331        let r = self
332            .response
333            .lock()
334            .unwrap()
335            .take()
336            .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
337
338        let bytes = r.bytes().await.map_err(any_err)?;
339
340        lua.create_string(bytes.as_ref())
341    }
342}
343
344impl LuaUserData for ResponseWrapper {
345    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
346        methods.add_method("status_code", |_, this, _: ()| Ok(this.status.as_u16()));
347        methods.add_method("status_reason", |_, this, _: ()| {
348            Ok(this.status.canonical_reason())
349        });
350        methods.add_method("status_is_informational", |_, this, _: ()| {
351            Ok(this.status.is_informational())
352        });
353        methods.add_method("status_is_success", |_, this, _: ()| {
354            Ok(this.status.is_success())
355        });
356        methods.add_method("status_is_redirection", |_, this, _: ()| {
357            Ok(this.status.is_redirection())
358        });
359        methods.add_method("status_is_client_error", |_, this, _: ()| {
360            Ok(this.status.is_client_error())
361        });
362        methods.add_method("status_is_server_error", |_, this, _: ()| {
363            Ok(this.status.is_server_error())
364        });
365        methods.add_method("headers", |_, this, _: ()| {
366            this.with(|response| Ok(HeaderMapWrapper(response.headers().clone())))
367        });
368        methods.add_method("content_length", |_, this, _: ()| {
369            this.with(|response| Ok(response.content_length()))
370        });
371
372        methods.add_async_method("text", |_, this, _: ()| async move { this.text().await });
373
374        methods.add_async_method(
375            "bytes",
376            |lua, this, _: ()| async move { this.bytes(&lua).await },
377        );
378    }
379}
380
381// Headermap ---
382
383#[derive(Clone, mlua::FromLua)]
384struct HeaderMapWrapper(HeaderMap);
385
386impl LuaUserData for HeaderMapWrapper {
387    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
388        methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
389            if let Some(value) = this.0.get(&key) {
390                let s = lua.create_string(value.as_bytes())?;
391                return Ok(Value::String(s));
392            }
393            Ok(Value::Nil)
394        });
395
396        methods.add_meta_method(MetaMethod::Pairs, |lua, this, ()| {
397            let stateless_iter =
398                lua.create_function(|lua, (this, key): (HeaderMapWrapper, Option<String>)| {
399                    let iter = this.0.iter();
400
401                    let mut this_is_key = false;
402
403                    if key.is_none() {
404                        this_is_key = true;
405                    }
406
407                    for (this_key, value) in iter {
408                        if this_is_key {
409                            let key = lua.create_string(this_key.as_str().as_bytes())?;
410                            let value = lua.create_string(value.as_bytes())?;
411
412                            return Ok(mlua::MultiValue::from_vec(vec![
413                                Value::String(key),
414                                Value::String(value),
415                            ]));
416                        }
417                        if Some(this_key.as_str()) == key.as_deref() {
418                            this_is_key = true;
419                        }
420                    }
421                    Ok(mlua::MultiValue::new())
422                })?;
423            Ok((stateless_iter, this.clone(), Value::Nil))
424        });
425    }
426}
427
428#[derive(Clone)]
429struct WebSocketStream {
430    stream: Arc<
431        TokioMutex<
432            tokio_tungstenite::WebSocketStream<
433                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
434            >,
435        >,
436    >,
437}
438
439impl LuaUserData for WebSocketStream {
440    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
441        methods.add_async_method("recv", |lua, this, ()| async move {
442            let maybe_msg = {
443                let mut stream = this.stream.lock().await;
444                stream.next().await
445            };
446            let msg = match maybe_msg {
447                Some(msg) => msg.map_err(any_err)?,
448                None => return Ok(None),
449            };
450            Ok(match msg {
451                Message::Text(s) => Some(lua.create_string(&s)?),
452                Message::Close(_close_frame) => {
453                    return Ok(None);
454                }
455                Message::Pong(s) | Message::Binary(s) => Some(lua.create_string(&s)?),
456                Message::Ping(_) | Message::Frame(_) => {
457                    unreachable!()
458                }
459            })
460        });
461
462        methods.add_async_method("recv_batch", |lua, this, duration| async move {
463            let duration = match duration {
464                Value::Number(n) => std::time::Duration::from_secs_f64(n),
465                Value::String(s) => {
466                    let s = s.to_str()?;
467                    humantime::parse_duration(&s).map_err(any_err)?
468                }
469                _ => {
470                    return Err(mlua::Error::external("invalid timeout duration"));
471                }
472            };
473            let deadline = Instant::now() + duration;
474            let mut messages = vec![];
475            while let Ok(maybe_msg) = tokio::time::timeout_at(deadline, async {
476                let mut stream = this.stream.lock().await;
477                stream.next().await
478            })
479            .await
480            {
481                let msg = match maybe_msg {
482                    Some(msg) => msg.map_err(any_err)?,
483                    None => {
484                        if messages.is_empty() {
485                            return Err(mlua::Error::external("websocket closed"));
486                        }
487                        break;
488                    }
489                };
490                match msg {
491                    Message::Text(s) => messages.push(lua.create_string(&s)?),
492                    Message::Close(_close_frame) => {
493                        if messages.is_empty() {
494                            return Err(mlua::Error::external("websocket closed"));
495                        }
496                        break;
497                    }
498                    Message::Pong(s) | Message::Binary(s) => messages.push(lua.create_string(&s)?),
499                    Message::Ping(_) | Message::Frame(_) => {
500                        unreachable!()
501                    }
502                }
503            }
504
505            Ok(messages)
506        });
507    }
508}
509
510pub fn register(lua: &Lua) -> anyhow::Result<()> {
511    let http_mod = get_or_create_sub_module(lua, "http")?;
512
513    http_mod.set(
514        "build_url",
515        lua.create_function(|_lua, (url, params): (String, HashMap<String, String>)| {
516            let url = Url::parse_with_params(&url, params.into_iter()).map_err(any_err)?;
517            let url: String = url.into();
518            Ok(url)
519        })?,
520    )?;
521
522    http_mod.set(
523        "build_client",
524        lua.create_function(|lua, options: Value| {
525            let options: ClientOptions = from_lua_value(lua, options)?;
526            let mut builder = ClientBuilder::new().timeout(
527                options
528                    .timeout
529                    .unwrap_or_else(|| std::time::Duration::from_secs(60)),
530            );
531
532            if let Some(verbose) = options.connection_verbose {
533                builder = builder.connection_verbose(verbose);
534            }
535
536            if let Some(idle) = options.pool_idle_timeout {
537                builder = builder.pool_idle_timeout(idle);
538            }
539
540            if let Some(user_agent) = options.user_agent {
541                builder = builder.user_agent(user_agent);
542            }
543
544            let client = builder.build().map_err(any_err)?;
545            Ok(ClientWrapper {
546                client: Arc::new(Mutex::new(Some(Arc::new(client)))),
547            })
548        })?,
549    )?;
550
551    http_mod.set(
552        "connect_websocket",
553        lua.create_async_function(|_, url: String| async move {
554            let (stream, response) = tokio_tungstenite::connect_async(url)
555                .await
556                .map_err(any_err)?;
557            let stream = WebSocketStream {
558                stream: Arc::new(TokioMutex::new(stream)),
559            };
560
561            // Adapt the retured http::response into a reqwest::Response
562            // so that we can use our existing ResponseWrapper type with it
563            let status = response.status();
564            let (parts, body) = response.into_parts();
565            let body = Body::from(body.unwrap_or_else(std::vec::Vec::new));
566            let response = tokio_tungstenite::tungstenite::http::Response::from_parts(parts, body);
567
568            let response = ResponseWrapper {
569                status,
570                response: Arc::new(Mutex::new(Some(Response::from(response)))),
571            };
572
573            Ok((stream, response))
574        })?,
575    )?;
576
577    Ok(())
578}