1use crate::{Error, LimitSpec, REDIS};
2use anyhow::{anyhow, Context};
3use mod_redis::{RedisConnection, Script};
4use parking_lot::Mutex;
5use std::collections::HashMap;
6use std::sync::{Arc, LazyLock};
7use std::time::{Duration, Instant, SystemTime};
8use tokio::sync::Notify;
9use uuid::Uuid;
10
11static MEMORY: LazyLock<Mutex<MemoryStore>> = LazyLock::new(|| Mutex::new(MemoryStore::new()));
12
13static ACQUIRE_SCRIPT: LazyLock<Script> = LazyLock::new(|| Script::new(include_str!("limit.lua")));
14
15pub struct LimitSpecWithDuration {
16 pub spec: LimitSpec,
17 pub duration: Duration,
19}
20
21#[derive(Debug)]
22pub struct LimitLease {
23 name: String,
25 uuid: Uuid,
26 armed: bool,
27 backend: Backend,
28}
29
30#[derive(Debug, PartialEq, Clone, Copy)]
31enum Backend {
32 Memory,
33 Redis,
34}
35
36impl LimitSpecWithDuration {
37 pub async fn acquire_lease<S: AsRef<str>>(
38 &self,
39 key: S,
40 deadline: Instant,
41 ) -> Result<LimitLease, Error> {
42 match (self.spec.force_local, REDIS.get()) {
43 (false, Some(redis)) => {
44 self.acquire_lease_redis(redis, key.as_ref(), deadline)
45 .await
46 }
47 (true, _) | (false, None) => self.acquire_lease_memory(key.as_ref(), deadline).await,
48 }
49 }
50
51 pub async fn acquire_lease_redis(
52 &self,
53 conn: &RedisConnection,
54 key: &str,
55 deadline: Instant,
56 ) -> Result<LimitLease, Error> {
57 loop {
58 let now_ts = SystemTime::now()
59 .duration_since(SystemTime::UNIX_EPOCH)
60 .map(|d| d.as_secs_f64())
61 .unwrap_or(0.0);
62
63 let expires_ts = now_ts + self.duration.as_secs_f64();
64 let uuid = Uuid::new_v4();
65 let uuid_str = uuid.to_string();
66
67 let mut script = ACQUIRE_SCRIPT.prepare_invoke();
68 script
69 .key(key)
70 .arg(now_ts)
71 .arg(expires_ts)
72 .arg(self.spec.limit)
73 .arg(uuid_str);
74
75 match conn.invoke_script(script).await.with_context(|| {
76 format!(
77 "error invoking redis lease acquisition script \
78 key={key} now={now_ts} expires={expires_ts} \
79 limit={} uuid={uuid}",
80 self.spec.limit
81 )
82 })? {
83 mod_redis::RedisValue::Okay => {
84 return Ok(LimitLease {
85 name: key.to_string(),
86 uuid,
87 armed: true,
88 backend: Backend::Redis,
89 });
90 }
91 mod_redis::RedisValue::Int(next_expiration_interval) => {
92 if Instant::now() >= deadline {
93 return Err(Error::TooManyLeases(Duration::from_secs(
94 next_expiration_interval as u64,
95 )));
96 }
97
98 tokio::time::sleep(Duration::from_secs(3)).await;
99 }
100 mod_redis::RedisValue::Double(next_expiration_interval) => {
101 if Instant::now() >= deadline {
102 return Err(Error::TooManyLeases(Duration::from_secs(
103 next_expiration_interval as u64,
104 )));
105 }
106
107 tokio::time::sleep(Duration::from_secs(3)).await;
108 }
109 value => {
110 return Err(anyhow!("acquire script succeeded but returned {value:?}").into());
111 }
112 }
113 }
114 }
115
116 pub async fn acquire_lease_memory(
117 &self,
118 key: &str,
119 deadline: Instant,
120 ) -> Result<LimitLease, Error> {
121 let uuid = Uuid::new_v4();
122
123 fn resolve_set(key: &str) -> Arc<LeaseSet> {
124 MEMORY.lock().get_or_create(key)
125 }
126
127 let set = resolve_set(key);
128
129 set.acquire(uuid, self.spec.limit, self.duration, deadline)
130 .await?;
131
132 Ok(LimitLease {
133 name: key.to_string(),
134 uuid,
135 armed: true,
136 backend: Backend::Memory,
137 })
138 }
139}
140
141impl LimitLease {
142 pub async fn release(&mut self) {
143 self.armed = false;
144 match self.backend {
145 Backend::Memory => self.release_memory().await,
146 Backend::Redis => {
147 if let Some(redis) = REDIS.get() {
148 self.release_redis(redis).await;
149 } else {
150 eprintln!("LimitLease::release: backend is Redis but REDIS is not set");
151 }
152 }
153 }
154 }
155
156 pub async fn extend(&self, duration: Duration) -> Result<(), Error> {
157 match self.backend {
158 Backend::Memory => self.extend_memory(duration).await,
159 Backend::Redis => {
160 if let Some(redis) = REDIS.get() {
161 self.extend_redis(redis, duration).await
162 } else {
163 Err(anyhow::anyhow!(
164 "LimitLease::extend: backend is Redis but REDIS is not set"
165 )
166 .into())
167 }
168 }
169 }
170 }
171
172 pub fn take(&mut self) -> Self {
173 let armed = self.armed;
174 self.armed = false;
175 Self {
176 name: self.name.clone(),
177 uuid: self.uuid,
178 armed,
179 backend: self.backend,
180 }
181 }
182
183 async fn extend_memory(&self, duration: Duration) -> Result<(), Error> {
184 let store = MEMORY.lock();
185 if let Some(set) = store.get(&self.name) {
186 set.extend(self.uuid, duration)
187 } else {
188 Err(Error::NonExistentLease)
189 }
190 }
191
192 async fn extend_redis(&self, conn: &RedisConnection, duration: Duration) -> Result<(), Error> {
193 let now_ts = SystemTime::now()
194 .duration_since(SystemTime::UNIX_EPOCH)
195 .map(|d| d.as_secs_f64())
196 .unwrap_or(0.0);
197
198 let expires = now_ts + duration.as_secs_f64();
199
200 let mut cmd = mod_redis::cmd("ZADD");
201 cmd.arg(&self.name)
202 .arg("XX") .arg("CH") .arg(expires)
205 .arg(self.uuid.to_string());
206 let value = conn.query(cmd).await?;
207
208 if value != mod_redis::RedisValue::Int(1) {
209 return Err(anyhow!("Failed to extend lease").into());
210 }
211
212 Ok(())
213 }
214
215 async fn release_memory(&self) {
216 let store = MEMORY.lock();
217 if let Some(set) = store.get(&self.name) {
218 set.release(self.uuid);
219 }
220 }
221
222 async fn release_redis(&mut self, conn: &RedisConnection) {
223 let mut cmd = mod_redis::cmd("ZREM");
224 cmd.arg(&self.name).arg(self.uuid.to_string());
225 conn.query(cmd).await.ok();
226 }
227}
228
229impl Drop for LimitLease {
230 fn drop(&mut self) {
231 if self.armed {
232 self.armed = false;
233 let mut deferred = Self {
234 armed: false,
235 name: self.name.clone(),
236 uuid: self.uuid,
237 backend: self.backend,
238 };
239 tokio::task::Builder::new()
240 .name("LimitLeaseDropper")
241 .spawn(async move {
242 deferred.release().await;
243 })
244 .ok();
245 }
246 }
247}
248
249struct LeaseSet {
250 members: Mutex<HashMap<Uuid, Instant>>,
251 notify: Notify,
252}
253
254impl LeaseSet {
255 fn new() -> Self {
256 Self {
257 members: Mutex::new(HashMap::new()),
258 notify: Notify::new(),
259 }
260 }
261
262 fn acquire_immediate(&self, uuid: Uuid, limit: u64, duration: Duration) -> bool {
263 let mut members = self.members.lock();
264 let now = Instant::now();
265 members.retain(|_k, expiry| *expiry > now);
266
267 if members.len() as u64 + 1 <= limit {
268 members.insert(uuid, now + duration);
269 return true;
270 }
271
272 false
273 }
274
275 async fn acquire(
276 &self,
277 uuid: Uuid,
278 limit: u64,
279 duration: Duration,
280 deadline: Instant,
281 ) -> Result<(), Error> {
282 loop {
283 if self.acquire_immediate(uuid, limit, duration) {
284 return Ok(());
285 }
286
287 match tokio::time::timeout_at(deadline.into(), self.notify.notified()).await {
288 Err(_) => {
289 if self.acquire_immediate(uuid, limit, duration) {
290 return Ok(());
291 }
292 let min_expiration = self
293 .members
294 .lock()
295 .values()
296 .cloned()
297 .min()
298 .expect("some elements");
299 return Err(Error::TooManyLeases(min_expiration - Instant::now()));
300 }
301 Ok(_) => {
302 continue;
304 }
305 }
306 }
307 }
308
309 fn extend(&self, uuid: Uuid, duration: Duration) -> Result<(), Error> {
310 match self.members.lock().get_mut(&uuid) {
311 Some(entry) => {
312 *entry = Instant::now() + duration;
313 Ok(())
314 }
315 None => Err(Error::NonExistentLease),
316 }
317 }
318
319 fn release(&self, uuid: Uuid) {
320 let mut members = self.members.lock();
321 members.remove(&uuid);
322 self.notify.notify_one();
323 }
324}
325
326struct MemoryStore {
327 sets: HashMap<String, Arc<LeaseSet>>,
328}
329
330impl MemoryStore {
331 fn new() -> Self {
332 Self {
333 sets: HashMap::new(),
334 }
335 }
336
337 fn get(&self, name: &str) -> Option<Arc<LeaseSet>> {
338 self.sets.get(name).map(Arc::clone)
339 }
340
341 fn get_or_create(&mut self, name: &str) -> Arc<LeaseSet> {
342 self.sets
343 .entry(name.to_string())
344 .or_insert_with(|| Arc::new(LeaseSet::new()))
345 .clone()
346 }
347}
348
349#[cfg(test)]
350mod test {
351 use super::*;
352 use mod_redis::test::{RedisCluster, RedisServer};
353
354 #[tokio::test]
355 async fn test_memory() {
356 let limit = LimitSpecWithDuration {
357 spec: LimitSpec::new(2),
358 duration: Duration::from_secs(2),
359 };
360
361 let key = format!("test_memory-{}", Uuid::new_v4());
362 let lease1 = limit
363 .acquire_lease_memory(&key, Instant::now())
364 .await
365 .unwrap();
366 eprintln!("lease1: {lease1:?}");
367 let mut lease2 = limit
368 .acquire_lease_memory(&key, Instant::now())
369 .await
370 .unwrap();
371 eprintln!("lease2: {lease2:?}");
372 assert!(limit
374 .acquire_lease_memory(&key, Instant::now())
375 .await
376 .is_err());
377
378 lease2.release().await;
380 let _lease3 = limit
381 .acquire_lease_memory(&key, Instant::now())
382 .await
383 .unwrap();
384
385 assert!(limit
387 .acquire_lease_memory(&key, Instant::now())
388 .await
389 .is_err());
390
391 let start = Instant::now();
392
393 let _lease4 = limit
395 .acquire_lease_memory(&key, start + limit.duration + limit.duration)
396 .await
397 .unwrap();
398
399 assert!(
400 start.elapsed() > limit.duration,
401 "elapsed is {:?}",
402 start.elapsed()
403 );
404 }
405
406 #[tokio::test]
407 async fn test_redis() {
408 if !RedisServer::is_available() {
409 return;
410 }
411 let redis = RedisServer::spawn("").await.unwrap();
412 let conn = redis.connection().await.unwrap();
413
414 let limit = LimitSpecWithDuration {
415 spec: LimitSpec::new(2),
416 duration: Duration::from_secs(2),
417 };
418
419 let key = format!("test_redis-{}", Uuid::new_v4());
420 let mut lease1 = limit
421 .acquire_lease_redis(&conn, &key, Instant::now())
422 .await
423 .unwrap();
424 eprintln!("lease1: {lease1:?}");
425 let mut lease2 = limit
426 .acquire_lease_redis(&conn, &key, Instant::now())
427 .await
428 .unwrap();
429 eprintln!("lease2: {lease2:?}");
430 assert!(limit
432 .acquire_lease_redis(&conn, &key, Instant::now())
433 .await
434 .is_err());
435
436 lease2.release_redis(&conn).await;
438 let mut lease3 = limit
439 .acquire_lease_redis(&conn, &key, Instant::now())
440 .await
441 .unwrap();
442
443 assert!(limit
445 .acquire_lease_redis(&conn, &key, Instant::now())
446 .await
447 .is_err());
448
449 let start = Instant::now();
450
451 let mut lease4 = limit
453 .acquire_lease_redis(&conn, &key, start + limit.duration + limit.duration)
454 .await
455 .unwrap();
456
457 assert!(
458 start.elapsed() > limit.duration,
459 "elapsed is {:?}",
460 start.elapsed()
461 );
462
463 lease1.release_redis(&conn).await;
464 lease3.release_redis(&conn).await;
465 lease4.release_redis(&conn).await;
466 }
467
468 #[tokio::test]
469 async fn test_redis_cluster() {
470 if !RedisCluster::is_available().await {
471 return;
472 }
473 let redis = RedisCluster::spawn().await.unwrap();
474 let conn = redis.connection().await.unwrap();
475
476 let limit = LimitSpecWithDuration {
477 spec: LimitSpec::new(2),
478 duration: Duration::from_secs(2),
479 };
480
481 let key = format!("test_redis-{}", Uuid::new_v4());
482 let mut lease1 = limit
483 .acquire_lease_redis(&conn, &key, Instant::now())
484 .await
485 .unwrap();
486 eprintln!("lease1: {lease1:?}");
487 let mut lease2 = limit
488 .acquire_lease_redis(&conn, &key, Instant::now())
489 .await
490 .unwrap();
491 eprintln!("lease2: {lease2:?}");
492 assert!(limit
494 .acquire_lease_redis(&conn, &key, Instant::now())
495 .await
496 .is_err());
497
498 lease2.release_redis(&conn).await;
500 let mut lease3 = limit
501 .acquire_lease_redis(&conn, &key, Instant::now())
502 .await
503 .unwrap();
504
505 assert!(limit
507 .acquire_lease_redis(&conn, &key, Instant::now())
508 .await
509 .is_err());
510
511 tokio::time::sleep(limit.duration + limit.duration).await;
513
514 let mut lease4 = limit
515 .acquire_lease_redis(&conn, &key, Instant::now())
516 .await
517 .unwrap();
518
519 lease1.release_redis(&conn).await;
520 lease3.release_redis(&conn).await;
521 lease4.release_redis(&conn).await;
522 }
523
524 #[tokio::test]
525 async fn test_memory_extension() {
526 let limit = LimitSpecWithDuration {
527 spec: LimitSpec::new(1),
528 duration: Duration::from_secs(2),
529 };
530
531 let key = format!("test_redis-{}", Uuid::new_v4());
532 let lease1 = limit
533 .acquire_lease_memory(&key, Instant::now())
534 .await
535 .unwrap();
536 eprintln!("lease1: {lease1:?}");
537 assert!(limit
539 .acquire_lease_memory(&key, Instant::now())
540 .await
541 .is_err());
542
543 tokio::time::sleep(Duration::from_secs(1)).await;
544
545 lease1.extend_memory(Duration::from_secs(6)).await.unwrap();
546
547 tokio::time::sleep(limit.duration + limit.duration).await;
549
550 assert!(limit
552 .acquire_lease_memory(&key, Instant::now())
553 .await
554 .is_err());
555
556 tokio::time::sleep(limit.duration + limit.duration).await;
558
559 let _lease2 = limit
560 .acquire_lease_memory(&key, Instant::now())
561 .await
562 .unwrap();
563 }
564
565 #[tokio::test]
566 async fn test_redis_extension() {
567 if !RedisServer::is_available() {
568 return;
569 }
570 let redis = RedisServer::spawn("").await.unwrap();
571 let conn = redis.connection().await.unwrap();
572
573 let limit = LimitSpecWithDuration {
574 spec: LimitSpec::new(1),
575 duration: Duration::from_secs(2),
576 };
577
578 let key = format!("test_redis-{}", Uuid::new_v4());
579 let mut lease1 = limit
580 .acquire_lease_redis(&conn, &key, Instant::now())
581 .await
582 .unwrap();
583 eprintln!("lease1: {lease1:?}");
584 assert!(limit
586 .acquire_lease_redis(&conn, &key, Instant::now())
587 .await
588 .is_err());
589
590 tokio::time::sleep(Duration::from_secs(1)).await;
591
592 lease1
593 .extend_redis(&conn, Duration::from_secs(6))
594 .await
595 .unwrap();
596
597 tokio::time::sleep(limit.duration + limit.duration).await;
599
600 assert!(limit
602 .acquire_lease_redis(&conn, &key, Instant::now())
603 .await
604 .is_err());
605
606 tokio::time::sleep(limit.duration + limit.duration).await;
608
609 let mut lease2 = limit
610 .acquire_lease_redis(&conn, &key, Instant::now())
611 .await
612 .unwrap();
613
614 lease1.release_redis(&conn).await;
615 lease2.release_redis(&conn).await;
616 }
617}