1use config::{any_err, from_lua_value, get_or_create_sub_module, SerdeWrappedValue};
2use futures_util::StreamExt;
3use mlua::prelude::LuaUserData;
4use mlua::{Lua, LuaSerdeExt, MetaMethod, UserDataMethods, Value};
5use reqwest::header::HeaderMap;
6use reqwest::{Body, Client, ClientBuilder, RequestBuilder, Response, StatusCode, Url};
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use tokio::sync::Mutex as TokioMutex;
12use tokio::time::Instant;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Deserialize, Debug, Clone)]
18struct ClientOptions {
19 #[serde(default)]
20 user_agent: Option<String>,
21 #[serde(default)]
22 connection_verbose: Option<bool>,
23 #[serde(default, with = "duration_serde")]
24 pool_idle_timeout: Option<Duration>,
25 #[serde(default, with = "duration_serde")]
26 timeout: Option<Duration>,
27 #[serde(default)]
28 accept_invalid_certs: Option<bool>,
29}
30
31#[derive(Clone)]
32struct ClientWrapper {
33 client: Arc<Mutex<Option<Arc<Client>>>>,
34}
35
36impl ClientWrapper {
37 fn get_client(&self) -> mlua::Result<Arc<Client>> {
38 let inner = self.client.lock().unwrap();
39 inner
40 .as_ref()
41 .map(Arc::clone)
42 .ok_or_else(|| mlua::Error::external("client was closed"))
43 }
44}
45
46impl LuaUserData for ClientWrapper {
47 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
48 methods.add_method("get", |_, this, url: String| {
49 let builder = this.get_client()?.get(url);
50 Ok(RequestWrapper::new(builder))
51 });
52 methods.add_method("post", |_, this, url: String| {
53 let builder = this.get_client()?.post(url);
54 Ok(RequestWrapper::new(builder))
55 });
56 methods.add_method("put", |_, this, url: String| {
57 let builder = this.get_client()?.put(url);
58 Ok(RequestWrapper::new(builder))
59 });
60 methods.add_method("close", |_, this, _: ()| {
61 this.client.lock().unwrap().take();
62 Ok(())
63 });
64 }
65}
66
67#[derive(Clone)]
70struct RequestWrapper {
71 builder: Arc<Mutex<Option<RequestBuilder>>>,
72}
73
74impl RequestWrapper {
75 fn new(builder: RequestBuilder) -> Self {
76 Self {
77 builder: Arc::new(Mutex::new(Some(builder))),
78 }
79 }
80
81 fn apply<F>(&self, func: F) -> mlua::Result<()>
82 where
83 F: FnOnce(RequestBuilder) -> anyhow::Result<RequestBuilder>,
84 {
85 let b = self
86 .builder
87 .lock()
88 .unwrap()
89 .take()
90 .ok_or_else(|| mlua::Error::external("broken request builder"))?;
91
92 let b = (func)(b).map_err(any_err)?;
93
94 self.builder.lock().unwrap().replace(b);
95 Ok(())
96 }
97
98 async fn send(&self) -> mlua::Result<Response> {
99 let b = self
100 .builder
101 .lock()
102 .unwrap()
103 .take()
104 .ok_or_else(|| mlua::Error::external("broken request builder"))?;
105
106 b.send().await.map_err(any_err)
107 }
108}
109
110#[derive(Deserialize, Clone, Hash, PartialEq, Eq, Debug)]
111pub struct FilePart {
112 data: String,
113 file_name: String,
114}
115
116impl LuaUserData for RequestWrapper {
117 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
118 methods.add_method("header", |_, this, (key, value): (String, String)| {
119 this.apply(|b| Ok(b.header(key, value)))?;
120 Ok(this.clone())
121 });
122
123 methods.add_method("headers", |_, this, headers: HashMap<String, String>| {
124 for (key, value) in headers {
125 this.apply(|b| Ok(b.header(key, value)))?;
126 }
127 Ok(this.clone())
128 });
129
130 methods.add_method("timeout", |_, this, duration: Value| {
131 let duration = match duration {
132 Value::Number(n) => std::time::Duration::from_secs_f64(n),
133 Value::String(s) => {
134 let s = s.to_str()?;
135 humantime::parse_duration(&s).map_err(any_err)?
136 }
137 _ => {
138 return Err(mlua::Error::external("invalid timeout duration"));
139 }
140 };
141 this.apply(|b| Ok(b.timeout(duration)))?;
142 Ok(this.clone())
143 });
144
145 methods.add_method(
146 "basic_auth",
147 |_, this, (username, password): (String, Option<String>)| {
148 this.apply(|b| Ok(b.basic_auth(username, password)))?;
149 Ok(this.clone())
150 },
151 );
152
153 methods.add_method("bearer_auth", |_, this, token: String| {
154 this.apply(|b| Ok(b.bearer_auth(token)))?;
155 Ok(this.clone())
156 });
157
158 methods.add_method("body", |_, this, body: String| {
159 this.apply(|b| Ok(b.body(body)))?;
160 Ok(this.clone())
161 });
162
163 methods.add_method(
164 "form_url_encoded",
165 |_, this, params: HashMap<String, String>| {
166 this.apply(|b| Ok(b.form(¶ms)))?;
167 Ok(this.clone())
168 },
169 );
170
171 methods.add_method(
172 "form_multipart_data",
173 |lua, this, params: HashMap<String, mlua::Value>| {
174 use mail_builder::headers::text::Text;
176 use mail_builder::headers::HeaderType;
177 use mail_builder::mime::MimePart;
178 use mailparse::MailHeaderMap;
179 use std::borrow::Cow;
180
181 let mut data = MimePart::new_multipart("multipart/form-data", vec![]);
182
183 for (k, v) in params {
184 match v {
185 mlua::Value::String(s) => {
186 let part = if let Ok(s) = s.to_str() {
187 MimePart::new_text(Cow::Owned(s.to_string()))
188 } else {
189 MimePart::new_binary(
190 "application/octet-stream",
191 Cow::Owned(s.as_bytes().to_vec()),
192 )
193 };
194 data.add_part(part.header(
195 "Content-Disposition",
196 HeaderType::Text(Text::new(format!("form-data; name=\"{k}\""))),
197 ));
198 }
199 _ => {
200 let file: FilePart = lua.from_value(v.clone())?;
201
202 let part = MimePart::new_binary(
203 "application/octet-stream",
204 file.data.into_bytes(),
205 );
206 data.add_part(part.header(
207 "Content-Disposition",
208 HeaderType::Text(Text::new(format!(
209 "form-data; name=\"{k}\"; filename=\"{}\"",
210 file.file_name
211 ))),
212 ));
213 }
214 }
215 }
216 let builder = mail_builder::MessageBuilder::new();
217 let builder = builder.body(data);
218 let body = builder.write_to_vec().map_err(any_err)?;
219
220 let (headers, body_offset) = mailparse::parse_headers(&body).map_err(any_err)?;
226
227 let content_type = headers
228 .get_first_value("Content-Type")
229 .ok_or_else(|| mlua::Error::external("missing Content-Type!?".to_string()))?;
230
231 let body = &body[body_offset..];
232
233 this.apply(|b| Ok(b.header("Content-Type", content_type).body(body.to_vec())))?;
234
235 Ok(this.clone())
236 },
237 );
238
239 methods.add_async_method("send", |_, this, _: ()| async move {
240 let response = this.send().await?;
241 let status = response.status();
242 Ok(ResponseWrapper {
243 status,
244 response: Arc::new(Mutex::new(Some(response))),
245 })
246 });
247
248 methods.add_async_method(
250 "aws_sign_v4",
251 |_lua, this, params: SerdeWrappedValue<mod_aws_sigv4::SigV4Request>| async move {
252 let mut signer_params = params.0;
253
254 let req_builder = this
256 .builder
257 .lock()
258 .unwrap()
259 .as_ref()
260 .ok_or_else(|| mlua::Error::external("broken request builder"))?
261 .try_clone()
262 .ok_or_else(|| mlua::Error::external("failed to clone request builder"))?;
263
264 let req = req_builder.build().map_err(any_err)?;
265
266 signer_params.method = req.method().as_str().to_string();
268 signer_params.uri = req.url().path().to_string();
269 signer_params.query_params = req
270 .url()
271 .query_pairs()
272 .map(|(k, v)| (k.to_string(), v.to_string()))
273 .collect();
274
275 if !signer_params.headers.contains_key("host") {
277 if let Some(host) = req.url().host_str() {
278 signer_params
279 .headers
280 .insert("host".to_string(), host.to_string());
281 }
282 }
283
284 let sig = mod_aws_sigv4::sign_request(signer_params)
285 .await
286 .map_err(any_err)?;
287
288 this.apply(|b| {
290 Ok(b.header("Authorization", sig.authorization)
291 .header("X-Amz-Date", sig.timestamp))
292 })?;
293
294 Ok(this.clone())
295 },
296 );
297 }
298}
299
300#[derive(Clone)]
303struct ResponseWrapper {
304 status: StatusCode,
305 response: Arc<Mutex<Option<Response>>>,
306}
307
308impl ResponseWrapper {
309 fn with<F, T>(&self, func: F) -> mlua::Result<T>
310 where
311 F: FnOnce(&Response) -> anyhow::Result<T>,
312 {
313 let locked = self.response.lock().unwrap();
314 let response = locked
315 .as_ref()
316 .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
317
318 (func)(response).map_err(any_err)
319 }
320
321 async fn text(&self) -> mlua::Result<String> {
322 let r = self
323 .response
324 .lock()
325 .unwrap()
326 .take()
327 .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
328
329 r.text().await.map_err(any_err)
330 }
331
332 async fn bytes(&self, lua: &Lua) -> mlua::Result<mlua::String> {
333 let r = self
334 .response
335 .lock()
336 .unwrap()
337 .take()
338 .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
339
340 let bytes = r.bytes().await.map_err(any_err)?;
341
342 lua.create_string(bytes.as_ref())
343 }
344}
345
346impl LuaUserData for ResponseWrapper {
347 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
348 methods.add_method("status_code", |_, this, _: ()| Ok(this.status.as_u16()));
349 methods.add_method("status_reason", |_, this, _: ()| {
350 Ok(this.status.canonical_reason())
351 });
352 methods.add_method("status_is_informational", |_, this, _: ()| {
353 Ok(this.status.is_informational())
354 });
355 methods.add_method("status_is_success", |_, this, _: ()| {
356 Ok(this.status.is_success())
357 });
358 methods.add_method("status_is_redirection", |_, this, _: ()| {
359 Ok(this.status.is_redirection())
360 });
361 methods.add_method("status_is_client_error", |_, this, _: ()| {
362 Ok(this.status.is_client_error())
363 });
364 methods.add_method("status_is_server_error", |_, this, _: ()| {
365 Ok(this.status.is_server_error())
366 });
367 methods.add_method("headers", |_, this, _: ()| {
368 this.with(|response| Ok(HeaderMapWrapper(response.headers().clone())))
369 });
370 methods.add_method("content_length", |_, this, _: ()| {
371 this.with(|response| Ok(response.content_length()))
372 });
373
374 methods.add_async_method("text", |_, this, _: ()| async move { this.text().await });
375
376 methods.add_async_method(
377 "bytes",
378 |lua, this, _: ()| async move { this.bytes(&lua).await },
379 );
380 }
381}
382
383#[derive(Clone, mlua::FromLua)]
386struct HeaderMapWrapper(HeaderMap);
387
388impl LuaUserData for HeaderMapWrapper {
389 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
390 methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
391 if let Some(value) = this.0.get(&key) {
392 let s = lua.create_string(value.as_bytes())?;
393 return Ok(Value::String(s));
394 }
395 Ok(Value::Nil)
396 });
397
398 methods.add_meta_method(MetaMethod::Pairs, |lua, this, ()| {
399 let stateless_iter =
400 lua.create_function(|lua, (this, key): (HeaderMapWrapper, Option<String>)| {
401 let iter = this.0.iter();
402
403 let mut this_is_key = false;
404
405 if key.is_none() {
406 this_is_key = true;
407 }
408
409 for (this_key, value) in iter {
410 if this_is_key {
411 let key = lua.create_string(this_key.as_str().as_bytes())?;
412 let value = lua.create_string(value.as_bytes())?;
413
414 return Ok(mlua::MultiValue::from_vec(vec![
415 Value::String(key),
416 Value::String(value),
417 ]));
418 }
419 if Some(this_key.as_str()) == key.as_deref() {
420 this_is_key = true;
421 }
422 }
423 Ok(mlua::MultiValue::new())
424 })?;
425 Ok((stateless_iter, this.clone(), Value::Nil))
426 });
427 }
428}
429
430#[derive(Clone)]
431struct WebSocketStream {
432 stream: Arc<
433 TokioMutex<
434 tokio_tungstenite::WebSocketStream<
435 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
436 >,
437 >,
438 >,
439}
440
441impl LuaUserData for WebSocketStream {
442 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
443 methods.add_async_method("recv", |lua, this, ()| async move {
444 let maybe_msg = {
445 let mut stream = this.stream.lock().await;
446 stream.next().await
447 };
448 let msg = match maybe_msg {
449 Some(msg) => msg.map_err(any_err)?,
450 None => return Ok(None),
451 };
452 Ok(match msg {
453 Message::Text(s) => Some(lua.create_string(&s)?),
454 Message::Close(_close_frame) => {
455 return Ok(None);
456 }
457 Message::Pong(s) | Message::Binary(s) => Some(lua.create_string(&s)?),
458 Message::Ping(_) | Message::Frame(_) => {
459 unreachable!()
460 }
461 })
462 });
463
464 methods.add_async_method("recv_batch", |lua, this, duration| async move {
465 let duration = match duration {
466 Value::Number(n) => std::time::Duration::from_secs_f64(n),
467 Value::String(s) => {
468 let s = s.to_str()?;
469 humantime::parse_duration(&s).map_err(any_err)?
470 }
471 _ => {
472 return Err(mlua::Error::external("invalid timeout duration"));
473 }
474 };
475 let deadline = Instant::now() + duration;
476 let mut messages = vec![];
477 while let Ok(maybe_msg) = tokio::time::timeout_at(deadline, async {
478 let mut stream = this.stream.lock().await;
479 stream.next().await
480 })
481 .await
482 {
483 let msg = match maybe_msg {
484 Some(msg) => msg.map_err(any_err)?,
485 None => {
486 if messages.is_empty() {
487 return Err(mlua::Error::external("websocket closed"));
488 }
489 break;
490 }
491 };
492 match msg {
493 Message::Text(s) => messages.push(lua.create_string(&s)?),
494 Message::Close(_close_frame) => {
495 if messages.is_empty() {
496 return Err(mlua::Error::external("websocket closed"));
497 }
498 break;
499 }
500 Message::Pong(s) | Message::Binary(s) => messages.push(lua.create_string(&s)?),
501 Message::Ping(_) | Message::Frame(_) => {
502 unreachable!()
503 }
504 }
505 }
506
507 Ok(messages)
508 });
509 }
510}
511
512pub fn register(lua: &Lua) -> anyhow::Result<()> {
513 let http_mod = get_or_create_sub_module(lua, "http")?;
514
515 http_mod.set(
516 "build_url",
517 lua.create_function(|_lua, (url, params): (String, HashMap<String, String>)| {
518 let url = Url::parse_with_params(&url, params.into_iter()).map_err(any_err)?;
519 let url: String = url.into();
520 Ok(url)
521 })?,
522 )?;
523
524 http_mod.set(
525 "build_client",
526 lua.create_function(|lua, options: Value| {
527 let options: ClientOptions = from_lua_value(lua, options)?;
528 let mut builder = ClientBuilder::new().timeout(
529 options
530 .timeout
531 .unwrap_or_else(|| std::time::Duration::from_secs(60)),
532 );
533
534 if let Some(verbose) = options.connection_verbose {
535 builder = builder.connection_verbose(verbose);
536 }
537
538 if let Some(idle) = options.pool_idle_timeout {
539 builder = builder.pool_idle_timeout(idle);
540 }
541
542 if let Some(user_agent) = options.user_agent {
543 builder = builder.user_agent(user_agent);
544 }
545
546 if let Some(accept_invalid_certs) = options.accept_invalid_certs {
547 builder = builder.danger_accept_invalid_certs(accept_invalid_certs)
548 }
549
550 let client = builder.build().map_err(any_err)?;
551 Ok(ClientWrapper {
552 client: Arc::new(Mutex::new(Some(Arc::new(client)))),
553 })
554 })?,
555 )?;
556
557 http_mod.set(
558 "connect_websocket",
559 lua.create_async_function(|_, url: String| async move {
560 let (stream, response) = tokio_tungstenite::connect_async(url)
561 .await
562 .map_err(any_err)?;
563 let stream = WebSocketStream {
564 stream: Arc::new(TokioMutex::new(stream)),
565 };
566
567 let status = response.status();
570 let (parts, body) = response.into_parts();
571 let body = Body::from(body.unwrap_or_else(std::vec::Vec::new));
572 let response = tokio_tungstenite::tungstenite::http::Response::from_parts(parts, body);
573
574 let response = ResponseWrapper {
575 status,
576 response: Arc::new(Mutex::new(Some(Response::from(response)))),
577 };
578
579 Ok((stream, response))
580 })?,
581 )?;
582
583 Ok(())
584}