1use crate::epoch::{get_current_epoch, ConfigEpoch};
2use crate::pool::{pool_get, pool_put};
3pub use crate::pool::{set_gc_on_put, set_max_age, set_max_spare, set_max_use};
4use anyhow::Context;
5use kumo_prometheus::declare_metric;
6use kumo_prometheus::prometheus::HistogramTimer;
7use mlua::{
8 FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, Lua, LuaSerdeExt, MetaMethod, RegistryKey, Table,
9 UserData, UserDataMethods, Value,
10};
11use parking_lot::FairMutex as Mutex;
12pub use pastey as paste;
13use serde::Serialize;
14use std::borrow::Cow;
15use std::collections::HashSet;
16use std::path::PathBuf;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{LazyLock, Once};
19use std::time::Instant;
20
21pub mod epoch;
22mod pool;
23
24static POLICY_FILE: LazyLock<Mutex<Option<PathBuf>>> = LazyLock::new(|| Mutex::new(None));
25static FUNCS: LazyLock<Mutex<Vec<RegisterFunc>>> = LazyLock::new(|| Mutex::new(vec![]));
26
27declare_metric! {
28static LUA_LOAD_COUNT: IntCounter("lua_load_count");
30}
31
32declare_metric! {
33static LUA_COUNT: IntGauge("lua_count");
35}
36
37static CALLBACK_ALLOWS_MULTIPLE: LazyLock<Mutex<HashSet<String>>> =
38 LazyLock::new(|| Mutex::new(HashSet::new()));
39
40pub static VALIDATE_ONLY: AtomicBool = AtomicBool::new(false);
41pub static VALIDATION_FAILED: AtomicBool = AtomicBool::new(false);
42
43declare_metric! {
44static LATENCY_HIST: HistogramVec(
46 "lua_event_latency",
47 &["event"]);
48}
49
50declare_metric! {
51static EVENT_STARTED_COUNT: CounterVec(
53 "lua_event_started",
54 &["event"]
55 );
56}
57
58pub type RegisterFunc = fn(&Lua) -> anyhow::Result<()>;
59
60fn latency_timer(label: &str) -> HistogramTimer {
61 EVENT_STARTED_COUNT
62 .get_metric_with_label_values(&[label])
63 .expect("to get counter")
64 .inc();
65 LATENCY_HIST
66 .get_metric_with_label_values(&[label])
67 .expect("to get histo")
68 .start_timer()
69}
70
71#[derive(Debug)]
72struct LuaConfigInner {
73 lua: Lua,
74 created: Instant,
75 use_count: usize,
76 epoch: ConfigEpoch,
77}
78
79impl Drop for LuaConfigInner {
80 fn drop(&mut self) {
81 LUA_COUNT.dec();
82 }
83}
84
85#[derive(Debug)]
86pub struct LuaConfig {
87 inner: Option<LuaConfigInner>,
88}
89
90pub async fn set_policy_path(path: PathBuf) -> anyhow::Result<()> {
91 POLICY_FILE.lock().replace(path);
92 let config = load_config().await?;
93 config.put();
94 Ok(())
95}
96
97fn get_policy_path() -> Option<PathBuf> {
98 POLICY_FILE.lock().clone()
99}
100
101fn get_funcs() -> Vec<RegisterFunc> {
102 FUNCS.lock().clone()
103}
104pub fn is_validating() -> bool {
105 VALIDATE_ONLY.load(Ordering::Relaxed)
106}
107
108pub fn validation_failed() -> bool {
109 VALIDATION_FAILED.load(Ordering::Relaxed)
110}
111
112pub fn set_validation_failed() {
113 VALIDATION_FAILED.store(true, Ordering::Relaxed)
114}
115
116pub async fn load_config() -> anyhow::Result<LuaConfig> {
117 if let Some(pool) = pool_get() {
118 return Ok(pool);
119 }
120
121 LUA_LOAD_COUNT.inc();
122 let lua = Lua::new();
123 let created = Instant::now();
124 let epoch = get_current_epoch();
125
126 {
127 let globals = lua.globals();
128
129 if is_validating() {
130 globals.set("_VALIDATING_CONFIG", true)?;
131 }
132
133 let package: Table = globals.get("package")?;
134 let package_path: String = package.get("path")?;
135 let mut path_array: Vec<String> = package_path.split(";").map(|s| s.to_owned()).collect();
136
137 fn prefix_path(array: &mut Vec<String>, path: &str) {
138 array.insert(0, format!("{}/?.lua", path));
139 array.insert(1, format!("{}/?/init.lua", path));
140 }
141
142 prefix_path(&mut path_array, "/opt/kumomta/etc/policy");
143 prefix_path(&mut path_array, "/opt/kumomta/share");
144
145 #[cfg(debug_assertions)]
146 prefix_path(&mut path_array, "assets");
147
148 package.set("path", path_array.join(";"))?;
149 }
150
151 register_declared_events();
152
153 for func in get_funcs() {
154 (func)(&lua)?;
155 }
156
157 if let Some(policy) = get_policy_path() {
158 let code = tokio::fs::read_to_string(&policy)
159 .await
160 .with_context(|| format!("reading policy file {policy:?}"))?;
161
162 let func = {
163 let chunk = lua.load(&code);
164 let chunk = chunk.set_name(policy.to_string_lossy());
165 chunk.into_function()?
166 };
167
168 let _timer = latency_timer("context-creation");
169 func.call_async::<()>(()).await?;
170 }
171 LUA_COUNT.inc();
172
173 Ok(LuaConfig {
174 inner: Some(LuaConfigInner {
175 lua,
176 created,
177 use_count: 1,
178 epoch,
179 }),
180 })
181}
182
183pub fn register(func: RegisterFunc) {
184 FUNCS.lock().push(func);
185}
186
187impl LuaConfig {
188 fn set_current_event(&mut self, name: &str) -> mlua::Result<()> {
189 self.inner
190 .as_mut()
191 .unwrap()
192 .lua
193 .globals()
194 .set("_KUMO_CURRENT_EVENT", name.to_string())
195 }
196
197 pub fn convert_args_to_multi<A: Serialize>(
200 &self,
201 args: &[A],
202 ) -> anyhow::Result<mlua::MultiValue> {
203 let lua = self.inner.as_ref().unwrap();
204 let mut arg_vec = vec![];
205 for a in args.iter() {
206 arg_vec.push(lua.lua.to_value(a)?);
207 }
208 Ok(mlua::MultiValue::from_vec(arg_vec))
209 }
210
211 pub async fn convert_args_and_call_callback<A: Serialize>(
213 &mut self,
214 sig: &CallbackSignature<Value, ()>,
215 args: A,
216 ) -> anyhow::Result<()> {
217 let lua = self.inner.as_mut().unwrap();
218 let args = lua.lua.to_value(&args)?;
219
220 let name = sig.name();
221 let decorated_name = sig.decorated_name();
222
223 match lua
224 .lua
225 .named_registry_value::<mlua::Function>(&decorated_name)
226 {
227 Ok(func) => {
228 let _timer = latency_timer(name);
229 Ok(func.call_async(args).await?)
230 }
231 _ => anyhow::bail!("{name} has not been registered"),
232 }
233 }
234
235 pub fn put(mut self) {
253 if let Some(inner) = self.inner.take() {
254 pool_put(inner);
255 }
256 }
257
258 pub async fn call_callback<'a, A: IntoLuaMulti + Clone, R: FromLuaMulti>(
259 &mut self,
260 sig: &'a CallbackSignature<A, R>,
261 args: A,
262 ) -> anyhow::Result<CallbackDisposition<'a, R>> {
263 let name = sig.name();
264 self.set_current_event(name)?;
265 let lua = self.inner.as_mut().unwrap();
266 sig.call(&lua.lua, args).await
267 }
268
269 pub async fn async_call_callback<A: IntoLuaMulti + Clone, R: FromLuaMulti + Default>(
270 &mut self,
271 sig: &CallbackSignature<A, R>,
272 args: A,
273 ) -> anyhow::Result<R> {
274 let name = sig.name();
275 self.set_current_event(name)?;
276 let lua = self.inner.as_mut().unwrap();
277 Ok(sig.call(&lua.lua, args).await?.or_default())
278 }
279
280 pub async fn async_call_callback_non_default<A: IntoLuaMulti + Clone, R: FromLuaMulti>(
281 &mut self,
282 sig: &CallbackSignature<A, R>,
283 args: A,
284 ) -> anyhow::Result<R> {
285 let name = sig.name();
286 self.set_current_event(name)?;
287 let lua = self.inner.as_mut().unwrap();
288 sig.call(&lua.lua, args).await?.require_value()
289 }
290
291 pub async fn async_call_callback_non_default_opt<A: IntoLuaMulti + Clone, R: FromLua>(
292 &mut self,
293 sig: &CallbackSignature<A, Option<R>>,
294 args: A,
295 ) -> anyhow::Result<Option<R>> {
296 let name = sig.name();
297 self.set_current_event(name)?;
298 let lua = self.inner.as_mut().unwrap();
299 let result = sig.call(&lua.lua, args).await?;
300 match result.result {
301 None => Ok(None),
302 Some(result) => Ok(result),
303 }
304 }
305
306 pub fn remove_registry_value(&mut self, value: RegistryKey) -> anyhow::Result<()> {
307 Ok(self
308 .inner
309 .as_mut()
310 .unwrap()
311 .lua
312 .remove_registry_value(value)?)
313 }
314
315 pub async fn async_call_ctor<A: IntoLuaMulti + Clone>(
318 &mut self,
319 sig: &CallbackSignature<A, Value>,
320 args: A,
321 ) -> anyhow::Result<RegistryKey> {
322 let name = sig.name();
323 anyhow::ensure!(
324 !sig.allow_multiple(),
325 "ctor event signature for {name} is defined as allow_multiple, which is not supported"
326 );
327
328 let decorated_name = sig.decorated_name();
329 self.set_current_event(name)?;
330
331 let inner = self.inner.as_mut().unwrap();
332
333 let func = inner
334 .lua
335 .named_registry_value::<mlua::Function>(&decorated_name)?;
336
337 let _timer = latency_timer(name);
338 let value: Value = func.call_async(args.clone()).await?;
339 drop(func);
340
341 Ok(inner.lua.create_registry_value(value)?)
342 }
343
344 pub async fn with_registry_value<F, R, FUT>(
347 &mut self,
348 value: &RegistryKey,
349 func: F,
350 ) -> anyhow::Result<R>
351 where
352 R: FromLuaMulti,
353 F: FnOnce(Value) -> anyhow::Result<FUT>,
354 FUT: std::future::Future<Output = anyhow::Result<R>>,
355 {
356 let inner = self.inner.as_mut().unwrap();
357 let value = inner.lua.registry_value(value)?;
358 let future = (func)(value)?;
359 future.await
360 }
361}
362
363pub fn get_or_create_module(lua: &Lua, name: &str) -> anyhow::Result<mlua::Table> {
364 let globals = lua.globals();
365 let package: Table = globals.get("package")?;
366 let loaded: Table = package.get("loaded")?;
367
368 let module = loaded.get(name)?;
369 match module {
370 Value::Nil => {
371 let module = lua.create_table()?;
372 loaded.set(name, module.clone())?;
373 Ok(module)
374 }
375 Value::Table(table) => Ok(table),
376 wat => anyhow::bail!(
377 "cannot register module {} as package.loaded.{} is already set to a value of type {}",
378 name,
379 name,
380 wat.type_name()
381 ),
382 }
383}
384
385pub fn get_or_create_sub_module(lua: &Lua, name_path: &str) -> anyhow::Result<mlua::Table> {
390 let mut parent = get_or_create_module(lua, "kumo")?;
391 let mut path_so_far = String::new();
392
393 for name in name_path.split('.') {
394 if !path_so_far.is_empty() {
395 path_so_far.push('.');
396 }
397 path_so_far.push_str(name);
398
399 let sub = parent.get(name)?;
400 match sub {
401 Value::Nil => {
402 let sub = lua.create_table()?;
403 parent.set(name, sub.clone())?;
404 parent = sub;
405 }
406 Value::Table(sub) => {
407 parent = sub;
408 }
409 wat => anyhow::bail!(
410 "cannot register module kumo.{path_so_far} as it is already set to a value of type {}",
411 wat.type_name()
412 ),
413 }
414 }
415
416 Ok(parent)
417}
418
419pub fn any_err<E: std::fmt::Display>(err: E) -> mlua::Error {
421 mlua::Error::external(format!("{err:#}"))
422}
423
424pub fn impl_pairs_and_index<T, M>(methods: &mut M)
431where
432 T: UserData + Serialize,
433 M: UserDataMethods<T>,
434{
435 methods.add_meta_method(MetaMethod::Pairs, move |lua, this, _: ()| {
436 let Ok(serde_json::Value::Object(map)) = serde_json::to_value(this).map_err(any_err) else {
437 return Err(mlua::Error::external("must serialize to Map"));
438 };
439
440 let mut value_iter = map.into_iter();
441
442 let iter_func = lua.create_function_mut(
443 move |lua, (_state, _control): (Value, Value)| match value_iter.next() {
444 Some((key, value)) => {
445 let key = lua.to_value(&key)?;
446 let value = lua.to_value(&value)?;
447 Ok((key, value))
448 }
449 None => Ok((Value::Nil, Value::Nil)),
450 },
451 )?;
452
453 Ok((Value::Function(iter_func), Value::Nil, Value::Nil))
454 });
455
456 methods.add_meta_method(MetaMethod::Index, move |lua, this, field: Value| {
457 let value = lua.to_value(this)?;
458 match value {
459 Value::Table(t) => t.get(field),
460 _ => Ok(Value::Nil),
461 }
462 });
463
464 methods.add_meta_method(MetaMethod::Len, move |lua, this, _: ()| {
465 let value = lua.to_value(this)?;
466 match value {
467 Value::Table(v) => v.len(),
468 Value::String(v) => Ok(v.as_bytes().len() as i64),
469 _ => Ok(0),
470 }
471 });
472}
473
474pub fn materialize_to_lua_value(lua: &Lua, value: mlua::Value) -> mlua::Result<mlua::Value> {
482 match value {
483 mlua::Value::UserData(ud) => {
484 let mt = ud.metatable()?;
485 let Ok(pairs) = mt.get::<mlua::Function>("__pairs") else {
486 let value = ud.into_lua(lua)?;
487 return Err(mlua::Error::external(format!(
488 "cannot materialize_to_lua_value {value:?} \
489 because it has no __pairs metamethod"
490 )));
491 };
492 let tbl = lua.create_table()?;
493 let (iter_func, state, mut control): (mlua::Function, mlua::Value, mlua::Value) =
494 pairs.call(mlua::Value::UserData(ud.clone()))?;
495
496 loop {
497 let (k, v): (mlua::Value, mlua::Value) =
498 iter_func.call((state.clone(), control))?;
499 if k.is_nil() {
500 break;
501 }
502
503 tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
504 control = k;
505 }
506
507 Ok(mlua::Value::Table(tbl))
508 }
509 mlua::Value::Table(t) => {
510 let tbl = lua.create_table()?;
511 for pair in t.pairs::<mlua::Value, mlua::Value>() {
512 let (k, v) = pair?;
513 tbl.set(k.clone(), materialize_to_lua_value(lua, v)?)?;
514 }
515 Ok(mlua::Value::Table(tbl))
516 }
517 value => Ok(value),
518 }
519}
520
521pub struct SerdeWrappedValue<T>(pub T);
523
524impl<T: serde::Serialize> serde::Serialize for SerdeWrappedValue<T> {
525 fn serialize<S>(
526 &self,
527 s: S,
528 ) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error>
529 where
530 S: serde::Serializer,
531 {
532 self.0.serialize(s)
533 }
534}
535
536impl<T: Clone> Clone for SerdeWrappedValue<T> {
537 fn clone(&self) -> Self {
538 SerdeWrappedValue(self.0.clone())
539 }
540}
541
542impl<T: Default> Default for SerdeWrappedValue<T> {
543 fn default() -> Self {
544 SerdeWrappedValue(Default::default())
545 }
546}
547
548impl<T: serde::Serialize> SerdeWrappedValue<T> {
549 pub fn to_lua_value(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
550 lua.to_value_with(&self.0, serialize_options())
551 }
552}
553
554impl<T: serde::Serialize> IntoLua for SerdeWrappedValue<T> {
555 fn into_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
556 lua.to_value_with(&self.0, serialize_options())
557 }
558}
559
560impl<T: serde::de::DeserializeOwned> FromLua for SerdeWrappedValue<T> {
561 fn from_lua(value: mlua::Value, lua: &Lua) -> mlua::Result<SerdeWrappedValue<T>> {
562 let inner: T = from_lua_value(lua, value)?;
563 Ok(SerdeWrappedValue(inner))
564 }
565}
566
567impl<T> std::ops::Deref for SerdeWrappedValue<T> {
568 type Target = T;
569 fn deref(&self) -> &T {
570 &self.0
571 }
572}
573
574impl<T> std::ops::DerefMut for SerdeWrappedValue<T> {
575 fn deref_mut(&mut self) -> &mut T {
576 &mut self.0
577 }
578}
579
580pub fn from_lua_value<R>(lua: &Lua, value: mlua::Value) -> mlua::Result<R>
585where
586 R: serde::de::DeserializeOwned,
587{
588 let value_cloned = value.clone();
589 match lua.from_value(value) {
590 Ok(r) => Ok(r),
591 Err(err) => match materialize_to_lua_value(lua, value_cloned.clone()) {
592 Ok(materialized) => match lua.from_value(materialized.clone()) {
593 Ok(r) => Ok(r),
594 Err(err) => {
595 let mut serializer = serde_json::Serializer::new(Vec::new());
596 let serialized = match materialized.serialize(&mut serializer) {
597 Ok(_) => String::from_utf8_lossy(&serializer.into_inner()).to_string(),
598 Err(err) => format!("<unable to encode as json: {err:#}>"),
599 };
600 Err(mlua::Error::external(format!(
601 "{err:#}, while processing {serialized}"
602 )))
603 }
604 },
605 Err(materialize_err) => Err(mlua::Error::external(format!(
606 "{err:#}, while processing a userdata. \
607 Additionally, encountered {materialize_err:#} \
608 when trying to iterate the pairs of that userdata"
609 ))),
610 },
611 }
612}
613
614pub struct CallbackSignature<A, R>
628where
629 A: IntoLuaMulti,
630 R: FromLuaMulti,
631{
632 marker: std::marker::PhantomData<(A, R)>,
633 allow_multiple: bool,
634 name: Cow<'static, str>,
635}
636
637#[linkme::distributed_slice]
638pub static CALLBACK_SIGNATURES: [fn()];
639
640#[macro_export]
668macro_rules! declare_event {
669 ($vis:vis static $sym:ident: Multiple($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
670 $vis static $sym: ::std::sync::LazyLock<
671 $crate::CallbackSignature<($($args),*), $ret>> =
672 ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new_with_multiple($name));
673
674 $crate::paste::paste! {
675 #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
676 static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
677 $sym.register();
678 };
679 }
680 };
681 ($vis:vis static $sym:ident: Single($name:literal $(,)? $($param_name:ident: $args:ty),* $(,)? ) -> $ret:ty;) => {
682 $vis static $sym: ::std::sync::LazyLock<
683 $crate::CallbackSignature<($($args),*), $ret>> =
684 ::std::sync::LazyLock::new(|| $crate::CallbackSignature::new($name));
685
686 $crate::paste::paste! {
687 #[linkme::distributed_slice($crate::CALLBACK_SIGNATURES)]
688 static [<CALLBACK_SIG_REGISTER_ $sym>]: fn() = || {
689 $sym.register();
690 };
691 }
692 };
693}
694
695fn register_declared_events() {
702 static ONCE: Once = Once::new();
703 ONCE.call_once(|| {
704 for reg_func in CALLBACK_SIGNATURES {
705 reg_func();
706 }
707 });
708}
709
710impl<A, R> CallbackSignature<A, R>
711where
712 A: IntoLuaMulti,
713 R: FromLuaMulti,
714{
715 pub fn new<S: Into<Cow<'static, str>>>(name: S) -> Self {
716 let name = name.into();
717
718 Self {
719 marker: std::marker::PhantomData,
720 allow_multiple: false,
721 name,
722 }
723 }
724
725 pub fn new_with_multiple<S: Into<Cow<'static, str>>>(name: S) -> Self {
729 let name = name.into();
730
731 Self {
732 marker: std::marker::PhantomData,
733 allow_multiple: true,
734 name,
735 }
736 }
737
738 pub fn register(&self) {
739 if self.allow_multiple {
740 CALLBACK_ALLOWS_MULTIPLE
741 .lock()
742 .insert(self.name.to_string());
743 }
744 }
745
746 pub fn raise_error_if_allow_multiple(&self) -> anyhow::Result<()> {
747 anyhow::ensure!(
748 !self.allow_multiple(),
749 "handler {} is set to allow multiple handlers \
750 but is registered with a single instance. This indicates that \
751 register() was not called on the signature when initializing \
752 the lua context. Please report this issue to the KumoMTA team!",
753 self.name
754 );
755 Ok(())
756 }
757
758 pub fn allow_multiple(&self) -> bool {
761 self.allow_multiple
762 }
763
764 pub fn name(&self) -> &str {
765 &self.name
766 }
767
768 pub fn decorated_name(&self) -> String {
769 decorate_callback_name(&self.name)
770 }
771}
772
773impl<A, R> CallbackSignature<A, R>
774where
775 A: IntoLuaMulti + Clone,
776 R: FromLuaMulti,
777{
778 pub async fn call<'a>(
784 &'a self,
785 lua: &Lua,
786 args: A,
787 ) -> anyhow::Result<CallbackDisposition<'a, R>> {
788 let name = self.name();
789 let decorated_name = self.decorated_name();
790
791 match lua.named_registry_value::<mlua::Value>(&decorated_name)? {
792 Value::Table(tbl) => {
793 for func in tbl.sequence_values::<mlua::Function>().collect::<Vec<_>>() {
794 let func = func?;
795 let _timer = latency_timer(name);
796 let result: mlua::MultiValue = func.call_async(args.clone()).await?;
797 if result.is_empty() {
798 continue;
800 }
801 let result = R::from_lua_multi(result, lua)?;
802 return Ok(CallbackDisposition {
803 handler_was_defined: true,
804 result: Some(result),
805 event_name: name,
806 });
807 }
808 Ok(CallbackDisposition {
809 handler_was_defined: false,
810 result: None,
811 event_name: name,
812 })
813 }
814 Value::Function(func) => {
815 self.raise_error_if_allow_multiple()?;
816 let _timer = latency_timer(name);
817 Ok(CallbackDisposition {
818 handler_was_defined: true,
819 result: Some(func.call_async(args.clone()).await?),
820 event_name: name,
821 })
822 }
823 _ => Ok(CallbackDisposition {
824 handler_was_defined: false,
825 result: None,
826 event_name: name,
827 }),
828 }
829 }
830}
831
832pub fn does_callback_allow_multiple(name: &str) -> bool {
833 CALLBACK_ALLOWS_MULTIPLE.lock().contains(name)
834}
835
836pub fn decorate_callback_name(name: &str) -> String {
837 format!("kumomta-on-{name}")
838}
839
840pub struct CallbackDisposition<'a, T> {
843 pub handler_was_defined: bool,
848 pub result: Option<T>,
850 pub event_name: &'a str,
852}
853
854impl<'a, T> CallbackDisposition<'a, T> {
855 pub fn require_value(mut self) -> anyhow::Result<T> {
858 if !self.handler_was_defined {
859 anyhow::bail!("Event {} has not been registered", self.event_name);
860 }
861 match self.result.take() {
862 Some(value) => Ok(value),
863 None => anyhow::bail!("Event {} did not return a value", self.event_name),
864 }
865 }
866}
867impl<'a, T: Default> CallbackDisposition<'a, T> {
868 pub fn or_default(self) -> T {
870 self.result.unwrap_or_default()
871 }
872}
873
874pub fn serialize_options() -> mlua::SerializeOptions {
875 mlua::SerializeOptions::new()
876 .serialize_none_to_null(false)
877 .serialize_unit_to_null(false)
878}