1use crate::message::{Message, MessageList};
2use kumo_chrono_helper::{DateTime, Utc};
3use spool::SpoolId;
4use std::collections::HashMap;
5use tokio::time::{Duration, Instant};
6
7const WHEEL_BITS: usize = 8;
8const WHEEL_SIZE: usize = 256;
9const WHEEL_MASK: usize = WHEEL_SIZE - 1;
10
11#[derive(Debug)]
13pub struct TimeQ<const SLOTS: usize = 4> {
14 tick_resolution: Duration,
15 created: Instant,
16 next_run: Instant,
17 last_dispatched: Instant,
18 buckets: [Bucket; SLOTS],
19 entry_by_id: HashMap<SpoolId, ListEntry>,
20 entry_slot_to_id: HashMap<EntrySlotId, SpoolId>,
21 next_entry_slot_id: EntrySlotId,
22}
23
24type EntrySlotId = usize;
32pub type QuadTimeQ = TimeQ<4>;
33pub type TriTimeQ = TimeQ<3>;
34
35#[derive(Debug)]
36struct ListEntry {
37 msg: Message,
38 entry_slot: EntrySlotId,
39}
40
41#[derive(Debug)]
42struct Bucket {
43 lists: [Vec<EntrySlotId>; WHEEL_SIZE],
44}
45
46impl Default for Bucket {
47 fn default() -> Self {
48 Self {
49 lists: std::array::from_fn(|_| Default::default()),
50 }
51 }
52}
53
54trait RoundedMillis {
57 fn as_millis_round_up(&self) -> u128;
58}
59
60impl RoundedMillis for Duration {
61 fn as_millis_round_up(&self) -> u128 {
62 self.as_micros().div_ceil(1000)
63 }
64}
65
66#[derive(Copy, Clone)]
67enum RoundDirection {
68 Up,
69 Down,
70}
71
72impl<const SLOTS: usize> TimeQ<SLOTS> {
73 fn new_impl(now: Instant, tick_resolution: Duration) -> Self {
74 assert!(SLOTS > 0 && SLOTS <= 4, "SLOTS must be 1..=4");
75 Self {
76 tick_resolution,
77 next_run: now + tick_resolution,
78 last_dispatched: now,
79 created: now,
80 buckets: std::array::from_fn(|_| Default::default()),
81 entry_by_id: HashMap::new(),
82 entry_slot_to_id: HashMap::new(),
83 next_entry_slot_id: 0,
84 }
85 }
86
87 pub fn new(tick_resolution: Duration) -> Self {
88 Self::new_impl(Instant::now(), tick_resolution)
89 }
90
91 pub fn clear(&mut self) {
92 for bucket in &mut self.buckets {
93 for list in &mut bucket.lists {
94 list.clear();
95 }
96 }
97 self.entry_by_id.clear();
98 self.entry_slot_to_id.clear();
99 self.next_entry_slot_id = 0;
100 }
101
102 pub fn len(&self) -> usize {
103 self.entry_by_id.len()
104 }
105
106 pub fn is_empty(&self) -> bool {
107 self.entry_by_id.is_empty()
108 }
109
110 pub fn tick_resolution(&self) -> Duration {
111 self.tick_resolution
112 }
113
114 fn compute_abs_tick(&self, due: Instant, round_direction: RoundDirection) -> Option<usize> {
117 let delta = due.checked_duration_since(self.created)?;
118 match round_direction {
119 RoundDirection::Up => Some(
120 (delta
121 .as_millis_round_up()
122 .div_ceil(self.tick_resolution.as_millis_round_up())) as usize,
123 ),
124 RoundDirection::Down => {
125 Some((delta.as_millis() / self.tick_resolution.as_millis()) as usize)
126 }
127 }
128 }
129
130 fn compute_list(
131 &mut self,
132 due: Instant,
133 round_direction: RoundDirection,
134 ) -> Option<&mut Vec<EntrySlotId>> {
135 let next_run_tick = self.compute_abs_tick(self.next_run, round_direction)?;
136 let mut due = self.compute_abs_tick(due, round_direction)?;
137 let diff = due.checked_sub(next_run_tick)?;
138
139 for slot in 0..SLOTS {
140 let ceiling = 1 << ((slot + 1) * WHEEL_BITS);
141 if diff < ceiling {
142 return self
143 .buckets
144 .get_mut(slot)
145 .unwrap()
146 .lists
147 .get_mut((due >> (slot * WHEEL_BITS)) & WHEEL_MASK);
148 }
149 }
150
151 if diff > u32::MAX as usize {
153 due = next_run_tick + u32::MAX as usize
154 }
155
156 self.buckets
157 .last_mut()
158 .unwrap()
159 .lists
160 .get_mut((due >> ((SLOTS - 1) * WHEEL_BITS)) & WHEEL_MASK)
161 }
162
163 fn insert_impl(
164 &mut self,
165 now: Instant,
166 now_chrono: DateTime<Utc>,
167 message: Message,
168 round_direction: RoundDirection,
169 ) -> Result<(), Message> {
170 let Some(due) = message.get_due() else {
171 return Err(message);
173 };
174
175 let Ok(due_in) = (due - now_chrono).to_std() else {
176 return Err(message);
178 };
179
180 if due_in <= Duration::ZERO {
181 return Err(message);
183 }
184
185 let due_instant = now + due_in;
186
187 let id = *message.id();
188 let entry_slot = self.next_entry_slot_id;
189 self.next_entry_slot_id += 1;
190
191 match self.compute_list(due_instant, round_direction) {
192 Some(list) => {
193 list.push(entry_slot);
194 self.entry_slot_to_id.insert(entry_slot, id);
195 self.entry_by_id.insert(
196 id,
197 ListEntry {
198 msg: message,
199 entry_slot,
200 },
201 );
202
203 Ok(())
204 }
205 None => Err(message),
206 }
207 }
208
209 fn pop_impl(&mut self, now: Instant, now_utc: DateTime<Utc>) -> MessageList {
211 let mut ready_messages = MessageList::new();
212
213 if now < self.next_run {
214 return ready_messages;
216 }
217
218 let mut reinsert = vec![];
219
220 let last_slot = self
223 .compute_abs_tick(self.last_dispatched, RoundDirection::Down)
224 .expect("never negative");
225 let now_slot = self
226 .compute_abs_tick(now, RoundDirection::Down)
227 .expect("pop_impl failed because now is prior to the TimeQ creation");
228
229 for idx in last_slot + 1..=now_slot {
234 if idx & WHEEL_MASK != 0 {
235 continue;
236 }
237 fn cascade(bucket: &mut Bucket, slot: usize, reinsert: &mut Vec<EntrySlotId>) -> bool {
242 while let Some(entry_slot) = bucket.lists[slot].pop() {
243 reinsert.push(entry_slot);
244 }
245 bucket.lists[slot].shrink_to_fit();
246 slot == 0
247 }
248
249 for tier in 1..SLOTS {
250 if !cascade(
251 &mut self.buckets[tier],
252 (idx >> (tier * WHEEL_BITS)) & WHEEL_MASK,
253 &mut reinsert,
254 ) {
255 break;
256 }
257 }
258
259 while let Some(entry_slot) = reinsert.pop() {
264 let Some(spool_id) = self.entry_slot_to_id.remove(&entry_slot) else {
265 continue;
267 };
268 let Some(entry) = self.entry_by_id.get(&spool_id) else {
269 continue;
272 };
273 if entry.entry_slot != entry_slot {
274 continue;
278 }
279
280 let msg = entry.msg.clone();
281 if let Err(msg) = self.insert_impl(now, now_utc, msg, RoundDirection::Down) {
282 ready_messages.push_back(msg);
283 self.entry_by_id.remove(&spool_id);
284 }
285 }
286 }
287
288 let num_slots = (now_slot - last_slot).min(WHEEL_SIZE);
291 for idx in last_slot + 1..=last_slot + num_slots {
292 let mut nominally_ready = std::mem::take(&mut self.buckets[0].lists[idx & WHEEL_MASK]);
294 while let Some(entry_slot) = nominally_ready.pop() {
295 let Some(spool_id) = self.entry_slot_to_id.remove(&entry_slot) else {
296 continue;
298 };
299 let Some(entry) = self.entry_by_id.get(&spool_id) else {
300 continue;
303 };
304 if entry.entry_slot != entry_slot {
305 continue;
309 }
310
311 let msg = entry.msg.clone();
312 if let Err(msg) = self.insert_impl(now, now_utc, msg, RoundDirection::Down) {
313 ready_messages.push_back(msg);
314 self.entry_by_id.remove(&spool_id);
315 }
316 }
317 }
318
319 self.last_dispatched = now;
320 self.next_run = now + self.tick_resolution;
321
322 ready_messages
323 }
324
325 pub fn insert(&mut self, message: Message) -> Result<(), Message> {
328 self.insert_impl(Instant::now(), Utc::now(), message, RoundDirection::Up)
331 }
332
333 pub fn cancel(&mut self, message: &Message) -> bool {
336 match self.entry_by_id.remove(message.id()) {
337 Some(entry) => {
338 self.entry_slot_to_id.remove(&entry.entry_slot);
339 true
344 }
345 None => false,
346 }
347 }
348
349 pub fn contains(&self, message: &Message) -> bool {
350 self.entry_by_id
351 .get(message.id())
352 .and_then(|entry| self.entry_slot_to_id.get(&entry.entry_slot))
353 .is_some()
354 }
355
356 #[cfg(test)]
357 fn insert_for_test(
358 &mut self,
359 message: Message,
360 start: Instant,
361 start_utc: DateTime<Utc>,
362 ) -> Result<(), Message> {
363 self.insert_impl(
364 Instant::now(),
365 start_utc + start.elapsed(),
366 message,
367 RoundDirection::Up,
368 )
369 }
370
371 pub fn pop(&mut self) -> MessageList {
373 self.pop_impl(Instant::now(), Utc::now())
374 }
375
376 pub fn drain(&mut self) -> impl Iterator<Item = Message> + use<'_, SLOTS> {
378 self.buckets
379 .iter_mut()
380 .flat_map(|bucket| bucket.lists.iter_mut())
381 .flat_map(|list| std::mem::take(list).into_iter())
382 .filter_map(|entry_slot| {
383 let spool_id = self.entry_slot_to_id.remove(&entry_slot)?;
384 let entry = self.entry_by_id.get(&spool_id)?;
385 if entry.entry_slot == entry_slot {
386 let entry = self.entry_by_id.remove(&spool_id)?;
387 self.entry_by_id.remove(&spool_id);
388 Some(entry.msg)
389 } else {
390 None
391 }
392 })
393 }
394
395 pub fn retain<KEEPER>(&mut self, mut keeper: KEEPER)
399 where
400 KEEPER: FnMut(&Message) -> bool,
401 {
402 for bucket in self.buckets.iter_mut() {
403 for list in bucket.lists.iter_mut() {
404 let to_process = std::mem::take(list);
405 for entry_slot in to_process {
406 let Some(spool_id) = self.entry_slot_to_id.get(&entry_slot).copied() else {
407 continue;
409 };
410 let Some(entry) = self.entry_by_id.get(&spool_id) else {
411 self.entry_slot_to_id.remove(&entry_slot);
413 continue;
414 };
415 if entry.entry_slot != entry_slot {
416 self.entry_slot_to_id.remove(&entry_slot);
418 continue;
419 }
420
421 if (keeper)(&entry.msg) {
422 list.push(entry_slot);
424 } else {
425 self.entry_slot_to_id.remove(&entry_slot);
427 self.entry_by_id.remove(&spool_id);
428 }
429 }
430 }
431 }
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::EnvelopeAddress;
439 use spool::SpoolId;
440 use std::sync::Arc;
441
442 #[derive(Debug)]
443 struct Time {
444 start: Instant,
445 start_utc: DateTime<Utc>,
446 }
447
448 impl Time {
449 pub fn new() -> Self {
450 tokio::time::pause();
451 let start_utc = Utc::now();
452 let start = Instant::now();
453 Self { start, start_utc }
454 }
455
456 pub fn elapsed(&self) -> Duration {
457 self.start.elapsed()
458 }
459
460 pub async fn advance(&self, duration: Duration) {
461 tokio::time::advance(duration).await;
462 }
463
464 pub fn now_utc(&self) -> DateTime<Utc> {
465 self.start_utc + self.start.elapsed()
466 }
467
468 pub async fn new_msg_due_in(&self, duration: Duration) -> Message {
469 let msg = new_msg();
470 msg.set_due(Some(self.now_utc() + duration)).await.unwrap();
471 msg
472 }
473
474 pub fn insert<const SLOTS: usize>(
475 &self,
476 timeq: &mut TimeQ<SLOTS>,
477 msg: Message,
478 ) -> Result<(), Message> {
479 timeq.insert_for_test(msg, self.start, self.start_utc)
480 }
481
482 pub fn pop<const SLOTS: usize>(&self, timeq: &mut TimeQ<SLOTS>) -> MessageList {
483 timeq.pop_impl(Instant::now(), self.now_utc())
484 }
485
486 pub async fn advance_and_collect<const SLOTS: usize>(
487 &self,
488 num_seconds: usize,
489 timeq: &mut TimeQ<SLOTS>,
490 popped: &mut Vec<Duration>,
491 ) {
492 for _ in 0..num_seconds {
493 self.advance(Duration::from_secs(1)).await;
494 let mut ready = self.pop(timeq);
495 while let Some(_msg) = ready.pop_front() {
496 popped.push(self.start.elapsed());
497 }
498 }
499 }
500 }
501
502 fn new_msg() -> Message {
503 Message::new_dirty(
504 SpoolId::new(),
505 EnvelopeAddress::parse("sender@example.com").unwrap(),
506 EnvelopeAddress::parse("recip@example.com").unwrap(),
507 serde_json::json!({}),
508 Arc::new("test".as_bytes().to_vec().into_boxed_slice()),
509 )
510 .unwrap()
511 }
512
513 #[tokio::test]
514 async fn cannot_insert_immediately_due() {
515 let mut timeq = QuadTimeQ::new(Duration::from_secs(3));
516 assert!(timeq.is_empty());
517 let msg1 = new_msg();
518 assert!(timeq.insert(msg1).is_err());
519 assert!(timeq.is_empty());
520 }
521
522 #[tokio::test]
523 async fn double_insert() {
524 let mut timeq = QuadTimeQ::new(Duration::from_secs(3));
525 assert!(timeq.is_empty());
526 let msg1 = new_msg();
527 msg1.delay_by(chrono::Duration::seconds(10)).await.unwrap();
528 assert!(timeq.insert(msg1.clone()).is_ok());
529 assert!(timeq.insert(msg1.clone()).is_ok());
530 assert_eq!(timeq.len(), 1);
531 let drained = timeq.drain().collect::<Vec<_>>();
532 assert_eq!(drained.len(), 1);
533 assert_eq!(drained[0], msg1);
534 assert!(timeq.is_empty());
535 assert!(timeq.entry_slot_to_id.is_empty());
536 assert!(timeq.entry_by_id.is_empty());
537 }
538
539 #[tokio::test]
540 async fn retain() {
541 let time = Time::new();
542
543 let mut timeq = QuadTimeQ::new(Duration::from_secs(3));
544 assert!(timeq.is_empty());
545
546 let msg1 = time.new_msg_due_in(Duration::from_secs(10)).await;
547 time.insert(&mut timeq, msg1.clone()).unwrap();
548 assert_eq!(timeq.len(), 1);
549
550 let msg2 = time.new_msg_due_in(Duration::from_secs(86400)).await;
551 time.insert(&mut timeq, msg2.clone()).unwrap();
552 assert_eq!(timeq.len(), 2);
553
554 timeq.retain(|msg| *msg == msg2);
555 assert_eq!(timeq.len(), 1);
556 }
557
558 async fn schedule_in_tier<const SLOTS: usize>(tier: usize) {
559 let time = Time::new();
560
561 let tick_resolution = Duration::from_secs(3);
562 let mut base = tick_resolution;
563 for _ in 0..tier {
564 base *= WHEEL_SIZE as u32;
565 }
566 let limit = base * (WHEEL_SIZE as u32);
567 eprintln!("max delay for tier {tier} is {limit:?}");
568
569 let msg1 = time.new_msg_due_in(limit).await;
570
571 eprintln!("schedule_in_tier: {time:?}");
572
573 let mut timeq = TimeQ::<SLOTS>::new(tick_resolution);
574 assert!(timeq.is_empty());
575
576 eprintln!("msg is due: {:?}", msg1.get_due());
577 time.insert(&mut timeq, msg1.clone()).unwrap();
578 assert_eq!(timeq.len(), 1);
579
580 assert!(time.pop(&mut timeq).is_empty());
581
582 let mut wait = limit / 2;
587 let mut ready_messages;
588 loop {
589 eprintln!("waiting for {wait:?}");
590 time.advance(wait).await;
591 wait = (wait / 2).max(tick_resolution);
592 ready_messages = time.pop(&mut timeq);
593 if !ready_messages.is_empty() {
594 break;
595 }
596 }
597
598 let elapsed = time.elapsed();
599 let now_utc = time.now_utc();
600 eprintln!("schedule_in_tier: {elapsed:?} {now_utc:?}");
601 eprintln!("limit was {limit:?}, {elapsed:?} have elapsed");
602 assert!(
603 elapsed >= limit,
604 "waited until {limit:?} had elapsed, but {elapsed:?} have elapsed",
605 );
606 }
607
608 #[tokio::test]
609 async fn quad_schedule_in_tier_0() {
610 schedule_in_tier::<4>(0).await;
611 }
612 #[tokio::test]
613 async fn quad_schedule_in_tier_1() {
614 schedule_in_tier::<4>(1).await;
615 }
616
617 #[tokio::test]
618 async fn quad_schedule_in_tier_2() {
619 schedule_in_tier::<4>(2).await;
620 }
621
622 #[tokio::test]
623 #[cfg(not(debug_assertions))]
624 async fn quad_schedule_in_tier_3() {
625 schedule_in_tier::<4>(3).await;
626 }
627
628 #[tokio::test]
629 async fn tri_schedule_in_tier_0() {
630 schedule_in_tier::<3>(0).await;
631 }
632 #[tokio::test]
633 async fn tri_schedule_in_tier_1() {
634 schedule_in_tier::<3>(1).await;
635 }
636 #[tokio::test]
637 async fn tri_schedule_in_tier_2() {
638 schedule_in_tier::<3>(2).await;
639 }
640
641 #[tokio::test]
642 #[cfg(not(debug_assertions))]
643 async fn tri_schedule_in_tier_3() {
644 schedule_in_tier::<3>(3).await;
645 }
646
647 #[tokio::test]
648 async fn bi_schedule_in_tier_0() {
649 schedule_in_tier::<2>(0).await;
650 }
651 #[tokio::test]
652 async fn bi_schedule_in_tier_1() {
653 schedule_in_tier::<2>(1).await;
654 }
655 #[tokio::test]
656 async fn bi_schedule_in_tier_2() {
657 schedule_in_tier::<2>(2).await;
658 }
659
660 #[tokio::test]
661 async fn uni_schedule_in_tier_0() {
662 schedule_in_tier::<1>(0).await;
663 }
664 #[tokio::test]
665 async fn uni_schedule_in_tier_1() {
666 schedule_in_tier::<1>(1).await;
667 }
668
669 #[tokio::test]
670 async fn schedule_tier_0_and_1() {
671 let time = Time::new();
672
673 let mut timeq = QuadTimeQ::new(Duration::from_secs(3));
674 assert!(timeq.is_empty());
675
676 let intervals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 768 * 2];
677 for &seconds in &intervals {
678 let msg = time.new_msg_due_in(Duration::from_secs(seconds)).await;
679 time.insert(&mut timeq, msg.clone()).unwrap();
680 }
681
682 assert_eq!(timeq.len(), intervals.len());
683
684 let mut popped = vec![];
685 loop {
686 time.advance(Duration::from_secs(1)).await;
687 let mut ready = time.pop(&mut timeq);
688 while let Some(_msg) = ready.pop_front() {
689 popped.push(time.elapsed());
690 }
691
692 if timeq.is_empty() {
693 break;
694 }
695 }
696
697 eprintln!("{popped:?} vs {intervals:?}");
698
699 assert_eq!(popped.len(), intervals.len());
700
701 for (idx, (expected, actual)) in intervals.iter().zip(popped.iter()).enumerate() {
702 let upper_limit = Duration::from_secs({ *expected }.div_ceil(3) * 3);
703 assert!(
704 *actual >= Duration::from_secs(*expected) && *actual <= upper_limit,
705 "idx={idx}, expected {expected}-{upper_limit:?} seconds, got {actual:?}"
706 );
707 }
708 }
709
710 #[tokio::test]
711 async fn schedule_cusp() {
712 let time = Time::new();
713
714 let msg1 = time.new_msg_due_in(Duration::from_millis(2)).await;
715 eprintln!("msg is due: {:?}", msg1.get_due());
716
717 let mut timeq = QuadTimeQ::new(Duration::from_millis(1));
718
719 time.insert(&mut timeq, msg1.clone()).unwrap();
720 assert_eq!(timeq.len(), 1);
721
722 assert!(time.pop(&mut timeq).is_empty());
723
724 time.advance(Duration::from_millis(1)).await;
725 let ready_list = time.pop(&mut timeq);
726 assert_eq!(ready_list.len(), 0);
727
728 time.advance(Duration::from_millis(1)).await;
729 let mut ready_list = time.pop(&mut timeq);
730 assert_eq!(ready_list.len(), 1);
731
732 let msg = ready_list.pop_front().unwrap();
733 let due = msg.get_due().unwrap();
734 let now_utc = time.now_utc();
735
736 assert!(
737 due <= now_utc,
738 "cannot be due in the future. due={due} now={now_utc}"
739 );
740 }
741
742 #[tokio::test]
743 async fn schedule_after_creation() {
744 let time = Time::new();
745
746 let mut timeq = QuadTimeQ::new(Duration::from_secs(3));
747 assert!(timeq.is_empty());
748
749 let mut popped = vec![];
750
751 let msg = time.new_msg_due_in(Duration::from_secs(10)).await;
753 time.insert(&mut timeq, msg.clone()).unwrap();
754
755 time.advance_and_collect(6, &mut timeq, &mut popped).await;
756
757 let msg = time.new_msg_due_in(Duration::from_secs(3)).await;
760 time.insert(&mut timeq, msg.clone()).unwrap();
761
762 assert_eq!(timeq.len(), 2);
763 eprintln!("{timeq:?}");
764
765 loop {
766 time.advance(Duration::from_secs(1)).await;
767 let mut ready = time.pop(&mut timeq);
768 while let Some(_msg) = ready.pop_front() {
769 popped.push(time.elapsed());
770 }
771
772 eprintln!(
773 "popped.len={} timeq.empty={}",
774 popped.len(),
775 timeq.is_empty()
776 );
777
778 if timeq.is_empty() {
779 break;
780 }
781 }
782 eprintln!("{timeq:?}");
783
784 let intervals = [9, 12];
785 eprintln!("{popped:?} vs {intervals:?}");
786 assert_eq!(popped.len(), intervals.len());
787
788 for (expected, actual) in intervals.iter().zip(popped.iter()) {
789 let upper_limit = Duration::from_secs((*expected as u64).div_ceil(3) * 3);
790 assert!(
791 *actual >= Duration::from_secs(*expected as u64) && *actual <= upper_limit,
792 "expected {expected}-{upper_limit:?} seconds, got {actual:?}"
793 );
794 }
795 }
796}