throttle/
limit.rs

1use crate::{Error, LimitSpec, REDIS};
2use anyhow::{anyhow, Context};
3use mod_redis::{RedisConnection, Script};
4use parking_lot::Mutex;
5use std::collections::HashMap;
6use std::sync::{Arc, LazyLock};
7use std::time::{Duration, Instant, SystemTime};
8use tokio::sync::Notify;
9use uuid::Uuid;
10
11static MEMORY: LazyLock<Mutex<MemoryStore>> = LazyLock::new(|| Mutex::new(MemoryStore::new()));
12
13static ACQUIRE_SCRIPT: LazyLock<Script> = LazyLock::new(|| Script::new(include_str!("limit.lua")));
14
15pub struct LimitSpecWithDuration {
16    pub spec: LimitSpec,
17    /// Maximum lease duration for a single count
18    pub duration: Duration,
19}
20
21#[derive(Debug)]
22pub struct LimitLease {
23    /// Name of the element to release on Drop
24    name: String,
25    uuid: Uuid,
26    armed: bool,
27    backend: Backend,
28}
29
30#[derive(Debug, PartialEq, Clone, Copy)]
31enum Backend {
32    Memory,
33    Redis,
34}
35
36impl LimitSpecWithDuration {
37    pub async fn acquire_lease<S: AsRef<str>>(
38        &self,
39        key: S,
40        deadline: Instant,
41    ) -> Result<LimitLease, Error> {
42        match (self.spec.force_local, REDIS.get()) {
43            (false, Some(redis)) => {
44                self.acquire_lease_redis(redis, key.as_ref(), deadline)
45                    .await
46            }
47            (true, _) | (false, None) => self.acquire_lease_memory(key.as_ref(), deadline).await,
48        }
49    }
50
51    pub async fn acquire_lease_redis(
52        &self,
53        conn: &RedisConnection,
54        key: &str,
55        deadline: Instant,
56    ) -> Result<LimitLease, Error> {
57        loop {
58            let now_ts = SystemTime::now()
59                .duration_since(SystemTime::UNIX_EPOCH)
60                .map(|d| d.as_secs_f64())
61                .unwrap_or(0.0);
62
63            let expires_ts = now_ts + self.duration.as_secs_f64();
64            let uuid = Uuid::new_v4();
65            let uuid_str = uuid.to_string();
66
67            let mut script = ACQUIRE_SCRIPT.prepare_invoke();
68            script
69                .key(key)
70                .arg(now_ts)
71                .arg(expires_ts)
72                .arg(self.spec.limit)
73                .arg(uuid_str);
74
75            match conn.invoke_script(script).await.with_context(|| {
76                format!(
77                    "error invoking redis lease acquisition script \
78                     key={key} now={now_ts} expires={expires_ts} \
79                     limit={} uuid={uuid}",
80                    self.spec.limit
81                )
82            })? {
83                mod_redis::RedisValue::Okay => {
84                    return Ok(LimitLease {
85                        name: key.to_string(),
86                        uuid,
87                        armed: true,
88                        backend: Backend::Redis,
89                    });
90                }
91                mod_redis::RedisValue::Int(next_expiration_interval) => {
92                    if Instant::now() >= deadline {
93                        return Err(Error::TooManyLeases(Duration::from_secs(
94                            next_expiration_interval as u64,
95                        )));
96                    }
97
98                    tokio::time::sleep(Duration::from_secs(3)).await;
99                }
100                mod_redis::RedisValue::Double(next_expiration_interval) => {
101                    if Instant::now() >= deadline {
102                        return Err(Error::TooManyLeases(Duration::from_secs(
103                            next_expiration_interval as u64,
104                        )));
105                    }
106
107                    tokio::time::sleep(Duration::from_secs(3)).await;
108                }
109                value => {
110                    return Err(anyhow!("acquire script succeeded but returned {value:?}").into());
111                }
112            }
113        }
114    }
115
116    pub async fn acquire_lease_memory(
117        &self,
118        key: &str,
119        deadline: Instant,
120    ) -> Result<LimitLease, Error> {
121        let uuid = Uuid::new_v4();
122
123        fn resolve_set(key: &str) -> Arc<LeaseSet> {
124            MEMORY.lock().get_or_create(key)
125        }
126
127        let set = resolve_set(key);
128
129        set.acquire(uuid, self.spec.limit, self.duration, deadline)
130            .await?;
131
132        Ok(LimitLease {
133            name: key.to_string(),
134            uuid,
135            armed: true,
136            backend: Backend::Memory,
137        })
138    }
139}
140
141impl LimitLease {
142    pub async fn release(&mut self) {
143        self.armed = false;
144        match self.backend {
145            Backend::Memory => self.release_memory().await,
146            Backend::Redis => {
147                if let Some(redis) = REDIS.get() {
148                    self.release_redis(redis).await;
149                } else {
150                    eprintln!("LimitLease::release: backend is Redis but REDIS is not set");
151                }
152            }
153        }
154    }
155
156    pub async fn extend(&self, duration: Duration) -> Result<(), Error> {
157        match self.backend {
158            Backend::Memory => self.extend_memory(duration).await,
159            Backend::Redis => {
160                if let Some(redis) = REDIS.get() {
161                    self.extend_redis(redis, duration).await
162                } else {
163                    Err(anyhow::anyhow!(
164                        "LimitLease::extend: backend is Redis but REDIS is not set"
165                    )
166                    .into())
167                }
168            }
169        }
170    }
171
172    pub fn take(&mut self) -> Self {
173        let armed = self.armed;
174        self.armed = false;
175        Self {
176            name: self.name.clone(),
177            uuid: self.uuid,
178            armed,
179            backend: self.backend,
180        }
181    }
182
183    async fn extend_memory(&self, duration: Duration) -> Result<(), Error> {
184        let store = MEMORY.lock();
185        if let Some(set) = store.get(&self.name) {
186            set.extend(self.uuid, duration)
187        } else {
188            Err(Error::NonExistentLease)
189        }
190    }
191
192    async fn extend_redis(&self, conn: &RedisConnection, duration: Duration) -> Result<(), Error> {
193        let now_ts = SystemTime::now()
194            .duration_since(SystemTime::UNIX_EPOCH)
195            .map(|d| d.as_secs_f64())
196            .unwrap_or(0.0);
197
198        let expires = now_ts + duration.as_secs_f64();
199
200        let mut cmd = mod_redis::cmd("ZADD");
201        cmd.arg(&self.name)
202            .arg("XX") // only allow updating existing
203            .arg("CH") // return number of changed entries
204            .arg(expires)
205            .arg(self.uuid.to_string());
206        let value = conn.query(cmd).await?;
207
208        if value != mod_redis::RedisValue::Int(1) {
209            return Err(anyhow!("Failed to extend lease").into());
210        }
211
212        Ok(())
213    }
214
215    async fn release_memory(&self) {
216        let store = MEMORY.lock();
217        if let Some(set) = store.get(&self.name) {
218            set.release(self.uuid);
219        }
220    }
221
222    async fn release_redis(&mut self, conn: &RedisConnection) {
223        let mut cmd = mod_redis::cmd("ZREM");
224        cmd.arg(&self.name).arg(self.uuid.to_string());
225        conn.query(cmd).await.ok();
226    }
227}
228
229impl Drop for LimitLease {
230    fn drop(&mut self) {
231        if self.armed {
232            self.armed = false;
233            let mut deferred = Self {
234                armed: false,
235                name: self.name.clone(),
236                uuid: self.uuid,
237                backend: self.backend,
238            };
239            tokio::task::Builder::new()
240                .name("LimitLeaseDropper")
241                .spawn(async move {
242                    deferred.release().await;
243                })
244                .ok();
245        }
246    }
247}
248
249struct LeaseSet {
250    members: Mutex<HashMap<Uuid, Instant>>,
251    notify: Notify,
252}
253
254impl LeaseSet {
255    fn new() -> Self {
256        Self {
257            members: Mutex::new(HashMap::new()),
258            notify: Notify::new(),
259        }
260    }
261
262    fn acquire_immediate(&self, uuid: Uuid, limit: u64, duration: Duration) -> bool {
263        let mut members = self.members.lock();
264        let now = Instant::now();
265        members.retain(|_k, expiry| *expiry > now);
266
267        if members.len() as u64 + 1 <= limit {
268            members.insert(uuid, now + duration);
269            return true;
270        }
271
272        false
273    }
274
275    async fn acquire(
276        &self,
277        uuid: Uuid,
278        limit: u64,
279        duration: Duration,
280        deadline: Instant,
281    ) -> Result<(), Error> {
282        loop {
283            if self.acquire_immediate(uuid, limit, duration) {
284                return Ok(());
285            }
286
287            match tokio::time::timeout_at(deadline.into(), self.notify.notified()).await {
288                Err(_) => {
289                    if self.acquire_immediate(uuid, limit, duration) {
290                        return Ok(());
291                    }
292                    let min_expiration = self
293                        .members
294                        .lock()
295                        .values()
296                        .cloned()
297                        .min()
298                        .expect("some elements");
299                    return Err(Error::TooManyLeases(min_expiration - Instant::now()));
300                }
301                Ok(_) => {
302                    // Try to acquire again
303                    continue;
304                }
305            }
306        }
307    }
308
309    fn extend(&self, uuid: Uuid, duration: Duration) -> Result<(), Error> {
310        match self.members.lock().get_mut(&uuid) {
311            Some(entry) => {
312                *entry = Instant::now() + duration;
313                Ok(())
314            }
315            None => Err(Error::NonExistentLease),
316        }
317    }
318
319    fn release(&self, uuid: Uuid) {
320        let mut members = self.members.lock();
321        members.remove(&uuid);
322        self.notify.notify_one();
323    }
324}
325
326struct MemoryStore {
327    sets: HashMap<String, Arc<LeaseSet>>,
328}
329
330impl MemoryStore {
331    fn new() -> Self {
332        Self {
333            sets: HashMap::new(),
334        }
335    }
336
337    fn get(&self, name: &str) -> Option<Arc<LeaseSet>> {
338        self.sets.get(name).map(Arc::clone)
339    }
340
341    fn get_or_create(&mut self, name: &str) -> Arc<LeaseSet> {
342        self.sets
343            .entry(name.to_string())
344            .or_insert_with(|| Arc::new(LeaseSet::new()))
345            .clone()
346    }
347}
348
349#[cfg(test)]
350mod test {
351    use super::*;
352    use mod_redis::test::{RedisCluster, RedisServer};
353
354    #[tokio::test]
355    async fn test_memory() {
356        let limit = LimitSpecWithDuration {
357            spec: LimitSpec::new(2),
358            duration: Duration::from_secs(2),
359        };
360
361        let key = format!("test_memory-{}", Uuid::new_v4());
362        let lease1 = limit
363            .acquire_lease_memory(&key, Instant::now())
364            .await
365            .unwrap();
366        eprintln!("lease1: {lease1:?}");
367        let mut lease2 = limit
368            .acquire_lease_memory(&key, Instant::now())
369            .await
370            .unwrap();
371        eprintln!("lease2: {lease2:?}");
372        // Cannot acquire a 3rd lease while the other two are alive
373        assert!(limit
374            .acquire_lease_memory(&key, Instant::now())
375            .await
376            .is_err());
377
378        // Release and try to get a third
379        lease2.release().await;
380        let _lease3 = limit
381            .acquire_lease_memory(&key, Instant::now())
382            .await
383            .unwrap();
384
385        // Cannot acquire while the other two are alive
386        assert!(limit
387            .acquire_lease_memory(&key, Instant::now())
388            .await
389            .is_err());
390
391        let start = Instant::now();
392
393        // We can acquire another after waiting for some number of leases to expire
394        let _lease4 = limit
395            .acquire_lease_memory(&key, start + limit.duration + limit.duration)
396            .await
397            .unwrap();
398
399        assert!(
400            start.elapsed() > limit.duration,
401            "elapsed is {:?}",
402            start.elapsed()
403        );
404    }
405
406    #[tokio::test]
407    async fn test_redis() {
408        if !RedisServer::is_available() {
409            return;
410        }
411        let redis = RedisServer::spawn("").await.unwrap();
412        let conn = redis.connection().await.unwrap();
413
414        let limit = LimitSpecWithDuration {
415            spec: LimitSpec::new(2),
416            duration: Duration::from_secs(2),
417        };
418
419        let key = format!("test_redis-{}", Uuid::new_v4());
420        let mut lease1 = limit
421            .acquire_lease_redis(&conn, &key, Instant::now())
422            .await
423            .unwrap();
424        eprintln!("lease1: {lease1:?}");
425        let mut lease2 = limit
426            .acquire_lease_redis(&conn, &key, Instant::now())
427            .await
428            .unwrap();
429        eprintln!("lease2: {lease2:?}");
430        // Cannot acquire a 3rd lease while the other two are alive
431        assert!(limit
432            .acquire_lease_redis(&conn, &key, Instant::now())
433            .await
434            .is_err());
435
436        // Release and try to get a third
437        lease2.release_redis(&conn).await;
438        let mut lease3 = limit
439            .acquire_lease_redis(&conn, &key, Instant::now())
440            .await
441            .unwrap();
442
443        // Cannot acquire while the other two are alive
444        assert!(limit
445            .acquire_lease_redis(&conn, &key, Instant::now())
446            .await
447            .is_err());
448
449        let start = Instant::now();
450
451        // We can acquire another after waiting for some number of leases to expire
452        let mut lease4 = limit
453            .acquire_lease_redis(&conn, &key, start + limit.duration + limit.duration)
454            .await
455            .unwrap();
456
457        assert!(
458            start.elapsed() > limit.duration,
459            "elapsed is {:?}",
460            start.elapsed()
461        );
462
463        lease1.release_redis(&conn).await;
464        lease3.release_redis(&conn).await;
465        lease4.release_redis(&conn).await;
466    }
467
468    #[tokio::test]
469    async fn test_redis_cluster() {
470        if !RedisCluster::is_available().await {
471            return;
472        }
473        let redis = RedisCluster::spawn().await.unwrap();
474        let conn = redis.connection().await.unwrap();
475
476        let limit = LimitSpecWithDuration {
477            spec: LimitSpec::new(2),
478            duration: Duration::from_secs(2),
479        };
480
481        let key = format!("test_redis-{}", Uuid::new_v4());
482        let mut lease1 = limit
483            .acquire_lease_redis(&conn, &key, Instant::now())
484            .await
485            .unwrap();
486        eprintln!("lease1: {lease1:?}");
487        let mut lease2 = limit
488            .acquire_lease_redis(&conn, &key, Instant::now())
489            .await
490            .unwrap();
491        eprintln!("lease2: {lease2:?}");
492        // Cannot acquire a 3rd lease while the other two are alive
493        assert!(limit
494            .acquire_lease_redis(&conn, &key, Instant::now())
495            .await
496            .is_err());
497
498        // Release and try to get a third
499        lease2.release_redis(&conn).await;
500        let mut lease3 = limit
501            .acquire_lease_redis(&conn, &key, Instant::now())
502            .await
503            .unwrap();
504
505        // Cannot acquire while the other two are alive
506        assert!(limit
507            .acquire_lease_redis(&conn, &key, Instant::now())
508            .await
509            .is_err());
510
511        // Wait for some number of leases to expire
512        tokio::time::sleep(limit.duration + limit.duration).await;
513
514        let mut lease4 = limit
515            .acquire_lease_redis(&conn, &key, Instant::now())
516            .await
517            .unwrap();
518
519        lease1.release_redis(&conn).await;
520        lease3.release_redis(&conn).await;
521        lease4.release_redis(&conn).await;
522    }
523
524    #[tokio::test]
525    async fn test_memory_extension() {
526        let limit = LimitSpecWithDuration {
527            spec: LimitSpec::new(1),
528            duration: Duration::from_secs(2),
529        };
530
531        let key = format!("test_redis-{}", Uuid::new_v4());
532        let lease1 = limit
533            .acquire_lease_memory(&key, Instant::now())
534            .await
535            .unwrap();
536        eprintln!("lease1: {lease1:?}");
537        // Cannot acquire a 2nd lease while the first is are alive
538        assert!(limit
539            .acquire_lease_memory(&key, Instant::now())
540            .await
541            .is_err());
542
543        tokio::time::sleep(Duration::from_secs(1)).await;
544
545        lease1.extend_memory(Duration::from_secs(6)).await.unwrap();
546
547        // Wait for original lease duration to expire
548        tokio::time::sleep(limit.duration + limit.duration).await;
549
550        // Cannot acquire because we have an extended lease
551        assert!(limit
552            .acquire_lease_memory(&key, Instant::now())
553            .await
554            .is_err());
555
556        // Wait for extension to pass
557        tokio::time::sleep(limit.duration + limit.duration).await;
558
559        let _lease2 = limit
560            .acquire_lease_memory(&key, Instant::now())
561            .await
562            .unwrap();
563    }
564
565    #[tokio::test]
566    async fn test_redis_extension() {
567        if !RedisServer::is_available() {
568            return;
569        }
570        let redis = RedisServer::spawn("").await.unwrap();
571        let conn = redis.connection().await.unwrap();
572
573        let limit = LimitSpecWithDuration {
574            spec: LimitSpec::new(1),
575            duration: Duration::from_secs(2),
576        };
577
578        let key = format!("test_redis-{}", Uuid::new_v4());
579        let mut lease1 = limit
580            .acquire_lease_redis(&conn, &key, Instant::now())
581            .await
582            .unwrap();
583        eprintln!("lease1: {lease1:?}");
584        // Cannot acquire a 2nd lease while the first is are alive
585        assert!(limit
586            .acquire_lease_redis(&conn, &key, Instant::now())
587            .await
588            .is_err());
589
590        tokio::time::sleep(Duration::from_secs(1)).await;
591
592        lease1
593            .extend_redis(&conn, Duration::from_secs(6))
594            .await
595            .unwrap();
596
597        // Wait for original lease duration to expire
598        tokio::time::sleep(limit.duration + limit.duration).await;
599
600        // Cannot acquire because we have an extended lease
601        assert!(limit
602            .acquire_lease_redis(&conn, &key, Instant::now())
603            .await
604            .is_err());
605
606        // Wait for extension to pass
607        tokio::time::sleep(limit.duration + limit.duration).await;
608
609        let mut lease2 = limit
610            .acquire_lease_redis(&conn, &key, Instant::now())
611            .await
612            .unwrap();
613
614        lease1.release_redis(&conn).await;
615        lease2.release_redis(&conn).await;
616    }
617}