lruttl/
lib.rs

1use crate::metrics::*;
2use dashmap::DashMap;
3use kumo_prometheus::prometheus::{IntCounter, IntGauge};
4use kumo_server_memory::subscribe_to_memory_status_changes_async;
5use parking_lot::Mutex;
6use scopeguard::defer;
7use std::borrow::Borrow;
8use std::collections::HashMap;
9use std::fmt::Debug;
10use std::future::Future;
11use std::hash::Hash;
12use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use std::sync::{Arc, LazyLock, Weak};
14use tokio::sync::Semaphore;
15use tokio::time::{timeout_at, Duration, Instant};
16pub use {linkme, pastey as paste};
17
18mod metrics;
19
20static CACHES: LazyLock<Mutex<Vec<Weak<dyn CachePurger + Send + Sync>>>> =
21    LazyLock::new(Mutex::default);
22
23struct Inner<K: Clone + Hash + Eq + Debug, V: Clone + Send + Sync + Debug> {
24    name: String,
25    tick: AtomicUsize,
26    capacity: AtomicUsize,
27    allow_stale_reads: AtomicBool,
28    cache: DashMap<K, Item<V>>,
29    lru_samples: AtomicUsize,
30    sema_timeout_milliseconds: AtomicUsize,
31    lookup_counter: IntCounter,
32    evict_counter: IntCounter,
33    expire_counter: IntCounter,
34    hit_counter: IntCounter,
35    miss_counter: IntCounter,
36    populate_counter: IntCounter,
37    insert_counter: IntCounter,
38    stale_counter: IntCounter,
39    error_counter: IntCounter,
40    wait_gauge: IntGauge,
41    size_gauge: IntGauge,
42}
43
44impl<
45        K: Clone + Debug + Send + Sync + Hash + Eq + 'static,
46        V: Clone + Debug + Send + Sync + 'static,
47    > Inner<K, V>
48{
49    pub fn clear(&self) -> usize {
50        let num_entries = self.cache.len();
51
52        // We don't simply clear all elements here, as any pending
53        // items will be trying to wait to coordinate; we need
54        // to aggressively close the semaphore and wake them all up
55        // before we remove those entries.
56        self.cache.retain(|_k, item| {
57            if let ItemState::Pending(sema) = &item.item {
58                // Force everyone to wakeup and error out
59                sema.close();
60            }
61            false
62        });
63
64        self.size_gauge.set(self.cache.len() as i64);
65        num_entries
66    }
67
68    /// Evict up to target entries.
69    ///
70    /// We use a probablistic approach to the LRU, because
71    /// it is challenging to safely thread the classic doubly-linked-list
72    /// through dashmap.
73    ///
74    /// target is bounded to half of number of selected samples, in
75    /// order to ensure that we don't randomly pick the newest element
76    /// from the set when under pressure.
77    ///
78    /// Redis uses a similar technique for its LRU as described
79    /// in <https://redis.io/docs/latest/develop/reference/eviction/#apx-lru>
80    /// which suggests that sampling 10 keys at random to them compare
81    /// their recency yields a reasonably close approximation to the
82    /// 100% precise LRU.
83    ///
84    /// Since we also support TTLs, we'll just go ahead and remove
85    /// any expired keys that show up in the sampled set.
86    pub fn evict_some(&self, target: usize) -> usize {
87        let now = Instant::now();
88
89        // Approximate (since it could change immediately after reading)
90        // cache size
91        let cache_size = self.cache.len();
92        // How many keys to sample
93        let num_samples = self.lru_samples.load(Ordering::Relaxed).min(cache_size);
94
95        // a list of keys which have expired
96        let mut expired_keys = vec![];
97        // a random selection of up to num_samples (key, tick) tuples
98        let mut samples = vec![];
99
100        // Pick some random keys.
101        // The rand crate has some helpers for working with iterators,
102        // but they appear to copy many elements into an internal buffer
103        // in order to make a selection, and we want to avoid directly
104        // considering every possible element because some users have
105        // very large capacity caches.
106        //
107        // The approach taken here is to produce a random list of iterator
108        // offsets so that we can skim across the iterator in a single
109        // pass and pull out a random selection of elements.
110        // The sample function provides a randomized list of indices that
111        // we can use for this; we need to sort it first, but the cost
112        // should be reasonably low as num_samples should be ~10 or so
113        // in the most common configuration.
114        {
115            let mut rng = rand::thread_rng();
116            let mut indices =
117                rand::seq::index::sample(&mut rng, cache_size, num_samples).into_vec();
118            indices.sort();
119            let mut iter = self.cache.iter();
120            let mut current_idx = 0;
121
122            /// Advance an iterator by skip_amount.
123            /// Ideally we'd use Iterator::advance_by for this, but at the
124            /// time of writing that method is nightly only.
125            /// Note that it also uses next() internally anyway
126            fn advance_by(iter: &mut impl Iterator, skip_amount: usize) {
127                for _ in 0..skip_amount {
128                    if iter.next().is_none() {
129                        return;
130                    }
131                }
132            }
133
134            for idx in indices {
135                // idx is the index we want to be on; we'll need to skip ahead
136                // by some number of slots based on the current one. skip_amount
137                // is that number.
138                let skip_amount = idx - current_idx;
139                advance_by(&mut iter, skip_amount);
140
141                match iter.next() {
142                    Some(map_entry) => {
143                        current_idx = idx + 1;
144                        let item = map_entry.value();
145                        match &item.item {
146                            ItemState::Pending(_) | ItemState::Refreshing { .. } => {
147                                // Cannot evict a pending lookup
148                            }
149                            ItemState::Present(_) | ItemState::Failed(_) => {
150                                if now >= item.expiration {
151                                    expired_keys.push(map_entry.key().clone());
152                                } else {
153                                    let last_tick = item.last_tick.load(Ordering::Relaxed);
154                                    samples.push((map_entry.key().clone(), last_tick));
155                                }
156                            }
157                        }
158                    }
159                    None => {
160                        break;
161                    }
162                }
163            }
164        }
165
166        let mut num_removed = 0;
167        for key in expired_keys {
168            // Sanity check that it is still expired before removing it,
169            // because it would be a shame to remove it if another actor
170            // has just updated it
171            let removed = self
172                .cache
173                .remove_if(&key, |_k, entry| now >= entry.expiration)
174                .is_some();
175            if removed {
176                tracing::trace!("{} expired {key:?}", self.name);
177                num_removed += 1;
178                self.expire_counter.inc();
179            }
180        }
181
182        // Since we're picking random elements, we want to ensure that
183        // we never pick the newest element from the set to evict because
184        // that is likely the wrong choice. We need enough samples to
185        // know that the lowest number we picked is representative
186        // of the eldest element in the map overall.
187        // We limit ourselves to half of the number of selected samples.
188        let target = target.min(samples.len() / 2).max(1);
189
190        // If we met our target, skip the extra work below
191        if num_removed >= target {
192            self.size_gauge.set(self.cache.len() as i64);
193            tracing::trace!("{} expired {num_removed} of target {target}", self.name);
194            return num_removed;
195        }
196
197        // Sort by ascending tick, which is equivalent to having the
198        // LRU within that set towards the front of the vec
199        samples.sort_by(|(_ka, tick_a), (_kb, tick_b)| tick_a.cmp(tick_b));
200
201        for (key, tick) in samples {
202            // Sanity check that the tick value is the same as we expect.
203            // If it has changed since we sampled it, then that element
204            // is no longer a good candidate for LRU eviction.
205            if self
206                .cache
207                .remove_if(&key, |_k, item| {
208                    item.last_tick.load(Ordering::Relaxed) == tick
209                })
210                .is_some()
211            {
212                tracing::debug!("{} evicted {key:?}", self.name);
213                num_removed += 1;
214                self.evict_counter.inc();
215                self.size_gauge.set(self.cache.len() as i64);
216                if num_removed >= target {
217                    return num_removed;
218                }
219            }
220        }
221
222        if num_removed == 0 {
223            tracing::debug!(
224                "{} did not find anything to evict, target was {target}",
225                self.name
226            );
227        }
228
229        tracing::trace!("{} removed {num_removed} of target {target}", self.name);
230
231        num_removed
232    }
233
234    /// Potentially make some progress to get back under
235    /// budget on the cache capacity
236    pub fn maybe_evict(&self) -> usize {
237        let cache_size = self.cache.len();
238        let capacity = self.capacity.load(Ordering::Relaxed);
239        if cache_size > capacity {
240            self.evict_some(cache_size - capacity)
241        } else {
242            0
243        }
244    }
245}
246
247trait CachePurger {
248    fn name(&self) -> &str;
249    fn purge(&self) -> usize;
250    fn process_expirations(&self) -> usize;
251    fn update_capacity(&self, capacity: usize);
252}
253
254impl<
255        K: Clone + Debug + Send + Sync + Hash + Eq + 'static,
256        V: Clone + Debug + Send + Sync + 'static,
257    > CachePurger for Inner<K, V>
258{
259    fn name(&self) -> &str {
260        &self.name
261    }
262    fn purge(&self) -> usize {
263        self.clear()
264    }
265    fn process_expirations(&self) -> usize {
266        let now = Instant::now();
267        let mut expired_keys = vec![];
268        for map_entry in self.cache.iter() {
269            let item = map_entry.value();
270            match &item.item {
271                ItemState::Pending(_) | ItemState::Refreshing { .. } => {
272                    // Cannot evict a pending lookup
273                }
274                ItemState::Present(_) | ItemState::Failed(_) => {
275                    if now >= item.expiration {
276                        expired_keys.push(map_entry.key().clone());
277                    }
278                }
279            }
280        }
281
282        let mut num_removed = 0;
283        for key in expired_keys {
284            // Sanity check that it is still expired before removing it,
285            // because it would be a shame to remove it if another actor
286            // has just updated it
287            let removed = self
288                .cache
289                .remove_if(&key, |_k, entry| now >= entry.expiration)
290                .is_some();
291            if removed {
292                num_removed += 1;
293                self.expire_counter.inc();
294                self.size_gauge.set(self.cache.len() as i64);
295            }
296        }
297
298        num_removed + self.maybe_evict()
299    }
300
301    fn update_capacity(&self, capacity: usize) {
302        self.capacity.store(capacity, Ordering::Relaxed);
303        // Bring it within capacity.
304        // At the time of writing this is a bit half-hearted,
305        // but we'll eventually trim down via ongoing process_expirations()
306        // calls
307        self.process_expirations();
308    }
309}
310
311fn all_caches() -> Vec<Arc<dyn CachePurger + Send + Sync>> {
312    let mut result = vec![];
313    let mut caches = CACHES.lock();
314    caches.retain(|entry| match entry.upgrade() {
315        Some(purger) => {
316            result.push(purger);
317            true
318        }
319        None => false,
320    });
321    result
322}
323
324pub fn purge_all_caches() {
325    let purgers = all_caches();
326
327    tracing::error!("purging {} caches", purgers.len());
328    for purger in purgers {
329        let name = purger.name();
330        let num_entries = purger.purge();
331        tracing::error!("cleared {num_entries} entries from cache {name}");
332    }
333}
334
335async fn prune_expired_caches() {
336    loop {
337        tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
338        let purgers = all_caches();
339
340        for p in purgers {
341            let n = p.process_expirations();
342            if n > 0 {
343                tracing::debug!("expired {n} entries from cache {}", p.name());
344            }
345        }
346    }
347}
348
349#[linkme::distributed_slice]
350pub static LRUTTL_VIVIFY: [fn() -> CacheDefinition];
351
352#[macro_export]
353macro_rules! optional_doc {
354    ($doc:expr) => {
355        Some($doc.trim())
356    };
357    ($($doc:expr)+) => {
358        Some(concat!($($doc,)+).trim())
359    };
360    () => {
361        None
362    };
363}
364
365/// Declare a cache as a static, and link it into the list of possible
366/// pre-defined caches.
367///
368/// Due to a limitation in implementation details, you must also add
369/// `linkme.workspace = true` to the manifest of the crate where you
370/// use this macro.
371#[macro_export]
372macro_rules! declare_cache {
373    (
374     $(#[doc = $doc:expr])*
375     $vis:vis
376        static $sym:ident:
377        LruCacheWithTtl<$key:ty, $value:ty>::new($name:expr, $capacity:expr);
378    ) => {
379        $(#[doc = $doc])*
380        $vis static $sym: ::std::sync::LazyLock<$crate::LruCacheWithTtl<$key, $value>> =
381            ::std::sync::LazyLock::new(
382                || $crate::LruCacheWithTtl::new($name, $capacity));
383
384        // Link into LRUTTL_VIVIFY
385        $crate::paste::paste! {
386            #[linkme::distributed_slice($crate::LRUTTL_VIVIFY)]
387            static [<VIVIFY_ $sym>]: fn() -> $crate::CacheDefinition = || {
388                ::std::sync::LazyLock::force(&$sym);
389                $crate::CacheDefinition {
390                    name: $name,
391                    capacity: $capacity,
392                    doc: $crate::optional_doc!($($doc)*),
393                }
394            };
395        }
396    };
397}
398
399/// Ensure that all caches declared via declare_cache!
400/// have been instantiated and returns the set of names.
401fn vivify() {
402    LazyLock::force(&PREDEFINED_CACHES);
403}
404
405fn vivify_impl() -> HashMap<&'static str, CacheDefinition> {
406    let mut map = HashMap::new();
407
408    for vivify_func in LRUTTL_VIVIFY {
409        let definition = vivify_func();
410        assert!(
411            !map.contains_key(definition.name),
412            "duplicate cache name {}",
413            definition.name
414        );
415        map.insert(definition.name, definition);
416    }
417
418    map
419}
420
421#[derive(serde::Serialize)]
422pub struct CacheDefinition {
423    pub name: &'static str,
424    pub capacity: usize,
425    pub doc: Option<&'static str>,
426}
427
428static PREDEFINED_CACHES: LazyLock<HashMap<&'static str, CacheDefinition>> =
429    LazyLock::new(vivify_impl);
430
431pub fn get_definitions() -> Vec<&'static CacheDefinition> {
432    let mut defs = PREDEFINED_CACHES.values().collect::<Vec<_>>();
433    defs.sort_by(|a, b| a.name.cmp(&b.name));
434    defs
435}
436
437pub fn is_name_available(name: &str) -> bool {
438    !PREDEFINED_CACHES.contains_key(name)
439}
440
441/// Update the capacity value for a pre-defined cache
442pub fn set_cache_capacity(name: &str, capacity: usize) -> bool {
443    if !PREDEFINED_CACHES.contains_key(name) {
444        return false;
445    }
446    let caches = all_caches();
447    match caches.iter().find(|p| p.name() == name) {
448        Some(p) => {
449            p.update_capacity(capacity);
450            true
451        }
452        None => false,
453    }
454}
455
456pub fn spawn_memory_monitor() {
457    vivify();
458    tokio::spawn(purge_caches_on_memory_shortage());
459    tokio::spawn(prune_expired_caches());
460}
461
462async fn purge_caches_on_memory_shortage() {
463    tracing::debug!("starting memory monitor");
464    let mut memory_status = subscribe_to_memory_status_changes_async().await;
465    while let Ok(()) = memory_status.changed().await {
466        if kumo_server_memory::get_headroom() == 0 {
467            purge_all_caches();
468
469            // Wait a little bit so that we can debounce
470            // in the case where we're riding the cusp of
471            // the limit and would thrash the caches
472            tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
473        }
474    }
475}
476
477#[derive(Debug, Clone)]
478enum ItemState<V>
479where
480    V: Send,
481    V: Sync,
482{
483    Present(V),
484    Pending(Arc<Semaphore>),
485    Failed(Arc<anyhow::Error>),
486    Refreshing {
487        stale_value: V,
488        pending: Arc<Semaphore>,
489    },
490}
491
492#[derive(Debug)]
493struct Item<V>
494where
495    V: Send,
496    V: Sync,
497{
498    item: ItemState<V>,
499    expiration: Instant,
500    last_tick: AtomicUsize,
501}
502
503impl<V: Clone + Send + Sync> Clone for Item<V> {
504    fn clone(&self) -> Self {
505        Self {
506            item: self.item.clone(),
507            expiration: self.expiration,
508            last_tick: self.last_tick.load(Ordering::Relaxed).into(),
509        }
510    }
511}
512
513#[derive(Debug)]
514pub struct ItemLookup<V: Debug> {
515    /// A copy of the item
516    pub item: V,
517    /// If true, the get_or_try_insert operation populated the entry;
518    /// the operation was a cache miss
519    pub is_fresh: bool,
520    /// The instant at which this entry will expire
521    pub expiration: Instant,
522}
523
524pub struct LruCacheWithTtl<K: Clone + Debug + Hash + Eq, V: Clone + Debug + Send + Sync> {
525    inner: Arc<Inner<K, V>>,
526}
527
528impl<
529        K: Clone + Debug + Hash + Eq + Send + Sync + std::fmt::Debug + 'static,
530        V: Clone + Debug + Send + Sync + 'static,
531    > LruCacheWithTtl<K, V>
532{
533    pub fn new<S: Into<String>>(name: S, capacity: usize) -> Self {
534        let name = name.into();
535        let cache = DashMap::new();
536
537        let lookup_counter = CACHE_LOOKUP
538            .get_metric_with_label_values(&[&name])
539            .expect("failed to get counter");
540        let hit_counter = CACHE_HIT
541            .get_metric_with_label_values(&[&name])
542            .expect("failed to get counter");
543        let stale_counter = CACHE_STALE
544            .get_metric_with_label_values(&[&name])
545            .expect("failed to get counter");
546        let evict_counter = CACHE_EVICT
547            .get_metric_with_label_values(&[&name])
548            .expect("failed to get counter");
549        let expire_counter = CACHE_EXPIRE
550            .get_metric_with_label_values(&[&name])
551            .expect("failed to get counter");
552        let miss_counter = CACHE_MISS
553            .get_metric_with_label_values(&[&name])
554            .expect("failed to get counter");
555        let populate_counter = CACHE_POPULATED
556            .get_metric_with_label_values(&[&name])
557            .expect("failed to get counter");
558        let insert_counter = CACHE_INSERT
559            .get_metric_with_label_values(&[&name])
560            .expect("failed to get counter");
561        let error_counter = CACHE_ERROR
562            .get_metric_with_label_values(&[&name])
563            .expect("failed to get counter");
564        let wait_gauge = CACHE_WAIT
565            .get_metric_with_label_values(&[&name])
566            .expect("failed to get counter");
567        let size_gauge = CACHE_SIZE
568            .get_metric_with_label_values(&[&name])
569            .expect("failed to get counter");
570
571        let inner = Arc::new(Inner {
572            name,
573            cache,
574            tick: AtomicUsize::new(0),
575            allow_stale_reads: AtomicBool::new(false),
576            capacity: AtomicUsize::new(capacity),
577            lru_samples: AtomicUsize::new(10),
578            sema_timeout_milliseconds: AtomicUsize::new(120_000),
579            lookup_counter,
580            evict_counter,
581            expire_counter,
582            hit_counter,
583            miss_counter,
584            populate_counter,
585            error_counter,
586            wait_gauge,
587            insert_counter,
588            stale_counter,
589            size_gauge,
590        });
591
592        // Register with the global list of caches using a weak reference.
593        // We need to "erase" the K/V types in order to do that, so we
594        // use the CachePurger trait for this purpose.
595        {
596            let generic: Arc<dyn CachePurger + Send + Sync> = inner.clone();
597            CACHES.lock().push(Arc::downgrade(&generic));
598            tracing::debug!(
599                "registered cache {} with capacity {capacity}",
600                generic.name()
601            );
602        }
603
604        Self { inner }
605    }
606
607    fn allow_stale_reads(&self) -> bool {
608        self.inner.allow_stale_reads.load(Ordering::Relaxed)
609    }
610
611    pub fn set_allow_stale_reads(&self, value: bool) {
612        self.inner.allow_stale_reads.store(value, Ordering::Relaxed);
613    }
614
615    pub fn set_sema_timeout(&self, duration: Duration) {
616        self.inner
617            .sema_timeout_milliseconds
618            .store(duration.as_millis() as usize, Ordering::Relaxed);
619    }
620
621    pub fn clear(&self) -> usize {
622        self.inner.clear()
623    }
624
625    fn inc_tick(&self) -> usize {
626        self.inner.tick.fetch_add(1, Ordering::Relaxed) + 1
627    }
628
629    fn update_tick(&self, item: &Item<V>) {
630        let v = self.inc_tick();
631        item.last_tick.store(v, Ordering::Relaxed);
632    }
633
634    pub fn lookup<Q: ?Sized>(&self, name: &Q) -> Option<ItemLookup<V>>
635    where
636        K: Borrow<Q>,
637        Q: Hash + Eq,
638    {
639        self.inner.lookup_counter.inc();
640        match self.inner.cache.get_mut(name) {
641            None => {
642                self.inner.miss_counter.inc();
643                None
644            }
645            Some(entry) => {
646                match &entry.item {
647                    ItemState::Present(item) => {
648                        let now = Instant::now();
649                        if now >= entry.expiration {
650                            // Expired
651                            if self.allow_stale_reads() {
652                                // We don't furnish a result directly, but we
653                                // also do not want to remove it from the map
654                                // at this stage.
655                                // We're assuming that lookup() is called only
656                                // via get_or_try_insert when allow_stale_reads
657                                // is enabled.
658                                self.inner.miss_counter.inc();
659                                return None;
660                            }
661
662                            // otherwise: remove it from the map.
663                            // Take care to drop our ref first so that we don't
664                            // self-deadlock
665                            drop(entry);
666                            if self
667                                .inner
668                                .cache
669                                .remove_if(name, |_k, entry| now >= entry.expiration)
670                                .is_some()
671                            {
672                                self.inner.expire_counter.inc();
673                                self.inner.size_gauge.set(self.inner.cache.len() as i64);
674                            }
675                            self.inner.miss_counter.inc();
676                            return None;
677                        }
678                        self.inner.hit_counter.inc();
679                        self.update_tick(&entry);
680                        Some(ItemLookup {
681                            item: item.clone(),
682                            expiration: entry.expiration,
683                            is_fresh: false,
684                        })
685                    }
686                    ItemState::Refreshing { .. } | ItemState::Pending(_) | ItemState::Failed(_) => {
687                        self.inner.miss_counter.inc();
688                        None
689                    }
690                }
691            }
692        }
693    }
694
695    pub fn get<Q: ?Sized>(&self, name: &Q) -> Option<V>
696    where
697        K: Borrow<Q>,
698        Q: Hash + Eq,
699    {
700        self.lookup(name).map(|lookup| lookup.item)
701    }
702
703    pub async fn insert(&self, name: K, item: V, expiration: Instant) -> V {
704        self.inner.cache.insert(
705            name,
706            Item {
707                item: ItemState::Present(item.clone()),
708                expiration,
709                last_tick: self.inc_tick().into(),
710            },
711        );
712
713        self.inner.insert_counter.inc();
714        self.inner.size_gauge.set(self.inner.cache.len() as i64);
715        self.inner.maybe_evict();
716
717        item
718    }
719
720    fn clone_item_state(&self, name: &K) -> (ItemState<V>, Instant) {
721        let mut is_new = false;
722        let mut entry = self.inner.cache.entry(name.clone()).or_insert_with(|| {
723            is_new = true;
724            Item {
725                item: ItemState::Pending(Arc::new(Semaphore::new(1))),
726                expiration: Instant::now() + Duration::from_secs(60),
727                last_tick: self.inc_tick().into(),
728            }
729        });
730
731        match &entry.value().item {
732            ItemState::Pending(sema) => {
733                if sema.is_closed() {
734                    entry.value_mut().item = ItemState::Pending(Arc::new(Semaphore::new(1)));
735                }
736            }
737            ItemState::Refreshing {
738                stale_value,
739                pending,
740            } => {
741                if pending.is_closed() {
742                    entry.value_mut().item = ItemState::Refreshing {
743                        stale_value: stale_value.clone(),
744                        pending: Arc::new(Semaphore::new(1)),
745                    };
746                }
747            }
748            ItemState::Present(item) => {
749                let now = Instant::now();
750                if now >= entry.expiration {
751                    // Expired; we will need to fetch it
752                    let pending = Arc::new(Semaphore::new(1));
753                    if self.allow_stale_reads() {
754                        entry.value_mut().item = ItemState::Refreshing {
755                            stale_value: item.clone(),
756                            pending,
757                        };
758                    } else {
759                        entry.value_mut().item = ItemState::Pending(pending);
760                    }
761                }
762            }
763            ItemState::Failed(_) => {
764                let now = Instant::now();
765                if now >= entry.expiration {
766                    // Expired; we will need to fetch it
767                    entry.value_mut().item = ItemState::Pending(Arc::new(Semaphore::new(1)));
768                }
769            }
770        }
771
772        self.update_tick(&entry);
773        let item = entry.value();
774        let result = (item.item.clone(), entry.expiration);
775        drop(entry);
776
777        if is_new {
778            self.inner.size_gauge.set(self.inner.cache.len() as i64);
779            self.inner.maybe_evict();
780        }
781
782        result
783    }
784
785    /// Get an existing item, but if that item doesn't already exist,
786    /// execute the future `fut` to provide a value that will be inserted and then
787    /// returned.  This is done atomically wrt. other callers.
788    /// The TTL parameter is a function that can extract the TTL from the value type,
789    /// or just return a constant TTL.
790    pub async fn get_or_try_insert<E: Into<anyhow::Error>, TTL: FnOnce(&V) -> Duration>(
791        &self,
792        name: &K,
793        ttl_func: TTL,
794        fut: impl Future<Output = Result<V, E>>,
795    ) -> Result<ItemLookup<V>, Arc<anyhow::Error>> {
796        // Fast path avoids cloning the key
797        if let Some(entry) = self.lookup(name) {
798            return Ok(entry);
799        }
800
801        let timeout_duration = Duration::from_millis(
802            self.inner.sema_timeout_milliseconds.load(Ordering::Relaxed) as u64,
803        );
804        let start = Instant::now();
805        let deadline = start + timeout_duration;
806
807        // Note: the lookup call increments lookup_counter and miss_counter
808        'retry: loop {
809            let (stale_value, sema) = match self.clone_item_state(name) {
810                (ItemState::Present(item), expiration) => {
811                    return Ok(ItemLookup {
812                        item,
813                        expiration,
814                        is_fresh: false,
815                    });
816                }
817                (ItemState::Failed(error), _) => {
818                    return Err(error);
819                }
820                (
821                    ItemState::Refreshing {
822                        stale_value,
823                        pending,
824                    },
825                    expiration,
826                ) => (Some((stale_value, expiration)), pending),
827                (ItemState::Pending(sema), _) => (None, sema),
828            };
829
830            let wait_result = {
831                self.inner.wait_gauge.inc();
832                defer! {
833                    self.inner.wait_gauge.dec();
834                }
835
836                match timeout_at(deadline, sema.acquire_owned()).await {
837                    Err(_) => {
838                        if let Some((item, expiration)) = stale_value {
839                            tracing::debug!(
840                                "{} semaphore acquire for {name:?} timed out after \
841                                {timeout_duration:?}, allowing stale value to satisfy the lookup",
842                                self.inner.name
843                            );
844                            self.inner.stale_counter.inc();
845                            return Ok(ItemLookup {
846                                item,
847                                expiration,
848                                is_fresh: false,
849                            });
850                        }
851                        tracing::debug!(
852                            "{} semaphore acquire for {name:?} timed out after \
853                                {timeout_duration:?}, returning error",
854                            self.inner.name
855                        );
856
857                        self.inner.error_counter.inc();
858                        return Err(Arc::new(anyhow::anyhow!(
859                            "{} lookup for {name:?} \
860                            timed out after {timeout_duration:?} \
861                            on semaphore acquire while waiting for cache to populate",
862                            self.inner.name
863                        )));
864                    }
865                    Ok(r) => r,
866                }
867            };
868
869            // While we slept, someone else may have satisfied
870            // the lookup; check it
871            let current_sema = match self.clone_item_state(name) {
872                (ItemState::Present(item), expiration) => {
873                    return Ok(ItemLookup {
874                        item,
875                        expiration,
876                        is_fresh: false,
877                    });
878                }
879                (ItemState::Failed(error), _) => {
880                    self.inner.hit_counter.inc();
881                    return Err(error);
882                }
883                (
884                    ItemState::Refreshing {
885                        stale_value: _,
886                        pending,
887                    },
888                    _,
889                ) => pending,
890                (ItemState::Pending(current_sema), _) => current_sema,
891            };
892
893            // It's still outstanding
894            match wait_result {
895                Ok(permit) => {
896                    // We're responsible for resolving it.
897                    // We will always close the semaphore when
898                    // we're done with this logic (and when we unwind
899                    // or are cancelled) so that we can wake up any
900                    // waiters.
901                    // We use defer! for this so that if we are cancelled
902                    // at the await point below, others are still woken up.
903                    defer! {
904                        permit.semaphore().close();
905                    }
906
907                    if !Arc::ptr_eq(&current_sema, permit.semaphore()) {
908                        self.inner.error_counter.inc();
909                        tracing::warn!(
910                            "{} mismatched semaphores for {name:?}, \
911                                    will restart cache resolve.",
912                            self.inner.name
913                        );
914                        continue 'retry;
915                    }
916
917                    self.inner.populate_counter.inc();
918                    let mut ttl = Duration::from_secs(60);
919                    let future_result = fut.await;
920                    let now = Instant::now();
921
922                    let (item_result, return_value) = match future_result {
923                        Ok(item) => {
924                            ttl = ttl_func(&item);
925                            (
926                                ItemState::Present(item.clone()),
927                                Ok(ItemLookup {
928                                    item,
929                                    expiration: now + ttl,
930                                    is_fresh: true,
931                                }),
932                            )
933                        }
934                        Err(err) => {
935                            self.inner.error_counter.inc();
936                            let err = Arc::new(err.into());
937                            (ItemState::Failed(err.clone()), Err(err))
938                        }
939                    };
940
941                    self.inner.cache.insert(
942                        name.clone(),
943                        Item {
944                            item: item_result,
945                            expiration: Instant::now() + ttl,
946                            last_tick: self.inc_tick().into(),
947                        },
948                    );
949                    self.inner.maybe_evict();
950                    return return_value;
951                }
952                Err(_) => {
953                    self.inner.error_counter.inc();
954
955                    // semaphore was closed, but the status is
956                    // still somehow pending
957                    tracing::debug!(
958                        "{} lookup for {name:?} woke up semaphores \
959                                but is still marked pending, \
960                                will restart cache lookup",
961                        self.inner.name
962                    );
963                    continue 'retry;
964                }
965            }
966        }
967    }
968}
969
970#[cfg(test)]
971mod test {
972    use super::*;
973    use test_log::test; // run with RUST_LOG=lruttl=trace to trace
974
975    #[test(tokio::test)]
976    async fn test_capacity() {
977        let cache = LruCacheWithTtl::new("test_capacity", 40);
978
979        let expiration = Instant::now() + Duration::from_secs(60);
980        for i in 0..100 {
981            cache.insert(i, i, expiration).await;
982        }
983
984        assert_eq!(cache.inner.cache.len(), 40, "capacity is respected");
985    }
986
987    #[test(tokio::test)]
988    async fn test_expiration() {
989        let cache = LruCacheWithTtl::new("test_expiration", 1);
990
991        tokio::time::pause();
992        let expiration = Instant::now() + Duration::from_secs(1);
993        cache.insert(0, 0, expiration).await;
994
995        cache.get(&0).expect("still in cache");
996        tokio::time::advance(Duration::from_secs(2)).await;
997        assert!(cache.get(&0).is_none(), "evicted due to ttl");
998    }
999
1000    #[test(tokio::test)]
1001    async fn test_over_capacity_slow_resolve() {
1002        let cache = Arc::new(LruCacheWithTtl::<String, u64>::new(
1003            "test_over_capacity_slow_resolve",
1004            1,
1005        ));
1006
1007        let mut foos = vec![];
1008        for idx in 0..2 {
1009            let cache = cache.clone();
1010            foos.push(tokio::spawn(async move {
1011                eprintln!("spawned task {idx} is running");
1012                cache
1013                    .get_or_try_insert(&"foo".to_string(), |_| Duration::from_secs(86400), async {
1014                        if idx == 0 {
1015                            eprintln!("foo {idx} getter sleeping");
1016                            tokio::time::sleep(Duration::from_secs(300)).await;
1017                        }
1018                        eprintln!("foo {idx} getter done");
1019                        Ok::<_, anyhow::Error>(idx)
1020                    })
1021                    .await
1022            }));
1023        }
1024
1025        tokio::task::yield_now().await;
1026
1027        eprintln!("calling again with immediate getter");
1028        let result = cache
1029            .get_or_try_insert(&"bar".to_string(), |_| Duration::from_secs(60), async {
1030                eprintln!("bar immediate getter running");
1031                Ok::<_, anyhow::Error>(42)
1032            })
1033            .await
1034            .unwrap();
1035
1036        assert_eq!(result.item, 42);
1037        assert_eq!(cache.inner.cache.len(), 1);
1038
1039        eprintln!("aborting first one");
1040        foos.remove(0).abort();
1041
1042        eprintln!("try new key");
1043        let result = cache
1044            .get_or_try_insert(&"baz".to_string(), |_| Duration::from_secs(60), async {
1045                eprintln!("baz immediate getter running");
1046                Ok::<_, anyhow::Error>(32)
1047            })
1048            .await
1049            .unwrap();
1050        assert_eq!(result.item, 32);
1051        assert_eq!(cache.inner.cache.len(), 1);
1052
1053        eprintln!("waiting second one");
1054        assert_eq!(1, foos.pop().unwrap().await.unwrap().unwrap().item);
1055
1056        assert_eq!(cache.inner.cache.len(), 1);
1057    }
1058}