1use config::{any_err, from_lua_value, get_or_create_sub_module};
2use futures_util::StreamExt;
3use mlua::prelude::LuaUserData;
4use mlua::{Lua, LuaSerdeExt, MetaMethod, UserDataMethods, Value};
5use reqwest::header::HeaderMap;
6use reqwest::{Body, Client, ClientBuilder, RequestBuilder, Response, StatusCode, Url};
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::Duration;
11use tokio::sync::Mutex as TokioMutex;
12use tokio::time::Instant;
13use tokio_tungstenite::tungstenite::Message;
14
15#[derive(Deserialize, Debug, Clone)]
18struct ClientOptions {
19 #[serde(default)]
20 user_agent: Option<String>,
21 #[serde(default)]
22 connection_verbose: Option<bool>,
23 #[serde(default, with = "duration_serde")]
24 pool_idle_timeout: Option<Duration>,
25 #[serde(default, with = "duration_serde")]
26 timeout: Option<Duration>,
27}
28
29#[derive(Clone)]
30struct ClientWrapper {
31 client: Arc<Mutex<Option<Arc<Client>>>>,
32}
33
34impl ClientWrapper {
35 fn get_client(&self) -> mlua::Result<Arc<Client>> {
36 let inner = self.client.lock().unwrap();
37 inner
38 .as_ref()
39 .map(Arc::clone)
40 .ok_or_else(|| mlua::Error::external("client was closed"))
41 }
42}
43
44impl LuaUserData for ClientWrapper {
45 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
46 methods.add_method("get", |_, this, url: String| {
47 let builder = this.get_client()?.get(url);
48 Ok(RequestWrapper::new(builder))
49 });
50 methods.add_method("post", |_, this, url: String| {
51 let builder = this.get_client()?.post(url);
52 Ok(RequestWrapper::new(builder))
53 });
54 methods.add_method("put", |_, this, url: String| {
55 let builder = this.get_client()?.put(url);
56 Ok(RequestWrapper::new(builder))
57 });
58 methods.add_method("close", |_, this, _: ()| {
59 this.client.lock().unwrap().take();
60 Ok(())
61 });
62 }
63}
64
65#[derive(Clone)]
68struct RequestWrapper {
69 builder: Arc<Mutex<Option<RequestBuilder>>>,
70}
71
72impl RequestWrapper {
73 fn new(builder: RequestBuilder) -> Self {
74 Self {
75 builder: Arc::new(Mutex::new(Some(builder))),
76 }
77 }
78
79 fn apply<F>(&self, func: F) -> mlua::Result<()>
80 where
81 F: FnOnce(RequestBuilder) -> anyhow::Result<RequestBuilder>,
82 {
83 let b = self
84 .builder
85 .lock()
86 .unwrap()
87 .take()
88 .ok_or_else(|| mlua::Error::external("broken request builder"))?;
89
90 let b = (func)(b).map_err(any_err)?;
91
92 self.builder.lock().unwrap().replace(b);
93 Ok(())
94 }
95
96 async fn send(&self) -> mlua::Result<Response> {
97 let b = self
98 .builder
99 .lock()
100 .unwrap()
101 .take()
102 .ok_or_else(|| mlua::Error::external("broken request builder"))?;
103
104 b.send().await.map_err(any_err)
105 }
106}
107
108#[derive(Deserialize, Clone, Hash, PartialEq, Eq, Debug)]
109pub struct FilePart {
110 data: String,
111 file_name: String,
112}
113
114impl LuaUserData for RequestWrapper {
115 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
116 methods.add_method("header", |_, this, (key, value): (String, String)| {
117 this.apply(|b| Ok(b.header(key, value)))?;
118 Ok(this.clone())
119 });
120
121 methods.add_method("headers", |_, this, headers: HashMap<String, String>| {
122 for (key, value) in headers {
123 this.apply(|b| Ok(b.header(key, value)))?;
124 }
125 Ok(this.clone())
126 });
127
128 methods.add_method("timeout", |_, this, duration: Value| {
129 let duration = match duration {
130 Value::Number(n) => std::time::Duration::from_secs_f64(n),
131 Value::String(s) => {
132 let s = s.to_str()?;
133 humantime::parse_duration(&s).map_err(any_err)?
134 }
135 _ => {
136 return Err(mlua::Error::external("invalid timeout duration"));
137 }
138 };
139 this.apply(|b| Ok(b.timeout(duration)))?;
140 Ok(this.clone())
141 });
142
143 methods.add_method(
144 "basic_auth",
145 |_, this, (username, password): (String, Option<String>)| {
146 this.apply(|b| Ok(b.basic_auth(username, password)))?;
147 Ok(this.clone())
148 },
149 );
150
151 methods.add_method("bearer_auth", |_, this, token: String| {
152 this.apply(|b| Ok(b.bearer_auth(token)))?;
153 Ok(this.clone())
154 });
155
156 methods.add_method("body", |_, this, body: String| {
157 this.apply(|b| Ok(b.body(body)))?;
158 Ok(this.clone())
159 });
160
161 methods.add_method(
162 "form_url_encoded",
163 |_, this, params: HashMap<String, String>| {
164 this.apply(|b| Ok(b.form(¶ms)))?;
165 Ok(this.clone())
166 },
167 );
168
169 methods.add_method(
170 "form_multipart_data",
171 |lua, this, params: HashMap<String, mlua::Value>| {
172 use mail_builder::headers::text::Text;
174 use mail_builder::headers::HeaderType;
175 use mail_builder::mime::MimePart;
176 use mailparse::MailHeaderMap;
177 use std::borrow::Cow;
178
179 let mut data = MimePart::new_multipart("multipart/form-data", vec![]);
180
181 for (k, v) in params {
182 match v {
183 mlua::Value::String(s) => {
184 let part = if let Ok(s) = s.to_str() {
185 MimePart::new_text(Cow::Owned(s.to_string()))
186 } else {
187 MimePart::new_binary(
188 "application/octet-stream",
189 Cow::Owned(s.as_bytes().to_vec()),
190 )
191 };
192 data.add_part(part.header(
193 "Content-Disposition",
194 HeaderType::Text(Text::new(format!("form-data; name=\"{k}\""))),
195 ));
196 }
197 _ => {
198 let file: FilePart = lua.from_value(v.clone())?;
199
200 let part = MimePart::new_binary(
201 "application/octet-stream",
202 file.data.into_bytes(),
203 );
204 data.add_part(part.header(
205 "Content-Disposition",
206 HeaderType::Text(Text::new(format!(
207 "form-data; name=\"{k}\"; filename=\"{}\"",
208 file.file_name
209 ))),
210 ));
211 }
212 }
213 }
214 let builder = mail_builder::MessageBuilder::new();
215 let builder = builder.body(data);
216 let body = builder.write_to_vec().map_err(any_err)?;
217
218 let (headers, body_offset) = mailparse::parse_headers(&body).map_err(any_err)?;
224
225 let content_type = headers
226 .get_first_value("Content-Type")
227 .ok_or_else(|| mlua::Error::external("missing Content-Type!?".to_string()))?;
228
229 let body = &body[body_offset..];
230
231 this.apply(|b| Ok(b.header("Content-Type", content_type).body(body.to_vec())))?;
232
233 Ok(this.clone())
234 },
235 );
236
237 methods.add_async_method("send", |_, this, _: ()| async move {
238 let response = this.send().await?;
239 let status = response.status();
240 Ok(ResponseWrapper {
241 status,
242 response: Arc::new(Mutex::new(Some(response))),
243 })
244 });
245 }
246}
247
248#[derive(Clone)]
251struct ResponseWrapper {
252 status: StatusCode,
253 response: Arc<Mutex<Option<Response>>>,
254}
255
256impl ResponseWrapper {
257 fn with<F, T>(&self, func: F) -> mlua::Result<T>
258 where
259 F: FnOnce(&Response) -> anyhow::Result<T>,
260 {
261 let locked = self.response.lock().unwrap();
262 let response = locked
263 .as_ref()
264 .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
265
266 (func)(response).map_err(any_err)
267 }
268
269 async fn text(&self) -> mlua::Result<String> {
270 let r = self
271 .response
272 .lock()
273 .unwrap()
274 .take()
275 .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
276
277 r.text().await.map_err(any_err)
278 }
279
280 async fn bytes(&self, lua: &Lua) -> mlua::Result<mlua::String> {
281 let r = self
282 .response
283 .lock()
284 .unwrap()
285 .take()
286 .ok_or_else(|| mlua::Error::external("broken response wrapper"))?;
287
288 let bytes = r.bytes().await.map_err(any_err)?;
289
290 lua.create_string(bytes.as_ref())
291 }
292}
293
294impl LuaUserData for ResponseWrapper {
295 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
296 methods.add_method("status_code", |_, this, _: ()| Ok(this.status.as_u16()));
297 methods.add_method("status_reason", |_, this, _: ()| {
298 Ok(this.status.canonical_reason())
299 });
300 methods.add_method("status_is_informational", |_, this, _: ()| {
301 Ok(this.status.is_informational())
302 });
303 methods.add_method("status_is_success", |_, this, _: ()| {
304 Ok(this.status.is_success())
305 });
306 methods.add_method("status_is_redirection", |_, this, _: ()| {
307 Ok(this.status.is_redirection())
308 });
309 methods.add_method("status_is_client_error", |_, this, _: ()| {
310 Ok(this.status.is_client_error())
311 });
312 methods.add_method("status_is_server_error", |_, this, _: ()| {
313 Ok(this.status.is_server_error())
314 });
315 methods.add_method("headers", |_, this, _: ()| {
316 this.with(|response| Ok(HeaderMapWrapper(response.headers().clone())))
317 });
318 methods.add_method("content_length", |_, this, _: ()| {
319 this.with(|response| Ok(response.content_length()))
320 });
321
322 methods.add_async_method("text", |_, this, _: ()| async move { this.text().await });
323
324 methods.add_async_method(
325 "bytes",
326 |lua, this, _: ()| async move { this.bytes(&lua).await },
327 );
328 }
329}
330
331#[derive(Clone, mlua::FromLua)]
334struct HeaderMapWrapper(HeaderMap);
335
336impl LuaUserData for HeaderMapWrapper {
337 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
338 methods.add_meta_method(MetaMethod::Index, |lua, this, key: String| {
339 if let Some(value) = this.0.get(&key) {
340 let s = lua.create_string(value.as_bytes())?;
341 return Ok(Value::String(s));
342 }
343 Ok(Value::Nil)
344 });
345
346 methods.add_meta_method(MetaMethod::Pairs, |lua, this, ()| {
347 let stateless_iter =
348 lua.create_function(|lua, (this, key): (HeaderMapWrapper, Option<String>)| {
349 let iter = this.0.iter();
350
351 let mut this_is_key = false;
352
353 if key.is_none() {
354 this_is_key = true;
355 }
356
357 for (this_key, value) in iter {
358 if this_is_key {
359 let key = lua.create_string(this_key.as_str().as_bytes())?;
360 let value = lua.create_string(value.as_bytes())?;
361
362 return Ok(mlua::MultiValue::from_vec(vec![
363 Value::String(key),
364 Value::String(value),
365 ]));
366 }
367 if Some(this_key.as_str()) == key.as_deref() {
368 this_is_key = true;
369 }
370 }
371 Ok(mlua::MultiValue::new())
372 })?;
373 Ok((stateless_iter, this.clone(), Value::Nil))
374 });
375 }
376}
377
378#[derive(Clone)]
379struct WebSocketStream {
380 stream: Arc<
381 TokioMutex<
382 tokio_tungstenite::WebSocketStream<
383 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
384 >,
385 >,
386 >,
387}
388
389impl LuaUserData for WebSocketStream {
390 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
391 methods.add_async_method("recv", |lua, this, ()| async move {
392 let maybe_msg = {
393 let mut stream = this.stream.lock().await;
394 stream.next().await
395 };
396 let msg = match maybe_msg {
397 Some(msg) => msg.map_err(any_err)?,
398 None => return Ok(None),
399 };
400 Ok(match msg {
401 Message::Text(s) => Some(lua.create_string(&s)?),
402 Message::Close(_close_frame) => {
403 return Ok(None);
404 }
405 Message::Pong(s) | Message::Binary(s) => Some(lua.create_string(&s)?),
406 Message::Ping(_) | Message::Frame(_) => {
407 unreachable!()
408 }
409 })
410 });
411
412 methods.add_async_method("recv_batch", |lua, this, duration| async move {
413 let duration = match duration {
414 Value::Number(n) => std::time::Duration::from_secs_f64(n),
415 Value::String(s) => {
416 let s = s.to_str()?;
417 humantime::parse_duration(&s).map_err(any_err)?
418 }
419 _ => {
420 return Err(mlua::Error::external("invalid timeout duration"));
421 }
422 };
423 let deadline = Instant::now() + duration;
424 let mut messages = vec![];
425 while let Ok(maybe_msg) = tokio::time::timeout_at(deadline, async {
426 let mut stream = this.stream.lock().await;
427 stream.next().await
428 })
429 .await
430 {
431 let msg = match maybe_msg {
432 Some(msg) => msg.map_err(any_err)?,
433 None => {
434 if messages.is_empty() {
435 return Err(mlua::Error::external("websocket closed"));
436 }
437 break;
438 }
439 };
440 match msg {
441 Message::Text(s) => messages.push(lua.create_string(&s)?),
442 Message::Close(_close_frame) => {
443 if messages.is_empty() {
444 return Err(mlua::Error::external("websocket closed"));
445 }
446 break;
447 }
448 Message::Pong(s) | Message::Binary(s) => messages.push(lua.create_string(&s)?),
449 Message::Ping(_) | Message::Frame(_) => {
450 unreachable!()
451 }
452 }
453 }
454
455 Ok(messages)
456 });
457 }
458}
459
460pub fn register(lua: &Lua) -> anyhow::Result<()> {
461 let http_mod = get_or_create_sub_module(lua, "http")?;
462
463 http_mod.set(
464 "build_url",
465 lua.create_function(|_lua, (url, params): (String, HashMap<String, String>)| {
466 let url = Url::parse_with_params(&url, params.into_iter()).map_err(any_err)?;
467 let url: String = url.into();
468 Ok(url)
469 })?,
470 )?;
471
472 http_mod.set(
473 "build_client",
474 lua.create_function(|lua, options: Value| {
475 let options: ClientOptions = from_lua_value(lua, options)?;
476 let mut builder = ClientBuilder::new().timeout(
477 options
478 .timeout
479 .unwrap_or_else(|| std::time::Duration::from_secs(60)),
480 );
481
482 if let Some(verbose) = options.connection_verbose {
483 builder = builder.connection_verbose(verbose);
484 }
485
486 if let Some(idle) = options.pool_idle_timeout {
487 builder = builder.pool_idle_timeout(idle);
488 }
489
490 if let Some(user_agent) = options.user_agent {
491 builder = builder.user_agent(user_agent);
492 }
493
494 let client = builder.build().map_err(any_err)?;
495 Ok(ClientWrapper {
496 client: Arc::new(Mutex::new(Some(Arc::new(client)))),
497 })
498 })?,
499 )?;
500
501 http_mod.set(
502 "connect_websocket",
503 lua.create_async_function(|_, url: String| async move {
504 let (stream, response) = tokio_tungstenite::connect_async(url)
505 .await
506 .map_err(any_err)?;
507 let stream = WebSocketStream {
508 stream: Arc::new(TokioMutex::new(stream)),
509 };
510
511 let status = response.status();
514 let (parts, body) = response.into_parts();
515 let body = Body::from(body.unwrap_or_else(std::vec::Vec::new));
516 let response = tokio_tungstenite::tungstenite::http::Response::from_parts(parts, body);
517
518 let response = ResponseWrapper {
519 status,
520 response: Arc::new(Mutex::new(Some(Response::from(response)))),
521 };
522
523 Ok((stream, response))
524 })?,
525 )?;
526
527 Ok(())
528}