lruttl/
lib.rs

1use crate::metrics::*;
2use dashmap::DashMap;
3use kumo_prometheus::prometheus::{IntCounter, IntGauge};
4use kumo_server_memory::subscribe_to_memory_status_changes_async;
5pub use linkme;
6use parking_lot::Mutex;
7pub use pastey as paste;
8use scopeguard::defer;
9use std::borrow::Borrow;
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::future::Future;
13use std::hash::Hash;
14use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
15use std::sync::{Arc, LazyLock, Weak};
16use tokio::sync::Semaphore;
17use tokio::time::{timeout_at, Duration, Instant};
18
19mod metrics;
20
21static CACHES: LazyLock<Mutex<Vec<Weak<dyn CachePurger + Send + Sync>>>> =
22    LazyLock::new(Mutex::default);
23
24struct Inner<K: Clone + Hash + Eq + Debug, V: Clone + Send + Sync + Debug> {
25    name: String,
26    tick: AtomicUsize,
27    capacity: AtomicUsize,
28    allow_stale_reads: AtomicBool,
29    cache: DashMap<K, Item<V>>,
30    lru_samples: AtomicUsize,
31    sema_timeout_milliseconds: AtomicUsize,
32    lookup_counter: IntCounter,
33    evict_counter: IntCounter,
34    expire_counter: IntCounter,
35    hit_counter: IntCounter,
36    miss_counter: IntCounter,
37    populate_counter: IntCounter,
38    insert_counter: IntCounter,
39    stale_counter: IntCounter,
40    error_counter: IntCounter,
41    wait_gauge: IntGauge,
42    size_gauge: IntGauge,
43}
44
45impl<
46        K: Clone + Debug + Send + Sync + Hash + Eq + 'static,
47        V: Clone + Debug + Send + Sync + 'static,
48    > Inner<K, V>
49{
50    pub fn clear(&self) -> usize {
51        let num_entries = self.cache.len();
52
53        // We don't simply clear all elements here, as any pending
54        // items will be trying to wait to coordinate; we need
55        // to aggressively close the semaphore and wake them all up
56        // before we remove those entries.
57        self.cache.retain(|_k, item| {
58            if let ItemState::Pending(sema) = &item.item {
59                // Force everyone to wakeup and error out
60                sema.close();
61            }
62            false
63        });
64
65        self.size_gauge.set(self.cache.len() as i64);
66        num_entries
67    }
68
69    /// Evict up to target entries.
70    ///
71    /// We use a probablistic approach to the LRU, because
72    /// it is challenging to safely thread the classic doubly-linked-list
73    /// through dashmap.
74    ///
75    /// target is bounded to half of number of selected samples, in
76    /// order to ensure that we don't randomly pick the newest element
77    /// from the set when under pressure.
78    ///
79    /// Redis uses a similar technique for its LRU as described
80    /// in <https://redis.io/docs/latest/develop/reference/eviction/#apx-lru>
81    /// which suggests that sampling 10 keys at random to them compare
82    /// their recency yields a reasonably close approximation to the
83    /// 100% precise LRU.
84    ///
85    /// Since we also support TTLs, we'll just go ahead and remove
86    /// any expired keys that show up in the sampled set.
87    pub fn evict_some(&self, target: usize) -> usize {
88        let now = Instant::now();
89
90        // Approximate (since it could change immediately after reading)
91        // cache size
92        let cache_size = self.cache.len();
93        // How many keys to sample
94        let num_samples = self.lru_samples.load(Ordering::Relaxed).min(cache_size);
95
96        // a list of keys which have expired
97        let mut expired_keys = vec![];
98        // a random selection of up to num_samples (key, tick) tuples
99        let mut samples = vec![];
100
101        // Pick some random keys.
102        // The rand crate has some helpers for working with iterators,
103        // but they appear to copy many elements into an internal buffer
104        // in order to make a selection, and we want to avoid directly
105        // considering every possible element because some users have
106        // very large capacity caches.
107        //
108        // The approach taken here is to produce a random list of iterator
109        // offsets so that we can skim across the iterator in a single
110        // pass and pull out a random selection of elements.
111        // The sample function provides a randomized list of indices that
112        // we can use for this; we need to sort it first, but the cost
113        // should be reasonably low as num_samples should be ~10 or so
114        // in the most common configuration.
115        {
116            let mut rng = rand::thread_rng();
117            let mut indices =
118                rand::seq::index::sample(&mut rng, cache_size, num_samples).into_vec();
119            indices.sort();
120            let mut iter = self.cache.iter();
121            let mut current_idx = 0;
122
123            /// Advance an iterator by skip_amount.
124            /// Ideally we'd use Iterator::advance_by for this, but at the
125            /// time of writing that method is nightly only.
126            /// Note that it also uses next() internally anyway
127            fn advance_by(iter: &mut impl Iterator, skip_amount: usize) {
128                for _ in 0..skip_amount {
129                    if iter.next().is_none() {
130                        return;
131                    }
132                }
133            }
134
135            for idx in indices {
136                // idx is the index we want to be on; we'll need to skip ahead
137                // by some number of slots based on the current one. skip_amount
138                // is that number.
139                let skip_amount = idx - current_idx;
140                advance_by(&mut iter, skip_amount);
141
142                match iter.next() {
143                    Some(map_entry) => {
144                        current_idx = idx + 1;
145                        let item = map_entry.value();
146                        match &item.item {
147                            ItemState::Pending(_) | ItemState::Refreshing { .. } => {
148                                // Cannot evict a pending lookup
149                            }
150                            ItemState::Present(_) | ItemState::Failed(_) => {
151                                if now >= item.expiration {
152                                    expired_keys.push(map_entry.key().clone());
153                                } else {
154                                    let last_tick = item.last_tick.load(Ordering::Relaxed);
155                                    samples.push((map_entry.key().clone(), last_tick));
156                                }
157                            }
158                        }
159                    }
160                    None => {
161                        break;
162                    }
163                }
164            }
165        }
166
167        let mut num_removed = 0;
168        for key in expired_keys {
169            // Sanity check that it is still expired before removing it,
170            // because it would be a shame to remove it if another actor
171            // has just updated it
172            let removed = self
173                .cache
174                .remove_if(&key, |_k, entry| now >= entry.expiration)
175                .is_some();
176            if removed {
177                tracing::trace!("{} expired {key:?}", self.name);
178                num_removed += 1;
179                self.expire_counter.inc();
180            }
181        }
182
183        // Since we're picking random elements, we want to ensure that
184        // we never pick the newest element from the set to evict because
185        // that is likely the wrong choice. We need enough samples to
186        // know that the lowest number we picked is representative
187        // of the eldest element in the map overall.
188        // We limit ourselves to half of the number of selected samples.
189        let target = target.min(samples.len() / 2).max(1);
190
191        // If we met our target, skip the extra work below
192        if num_removed >= target {
193            self.size_gauge.set(self.cache.len() as i64);
194            tracing::trace!("{} expired {num_removed} of target {target}", self.name);
195            return num_removed;
196        }
197
198        // Sort by ascending tick, which is equivalent to having the
199        // LRU within that set towards the front of the vec
200        samples.sort_by(|(_ka, tick_a), (_kb, tick_b)| tick_a.cmp(tick_b));
201
202        for (key, tick) in samples {
203            // Sanity check that the tick value is the same as we expect.
204            // If it has changed since we sampled it, then that element
205            // is no longer a good candidate for LRU eviction.
206            if self
207                .cache
208                .remove_if(&key, |_k, item| {
209                    item.last_tick.load(Ordering::Relaxed) == tick
210                })
211                .is_some()
212            {
213                tracing::debug!("{} evicted {key:?}", self.name);
214                num_removed += 1;
215                self.evict_counter.inc();
216                self.size_gauge.set(self.cache.len() as i64);
217                if num_removed >= target {
218                    return num_removed;
219                }
220            }
221        }
222
223        if num_removed == 0 {
224            tracing::debug!(
225                "{} did not find anything to evict, target was {target}",
226                self.name
227            );
228        }
229
230        tracing::trace!("{} removed {num_removed} of target {target}", self.name);
231
232        num_removed
233    }
234
235    /// Potentially make some progress to get back under
236    /// budget on the cache capacity
237    pub fn maybe_evict(&self) -> usize {
238        let cache_size = self.cache.len();
239        let capacity = self.capacity.load(Ordering::Relaxed);
240        if cache_size > capacity {
241            self.evict_some(cache_size - capacity)
242        } else {
243            0
244        }
245    }
246}
247
248trait CachePurger {
249    fn name(&self) -> &str;
250    fn purge(&self) -> usize;
251    fn process_expirations(&self) -> usize;
252    fn update_capacity(&self, capacity: usize);
253}
254
255impl<
256        K: Clone + Debug + Send + Sync + Hash + Eq + 'static,
257        V: Clone + Debug + Send + Sync + 'static,
258    > CachePurger for Inner<K, V>
259{
260    fn name(&self) -> &str {
261        &self.name
262    }
263    fn purge(&self) -> usize {
264        self.clear()
265    }
266    fn process_expirations(&self) -> usize {
267        let now = Instant::now();
268        let mut expired_keys = vec![];
269        for map_entry in self.cache.iter() {
270            let item = map_entry.value();
271            match &item.item {
272                ItemState::Pending(_) | ItemState::Refreshing { .. } => {
273                    // Cannot evict a pending lookup
274                }
275                ItemState::Present(_) | ItemState::Failed(_) => {
276                    if now >= item.expiration {
277                        expired_keys.push(map_entry.key().clone());
278                    }
279                }
280            }
281        }
282
283        let mut num_removed = 0;
284        for key in expired_keys {
285            // Sanity check that it is still expired before removing it,
286            // because it would be a shame to remove it if another actor
287            // has just updated it
288            let removed = self
289                .cache
290                .remove_if(&key, |_k, entry| now >= entry.expiration)
291                .is_some();
292            if removed {
293                num_removed += 1;
294                self.expire_counter.inc();
295                self.size_gauge.set(self.cache.len() as i64);
296            }
297        }
298
299        num_removed + self.maybe_evict()
300    }
301
302    fn update_capacity(&self, capacity: usize) {
303        self.capacity.store(capacity, Ordering::Relaxed);
304        // Bring it within capacity.
305        // At the time of writing this is a bit half-hearted,
306        // but we'll eventually trim down via ongoing process_expirations()
307        // calls
308        self.process_expirations();
309    }
310}
311
312fn all_caches() -> Vec<Arc<dyn CachePurger + Send + Sync>> {
313    let mut result = vec![];
314    let mut caches = CACHES.lock();
315    caches.retain(|entry| match entry.upgrade() {
316        Some(purger) => {
317            result.push(purger);
318            true
319        }
320        None => false,
321    });
322    result
323}
324
325pub fn purge_all_caches() {
326    let purgers = all_caches();
327
328    tracing::error!("purging {} caches", purgers.len());
329    for purger in purgers {
330        let name = purger.name();
331        let num_entries = purger.purge();
332        tracing::error!("cleared {num_entries} entries from cache {name}");
333    }
334}
335
336async fn prune_expired_caches() {
337    loop {
338        tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
339        let purgers = all_caches();
340
341        for p in purgers {
342            let n = p.process_expirations();
343            if n > 0 {
344                tracing::debug!("expired {n} entries from cache {}", p.name());
345            }
346        }
347    }
348}
349
350#[linkme::distributed_slice]
351pub static LRUTTL_VIVIFY: [fn() -> CacheDefinition];
352
353#[macro_export]
354macro_rules! optional_doc {
355    ($doc:expr) => {
356        Some($doc.trim())
357    };
358    ($($doc:expr)+) => {
359        Some(concat!($($doc,)+).trim())
360    };
361    () => {
362        None
363    };
364}
365
366/// Declare a cache as a static, and link it into the list of possible
367/// pre-defined caches.
368///
369/// Due to a limitation in implementation details, you must also add
370/// `linkme.workspace = true` to the manifest of the crate where you
371/// use this macro.
372#[macro_export]
373macro_rules! declare_cache {
374    (
375     $(#[doc = $doc:expr])*
376     $vis:vis
377        static $sym:ident:
378        LruCacheWithTtl<$key:ty, $value:ty>::new($name:expr, $capacity:expr);
379    ) => {
380        $(#[doc = $doc])*
381        $vis static $sym: ::std::sync::LazyLock<$crate::LruCacheWithTtl<$key, $value>> =
382            ::std::sync::LazyLock::new(
383                || $crate::LruCacheWithTtl::new($name, $capacity));
384
385        // Link into LRUTTL_VIVIFY
386        $crate::paste::paste! {
387            #[linkme::distributed_slice($crate::LRUTTL_VIVIFY)]
388            static [<VIVIFY_ $sym>]: fn() -> $crate::CacheDefinition = || {
389                ::std::sync::LazyLock::force(&$sym);
390                $crate::CacheDefinition {
391                    name: $name,
392                    capacity: $capacity,
393                    doc: $crate::optional_doc!($($doc)*),
394                }
395            };
396        }
397    };
398}
399
400/// Ensure that all caches declared via declare_cache!
401/// have been instantiated and returns the set of names.
402fn vivify() {
403    LazyLock::force(&PREDEFINED_CACHES);
404}
405
406fn vivify_impl() -> HashMap<&'static str, CacheDefinition> {
407    let mut map = HashMap::new();
408
409    for vivify_func in LRUTTL_VIVIFY {
410        let definition = vivify_func();
411        assert!(
412            !map.contains_key(definition.name),
413            "duplicate cache name {}",
414            definition.name
415        );
416        map.insert(definition.name, definition);
417    }
418
419    map
420}
421
422#[derive(serde::Serialize)]
423pub struct CacheDefinition {
424    pub name: &'static str,
425    pub capacity: usize,
426    pub doc: Option<&'static str>,
427}
428
429static PREDEFINED_CACHES: LazyLock<HashMap<&'static str, CacheDefinition>> =
430    LazyLock::new(vivify_impl);
431
432pub fn get_definitions() -> Vec<&'static CacheDefinition> {
433    let mut defs = PREDEFINED_CACHES.values().collect::<Vec<_>>();
434    defs.sort_by(|a, b| a.name.cmp(&b.name));
435    defs
436}
437
438pub fn is_name_available(name: &str) -> bool {
439    !PREDEFINED_CACHES.contains_key(name)
440}
441
442/// Update the capacity value for a pre-defined cache
443pub fn set_cache_capacity(name: &str, capacity: usize) -> bool {
444    if !PREDEFINED_CACHES.contains_key(name) {
445        return false;
446    }
447    let caches = all_caches();
448    match caches.iter().find(|p| p.name() == name) {
449        Some(p) => {
450            p.update_capacity(capacity);
451            true
452        }
453        None => false,
454    }
455}
456
457pub fn spawn_memory_monitor() {
458    vivify();
459    tokio::spawn(purge_caches_on_memory_shortage());
460    tokio::spawn(prune_expired_caches());
461}
462
463async fn purge_caches_on_memory_shortage() {
464    tracing::debug!("starting memory monitor");
465    let mut memory_status = subscribe_to_memory_status_changes_async().await;
466    while let Ok(()) = memory_status.changed().await {
467        if kumo_server_memory::get_headroom() == 0 {
468            purge_all_caches();
469
470            // Wait a little bit so that we can debounce
471            // in the case where we're riding the cusp of
472            // the limit and would thrash the caches
473            tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
474        }
475    }
476}
477
478#[derive(Debug, Clone)]
479enum ItemState<V>
480where
481    V: Send,
482    V: Sync,
483{
484    Present(V),
485    Pending(Arc<Semaphore>),
486    Failed(Arc<anyhow::Error>),
487    Refreshing {
488        stale_value: V,
489        pending: Arc<Semaphore>,
490    },
491}
492
493#[derive(Debug)]
494struct Item<V>
495where
496    V: Send,
497    V: Sync,
498{
499    item: ItemState<V>,
500    expiration: Instant,
501    last_tick: AtomicUsize,
502}
503
504impl<V: Clone + Send + Sync> Clone for Item<V> {
505    fn clone(&self) -> Self {
506        Self {
507            item: self.item.clone(),
508            expiration: self.expiration,
509            last_tick: self.last_tick.load(Ordering::Relaxed).into(),
510        }
511    }
512}
513
514#[derive(Debug)]
515pub struct ItemLookup<V: Debug> {
516    /// A copy of the item
517    pub item: V,
518    /// If true, the get_or_try_insert operation populated the entry;
519    /// the operation was a cache miss
520    pub is_fresh: bool,
521    /// The instant at which this entry will expire
522    pub expiration: Instant,
523}
524
525pub struct LruCacheWithTtl<K: Clone + Debug + Hash + Eq, V: Clone + Debug + Send + Sync> {
526    inner: Arc<Inner<K, V>>,
527}
528
529impl<
530        K: Clone + Debug + Hash + Eq + Send + Sync + std::fmt::Debug + 'static,
531        V: Clone + Debug + Send + Sync + 'static,
532    > LruCacheWithTtl<K, V>
533{
534    pub fn new<S: Into<String>>(name: S, capacity: usize) -> Self {
535        let name = name.into();
536        let cache = DashMap::new();
537
538        let lookup_counter = CACHE_LOOKUP
539            .get_metric_with_label_values(&[&name])
540            .expect("failed to get counter");
541        let hit_counter = CACHE_HIT
542            .get_metric_with_label_values(&[&name])
543            .expect("failed to get counter");
544        let stale_counter = CACHE_STALE
545            .get_metric_with_label_values(&[&name])
546            .expect("failed to get counter");
547        let evict_counter = CACHE_EVICT
548            .get_metric_with_label_values(&[&name])
549            .expect("failed to get counter");
550        let expire_counter = CACHE_EXPIRE
551            .get_metric_with_label_values(&[&name])
552            .expect("failed to get counter");
553        let miss_counter = CACHE_MISS
554            .get_metric_with_label_values(&[&name])
555            .expect("failed to get counter");
556        let populate_counter = CACHE_POPULATED
557            .get_metric_with_label_values(&[&name])
558            .expect("failed to get counter");
559        let insert_counter = CACHE_INSERT
560            .get_metric_with_label_values(&[&name])
561            .expect("failed to get counter");
562        let error_counter = CACHE_ERROR
563            .get_metric_with_label_values(&[&name])
564            .expect("failed to get counter");
565        let wait_gauge = CACHE_WAIT
566            .get_metric_with_label_values(&[&name])
567            .expect("failed to get counter");
568        let size_gauge = CACHE_SIZE
569            .get_metric_with_label_values(&[&name])
570            .expect("failed to get counter");
571
572        let inner = Arc::new(Inner {
573            name,
574            cache,
575            tick: AtomicUsize::new(0),
576            allow_stale_reads: AtomicBool::new(false),
577            capacity: AtomicUsize::new(capacity),
578            lru_samples: AtomicUsize::new(10),
579            sema_timeout_milliseconds: AtomicUsize::new(120_000),
580            lookup_counter,
581            evict_counter,
582            expire_counter,
583            hit_counter,
584            miss_counter,
585            populate_counter,
586            error_counter,
587            wait_gauge,
588            insert_counter,
589            stale_counter,
590            size_gauge,
591        });
592
593        // Register with the global list of caches using a weak reference.
594        // We need to "erase" the K/V types in order to do that, so we
595        // use the CachePurger trait for this purpose.
596        {
597            let generic: Arc<dyn CachePurger + Send + Sync> = inner.clone();
598            CACHES.lock().push(Arc::downgrade(&generic));
599            tracing::debug!(
600                "registered cache {} with capacity {capacity}",
601                generic.name()
602            );
603        }
604
605        Self { inner }
606    }
607
608    fn allow_stale_reads(&self) -> bool {
609        self.inner.allow_stale_reads.load(Ordering::Relaxed)
610    }
611
612    pub fn set_allow_stale_reads(&self, value: bool) {
613        self.inner.allow_stale_reads.store(value, Ordering::Relaxed);
614    }
615
616    pub fn set_sema_timeout(&self, duration: Duration) {
617        self.inner
618            .sema_timeout_milliseconds
619            .store(duration.as_millis() as usize, Ordering::Relaxed);
620    }
621
622    pub fn clear(&self) -> usize {
623        self.inner.clear()
624    }
625
626    fn inc_tick(&self) -> usize {
627        self.inner.tick.fetch_add(1, Ordering::Relaxed) + 1
628    }
629
630    fn update_tick(&self, item: &Item<V>) {
631        let v = self.inc_tick();
632        item.last_tick.store(v, Ordering::Relaxed);
633    }
634
635    pub fn lookup<Q: ?Sized>(&self, name: &Q) -> Option<ItemLookup<V>>
636    where
637        K: Borrow<Q>,
638        Q: Hash + Eq,
639    {
640        self.inner.lookup_counter.inc();
641        match self.inner.cache.get_mut(name) {
642            None => {
643                self.inner.miss_counter.inc();
644                None
645            }
646            Some(entry) => {
647                match &entry.item {
648                    ItemState::Present(item) => {
649                        let now = Instant::now();
650                        if now >= entry.expiration {
651                            // Expired
652                            if self.allow_stale_reads() {
653                                // We don't furnish a result directly, but we
654                                // also do not want to remove it from the map
655                                // at this stage.
656                                // We're assuming that lookup() is called only
657                                // via get_or_try_insert when allow_stale_reads
658                                // is enabled.
659                                self.inner.miss_counter.inc();
660                                return None;
661                            }
662
663                            // otherwise: remove it from the map.
664                            // Take care to drop our ref first so that we don't
665                            // self-deadlock
666                            drop(entry);
667                            if self
668                                .inner
669                                .cache
670                                .remove_if(name, |_k, entry| now >= entry.expiration)
671                                .is_some()
672                            {
673                                self.inner.expire_counter.inc();
674                                self.inner.size_gauge.set(self.inner.cache.len() as i64);
675                            }
676                            self.inner.miss_counter.inc();
677                            return None;
678                        }
679                        self.inner.hit_counter.inc();
680                        self.update_tick(&entry);
681                        Some(ItemLookup {
682                            item: item.clone(),
683                            expiration: entry.expiration,
684                            is_fresh: false,
685                        })
686                    }
687                    ItemState::Refreshing { .. } | ItemState::Pending(_) | ItemState::Failed(_) => {
688                        self.inner.miss_counter.inc();
689                        None
690                    }
691                }
692            }
693        }
694    }
695
696    pub fn get<Q: ?Sized>(&self, name: &Q) -> Option<V>
697    where
698        K: Borrow<Q>,
699        Q: Hash + Eq,
700    {
701        self.lookup(name).map(|lookup| lookup.item)
702    }
703
704    pub async fn insert(&self, name: K, item: V, expiration: Instant) -> V {
705        self.inner.cache.insert(
706            name,
707            Item {
708                item: ItemState::Present(item.clone()),
709                expiration,
710                last_tick: self.inc_tick().into(),
711            },
712        );
713
714        self.inner.insert_counter.inc();
715        self.inner.size_gauge.set(self.inner.cache.len() as i64);
716        self.inner.maybe_evict();
717
718        item
719    }
720
721    fn clone_item_state(
722        &self,
723        name: &K,
724        deadline: Instant,
725        timeout_duration: Duration,
726    ) -> (ItemState<V>, Instant) {
727        let mut is_new = false;
728        let mut entry = self.inner.cache.entry(name.clone()).or_insert_with(|| {
729            is_new = true;
730            Item {
731                item: ItemState::Pending(Arc::new(Semaphore::new(1))),
732                expiration: deadline,
733                last_tick: self.inc_tick().into(),
734            }
735        });
736
737        match &entry.value().item {
738            ItemState::Pending(sema) => {
739                if sema.is_closed() {
740                    entry.value_mut().item = ItemState::Pending(Arc::new(Semaphore::new(1)));
741                } else {
742                    let now = Instant::now();
743                    if now >= entry.expiration {
744                        // Exceeded deadline without closing the semaphore.
745                        // Perhaps there is some blocking issue preventing progress?
746                        // Force it into an error state to help unclog the rest of
747                        // the system
748                        tracing::warn!(
749                            "{} semaphore for {name:?} remains open, \
750                            but the lookup state has not been satisfied within the \
751                            populate deadline. Assuming that something is stuck and \
752                            making the now-active caller responsible for populating \
753                            this entry",
754                            self.inner.name
755                        );
756                        entry.value_mut().item = ItemState::Pending(Arc::new(Semaphore::new(1)));
757                        entry.value_mut().expiration = now + timeout_duration;
758                    }
759                }
760            }
761            ItemState::Refreshing {
762                stale_value,
763                pending,
764            } => {
765                if pending.is_closed() {
766                    entry.value_mut().item = ItemState::Refreshing {
767                        stale_value: stale_value.clone(),
768                        pending: Arc::new(Semaphore::new(1)),
769                    };
770                }
771            }
772            ItemState::Present(item) => {
773                let now = Instant::now();
774                if now >= entry.expiration {
775                    // Expired; we will need to fetch it
776                    let pending = Arc::new(Semaphore::new(1));
777                    if self.allow_stale_reads() {
778                        entry.value_mut().item = ItemState::Refreshing {
779                            stale_value: item.clone(),
780                            pending,
781                        };
782                    } else {
783                        entry.value_mut().item = ItemState::Pending(pending);
784                    }
785                }
786            }
787            ItemState::Failed(_) => {
788                let now = Instant::now();
789                if now >= entry.expiration {
790                    // Expired; we will need to fetch it
791                    entry.value_mut().item = ItemState::Pending(Arc::new(Semaphore::new(1)));
792                    entry.value_mut().expiration = now + timeout_duration;
793                }
794            }
795        }
796
797        self.update_tick(&entry);
798        let item = entry.value();
799        let result = (item.item.clone(), entry.expiration);
800        drop(entry);
801
802        if is_new {
803            self.inner.size_gauge.set(self.inner.cache.len() as i64);
804            self.inner.maybe_evict();
805        }
806
807        result
808    }
809
810    /// Get an existing item, but if that item doesn't already exist,
811    /// execute the future `fut` to provide a value that will be inserted and then
812    /// returned.  This is done atomically wrt. other callers.
813    /// The TTL parameter is a function that can extract the TTL from the value type,
814    /// or just return a constant TTL.
815    pub async fn get_or_try_insert<E: Into<anyhow::Error>, TTL: FnOnce(&V) -> Duration>(
816        &self,
817        name: &K,
818        ttl_func: TTL,
819        fut: impl Future<Output = Result<V, E>>,
820    ) -> Result<ItemLookup<V>, Arc<anyhow::Error>> {
821        // Fast path avoids cloning the key
822        if let Some(entry) = self.lookup(name) {
823            return Ok(entry);
824        }
825
826        let timeout_duration = Duration::from_millis(
827            self.inner.sema_timeout_milliseconds.load(Ordering::Relaxed) as u64,
828        );
829        let start = Instant::now();
830        let deadline = start + timeout_duration;
831
832        // Note: the lookup call increments lookup_counter and miss_counter
833        const MAX_ATTEMPTS: usize = 10;
834        'retry: for _ in 0..MAX_ATTEMPTS {
835            let (stale_value, sema) = match self.clone_item_state(name, deadline, timeout_duration)
836            {
837                (ItemState::Present(item), expiration) => {
838                    return Ok(ItemLookup {
839                        item,
840                        expiration,
841                        is_fresh: false,
842                    });
843                }
844                (ItemState::Failed(error), _) => {
845                    return Err(error);
846                }
847                (
848                    ItemState::Refreshing {
849                        stale_value,
850                        pending,
851                    },
852                    expiration,
853                ) => (Some((stale_value, expiration)), pending),
854                (ItemState::Pending(sema), _) => (None, sema),
855            };
856
857            let wait_result = {
858                self.inner.wait_gauge.inc();
859                defer! {
860                    self.inner.wait_gauge.dec();
861                }
862
863                match timeout_at(deadline, sema.acquire_owned()).await {
864                    Err(_) => {
865                        if let Some((item, expiration)) = stale_value {
866                            tracing::debug!(
867                                "{} semaphore acquire for {name:?} timed out after \
868                                {timeout_duration:?}, allowing stale value to satisfy the lookup",
869                                self.inner.name
870                            );
871                            self.inner.stale_counter.inc();
872                            return Ok(ItemLookup {
873                                item,
874                                expiration,
875                                is_fresh: false,
876                            });
877                        }
878                        tracing::debug!(
879                            "{} semaphore acquire for {name:?} timed out after \
880                                {timeout_duration:?}, returning error",
881                            self.inner.name
882                        );
883
884                        self.inner.error_counter.inc();
885                        return Err(Arc::new(anyhow::anyhow!(
886                            "{} lookup for {name:?} \
887                            timed out after {timeout_duration:?} \
888                            on semaphore acquire while waiting for cache to populate",
889                            self.inner.name
890                        )));
891                    }
892                    Ok(r) => r,
893                }
894            };
895
896            // While we slept, someone else may have satisfied
897            // the lookup; check it
898            let current_sema = match self.clone_item_state(name, deadline, timeout_duration) {
899                (ItemState::Present(item), expiration) => {
900                    return Ok(ItemLookup {
901                        item,
902                        expiration,
903                        is_fresh: false,
904                    });
905                }
906                (ItemState::Failed(error), _) => {
907                    self.inner.hit_counter.inc();
908                    return Err(error);
909                }
910                (
911                    ItemState::Refreshing {
912                        stale_value: _,
913                        pending,
914                    },
915                    _,
916                ) => pending,
917                (ItemState::Pending(current_sema), _) => current_sema,
918            };
919
920            // It's still outstanding
921            match wait_result {
922                Ok(permit) => {
923                    // We're responsible for resolving it.
924                    // We will always close the semaphore when
925                    // we're done with this logic (and when we unwind
926                    // or are cancelled) so that we can wake up any
927                    // waiters.
928                    // We use defer! for this so that if we are cancelled
929                    // at the await point below, others are still woken up.
930                    defer! {
931                        permit.semaphore().close();
932                    }
933
934                    if !Arc::ptr_eq(&current_sema, permit.semaphore()) {
935                        self.inner.error_counter.inc();
936                        tracing::warn!(
937                            "{} mismatched semaphores for {name:?}, \
938                                    will restart cache resolve.",
939                            self.inner.name
940                        );
941                        continue 'retry;
942                    }
943
944                    self.inner.populate_counter.inc();
945                    let mut ttl = Duration::from_secs(60);
946                    let future_result = fut.await;
947                    let now = Instant::now();
948
949                    let (item_result, return_value) = match future_result {
950                        Ok(item) => {
951                            ttl = ttl_func(&item);
952                            (
953                                ItemState::Present(item.clone()),
954                                Ok(ItemLookup {
955                                    item,
956                                    expiration: now + ttl,
957                                    is_fresh: true,
958                                }),
959                            )
960                        }
961                        Err(err) => {
962                            self.inner.error_counter.inc();
963                            let err = Arc::new(err.into());
964                            (ItemState::Failed(err.clone()), Err(err))
965                        }
966                    };
967
968                    self.inner.cache.insert(
969                        name.clone(),
970                        Item {
971                            item: item_result,
972                            expiration: Instant::now() + ttl,
973                            last_tick: self.inc_tick().into(),
974                        },
975                    );
976                    self.inner.maybe_evict();
977                    return return_value;
978                }
979                Err(_) => {
980                    self.inner.error_counter.inc();
981
982                    // semaphore was closed, but the status is
983                    // still somehow pending
984                    tracing::debug!(
985                        "{} lookup for {name:?} woke up semaphores \
986                                but is still marked pending, \
987                                will restart cache lookup",
988                        self.inner.name
989                    );
990                    continue 'retry;
991                }
992            }
993        }
994
995        return Err(Arc::new(anyhow::anyhow!(
996            "{} lookup for {name:?} failed after {MAX_ATTEMPTS} attempts",
997            self.inner.name
998        )));
999    }
1000}
1001
1002#[cfg(test)]
1003mod test {
1004    use super::*;
1005    use test_log::test; // run with RUST_LOG=lruttl=trace to trace
1006
1007    #[test(tokio::test)]
1008    async fn test_capacity() {
1009        let cache = LruCacheWithTtl::new("test_capacity", 40);
1010
1011        let expiration = Instant::now() + Duration::from_secs(60);
1012        for i in 0..100 {
1013            cache.insert(i, i, expiration).await;
1014        }
1015
1016        assert_eq!(cache.inner.cache.len(), 40, "capacity is respected");
1017    }
1018
1019    #[test(tokio::test)]
1020    async fn test_expiration() {
1021        let cache = LruCacheWithTtl::new("test_expiration", 1);
1022
1023        tokio::time::pause();
1024        let expiration = Instant::now() + Duration::from_secs(1);
1025        cache.insert(0, 0, expiration).await;
1026
1027        cache.get(&0).expect("still in cache");
1028        tokio::time::advance(Duration::from_secs(2)).await;
1029        assert!(cache.get(&0).is_none(), "evicted due to ttl");
1030    }
1031
1032    #[test(tokio::test)]
1033    async fn test_over_capacity_slow_resolve() {
1034        let cache = Arc::new(LruCacheWithTtl::<String, u64>::new(
1035            "test_over_capacity_slow_resolve",
1036            1,
1037        ));
1038
1039        let mut foos = vec![];
1040        for idx in 0..2 {
1041            let cache = cache.clone();
1042            foos.push(tokio::spawn(async move {
1043                eprintln!("spawned task {idx} is running");
1044                cache
1045                    .get_or_try_insert(&"foo".to_string(), |_| Duration::from_secs(86400), async {
1046                        if idx == 0 {
1047                            eprintln!("foo {idx} getter sleeping");
1048                            tokio::time::sleep(Duration::from_secs(300)).await;
1049                        }
1050                        eprintln!("foo {idx} getter done");
1051                        Ok::<_, anyhow::Error>(idx)
1052                    })
1053                    .await
1054            }));
1055        }
1056
1057        tokio::task::yield_now().await;
1058
1059        eprintln!("calling again with immediate getter");
1060        let result = cache
1061            .get_or_try_insert(&"bar".to_string(), |_| Duration::from_secs(60), async {
1062                eprintln!("bar immediate getter running");
1063                Ok::<_, anyhow::Error>(42)
1064            })
1065            .await
1066            .unwrap();
1067
1068        assert_eq!(result.item, 42);
1069        assert_eq!(cache.inner.cache.len(), 1);
1070
1071        eprintln!("aborting first one");
1072        foos.remove(0).abort();
1073
1074        eprintln!("try new key");
1075        let result = cache
1076            .get_or_try_insert(&"baz".to_string(), |_| Duration::from_secs(60), async {
1077                eprintln!("baz immediate getter running");
1078                Ok::<_, anyhow::Error>(32)
1079            })
1080            .await
1081            .unwrap();
1082        assert_eq!(result.item, 32);
1083        assert_eq!(cache.inner.cache.len(), 1);
1084
1085        eprintln!("waiting second one");
1086        assert_eq!(1, foos.pop().unwrap().await.unwrap().unwrap().item);
1087
1088        assert_eq!(cache.inner.cache.len(), 1);
1089    }
1090}