mod_nats/
lib.rs

1use async_nats::jetstream::Context;
2use async_nats::{ConnectOptions, HeaderMap};
3use config::{any_err, get_or_create_sub_module, SerdeWrappedValue};
4use data_loader::KeySource;
5use mlua::prelude::LuaUserData;
6use mlua::{Lua, LuaSerdeExt, UserDataMethods, Value};
7use parking_lot::Mutex;
8use serde::Deserialize;
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::Arc;
12use std::time::Duration;
13
14// https://docs.rs/async-nats/0.46.0/src/async_nats/options.rs.html#43
15#[derive(Debug, Deserialize)]
16#[serde(deny_unknown_fields)]
17struct Config {
18    servers: Vec<String>,
19    #[serde(default)]
20    auth: Option<ConfigAuth>,
21
22    name: Option<String>,
23    no_echo: Option<bool>,
24    max_reconnects: Option<usize>,
25    #[serde(default, with = "duration_serde")]
26    connection_timeout: Option<Duration>,
27    tls_required: Option<bool>,
28    tls_first: Option<bool>,
29    certificate: Option<PathBuf>,
30    client_cert: Option<PathBuf>,
31    client_key: Option<PathBuf>,
32    ping_interval: Option<Duration>,
33    client_capacity: Option<usize>,
34    inbox_prefix: Option<String>,
35    #[serde(default, with = "duration_serde")]
36    request_timeout: Option<Duration>,
37    retry_on_initial_connect: Option<bool>,
38    ignore_discovered_servers: Option<bool>,
39    retain_servers_order: Option<bool>,
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(deny_unknown_fields)]
44struct ConfigAuth {
45    username: Option<KeySource>,
46    password: Option<KeySource>,
47    token: Option<KeySource>,
48}
49
50#[derive(Clone)]
51struct Client {
52    context: Arc<Mutex<Option<Arc<Context>>>>,
53}
54
55impl Client {
56    fn get_context(&self) -> mlua::Result<Arc<Context>> {
57        self.context
58            .lock_arc()
59            .as_ref()
60            .map(Arc::clone)
61            .ok_or_else(|| mlua::Error::external("client was closed"))
62    }
63}
64
65#[derive(Deserialize, Debug)]
66#[serde(deny_unknown_fields)]
67struct Message {
68    /// Required destination subject
69    subject: String,
70    /// Payload
71    #[serde(with = "serde_bytes")]
72    payload: Vec<u8>,
73    /// Optional headers
74    #[serde(default)]
75    headers: HashMap<String, String>,
76    /// Optional acknowledgement
77    #[serde(default = "default_true")]
78    await_ack: bool,
79}
80
81fn default_true() -> bool {
82    true
83}
84
85impl LuaUserData for Client {
86    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
87        methods.add_async_method("publish", |lua, this, value: Value| async move {
88            let message: Message = lua.from_value(value)?;
89
90            let mut headers = HeaderMap::new();
91            for (key, v) in message.headers {
92                headers.insert(key, v);
93            }
94
95            let ack_fut = this
96                .get_context()?
97                .publish_with_headers(message.subject, headers, message.payload.into())
98                .await
99                .map_err(|err| any_err(err))?;
100
101            let ret = lua.create_table()?;
102
103            if message.await_ack {
104                let resp = ack_fut.await.map_err(|err| any_err(err))?;
105                ret.set("stream", resp.stream)?;
106                ret.set("value", resp.value.unwrap_or_default())?;
107                ret.set("duplicate", resp.duplicate)?;
108                ret.set("sequence", resp.sequence)?;
109                ret.set("domain", resp.domain)?;
110            }
111
112            Ok(ret)
113        });
114
115        methods.add_async_method("close", |_lua, this, _: ()| async move {
116            this.context.lock().take();
117
118            Ok(())
119        });
120    }
121}
122
123async fn build_client(config: Config) -> anyhow::Result<async_nats::Client> {
124    let mut opts = ConnectOptions::new();
125
126    if let Some(name) = config.name {
127        opts = opts.name(name);
128    }
129    if let Some(true) = config.no_echo {
130        opts = opts.no_echo();
131    }
132    if let Some(max_reconnects) = config.max_reconnects {
133        opts = opts.max_reconnects(max_reconnects);
134    }
135    if let Some(connection_timeout) = config.connection_timeout {
136        opts = opts.connection_timeout(connection_timeout);
137    }
138
139    if let Some(auth) = &config.auth {
140        match (&auth.username, &auth.password) {
141            (Some(username), Some(password)) => {
142                let username = String::from_utf8(username.get().await?)?;
143                let password = String::from_utf8(password.get().await?)?;
144                opts = opts.user_and_password(username, password);
145            }
146            (None, None) => {}
147            _ => {
148                anyhow::bail!("either specify both of username and password or neither");
149            }
150        }
151
152        if let Some(token) = &auth.token {
153            let token = String::from_utf8(token.get().await?)?;
154            opts = opts.token(token);
155        }
156    }
157    if let Some(tls_required) = config.tls_required {
158        opts = opts.require_tls(tls_required);
159    }
160    if let Some(true) = config.tls_first {
161        opts = opts.tls_first();
162    }
163    if let Some(certificate) = config.certificate {
164        opts = opts.add_root_certificates(certificate);
165    }
166
167    match (config.client_cert, config.client_key) {
168        (Some(client_cert), Some(client_key)) => {
169            opts = opts.add_client_certificate(client_cert, client_key);
170        }
171        (None, None) => {}
172        _ => {
173            anyhow::bail!("either specify both of client_cert and client_key or neither");
174        }
175    }
176
177    if let Some(ping_interval) = config.ping_interval {
178        opts = opts.ping_interval(ping_interval);
179    }
180    if let Some(sender_capacity) = config.client_capacity {
181        opts = opts.client_capacity(sender_capacity);
182    }
183    if let Some(inbox_prefix) = config.inbox_prefix {
184        opts = opts.custom_inbox_prefix(inbox_prefix);
185    }
186    opts = opts.request_timeout(config.request_timeout);
187
188    if let Some(true) = config.retry_on_initial_connect {
189        opts = opts.retry_on_initial_connect();
190    }
191    if let Some(true) = config.ignore_discovered_servers {
192        opts = opts.ignore_discovered_servers();
193    }
194    if let Some(true) = config.retain_servers_order {
195        opts = opts.retain_servers_order();
196    }
197
198    Ok(opts.connect(config.servers).await?)
199}
200
201pub fn register(lua: &Lua) -> anyhow::Result<()> {
202    let nats_mod = get_or_create_sub_module(lua, "nats")?;
203
204    nats_mod.set(
205        "connect",
206        lua.create_async_function(|_lua, config: SerdeWrappedValue<Config>| async move {
207            let client = build_client(config.0).await.map_err(any_err)?;
208            let context = async_nats::jetstream::new(client);
209
210            Ok(Client {
211                context: Arc::new(Mutex::new(Some(Arc::new(context)))),
212            })
213        })?,
214    )?;
215
216    Ok(())
217}