1#[cfg(feature = "redis")]
6use mod_redis::RedisError;
7use serde::{Deserialize, Deserializer, Serialize};
8use std::convert::TryFrom;
9use std::time::Duration;
10use thiserror::Error;
11
12#[cfg(feature = "redis")]
13pub mod limit;
14#[cfg(feature = "redis")]
15mod throttle;
16
17#[cfg(feature = "redis")]
18mod redis {
19 use super::*;
20 use mod_redis::{Cmd, RedisConnection, RedisValue};
21 use std::ops::Deref;
22 use std::sync::OnceLock;
23
24 #[derive(Debug)]
25 pub(crate) struct RedisContext {
26 pub(crate) connection: RedisConnection,
27 pub(crate) has_redis_cell: bool,
28 }
29
30 impl RedisContext {
31 pub async fn try_from(connection: RedisConnection) -> anyhow::Result<Self> {
32 let mut cmd = Cmd::new();
33 cmd.arg("COMMAND").arg("INFO").arg("CL.THROTTLE");
34
35 let rsp = connection.query(cmd).await?;
36 let has_redis_cell = rsp
37 .as_sequence()
38 .map_or(false, |arr| arr.iter().any(|v| v != &RedisValue::Nil));
39
40 Ok(Self {
41 has_redis_cell,
42 connection,
43 })
44 }
45 }
46
47 impl Deref for RedisContext {
48 type Target = RedisConnection;
49 fn deref(&self) -> &Self::Target {
50 &self.connection
51 }
52 }
53
54 pub(crate) static REDIS: OnceLock<RedisContext> = OnceLock::new();
55
56 pub async fn use_redis(conn: RedisConnection) -> Result<(), Error> {
57 REDIS
58 .set(RedisContext::try_from(conn).await?)
59 .map_err(|_| Error::Generic("redis already configured for throttles".to_string()))?;
60 Ok(())
61 }
62}
63
64#[cfg(feature = "redis")]
65pub use redis::use_redis;
66#[cfg(feature = "redis")]
67pub(crate) use redis::REDIS;
68
69#[derive(Error, Debug)]
70pub enum Error {
71 #[error("{0}")]
72 Generic(String),
73 #[error("{0}")]
74 AnyHow(#[from] anyhow::Error),
75 #[cfg(feature = "redis")]
76 #[error("{0}")]
77 Redis(#[from] RedisError),
78 #[error("TooManyLeases, try again in {0:?}")]
79 TooManyLeases(Duration),
80 #[error("NonExistentLease")]
81 NonExistentLease,
82}
83
84#[derive(Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Hash)]
85#[serde(try_from = "String", into = "String")]
86pub struct ThrottleSpec {
87 pub limit: u64,
88 pub period: u64,
90 pub max_burst: Option<u64>,
99 pub force_local: bool,
100}
101
102#[cfg(feature = "redis")]
103impl ThrottleSpec {
104 pub async fn throttle<S: AsRef<str>>(&self, key: S) -> Result<ThrottleResult, Error> {
105 self.throttle_quantity(key, 1).await
106 }
107
108 pub async fn throttle_quantity<S: AsRef<str>>(
109 &self,
110 key: S,
111 quantity: u64,
112 ) -> Result<ThrottleResult, Error> {
113 let key = key.as_ref();
114 let limit = self.limit;
115 let period = self.period;
116 let max_burst = self.max_burst.unwrap_or(limit);
117 let key = format!("{key}:{limit}:{max_burst}:{period}");
118 throttle::throttle(
119 &key,
120 limit,
121 Duration::from_secs(period),
122 max_burst,
123 Some(quantity),
124 self.force_local,
125 )
126 .await
127 }
128
129 pub fn burst(&self) -> u64 {
131 self.max_burst.unwrap_or(self.limit)
132 }
133
134 pub fn interval(&self) -> Duration {
136 Duration::from_secs_f64(self.period as f64 / self.limit.max(1) as f64)
137 }
138
139 pub fn as_local(&self) -> Self {
140 let mut copy = self.clone();
141 copy.force_local = true;
142 copy
143 }
144}
145
146impl std::fmt::Debug for ThrottleSpec {
147 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
148 write!(fmt, "{}", self.as_string())
149 }
150}
151
152impl std::fmt::Display for ThrottleSpec {
153 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
154 write!(fmt, "{}", self.as_string())
155 }
156}
157
158impl ThrottleSpec {
159 pub fn as_string(&self) -> String {
160 let mut period_scale = None;
161 let period = match self.period {
162 86400 => "d",
163 3600 => "h",
164 60 => "m",
165 1 => "s",
166 other => {
167 period_scale.replace(other.to_string());
168 "s"
169 }
170 };
171
172 let burst = match self.max_burst {
173 Some(b) => format!(",max_burst={b}"),
174 None => String::new(),
175 };
176
177 format!(
178 "{}{}/{}{period}{burst}",
179 if self.force_local { "local:" } else { "" },
180 self.limit,
181 match &period_scale {
182 Some(scale) => scale.as_str(),
183 None => "",
184 }
185 )
186 }
187}
188
189impl From<ThrottleSpec> for String {
190 fn from(spec: ThrottleSpec) -> String {
191 spec.as_string()
192 }
193}
194
195impl TryFrom<String> for ThrottleSpec {
196 type Error = String;
197 fn try_from(s: String) -> Result<Self, String> {
198 Self::try_from(s.as_str())
199 }
200}
201
202fn opt_digit_prefix(s: &str) -> Result<(u64, &str), String> {
203 let mut n: Option<u64> = None;
204 for (idx, c) in s.char_indices() {
205 if !c.is_ascii_digit() {
206 return Ok((n.unwrap_or(1), &s[idx..]));
207 }
208
209 let byte = c as u8;
210
211 let digit = match byte.checked_sub(b'0') {
212 None => return Err(format!("invalid digit {c}")),
213 Some(digit) if digit > 9 => return Err(format!("invalid digit {c}")),
214 Some(digit) => {
215 debug_assert!((0..=9).contains(&digit));
216 u64::from(digit)
217 }
218 };
219
220 n = Some(
221 n.take()
222 .unwrap_or(0)
223 .checked_mul(10)
224 .and_then(|n| n.checked_add(digit))
225 .ok_or_else(|| format!("number too big"))?,
226 );
227 }
228
229 Err(format!("invalid period quantity {s}"))
230}
231
232fn parse_separated_number(limit: &str) -> Result<u64, String> {
234 let value: String = limit
235 .chars()
236 .filter_map(|c| match c {
237 '_' | ',' => None,
238 c => Some(c),
239 })
240 .collect();
241
242 value
243 .parse::<u64>()
244 .map_err(|err| format!("invalid limit '{limit}': {err:#}"))
245}
246
247impl TryFrom<&str> for ThrottleSpec {
248 type Error = String;
249 fn try_from(s: &str) -> Result<Self, String> {
250 let (force_local, s) = match s.strip_prefix("local:") {
251 Some(s) => (true, s),
252 None => (false, s),
253 };
254
255 let (s, max_burst) = match s.split_once(",max_burst=") {
256 Some((s, burst_spec)) => {
257 let burst = parse_separated_number(burst_spec)?;
258 (s, Some(burst))
259 }
260 None => (s, None),
261 };
262
263 let (limit, period) = s
264 .split_once("/")
265 .ok_or_else(|| format!("expected 'limit/period', got {s}"))?;
266
267 let (period_scale, period) = opt_digit_prefix(period)?;
268
269 let period = match period {
270 "h" | "hr" | "hour" => 3600,
271 "m" | "min" | "minute" => 60,
272 "s" | "sec" | "second" => 1,
273 "d" | "day" => 86400,
274 invalid => return Err(format!("unknown period quantity {invalid}")),
275 } * period_scale;
276
277 let limit = parse_separated_number(limit)?;
279
280 if limit == 0 {
281 return Err(format!(
282 "invalid ThrottleSpec `{s}`: limit must be greater than 0!"
283 ));
284 }
285
286 Ok(Self {
287 limit,
288 period,
289 max_burst,
290 force_local,
291 })
292 }
293}
294
295#[derive(Debug, Eq, PartialEq, Serialize)]
296pub struct ThrottleResult {
297 pub throttled: bool,
299 pub limit: u64,
302 pub remaining: u64,
304 pub reset_after: Duration,
307 pub retry_after: Option<Duration>,
310}
311
312#[derive(Eq, PartialEq, Clone, Copy, Serialize, Hash)]
313pub struct LimitSpec {
314 pub limit: u64,
316 pub force_local: bool,
317}
318
319impl std::fmt::Debug for LimitSpec {
320 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
321 if self.force_local {
322 write!(fmt, "local:{:?}", self.limit)
323 } else {
324 self.limit.fmt(fmt)
325 }
326 }
327}
328
329impl std::fmt::Display for LimitSpec {
330 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
331 write!(fmt, "{:?}", self)
332 }
333}
334
335impl LimitSpec {
336 pub const fn new(limit: u64) -> Self {
337 Self {
338 limit,
339 force_local: false,
340 }
341 }
342}
343
344impl TryFrom<&str> for LimitSpec {
345 type Error = String;
346 fn try_from(s: &str) -> Result<Self, String> {
347 let (force_local, s) = match s.strip_prefix("local:") {
348 Some(s) => (true, s),
349 None => (false, s),
350 };
351
352 let limit: String = s
354 .chars()
355 .filter_map(|c| match c {
356 '_' | ',' => None,
357 c => Some(c),
358 })
359 .collect();
360
361 let limit = limit
362 .parse::<u64>()
363 .map_err(|err| format!("invalid limit '{limit}': {err:#}"))?;
364
365 if limit == 0 {
366 return Err(format!(
367 "invalid LimitSpec `{s}`: limit must be greater than 0!"
368 ));
369 }
370
371 Ok(Self { limit, force_local })
372 }
373}
374
375impl<'de> Deserialize<'de> for LimitSpec {
376 fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
377 where
378 D: Deserializer<'de>,
379 {
380 use serde::de::Visitor;
381
382 struct Helper {}
383 impl<'de> Visitor<'de> for Helper {
384 type Value = LimitSpec;
385
386 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
387 formatter.write_str("string or numeric limit spec")
388 }
389
390 fn visit_str<E>(self, value: &str) -> Result<LimitSpec, E>
391 where
392 E: serde::de::Error,
393 {
394 value.try_into().map_err(|err| E::custom(err))
395 }
396
397 fn visit_i8<E>(self, value: i8) -> Result<LimitSpec, E>
398 where
399 E: serde::de::Error,
400 {
401 if value < 1 {
402 return Err(E::custom("limit must be 1 or larger"));
403 }
404 Ok(LimitSpec {
405 limit: value as u64,
406 force_local: false,
407 })
408 }
409
410 fn visit_i16<E>(self, value: i16) -> Result<LimitSpec, E>
411 where
412 E: serde::de::Error,
413 {
414 if value < 1 {
415 return Err(E::custom("limit must be 1 or larger"));
416 }
417 Ok(LimitSpec {
418 limit: value as u64,
419 force_local: false,
420 })
421 }
422
423 fn visit_i32<E>(self, value: i32) -> Result<LimitSpec, E>
424 where
425 E: serde::de::Error,
426 {
427 if value < 1 {
428 return Err(E::custom("limit must be 1 or larger"));
429 }
430 Ok(LimitSpec {
431 limit: value as u64,
432 force_local: false,
433 })
434 }
435
436 fn visit_i64<E>(self, value: i64) -> Result<LimitSpec, E>
437 where
438 E: serde::de::Error,
439 {
440 if value < 1 {
441 return Err(E::custom("limit must be 1 or larger"));
442 }
443 Ok(LimitSpec {
444 limit: value as u64,
445 force_local: false,
446 })
447 }
448
449 fn visit_u8<E>(self, value: u8) -> Result<LimitSpec, E>
450 where
451 E: serde::de::Error,
452 {
453 if value < 1 {
454 return Err(E::custom("limit must be 1 or larger"));
455 }
456 Ok(LimitSpec {
457 limit: value as u64,
458 force_local: false,
459 })
460 }
461
462 fn visit_u16<E>(self, value: u16) -> Result<LimitSpec, E>
463 where
464 E: serde::de::Error,
465 {
466 if value < 1 {
467 return Err(E::custom("limit must be 1 or larger"));
468 }
469 Ok(LimitSpec {
470 limit: value as u64,
471 force_local: false,
472 })
473 }
474
475 fn visit_u32<E>(self, value: u32) -> Result<LimitSpec, E>
476 where
477 E: serde::de::Error,
478 {
479 if value < 1 {
480 return Err(E::custom("limit must be 1 or larger"));
481 }
482 Ok(LimitSpec {
483 limit: value as u64,
484 force_local: false,
485 })
486 }
487
488 fn visit_u64<E>(self, value: u64) -> Result<LimitSpec, E>
489 where
490 E: serde::de::Error,
491 {
492 if value < 1 {
493 return Err(E::custom("limit must be 1 or larger"));
494 }
495 Ok(LimitSpec {
496 limit: value,
497 force_local: false,
498 })
499 }
500 }
501
502 deserializer.deserialize_any(Helper {})
503 }
504}
505
506#[cfg(test)]
507mod test {
508 use super::*;
509
510 #[test]
511 fn throttle_spec_parse() {
512 assert_eq!(
513 ThrottleSpec::try_from("100/hr").unwrap(),
514 ThrottleSpec {
515 limit: 100,
516 period: 3600,
517 max_burst: None,
518 force_local: false,
519 }
520 );
521 assert_eq!(
522 ThrottleSpec::try_from("local:100/hr").unwrap(),
523 ThrottleSpec {
524 limit: 100,
525 period: 3600,
526 max_burst: None,
527 force_local: true,
528 }
529 );
530
531 assert_eq!(
532 ThrottleSpec {
533 limit: 100,
534 period: 3600,
535 max_burst: None,
536 force_local: false,
537 }
538 .as_string(),
539 "100/h"
540 );
541 assert_eq!(
542 ThrottleSpec {
543 limit: 100,
544 period: 3600,
545 max_burst: None,
546 force_local: true,
547 }
548 .as_string(),
549 "local:100/h"
550 );
551
552 let weird_duration = ThrottleSpec::try_from("local:100/123m").unwrap();
553 assert_eq!(
554 weird_duration,
555 ThrottleSpec {
556 limit: 100,
557 period: 123 * 60,
558 max_burst: None,
559 force_local: true,
560 }
561 );
562 assert_eq!(weird_duration.as_string(), "local:100/7380s");
563
564 assert_eq!(
565 ThrottleSpec::try_from("1_0,0/hour").unwrap(),
566 ThrottleSpec {
567 limit: 100,
568 period: 3600,
569 max_burst: None,
570 force_local: false,
571 }
572 );
573 assert_eq!(
574 ThrottleSpec::try_from("100/our").unwrap_err(),
575 "unknown period quantity our".to_string()
576 );
577 assert_eq!(
578 ThrottleSpec::try_from("three/hour").unwrap_err(),
579 "invalid limit 'three': invalid digit found in string".to_string()
580 );
581
582 let burst = ThrottleSpec::try_from("50/day,max_burst=1").unwrap();
583 assert_eq!(
584 burst,
585 ThrottleSpec {
586 limit: 50,
587 period: 86400,
588 max_burst: Some(1),
589 force_local: false,
590 }
591 );
592 assert_eq!(burst.as_string(), "50/d,max_burst=1");
593 assert_eq!(burst.burst(), 1);
594 assert_eq!(format!("{:?}", burst.interval()), "1728s");
595 }
596
597 #[test]
598 fn test_opt_digit_prefix() {
599 assert_eq!(opt_digit_prefix("m").unwrap(), (1, "m"));
600 assert_eq!(
601 opt_digit_prefix("1").unwrap_err(),
602 "invalid period quantity 1"
603 );
604 assert_eq!(opt_digit_prefix("1q").unwrap(), (1, "q"));
605 assert_eq!(opt_digit_prefix("2s").unwrap(), (2, "s"));
606 assert_eq!(opt_digit_prefix("20s").unwrap(), (20, "s"));
607 assert_eq!(opt_digit_prefix("12378s").unwrap(), (12378, "s"));
608 }
609
610 #[test]
611 fn limit_spec_parse() {
612 assert_eq!(LimitSpec::try_from("100").unwrap(), LimitSpec::new(100));
613 assert_eq!(LimitSpec::try_from("1_00").unwrap(), LimitSpec::new(100));
614 assert_eq!(
615 LimitSpec::try_from("local:1_00").unwrap(),
616 LimitSpec {
617 limit: 100,
618 force_local: true
619 }
620 );
621 assert_eq!(
622 LimitSpec::try_from("three").unwrap_err(),
623 "invalid limit 'three': invalid digit found in string".to_string()
624 );
625 }
626}