mod_memoize/
lib.rs

1use config::epoch::{get_current_epoch, ConfigEpoch};
2use config::{any_err, from_lua_value, get_or_create_module, serialize_options};
3use dashmap::DashMap;
4use lruttl::LruCacheWithTtl;
5use mlua::{
6    FromLua, Function, IntoLua, Lua, LuaSerdeExt, MetaMethod, MultiValue, UserData,
7    UserDataMethods, UserDataRef,
8};
9use prometheus::CounterVec;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::{Arc, LazyLock};
13use std::time::Duration;
14
15/// Memoized is a helper type that allows native Rust types to be captured
16/// in memoization caches.
17/// Unfortunately, we cannot automatically make that work for all UserData
18/// that are exported to lua, but we can make it simple for a type to opt-in
19/// to that behavior.
20///
21/// When you impl UserData for your type, you can call
22/// `Memoized::impl_memoize(methods)` from your add_methods impl.
23/// That will add a metamethod to your UserData type that will clone your
24/// value and wrap it into a Memoized wrapper.
25///
26/// Since Clone is used, it is recommended that you use an Arc inside your
27/// type to avoid making large or expensive clones.
28#[derive(Clone, mlua::FromLua)]
29pub struct Memoized {
30    pub to_value: Arc<dyn Fn(&Lua) -> mlua::Result<mlua::Value> + Send + Sync>,
31}
32
33impl PartialEq for Memoized {
34    fn eq(&self, other: &Self) -> bool {
35        Arc::ptr_eq(&self.to_value, &other.to_value)
36    }
37}
38
39impl Memoized {
40    /// Call this from your `UserData::add_methods` implementation to
41    /// enable memoization for your UserData type
42    pub fn impl_memoize<T, M>(methods: &mut M)
43    where
44        T: UserData + Send + Sync + Clone + 'static,
45        M: UserDataMethods<T>,
46    {
47        methods.add_meta_method(
48            "__memoize",
49            move |_lua, this, _: ()| -> mlua::Result<Memoized> {
50                let this = this.clone();
51                Ok(Memoized {
52                    to_value: Arc::new(move |lua| this.clone().into_lua(lua)),
53                })
54            },
55        );
56    }
57}
58
59impl UserData for Memoized {}
60
61#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
62#[serde(deny_unknown_fields)]
63pub struct MemoizeParams {
64    #[serde(with = "duration_serde")]
65    pub ttl: Duration,
66    pub capacity: usize,
67    pub name: String,
68    #[serde(default)]
69    pub invalidate_with_epoch: bool,
70    #[serde(default)]
71    pub retry_on_populate_timeout: bool,
72    #[serde(default, with = "duration_serde")]
73    pub populate_timeout: Option<Duration>,
74    #[serde(default)]
75    pub allow_stale_reads: bool,
76}
77
78#[derive(Clone, Hash, Eq, PartialEq)]
79pub enum MapKey {
80    Integer(mlua::Integer),
81    String(Vec<u8>),
82}
83
84impl MapKey {
85    pub fn from_lua(v: mlua::Value) -> Option<Self> {
86        match v {
87            mlua::Value::String(s) => Some(Self::String(s.as_bytes().to_vec())),
88            mlua::Value::Integer(n) => Some(Self::Integer(n)),
89            _ => None,
90        }
91    }
92
93    pub fn as_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
94        match self {
95            Self::Integer(j) => Ok(mlua::Value::Integer(j)),
96            Self::String(b) => Ok(mlua::Value::String(lua.create_string(b)?)),
97        }
98    }
99}
100
101#[derive(Clone, PartialEq)]
102pub enum CacheValue {
103    Table(Arc<HashMap<MapKey, CacheValue>>),
104    Json(serde_json::Value),
105    Memoized(Memoized),
106}
107
108impl std::fmt::Debug for CacheValue {
109    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
110        fmt.debug_struct("CacheValue").finish()
111    }
112}
113
114impl FromLua for CacheValue {
115    fn from_lua(value: mlua::Value, lua: &Lua) -> mlua::Result<Self> {
116        match value {
117            mlua::Value::UserData(ud) => {
118                let mt = ud.metatable()?;
119                let func: Function = mt.get("__memoize")?;
120                let m: Memoized = func.call(mlua::Value::UserData(ud))?;
121                Ok(Self::Memoized(m))
122            }
123            mlua::Value::Table(tbl) => {
124                let mut map = HashMap::new();
125                for pair in tbl.pairs::<mlua::Value, mlua::Value>() {
126                    let (key, value) = pair?;
127                    let key = match key {
128                        mlua::Value::Integer(n) => MapKey::Integer(n),
129                        mlua::Value::String(n) => MapKey::String(n.as_bytes().to_vec()),
130                        _ => {
131                            return Err(anyhow::anyhow!(
132                                "table key {key:?} cannot be used as a key in a memoizable table"
133                            ))
134                            .map_err(any_err)
135                        }
136                    };
137                    let value = CacheValue::from_lua(value, lua)?;
138                    map.insert(key, value);
139                }
140                Ok(Self::Table(map.into()))
141            }
142            _ => Ok(Self::Json(from_lua_value(lua, value)?)),
143        }
144    }
145}
146
147impl IntoLua for CacheValue {
148    fn into_lua(self, lua: &Lua) -> mlua::Result<mlua::Value> {
149        self.as_lua(lua)
150    }
151}
152
153impl CacheValue {
154    pub fn as_lua(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
155        match self {
156            Self::Json(j) => lua.to_value_with(j, serialize_options()),
157            Self::Memoized(m) => (m.to_value)(lua),
158            Self::Table(m) => Ok(mlua::Value::UserData(
159                lua.create_userdata(MemoizedTable::Shared(m.clone()))?,
160            )),
161        }
162    }
163}
164
165/// MemoizedTable is a helper type that is returned to represent
166/// cached table values.  We'll return the Shared variant by
167/// default as that presents the cheapest way to return the cached
168/// data--only a clone of the underlying Arc is required to return
169/// the value.
170///
171/// This type implements __index, __newindex, __len, and __pairs
172/// metamethods which allow reading and iterating the table.
173///
174/// Writing to the table via __newindex will "unshare" the table in
175/// a similar manner to the Cow type, creating a mutable copy of the top
176/// level of the table.
177enum MemoizedTable {
178    Shared(Arc<HashMap<MapKey, CacheValue>>),
179    Mut(HashMap<MapKey, CacheValue>),
180}
181
182impl MemoizedTable {
183    /// Get a reference to the table, facilitating get() and iter(),
184    /// regardless of whether we are Shared or Mut.
185    fn table(&self) -> &HashMap<MapKey, CacheValue> {
186        match self {
187            Self::Shared(s) => s,
188            Self::Mut(s) => s,
189        }
190    }
191
192    /// Transform Shared -> Mut
193    fn unshare(&mut self) -> &mut HashMap<MapKey, CacheValue> {
194        if let Self::Shared(t) = self {
195            *self = Self::Mut(t.iter().map(|(k, v)| (k.clone(), v.clone())).collect());
196        }
197
198        match self {
199            Self::Shared(_) => unreachable!(),
200            Self::Mut(map) => map,
201        }
202    }
203}
204
205impl UserData for MemoizedTable {
206    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
207        // Index allows reading fields of the table
208        methods.add_meta_method(MetaMethod::Index, move |lua, this, key: mlua::Value| {
209            match MapKey::from_lua(key) {
210                Some(key) => match this.table().get(&key) {
211                    Some(value) => value.as_lua(lua),
212                    None => Ok(mlua::Value::Nil),
213                },
214                None => Ok(mlua::Value::Nil),
215            }
216        });
217
218        // NewIndex allows writing fields of the table
219        methods.add_meta_method_mut(
220            MetaMethod::NewIndex,
221            move |lua, this, (key, value): (mlua::Value, mlua::Value)| match MapKey::from_lua(key) {
222                Some(key) => {
223                    let value = CacheValue::from_lua(value, lua)?;
224                    this.unshare().insert(key, value);
225                    Ok(())
226                }
227                None => Err(mlua::Error::external(
228                    "invalid key type while trying to call __newindex and assign a value",
229                )),
230            },
231        );
232        methods.add_meta_method(MetaMethod::Len, move |_lua, this, ()| {
233            Ok(this.table().len())
234        });
235
236        // Pairs iterates the keys of the table.
237        // We use add_meta_function rather than add_meta_method here
238        // because we need to return `this` as the "state" parameter
239        // for use in a generic-for statement
240        methods.add_meta_function(MetaMethod::Pairs, move |lua, this: mlua::Value| {
241            // Maintain our own local idea of the control variable,
242            // as it is much cheaper and simpler to iterate based
243            // on skipping than to keep comparing keys
244            let mut idx = 0;
245
246            let iter_func =
247                lua.create_function_mut(
248                    move |lua, (state, _control): (UserDataRef<MemoizedTable>, mlua::Value)| {
249                        match state.table().iter().nth(idx) {
250                            Some((key, value)) => {
251                                idx += 1;
252                                let key = key.clone().as_lua(lua)?;
253                                let value = value.as_lua(lua)?;
254                                Ok((key, value))
255                            }
256                            None => Ok((mlua::Value::Nil, mlua::Value::Nil)),
257                        }
258                    },
259                )?;
260
261            // Return the iterator, state and control values.
262            // The state and control will be passed back into iter_func
263            // as the for-loop iterates.
264            // Control is Nil here because we track our own idx
265            // value in the iter_func closure.
266            Ok((mlua::Value::Function(iter_func), this, mlua::Value::Nil))
267        });
268    }
269}
270
271#[derive(Clone, Debug)]
272enum CacheEntry {
273    Null,
274    Single(CacheValue),
275    Multi(Vec<CacheValue>),
276}
277
278impl CacheEntry {
279    fn to_value(&self, lua: &Lua) -> mlua::Result<mlua::Value> {
280        match self {
281            Self::Null => Ok(mlua::Value::Nil),
282            Self::Single(value) => value.as_lua(lua),
283            Self::Multi(values) => {
284                let mut result = vec![];
285                for v in values {
286                    result.push(v.as_lua(lua)?);
287                }
288                result.into_lua(lua)
289            }
290        }
291    }
292
293    fn from_multi_value(lua: &Lua, multi: MultiValue) -> mlua::Result<Self> {
294        let mut values = multi.into_vec();
295        if values.is_empty() {
296            Ok(Self::Null)
297        } else if values.len() == 1 {
298            Ok(Self::Single(CacheValue::from_lua(
299                values.pop().unwrap(),
300                lua,
301            )?))
302        } else {
303            let mut cvalues = vec![];
304            for v in values.into_iter() {
305                cvalues.push(CacheValue::from_lua(v, lua)?);
306            }
307            Ok(Self::Multi(cvalues))
308        }
309    }
310}
311
312struct MemoizeCache {
313    params: MemoizeParams,
314    cache: Arc<LruCacheWithTtl<CacheKey, CacheEntry>>,
315}
316
317static CACHES: LazyLock<DashMap<String, MemoizeCache>> = LazyLock::new(DashMap::new);
318
319type CacheKey = (Option<ConfigEpoch>, String);
320
321fn get_cache_by_name(
322    name: &str,
323) -> Option<(Arc<LruCacheWithTtl<CacheKey, CacheEntry>>, Duration, bool)> {
324    CACHES.get(name).map(|item| {
325        (
326            item.cache.clone(),
327            item.params.ttl,
328            item.params.invalidate_with_epoch,
329        )
330    })
331}
332
333static CACHE_LOOKUP: LazyLock<CounterVec> = LazyLock::new(|| {
334    prometheus::register_counter_vec!(
335        "memoize_cache_lookup_count",
336        "how many times a memoize cache lookup was initiated for a given cache",
337        &["cache_name"]
338    )
339    .unwrap()
340});
341static CACHE_HIT: LazyLock<CounterVec> = LazyLock::new(|| {
342    prometheus::register_counter_vec!(
343        "memoize_cache_hit_count",
344        "how many times a memoize cache lookup was a hit for a given cache",
345        &["cache_name"]
346    )
347    .unwrap()
348});
349static CACHE_MISS: LazyLock<CounterVec> = LazyLock::new(|| {
350    prometheus::register_counter_vec!(
351        "memoize_cache_miss_count",
352        "how many times a memoize cache lookup was a miss for a given cache",
353        &["cache_name"]
354    )
355    .unwrap()
356});
357static CACHE_POPULATED: LazyLock<CounterVec> = LazyLock::new(|| {
358    prometheus::register_counter_vec!(
359        "memoize_cache_populated_count",
360        "how many times a memoize cache lookup resulted in performing the work to populate the entry",
361        &["cache_name"]
362    )
363    .unwrap()
364});
365
366fn multi_value_to_json_value(lua: &Lua, multi: MultiValue) -> mlua::Result<serde_json::Value> {
367    let mut values = multi.into_vec();
368    if values.is_empty() {
369        Ok(serde_json::Value::Null)
370    } else if values.len() == 1 {
371        from_lua_value(lua, values.pop().unwrap())
372    } else {
373        let mut jvalues = vec![];
374        for v in values.into_iter() {
375            jvalues.push(from_lua_value(lua, v)?);
376        }
377        Ok(serde_json::Value::Array(jvalues))
378    }
379}
380
381pub fn register(lua: &Lua) -> anyhow::Result<()> {
382    let kumo_mod = get_or_create_module(lua, "kumo")?;
383
384    kumo_mod.set(
385        "memoize",
386        lua.create_function(move |lua, (func, params): (mlua::Function, mlua::Value)| {
387            let params: MemoizeParams = from_lua_value(lua, params)?;
388
389            let cache_name = params.name.to_string();
390
391            if !lruttl::is_name_available(&cache_name) {
392                return Err(mlua::Error::external(format!(
393                    "cannot use name `{cache_name}` for a memoize cache, \
394                    as it collides with a built-in cache. \
395                    Suggestion: prefix your cache name with `user.` to \
396                    avoid conflicts with current and future caches."
397                )));
398            }
399
400            CACHES.remove_if(&params.name, |_k, item| {
401                let changed = item.params != params;
402                if changed {
403                    tracing::trace!("memoize parameters changed, replacing old cache {params:?}");
404                }
405                changed
406            });
407            CACHES.entry(cache_name.to_string()).or_insert_with(|| {
408                let cache = LruCacheWithTtl::new(cache_name.clone(), params.capacity);
409                if let Some(duration) = params.populate_timeout {
410                    cache.set_sema_timeout(duration);
411                }
412                cache.set_allow_stale_reads(params.allow_stale_reads);
413
414                MemoizeCache {
415                    params: params.clone(),
416                    cache: Arc::new(cache),
417                }
418            });
419
420            let lookup_counter = CACHE_LOOKUP
421                .get_metric_with_label_values(&[&cache_name])
422                .map_err(any_err)?;
423            let hit_counter = CACHE_HIT
424                .get_metric_with_label_values(&[&cache_name])
425                .map_err(any_err)?;
426            let miss_counter = CACHE_MISS
427                .get_metric_with_label_values(&[&cache_name])
428                .map_err(any_err)?;
429            let populate_counter = CACHE_POPULATED
430                .get_metric_with_label_values(&[&cache_name])
431                .map_err(any_err)?;
432            let retry_on_populate_timeout = params.retry_on_populate_timeout;
433            let allow_stale_reads = params.allow_stale_reads;
434
435            let func_ref = lua.create_registry_value(func)?;
436
437            lua.create_async_function(move |lua, params: MultiValue| {
438                let cache_name = cache_name.clone();
439                let func = lua.registry_value::<mlua::Function>(&func_ref);
440                let lookup_counter = lookup_counter.clone();
441                let hit_counter = hit_counter.clone();
442                let miss_counter = miss_counter.clone();
443                let populate_counter = populate_counter.clone();
444                async move {
445                    lookup_counter.inc();
446                    let key = multi_value_to_json_value(&lua, params.clone())?;
447
448                    let func = func?;
449
450                    let mut last_failure = None;
451
452                    for _attempt in 0..3 {
453                        // We use the epoch from the start of the lookup as part
454                        // of the cache key. If the epoch changes while we are in
455                        // the middle of computing this value then subsequent calls
456                        // through to the cached function will see the newer epoch
457                        // and encounter a cache miss. This prevents a race condition
458                        // poisoning the cache with a stale value during an epoch
459                        // bump. The caller will still observe the stale value, so
460                        // ultimately should have some accommodation for detecting
461                        // the epoch change and retrying their call through here,
462                        // if it is important to not see a stale value.
463                        let epoch_at_start = get_current_epoch();
464
465                        let (cache, ttl, invalidate_with_epoch) = get_cache_by_name(&cache_name)
466                            .ok_or_else(|| anyhow::anyhow!("cache is somehow undefined!?"))
467                            .map_err(any_err)?;
468
469                        let epoch_key = if invalidate_with_epoch && !allow_stale_reads {
470                            Some(epoch_at_start)
471                        } else {
472                            None
473                        };
474                        let key = serde_json::to_string(&key).map_err(any_err)?;
475                        let key = (epoch_key, key);
476
477                        let value_result = cache
478                            .get_or_try_insert(&key, |_| ttl, async {
479                                tracing::trace!("populate {key:?}");
480                                populate_counter.inc();
481                                let result: MultiValue =
482                                    (func.clone()).call_async(params.clone()).await?;
483                                CacheEntry::from_multi_value(&lua, result.clone())
484                            })
485                            .await;
486
487                        match value_result {
488                            Ok(lookup) => {
489                                if lookup.is_fresh {
490                                    miss_counter.inc();
491                                } else {
492                                    hit_counter.inc();
493                                }
494                                return lookup.item.to_value(&lua);
495                            }
496                            Err(err) => {
497                                tracing::error!("{cache_name} {key:?} failed: {err:#}");
498                                let error = format!("{err:#}");
499                                if !retry_on_populate_timeout {
500                                    return Err(mlua::Error::external(error));
501                                }
502                                last_failure.replace(error);
503                            }
504                        }
505                    }
506
507                    Err(mlua::Error::external(
508                        last_failure.expect("last_failure to always be set in loop above"),
509                    ))
510                }
511            })
512        })?,
513    )?;
514
515    Ok(())
516}
517
518#[cfg(test)]
519mod test {
520    use super::*;
521    use mlua::UserDataMethods;
522    use std::sync::atomic::{AtomicUsize, Ordering};
523
524    #[tokio::test]
525    async fn test_memoize() {
526        let lua = Lua::new();
527        register(&lua).unwrap();
528
529        let call_count = Arc::new(AtomicUsize::new(0));
530
531        let globals = lua.globals();
532        let counter = Arc::clone(&call_count);
533        globals
534            .set(
535                "do_thing",
536                lua.create_function(move |_lua, _: ()| {
537                    let count = counter.fetch_add(1, Ordering::SeqCst);
538                    Ok(count)
539                })
540                .unwrap(),
541            )
542            .unwrap();
543
544        let result: usize = lua
545            .load(
546                r#"
547            local kumo = require 'kumo';
548            -- make cached_do_thing a global for use in the expiry test below
549            cached_do_thing = kumo.memoize(do_thing, {
550                ttl = "1s",
551                capacity = 4,
552                name = "test_memoize_do_thing",
553            })
554            return cached_do_thing() + cached_do_thing() + cached_do_thing()
555        "#,
556            )
557            .eval_async()
558            .await
559            .unwrap();
560
561        assert_eq!(result, 0);
562        assert_eq!(call_count.load(Ordering::SeqCst), 1);
563
564        // And confirm that expiry works
565        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
566
567        let result: usize = lua
568            .load(
569                r#"
570            return cached_do_thing()
571        "#,
572            )
573            .eval()
574            .unwrap();
575
576        assert_eq!(result, 1);
577        assert_eq!(call_count.load(Ordering::SeqCst), 2);
578    }
579
580    #[tokio::test]
581    async fn test_memoize_rust() {
582        let lua = Lua::new();
583        register(&lua).unwrap();
584
585        let call_count = Arc::new(AtomicUsize::new(0));
586
587        #[derive(Clone)]
588        struct Foo {
589            value: usize,
590        }
591
592        impl UserData for Foo {
593            fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
594                Memoized::impl_memoize(methods);
595                methods.add_method("get_value", move |_lua, this, _: ()| Ok(this.value));
596            }
597        }
598
599        let globals = lua.globals();
600        let counter = Arc::clone(&call_count);
601        globals
602            .set(
603                "make_foo",
604                lua.create_function(move |_lua, _: ()| {
605                    let count = counter.fetch_add(1, Ordering::SeqCst);
606                    Ok(Foo { value: count })
607                })
608                .unwrap(),
609            )
610            .unwrap();
611
612        let result: usize = lua
613            .load(
614                r#"
615            local kumo = require 'kumo';
616            local cached_make_foo = kumo.memoize(make_foo, {
617                ttl = "1s",
618                capacity = 4,
619                name = "test_memoize_make_foo",
620            })
621            return cached_make_foo():get_value() +
622                   cached_make_foo():get_value() +
623                   cached_make_foo():get_value()
624        "#,
625            )
626            .eval()
627            .unwrap();
628
629        assert_eq!(result, 0);
630        assert_eq!(call_count.load(Ordering::SeqCst), 1);
631    }
632}