1use crate::epoch::{get_current_epoch, ConfigEpoch};
2use crate::pool::{pool_get, pool_put};
3pub use crate::pool::{set_gc_on_put, set_max_age, set_max_spare, set_max_use};
4use anyhow::Context;
5use mlua::{
6 FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Lua, LuaSerdeExt, MetaMethod, RegistryKey, Table,
7 UserData, UserDataMethods, Value,
8};
9use parking_lot::FairMutex as Mutex;
10pub use paste;
11use prometheus::{CounterVec, HistogramTimer, HistogramVec};
12use serde::Serialize;
13use std::borrow::Cow;
14use std::collections::HashSet;
15use std::path::PathBuf;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{LazyLock, Once};
18use std::time::Instant;
19
20pub mod epoch;
21mod pool;
22
23static POLICY_FILE: LazyLock<Mutex<Option<PathBuf>>> = LazyLock::new(|| Mutex::new(None));
24static FUNCS: LazyLock<Mutex<Vec<RegisterFunc>>> = LazyLock::new(|| Mutex::new(vec![]));
25static LUA_LOAD_COUNT: LazyLock<metrics::Counter> = LazyLock::new(|| {
26 metrics::describe_counter!(
27 "lua_load_count",
28 "how many times the policy lua script has been \
29 loaded into a new context"
30 );
31 metrics::counter!("lua_load_count")
32});
33static LUA_COUNT: LazyLock<metrics::Gauge> = LazyLock::new(|| {
34 metrics::describe_gauge!("lua_count", "the number of lua contexts currently alive");
35 metrics::gauge!("lua_count")
36});
37static CALLBACK_ALLOWS_MULTIPLE: LazyLock<Mutex<HashSet<String>>> =
38 LazyLock::new(|| Mutex::new(HashSet::new()));
39
40pub static VALIDATE_ONLY: AtomicBool = AtomicBool::new(false);
41pub static VALIDATION_FAILED: AtomicBool = AtomicBool::new(false);
42static LATENCY_HIST: LazyLock<HistogramVec> = LazyLock::new(|| {
43 prometheus::register_histogram_vec!(
44 "lua_event_latency",
45 "how long a given lua event callback took",
46 &["event"]
47 )
48 .unwrap()
49});
50static EVENT_STARTED_COUNT: LazyLock<CounterVec> = LazyLock::new(|| {
51 prometheus::register_counter_vec!(
52 "lua_event_started",
53 "Incremented each time we start to call a lua event callback. Use lua_event_latency_count to track completed events",
54 &["event"]
55 )
56 .unwrap()
57});
58
59pub type RegisterFunc = fn(&Lua) -> anyhow::Result<()>;
60
61fn latency_timer(label: &str) -> HistogramTimer {
62 EVENT_STARTED_COUNT
63 .get_metric_with_label_values(&[label])
64 .expect("to get counter")
65 .inc();
66 LATENCY_HIST
67 .get_metric_with_label_values(&[label])
68 .expect("to get histo")
69 .start_timer()
70}
71
72#[derive(Debug)]
73struct LuaConfigInner {
74 lua: Lua,
75 created: Instant,
76 use_count: usize,
77 epoch: ConfigEpoch,
78}
79
80impl Drop for LuaConfigInner {
81 fn drop(&mut self) {
82 LUA_COUNT.decrement(1.);
83 }
84}
85
86#[derive(Debug)]
87pub struct LuaConfig {
88 inner: Option<LuaConfigInner>,
89}
90
91pub async fn set_policy_path(path: PathBuf) -> anyhow::Result<()> {
92 POLICY_FILE.lock().replace(path);
93 let config = load_config().await?;
94 config.put();
95 Ok(())
96}
97
98fn get_policy_path() -> Option<PathBuf> {
99 POLICY_FILE.lock().clone()
100}
101
102fn get_funcs() -> Vec<RegisterFunc> {
103 FUNCS.lock().clone()
104}
105pub fn is_validating() -> bool {
106 VALIDATE_ONLY.load(Ordering::Relaxed)
107}
108
109pub fn validation_failed() -> bool {
110 VALIDATION_FAILED.load(Ordering::Relaxed)
111}
112
113pub fn set_validation_failed() {
114 VALIDATION_FAILED.store(true, Ordering::Relaxed)
115}
116
117pub async fn load_config() -> anyhow::Result<LuaConfig> {
118 if let Some(pool) = pool_get() {
119 return Ok(pool);
120 }
121
122 LUA_LOAD_COUNT.increment(1);
123 let lua = Lua::new();
124 let created = Instant::now();
125 let epoch = get_current_epoch();
126
127 {
128 let globals = lua.globals();
129
130 if is_validating() {
131 globals.set("_VALIDATING_CONFIG", true)?;
132 }
133
134 let package: Table = globals.get("package")?;
135 let package_path: String = package.get("path")?;
136 let mut path_array: Vec<String> = package_path.split(";").map(|s| s.to_owned()).collect();
137
138 fn prefix_path(array: &mut Vec<String>, path: &str) {
139 array.insert(0, format!("{}/?.lua", path));
140 array.insert(1, format!("{}/?/init.lua", path));
141 }
142
143 prefix_path(&mut path_array, "/opt/kumomta/etc/policy");
144 prefix_path(&mut path_array, "/opt/kumomta/share");
145
146 #[cfg(debug_assertions)]
147 prefix_path(&mut path_array, "assets");
148
149 package.set("path", path_array.join(";"))?;
150 }
151
152 register_declared_events();
153
154 for func in get_funcs() {
155 (func)(&lua)?;
156 }
157
158 if let Some(policy) = get_policy_path() {
159 let code = tokio::fs::read_to_string(&policy)
160 .await
161 .with_context(|| format!("reading policy file {policy:?}"))?;
162
163 let func = {
164 let chunk = lua.load(&code);
165 let chunk = chunk.set_name(policy.to_string_lossy());
166 chunk.into_function()?
167 };
168
169 let _timer = latency_timer("context-creation");
170 func.call_async::<()>(()).await?;
171 }
172 LUA_COUNT.increment(1.);
173
174 Ok(LuaConfig {
175 inner: Some(LuaConfigInner {
176 lua,
177 created,
178 use_count: 1,
179 epoch,
180 }),
181 })
182}
183
184pub fn register(func: RegisterFunc) {
185 FUNCS.lock().push(func);
186}
187
188impl LuaConfig {
189 fn set_current_event(&mut self, name: &str) -> mlua::Result<()> {
190 self.inner
191 .as_mut()
192 .unwrap()
193 .lua
194 .globals()
195 .set("_KUMO_CURRENT_EVENT", name.to_string())
196 }
197
198 pub async fn convert_args_and_call_callback<A: Serialize>(
200 &mut self,
201 sig: &CallbackSignature<Value, ()>,
202 args: A,
203 ) -> anyhow::Result<()> {
204 let lua = self.inner.as_mut().unwrap();
205 let args = lua.lua.to_value(&args)?;
206
207 let name = sig.name();
208 let decorated_name = sig.decorated_name();
209
210 match lua
211 .lua
212 .named_registry_value::<mlua::Function>(&decorated_name)
213 {
214 Ok(func) => {
215 let _timer = latency_timer(name);
216 Ok(func.call_async(args).await?)
217 }
218 _ => anyhow::bail!("{name} has not been registered"),
219 }
220 }
221
222 pub fn put(mut self) {
240 if let Some(inner) = self.inner.take() {
241 pool_put(inner);
242 }
243 }
244
245 pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
246 &mut self,
247 sig: &CallbackSignature<A, R>,
248 args: A,
249 ) -> anyhow::Result<R> {
250 let name = sig.name();
251 self.set_current_event(name)?;
252 let lua = self.inner.as_mut().unwrap();
253 async_call_callback(&lua.lua, sig, args).await
254 }
255
256 pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
257 &mut self,
258 sig: &CallbackSignature<A, R>,
259 args: A,
260 ) -> anyhow::Result<R> {
261 let name = sig.name();
262 self.set_current_event(name)?;
263 let lua = self.inner.as_mut().unwrap();
264 async_call_callback_non_default(&lua.lua, sig, args).await
265 }
266
267 pub async fn async_call_callback_non_default_opt<A: IntoLuaMulti + Clone, R: FromLua>(
268 &mut self,
269 sig: &CallbackSignature<A, Option<R>>,
270 args: A,
271 ) -> anyhow::Result<Option<R>> {
272 let name = sig.name();
273 let decorated_name = sig.decorated_name();
274 self.set_current_event(name)?;
275 let lua = self.inner.as_mut().unwrap();
276
277 match lua
278 .lua
279 .named_registry_value::<mlua::Value>(&decorated_name)?
280 {
281 Value::Table(tbl) => {
282 for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
283 let func = func?;
284 let _timer = latency_timer(name);
285 let result: mlua::MultiValue = func.call_async(args.clone()).await?;
286 if result.is_empty() {
287 continue;
289 }
290 let result = R::from_lua_multi(result, &lua.lua)?;
291 return Ok(Some(result));
292 }
293 Ok(None)
294 }
295 Value::Function(func) => {
296 sig.raise_error_if_allow_multiple()?;
297 let _timer = latency_timer(name);
298 let value: Value = func.call_async(args.clone()).await?;
299
300 match value {
301 Value::Nil => Ok(None),
302 value => {
303 let result = R::from_lua(value, &lua.lua)?;
304 Ok(Some(result))
305 }
306 }
307 }
308 _ => Ok(None),
309 }
310 }
311
312 pub fn remove_registry_value(&mut self, value: RegistryKey) -> anyhow::Result<()> {
313 Ok(self
314 .inner
315 .as_mut()
316 .unwrap()
317 .lua
318 .remove_registry_value(value)?)
319 }
320
321 pub async fn async_call_ctor<A: IntoLuaMulti + Clone>(
324 &mut self,
325 sig: &CallbackSignature<A, Value>,
326 args: A,
327 ) -> anyhow::Result<RegistryKey> {
328 let name = sig.name();
329 anyhow::ensure!(
330 !sig.allow_multiple(),
331 "ctor event signature for {name} is defined as allow_multiple, which is not supported"
332 );
333
334 let decorated_name = sig.decorated_name();
335 self.set_current_event(name)?;
336
337 let inner = self.inner.as_mut().unwrap();
338
339 let func = inner
340 .lua
341 .named_registry_value::<mlua::Function>(&decorated_name)?;
342
343 let _timer = latency_timer(name);
344 let value: Value = func.call_async(args.clone()).await?;
345 drop(func);
346
347 Ok(inner.lua.create_registry_value(value)?)
348 }
349
350 pub async fn with_registry_value<F, R, FUT>(
353 &mut self,
354 value: &RegistryKey,
355 func: F,
356 ) -> anyhow::Result<R>
357 where
358 R: FromLuaMulti,
359 F: FnOnce(Value) -> anyhow::Result<FUT>,
360 FUT: std::future::Future<Output = anyhow::Result<R>>,
361 {
362 let inner = self.inner.as_mut().unwrap();
363 let value = inner.lua.registry_value(value)?;
364 let future = (func)(value)?;
365 future.await
366 }
367}
368
369pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
370 lua: &Lua,
371 sig: &CallbackSignature<A, R>,
372 args: A,
373) -> anyhow::Result<R> {
374 let name = sig.name();
375 let decorated_name = sig.decorated_name();
376
377 match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
378 Value::Table(tbl) => {
379 for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
380 let func = func?;
381 let _timer = latency_timer(name);
382 let result: mlua::MultiValue = func.call_async(args.clone()).await?;
383 if result.is_empty() {
384 continue;
386 }
387 let result = R::from_lua_multi(result, lua)?;
388 return Ok(result);
389 }
390 Ok(R::default())
391 }
392 Value::Function(func) => {
393 sig.raise_error_if_allow_multiple()?;
394 let _timer = latency_timer(name);
395 Ok(func.call_async(args.clone()).await?)
396 }
397 _ => Ok(R::default()),
398 }
399}
400
401pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
402 lua: &Lua,
403 sig: &CallbackSignature<A, R>,
404 args: A,
405) -> anyhow::Result<R> {
406 let name = sig.name();
407 let decorated_name = sig.decorated_name();
408
409 match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
410 Value::Table(tbl) => {
411 for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
412 let func = func?;
413 let _timer = latency_timer(name);
414 let result: mlua::MultiValue = func.call_async(args.clone()).await?;
415 if result.is_empty() {
416 continue;
418 }
419 let result = R::from_lua_multi(result, lua)?;
420 return Ok(result);
421 }
422 anyhow::bail!("invalid return type for {name} event");
423 }
424 Value::Function(func) => {
425 sig.raise_error_if_allow_multiple()?;
426 let _timer = latency_timer(name);
427 Ok(func.call_async(args.clone()).await?)
428 }
429 _ => anyhow::bail!("Event {name} has not been registered"),
430 }
431}
432
433pub fn get_or_create_module(lua: &Lua, name: &str) -> anyhow::Result<mlua::Table> {
434 let globals = lua.globals();
435 let package: Table = globals.get("package")?;
436 let loaded: Table = package.get("loaded")?;
437
438 let module = loaded.get(name)?;
439 match module {
440 Value::Nil => {
441 let module = lua.create_table()?;
442 loaded.set(name, module.clone())?;
443 Ok(module)
444 }
445 Value::Table(table) => Ok(table),
446 wat => anyhow::bail!(
447 "cannot register module {} as package.loaded.{} is already set to a value of type {}",
448 name,
449 name,
450 wat.type_name()
451 ),
452 }
453}
454
455pub fn get_or_create_sub_module(lua: &Lua, name_path: &str) -> anyhow::Result<mlua::Table> {
460 let mut parent = get_or_create_module(lua, "kumo")?;
461 let mut path_so_far = String::new();
462
463 for name in name_path.split('.') {
464 if !path_so_far.is_empty() {
465 path_so_far.push('.');
466 }
467 path_so_far.push_str(name);
468
469 let sub = parent.get(name)?;
470 match sub {
471 Value::Nil => {
472 let sub = lua.create_table()?;
473 parent.set(name, sub.clone())?;
474 parent = sub;
475 }
476 Value::Table(sub) => {
477 parent = sub;
478 }
479 wat => anyhow::bail!(
480 "cannot register module kumo.{path_so_far} as it is already set to a value of type {}",
481 wat.type_name()
482 ),
483 }
484 }
485
486 Ok(parent)
487}
488
489pub fn any_err<E: std::fmt::Display>(err: E) -> mlua::Error {
491 mlua::Error::external(format!("{err:#}"))
492}
493
494pub fn impl_pairs_and_index<T, M>(methods: &mut M)
501where
502 T: UserData + Serialize,
503 M: UserDataMethods<T>,
504{
505 methods.add_meta_method(MetaMethod::Pairs, move |lua, this, _: ()| {
506 let Ok(serde_json::Value::Object(map)) = serde_json::to_value(this).map_err(any_err) else {
507 return Err(mlua::Error::external("must serialize to Map"));
508 };
509
510 let mut value_iter = map.into_iter();
511
512 let iter_func = lua.create_function_mut(
513 move |lua, (_state, _control): (Value, Value)| match value_iter.next() {
514 Some((key, value)) => {
515 let key = lua.to_value(&key)?;
516 let value = lua.to_value(&value)?;
517 Ok((key, value))
518 }
519 None => Ok((Value::Nil, Value::Nil)),
520 },
521 )?;
522
523 Ok((Value::Function(iter_func), Value::Nil, Value::Nil))
524 });
525
526 methods.add_meta_method(MetaMethod::Index, move |lua, this, field: Value| {
527 let value = lua.to_value(this)?;
528 match value {
529 Value::Table(t) => t.get(field),
530 _ => Ok(Value::Nil),
531 }
532 });
533
534 methods.add_meta_method(MetaMethod::Len, move |lua, this, _: ()| {
535 let value = lua.to_value(this)?;
536 match value {
537 Value::Table(v) => v.len(),
538 Value::String(v) => Ok(v.as_bytes().len() as i64),
539 _ => Ok(0),
540 }
541 });
542}
543
544pub fn materialize_to_lua_value(lua: &Lua, value: mlua::Value) -> mlua::Result<mlua::Value> {
552 match value {
553 mlua::Value::UserData(ud) => {
554 let mt = ud.metatable()?;
555 let Ok(pairs) = mt.get::<mlua::Function>("__pairs") else {
556 let value = ud.into_lua(lua)?;
557 return Err(mlua::Error::external(format!(
558 "cannot materialize_to_lua_value {value:?} \
559 because it has no __pairs metamethod"
560 )));
561 };
562 let tbl = lua.create_table()?;
563 let (iter_func, state, mut control): (mlua::Function, mlua::Value, mlua::Value) =
564 pairs.call(mlua::Value::UserData(ud.clone()))?;
565
566 loop {
567 let (k, v): (mlua::Value, mlua::Value) =
568 iter_func.call((state.clone(), control))?;
569 if k.is_nil() {
570 break;
571 }
572
573 tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
574 control = k;
575 }
576
577 Ok(mlua::Value::Table(tbl))
578 }
579 mlua::Value::Table(t) => {
580 let tbl = lua.create_table()?;
581 for pair in t.pairs::<mlua::Value, mlua::Value>() {
582 let (k, v) = pair?;
583 tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
584 }
585 Ok(mlua::Value::Table(tbl))
586 }
587 value => Ok(value),
588 }
589}
590
591pub fn from_lua_value<R>(lua: &Lua, value: mlua::Value) -> mlua::Result<R>
596where
597 R: serde::de::DeserializeOwned,
598{
599 let value_cloned = value.clone();
600 match lua.from_value(value) {
601 Ok(r) => Ok(r),
602 Err(err) => match materialize_to_lua_value(lua, value_cloned.clone()) {
603 Ok(materialized) => match lua.from_value(materialized.clone()) {
604 Ok(r) => Ok(r),
605 Err(err) => {
606 let mut serializer = serde_json::Serializer::new(Vec::new());
607 let serialized = match materialized.serialize(&mut serializer) {
608 Ok(_) => String::from_utf8_lossy(&serializer.into_inner()).to_string(),
609 Err(err) => format!("<unable to encode as json: {err:#}>"),
610 };
611 Err(mlua::Error::external(format!(
612 "{err:#}, while processing {serialized}"
613 )))
614 }
615 },
616 Err(materialize_err) => Err(mlua::Error::external(format!(
617 "{err:#}, while processing a userdata. \
618 Additionally, encountered {materialize_err:#} \
619 when trying to iterate the pairs of that userdata"
620 ))),
621 },
622 }
623}
624
625pub struct CallbackSignature<A, R>
639where
640 A: IntoLuaMulti,
641 R: FromLuaMulti,
642{
643 marker: std::marker::PhantomData<(A, R)>,
644 allow_multiple: bool,
645 name: Cow<'static, str>,
646}
647
648#[linkme::distributed_slice]
649pub static CALLBACK_SIGNATURES: [fn()];
650
651#[macro_export]
679macro_rules! declare_event {
680 ($vis:vis static $sym:ident: Multiple($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
681 $vis static $sym: ::std::sync::LazyLock<
682 $crate::CallbackSignature<($($args),*), $ret>> =
683 ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new_with_multiple($name));
684
685 $crate::paste::paste! {
686 #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
687 static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
688 $sym.register();
689 };
690 }
691 };
692 ($vis:vis static $sym:ident: Single($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
693 $vis static $sym: ::std::sync::LazyLock<
694 $crate::CallbackSignature<($($args),*), $ret>> =
695 ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new($name));
696
697 $crate::paste::paste! {
698 #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
699 static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
700 $sym.register();
701 };
702 }
703 };
704}
705
706fn register_declared_events() {
713 static ONCE: Once = Once::new();
714 ONCE.call_once(|| {
715 for reg_func in CALLBACK_SIGNATURES {
716 reg_func();
717 }
718 });
719}
720
721impl<A, R> CallbackSignature<A, R>
722where
723 A: IntoLuaMulti,
724 R: FromLuaMulti,
725{
726 pub fn new<S: Into<Cow<'static, str>>>(name: S) -> Self {
727 let name = name.into();
728
729 Self {
730 marker: std::marker::PhantomData,
731 allow_multiple: false,
732 name,
733 }
734 }
735
736 pub fn new_with_multiple<S: Into<Cow<'static, str>>>(name: S) -> Self {
740 let name = name.into();
741
742 Self {
743 marker: std::marker::PhantomData,
744 allow_multiple: true,
745 name,
746 }
747 }
748
749 pub fn register(&self) {
750 if self.allow_multiple {
751 CALLBACK_ALLOWS_MULTIPLE
752 .lock()
753 .insert(self.name.to_string());
754 }
755 }
756
757 pub fn raise_error_if_allow_multiple(&self) -> anyhow::Result<()> {
758 anyhow::ensure!(
759 !self.allow_multiple(),
760 "handler {} is set to allow multiple handlers \
761 but is registered with a single instance. This indicates that \
762 register() was not called on the signature when initializing \
763 the lua context. Please report this issue to the KumoMTA team!",
764 self.name
765 );
766 Ok(())
767 }
768
769 pub fn allow_multiple(&self) -> bool {
772 self.allow_multiple
773 }
774
775 pub fn name(&self) -> &str {
776 &self.name
777 }
778
779 pub fn decorated_name(&self) -> String {
780 decorate_callback_name(&self.name)
781 }
782}
783
784pub fn does_callback_allow_multiple(name: &str) -> bool {
785 CALLBACK_ALLOWS_MULTIPLE.lock().contains(name)
786}
787
788pub fn decorate_callback_name(name: &str) -> String {
789 format!("kumomta-on-{name}")
790}
791
792pub fn serialize_options() -> mlua::SerializeOptions {
793 mlua::SerializeOptions::new()
794 .serialize_none_to_null(false)
795 .serialize_unit_to_null(false)
796}