mod_http/
lib.rs

1use config::{any_err, from_lua_value, get_or_create_sub_module};
2use futures_util::StreamExt;
3use mlua::prelude::LuaUserData;
4use mlua::{Lua, LuaSerdeExt, MetaMethod, UserDataMethods, Value};
5use reqwest::header::HeaderMap;
6use reqwest::{Body, Client, ClientBuilder, RequestBuilder, Response, StatusCode, Url};
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use tokio::sync::Mutex as TokioMutex;
12use tokio::time::Instant;
13use tokio_tungstenite::tungstenite::Message;
14
15// Client ----
16
17#[derive(Deserialize, Debug, Clone)]
18struct ClientOptions {
19    #[serde(default)]
20    user_agent: Option<String>,
21    #[serde(default)]
22    connection_verbose: Option<bool>,
23    #[serde(default, with = "duration_serde")]
24    pool_idle_timeout: Option<Duration>,
25    #[serde(default, with = "duration_serde")]
26    timeout: Option<Duration>,
27}
28
29#[derive(Clone)]
30struct ClientWrapper {
31    client: Arc<Mutex<Option<Arc<Client>>>>,
32}
33
34impl ClientWrapper {
35    fn get_client(&self) -> mlua::Result<Arc<Client>> {
36        let inner = self.client.lock().unwrap();
37        inner
38            .as_ref()
39            .map(Arc::clone)
40            .ok_or_else(|| mlua::Error::external("client was closed"))
41    }
42}
43
44impl LuaUserData for ClientWrapper {
45    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
46        methods.add_method("get", |_, this, url: String| {
47            let builder = this.get_client()?.get(url);
48            Ok(RequestWrapper::new(builder))
49        });
50        methods.add_method("post", |_, this, url: String| {
51            let builder = this.get_client()?.post(url);
52            Ok(RequestWrapper::new(builder))
53        });
54        methods.add_method("put", |_, this, url: String| {
55            let builder = this.get_client()?.put(url);
56            Ok(RequestWrapper::new(builder))
57        });
58        methods.add_method("close", |_, this, _: ()| {
59            this.client.lock().unwrap().take();
60            Ok(())
61        });
62    }
63}
64
65// Request ----
66
67#[derive(Clone)]
68struct RequestWrapper {
69    builder: Arc<Mutex<Option<RequestBuilder>>>,
70}
71
72impl RequestWrapper {
73    fn new(builder: RequestBuilder) -> Self {
74        Self {
75            builder: Arc::new(Mutex::new(Some(builder))),
76        }
77    }
78
79    fn apply<F>(&self, func: F) -> mlua::Result<()>
80    where
81        F: FnOnce(RequestBuilder) -> anyhow::Result<RequestBuilder>,
82    {
83        let b = self
84            .builder
85            .lock()
86            .unwrap()
87            .take()
88            .ok_or_else(|| mlua::Error::external("broken request builder"))?;
89
90        let b = (func)(b).map_err(any_err)?;
91
92        self.builder.lock().unwrap().replace(b);
93        Ok(())
94    }
95
96    async fn send(&self) -> mlua::Result<Response> {
97        let b = self
98            .builder
99            .lock()
100            .unwrap()
101            .take()
102            .ok_or_else(|| mlua::Error::external("broken request builder"))?;
103
104        b.send().await.map_err(any_err)
105    }
106}
107
108#[derive(Deserialize, Clone, Hash, PartialEq, Eq, Debug)]
109pub struct FilePart {
110    data: String,
111    file_name: String,
112}
113
114impl LuaUserData for RequestWrapper {
115    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
116        methods.add_method("header", |_, this, (key, value): (String, String)| {
117            this.apply(|b| Ok(b.header(key, value)))?;
118            Ok(this.clone())
119        });
120
121        methods.add_method("headers", |_, this, headers: HashMap<String, String>| {
122            for (key, value) in headers {
123                this.apply(|b| Ok(b.header(key, value)))?;
124            }
125            Ok(this.clone())
126        });
127
128        methods.add_method("timeout", |_, this, duration: Value| {
129            let duration = match duration {
130                Value::Number(n) => std::time::Duration::from_secs_f64(n),
131                Value::String(s) => {
132                    let s = s.to_str()?;
133                    humantime::parse_duration(&s).map_err(any_err)?
134                }
135                _ => {
136                    return Err(mlua::Error::external("invalid timeout duration"));
137                }
138            };
139            this.apply(|b| Ok(b.timeout(duration)))?;
140            Ok(this.clone())
141        });
142
143        methods.add_method(
144            "basic_auth",
145            |_, this, (username, password): (String, Option<String>)| {
146                this.apply(|b| Ok(b.basic_auth(username, password)))?;
147                Ok(this.clone())
148            },
149        );
150
151        methods.add_method("bearer_auth", |_, this, token: String| {
152            this.apply(|b| Ok(b.bearer_auth(token)))?;
153            Ok(this.clone())
154        });
155
156        methods.add_method("body", |_, this, body: String| {
157            this.apply(|b| Ok(b.body(body)))?;
158            Ok(this.clone())
159        });
160
161        methods.add_method(
162            "form_url_encoded",
163            |_, this, params: HashMap<String, String>| {
164                this.apply(|b| Ok(b.form(&params)))?;
165                Ok(this.clone())
166            },
167        );
168
169        methods.add_method(
170            "form_multipart_data",
171            |lua, this, params: HashMap<String, mlua::Value>| {
172                // Generate a MIME body from the provided parameters
173                use mail_builder::headers::text::Text;
174                use mail_builder::headers::HeaderType;
175                use mail_builder::mime::MimePart;
176                use mailparse::MailHeaderMap;
177                use std::borrow::Cow;
178
179                let mut data = MimePart::new_multipart("multipart/form-data", vec![]);
180
181                for (k, v) in params {
182                    match v {
183                        mlua::Value::String(s) => {
184                            let part = if let Ok(s) = s.to_str() {
185                                MimePart::new_text(Cow::Owned(s.to_string()))
186                            } else {
187                                MimePart::new_binary(
188                                    "application/octet-stream",
189                                    Cow::Owned(s.as_bytes().to_vec()),
190                                )
191                            };
192                            data.add_part(part.header(
193                                "Content-Disposition",
194                                HeaderType::Text(Text::new(format!("form-data; name=\"{k}\""))),
195                            ));
196                        }
197                        _ => {
198                            let file: FilePart = lua.from_value(v.clone())?;
199
200                            let part = MimePart::new_binary(
201                                "application/octet-stream",
202                                file.data.into_bytes(),
203                            );
204                            data.add_part(part.header(
205                                "Content-Disposition",
206                                HeaderType::Text(Text::new(format!(
207                                    "form-data; name=\"{k}\"; filename=\"{}\"",
208                                    file.file_name
209                                ))),
210                            ));
211                        }
212                    }
213                }
214                let builder = mail_builder::MessageBuilder::new();
215                let builder = builder.body(data);
216                let body = builder.write_to_vec().map_err(any_err)?;
217
218                // Now, parse out the Content-Type header so that we can set that in
219                // the request, and get the generated body with its generated boundary
220                // string into a separate variable so that we can assign it as the body
221                // of the HTTP request.
222
223                let (headers, body_offset) = mailparse::parse_headers(&body).map_err(any_err)?;
224
225                let content_type = headers
226                    .get_first_value("Content-Type")
227                    .ok_or_else(|| mlua::Error::external("missing Content-Type!?".to_string()))?;
228
229                let body = &body[body_offset..];
230
231                this.apply(|b| Ok(b.header("Content-Type", content_type).body(body.to_vec())))?;
232
233                Ok(this.clone())
234            },
235        );
236
237        methods.add_async_method("send", |_, this, _: ()| async move {
238            let response = this.send().await?;
239            let status = response.status();
240            Ok(ResponseWrapper {
241                status,
242                response: Arc::new(Mutex::new(Some(response))),
243            })
244        });
245    }
246}
247
248// Response ----
249
250#[derive(Clone)]
251struct ResponseWrapper {
252    status: StatusCode,
253    response: Arc<Mutex<Option<Response>>>,
254}
255
256impl ResponseWrapper {
257    fn with<F, T>(&self, func: F) -> mlua::Result<T>
258    where
259        F: FnOnce(&Response) -> anyhow::Result<T>,
260    {
261        let locked = self.response.lock().unwrap();
262        let response = locked
263            .as_ref()
264            .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
265
266        (func)(response).map_err(any_err)
267    }
268
269    async fn text(&self) -> mlua::Result<String> {
270        let r = self
271            .response
272            .lock()
273            .unwrap()
274            .take()
275            .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
276
277        r.text().await.map_err(any_err)
278    }
279
280    async fn bytes(&self, lua: &Lua) -> mlua::Result<mlua::String> {
281        let r = self
282            .response
283            .lock()
284            .unwrap()
285            .take()
286            .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
287
288        let bytes = r.bytes().await.map_err(any_err)?;
289
290        lua.create_string(bytes.as_ref())
291    }
292}
293
294impl LuaUserData for ResponseWrapper {
295    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
296        methods.add_method("status_code", |_, this, _: ()| Ok(this.status.as_u16()));
297        methods.add_method("status_reason", |_, this, _: ()| {
298            Ok(this.status.canonical_reason())
299        });
300        methods.add_method("status_is_informational", |_, this, _: ()| {
301            Ok(this.status.is_informational())
302        });
303        methods.add_method("status_is_success", |_, this, _: ()| {
304            Ok(this.status.is_success())
305        });
306        methods.add_method("status_is_redirection", |_, this, _: ()| {
307            Ok(this.status.is_redirection())
308        });
309        methods.add_method("status_is_client_error", |_, this, _: ()| {
310            Ok(this.status.is_client_error())
311        });
312        methods.add_method("status_is_server_error", |_, this, _: ()| {
313            Ok(this.status.is_server_error())
314        });
315        methods.add_method("headers", |_, this, _: ()| {
316            this.with(|response| Ok(HeaderMapWrapper(response.headers().clone())))
317        });
318        methods.add_method("content_length", |_, this, _: ()| {
319            this.with(|response| Ok(response.content_length()))
320        });
321
322        methods.add_async_method("text", |_, this, _: ()| async move { this.text().await });
323
324        methods.add_async_method(
325            "bytes",
326            |lua, this, _: ()| async move { this.bytes(&lua).await },
327        );
328    }
329}
330
331// Headermap ---
332
333#[derive(Clone, mlua::FromLua)]
334struct HeaderMapWrapper(HeaderMap);
335
336impl LuaUserData for HeaderMapWrapper {
337    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
338        methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
339            if let Some(value) = this.0.get(&key) {
340                let s = lua.create_string(value.as_bytes())?;
341                return Ok(Value::String(s));
342            }
343            Ok(Value::Nil)
344        });
345
346        methods.add_meta_method(MetaMethod::Pairs, |lua, this, ()| {
347            let stateless_iter =
348                lua.create_function(|lua, (this, key): (HeaderMapWrapper, Option<String>)| {
349                    let iter = this.0.iter();
350
351                    let mut this_is_key = false;
352
353                    if key.is_none() {
354                        this_is_key = true;
355                    }
356
357                    for (this_key, value) in iter {
358                        if this_is_key {
359                            let key = lua.create_string(this_key.as_str().as_bytes())?;
360                            let value = lua.create_string(value.as_bytes())?;
361
362                            return Ok(mlua::MultiValue::from_vec(vec![
363                                Value::String(key),
364                                Value::String(value),
365                            ]));
366                        }
367                        if Some(this_key.as_str()) == key.as_deref() {
368                            this_is_key = true;
369                        }
370                    }
371                    Ok(mlua::MultiValue::new())
372                })?;
373            Ok((stateless_iter, this.clone(), Value::Nil))
374        });
375    }
376}
377
378#[derive(Clone)]
379struct WebSocketStream {
380    stream: Arc<
381        TokioMutex<
382            tokio_tungstenite::WebSocketStream<
383                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
384            >,
385        >,
386    >,
387}
388
389impl LuaUserData for WebSocketStream {
390    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
391        methods.add_async_method("recv", |lua, this, ()| async move {
392            let maybe_msg = {
393                let mut stream = this.stream.lock().await;
394                stream.next().await
395            };
396            let msg = match maybe_msg {
397                Some(msg) => msg.map_err(any_err)?,
398                None => return Ok(None),
399            };
400            Ok(match msg {
401                Message::Text(s) => Some(lua.create_string(&s)?),
402                Message::Close(_close_frame) => {
403                    return Ok(None);
404                }
405                Message::Pong(s) | Message::Binary(s) => Some(lua.create_string(&s)?),
406                Message::Ping(_) | Message::Frame(_) => {
407                    unreachable!()
408                }
409            })
410        });
411
412        methods.add_async_method("recv_batch", |lua, this, duration| async move {
413            let duration = match duration {
414                Value::Number(n) => std::time::Duration::from_secs_f64(n),
415                Value::String(s) => {
416                    let s = s.to_str()?;
417                    humantime::parse_duration(&s).map_err(any_err)?
418                }
419                _ => {
420                    return Err(mlua::Error::external("invalid timeout duration"));
421                }
422            };
423            let deadline = Instant::now() + duration;
424            let mut messages = vec![];
425            while let Ok(maybe_msg) = tokio::time::timeout_at(deadline, async {
426                let mut stream = this.stream.lock().await;
427                stream.next().await
428            })
429            .await
430            {
431                let msg = match maybe_msg {
432                    Some(msg) => msg.map_err(any_err)?,
433                    None => {
434                        if messages.is_empty() {
435                            return Err(mlua::Error::external("websocket closed"));
436                        }
437                        break;
438                    }
439                };
440                match msg {
441                    Message::Text(s) => messages.push(lua.create_string(&s)?),
442                    Message::Close(_close_frame) => {
443                        if messages.is_empty() {
444                            return Err(mlua::Error::external("websocket closed"));
445                        }
446                        break;
447                    }
448                    Message::Pong(s) | Message::Binary(s) => messages.push(lua.create_string(&s)?),
449                    Message::Ping(_) | Message::Frame(_) => {
450                        unreachable!()
451                    }
452                }
453            }
454
455            Ok(messages)
456        });
457    }
458}
459
460pub fn register(lua: &Lua) -> anyhow::Result<()> {
461    let http_mod = get_or_create_sub_module(lua, "http")?;
462
463    http_mod.set(
464        "build_url",
465        lua.create_function(|_lua, (url, params): (String, HashMap<String, String>)| {
466            let url = Url::parse_with_params(&url, params.into_iter()).map_err(any_err)?;
467            let url: String = url.into();
468            Ok(url)
469        })?,
470    )?;
471
472    http_mod.set(
473        "build_client",
474        lua.create_function(|lua, options: Value| {
475            let options: ClientOptions = from_lua_value(lua, options)?;
476            let mut builder = ClientBuilder::new().timeout(
477                options
478                    .timeout
479                    .unwrap_or_else(|| std::time::Duration::from_secs(60)),
480            );
481
482            if let Some(verbose) = options.connection_verbose {
483                builder = builder.connection_verbose(verbose);
484            }
485
486            if let Some(idle) = options.pool_idle_timeout {
487                builder = builder.pool_idle_timeout(idle);
488            }
489
490            if let Some(user_agent) = options.user_agent {
491                builder = builder.user_agent(user_agent);
492            }
493
494            let client = builder.build().map_err(any_err)?;
495            Ok(ClientWrapper {
496                client: Arc::new(Mutex::new(Some(Arc::new(client)))),
497            })
498        })?,
499    )?;
500
501    http_mod.set(
502        "connect_websocket",
503        lua.create_async_function(|_, url: String| async move {
504            let (stream, response) = tokio_tungstenite::connect_async(url)
505                .await
506                .map_err(any_err)?;
507            let stream = WebSocketStream {
508                stream: Arc::new(TokioMutex::new(stream)),
509            };
510
511            // Adapt the retured http::response into a reqwest::Response
512            // so that we can use our existing ResponseWrapper type with it
513            let status = response.status();
514            let (parts, body) = response.into_parts();
515            let body = Body::from(body.unwrap_or_else(std::vec::Vec::new));
516            let response = tokio_tungstenite::tungstenite::http::Response::from_parts(parts, body);
517
518            let response = ResponseWrapper {
519                status,
520                response: Arc::new(Mutex::new(Some(Response::from(response)))),
521            };
522
523            Ok((stream, response))
524        })?,
525    )?;
526
527    Ok(())
528}