mod_memoize/
lib.rs

1use config::epoch::{get_current_epoch, ConfigEpoch};
2use config::{any_err, from_lua_value, get_or_create_module, serialize_options};
3use dashmap::DashMap;
4use kumo_prometheus::declare_metric;
5use lruttl::LruCacheWithTtl;
6use mlua::{
7    FromLua, Function, IntoLua, Lua, LuaSerdeExt, MetaMethod, MultiValue, UserData,
8    UserDataMethods, UserDataRef,
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, LazyLock};
13use std::time::Duration;
14
15/// Memoized is a helper type that allows native Rust types to be captured
16/// in memoization caches.
17/// Unfortunately, we cannot automatically make that work for all UserData
18/// that are exported to lua, but we can make it simple for a type to opt-in
19/// to that behavior.
20///
21/// When you impl UserData for your type, you can call
22/// `Memoized::impl_memoize(methods)` from your add_methods impl.
23/// That will add a metamethod to your UserData type that will clone your
24/// value and wrap it into a Memoized wrapper.
25///
26/// Since Clone is used, it is recommended that you use an Arc inside your
27/// type to avoid making large or expensive clones.
28#[derive(Clone, mlua::FromLua)]
29pub struct Memoized {
30    pub to_value: Arc<dyn Fn(&Lua) -> mlua::Result<mlua::Value> + Send + Sync>,
31}
32
33impl PartialEq for Memoized {
34    fn eq(&self, other: &Self) -> bool {
35        Arc::ptr_eq(&self.to_value, &other.to_value)
36    }
37}
38
39impl Memoized {
40    /// Call this from your `UserData::add_methods` implementation to
41    /// enable memoization for your UserData type
42    pub fn impl_memoize<T, M>(methods: &mut M)
43    where
44        T: UserData + Send + Sync + Clone + 'static,
45        M: UserDataMethods<T>,
46    {
47        methods.add_meta_method(
48            "__memoize",
49            move |_lua, this, _: ()| -> mlua::Result<Memoized> {
50                let this = this.clone();
51                Ok(Memoized {
52                    to_value: Arc::new(move |lua| this.clone().into_lua(lua)),
53                })
54            },
55        );
56    }
57}
58
59impl UserData for Memoized {}
60
61#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
62#[serde(deny_unknown_fields)]
63pub struct MemoizeParams {
64    #[serde(with = "duration_serde")]
65    pub ttl: Duration,
66    pub capacity: usize,
67    pub name: String,
68    #[serde(default)]
69    pub invalidate_with_epoch: bool,
70    #[serde(default)]
71    pub retry_on_populate_timeout: bool,
72    #[serde(default, with = "duration_serde")]
73    pub populate_timeout: Option<Duration>,
74    #[serde(default)]
75    pub allow_stale_reads: bool,
76}
77
78#[derive(Clone, Hash, Eq, PartialEq)]
79pub enum MapKey {
80    Integer(mlua::Integer),
81    String(Vec<u8>),
82}
83
84impl MapKey {
85    pub fn from_lua(v: mlua::Value) -> Option<Self> {
86        match v {
87            mlua::Value::String(s) => Some(Self::String(s.as_bytes().to_vec())),
88            mlua::Value::Integer(n) => Some(Self::Integer(n)),
89            _ => None,
90        }
91    }
92
93    pub fn as_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
94        match self {
95            Self::Integer(j) => Ok(mlua::Value::Integer(j)),
96            Self::String(b) => Ok(mlua::Value::String(lua.create_string(b)?)),
97        }
98    }
99}
100
101#[derive(Clone, PartialEq)]
102pub enum CacheValue {
103    Table(Arc<HashMap<MapKey, CacheValue>>),
104    Json(serde_json::Value),
105    Memoized(Memoized),
106}
107
108impl std::fmt::Debug for CacheValue {
109    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
110        fmt.debug_struct("CacheValue").finish()
111    }
112}
113
114impl FromLua for CacheValue {
115    fn from_lua(value: mlua::Value, lua: &Lua) -> mlua::Result<Self> {
116        match value {
117            mlua::Value::UserData(ud) => {
118                let mt = ud.metatable()?;
119                let func: Function = mt.get("__memoize")?;
120                let m: Memoized = func.call(mlua::Value::UserData(ud))?;
121                Ok(Self::Memoized(m))
122            }
123            mlua::Value::Table(tbl) => {
124                let mut map = HashMap::new();
125                for pair in tbl.pairs::<mlua::Value, mlua::Value>() {
126                    let (key, value) = pair?;
127                    let key = match key {
128                        mlua::Value::Integer(n) => MapKey::Integer(n),
129                        mlua::Value::String(n) => MapKey::String(n.as_bytes().to_vec()),
130                        _ => {
131                            return Err(anyhow::anyhow!(
132                                "table key {key:?} cannot be used as a key in a memoizable table"
133                            ))
134                            .map_err(any_err)
135                        }
136                    };
137                    let value = CacheValue::from_lua(value, lua)?;
138                    map.insert(key, value);
139                }
140                Ok(Self::Table(map.into()))
141            }
142            _ => Ok(Self::Json(from_lua_value(lua, value)?)),
143        }
144    }
145}
146
147impl IntoLua for CacheValue {
148    fn into_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
149        self.as_lua(lua)
150    }
151}
152
153impl CacheValue {
154    pub fn as_lua(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
155        match self {
156            Self::Json(j) => lua.to_value_with(j, serialize_options()),
157            Self::Memoized(m) => (m.to_value)(lua),
158            Self::Table(m) => Ok(mlua::Value::UserData(
159                lua.create_userdata(MemoizedTable::Shared(m.clone()))?,
160            )),
161        }
162    }
163}
164
165/// MemoizedTable is a helper type that is returned to represent
166/// cached table values.  We'll return the Shared variant by
167/// default as that presents the cheapest way to return the cached
168/// data--only a clone of the underlying Arc is required to return
169/// the value.
170///
171/// This type implements __index, __newindex, __len, and __pairs
172/// metamethods which allow reading and iterating the table.
173///
174/// Writing to the table via __newindex will "unshare" the table in
175/// a similar manner to the Cow type, creating a mutable copy of the top
176/// level of the table.
177enum MemoizedTable {
178    Shared(Arc<HashMap<MapKey, CacheValue>>),
179    Mut(HashMap<MapKey, CacheValue>),
180}
181
182impl MemoizedTable {
183    /// Get a reference to the table, facilitating get() and iter(),
184    /// regardless of whether we are Shared or Mut.
185    fn table(&self) -> &HashMap<MapKey, CacheValue> {
186        match self {
187            Self::Shared(s) => s,
188            Self::Mut(s) => s,
189        }
190    }
191
192    /// Transform Shared -> Mut
193    fn unshare(&mut self) -> &mut HashMap<MapKey, CacheValue> {
194        if let Self::Shared(t) = self {
195            *self = Self::Mut(t.iter().map(|(k, v)| (k.clone(), v.clone())).collect());
196        }
197
198        match self {
199            Self::Shared(_) => unreachable!(),
200            Self::Mut(map) => map,
201        }
202    }
203}
204
205impl UserData for MemoizedTable {
206    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
207        // Index allows reading fields of the table
208        methods.add_meta_method(MetaMethod::Index, move |lua, this, key: mlua::Value| {
209            match MapKey::from_lua(key) {
210                Some(key) => match this.table().get(&key) {
211                    Some(value) => value.as_lua(lua),
212                    None => Ok(mlua::Value::Nil),
213                },
214                None => Ok(mlua::Value::Nil),
215            }
216        });
217
218        // NewIndex allows writing fields of the table
219        methods.add_meta_method_mut(
220            MetaMethod::NewIndex,
221            move |lua, this, (key, value): (mlua::Value, mlua::Value)| match MapKey::from_lua(key) {
222                Some(key) => {
223                    let value = CacheValue::from_lua(value, lua)?;
224                    this.unshare().insert(key, value);
225                    Ok(())
226                }
227                None => Err(mlua::Error::external(
228                    "invalid key type while trying to call __newindex and assign a value",
229                )),
230            },
231        );
232        methods.add_meta_method(MetaMethod::Len, move |_lua, this, ()| {
233            Ok(this.table().len())
234        });
235
236        // Pairs iterates the keys of the table.
237        // We use add_meta_function rather than add_meta_method here
238        // because we need to return `this` as the "state" parameter
239        // for use in a generic-for statement
240        methods.add_meta_function(MetaMethod::Pairs, move |lua, this: mlua::Value| {
241            // Maintain our own local idea of the control variable,
242            // as it is much cheaper and simpler to iterate based
243            // on skipping than to keep comparing keys
244            let mut idx = 0;
245
246            let iter_func =
247                lua.create_function_mut(
248                    move |lua, (state, _control): (UserDataRef<MemoizedTable>, mlua::Value)| {
249                        match state.table().iter().nth(idx) {
250                            Some((key, value)) => {
251                                idx += 1;
252                                let key = key.clone().as_lua(lua)?;
253                                let value = value.as_lua(lua)?;
254                                Ok((key, value))
255                            }
256                            None => Ok((mlua::Value::Nil, mlua::Value::Nil)),
257                        }
258                    },
259                )?;
260
261            // Return the iterator, state and control values.
262            // The state and control will be passed back into iter_func
263            // as the for-loop iterates.
264            // Control is Nil here because we track our own idx
265            // value in the iter_func closure.
266            Ok((mlua::Value::Function(iter_func), this, mlua::Value::Nil))
267        });
268    }
269}
270
271#[derive(Clone, Debug)]
272enum CacheEntry {
273    Null,
274    Single(CacheValue),
275    Multi(Vec<CacheValue>),
276}
277
278impl CacheEntry {
279    fn to_value(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
280        match self {
281            Self::Null => Ok(mlua::Value::Nil),
282            Self::Single(value) => value.as_lua(lua),
283            Self::Multi(values) => {
284                let mut result = vec![];
285                for v in values {
286                    result.push(v.as_lua(lua)?);
287                }
288                result.into_lua(lua)
289            }
290        }
291    }
292
293    fn from_multi_value(lua: &Lua, multi: MultiValue) -> mlua::Result<Self> {
294        let mut values = multi.into_vec();
295        if values.is_empty() {
296            Ok(Self::Null)
297        } else if values.len() == 1 {
298            Ok(Self::Single(CacheValue::from_lua(
299                values.pop().unwrap(),
300                lua,
301            )?))
302        } else {
303            let mut cvalues = vec![];
304            for v in values.into_iter() {
305                cvalues.push(CacheValue::from_lua(v, lua)?);
306            }
307            Ok(Self::Multi(cvalues))
308        }
309    }
310}
311
312struct MemoizeCache {
313    params: MemoizeParams,
314    cache: Arc<LruCacheWithTtl<CacheKey, CacheEntry>>,
315}
316
317static CACHES: LazyLock<DashMap<String, MemoizeCache>> = LazyLock::new(DashMap::new);
318
319type CacheKey = (Option<ConfigEpoch>, String);
320
321fn get_cache_by_name(
322    name: &str,
323) -> Option<(Arc<LruCacheWithTtl<CacheKey, CacheEntry>>, Duration, bool)> {
324    CACHES.get(name).map(|item| {
325        (
326            item.cache.clone(),
327            item.params.ttl,
328            item.params.invalidate_with_epoch,
329        )
330    })
331}
332
333declare_metric! {
334/// How many times a memoize cache lookup was initiated for a given cache.
335///
336/// Redundant with the newer [lruttl_lookup_count](lruttl_lookup_count.md) metric.
337static CACHE_LOOKUP: CounterVec(
338        "memoize_cache_lookup_count",
339        &["cache_name"]);
340}
341
342declare_metric! {
343/// How many times a memoize cache lookup was a hit for a given cache.
344///
345/// Redundant with the newer [lruttl_hit_count](lruttl_hit_count.md) metric.
346static CACHE_HIT: CounterVec(
347        "memoize_cache_hit_count",
348        &["cache_name"]);
349}
350
351declare_metric! {
352/// How many times a memoize cache lookup was a miss for a given cache
353///
354/// Redundant with the newer [lruttl_miss_count](lruttl_miss_count.md) metric.
355static CACHE_MISS: CounterVec(
356        "memoize_cache_miss_count",
357        &["cache_name"]);
358}
359
360declare_metric! {
361/// How many times a memoize cache lookup resulted in performing the work to populate the entry
362///
363/// Redundant with the newer [lruttl_populated_count](lruttl_populated_count.md) metric.
364static CACHE_POPULATED: CounterVec(
365        "memoize_cache_populated_count",
366        &["cache_name"]);
367}
368
369fn multi_value_to_json_value(lua: &Lua, multi: MultiValue) -> mlua::Result<serde_json::Value> {
370    let mut values = multi.into_vec();
371    if values.is_empty() {
372        Ok(serde_json::Value::Null)
373    } else if values.len() == 1 {
374        from_lua_value(lua, values.pop().unwrap())
375    } else {
376        let mut jvalues = vec![];
377        for v in values.into_iter() {
378            jvalues.push(from_lua_value(lua, v)?);
379        }
380        Ok(serde_json::Value::Array(jvalues))
381    }
382}
383
384pub fn register(lua: &Lua) -> anyhow::Result<()> {
385    let kumo_mod = get_or_create_module(lua, "kumo")?;
386
387    kumo_mod.set(
388        "memoize",
389        lua.create_function(move |lua, (func, params): (mlua::Function, mlua::Value)| {
390            let params: MemoizeParams = from_lua_value(lua, params)?;
391
392            let cache_name = params.name.to_string();
393
394            if !lruttl::is_name_available(&cache_name) {
395                return Err(mlua::Error::external(format!(
396                    "cannot use name `{cache_name}` for a memoize cache, \
397                    as it collides with a built-in cache. \
398                    Suggestion: prefix your cache name with `user.` to \
399                    avoid conflicts with current and future caches."
400                )));
401            }
402
403            CACHES.remove_if(&params.name, |_k, item| {
404                let changed = item.params != params;
405                if changed {
406                    tracing::trace!("memoize parameters changed, replacing old cache {params:?}");
407                }
408                changed
409            });
410            CACHES.entry(cache_name.to_string()).or_insert_with(|| {
411                let cache = LruCacheWithTtl::new(cache_name.clone(), params.capacity);
412                if let Some(duration) = params.populate_timeout {
413                    cache.set_sema_timeout(duration);
414                }
415                cache.set_allow_stale_reads(params.allow_stale_reads);
416
417                MemoizeCache {
418                    params: params.clone(),
419                    cache: Arc::new(cache),
420                }
421            });
422
423            let lookup_counter = CACHE_LOOKUP
424                .get_metric_with_label_values(&[&cache_name])
425                .map_err(any_err)?;
426            let hit_counter = CACHE_HIT
427                .get_metric_with_label_values(&[&cache_name])
428                .map_err(any_err)?;
429            let miss_counter = CACHE_MISS
430                .get_metric_with_label_values(&[&cache_name])
431                .map_err(any_err)?;
432            let populate_counter = CACHE_POPULATED
433                .get_metric_with_label_values(&[&cache_name])
434                .map_err(any_err)?;
435            let retry_on_populate_timeout = params.retry_on_populate_timeout;
436            let allow_stale_reads = params.allow_stale_reads;
437
438            let func_ref = lua.create_registry_value(func)?;
439
440            lua.create_async_function(move |lua, params: MultiValue| {
441                let cache_name = cache_name.clone();
442                let func = lua.registry_value::<mlua::Function>(&func_ref);
443                let lookup_counter = lookup_counter.clone();
444                let hit_counter = hit_counter.clone();
445                let miss_counter = miss_counter.clone();
446                let populate_counter = populate_counter.clone();
447                async move {
448                    lookup_counter.inc();
449                    let key = multi_value_to_json_value(&lua, params.clone())?;
450
451                    let func = func?;
452
453                    let mut last_failure = None;
454
455                    for _attempt in 0..3 {
456                        // We use the epoch from the start of the lookup as part
457                        // of the cache key. If the epoch changes while we are in
458                        // the middle of computing this value then subsequent calls
459                        // through to the cached function will see the newer epoch
460                        // and encounter a cache miss. This prevents a race condition
461                        // poisoning the cache with a stale value during an epoch
462                        // bump. The caller will still observe the stale value, so
463                        // ultimately should have some accommodation for detecting
464                        // the epoch change and retrying their call through here,
465                        // if it is important to not see a stale value.
466                        let epoch_at_start = get_current_epoch();
467
468                        let (cache, ttl, invalidate_with_epoch) = get_cache_by_name(&cache_name)
469                            .ok_or_else(|| anyhow::anyhow!("cache is somehow undefined!?"))
470                            .map_err(any_err)?;
471
472                        let epoch_key = if invalidate_with_epoch && !allow_stale_reads {
473                            Some(epoch_at_start)
474                        } else {
475                            None
476                        };
477                        let key = serde_json::to_string(&key).map_err(any_err)?;
478                        let key = (epoch_key, key);
479
480                        let value_result = cache
481                            .get_or_try_insert(&key, |_| ttl, async {
482                                tracing::trace!("populate {key:?}");
483                                populate_counter.inc();
484                                let result: MultiValue =
485                                    (func.clone()).call_async(params.clone()).await?;
486                                CacheEntry::from_multi_value(&lua, result.clone())
487                            })
488                            .await;
489
490                        match value_result {
491                            Ok(lookup) => {
492                                if lookup.is_fresh {
493                                    miss_counter.inc();
494                                } else {
495                                    hit_counter.inc();
496                                }
497                                return lookup.item.to_value(&lua);
498                            }
499                            Err(err) => {
500                                tracing::error!("{cache_name} {key:?} failed: {err:#}");
501                                let error = format!("{err:#}");
502                                if !retry_on_populate_timeout {
503                                    return Err(mlua::Error::external(error));
504                                }
505                                last_failure.replace(error);
506                            }
507                        }
508                    }
509
510                    Err(mlua::Error::external(
511                        last_failure.expect("last_failure to always be set in loop above"),
512                    ))
513                }
514            })
515        })?,
516    )?;
517
518    Ok(())
519}
520
521#[cfg(test)]
522mod test {
523    use super::*;
524    use mlua::UserDataMethods;
525    use std::sync::atomic::{AtomicUsize, Ordering};
526
527    #[tokio::test]
528    async fn test_memoize() {
529        let lua = Lua::new();
530        register(&lua).unwrap();
531
532        let call_count = Arc::new(AtomicUsize::new(0));
533
534        let globals = lua.globals();
535        let counter = Arc::clone(&call_count);
536        globals
537            .set(
538                "do_thing",
539                lua.create_function(move |_lua, _: ()| {
540                    let count = counter.fetch_add(1, Ordering::SeqCst);
541                    Ok(count)
542                })
543                .unwrap(),
544            )
545            .unwrap();
546
547        let result: usize = lua
548            .load(
549                r#"
550            local kumo = require 'kumo';
551            -- make cached_do_thing a global for use in the expiry test below
552            cached_do_thing = kumo.memoize(do_thing, {
553                ttl = "1s",
554                capacity = 4,
555                name = "test_memoize_do_thing",
556            })
557            return cached_do_thing() + cached_do_thing() + cached_do_thing()
558        "#,
559            )
560            .eval_async()
561            .await
562            .unwrap();
563
564        assert_eq!(result, 0);
565        assert_eq!(call_count.load(Ordering::SeqCst), 1);
566
567        // And confirm that expiry works
568        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
569
570        let result: usize = lua
571            .load(
572                r#"
573            return cached_do_thing()
574        "#,
575            )
576            .eval()
577            .unwrap();
578
579        assert_eq!(result, 1);
580        assert_eq!(call_count.load(Ordering::SeqCst), 2);
581    }
582
583    #[tokio::test]
584    async fn test_memoize_rust() {
585        let lua = Lua::new();
586        register(&lua).unwrap();
587
588        let call_count = Arc::new(AtomicUsize::new(0));
589
590        #[derive(Clone)]
591        struct Foo {
592            value: usize,
593        }
594
595        impl UserData for Foo {
596            fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
597                Memoized::impl_memoize(methods);
598                methods.add_method("get_value", move |_lua, this, _: ()| Ok(this.value));
599            }
600        }
601
602        let globals = lua.globals();
603        let counter = Arc::clone(&call_count);
604        globals
605            .set(
606                "make_foo",
607                lua.create_function(move |_lua, _: ()| {
608                    let count = counter.fetch_add(1, Ordering::SeqCst);
609                    Ok(Foo { value: count })
610                })
611                .unwrap(),
612            )
613            .unwrap();
614
615        let result: usize = lua
616            .load(
617                r#"
618            local kumo = require 'kumo';
619            local cached_make_foo = kumo.memoize(make_foo, {
620                ttl = "1s",
621                capacity = 4,
622                name = "test_memoize_make_foo",
623            })
624            return cached_make_foo():get_value() +
625                   cached_make_foo():get_value() +
626                   cached_make_foo():get_value()
627        "#,
628            )
629            .eval()
630            .unwrap();
631
632        assert_eq!(result, 0);
633        assert_eq!(call_count.load(Ordering::SeqCst), 1);
634    }
635}