1use config::epoch::{get_current_epoch, ConfigEpoch};
2use config::{any_err, from_lua_value, get_or_create_module, serialize_options};
3use dashmap::DashMap;
4use lruttl::LruCacheWithTtl;
5use mlua::{
6 FromLua, Function, IntoLua, Lua, LuaSerdeExt, MetaMethod, MultiValue, UserData,
7 UserDataMethods, UserDataRef,
8};
9use prometheus::CounterVec;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, LazyLock};
13use std::time::Duration;
14
15#[derive(Clone, mlua::FromLua)]
29pub struct Memoized {
30 pub to_value: Arc<dyn Fn(&Lua) -> mlua::Result<mlua::Value> + Send + Sync>,
31}
32
33impl PartialEq for Memoized {
34 fn eq(&self, other: &Self) -> bool {
35 Arc::ptr_eq(&self.to_value, &other.to_value)
36 }
37}
38
39impl Memoized {
40 pub fn impl_memoize<T, M>(methods: &mut M)
43 where
44 T: UserData + Send + Sync + Clone + 'static,
45 M: UserDataMethods<T>,
46 {
47 methods.add_meta_method(
48 "__memoize",
49 move |_lua, this, _: ()| -> mlua::Result<Memoized> {
50 let this = this.clone();
51 Ok(Memoized {
52 to_value: Arc::new(move |lua| this.clone().into_lua(lua)),
53 })
54 },
55 );
56 }
57}
58
59impl UserData for Memoized {}
60
61#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
62#[serde(deny_unknown_fields)]
63pub struct MemoizeParams {
64 #[serde(with = "duration_serde")]
65 pub ttl: Duration,
66 pub capacity: usize,
67 pub name: String,
68 #[serde(default)]
69 pub invalidate_with_epoch: bool,
70 #[serde(default)]
71 pub retry_on_populate_timeout: bool,
72 #[serde(default, with = "duration_serde")]
73 pub populate_timeout: Option<Duration>,
74 #[serde(default)]
75 pub allow_stale_reads: bool,
76}
77
78#[derive(Clone, Hash, Eq, PartialEq)]
79pub enum MapKey {
80 Integer(mlua::Integer),
81 String(Vec<u8>),
82}
83
84impl MapKey {
85 pub fn from_lua(v: mlua::Value) -> Option<Self> {
86 match v {
87 mlua::Value::String(s) => Some(Self::String(s.as_bytes().to_vec())),
88 mlua::Value::Integer(n) => Some(Self::Integer(n)),
89 _ => None,
90 }
91 }
92
93 pub fn as_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
94 match self {
95 Self::Integer(j) => Ok(mlua::Value::Integer(j)),
96 Self::String(b) => Ok(mlua::Value::String(lua.create_string(b)?)),
97 }
98 }
99}
100
101#[derive(Clone, PartialEq)]
102pub enum CacheValue {
103 Table(Arc<HashMap<MapKey, CacheValue>>),
104 Json(serde_json::Value),
105 Memoized(Memoized),
106}
107
108impl std::fmt::Debug for CacheValue {
109 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
110 fmt.debug_struct("CacheValue").finish()
111 }
112}
113
114impl FromLua for CacheValue {
115 fn from_lua(value: mlua::Value, lua: &Lua) -> mlua::Result<Self> {
116 match value {
117 mlua::Value::UserData(ud) => {
118 let mt = ud.metatable()?;
119 let func: Function = mt.get("__memoize")?;
120 let m: Memoized = func.call(mlua::Value::UserData(ud))?;
121 Ok(Self::Memoized(m))
122 }
123 mlua::Value::Table(tbl) => {
124 let mut map = HashMap::new();
125 for pair in tbl.pairs::<mlua::Value, mlua::Value>() {
126 let (key, value) = pair?;
127 let key = match key {
128 mlua::Value::Integer(n) => MapKey::Integer(n),
129 mlua::Value::String(n) => MapKey::String(n.as_bytes().to_vec()),
130 _ => {
131 return Err(anyhow::anyhow!(
132 "table key {key:?} cannot be used as a key in a memoizable table"
133 ))
134 .map_err(any_err)
135 }
136 };
137 let value = CacheValue::from_lua(value, lua)?;
138 map.insert(key, value);
139 }
140 Ok(Self::Table(map.into()))
141 }
142 _ => Ok(Self::Json(from_lua_value(lua, value)?)),
143 }
144 }
145}
146
147impl IntoLua for CacheValue {
148 fn into_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
149 self.as_lua(lua)
150 }
151}
152
153impl CacheValue {
154 pub fn as_lua(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
155 match self {
156 Self::Json(j) => lua.to_value_with(j, serialize_options()),
157 Self::Memoized(m) => (m.to_value)(lua),
158 Self::Table(m) => Ok(mlua::Value::UserData(
159 lua.create_userdata(MemoizedTable::Shared(m.clone()))?,
160 )),
161 }
162 }
163}
164
165enum MemoizedTable {
178 Shared(Arc<HashMap<MapKey, CacheValue>>),
179 Mut(HashMap<MapKey, CacheValue>),
180}
181
182impl MemoizedTable {
183 fn table(&self) -> &HashMap<MapKey, CacheValue> {
186 match self {
187 Self::Shared(s) => s,
188 Self::Mut(s) => s,
189 }
190 }
191
192 fn unshare(&mut self) -> &mut HashMap<MapKey, CacheValue> {
194 if let Self::Shared(t) = self {
195 *self = Self::Mut(t.iter().map(|(k, v)| (k.clone(), v.clone())).collect());
196 }
197
198 match self {
199 Self::Shared(_) => unreachable!(),
200 Self::Mut(map) => map,
201 }
202 }
203}
204
205impl UserData for MemoizedTable {
206 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
207 methods.add_meta_method(MetaMethod::Index, move |lua, this, key: mlua::Value| {
209 match MapKey::from_lua(key) {
210 Some(key) => match this.table().get(&key) {
211 Some(value) => value.as_lua(lua),
212 None => Ok(mlua::Value::Nil),
213 },
214 None => Ok(mlua::Value::Nil),
215 }
216 });
217
218 methods.add_meta_method_mut(
220 MetaMethod::NewIndex,
221 move |lua, this, (key, value): (mlua::Value, mlua::Value)| match MapKey::from_lua(key) {
222 Some(key) => {
223 let value = CacheValue::from_lua(value, lua)?;
224 this.unshare().insert(key, value);
225 Ok(())
226 }
227 None => Err(mlua::Error::external(
228 "invalid key type while trying to call __newindex and assign a value",
229 )),
230 },
231 );
232 methods.add_meta_method(MetaMethod::Len, move |_lua, this, ()| {
233 Ok(this.table().len())
234 });
235
236 methods.add_meta_function(MetaMethod::Pairs, move |lua, this: mlua::Value| {
241 let mut idx = 0;
245
246 let iter_func =
247 lua.create_function_mut(
248 move |lua, (state, _control): (UserDataRef<MemoizedTable>, mlua::Value)| {
249 match state.table().iter().nth(idx) {
250 Some((key, value)) => {
251 idx += 1;
252 let key = key.clone().as_lua(lua)?;
253 let value = value.as_lua(lua)?;
254 Ok((key, value))
255 }
256 None => Ok((mlua::Value::Nil, mlua::Value::Nil)),
257 }
258 },
259 )?;
260
261 Ok((mlua::Value::Function(iter_func), this, mlua::Value::Nil))
267 });
268 }
269}
270
271#[derive(Clone, Debug)]
272enum CacheEntry {
273 Null,
274 Single(CacheValue),
275 Multi(Vec<CacheValue>),
276}
277
278impl CacheEntry {
279 fn to_value(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
280 match self {
281 Self::Null => Ok(mlua::Value::Nil),
282 Self::Single(value) => value.as_lua(lua),
283 Self::Multi(values) => {
284 let mut result = vec![];
285 for v in values {
286 result.push(v.as_lua(lua)?);
287 }
288 result.into_lua(lua)
289 }
290 }
291 }
292
293 fn from_multi_value(lua: &Lua, multi: MultiValue) -> mlua::Result<Self> {
294 let mut values = multi.into_vec();
295 if values.is_empty() {
296 Ok(Self::Null)
297 } else if values.len() == 1 {
298 Ok(Self::Single(CacheValue::from_lua(
299 values.pop().unwrap(),
300 lua,
301 )?))
302 } else {
303 let mut cvalues = vec![];
304 for v in values.into_iter() {
305 cvalues.push(CacheValue::from_lua(v, lua)?);
306 }
307 Ok(Self::Multi(cvalues))
308 }
309 }
310}
311
312struct MemoizeCache {
313 params: MemoizeParams,
314 cache: Arc<LruCacheWithTtl<CacheKey, CacheEntry>>,
315}
316
317static CACHES: LazyLock<DashMap<String, MemoizeCache>> = LazyLock::new(DashMap::new);
318
319type CacheKey = (Option<ConfigEpoch>, String);
320
321fn get_cache_by_name(
322 name: &str,
323) -> Option<(Arc<LruCacheWithTtl<CacheKey, CacheEntry>>, Duration, bool)> {
324 CACHES.get(name).map(|item| {
325 (
326 item.cache.clone(),
327 item.params.ttl,
328 item.params.invalidate_with_epoch,
329 )
330 })
331}
332
333static CACHE_LOOKUP: LazyLock<CounterVec> = LazyLock::new(|| {
334 prometheus::register_counter_vec!(
335 "memoize_cache_lookup_count",
336 "how many times a memoize cache lookup was initiated for a given cache",
337 &["cache_name"]
338 )
339 .unwrap()
340});
341static CACHE_HIT: LazyLock<CounterVec> = LazyLock::new(|| {
342 prometheus::register_counter_vec!(
343 "memoize_cache_hit_count",
344 "how many times a memoize cache lookup was a hit for a given cache",
345 &["cache_name"]
346 )
347 .unwrap()
348});
349static CACHE_MISS: LazyLock<CounterVec> = LazyLock::new(|| {
350 prometheus::register_counter_vec!(
351 "memoize_cache_miss_count",
352 "how many times a memoize cache lookup was a miss for a given cache",
353 &["cache_name"]
354 )
355 .unwrap()
356});
357static CACHE_POPULATED: LazyLock<CounterVec> = LazyLock::new(|| {
358 prometheus::register_counter_vec!(
359 "memoize_cache_populated_count",
360 "how many times a memoize cache lookup resulted in performing the work to populate the entry",
361 &["cache_name"]
362 )
363 .unwrap()
364});
365
366fn multi_value_to_json_value(lua: &Lua, multi: MultiValue) -> mlua::Result<serde_json::Value> {
367 let mut values = multi.into_vec();
368 if values.is_empty() {
369 Ok(serde_json::Value::Null)
370 } else if values.len() == 1 {
371 from_lua_value(lua, values.pop().unwrap())
372 } else {
373 let mut jvalues = vec![];
374 for v in values.into_iter() {
375 jvalues.push(from_lua_value(lua, v)?);
376 }
377 Ok(serde_json::Value::Array(jvalues))
378 }
379}
380
381pub fn register(lua: &Lua) -> anyhow::Result<()> {
382 let kumo_mod = get_or_create_module(lua, "kumo")?;
383
384 kumo_mod.set(
385 "memoize",
386 lua.create_function(move |lua, (func, params): (mlua::Function, mlua::Value)| {
387 let params: MemoizeParams = from_lua_value(lua, params)?;
388
389 let cache_name = params.name.to_string();
390
391 if !lruttl::is_name_available(&cache_name) {
392 return Err(mlua::Error::external(format!(
393 "cannot use name `{cache_name}` for a memoize cache, \
394 as it collides with a built-in cache. \
395 Suggestion: prefix your cache name with `user.` to \
396 avoid conflicts with current and future caches."
397 )));
398 }
399
400 CACHES.remove_if(¶ms.name, |_k, item| {
401 let changed = item.params != params;
402 if changed {
403 tracing::trace!("memoize parameters changed, replacing old cache {params:?}");
404 }
405 changed
406 });
407 CACHES.entry(cache_name.to_string()).or_insert_with(|| {
408 let cache = LruCacheWithTtl::new(cache_name.clone(), params.capacity);
409 if let Some(duration) = params.populate_timeout {
410 cache.set_sema_timeout(duration);
411 }
412 cache.set_allow_stale_reads(params.allow_stale_reads);
413
414 MemoizeCache {
415 params: params.clone(),
416 cache: Arc::new(cache),
417 }
418 });
419
420 let lookup_counter = CACHE_LOOKUP
421 .get_metric_with_label_values(&[&cache_name])
422 .map_err(any_err)?;
423 let hit_counter = CACHE_HIT
424 .get_metric_with_label_values(&[&cache_name])
425 .map_err(any_err)?;
426 let miss_counter = CACHE_MISS
427 .get_metric_with_label_values(&[&cache_name])
428 .map_err(any_err)?;
429 let populate_counter = CACHE_POPULATED
430 .get_metric_with_label_values(&[&cache_name])
431 .map_err(any_err)?;
432 let retry_on_populate_timeout = params.retry_on_populate_timeout;
433 let allow_stale_reads = params.allow_stale_reads;
434
435 let func_ref = lua.create_registry_value(func)?;
436
437 lua.create_async_function(move |lua, params: MultiValue| {
438 let cache_name = cache_name.clone();
439 let func = lua.registry_value::<mlua::Function>(&func_ref);
440 let lookup_counter = lookup_counter.clone();
441 let hit_counter = hit_counter.clone();
442 let miss_counter = miss_counter.clone();
443 let populate_counter = populate_counter.clone();
444 async move {
445 lookup_counter.inc();
446 let key = multi_value_to_json_value(&lua, params.clone())?;
447
448 let func = func?;
449
450 let mut last_failure = None;
451
452 for _attempt in 0..3 {
453 let epoch_at_start = get_current_epoch();
464
465 let (cache, ttl, invalidate_with_epoch) = get_cache_by_name(&cache_name)
466 .ok_or_else(|| anyhow::anyhow!("cache is somehow undefined!?"))
467 .map_err(any_err)?;
468
469 let epoch_key = if invalidate_with_epoch && !allow_stale_reads {
470 Some(epoch_at_start)
471 } else {
472 None
473 };
474 let key = serde_json::to_string(&key).map_err(any_err)?;
475 let key = (epoch_key, key);
476
477 let value_result = cache
478 .get_or_try_insert(&key, |_| ttl, async {
479 tracing::trace!("populate {key:?}");
480 populate_counter.inc();
481 let result: MultiValue =
482 (func.clone()).call_async(params.clone()).await?;
483 CacheEntry::from_multi_value(&lua, result.clone())
484 })
485 .await;
486
487 match value_result {
488 Ok(lookup) => {
489 if lookup.is_fresh {
490 miss_counter.inc();
491 } else {
492 hit_counter.inc();
493 }
494 return lookup.item.to_value(&lua);
495 }
496 Err(err) => {
497 tracing::error!("{cache_name} {key:?} failed: {err:#}");
498 let error = format!("{err:#}");
499 if !retry_on_populate_timeout {
500 return Err(mlua::Error::external(error));
501 }
502 last_failure.replace(error);
503 }
504 }
505 }
506
507 Err(mlua::Error::external(
508 last_failure.expect("last_failure to always be set in loop above"),
509 ))
510 }
511 })
512 })?,
513 )?;
514
515 Ok(())
516}
517
518#[cfg(test)]
519mod test {
520 use super::*;
521 use mlua::UserDataMethods;
522 use std::sync::atomic::{AtomicUsize, Ordering};
523
524 #[tokio::test]
525 async fn test_memoize() {
526 let lua = Lua::new();
527 register(&lua).unwrap();
528
529 let call_count = Arc::new(AtomicUsize::new(0));
530
531 let globals = lua.globals();
532 let counter = Arc::clone(&call_count);
533 globals
534 .set(
535 "do_thing",
536 lua.create_function(move |_lua, _: ()| {
537 let count = counter.fetch_add(1, Ordering::SeqCst);
538 Ok(count)
539 })
540 .unwrap(),
541 )
542 .unwrap();
543
544 let result: usize = lua
545 .load(
546 r#"
547 local kumo = require 'kumo';
548 -- make cached_do_thing a global for use in the expiry test below
549 cached_do_thing = kumo.memoize(do_thing, {
550 ttl = "1s",
551 capacity = 4,
552 name = "test_memoize_do_thing",
553 })
554 return cached_do_thing() + cached_do_thing() + cached_do_thing()
555 "#,
556 )
557 .eval_async()
558 .await
559 .unwrap();
560
561 assert_eq!(result, 0);
562 assert_eq!(call_count.load(Ordering::SeqCst), 1);
563
564 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
566
567 let result: usize = lua
568 .load(
569 r#"
570 return cached_do_thing()
571 "#,
572 )
573 .eval()
574 .unwrap();
575
576 assert_eq!(result, 1);
577 assert_eq!(call_count.load(Ordering::SeqCst), 2);
578 }
579
580 #[tokio::test]
581 async fn test_memoize_rust() {
582 let lua = Lua::new();
583 register(&lua).unwrap();
584
585 let call_count = Arc::new(AtomicUsize::new(0));
586
587 #[derive(Clone)]
588 struct Foo {
589 value: usize,
590 }
591
592 impl UserData for Foo {
593 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
594 Memoized::impl_memoize(methods);
595 methods.add_method("get_value", move |_lua, this, _: ()| Ok(this.value));
596 }
597 }
598
599 let globals = lua.globals();
600 let counter = Arc::clone(&call_count);
601 globals
602 .set(
603 "make_foo",
604 lua.create_function(move |_lua, _: ()| {
605 let count = counter.fetch_add(1, Ordering::SeqCst);
606 Ok(Foo { value: count })
607 })
608 .unwrap(),
609 )
610 .unwrap();
611
612 let result: usize = lua
613 .load(
614 r#"
615 local kumo = require 'kumo';
616 local cached_make_foo = kumo.memoize(make_foo, {
617 ttl = "1s",
618 capacity = 4,
619 name = "test_memoize_make_foo",
620 })
621 return cached_make_foo():get_value() +
622 cached_make_foo():get_value() +
623 cached_make_foo():get_value()
624 "#,
625 )
626 .eval()
627 .unwrap();
628
629 assert_eq!(result, 0);
630 assert_eq!(call_count.load(Ordering::SeqCst), 1);
631 }
632}