1use crate::metrics::*;
2use dashmap::DashMap;
3use kumo_server_memory::subscribe_to_memory_status_changes_async;
4use parking_lot::Mutex;
5use prometheus::{IntCounter, IntGauge};
6use scopeguard::defer;
7use std::borrow::Borrow;
8use std::collections::HashMap;
9use std::fmt::Debug;
10use std::future::Future;
11use std::hash::Hash;
12use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
13use std::sync::{Arc, LazyLock, Weak};
14use tokio::sync::Semaphore;
15use tokio::time::{timeout_at, Duration, Instant};
16pub use {linkme, pastey as paste};
17
18mod metrics;
19
20static CACHES: LazyLock<Mutex<Vec<Weak<dyn CachePurger + Send + Sync>>>> =
21 LazyLock::new(Mutex::default);
22
23struct Inner<K: Clone + Hash + Eq + Debug, V: Clone + Send + Sync + Debug> {
24 name: String,
25 tick: AtomicUsize,
26 capacity: AtomicUsize,
27 allow_stale_reads: AtomicBool,
28 cache: DashMap<K, Item<V>>,
29 lru_samples: AtomicUsize,
30 sema_timeout_milliseconds: AtomicUsize,
31 lookup_counter: IntCounter,
32 evict_counter: IntCounter,
33 expire_counter: IntCounter,
34 hit_counter: IntCounter,
35 miss_counter: IntCounter,
36 populate_counter: IntCounter,
37 insert_counter: IntCounter,
38 stale_counter: IntCounter,
39 error_counter: IntCounter,
40 wait_gauge: IntGauge,
41 size_gauge: IntGauge,
42}
43
44impl<
45 K: Clone + Debug + Send + Sync + Hash + Eq + 'static,
46 V: Clone + Debug + Send + Sync + 'static,
47 > Inner<K, V>
48{
49 pub fn clear(&self) -> usize {
50 let num_entries = self.cache.len();
51
52 self.cache.retain(|_k, item| {
57 if let ItemState::Pending(sema) = &item.item {
58 sema.close();
60 }
61 false
62 });
63
64 self.size_gauge.set(self.cache.len() as i64);
65 num_entries
66 }
67
68 pub fn evict_some(&self, target: usize) -> usize {
87 let now = Instant::now();
88
89 let cache_size = self.cache.len();
92 let num_samples = self.lru_samples.load(Ordering::Relaxed).min(cache_size);
94
95 let mut expired_keys = vec![];
97 let mut samples = vec![];
99
100 {
115 let mut rng = rand::thread_rng();
116 let mut indices =
117 rand::seq::index::sample(&mut rng, cache_size, num_samples).into_vec();
118 indices.sort();
119 let mut iter = self.cache.iter();
120 let mut current_idx = 0;
121
122 fn advance_by(iter: &mut impl Iterator, skip_amount: usize) {
127 for _ in 0..skip_amount {
128 if iter.next().is_none() {
129 return;
130 }
131 }
132 }
133
134 for idx in indices {
135 let skip_amount = idx - current_idx;
139 advance_by(&mut iter, skip_amount);
140
141 match iter.next() {
142 Some(map_entry) => {
143 current_idx = idx + 1;
144 let item = map_entry.value();
145 match &item.item {
146 ItemState::Pending(_) | ItemState::Refreshing { .. } => {
147 }
149 ItemState::Present(_) | ItemState::Failed(_) => {
150 if now >= item.expiration {
151 expired_keys.push(map_entry.key().clone());
152 } else {
153 let last_tick = item.last_tick.load(Ordering::Relaxed);
154 samples.push((map_entry.key().clone(), last_tick));
155 }
156 }
157 }
158 }
159 None => {
160 break;
161 }
162 }
163 }
164 }
165
166 let mut num_removed = 0;
167 for key in expired_keys {
168 let removed = self
172 .cache
173 .remove_if(&key, |_k, entry| now >= entry.expiration)
174 .is_some();
175 if removed {
176 tracing::trace!("{} expired {key:?}", self.name);
177 num_removed += 1;
178 self.expire_counter.inc();
179 }
180 }
181
182 let target = target.min(samples.len() / 2).max(1);
189
190 if num_removed >= target {
192 self.size_gauge.set(self.cache.len() as i64);
193 tracing::trace!("{} expired {num_removed} of target {target}", self.name);
194 return num_removed;
195 }
196
197 samples.sort_by(|(_ka, tick_a), (_kb, tick_b)| tick_a.cmp(tick_b));
200
201 for (key, tick) in samples {
202 if self
206 .cache
207 .remove_if(&key, |_k, item| {
208 item.last_tick.load(Ordering::Relaxed) == tick
209 })
210 .is_some()
211 {
212 tracing::debug!("{} evicted {key:?}", self.name);
213 num_removed += 1;
214 self.evict_counter.inc();
215 self.size_gauge.set(self.cache.len() as i64);
216 if num_removed >= target {
217 return num_removed;
218 }
219 }
220 }
221
222 if num_removed == 0 {
223 tracing::debug!(
224 "{} did not find anything to evict, target was {target}",
225 self.name
226 );
227 }
228
229 tracing::trace!("{} removed {num_removed} of target {target}", self.name);
230
231 num_removed
232 }
233
234 pub fn maybe_evict(&self) -> usize {
237 let cache_size = self.cache.len();
238 let capacity = self.capacity.load(Ordering::Relaxed);
239 if cache_size > capacity {
240 self.evict_some(cache_size - capacity)
241 } else {
242 0
243 }
244 }
245}
246
247trait CachePurger {
248 fn name(&self) -> &str;
249 fn purge(&self) -> usize;
250 fn process_expirations(&self) -> usize;
251 fn update_capacity(&self, capacity: usize);
252}
253
254impl<
255 K: Clone + Debug + Send + Sync + Hash + Eq + 'static,
256 V: Clone + Debug + Send + Sync + 'static,
257 > CachePurger for Inner<K, V>
258{
259 fn name(&self) -> &str {
260 &self.name
261 }
262 fn purge(&self) -> usize {
263 self.clear()
264 }
265 fn process_expirations(&self) -> usize {
266 let now = Instant::now();
267 let mut expired_keys = vec![];
268 for map_entry in self.cache.iter() {
269 let item = map_entry.value();
270 match &item.item {
271 ItemState::Pending(_) | ItemState::Refreshing { .. } => {
272 }
274 ItemState::Present(_) | ItemState::Failed(_) => {
275 if now >= item.expiration {
276 expired_keys.push(map_entry.key().clone());
277 }
278 }
279 }
280 }
281
282 let mut num_removed = 0;
283 for key in expired_keys {
284 let removed = self
288 .cache
289 .remove_if(&key, |_k, entry| now >= entry.expiration)
290 .is_some();
291 if removed {
292 num_removed += 1;
293 self.expire_counter.inc();
294 self.size_gauge.set(self.cache.len() as i64);
295 }
296 }
297
298 num_removed + self.maybe_evict()
299 }
300
301 fn update_capacity(&self, capacity: usize) {
302 self.capacity.store(capacity, Ordering::Relaxed);
303 self.process_expirations();
308 }
309}
310
311fn all_caches() -> Vec<Arc<dyn CachePurger + Send + Sync>> {
312 let mut result = vec![];
313 let mut caches = CACHES.lock();
314 caches.retain(|entry| match entry.upgrade() {
315 Some(purger) => {
316 result.push(purger);
317 true
318 }
319 None => false,
320 });
321 result
322}
323
324pub fn purge_all_caches() {
325 let purgers = all_caches();
326
327 tracing::error!("purging {} caches", purgers.len());
328 for purger in purgers {
329 let name = purger.name();
330 let num_entries = purger.purge();
331 tracing::error!("cleared {num_entries} entries from cache {name}");
332 }
333}
334
335async fn prune_expired_caches() {
336 loop {
337 tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
338 let purgers = all_caches();
339
340 for p in purgers {
341 let n = p.process_expirations();
342 if n > 0 {
343 tracing::debug!("expired {n} entries from cache {}", p.name());
344 }
345 }
346 }
347}
348
349#[linkme::distributed_slice]
350pub static LRUTTL_VIVIFY: [fn() -> CacheDefinition];
351
352#[macro_export]
353macro_rules! optional_doc {
354 ($doc:expr) => {
355 Some($doc.trim())
356 };
357 () => {
358 None
359 };
360}
361
362#[macro_export]
369macro_rules! declare_cache {
370 (
371 $(#[doc = $doc:expr])*
372 $vis:vis
373 static $sym:ident:
374 LruCacheWithTtl<$key:ty, $value:ty>::new($name:expr, $capacity:expr);
375 ) => {
376 $(#[doc = $doc])*
377 $vis static $sym: ::std::sync::LazyLock<$crate::LruCacheWithTtl<$key, $value>> =
378 ::std::sync::LazyLock::new(
379 || $crate::LruCacheWithTtl::new($name, $capacity));
380
381 $crate::paste::paste! {
383 #[linkme::distributed_slice($crate::LRUTTL_VIVIFY)]
384 static [<VIVIFY_ $sym>]: fn() -> $crate::CacheDefinition = || {
385 ::std::sync::LazyLock::force(&$sym);
386 $crate::CacheDefinition {
387 name: $name,
388 capacity: $capacity,
389 doc: $crate::optional_doc!($($doc)*),
390 }
391 };
392 }
393 };
394}
395
396fn vivify() {
399 LazyLock::force(&PREDEFINED_CACHES);
400}
401
402fn vivify_impl() -> HashMap<&'static str, CacheDefinition> {
403 let mut map = HashMap::new();
404
405 for vivify_func in LRUTTL_VIVIFY {
406 let definition = vivify_func();
407 assert!(
408 !map.contains_key(definition.name),
409 "duplicate cache name {}",
410 definition.name
411 );
412 map.insert(definition.name, definition);
413 }
414
415 map
416}
417
418#[derive(serde::Serialize)]
419pub struct CacheDefinition {
420 pub name: &'static str,
421 pub capacity: usize,
422 pub doc: Option<&'static str>,
423}
424
425static PREDEFINED_CACHES: LazyLock<HashMap<&'static str, CacheDefinition>> =
426 LazyLock::new(vivify_impl);
427
428pub fn get_definitions() -> Vec<&'static CacheDefinition> {
429 let mut defs = PREDEFINED_CACHES.values().collect::<Vec<_>>();
430 defs.sort_by(|a, b| a.name.cmp(&b.name));
431 defs
432}
433
434pub fn is_name_available(name: &str) -> bool {
435 !PREDEFINED_CACHES.contains_key(name)
436}
437
438pub fn set_cache_capacity(name: &str, capacity: usize) -> bool {
440 if !PREDEFINED_CACHES.contains_key(name) {
441 return false;
442 }
443 let caches = all_caches();
444 match caches.iter().find(|p| p.name() == name) {
445 Some(p) => {
446 p.update_capacity(capacity);
447 true
448 }
449 None => false,
450 }
451}
452
453pub fn spawn_memory_monitor() {
454 vivify();
455 tokio::spawn(purge_caches_on_memory_shortage());
456 tokio::spawn(prune_expired_caches());
457}
458
459async fn purge_caches_on_memory_shortage() {
460 tracing::debug!("starting memory monitor");
461 let mut memory_status = subscribe_to_memory_status_changes_async().await;
462 while let Ok(()) = memory_status.changed().await {
463 if kumo_server_memory::get_headroom() == 0 {
464 purge_all_caches();
465
466 tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
470 }
471 }
472}
473
474#[derive(Debug, Clone)]
475enum ItemState<V>
476where
477 V: Send,
478 V: Sync,
479{
480 Present(V),
481 Pending(Arc<Semaphore>),
482 Failed(Arc<anyhow::Error>),
483 Refreshing {
484 stale_value: V,
485 pending: Arc<Semaphore>,
486 },
487}
488
489#[derive(Debug)]
490struct Item<V>
491where
492 V: Send,
493 V: Sync,
494{
495 item: ItemState<V>,
496 expiration: Instant,
497 last_tick: AtomicUsize,
498}
499
500impl<V: Clone + Send + Sync> Clone for Item<V> {
501 fn clone(&self) -> Self {
502 Self {
503 item: self.item.clone(),
504 expiration: self.expiration,
505 last_tick: self.last_tick.load(Ordering::Relaxed).into(),
506 }
507 }
508}
509
510#[derive(Debug)]
511pub struct ItemLookup<V: Debug> {
512 pub item: V,
514 pub is_fresh: bool,
517 pub expiration: Instant,
519}
520
521pub struct LruCacheWithTtl<K: Clone + Debug + Hash + Eq, V: Clone + Debug + Send + Sync> {
522 inner: Arc<Inner<K, V>>,
523}
524
525impl<
526 K: Clone + Debug + Hash + Eq + Send + Sync + std::fmt::Debug + 'static,
527 V: Clone + Debug + Send + Sync + 'static,
528 > LruCacheWithTtl<K, V>
529{
530 pub fn new<S: Into<String>>(name: S, capacity: usize) -> Self {
531 let name = name.into();
532 let cache = DashMap::new();
533
534 let lookup_counter = CACHE_LOOKUP
535 .get_metric_with_label_values(&[&name])
536 .expect("failed to get counter");
537 let hit_counter = CACHE_HIT
538 .get_metric_with_label_values(&[&name])
539 .expect("failed to get counter");
540 let stale_counter = CACHE_STALE
541 .get_metric_with_label_values(&[&name])
542 .expect("failed to get counter");
543 let evict_counter = CACHE_EVICT
544 .get_metric_with_label_values(&[&name])
545 .expect("failed to get counter");
546 let expire_counter = CACHE_EXPIRE
547 .get_metric_with_label_values(&[&name])
548 .expect("failed to get counter");
549 let miss_counter = CACHE_MISS
550 .get_metric_with_label_values(&[&name])
551 .expect("failed to get counter");
552 let populate_counter = CACHE_POPULATED
553 .get_metric_with_label_values(&[&name])
554 .expect("failed to get counter");
555 let insert_counter = CACHE_INSERT
556 .get_metric_with_label_values(&[&name])
557 .expect("failed to get counter");
558 let error_counter = CACHE_ERROR
559 .get_metric_with_label_values(&[&name])
560 .expect("failed to get counter");
561 let wait_gauge = CACHE_WAIT
562 .get_metric_with_label_values(&[&name])
563 .expect("failed to get counter");
564 let size_gauge = CACHE_SIZE
565 .get_metric_with_label_values(&[&name])
566 .expect("failed to get counter");
567
568 let inner = Arc::new(Inner {
569 name,
570 cache,
571 tick: AtomicUsize::new(0),
572 allow_stale_reads: AtomicBool::new(false),
573 capacity: AtomicUsize::new(capacity),
574 lru_samples: AtomicUsize::new(10),
575 sema_timeout_milliseconds: AtomicUsize::new(120_000),
576 lookup_counter,
577 evict_counter,
578 expire_counter,
579 hit_counter,
580 miss_counter,
581 populate_counter,
582 error_counter,
583 wait_gauge,
584 insert_counter,
585 stale_counter,
586 size_gauge,
587 });
588
589 {
593 let generic: Arc<dyn CachePurger + Send + Sync> = inner.clone();
594 CACHES.lock().push(Arc::downgrade(&generic));
595 tracing::debug!(
596 "registered cache {} with capacity {capacity}",
597 generic.name()
598 );
599 }
600
601 Self { inner }
602 }
603
604 fn allow_stale_reads(&self) -> bool {
605 self.inner.allow_stale_reads.load(Ordering::Relaxed)
606 }
607
608 pub fn set_allow_stale_reads(&self, value: bool) {
609 self.inner.allow_stale_reads.store(value, Ordering::Relaxed);
610 }
611
612 pub fn set_sema_timeout(&self, duration: Duration) {
613 self.inner
614 .sema_timeout_milliseconds
615 .store(duration.as_millis() as usize, Ordering::Relaxed);
616 }
617
618 pub fn clear(&self) -> usize {
619 self.inner.clear()
620 }
621
622 fn inc_tick(&self) -> usize {
623 self.inner.tick.fetch_add(1, Ordering::Relaxed) + 1
624 }
625
626 fn update_tick(&self, item: &Item<V>) {
627 let v = self.inc_tick();
628 item.last_tick.store(v, Ordering::Relaxed);
629 }
630
631 pub fn lookup<Q: ?Sized>(&self, name: &Q) -> Option<ItemLookup<V>>
632 where
633 K: Borrow<Q>,
634 Q: Hash + Eq,
635 {
636 self.inner.lookup_counter.inc();
637 match self.inner.cache.get_mut(name) {
638 None => {
639 self.inner.miss_counter.inc();
640 None
641 }
642 Some(entry) => {
643 match &entry.item {
644 ItemState::Present(item) => {
645 let now = Instant::now();
646 if now >= entry.expiration {
647 if self.allow_stale_reads() {
649 self.inner.miss_counter.inc();
656 return None;
657 }
658
659 drop(entry);
663 if self
664 .inner
665 .cache
666 .remove_if(name, |_k, entry| now >= entry.expiration)
667 .is_some()
668 {
669 self.inner.expire_counter.inc();
670 self.inner.size_gauge.set(self.inner.cache.len() as i64);
671 }
672 self.inner.miss_counter.inc();
673 return None;
674 }
675 self.inner.hit_counter.inc();
676 self.update_tick(&entry);
677 Some(ItemLookup {
678 item: item.clone(),
679 expiration: entry.expiration,
680 is_fresh: false,
681 })
682 }
683 ItemState::Refreshing { .. } | ItemState::Pending(_) | ItemState::Failed(_) => {
684 self.inner.miss_counter.inc();
685 None
686 }
687 }
688 }
689 }
690 }
691
692 pub fn get<Q: ?Sized>(&self, name: &Q) -> Option<V>
693 where
694 K: Borrow<Q>,
695 Q: Hash + Eq,
696 {
697 self.lookup(name).map(|lookup| lookup.item)
698 }
699
700 pub async fn insert(&self, name: K, item: V, expiration: Instant) -> V {
701 self.inner.cache.insert(
702 name,
703 Item {
704 item: ItemState::Present(item.clone()),
705 expiration,
706 last_tick: self.inc_tick().into(),
707 },
708 );
709
710 self.inner.insert_counter.inc();
711 self.inner.size_gauge.set(self.inner.cache.len() as i64);
712 self.inner.maybe_evict();
713
714 item
715 }
716
717 fn clone_item_state(&self, name: &K) -> (ItemState<V>, Instant) {
718 let mut is_new = false;
719 let mut entry = self.inner.cache.entry(name.clone()).or_insert_with(|| {
720 is_new = true;
721 Item {
722 item: ItemState::Pending(Arc::new(Semaphore::new(1))),
723 expiration: Instant::now() + Duration::from_secs(60),
724 last_tick: self.inc_tick().into(),
725 }
726 });
727
728 match &entry.value().item {
729 ItemState::Pending(sema) => {
730 if sema.is_closed() {
731 entry.value_mut().item = ItemState::Pending(Arc::new(Semaphore::new(1)));
732 }
733 }
734 ItemState::Refreshing {
735 stale_value,
736 pending,
737 } => {
738 if pending.is_closed() {
739 entry.value_mut().item = ItemState::Refreshing {
740 stale_value: stale_value.clone(),
741 pending: Arc::new(Semaphore::new(1)),
742 };
743 }
744 }
745 ItemState::Present(item) => {
746 let now = Instant::now();
747 if now >= entry.expiration {
748 let pending = Arc::new(Semaphore::new(1));
750 if self.allow_stale_reads() {
751 entry.value_mut().item = ItemState::Refreshing {
752 stale_value: item.clone(),
753 pending,
754 };
755 } else {
756 entry.value_mut().item = ItemState::Pending(pending);
757 }
758 }
759 }
760 ItemState::Failed(_) => {
761 let now = Instant::now();
762 if now >= entry.expiration {
763 entry.value_mut().item = ItemState::Pending(Arc::new(Semaphore::new(1)));
765 }
766 }
767 }
768
769 self.update_tick(&entry);
770 let item = entry.value();
771 let result = (item.item.clone(), entry.expiration);
772 drop(entry);
773
774 if is_new {
775 self.inner.size_gauge.set(self.inner.cache.len() as i64);
776 self.inner.maybe_evict();
777 }
778
779 result
780 }
781
782 pub async fn get_or_try_insert<E: Into<anyhow::Error>, TTL: FnOnce(&V) -> Duration>(
788 &self,
789 name: &K,
790 ttl_func: TTL,
791 fut: impl Future<Output = Result<V, E>>,
792 ) -> Result<ItemLookup<V>, Arc<anyhow::Error>> {
793 if let Some(entry) = self.lookup(name) {
795 return Ok(entry);
796 }
797
798 let timeout_duration = Duration::from_millis(
799 self.inner.sema_timeout_milliseconds.load(Ordering::Relaxed) as u64,
800 );
801 let start = Instant::now();
802 let deadline = start + timeout_duration;
803
804 'retry: loop {
806 let (stale_value, sema) = match self.clone_item_state(name) {
807 (ItemState::Present(item), expiration) => {
808 return Ok(ItemLookup {
809 item,
810 expiration,
811 is_fresh: false,
812 });
813 }
814 (ItemState::Failed(error), _) => {
815 return Err(error);
816 }
817 (
818 ItemState::Refreshing {
819 stale_value,
820 pending,
821 },
822 expiration,
823 ) => (Some((stale_value, expiration)), pending),
824 (ItemState::Pending(sema), _) => (None, sema),
825 };
826
827 let wait_result = {
828 self.inner.wait_gauge.inc();
829 defer! {
830 self.inner.wait_gauge.dec();
831 }
832
833 match timeout_at(deadline, sema.acquire_owned()).await {
834 Err(_) => {
835 if let Some((item, expiration)) = stale_value {
836 tracing::debug!(
837 "{} semaphore acquire for {name:?} timed out after \
838 {timeout_duration:?}, allowing stale value to satisfy the lookup",
839 self.inner.name
840 );
841 self.inner.stale_counter.inc();
842 return Ok(ItemLookup {
843 item,
844 expiration,
845 is_fresh: false,
846 });
847 }
848 tracing::debug!(
849 "{} semaphore acquire for {name:?} timed out after \
850 {timeout_duration:?}, returning error",
851 self.inner.name
852 );
853
854 self.inner.error_counter.inc();
855 return Err(Arc::new(anyhow::anyhow!(
856 "{} lookup for {name:?} \
857 timed out after {timeout_duration:?} \
858 on semaphore acquire while waiting for cache to populate",
859 self.inner.name
860 )));
861 }
862 Ok(r) => r,
863 }
864 };
865
866 let current_sema = match self.clone_item_state(name) {
869 (ItemState::Present(item), expiration) => {
870 return Ok(ItemLookup {
871 item,
872 expiration,
873 is_fresh: false,
874 });
875 }
876 (ItemState::Failed(error), _) => {
877 self.inner.hit_counter.inc();
878 return Err(error);
879 }
880 (
881 ItemState::Refreshing {
882 stale_value: _,
883 pending,
884 },
885 _,
886 ) => pending,
887 (ItemState::Pending(current_sema), _) => current_sema,
888 };
889
890 match wait_result {
892 Ok(permit) => {
893 defer! {
901 permit.semaphore().close();
902 }
903
904 if !Arc::ptr_eq(¤t_sema, permit.semaphore()) {
905 self.inner.error_counter.inc();
906 tracing::warn!(
907 "{} mismatched semaphores for {name:?}, \
908 will restart cache resolve.",
909 self.inner.name
910 );
911 continue 'retry;
912 }
913
914 self.inner.populate_counter.inc();
915 let mut ttl = Duration::from_secs(60);
916 let future_result = fut.await;
917 let now = Instant::now();
918
919 let (item_result, return_value) = match future_result {
920 Ok(item) => {
921 ttl = ttl_func(&item);
922 (
923 ItemState::Present(item.clone()),
924 Ok(ItemLookup {
925 item,
926 expiration: now + ttl,
927 is_fresh: true,
928 }),
929 )
930 }
931 Err(err) => {
932 self.inner.error_counter.inc();
933 let err = Arc::new(err.into());
934 (ItemState::Failed(err.clone()), Err(err))
935 }
936 };
937
938 self.inner.cache.insert(
939 name.clone(),
940 Item {
941 item: item_result,
942 expiration: Instant::now() + ttl,
943 last_tick: self.inc_tick().into(),
944 },
945 );
946 self.inner.maybe_evict();
947 return return_value;
948 }
949 Err(_) => {
950 self.inner.error_counter.inc();
951
952 tracing::debug!(
955 "{} lookup for {name:?} woke up semaphores \
956 but is still marked pending, \
957 will restart cache lookup",
958 self.inner.name
959 );
960 continue 'retry;
961 }
962 }
963 }
964 }
965}
966
967#[cfg(test)]
968mod test {
969 use super::*;
970 use test_log::test; #[test(tokio::test)]
973 async fn test_capacity() {
974 let cache = LruCacheWithTtl::new("test_capacity", 40);
975
976 let expiration = Instant::now() + Duration::from_secs(60);
977 for i in 0..100 {
978 cache.insert(i, i, expiration).await;
979 }
980
981 assert_eq!(cache.inner.cache.len(), 40, "capacity is respected");
982 }
983
984 #[test(tokio::test)]
985 async fn test_expiration() {
986 let cache = LruCacheWithTtl::new("test_expiration", 1);
987
988 tokio::time::pause();
989 let expiration = Instant::now() + Duration::from_secs(1);
990 cache.insert(0, 0, expiration).await;
991
992 cache.get(&0).expect("still in cache");
993 tokio::time::advance(Duration::from_secs(2)).await;
994 assert!(cache.get(&0).is_none(), "evicted due to ttl");
995 }
996
997 #[test(tokio::test)]
998 async fn test_over_capacity_slow_resolve() {
999 let cache = Arc::new(LruCacheWithTtl::<String, u64>::new(
1000 "test_over_capacity_slow_resolve",
1001 1,
1002 ));
1003
1004 let mut foos = vec![];
1005 for idx in 0..2 {
1006 let cache = cache.clone();
1007 foos.push(tokio::spawn(async move {
1008 eprintln!("spawned task {idx} is running");
1009 cache
1010 .get_or_try_insert(&"foo".to_string(), |_| Duration::from_secs(86400), async {
1011 if idx == 0 {
1012 eprintln!("foo {idx} getter sleeping");
1013 tokio::time::sleep(Duration::from_secs(300)).await;
1014 }
1015 eprintln!("foo {idx} getter done");
1016 Ok::<_, anyhow::Error>(idx)
1017 })
1018 .await
1019 }));
1020 }
1021
1022 tokio::task::yield_now().await;
1023
1024 eprintln!("calling again with immediate getter");
1025 let result = cache
1026 .get_or_try_insert(&"bar".to_string(), |_| Duration::from_secs(60), async {
1027 eprintln!("bar immediate getter running");
1028 Ok::<_, anyhow::Error>(42)
1029 })
1030 .await
1031 .unwrap();
1032
1033 assert_eq!(result.item, 42);
1034 assert_eq!(cache.inner.cache.len(), 1);
1035
1036 eprintln!("aborting first one");
1037 foos.remove(0).abort();
1038
1039 eprintln!("try new key");
1040 let result = cache
1041 .get_or_try_insert(&"baz".to_string(), |_| Duration::from_secs(60), async {
1042 eprintln!("baz immediate getter running");
1043 Ok::<_, anyhow::Error>(32)
1044 })
1045 .await
1046 .unwrap();
1047 assert_eq!(result.item, 32);
1048 assert_eq!(cache.inner.cache.len(), 1);
1049
1050 eprintln!("waiting second one");
1051 assert_eq!(1, foos.pop().unwrap().await.unwrap().unwrap().item);
1052
1053 assert_eq!(cache.inner.cache.len(), 1);
1054 }
1055}