1use async_nats::jetstream::Context;
2use async_nats::{ConnectOptions, HeaderMap};
3use config::{any_err, get_or_create_sub_module, SerdeWrappedValue};
4use data_loader::KeySource;
5use mlua::prelude::LuaUserData;
6use mlua::{Lua, LuaSerdeExt, UserDataMethods, Value};
7use parking_lot::Mutex;
8use serde::Deserialize;
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::sync::Arc;
12use std::time::Duration;
13
14#[derive(Debug, Deserialize)]
16#[serde(deny_unknown_fields)]
17struct Config {
18 servers: Vec<String>,
19 #[serde(default)]
20 auth: Option<ConfigAuth>,
21
22 name: Option<String>,
23 no_echo: Option<bool>,
24 max_reconnects: Option<usize>,
25 #[serde(default, with = "duration_serde")]
26 connection_timeout: Option<Duration>,
27 tls_required: Option<bool>,
28 tls_first: Option<bool>,
29 certificate: Option<PathBuf>,
30 client_cert: Option<PathBuf>,
31 client_key: Option<PathBuf>,
32 ping_interval: Option<Duration>,
33 client_capacity: Option<usize>,
34 inbox_prefix: Option<String>,
35 #[serde(default, with = "duration_serde")]
36 request_timeout: Option<Duration>,
37 retry_on_initial_connect: Option<bool>,
38 ignore_discovered_servers: Option<bool>,
39 retain_servers_order: Option<bool>,
40}
41
42#[derive(Debug, Deserialize)]
43#[serde(deny_unknown_fields)]
44struct ConfigAuth {
45 username: Option<KeySource>,
46 password: Option<KeySource>,
47 token: Option<KeySource>,
48}
49
50#[derive(Clone)]
51struct Client {
52 context: Arc<Mutex<Option<Arc<Context>>>>,
53}
54
55impl Client {
56 fn get_context(&self) -> mlua::Result<Arc<Context>> {
57 self.context
58 .lock_arc()
59 .as_ref()
60 .map(Arc::clone)
61 .ok_or_else(|| mlua::Error::external("client was closed"))
62 }
63}
64
65#[derive(Deserialize, Debug)]
66#[serde(deny_unknown_fields)]
67struct Message {
68 subject: String,
70 #[serde(with = "serde_bytes")]
72 payload: Vec<u8>,
73 #[serde(default)]
75 headers: HashMap<String, String>,
76 #[serde(default = "default_true")]
78 await_ack: bool,
79}
80
81fn default_true() -> bool {
82 true
83}
84
85impl LuaUserData for Client {
86 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
87 methods.add_async_method("publish", |lua, this, value: Value| async move {
88 let message: Message = lua.from_value(value)?;
89
90 let mut headers = HeaderMap::new();
91 for (key, v) in message.headers {
92 headers.insert(key, v);
93 }
94
95 let ack_fut = this
96 .get_context()?
97 .publish_with_headers(message.subject, headers, message.payload.into())
98 .await
99 .map_err(|err| any_err(err))?;
100
101 let ret = lua.create_table()?;
102
103 if message.await_ack {
104 let resp = ack_fut.await.map_err(|err| any_err(err))?;
105 ret.set("stream", resp.stream)?;
106 ret.set("value", resp.value.unwrap_or_default())?;
107 ret.set("duplicate", resp.duplicate)?;
108 ret.set("sequence", resp.sequence)?;
109 ret.set("domain", resp.domain)?;
110 }
111
112 Ok(ret)
113 });
114
115 methods.add_async_method("close", |_lua, this, _: ()| async move {
116 this.context.lock().take();
117
118 Ok(())
119 });
120 }
121}
122
123async fn build_client(config: Config) -> anyhow::Result<async_nats::Client> {
124 let mut opts = ConnectOptions::new();
125
126 if let Some(name) = config.name {
127 opts = opts.name(name);
128 }
129 if let Some(true) = config.no_echo {
130 opts = opts.no_echo();
131 }
132 if let Some(max_reconnects) = config.max_reconnects {
133 opts = opts.max_reconnects(max_reconnects);
134 }
135 if let Some(connection_timeout) = config.connection_timeout {
136 opts = opts.connection_timeout(connection_timeout);
137 }
138
139 if let Some(auth) = &config.auth {
140 match (&auth.username, &auth.password) {
141 (Some(username), Some(password)) => {
142 let username = String::from_utf8(username.get().await?)?;
143 let password = String::from_utf8(password.get().await?)?;
144 opts = opts.user_and_password(username, password);
145 }
146 (None, None) => {}
147 _ => {
148 anyhow::bail!("either specify both of username and password or neither");
149 }
150 }
151
152 if let Some(token) = &auth.token {
153 let token = String::from_utf8(token.get().await?)?;
154 opts = opts.token(token);
155 }
156 }
157 if let Some(tls_required) = config.tls_required {
158 opts = opts.require_tls(tls_required);
159 }
160 if let Some(true) = config.tls_first {
161 opts = opts.tls_first();
162 }
163 if let Some(certificate) = config.certificate {
164 opts = opts.add_root_certificates(certificate);
165 }
166
167 match (config.client_cert, config.client_key) {
168 (Some(client_cert), Some(client_key)) => {
169 opts = opts.add_client_certificate(client_cert, client_key);
170 }
171 (None, None) => {}
172 _ => {
173 anyhow::bail!("either specify both of client_cert and client_key or neither");
174 }
175 }
176
177 if let Some(ping_interval) = config.ping_interval {
178 opts = opts.ping_interval(ping_interval);
179 }
180 if let Some(sender_capacity) = config.client_capacity {
181 opts = opts.client_capacity(sender_capacity);
182 }
183 if let Some(inbox_prefix) = config.inbox_prefix {
184 opts = opts.custom_inbox_prefix(inbox_prefix);
185 }
186 opts = opts.request_timeout(config.request_timeout);
187
188 if let Some(true) = config.retry_on_initial_connect {
189 opts = opts.retry_on_initial_connect();
190 }
191 if let Some(true) = config.ignore_discovered_servers {
192 opts = opts.ignore_discovered_servers();
193 }
194 if let Some(true) = config.retain_servers_order {
195 opts = opts.retain_servers_order();
196 }
197
198 Ok(opts.connect(config.servers).await?)
199}
200
201pub fn register(lua: &Lua) -> anyhow::Result<()> {
202 let nats_mod = get_or_create_sub_module(lua, "nats")?;
203
204 nats_mod.set(
205 "connect",
206 lua.create_async_function(|_lua, config: SerdeWrappedValue<Config>| async move {
207 let client = build_client(config.0).await.map_err(any_err)?;
208 let context = async_nats::jetstream::new(client);
209
210 Ok(Client {
211 context: Arc::new(Mutex::new(Some(Arc::new(context)))),
212 })
213 })?,
214 )?;
215
216 Ok(())
217}