lruttl/
lib.rs

1use crate::metrics::*;
2use dashmap::DashMap;
3use kumo_server_memory::subscribe_to_memory_status_changes_async;
4use parking_lot::Mutex;
5use prometheus::{IntCounter, IntGauge};
6use scopeguard::defer;
7use std::borrow::Borrow;
8use std::collections::HashMap;
9use std::fmt::Debug;
10use std::future::Future;
11use std::hash::Hash;
12use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use std::sync::{Arc, LazyLock, Weak};
14use tokio::sync::Semaphore;
15use tokio::time::{timeout_at, Duration, Instant};
16pub use {linkme, pastey as paste};
17
18mod metrics;
19
20static CACHES: LazyLock<Mutex<Vec<Weak<dyn CachePurger + Send + Sync>>>> =
21    LazyLock::new(Mutex::default);
22
23struct Inner<K: Clone + Hash + Eq + Debug, V: Clone + Send + Sync + Debug> {
24    name: String,
25    tick: AtomicUsize,
26    capacity: AtomicUsize,
27    allow_stale_reads: AtomicBool,
28    cache: DashMap<K, Item<V>>,
29    lru_samples: AtomicUsize,
30    sema_timeout_milliseconds: AtomicUsize,
31    lookup_counter: IntCounter,
32    evict_counter: IntCounter,
33    expire_counter: IntCounter,
34    hit_counter: IntCounter,
35    miss_counter: IntCounter,
36    populate_counter: IntCounter,
37    insert_counter: IntCounter,
38    stale_counter: IntCounter,
39    error_counter: IntCounter,
40    wait_gauge: IntGauge,
41    size_gauge: IntGauge,
42}
43
44impl<
45        K: Clone + Debug + Send + Sync + Hash + Eq + 'static,
46        V: Clone + Debug + Send + Sync + 'static,
47    > Inner<K, V>
48{
49    pub fn clear(&self) -> usize {
50        let num_entries = self.cache.len();
51
52        // We don't simply clear all elements here, as any pending
53        // items will be trying to wait to coordinate; we need
54        // to aggressively close the semaphore and wake them all up
55        // before we remove those entries.
56        self.cache.retain(|_k, item| {
57            if let ItemState::Pending(sema) = &item.item {
58                // Force everyone to wakeup and error out
59                sema.close();
60            }
61            false
62        });
63
64        self.size_gauge.set(self.cache.len() as i64);
65        num_entries
66    }
67
68    /// Evict up to target entries.
69    ///
70    /// We use a probablistic approach to the LRU, because
71    /// it is challenging to safely thread the classic doubly-linked-list
72    /// through dashmap.
73    ///
74    /// target is bounded to half of number of selected samples, in
75    /// order to ensure that we don't randomly pick the newest element
76    /// from the set when under pressure.
77    ///
78    /// Redis uses a similar technique for its LRU as described
79    /// in <https://redis.io/docs/latest/develop/reference/eviction/#apx-lru>
80    /// which suggests that sampling 10 keys at random to them compare
81    /// their recency yields a reasonably close approximation to the
82    /// 100% precise LRU.
83    ///
84    /// Since we also support TTLs, we'll just go ahead and remove
85    /// any expired keys that show up in the sampled set.
86    pub fn evict_some(&self, target: usize) -> usize {
87        let now = Instant::now();
88
89        // Approximate (since it could change immediately after reading)
90        // cache size
91        let cache_size = self.cache.len();
92        // How many keys to sample
93        let num_samples = self.lru_samples.load(Ordering::Relaxed).min(cache_size);
94
95        // a list of keys which have expired
96        let mut expired_keys = vec![];
97        // a random selection of up to num_samples (key, tick) tuples
98        let mut samples = vec![];
99
100        // Pick some random keys.
101        // The rand crate has some helpers for working with iterators,
102        // but they appear to copy many elements into an internal buffer
103        // in order to make a selection, and we want to avoid directly
104        // considering every possible element because some users have
105        // very large capacity caches.
106        //
107        // The approach taken here is to produce a random list of iterator
108        // offsets so that we can skim across the iterator in a single
109        // pass and pull out a random selection of elements.
110        // The sample function provides a randomized list of indices that
111        // we can use for this; we need to sort it first, but the cost
112        // should be reasonably low as num_samples should be ~10 or so
113        // in the most common configuration.
114        {
115            let mut rng = rand::thread_rng();
116            let mut indices =
117                rand::seq::index::sample(&mut rng, cache_size, num_samples).into_vec();
118            indices.sort();
119            let mut iter = self.cache.iter();
120            let mut current_idx = 0;
121
122            /// Advance an iterator by skip_amount.
123            /// Ideally we'd use Iterator::advance_by for this, but at the
124            /// time of writing that method is nightly only.
125            /// Note that it also uses next() internally anyway
126            fn advance_by(iter: &mut impl Iterator, skip_amount: usize) {
127                for _ in 0..skip_amount {
128                    if iter.next().is_none() {
129                        return;
130                    }
131                }
132            }
133
134            for idx in indices {
135                // idx is the index we want to be on; we'll need to skip ahead
136                // by some number of slots based on the current one. skip_amount
137                // is that number.
138                let skip_amount = idx - current_idx;
139                advance_by(&mut iter, skip_amount);
140
141                match iter.next() {
142                    Some(map_entry) => {
143                        current_idx = idx + 1;
144                        let item = map_entry.value();
145                        match &item.item {
146                            ItemState::Pending(_) | ItemState::Refreshing { .. } => {
147                                // Cannot evict a pending lookup
148                            }
149                            ItemState::Present(_) | ItemState::Failed(_) => {
150                                if now >= item.expiration {
151                                    expired_keys.push(map_entry.key().clone());
152                                } else {
153                                    let last_tick = item.last_tick.load(Ordering::Relaxed);
154                                    samples.push((map_entry.key().clone(), last_tick));
155                                }
156                            }
157                        }
158                    }
159                    None => {
160                        break;
161                    }
162                }
163            }
164        }
165
166        let mut num_removed = 0;
167        for key in expired_keys {
168            // Sanity check that it is still expired before removing it,
169            // because it would be a shame to remove it if another actor
170            // has just updated it
171            let removed = self
172                .cache
173                .remove_if(&key, |_k, entry| now >= entry.expiration)
174                .is_some();
175            if removed {
176                tracing::trace!("{} expired {key:?}", self.name);
177                num_removed += 1;
178                self.expire_counter.inc();
179            }
180        }
181
182        // Since we're picking random elements, we want to ensure that
183        // we never pick the newest element from the set to evict because
184        // that is likely the wrong choice. We need enough samples to
185        // know that the lowest number we picked is representative
186        // of the eldest element in the map overall.
187        // We limit ourselves to half of the number of selected samples.
188        let target = target.min(samples.len() / 2).max(1);
189
190        // If we met our target, skip the extra work below
191        if num_removed >= target {
192            self.size_gauge.set(self.cache.len() as i64);
193            tracing::trace!("{} expired {num_removed} of target {target}", self.name);
194            return num_removed;
195        }
196
197        // Sort by ascending tick, which is equivalent to having the
198        // LRU within that set towards the front of the vec
199        samples.sort_by(|(_ka, tick_a), (_kb, tick_b)| tick_a.cmp(tick_b));
200
201        for (key, tick) in samples {
202            // Sanity check that the tick value is the same as we expect.
203            // If it has changed since we sampled it, then that element
204            // is no longer a good candidate for LRU eviction.
205            if self
206                .cache
207                .remove_if(&key, |_k, item| {
208                    item.last_tick.load(Ordering::Relaxed) == tick
209                })
210                .is_some()
211            {
212                tracing::debug!("{} evicted {key:?}", self.name);
213                num_removed += 1;
214                self.evict_counter.inc();
215                self.size_gauge.set(self.cache.len() as i64);
216                if num_removed >= target {
217                    return num_removed;
218                }
219            }
220        }
221
222        if num_removed == 0 {
223            tracing::debug!(
224                "{} did not find anything to evict, target was {target}",
225                self.name
226            );
227        }
228
229        tracing::trace!("{} removed {num_removed} of target {target}", self.name);
230
231        num_removed
232    }
233
234    /// Potentially make some progress to get back under
235    /// budget on the cache capacity
236    pub fn maybe_evict(&self) -> usize {
237        let cache_size = self.cache.len();
238        let capacity = self.capacity.load(Ordering::Relaxed);
239        if cache_size > capacity {
240            self.evict_some(cache_size - capacity)
241        } else {
242            0
243        }
244    }
245}
246
247trait CachePurger {
248    fn name(&self) -> &str;
249    fn purge(&self) -> usize;
250    fn process_expirations(&self) -> usize;
251    fn update_capacity(&self, capacity: usize);
252}
253
254impl<
255        K: Clone + Debug + Send + Sync + Hash + Eq + 'static,
256        V: Clone + Debug + Send + Sync + 'static,
257    > CachePurger for Inner<K, V>
258{
259    fn name(&self) -> &str {
260        &self.name
261    }
262    fn purge(&self) -> usize {
263        self.clear()
264    }
265    fn process_expirations(&self) -> usize {
266        let now = Instant::now();
267        let mut expired_keys = vec![];
268        for map_entry in self.cache.iter() {
269            let item = map_entry.value();
270            match &item.item {
271                ItemState::Pending(_) | ItemState::Refreshing { .. } => {
272                    // Cannot evict a pending lookup
273                }
274                ItemState::Present(_) | ItemState::Failed(_) => {
275                    if now >= item.expiration {
276                        expired_keys.push(map_entry.key().clone());
277                    }
278                }
279            }
280        }
281
282        let mut num_removed = 0;
283        for key in expired_keys {
284            // Sanity check that it is still expired before removing it,
285            // because it would be a shame to remove it if another actor
286            // has just updated it
287            let removed = self
288                .cache
289                .remove_if(&key, |_k, entry| now >= entry.expiration)
290                .is_some();
291            if removed {
292                num_removed += 1;
293                self.expire_counter.inc();
294                self.size_gauge.set(self.cache.len() as i64);
295            }
296        }
297
298        num_removed + self.maybe_evict()
299    }
300
301    fn update_capacity(&self, capacity: usize) {
302        self.capacity.store(capacity, Ordering::Relaxed);
303        // Bring it within capacity.
304        // At the time of writing this is a bit half-hearted,
305        // but we'll eventually trim down via ongoing process_expirations()
306        // calls
307        self.process_expirations();
308    }
309}
310
311fn all_caches() -> Vec<Arc<dyn CachePurger + Send + Sync>> {
312    let mut result = vec![];
313    let mut caches = CACHES.lock();
314    caches.retain(|entry| match entry.upgrade() {
315        Some(purger) => {
316            result.push(purger);
317            true
318        }
319        None => false,
320    });
321    result
322}
323
324pub fn purge_all_caches() {
325    let purgers = all_caches();
326
327    tracing::error!("purging {} caches", purgers.len());
328    for purger in purgers {
329        let name = purger.name();
330        let num_entries = purger.purge();
331        tracing::error!("cleared {num_entries} entries from cache {name}");
332    }
333}
334
335async fn prune_expired_caches() {
336    loop {
337        tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
338        let purgers = all_caches();
339
340        for p in purgers {
341            let n = p.process_expirations();
342            if n > 0 {
343                tracing::debug!("expired {n} entries from cache {}", p.name());
344            }
345        }
346    }
347}
348
349#[linkme::distributed_slice]
350pub static LRUTTL_VIVIFY: [fn() -> CacheDefinition];
351
352#[macro_export]
353macro_rules! optional_doc {
354    ($doc:expr) => {
355        Some($doc.trim())
356    };
357    () => {
358        None
359    };
360}
361
362/// Declare a cache as a static, and link it into the list of possible
363/// pre-defined caches.
364///
365/// Due to a limitation in implementation details, you must also add
366/// `linkme.workspace = true` to the manifest of the crate where you
367/// use this macro.
368#[macro_export]
369macro_rules! declare_cache {
370    (
371     $(#[doc = $doc:expr])*
372     $vis:vis
373        static $sym:ident:
374        LruCacheWithTtl<$key:ty, $value:ty>::new($name:expr, $capacity:expr);
375    ) => {
376        $(#[doc = $doc])*
377        $vis static $sym: ::std::sync::LazyLock<$crate::LruCacheWithTtl<$key, $value>> =
378            ::std::sync::LazyLock::new(
379                || $crate::LruCacheWithTtl::new($name, $capacity));
380
381        // Link into LRUTTL_VIVIFY
382        $crate::paste::paste! {
383            #[linkme::distributed_slice($crate::LRUTTL_VIVIFY)]
384            static [<VIVIFY_ $sym>]: fn() -> $crate::CacheDefinition = || {
385                ::std::sync::LazyLock::force(&$sym);
386                $crate::CacheDefinition {
387                    name: $name,
388                    capacity: $capacity,
389                    doc: $crate::optional_doc!($($doc)*),
390                }
391            };
392        }
393    };
394}
395
396/// Ensure that all caches declared via declare_cache!
397/// have been instantiated and returns the set of names.
398fn vivify() {
399    LazyLock::force(&PREDEFINED_CACHES);
400}
401
402fn vivify_impl() -> HashMap<&'static str, CacheDefinition> {
403    let mut map = HashMap::new();
404
405    for vivify_func in LRUTTL_VIVIFY {
406        let definition = vivify_func();
407        assert!(
408            !map.contains_key(definition.name),
409            "duplicate cache name {}",
410            definition.name
411        );
412        map.insert(definition.name, definition);
413    }
414
415    map
416}
417
418#[derive(serde::Serialize)]
419pub struct CacheDefinition {
420    pub name: &'static str,
421    pub capacity: usize,
422    pub doc: Option<&'static str>,
423}
424
425static PREDEFINED_CACHES: LazyLock<HashMap<&'static str, CacheDefinition>> =
426    LazyLock::new(vivify_impl);
427
428pub fn get_definitions() -> Vec<&'static CacheDefinition> {
429    let mut defs = PREDEFINED_CACHES.values().collect::<Vec<_>>();
430    defs.sort_by(|a, b| a.name.cmp(&b.name));
431    defs
432}
433
434pub fn is_name_available(name: &str) -> bool {
435    !PREDEFINED_CACHES.contains_key(name)
436}
437
438/// Update the capacity value for a pre-defined cache
439pub fn set_cache_capacity(name: &str, capacity: usize) -> bool {
440    if !PREDEFINED_CACHES.contains_key(name) {
441        return false;
442    }
443    let caches = all_caches();
444    match caches.iter().find(|p| p.name() == name) {
445        Some(p) => {
446            p.update_capacity(capacity);
447            true
448        }
449        None => false,
450    }
451}
452
453pub fn spawn_memory_monitor() {
454    vivify();
455    tokio::spawn(purge_caches_on_memory_shortage());
456    tokio::spawn(prune_expired_caches());
457}
458
459async fn purge_caches_on_memory_shortage() {
460    tracing::debug!("starting memory monitor");
461    let mut memory_status = subscribe_to_memory_status_changes_async().await;
462    while let Ok(()) = memory_status.changed().await {
463        if kumo_server_memory::get_headroom() == 0 {
464            purge_all_caches();
465
466            // Wait a little bit so that we can debounce
467            // in the case where we're riding the cusp of
468            // the limit and would thrash the caches
469            tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
470        }
471    }
472}
473
474#[derive(Debug, Clone)]
475enum ItemState<V>
476where
477    V: Send,
478    V: Sync,
479{
480    Present(V),
481    Pending(Arc<Semaphore>),
482    Failed(Arc<anyhow::Error>),
483    Refreshing {
484        stale_value: V,
485        pending: Arc<Semaphore>,
486    },
487}
488
489#[derive(Debug)]
490struct Item<V>
491where
492    V: Send,
493    V: Sync,
494{
495    item: ItemState<V>,
496    expiration: Instant,
497    last_tick: AtomicUsize,
498}
499
500impl<V: Clone + Send + Sync> Clone for Item<V> {
501    fn clone(&self) -> Self {
502        Self {
503            item: self.item.clone(),
504            expiration: self.expiration,
505            last_tick: self.last_tick.load(Ordering::Relaxed).into(),
506        }
507    }
508}
509
510#[derive(Debug)]
511pub struct ItemLookup<V: Debug> {
512    /// A copy of the item
513    pub item: V,
514    /// If true, the get_or_try_insert operation populated the entry;
515    /// the operation was a cache miss
516    pub is_fresh: bool,
517    /// The instant at which this entry will expire
518    pub expiration: Instant,
519}
520
521pub struct LruCacheWithTtl<K: Clone + Debug + Hash + Eq, V: Clone + Debug + Send + Sync> {
522    inner: Arc<Inner<K, V>>,
523}
524
525impl<
526        K: Clone + Debug + Hash + Eq + Send + Sync + std::fmt::Debug + 'static,
527        V: Clone + Debug + Send + Sync + 'static,
528    > LruCacheWithTtl<K, V>
529{
530    pub fn new<S: Into<String>>(name: S, capacity: usize) -> Self {
531        let name = name.into();
532        let cache = DashMap::new();
533
534        let lookup_counter = CACHE_LOOKUP
535            .get_metric_with_label_values(&[&name])
536            .expect("failed to get counter");
537        let hit_counter = CACHE_HIT
538            .get_metric_with_label_values(&[&name])
539            .expect("failed to get counter");
540        let stale_counter = CACHE_STALE
541            .get_metric_with_label_values(&[&name])
542            .expect("failed to get counter");
543        let evict_counter = CACHE_EVICT
544            .get_metric_with_label_values(&[&name])
545            .expect("failed to get counter");
546        let expire_counter = CACHE_EXPIRE
547            .get_metric_with_label_values(&[&name])
548            .expect("failed to get counter");
549        let miss_counter = CACHE_MISS
550            .get_metric_with_label_values(&[&name])
551            .expect("failed to get counter");
552        let populate_counter = CACHE_POPULATED
553            .get_metric_with_label_values(&[&name])
554            .expect("failed to get counter");
555        let insert_counter = CACHE_INSERT
556            .get_metric_with_label_values(&[&name])
557            .expect("failed to get counter");
558        let error_counter = CACHE_ERROR
559            .get_metric_with_label_values(&[&name])
560            .expect("failed to get counter");
561        let wait_gauge = CACHE_WAIT
562            .get_metric_with_label_values(&[&name])
563            .expect("failed to get counter");
564        let size_gauge = CACHE_SIZE
565            .get_metric_with_label_values(&[&name])
566            .expect("failed to get counter");
567
568        let inner = Arc::new(Inner {
569            name,
570            cache,
571            tick: AtomicUsize::new(0),
572            allow_stale_reads: AtomicBool::new(false),
573            capacity: AtomicUsize::new(capacity),
574            lru_samples: AtomicUsize::new(10),
575            sema_timeout_milliseconds: AtomicUsize::new(120_000),
576            lookup_counter,
577            evict_counter,
578            expire_counter,
579            hit_counter,
580            miss_counter,
581            populate_counter,
582            error_counter,
583            wait_gauge,
584            insert_counter,
585            stale_counter,
586            size_gauge,
587        });
588
589        // Register with the global list of caches using a weak reference.
590        // We need to "erase" the K/V types in order to do that, so we
591        // use the CachePurger trait for this purpose.
592        {
593            let generic: Arc<dyn CachePurger + Send + Sync> = inner.clone();
594            CACHES.lock().push(Arc::downgrade(&generic));
595            tracing::debug!(
596                "registered cache {} with capacity {capacity}",
597                generic.name()
598            );
599        }
600
601        Self { inner }
602    }
603
604    fn allow_stale_reads(&self) -> bool {
605        self.inner.allow_stale_reads.load(Ordering::Relaxed)
606    }
607
608    pub fn set_allow_stale_reads(&self, value: bool) {
609        self.inner.allow_stale_reads.store(value, Ordering::Relaxed);
610    }
611
612    pub fn set_sema_timeout(&self, duration: Duration) {
613        self.inner
614            .sema_timeout_milliseconds
615            .store(duration.as_millis() as usize, Ordering::Relaxed);
616    }
617
618    pub fn clear(&self) -> usize {
619        self.inner.clear()
620    }
621
622    fn inc_tick(&self) -> usize {
623        self.inner.tick.fetch_add(1, Ordering::Relaxed) + 1
624    }
625
626    fn update_tick(&self, item: &Item<V>) {
627        let v = self.inc_tick();
628        item.last_tick.store(v, Ordering::Relaxed);
629    }
630
631    pub fn lookup<Q: ?Sized>(&self, name: &Q) -> Option<ItemLookup<V>>
632    where
633        K: Borrow<Q>,
634        Q: Hash + Eq,
635    {
636        self.inner.lookup_counter.inc();
637        match self.inner.cache.get_mut(name) {
638            None => {
639                self.inner.miss_counter.inc();
640                None
641            }
642            Some(entry) => {
643                match &entry.item {
644                    ItemState::Present(item) => {
645                        let now = Instant::now();
646                        if now >= entry.expiration {
647                            // Expired
648                            if self.allow_stale_reads() {
649                                // We don't furnish a result directly, but we
650                                // also do not want to remove it from the map
651                                // at this stage.
652                                // We're assuming that lookup() is called only
653                                // via get_or_try_insert when allow_stale_reads
654                                // is enabled.
655                                self.inner.miss_counter.inc();
656                                return None;
657                            }
658
659                            // otherwise: remove it from the map.
660                            // Take care to drop our ref first so that we don't
661                            // self-deadlock
662                            drop(entry);
663                            if self
664                                .inner
665                                .cache
666                                .remove_if(name, |_k, entry| now >= entry.expiration)
667                                .is_some()
668                            {
669                                self.inner.expire_counter.inc();
670                                self.inner.size_gauge.set(self.inner.cache.len() as i64);
671                            }
672                            self.inner.miss_counter.inc();
673                            return None;
674                        }
675                        self.inner.hit_counter.inc();
676                        self.update_tick(&entry);
677                        Some(ItemLookup {
678                            item: item.clone(),
679                            expiration: entry.expiration,
680                            is_fresh: false,
681                        })
682                    }
683                    ItemState::Refreshing { .. } | ItemState::Pending(_) | ItemState::Failed(_) => {
684                        self.inner.miss_counter.inc();
685                        None
686                    }
687                }
688            }
689        }
690    }
691
692    pub fn get<Q: ?Sized>(&self, name: &Q) -> Option<V>
693    where
694        K: Borrow<Q>,
695        Q: Hash + Eq,
696    {
697        self.lookup(name).map(|lookup| lookup.item)
698    }
699
700    pub async fn insert(&self, name: K, item: V, expiration: Instant) -> V {
701        self.inner.cache.insert(
702            name,
703            Item {
704                item: ItemState::Present(item.clone()),
705                expiration,
706                last_tick: self.inc_tick().into(),
707            },
708        );
709
710        self.inner.insert_counter.inc();
711        self.inner.size_gauge.set(self.inner.cache.len() as i64);
712        self.inner.maybe_evict();
713
714        item
715    }
716
717    fn clone_item_state(&self, name: &K) -> (ItemState<V>, Instant) {
718        let mut is_new = false;
719        let mut entry = self.inner.cache.entry(name.clone()).or_insert_with(|| {
720            is_new = true;
721            Item {
722                item: ItemState::Pending(Arc::new(Semaphore::new(1))),
723                expiration: Instant::now() + Duration::from_secs(60),
724                last_tick: self.inc_tick().into(),
725            }
726        });
727
728        match &entry.value().item {
729            ItemState::Pending(sema) => {
730                if sema.is_closed() {
731                    entry.value_mut().item = ItemState::Pending(Arc::new(Semaphore::new(1)));
732                }
733            }
734            ItemState::Refreshing {
735                stale_value,
736                pending,
737            } => {
738                if pending.is_closed() {
739                    entry.value_mut().item = ItemState::Refreshing {
740                        stale_value: stale_value.clone(),
741                        pending: Arc::new(Semaphore::new(1)),
742                    };
743                }
744            }
745            ItemState::Present(item) => {
746                let now = Instant::now();
747                if now >= entry.expiration {
748                    // Expired; we will need to fetch it
749                    let pending = Arc::new(Semaphore::new(1));
750                    if self.allow_stale_reads() {
751                        entry.value_mut().item = ItemState::Refreshing {
752                            stale_value: item.clone(),
753                            pending,
754                        };
755                    } else {
756                        entry.value_mut().item = ItemState::Pending(pending);
757                    }
758                }
759            }
760            ItemState::Failed(_) => {
761                let now = Instant::now();
762                if now >= entry.expiration {
763                    // Expired; we will need to fetch it
764                    entry.value_mut().item = ItemState::Pending(Arc::new(Semaphore::new(1)));
765                }
766            }
767        }
768
769        self.update_tick(&entry);
770        let item = entry.value();
771        let result = (item.item.clone(), entry.expiration);
772        drop(entry);
773
774        if is_new {
775            self.inner.size_gauge.set(self.inner.cache.len() as i64);
776            self.inner.maybe_evict();
777        }
778
779        result
780    }
781
782    /// Get an existing item, but if that item doesn't already exist,
783    /// execute the future `fut` to provide a value that will be inserted and then
784    /// returned.  This is done atomically wrt. other callers.
785    /// The TTL parameter is a function that can extract the TTL from the value type,
786    /// or just return a constant TTL.
787    pub async fn get_or_try_insert<E: Into<anyhow::Error>, TTL: FnOnce(&V) -> Duration>(
788        &self,
789        name: &K,
790        ttl_func: TTL,
791        fut: impl Future<Output = Result<V, E>>,
792    ) -> Result<ItemLookup<V>, Arc<anyhow::Error>> {
793        // Fast path avoids cloning the key
794        if let Some(entry) = self.lookup(name) {
795            return Ok(entry);
796        }
797
798        let timeout_duration = Duration::from_millis(
799            self.inner.sema_timeout_milliseconds.load(Ordering::Relaxed) as u64,
800        );
801        let start = Instant::now();
802        let deadline = start + timeout_duration;
803
804        // Note: the lookup call increments lookup_counter and miss_counter
805        'retry: loop {
806            let (stale_value, sema) = match self.clone_item_state(name) {
807                (ItemState::Present(item), expiration) => {
808                    return Ok(ItemLookup {
809                        item,
810                        expiration,
811                        is_fresh: false,
812                    });
813                }
814                (ItemState::Failed(error), _) => {
815                    return Err(error);
816                }
817                (
818                    ItemState::Refreshing {
819                        stale_value,
820                        pending,
821                    },
822                    expiration,
823                ) => (Some((stale_value, expiration)), pending),
824                (ItemState::Pending(sema), _) => (None, sema),
825            };
826
827            let wait_result = {
828                self.inner.wait_gauge.inc();
829                defer! {
830                    self.inner.wait_gauge.dec();
831                }
832
833                match timeout_at(deadline, sema.acquire_owned()).await {
834                    Err(_) => {
835                        if let Some((item, expiration)) = stale_value {
836                            tracing::debug!(
837                                "{} semaphore acquire for {name:?} timed out after \
838                                {timeout_duration:?}, allowing stale value to satisfy the lookup",
839                                self.inner.name
840                            );
841                            self.inner.stale_counter.inc();
842                            return Ok(ItemLookup {
843                                item,
844                                expiration,
845                                is_fresh: false,
846                            });
847                        }
848                        tracing::debug!(
849                            "{} semaphore acquire for {name:?} timed out after \
850                                {timeout_duration:?}, returning error",
851                            self.inner.name
852                        );
853
854                        self.inner.error_counter.inc();
855                        return Err(Arc::new(anyhow::anyhow!(
856                            "{} lookup for {name:?} \
857                            timed out after {timeout_duration:?} \
858                            on semaphore acquire while waiting for cache to populate",
859                            self.inner.name
860                        )));
861                    }
862                    Ok(r) => r,
863                }
864            };
865
866            // While we slept, someone else may have satisfied
867            // the lookup; check it
868            let current_sema = match self.clone_item_state(name) {
869                (ItemState::Present(item), expiration) => {
870                    return Ok(ItemLookup {
871                        item,
872                        expiration,
873                        is_fresh: false,
874                    });
875                }
876                (ItemState::Failed(error), _) => {
877                    self.inner.hit_counter.inc();
878                    return Err(error);
879                }
880                (
881                    ItemState::Refreshing {
882                        stale_value: _,
883                        pending,
884                    },
885                    _,
886                ) => pending,
887                (ItemState::Pending(current_sema), _) => current_sema,
888            };
889
890            // It's still outstanding
891            match wait_result {
892                Ok(permit) => {
893                    // We're responsible for resolving it.
894                    // We will always close the semaphore when
895                    // we're done with this logic (and when we unwind
896                    // or are cancelled) so that we can wake up any
897                    // waiters.
898                    // We use defer! for this so that if we are cancelled
899                    // at the await point below, others are still woken up.
900                    defer! {
901                        permit.semaphore().close();
902                    }
903
904                    if !Arc::ptr_eq(&current_sema, permit.semaphore()) {
905                        self.inner.error_counter.inc();
906                        tracing::warn!(
907                            "{} mismatched semaphores for {name:?}, \
908                                    will restart cache resolve.",
909                            self.inner.name
910                        );
911                        continue 'retry;
912                    }
913
914                    self.inner.populate_counter.inc();
915                    let mut ttl = Duration::from_secs(60);
916                    let future_result = fut.await;
917                    let now = Instant::now();
918
919                    let (item_result, return_value) = match future_result {
920                        Ok(item) => {
921                            ttl = ttl_func(&item);
922                            (
923                                ItemState::Present(item.clone()),
924                                Ok(ItemLookup {
925                                    item,
926                                    expiration: now + ttl,
927                                    is_fresh: true,
928                                }),
929                            )
930                        }
931                        Err(err) => {
932                            self.inner.error_counter.inc();
933                            let err = Arc::new(err.into());
934                            (ItemState::Failed(err.clone()), Err(err))
935                        }
936                    };
937
938                    self.inner.cache.insert(
939                        name.clone(),
940                        Item {
941                            item: item_result,
942                            expiration: Instant::now() + ttl,
943                            last_tick: self.inc_tick().into(),
944                        },
945                    );
946                    self.inner.maybe_evict();
947                    return return_value;
948                }
949                Err(_) => {
950                    self.inner.error_counter.inc();
951
952                    // semaphore was closed, but the status is
953                    // still somehow pending
954                    tracing::debug!(
955                        "{} lookup for {name:?} woke up semaphores \
956                                but is still marked pending, \
957                                will restart cache lookup",
958                        self.inner.name
959                    );
960                    continue 'retry;
961                }
962            }
963        }
964    }
965}
966
967#[cfg(test)]
968mod test {
969    use super::*;
970    use test_log::test; // run with RUST_LOG=lruttl=trace to trace
971
972    #[test(tokio::test)]
973    async fn test_capacity() {
974        let cache = LruCacheWithTtl::new("test_capacity", 40);
975
976        let expiration = Instant::now() + Duration::from_secs(60);
977        for i in 0..100 {
978            cache.insert(i, i, expiration).await;
979        }
980
981        assert_eq!(cache.inner.cache.len(), 40, "capacity is respected");
982    }
983
984    #[test(tokio::test)]
985    async fn test_expiration() {
986        let cache = LruCacheWithTtl::new("test_expiration", 1);
987
988        tokio::time::pause();
989        let expiration = Instant::now() + Duration::from_secs(1);
990        cache.insert(0, 0, expiration).await;
991
992        cache.get(&0).expect("still in cache");
993        tokio::time::advance(Duration::from_secs(2)).await;
994        assert!(cache.get(&0).is_none(), "evicted due to ttl");
995    }
996
997    #[test(tokio::test)]
998    async fn test_over_capacity_slow_resolve() {
999        let cache = Arc::new(LruCacheWithTtl::<String, u64>::new(
1000            "test_over_capacity_slow_resolve",
1001            1,
1002        ));
1003
1004        let mut foos = vec![];
1005        for idx in 0..2 {
1006            let cache = cache.clone();
1007            foos.push(tokio::spawn(async move {
1008                eprintln!("spawned task {idx} is running");
1009                cache
1010                    .get_or_try_insert(&"foo".to_string(), |_| Duration::from_secs(86400), async {
1011                        if idx == 0 {
1012                            eprintln!("foo {idx} getter sleeping");
1013                            tokio::time::sleep(Duration::from_secs(300)).await;
1014                        }
1015                        eprintln!("foo {idx} getter done");
1016                        Ok::<_, anyhow::Error>(idx)
1017                    })
1018                    .await
1019            }));
1020        }
1021
1022        tokio::task::yield_now().await;
1023
1024        eprintln!("calling again with immediate getter");
1025        let result = cache
1026            .get_or_try_insert(&"bar".to_string(), |_| Duration::from_secs(60), async {
1027                eprintln!("bar immediate getter running");
1028                Ok::<_, anyhow::Error>(42)
1029            })
1030            .await
1031            .unwrap();
1032
1033        assert_eq!(result.item, 42);
1034        assert_eq!(cache.inner.cache.len(), 1);
1035
1036        eprintln!("aborting first one");
1037        foos.remove(0).abort();
1038
1039        eprintln!("try new key");
1040        let result = cache
1041            .get_or_try_insert(&"baz".to_string(), |_| Duration::from_secs(60), async {
1042                eprintln!("baz immediate getter running");
1043                Ok::<_, anyhow::Error>(32)
1044            })
1045            .await
1046            .unwrap();
1047        assert_eq!(result.item, 32);
1048        assert_eq!(cache.inner.cache.len(), 1);
1049
1050        eprintln!("waiting second one");
1051        assert_eq!(1, foos.pop().unwrap().await.unwrap().unwrap().item);
1052
1053        assert_eq!(cache.inner.cache.len(), 1);
1054    }
1055}