1use config::{any_err, from_lua_value, get_or_create_sub_module, SerdeWrappedValue};
2use futures_util::StreamExt;
3use mlua::prelude::LuaUserData;
4use mlua::{Lua, LuaSerdeExt, MetaMethod, UserDataMethods, Value};
5use reqwest::header::HeaderMap;
6use reqwest::{Body, Client, ClientBuilder, RequestBuilder, Response, StatusCode, Url};
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use tokio::sync::Mutex as TokioMutex;
12use tokio::time::Instant;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Deserialize, Debug, Clone)]
18struct ClientOptions {
19 #[serde(default)]
20 user_agent: Option<String>,
21 #[serde(default)]
22 connection_verbose: Option<bool>,
23 #[serde(default, with = "duration_serde")]
24 pool_idle_timeout: Option<Duration>,
25 #[serde(default, with = "duration_serde")]
26 timeout: Option<Duration>,
27}
28
29#[derive(Clone)]
30struct ClientWrapper {
31 client: Arc<Mutex<Option<Arc<Client>>>>,
32}
33
34impl ClientWrapper {
35 fn get_client(&self) -> mlua::Result<Arc<Client>> {
36 let inner = self.client.lock().unwrap();
37 inner
38 .as_ref()
39 .map(Arc::clone)
40 .ok_or_else(|| mlua::Error::external("client was closed"))
41 }
42}
43
44impl LuaUserData for ClientWrapper {
45 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
46 methods.add_method("get", |_, this, url: String| {
47 let builder = this.get_client()?.get(url);
48 Ok(RequestWrapper::new(builder))
49 });
50 methods.add_method("post", |_, this, url: String| {
51 let builder = this.get_client()?.post(url);
52 Ok(RequestWrapper::new(builder))
53 });
54 methods.add_method("put", |_, this, url: String| {
55 let builder = this.get_client()?.put(url);
56 Ok(RequestWrapper::new(builder))
57 });
58 methods.add_method("close", |_, this, _: ()| {
59 this.client.lock().unwrap().take();
60 Ok(())
61 });
62 }
63}
64
65#[derive(Clone)]
68struct RequestWrapper {
69 builder: Arc<Mutex<Option<RequestBuilder>>>,
70}
71
72impl RequestWrapper {
73 fn new(builder: RequestBuilder) -> Self {
74 Self {
75 builder: Arc::new(Mutex::new(Some(builder))),
76 }
77 }
78
79 fn apply<F>(&self, func: F) -> mlua::Result<()>
80 where
81 F: FnOnce(RequestBuilder) -> anyhow::Result<RequestBuilder>,
82 {
83 let b = self
84 .builder
85 .lock()
86 .unwrap()
87 .take()
88 .ok_or_else(|| mlua::Error::external("broken request builder"))?;
89
90 let b = (func)(b).map_err(any_err)?;
91
92 self.builder.lock().unwrap().replace(b);
93 Ok(())
94 }
95
96 async fn send(&self) -> mlua::Result<Response> {
97 let b = self
98 .builder
99 .lock()
100 .unwrap()
101 .take()
102 .ok_or_else(|| mlua::Error::external("broken request builder"))?;
103
104 b.send().await.map_err(any_err)
105 }
106}
107
108#[derive(Deserialize, Clone, Hash, PartialEq, Eq, Debug)]
109pub struct FilePart {
110 data: String,
111 file_name: String,
112}
113
114impl LuaUserData for RequestWrapper {
115 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
116 methods.add_method("header", |_, this, (key, value): (String, String)| {
117 this.apply(|b| Ok(b.header(key, value)))?;
118 Ok(this.clone())
119 });
120
121 methods.add_method("headers", |_, this, headers: HashMap<String, String>| {
122 for (key, value) in headers {
123 this.apply(|b| Ok(b.header(key, value)))?;
124 }
125 Ok(this.clone())
126 });
127
128 methods.add_method("timeout", |_, this, duration: Value| {
129 let duration = match duration {
130 Value::Number(n) => std::time::Duration::from_secs_f64(n),
131 Value::String(s) => {
132 let s = s.to_str()?;
133 humantime::parse_duration(&s).map_err(any_err)?
134 }
135 _ => {
136 return Err(mlua::Error::external("invalid timeout duration"));
137 }
138 };
139 this.apply(|b| Ok(b.timeout(duration)))?;
140 Ok(this.clone())
141 });
142
143 methods.add_method(
144 "basic_auth",
145 |_, this, (username, password): (String, Option<String>)| {
146 this.apply(|b| Ok(b.basic_auth(username, password)))?;
147 Ok(this.clone())
148 },
149 );
150
151 methods.add_method("bearer_auth", |_, this, token: String| {
152 this.apply(|b| Ok(b.bearer_auth(token)))?;
153 Ok(this.clone())
154 });
155
156 methods.add_method("body", |_, this, body: String| {
157 this.apply(|b| Ok(b.body(body)))?;
158 Ok(this.clone())
159 });
160
161 methods.add_method(
162 "form_url_encoded",
163 |_, this, params: HashMap<String, String>| {
164 this.apply(|b| Ok(b.form(¶ms)))?;
165 Ok(this.clone())
166 },
167 );
168
169 methods.add_method(
170 "form_multipart_data",
171 |lua, this, params: HashMap<String, mlua::Value>| {
172 use mail_builder::headers::text::Text;
174 use mail_builder::headers::HeaderType;
175 use mail_builder::mime::MimePart;
176 use mailparse::MailHeaderMap;
177 use std::borrow::Cow;
178
179 let mut data = MimePart::new_multipart("multipart/form-data", vec![]);
180
181 for (k, v) in params {
182 match v {
183 mlua::Value::String(s) => {
184 let part = if let Ok(s) = s.to_str() {
185 MimePart::new_text(Cow::Owned(s.to_string()))
186 } else {
187 MimePart::new_binary(
188 "application/octet-stream",
189 Cow::Owned(s.as_bytes().to_vec()),
190 )
191 };
192 data.add_part(part.header(
193 "Content-Disposition",
194 HeaderType::Text(Text::new(format!("form-data; name=\"{k}\""))),
195 ));
196 }
197 _ => {
198 let file: FilePart = lua.from_value(v.clone())?;
199
200 let part = MimePart::new_binary(
201 "application/octet-stream",
202 file.data.into_bytes(),
203 );
204 data.add_part(part.header(
205 "Content-Disposition",
206 HeaderType::Text(Text::new(format!(
207 "form-data; name=\"{k}\"; filename=\"{}\"",
208 file.file_name
209 ))),
210 ));
211 }
212 }
213 }
214 let builder = mail_builder::MessageBuilder::new();
215 let builder = builder.body(data);
216 let body = builder.write_to_vec().map_err(any_err)?;
217
218 let (headers, body_offset) = mailparse::parse_headers(&body).map_err(any_err)?;
224
225 let content_type = headers
226 .get_first_value("Content-Type")
227 .ok_or_else(|| mlua::Error::external("missing Content-Type!?".to_string()))?;
228
229 let body = &body[body_offset..];
230
231 this.apply(|b| Ok(b.header("Content-Type", content_type).body(body.to_vec())))?;
232
233 Ok(this.clone())
234 },
235 );
236
237 methods.add_async_method("send", |_, this, _: ()| async move {
238 let response = this.send().await?;
239 let status = response.status();
240 Ok(ResponseWrapper {
241 status,
242 response: Arc::new(Mutex::new(Some(response))),
243 })
244 });
245
246 methods.add_async_method(
248 "aws_sign_v4",
249 |_lua, this, params: SerdeWrappedValue<mod_aws_sigv4::SigV4Request>| async move {
250 let mut signer_params = params.0;
251
252 let req_builder = this
254 .builder
255 .lock()
256 .unwrap()
257 .as_ref()
258 .ok_or_else(|| mlua::Error::external("broken request builder"))?
259 .try_clone()
260 .ok_or_else(|| mlua::Error::external("failed to clone request builder"))?;
261
262 let req = req_builder.build().map_err(any_err)?;
263
264 signer_params.method = req.method().as_str().to_string();
266 signer_params.uri = req.url().path().to_string();
267 signer_params.query_params = req
268 .url()
269 .query_pairs()
270 .map(|(k, v)| (k.to_string(), v.to_string()))
271 .collect();
272
273 if !signer_params.headers.contains_key("host") {
275 if let Some(host) = req.url().host_str() {
276 signer_params
277 .headers
278 .insert("host".to_string(), host.to_string());
279 }
280 }
281
282 let sig = mod_aws_sigv4::sign_request(signer_params)
283 .await
284 .map_err(any_err)?;
285
286 this.apply(|b| {
288 Ok(b.header("Authorization", sig.authorization)
289 .header("X-Amz-Date", sig.timestamp))
290 })?;
291
292 Ok(this.clone())
293 },
294 );
295 }
296}
297
298#[derive(Clone)]
301struct ResponseWrapper {
302 status: StatusCode,
303 response: Arc<Mutex<Option<Response>>>,
304}
305
306impl ResponseWrapper {
307 fn with<F, T>(&self, func: F) -> mlua::Result<T>
308 where
309 F: FnOnce(&Response) -> anyhow::Result<T>,
310 {
311 let locked = self.response.lock().unwrap();
312 let response = locked
313 .as_ref()
314 .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
315
316 (func)(response).map_err(any_err)
317 }
318
319 async fn text(&self) -> mlua::Result<String> {
320 let r = self
321 .response
322 .lock()
323 .unwrap()
324 .take()
325 .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
326
327 r.text().await.map_err(any_err)
328 }
329
330 async fn bytes(&self, lua: &Lua) -> mlua::Result<mlua::String> {
331 let r = self
332 .response
333 .lock()
334 .unwrap()
335 .take()
336 .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
337
338 let bytes = r.bytes().await.map_err(any_err)?;
339
340 lua.create_string(bytes.as_ref())
341 }
342}
343
344impl LuaUserData for ResponseWrapper {
345 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
346 methods.add_method("status_code", |_, this, _: ()| Ok(this.status.as_u16()));
347 methods.add_method("status_reason", |_, this, _: ()| {
348 Ok(this.status.canonical_reason())
349 });
350 methods.add_method("status_is_informational", |_, this, _: ()| {
351 Ok(this.status.is_informational())
352 });
353 methods.add_method("status_is_success", |_, this, _: ()| {
354 Ok(this.status.is_success())
355 });
356 methods.add_method("status_is_redirection", |_, this, _: ()| {
357 Ok(this.status.is_redirection())
358 });
359 methods.add_method("status_is_client_error", |_, this, _: ()| {
360 Ok(this.status.is_client_error())
361 });
362 methods.add_method("status_is_server_error", |_, this, _: ()| {
363 Ok(this.status.is_server_error())
364 });
365 methods.add_method("headers", |_, this, _: ()| {
366 this.with(|response| Ok(HeaderMapWrapper(response.headers().clone())))
367 });
368 methods.add_method("content_length", |_, this, _: ()| {
369 this.with(|response| Ok(response.content_length()))
370 });
371
372 methods.add_async_method("text", |_, this, _: ()| async move { this.text().await });
373
374 methods.add_async_method(
375 "bytes",
376 |lua, this, _: ()| async move { this.bytes(&lua).await },
377 );
378 }
379}
380
381#[derive(Clone, mlua::FromLua)]
384struct HeaderMapWrapper(HeaderMap);
385
386impl LuaUserData for HeaderMapWrapper {
387 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
388 methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
389 if let Some(value) = this.0.get(&key) {
390 let s = lua.create_string(value.as_bytes())?;
391 return Ok(Value::String(s));
392 }
393 Ok(Value::Nil)
394 });
395
396 methods.add_meta_method(MetaMethod::Pairs, |lua, this, ()| {
397 let stateless_iter =
398 lua.create_function(|lua, (this, key): (HeaderMapWrapper, Option<String>)| {
399 let iter = this.0.iter();
400
401 let mut this_is_key = false;
402
403 if key.is_none() {
404 this_is_key = true;
405 }
406
407 for (this_key, value) in iter {
408 if this_is_key {
409 let key = lua.create_string(this_key.as_str().as_bytes())?;
410 let value = lua.create_string(value.as_bytes())?;
411
412 return Ok(mlua::MultiValue::from_vec(vec![
413 Value::String(key),
414 Value::String(value),
415 ]));
416 }
417 if Some(this_key.as_str()) == key.as_deref() {
418 this_is_key = true;
419 }
420 }
421 Ok(mlua::MultiValue::new())
422 })?;
423 Ok((stateless_iter, this.clone(), Value::Nil))
424 });
425 }
426}
427
428#[derive(Clone)]
429struct WebSocketStream {
430 stream: Arc<
431 TokioMutex<
432 tokio_tungstenite::WebSocketStream<
433 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
434 >,
435 >,
436 >,
437}
438
439impl LuaUserData for WebSocketStream {
440 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
441 methods.add_async_method("recv", |lua, this, ()| async move {
442 let maybe_msg = {
443 let mut stream = this.stream.lock().await;
444 stream.next().await
445 };
446 let msg = match maybe_msg {
447 Some(msg) => msg.map_err(any_err)?,
448 None => return Ok(None),
449 };
450 Ok(match msg {
451 Message::Text(s) => Some(lua.create_string(&s)?),
452 Message::Close(_close_frame) => {
453 return Ok(None);
454 }
455 Message::Pong(s) | Message::Binary(s) => Some(lua.create_string(&s)?),
456 Message::Ping(_) | Message::Frame(_) => {
457 unreachable!()
458 }
459 })
460 });
461
462 methods.add_async_method("recv_batch", |lua, this, duration| async move {
463 let duration = match duration {
464 Value::Number(n) => std::time::Duration::from_secs_f64(n),
465 Value::String(s) => {
466 let s = s.to_str()?;
467 humantime::parse_duration(&s).map_err(any_err)?
468 }
469 _ => {
470 return Err(mlua::Error::external("invalid timeout duration"));
471 }
472 };
473 let deadline = Instant::now() + duration;
474 let mut messages = vec![];
475 while let Ok(maybe_msg) = tokio::time::timeout_at(deadline, async {
476 let mut stream = this.stream.lock().await;
477 stream.next().await
478 })
479 .await
480 {
481 let msg = match maybe_msg {
482 Some(msg) => msg.map_err(any_err)?,
483 None => {
484 if messages.is_empty() {
485 return Err(mlua::Error::external("websocket closed"));
486 }
487 break;
488 }
489 };
490 match msg {
491 Message::Text(s) => messages.push(lua.create_string(&s)?),
492 Message::Close(_close_frame) => {
493 if messages.is_empty() {
494 return Err(mlua::Error::external("websocket closed"));
495 }
496 break;
497 }
498 Message::Pong(s) | Message::Binary(s) => messages.push(lua.create_string(&s)?),
499 Message::Ping(_) | Message::Frame(_) => {
500 unreachable!()
501 }
502 }
503 }
504
505 Ok(messages)
506 });
507 }
508}
509
510pub fn register(lua: &Lua) -> anyhow::Result<()> {
511 let http_mod = get_or_create_sub_module(lua, "http")?;
512
513 http_mod.set(
514 "build_url",
515 lua.create_function(|_lua, (url, params): (String, HashMap<String, String>)| {
516 let url = Url::parse_with_params(&url, params.into_iter()).map_err(any_err)?;
517 let url: String = url.into();
518 Ok(url)
519 })?,
520 )?;
521
522 http_mod.set(
523 "build_client",
524 lua.create_function(|lua, options: Value| {
525 let options: ClientOptions = from_lua_value(lua, options)?;
526 let mut builder = ClientBuilder::new().timeout(
527 options
528 .timeout
529 .unwrap_or_else(|| std::time::Duration::from_secs(60)),
530 );
531
532 if let Some(verbose) = options.connection_verbose {
533 builder = builder.connection_verbose(verbose);
534 }
535
536 if let Some(idle) = options.pool_idle_timeout {
537 builder = builder.pool_idle_timeout(idle);
538 }
539
540 if let Some(user_agent) = options.user_agent {
541 builder = builder.user_agent(user_agent);
542 }
543
544 let client = builder.build().map_err(any_err)?;
545 Ok(ClientWrapper {
546 client: Arc::new(Mutex::new(Some(Arc::new(client)))),
547 })
548 })?,
549 )?;
550
551 http_mod.set(
552 "connect_websocket",
553 lua.create_async_function(|_, url: String| async move {
554 let (stream, response) = tokio_tungstenite::connect_async(url)
555 .await
556 .map_err(any_err)?;
557 let stream = WebSocketStream {
558 stream: Arc::new(TokioMutex::new(stream)),
559 };
560
561 let status = response.status();
564 let (parts, body) = response.into_parts();
565 let body = Body::from(body.unwrap_or_else(std::vec::Vec::new));
566 let response = tokio_tungstenite::tungstenite::http::Response::from_parts(parts, body);
567
568 let response = ResponseWrapper {
569 status,
570 response: Arc::new(Mutex::new(Some(Response::from(response)))),
571 };
572
573 Ok((stream, response))
574 })?,
575 )?;
576
577 Ok(())
578}