throttle/
lib.rs

1//! This crate implements a throttling API based on a generic cell rate algorithm.
2//! The implementation uses an in-memory store, but can be adjusted in the future
3//! to support using a redis-cell equipped redis server to share the throttles
4//! among multiple machines.
5#[cfg(feature = "redis")]
6use mod_redis::RedisError;
7use serde::{Deserialize, Deserializer, Serialize};
8use std::convert::TryFrom;
9use std::time::Duration;
10use thiserror::Error;
11
12#[cfg(feature = "redis")]
13pub mod limit;
14#[cfg(feature = "redis")]
15mod throttle;
16
17#[cfg(feature = "redis")]
18mod redis {
19    use super::*;
20    use mod_redis::{Cmd, RedisConnection, RedisValue};
21    use std::ops::Deref;
22    use std::sync::OnceLock;
23
24    #[derive(Debug)]
25    pub(crate) struct RedisContext {
26        pub(crate) connection: RedisConnection,
27        pub(crate) has_redis_cell: bool,
28    }
29
30    impl RedisContext {
31        pub async fn try_from(connection: RedisConnection) -> anyhow::Result<Self> {
32            let mut cmd = Cmd::new();
33            cmd.arg("COMMAND").arg("INFO").arg("CL.THROTTLE");
34
35            let rsp = connection.query(cmd).await?;
36            let has_redis_cell = rsp
37                .as_sequence()
38                .map_or(false, |arr| arr.iter().any(|v| v != &RedisValue::Nil));
39
40            Ok(Self {
41                has_redis_cell,
42                connection,
43            })
44        }
45    }
46
47    impl Deref for RedisContext {
48        type Target = RedisConnection;
49        fn deref(&self) -> &Self::Target {
50            &self.connection
51        }
52    }
53
54    pub(crate) static REDIS: OnceLock<RedisContext> = OnceLock::new();
55
56    pub async fn use_redis(conn: RedisConnection) -> Result<(), Error> {
57        REDIS
58            .set(RedisContext::try_from(conn).await?)
59            .map_err(|_| Error::Generic("redis already configured for throttles".to_string()))?;
60        Ok(())
61    }
62}
63
64#[cfg(feature = "redis")]
65pub use redis::use_redis;
66#[cfg(feature = "redis")]
67pub(crate) use redis::REDIS;
68
69#[derive(Error, Debug)]
70pub enum Error {
71    #[error("{0}")]
72    Generic(String),
73    #[error("{0}")]
74    AnyHow(#[from] anyhow::Error),
75    #[cfg(feature = "redis")]
76    #[error("{0}")]
77    Redis(#[from] RedisError),
78    #[error("TooManyLeases, try again in {0:?}")]
79    TooManyLeases(Duration),
80    #[error("NonExistentLease")]
81    NonExistentLease,
82}
83
84#[derive(Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Hash)]
85#[serde(try_from = "String", into = "String")]
86pub struct ThrottleSpec {
87    pub limit: u64,
88    /// Period, in seconds
89    pub period: u64,
90    /// Constrain how quickly the throttle can be consumed per
91    /// throttle interval (period / limit).
92    /// max_burst defaults to limit, allowing the entire throttle
93    /// to be used in an instant. Setting max_burst to 1 will
94    /// only allow 1 throttle bump per interval, spreading out
95    /// the utilization of the throttle more evenly over time.
96    /// Larger values allow more of the throttle to be used
97    /// per period.
98    pub max_burst: Option<u64>,
99    pub force_local: bool,
100}
101
102#[cfg(feature = "redis")]
103impl ThrottleSpec {
104    pub async fn throttle<S: AsRef<str>>(&self, key: S) -> Result<ThrottleResult, Error> {
105        self.throttle_quantity(key, 1).await
106    }
107
108    pub async fn throttle_quantity<S: AsRef<str>>(
109        &self,
110        key: S,
111        quantity: u64,
112    ) -> Result<ThrottleResult, Error> {
113        let key = key.as_ref();
114        let limit = self.limit;
115        let period = self.period;
116        let max_burst = self.max_burst.unwrap_or(limit);
117        let key = format!("{key}:{limit}:{max_burst}:{period}");
118        throttle::throttle(
119            &key,
120            limit,
121            Duration::from_secs(period),
122            max_burst,
123            Some(quantity),
124            self.force_local,
125        )
126        .await
127    }
128
129    /// Returns the effective burst value for this throttle spec
130    pub fn burst(&self) -> u64 {
131        self.max_burst.unwrap_or(self.limit)
132    }
133
134    /// Returns the throttle interval over which the burst applies
135    pub fn interval(&self) -> Duration {
136        Duration::from_secs_f64(self.period as f64 / self.limit.max(1) as f64)
137    }
138
139    pub fn as_local(&self) -> Self {
140        let mut copy = self.clone();
141        copy.force_local = true;
142        copy
143    }
144}
145
146impl std::fmt::Debug for ThrottleSpec {
147    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
148        write!(fmt, "{}", self.as_string())
149    }
150}
151
152impl std::fmt::Display for ThrottleSpec {
153    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
154        write!(fmt, "{}", self.as_string())
155    }
156}
157
158impl ThrottleSpec {
159    pub fn as_string(&self) -> String {
160        let mut period_scale = None;
161        let period = match self.period {
162            86400 => "d",
163            3600 => "h",
164            60 => "m",
165            1 => "s",
166            other => {
167                period_scale.replace(other.to_string());
168                "s"
169            }
170        };
171
172        let burst = match self.max_burst {
173            Some(b) => format!(",max_burst={b}"),
174            None => String::new(),
175        };
176
177        format!(
178            "{}{}/{}{period}{burst}",
179            if self.force_local { "local:" } else { "" },
180            self.limit,
181            match &period_scale {
182                Some(scale) => scale.as_str(),
183                None => "",
184            }
185        )
186    }
187}
188
189impl From<ThrottleSpec> for String {
190    fn from(spec: ThrottleSpec) -> String {
191        spec.as_string()
192    }
193}
194
195impl TryFrom<String> for ThrottleSpec {
196    type Error = String;
197    fn try_from(s: String) -> Result<Self, String> {
198        Self::try_from(s.as_str())
199    }
200}
201
202fn opt_digit_prefix(s: &str) -> Result<(u64, &str), String> {
203    let mut n: Option<u64> = None;
204    for (idx, c) in s.char_indices() {
205        if !c.is_ascii_digit() {
206            return Ok((n.unwrap_or(1), &s[idx..]));
207        }
208
209        let byte = c as u8;
210
211        let digit = match byte.checked_sub(b'0') {
212            None => return Err(format!("invalid digit {c}")),
213            Some(digit) if digit > 9 => return Err(format!("invalid digit {c}")),
214            Some(digit) => {
215                debug_assert!((0..=9).contains(&digit));
216                u64::from(digit)
217            }
218        };
219
220        n = Some(
221            n.take()
222                .unwrap_or(0)
223                .checked_mul(10)
224                .and_then(|n| n.checked_add(digit))
225                .ok_or_else(|| format!("number too big"))?,
226        );
227    }
228
229    Err(format!("invalid period quantity {s}"))
230}
231
232/// Allow "1_000" and "1,000" for more readable config
233fn parse_separated_number(limit: &str) -> Result<u64, String> {
234    let value: String = limit
235        .chars()
236        .filter_map(|c| match c {
237            '_' | ',' => None,
238            c => Some(c),
239        })
240        .collect();
241
242    value
243        .parse::<u64>()
244        .map_err(|err| format!("invalid limit '{limit}': {err:#}"))
245}
246
247impl TryFrom<&str> for ThrottleSpec {
248    type Error = String;
249    fn try_from(s: &str) -> Result<Self, String> {
250        let (force_local, s) = match s.strip_prefix("local:") {
251            Some(s) => (true, s),
252            None => (false, s),
253        };
254
255        let (s, max_burst) = match s.split_once(",max_burst=") {
256            Some((s, burst_spec)) => {
257                let burst = parse_separated_number(burst_spec)?;
258                (s, Some(burst))
259            }
260            None => (s, None),
261        };
262
263        let (limit, period) = s
264            .split_once("/")
265            .ok_or_else(|| format!("expected 'limit/period', got {s}"))?;
266
267        let (period_scale, period) = opt_digit_prefix(period)?;
268
269        let period = match period {
270            "h" | "hr" | "hour" => 3600,
271            "m" | "min" | "minute" => 60,
272            "s" | "sec" | "second" => 1,
273            "d" | "day" => 86400,
274            invalid => return Err(format!("unknown period quantity {invalid}")),
275        } * period_scale;
276
277        // Allow "1_000/hr" and "1,000/hr" for more readable config
278        let limit = parse_separated_number(limit)?;
279
280        if limit == 0 {
281            return Err(format!(
282                "invalid ThrottleSpec `{s}`: limit must be greater than 0!"
283            ));
284        }
285
286        Ok(Self {
287            limit,
288            period,
289            max_burst,
290            force_local,
291        })
292    }
293}
294
295#[derive(Debug, Eq, PartialEq, Serialize)]
296pub struct ThrottleResult {
297    /// true if the action was limited
298    pub throttled: bool,
299    /// The total limit of the key (max_burst + 1). This is equivalent to the common
300    /// X-RateLimit-Limit HTTP header.
301    pub limit: u64,
302    /// The remaining limit of the key. Equivalent to X-RateLimit-Remaining.
303    pub remaining: u64,
304    /// The number of seconds until the limit will reset to its maximum capacity.
305    /// Equivalent to X-RateLimit-Reset.
306    pub reset_after: Duration,
307    /// The number of seconds until the user should retry, but None if the action was
308    /// allowed. Equivalent to Retry-After.
309    pub retry_after: Option<Duration>,
310}
311
312#[derive(Eq, PartialEq, Clone, Copy, Serialize, Hash)]
313pub struct LimitSpec {
314    /// Maximum amount
315    pub limit: u64,
316    pub force_local: bool,
317}
318
319impl std::fmt::Debug for LimitSpec {
320    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
321        if self.force_local {
322            write!(fmt, "local:{:?}", self.limit)
323        } else {
324            self.limit.fmt(fmt)
325        }
326    }
327}
328
329impl std::fmt::Display for LimitSpec {
330    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
331        write!(fmt, "{:?}", self)
332    }
333}
334
335impl LimitSpec {
336    pub const fn new(limit: u64) -> Self {
337        Self {
338            limit,
339            force_local: false,
340        }
341    }
342}
343
344impl TryFrom<&str> for LimitSpec {
345    type Error = String;
346    fn try_from(s: &str) -> Result<Self, String> {
347        let (force_local, s) = match s.strip_prefix("local:") {
348            Some(s) => (true, s),
349            None => (false, s),
350        };
351
352        // Allow "1_000/hr" and "1,000/hr" for more readable config
353        let limit: String = s
354            .chars()
355            .filter_map(|c| match c {
356                '_' | ',' => None,
357                c => Some(c),
358            })
359            .collect();
360
361        let limit = limit
362            .parse::<u64>()
363            .map_err(|err| format!("invalid limit '{limit}': {err:#}"))?;
364
365        if limit == 0 {
366            return Err(format!(
367                "invalid LimitSpec `{s}`: limit must be greater than 0!"
368            ));
369        }
370
371        Ok(Self { limit, force_local })
372    }
373}
374
375impl<'de> Deserialize<'de> for LimitSpec {
376    fn deserialize<D>(deserializer: D) -> Result<Self, <D as Deserializer<'de>>::Error>
377    where
378        D: Deserializer<'de>,
379    {
380        use serde::de::Visitor;
381
382        struct Helper {}
383        impl<'de> Visitor<'de> for Helper {
384            type Value = LimitSpec;
385
386            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
387                formatter.write_str("string or numeric limit spec")
388            }
389
390            fn visit_str<E>(self, value: &str) -> Result<LimitSpec, E>
391            where
392                E: serde::de::Error,
393            {
394                value.try_into().map_err(|err| E::custom(err))
395            }
396
397            fn visit_i8<E>(self, value: i8) -> Result<LimitSpec, E>
398            where
399                E: serde::de::Error,
400            {
401                if value < 1 {
402                    return Err(E::custom("limit must be 1 or larger"));
403                }
404                Ok(LimitSpec {
405                    limit: value as u64,
406                    force_local: false,
407                })
408            }
409
410            fn visit_i16<E>(self, value: i16) -> Result<LimitSpec, E>
411            where
412                E: serde::de::Error,
413            {
414                if value < 1 {
415                    return Err(E::custom("limit must be 1 or larger"));
416                }
417                Ok(LimitSpec {
418                    limit: value as u64,
419                    force_local: false,
420                })
421            }
422
423            fn visit_i32<E>(self, value: i32) -> Result<LimitSpec, E>
424            where
425                E: serde::de::Error,
426            {
427                if value < 1 {
428                    return Err(E::custom("limit must be 1 or larger"));
429                }
430                Ok(LimitSpec {
431                    limit: value as u64,
432                    force_local: false,
433                })
434            }
435
436            fn visit_i64<E>(self, value: i64) -> Result<LimitSpec, E>
437            where
438                E: serde::de::Error,
439            {
440                if value < 1 {
441                    return Err(E::custom("limit must be 1 or larger"));
442                }
443                Ok(LimitSpec {
444                    limit: value as u64,
445                    force_local: false,
446                })
447            }
448
449            fn visit_u8<E>(self, value: u8) -> Result<LimitSpec, E>
450            where
451                E: serde::de::Error,
452            {
453                if value < 1 {
454                    return Err(E::custom("limit must be 1 or larger"));
455                }
456                Ok(LimitSpec {
457                    limit: value as u64,
458                    force_local: false,
459                })
460            }
461
462            fn visit_u16<E>(self, value: u16) -> Result<LimitSpec, E>
463            where
464                E: serde::de::Error,
465            {
466                if value < 1 {
467                    return Err(E::custom("limit must be 1 or larger"));
468                }
469                Ok(LimitSpec {
470                    limit: value as u64,
471                    force_local: false,
472                })
473            }
474
475            fn visit_u32<E>(self, value: u32) -> Result<LimitSpec, E>
476            where
477                E: serde::de::Error,
478            {
479                if value < 1 {
480                    return Err(E::custom("limit must be 1 or larger"));
481                }
482                Ok(LimitSpec {
483                    limit: value as u64,
484                    force_local: false,
485                })
486            }
487
488            fn visit_u64<E>(self, value: u64) -> Result<LimitSpec, E>
489            where
490                E: serde::de::Error,
491            {
492                if value < 1 {
493                    return Err(E::custom("limit must be 1 or larger"));
494                }
495                Ok(LimitSpec {
496                    limit: value,
497                    force_local: false,
498                })
499            }
500        }
501
502        deserializer.deserialize_any(Helper {})
503    }
504}
505
506#[cfg(test)]
507mod test {
508    use super::*;
509
510    #[test]
511    fn throttle_spec_parse() {
512        assert_eq!(
513            ThrottleSpec::try_from("100/hr").unwrap(),
514            ThrottleSpec {
515                limit: 100,
516                period: 3600,
517                max_burst: None,
518                force_local: false,
519            }
520        );
521        assert_eq!(
522            ThrottleSpec::try_from("local:100/hr").unwrap(),
523            ThrottleSpec {
524                limit: 100,
525                period: 3600,
526                max_burst: None,
527                force_local: true,
528            }
529        );
530
531        assert_eq!(
532            ThrottleSpec {
533                limit: 100,
534                period: 3600,
535                max_burst: None,
536                force_local: false,
537            }
538            .as_string(),
539            "100/h"
540        );
541        assert_eq!(
542            ThrottleSpec {
543                limit: 100,
544                period: 3600,
545                max_burst: None,
546                force_local: true,
547            }
548            .as_string(),
549            "local:100/h"
550        );
551
552        let weird_duration = ThrottleSpec::try_from("local:100/123m").unwrap();
553        assert_eq!(
554            weird_duration,
555            ThrottleSpec {
556                limit: 100,
557                period: 123 * 60,
558                max_burst: None,
559                force_local: true,
560            }
561        );
562        assert_eq!(weird_duration.as_string(), "local:100/7380s");
563
564        assert_eq!(
565            ThrottleSpec::try_from("1_0,0/hour").unwrap(),
566            ThrottleSpec {
567                limit: 100,
568                period: 3600,
569                max_burst: None,
570                force_local: false,
571            }
572        );
573        assert_eq!(
574            ThrottleSpec::try_from("100/our").unwrap_err(),
575            "unknown period quantity our".to_string()
576        );
577        assert_eq!(
578            ThrottleSpec::try_from("three/hour").unwrap_err(),
579            "invalid limit 'three': invalid digit found in string".to_string()
580        );
581
582        let burst = ThrottleSpec::try_from("50/day,max_burst=1").unwrap();
583        assert_eq!(
584            burst,
585            ThrottleSpec {
586                limit: 50,
587                period: 86400,
588                max_burst: Some(1),
589                force_local: false,
590            }
591        );
592        assert_eq!(burst.as_string(), "50/d,max_burst=1");
593        assert_eq!(burst.burst(), 1);
594        assert_eq!(format!("{:?}", burst.interval()), "1728s");
595    }
596
597    #[test]
598    fn test_opt_digit_prefix() {
599        assert_eq!(opt_digit_prefix("m").unwrap(), (1, "m"));
600        assert_eq!(
601            opt_digit_prefix("1").unwrap_err(),
602            "invalid period quantity 1"
603        );
604        assert_eq!(opt_digit_prefix("1q").unwrap(), (1, "q"));
605        assert_eq!(opt_digit_prefix("2s").unwrap(), (2, "s"));
606        assert_eq!(opt_digit_prefix("20s").unwrap(), (20, "s"));
607        assert_eq!(opt_digit_prefix("12378s").unwrap(), (12378, "s"));
608    }
609
610    #[test]
611    fn limit_spec_parse() {
612        assert_eq!(LimitSpec::try_from("100").unwrap(), LimitSpec::new(100));
613        assert_eq!(LimitSpec::try_from("1_00").unwrap(), LimitSpec::new(100));
614        assert_eq!(
615            LimitSpec::try_from("local:1_00").unwrap(),
616            LimitSpec {
617                limit: 100,
618                force_local: true
619            }
620        );
621        assert_eq!(
622            LimitSpec::try_from("three").unwrap_err(),
623            "invalid limit 'three': invalid digit found in string".to_string()
624        );
625    }
626}