1use config::epoch::{get_current_epoch, ConfigEpoch};
2use config::{any_err, from_lua_value, get_or_create_module, serialize_options};
3use dashmap::DashMap;
4use kumo_prometheus::declare_metric;
5use lruttl::LruCacheWithTtl;
6use mlua::{
7 FromLua, Function, IntoLua, Lua, LuaSerdeExt, MetaMethod, MultiValue, UserData,
8 UserDataMethods, UserDataRef,
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, LazyLock};
13use std::time::Duration;
14
15#[derive(Clone, mlua::FromLua)]
29pub struct Memoized {
30 pub to_value: Arc<dyn Fn(&Lua) -> mlua::Result<mlua::Value> + Send + Sync>,
31}
32
33impl PartialEq for Memoized {
34 fn eq(&self, other: &Self) -> bool {
35 Arc::ptr_eq(&self.to_value, &other.to_value)
36 }
37}
38
39impl Memoized {
40 pub fn impl_memoize<T, M>(methods: &mut M)
43 where
44 T: UserData + Send + Sync + Clone + 'static,
45 M: UserDataMethods<T>,
46 {
47 methods.add_meta_method(
48 "__memoize",
49 move |_lua, this, _: ()| -> mlua::Result<Memoized> {
50 let this = this.clone();
51 Ok(Memoized {
52 to_value: Arc::new(move |lua| this.clone().into_lua(lua)),
53 })
54 },
55 );
56 }
57}
58
59impl UserData for Memoized {}
60
61#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
62#[serde(deny_unknown_fields)]
63pub struct MemoizeParams {
64 #[serde(with = "duration_serde")]
65 pub ttl: Duration,
66 pub capacity: usize,
67 pub name: String,
68 #[serde(default)]
69 pub invalidate_with_epoch: bool,
70 #[serde(default)]
71 pub retry_on_populate_timeout: bool,
72 #[serde(default, with = "duration_serde")]
73 pub populate_timeout: Option<Duration>,
74 #[serde(default)]
75 pub allow_stale_reads: bool,
76}
77
78#[derive(Clone, Hash, Eq, PartialEq)]
79pub enum MapKey {
80 Integer(mlua::Integer),
81 String(Vec<u8>),
82}
83
84impl MapKey {
85 pub fn from_lua(v: mlua::Value) -> Option<Self> {
86 match v {
87 mlua::Value::String(s) => Some(Self::String(s.as_bytes().to_vec())),
88 mlua::Value::Integer(n) => Some(Self::Integer(n)),
89 _ => None,
90 }
91 }
92
93 pub fn as_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
94 match self {
95 Self::Integer(j) => Ok(mlua::Value::Integer(j)),
96 Self::String(b) => Ok(mlua::Value::String(lua.create_string(b)?)),
97 }
98 }
99}
100
101#[derive(Clone, PartialEq)]
102pub enum CacheValue {
103 Table(Arc<HashMap<MapKey, CacheValue>>),
104 Json(serde_json::Value),
105 Memoized(Memoized),
106}
107
108impl std::fmt::Debug for CacheValue {
109 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
110 fmt.debug_struct("CacheValue").finish()
111 }
112}
113
114impl FromLua for CacheValue {
115 fn from_lua(value: mlua::Value, lua: &Lua) -> mlua::Result<Self> {
116 match value {
117 mlua::Value::UserData(ud) => {
118 let mt = ud.metatable()?;
119 let func: Function = mt.get("__memoize")?;
120 let m: Memoized = func.call(mlua::Value::UserData(ud))?;
121 Ok(Self::Memoized(m))
122 }
123 mlua::Value::Table(tbl) => {
124 let mut map = HashMap::new();
125 for pair in tbl.pairs::<mlua::Value, mlua::Value>() {
126 let (key, value) = pair?;
127 let key = match key {
128 mlua::Value::Integer(n) => MapKey::Integer(n),
129 mlua::Value::String(n) => MapKey::String(n.as_bytes().to_vec()),
130 _ => {
131 return Err(anyhow::anyhow!(
132 "table key {key:?} cannot be used as a key in a memoizable table"
133 ))
134 .map_err(any_err)
135 }
136 };
137 let value = CacheValue::from_lua(value, lua)?;
138 map.insert(key, value);
139 }
140 Ok(Self::Table(map.into()))
141 }
142 _ => Ok(Self::Json(from_lua_value(lua, value)?)),
143 }
144 }
145}
146
147impl IntoLua for CacheValue {
148 fn into_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
149 self.as_lua(lua)
150 }
151}
152
153impl CacheValue {
154 pub fn as_lua(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
155 match self {
156 Self::Json(j) => lua.to_value_with(j, serialize_options()),
157 Self::Memoized(m) => (m.to_value)(lua),
158 Self::Table(m) => Ok(mlua::Value::UserData(
159 lua.create_userdata(MemoizedTable::Shared(m.clone()))?,
160 )),
161 }
162 }
163}
164
165enum MemoizedTable {
178 Shared(Arc<HashMap<MapKey, CacheValue>>),
179 Mut(HashMap<MapKey, CacheValue>),
180}
181
182impl MemoizedTable {
183 fn table(&self) -> &HashMap<MapKey, CacheValue> {
186 match self {
187 Self::Shared(s) => s,
188 Self::Mut(s) => s,
189 }
190 }
191
192 fn unshare(&mut self) -> &mut HashMap<MapKey, CacheValue> {
194 if let Self::Shared(t) = self {
195 *self = Self::Mut(t.iter().map(|(k, v)| (k.clone(), v.clone())).collect());
196 }
197
198 match self {
199 Self::Shared(_) => unreachable!(),
200 Self::Mut(map) => map,
201 }
202 }
203}
204
205impl UserData for MemoizedTable {
206 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
207 methods.add_meta_method(MetaMethod::Index, move |lua, this, key: mlua::Value| {
209 match MapKey::from_lua(key) {
210 Some(key) => match this.table().get(&key) {
211 Some(value) => value.as_lua(lua),
212 None => Ok(mlua::Value::Nil),
213 },
214 None => Ok(mlua::Value::Nil),
215 }
216 });
217
218 methods.add_meta_method_mut(
220 MetaMethod::NewIndex,
221 move |lua, this, (key, value): (mlua::Value, mlua::Value)| match MapKey::from_lua(key) {
222 Some(key) => {
223 let value = CacheValue::from_lua(value, lua)?;
224 this.unshare().insert(key, value);
225 Ok(())
226 }
227 None => Err(mlua::Error::external(
228 "invalid key type while trying to call __newindex and assign a value",
229 )),
230 },
231 );
232 methods.add_meta_method(MetaMethod::Len, move |_lua, this, ()| {
233 Ok(this.table().len())
234 });
235
236 methods.add_meta_function(MetaMethod::Pairs, move |lua, this: mlua::Value| {
241 let mut idx = 0;
245
246 let iter_func =
247 lua.create_function_mut(
248 move |lua, (state, _control): (UserDataRef<MemoizedTable>, mlua::Value)| {
249 match state.table().iter().nth(idx) {
250 Some((key, value)) => {
251 idx += 1;
252 let key = key.clone().as_lua(lua)?;
253 let value = value.as_lua(lua)?;
254 Ok((key, value))
255 }
256 None => Ok((mlua::Value::Nil, mlua::Value::Nil)),
257 }
258 },
259 )?;
260
261 Ok((mlua::Value::Function(iter_func), this, mlua::Value::Nil))
267 });
268 }
269}
270
271#[derive(Clone, Debug)]
272enum CacheEntry {
273 Null,
274 Single(CacheValue),
275 Multi(Vec<CacheValue>),
276}
277
278impl CacheEntry {
279 fn to_value(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
280 match self {
281 Self::Null => Ok(mlua::Value::Nil),
282 Self::Single(value) => value.as_lua(lua),
283 Self::Multi(values) => {
284 let mut result = vec![];
285 for v in values {
286 result.push(v.as_lua(lua)?);
287 }
288 result.into_lua(lua)
289 }
290 }
291 }
292
293 fn from_multi_value(lua: &Lua, multi: MultiValue) -> mlua::Result<Self> {
294 let mut values = multi.into_vec();
295 if values.is_empty() {
296 Ok(Self::Null)
297 } else if values.len() == 1 {
298 Ok(Self::Single(CacheValue::from_lua(
299 values.pop().unwrap(),
300 lua,
301 )?))
302 } else {
303 let mut cvalues = vec![];
304 for v in values.into_iter() {
305 cvalues.push(CacheValue::from_lua(v, lua)?);
306 }
307 Ok(Self::Multi(cvalues))
308 }
309 }
310}
311
312struct MemoizeCache {
313 params: MemoizeParams,
314 cache: Arc<LruCacheWithTtl<CacheKey, CacheEntry>>,
315}
316
317static CACHES: LazyLock<DashMap<String, MemoizeCache>> = LazyLock::new(DashMap::new);
318
319type CacheKey = (Option<ConfigEpoch>, String);
320
321fn get_cache_by_name(
322 name: &str,
323) -> Option<(Arc<LruCacheWithTtl<CacheKey, CacheEntry>>, Duration, bool)> {
324 CACHES.get(name).map(|item| {
325 (
326 item.cache.clone(),
327 item.params.ttl,
328 item.params.invalidate_with_epoch,
329 )
330 })
331}
332
333declare_metric! {
334static CACHE_LOOKUP: CounterVec(
338 "memoize_cache_lookup_count",
339 &["cache_name"]);
340}
341
342declare_metric! {
343static CACHE_HIT: CounterVec(
347 "memoize_cache_hit_count",
348 &["cache_name"]);
349}
350
351declare_metric! {
352static CACHE_MISS: CounterVec(
356 "memoize_cache_miss_count",
357 &["cache_name"]);
358}
359
360declare_metric! {
361static CACHE_POPULATED: CounterVec(
365 "memoize_cache_populated_count",
366 &["cache_name"]);
367}
368
369fn multi_value_to_json_value(lua: &Lua, multi: MultiValue) -> mlua::Result<serde_json::Value> {
370 let mut values = multi.into_vec();
371 if values.is_empty() {
372 Ok(serde_json::Value::Null)
373 } else if values.len() == 1 {
374 from_lua_value(lua, values.pop().unwrap())
375 } else {
376 let mut jvalues = vec![];
377 for v in values.into_iter() {
378 jvalues.push(from_lua_value(lua, v)?);
379 }
380 Ok(serde_json::Value::Array(jvalues))
381 }
382}
383
384pub fn register(lua: &Lua) -> anyhow::Result<()> {
385 let kumo_mod = get_or_create_module(lua, "kumo")?;
386
387 kumo_mod.set(
388 "memoize",
389 lua.create_function(move |lua, (func, params): (mlua::Function, mlua::Value)| {
390 let params: MemoizeParams = from_lua_value(lua, params)?;
391
392 let cache_name = params.name.to_string();
393
394 if !lruttl::is_name_available(&cache_name) {
395 return Err(mlua::Error::external(format!(
396 "cannot use name `{cache_name}` for a memoize cache, \
397 as it collides with a built-in cache. \
398 Suggestion: prefix your cache name with `user.` to \
399 avoid conflicts with current and future caches."
400 )));
401 }
402
403 CACHES.remove_if(¶ms.name, |_k, item| {
404 let changed = item.params != params;
405 if changed {
406 tracing::trace!("memoize parameters changed, replacing old cache {params:?}");
407 }
408 changed
409 });
410 CACHES.entry(cache_name.to_string()).or_insert_with(|| {
411 let cache = LruCacheWithTtl::new(cache_name.clone(), params.capacity);
412 if let Some(duration) = params.populate_timeout {
413 cache.set_sema_timeout(duration);
414 }
415 cache.set_allow_stale_reads(params.allow_stale_reads);
416
417 MemoizeCache {
418 params: params.clone(),
419 cache: Arc::new(cache),
420 }
421 });
422
423 let lookup_counter = CACHE_LOOKUP
424 .get_metric_with_label_values(&[&cache_name])
425 .map_err(any_err)?;
426 let hit_counter = CACHE_HIT
427 .get_metric_with_label_values(&[&cache_name])
428 .map_err(any_err)?;
429 let miss_counter = CACHE_MISS
430 .get_metric_with_label_values(&[&cache_name])
431 .map_err(any_err)?;
432 let populate_counter = CACHE_POPULATED
433 .get_metric_with_label_values(&[&cache_name])
434 .map_err(any_err)?;
435 let retry_on_populate_timeout = params.retry_on_populate_timeout;
436 let allow_stale_reads = params.allow_stale_reads;
437
438 let func_ref = lua.create_registry_value(func)?;
439
440 lua.create_async_function(move |lua, params: MultiValue| {
441 let cache_name = cache_name.clone();
442 let func = lua.registry_value::<mlua::Function>(&func_ref);
443 let lookup_counter = lookup_counter.clone();
444 let hit_counter = hit_counter.clone();
445 let miss_counter = miss_counter.clone();
446 let populate_counter = populate_counter.clone();
447 async move {
448 lookup_counter.inc();
449 let key = multi_value_to_json_value(&lua, params.clone())?;
450
451 let func = func?;
452
453 let mut last_failure = None;
454
455 for _attempt in 0..3 {
456 let epoch_at_start = get_current_epoch();
467
468 let (cache, ttl, invalidate_with_epoch) = get_cache_by_name(&cache_name)
469 .ok_or_else(|| anyhow::anyhow!("cache is somehow undefined!?"))
470 .map_err(any_err)?;
471
472 let epoch_key = if invalidate_with_epoch && !allow_stale_reads {
473 Some(epoch_at_start)
474 } else {
475 None
476 };
477 let key = serde_json::to_string(&key).map_err(any_err)?;
478 let key = (epoch_key, key);
479
480 let value_result = cache
481 .get_or_try_insert(&key, |_| ttl, async {
482 tracing::trace!("populate {key:?}");
483 populate_counter.inc();
484 let result: MultiValue =
485 (func.clone()).call_async(params.clone()).await?;
486 CacheEntry::from_multi_value(&lua, result.clone())
487 })
488 .await;
489
490 match value_result {
491 Ok(lookup) => {
492 if lookup.is_fresh {
493 miss_counter.inc();
494 } else {
495 hit_counter.inc();
496 }
497 return lookup.item.to_value(&lua);
498 }
499 Err(err) => {
500 tracing::error!("{cache_name} {key:?} failed: {err:#}");
501 let error = format!("{err:#}");
502 if !retry_on_populate_timeout {
503 return Err(mlua::Error::external(error));
504 }
505 last_failure.replace(error);
506 }
507 }
508 }
509
510 Err(mlua::Error::external(
511 last_failure.expect("last_failure to always be set in loop above"),
512 ))
513 }
514 })
515 })?,
516 )?;
517
518 Ok(())
519}
520
521#[cfg(test)]
522mod test {
523 use super::*;
524 use mlua::UserDataMethods;
525 use std::sync::atomic::{AtomicUsize, Ordering};
526
527 #[tokio::test]
528 async fn test_memoize() {
529 let lua = Lua::new();
530 register(&lua).unwrap();
531
532 let call_count = Arc::new(AtomicUsize::new(0));
533
534 let globals = lua.globals();
535 let counter = Arc::clone(&call_count);
536 globals
537 .set(
538 "do_thing",
539 lua.create_function(move |_lua, _: ()| {
540 let count = counter.fetch_add(1, Ordering::SeqCst);
541 Ok(count)
542 })
543 .unwrap(),
544 )
545 .unwrap();
546
547 let result: usize = lua
548 .load(
549 r#"
550 local kumo = require 'kumo';
551 -- make cached_do_thing a global for use in the expiry test below
552 cached_do_thing = kumo.memoize(do_thing, {
553 ttl = "1s",
554 capacity = 4,
555 name = "test_memoize_do_thing",
556 })
557 return cached_do_thing() + cached_do_thing() + cached_do_thing()
558 "#,
559 )
560 .eval_async()
561 .await
562 .unwrap();
563
564 assert_eq!(result, 0);
565 assert_eq!(call_count.load(Ordering::SeqCst), 1);
566
567 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
569
570 let result: usize = lua
571 .load(
572 r#"
573 return cached_do_thing()
574 "#,
575 )
576 .eval()
577 .unwrap();
578
579 assert_eq!(result, 1);
580 assert_eq!(call_count.load(Ordering::SeqCst), 2);
581 }
582
583 #[tokio::test]
584 async fn test_memoize_rust() {
585 let lua = Lua::new();
586 register(&lua).unwrap();
587
588 let call_count = Arc::new(AtomicUsize::new(0));
589
590 #[derive(Clone)]
591 struct Foo {
592 value: usize,
593 }
594
595 impl UserData for Foo {
596 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
597 Memoized::impl_memoize(methods);
598 methods.add_method("get_value", move |_lua, this, _: ()| Ok(this.value));
599 }
600 }
601
602 let globals = lua.globals();
603 let counter = Arc::clone(&call_count);
604 globals
605 .set(
606 "make_foo",
607 lua.create_function(move |_lua, _: ()| {
608 let count = counter.fetch_add(1, Ordering::SeqCst);
609 Ok(Foo { value: count })
610 })
611 .unwrap(),
612 )
613 .unwrap();
614
615 let result: usize = lua
616 .load(
617 r#"
618 local kumo = require 'kumo';
619 local cached_make_foo = kumo.memoize(make_foo, {
620 ttl = "1s",
621 capacity = 4,
622 name = "test_memoize_make_foo",
623 })
624 return cached_make_foo():get_value() +
625 cached_make_foo():get_value() +
626 cached_make_foo():get_value()
627 "#,
628 )
629 .eval()
630 .unwrap();
631
632 assert_eq!(result, 0);
633 assert_eq!(call_count.load(Ordering::SeqCst), 1);
634 }
635}