throttle/
throttle.rs

1use crate::{Error, ThrottleResult, REDIS};
2use anyhow::Context;
3use mod_redis::{Cmd, FromRedisValue, RedisConnection, Script};
4use redis_cell_impl::{time, MemoryStore, Rate, RateLimiter, RateQuota};
5use std::sync::{LazyLock, Mutex};
6use std::time::Duration;
7
8static MEMORY: LazyLock<Mutex<MemoryStore>> = LazyLock::new(|| Mutex::new(MemoryStore::new()));
9
10// Adapted from https://github.com/Losant/redis-gcra/blob/master/lib/gcra.lua
11static GCRA_SCRIPT: LazyLock<Script> = LazyLock::new(|| {
12    Script::new(
13        r#"
14local key = KEYS[1]
15local limit = ARGV[1]
16local period = ARGV[2]
17local max_burst = ARGV[3]
18local quantity = ARGV[4]
19
20local interval = period / limit
21local increment = interval * quantity
22local burst_offset = interval * max_burst
23
24local now = tonumber(redis.call("TIME")[1])
25local tat = redis.call("GET", key)
26
27if not tat then
28  tat = now
29else
30  tat = tonumber(tat)
31end
32tat = math.max(tat, now)
33
34local new_tat = tat + increment
35local allow_at = new_tat - burst_offset
36local diff = now - allow_at
37
38local throttled
39local reset_after
40local retry_after
41
42local remaining = math.floor(diff / interval) -- poor man's round
43
44if remaining < 0 then
45  throttled = 1
46  -- calculate how many tokens there actually are, since
47  -- remaining is how many there would have been if we had been able to limit
48  -- and we did not limit
49  remaining = math.floor((now - (tat - burst_offset)) / interval)
50  reset_after = math.ceil(tat - now)
51  retry_after = math.ceil(diff * -1)
52elseif remaining == 0 and increment <= 0 then
53  -- request with cost of 0
54  -- cost of 0 with remaining 0 is still limited
55  throttled = 1
56  remaining = 0
57  reset_after = math.ceil(tat - now)
58  retry_after = 0 -- retry_after is meaningless when quantity is 0
59else
60  throttled = 0
61  reset_after = math.ceil(new_tat - now)
62  retry_after = 0
63  redis.call("SET", key, new_tat, "EX", reset_after)
64end
65
66return {throttled, remaining, reset_after, retry_after, tostring(diff), tostring(interval)}
67"#,
68    )
69});
70
71fn local_throttle(
72    key: &str,
73    limit: u64,
74    period: Duration,
75    max_burst: u64,
76    quantity: Option<u64>,
77) -> Result<ThrottleResult, Error> {
78    let mut store = MEMORY.lock().unwrap();
79    let max_rate = Rate::per_period(
80        limit as i64,
81        time::Duration::try_from(period).map_err(|err| Error::Generic(format!("{err:#}")))?,
82    );
83    let mut limiter = RateLimiter::new(
84        &mut *store,
85        &RateQuota {
86            max_burst: max_burst.min(limit - 1) as i64,
87            max_rate,
88        },
89    );
90    let quantity = quantity.unwrap_or(1) as i64;
91    let (throttled, rate_limit_result) = limiter
92        .rate_limit(key, quantity)
93        .map_err(|err| Error::Generic(format!("{err:#}")))?;
94
95    // If either time had a partial component, bump it up to the next full
96    // second because otherwise a fast-paced caller could try again too
97    // early.
98    let mut retry_after = rate_limit_result.retry_after.whole_seconds();
99    if rate_limit_result.retry_after.subsec_milliseconds() > 0 {
100        retry_after += 1
101    }
102    let mut reset_after = rate_limit_result.reset_after.whole_seconds();
103    if rate_limit_result.reset_after.subsec_milliseconds() > 0 {
104        reset_after += 1
105    }
106
107    Ok(ThrottleResult {
108        throttled,
109        limit: rate_limit_result.limit as u64,
110        remaining: rate_limit_result.remaining as u64,
111        reset_after: Duration::from_secs(reset_after.max(0) as u64),
112        retry_after: if retry_after <= 0 {
113            None
114        } else {
115            Some(Duration::from_secs(retry_after.max(0) as u64))
116        },
117    })
118}
119
120async fn redis_cell_throttle(
121    conn: &RedisConnection,
122    key: &str,
123    limit: u64,
124    period: Duration,
125    max_burst: u64,
126    quantity: Option<u64>,
127) -> Result<ThrottleResult, Error> {
128    let mut cmd = Cmd::new();
129    cmd.arg("CL.THROTTLE")
130        .arg(key)
131        .arg(max_burst)
132        .arg(limit)
133        .arg(period.as_secs())
134        .arg(quantity.unwrap_or(1));
135    let result = conn.query(cmd).await?;
136    let result = <Vec<i64> as FromRedisValue>::from_redis_value(&result)?;
137
138    Ok(ThrottleResult {
139        throttled: result[0] != 0,
140        limit: result[1] as u64,
141        remaining: result[2] as u64,
142        retry_after: match result[3] {
143            n if n <= 0 => None,
144            n => Some(Duration::from_secs(n as u64)),
145        },
146        reset_after: Duration::from_secs(result[4].max(0) as u64),
147    })
148}
149
150async fn redis_script_throttle(
151    conn: &RedisConnection,
152    key: &str,
153    limit: u64,
154    period: Duration,
155    max_burst: u64,
156    quantity: Option<u64>,
157) -> Result<ThrottleResult, Error> {
158    let mut script = GCRA_SCRIPT.prepare_invoke();
159    script
160        .key(key)
161        .arg(limit)
162        .arg(period.as_secs())
163        .arg(max_burst)
164        .arg(quantity.unwrap_or(1));
165
166    let result = conn
167        .invoke_script(script)
168        .await
169        .context("error invoking redis GCRA script")?;
170    let result =
171        <(u64, u64, u64, u64, String, String) as FromRedisValue>::from_redis_value(&result)?;
172
173    Ok(ThrottleResult {
174        throttled: result.0 == 1,
175        limit: max_burst + 1,
176        remaining: result.1,
177        retry_after: match result.3 {
178            0 => None,
179            n => Some(Duration::from_secs(n)),
180        },
181        reset_after: Duration::from_secs(result.2),
182    })
183}
184
185/// It is very important for `key` to be used with the same `limit`,
186/// `period` and `max_burst` values in order to produce meaningful
187/// results.
188///
189/// This interface cannot detect or report that kind of misuse.
190/// It is recommended that those parameters be encoded into the
191/// key to make it impossible to misuse.
192///
193/// * `limit` - specifies the maximum number of tokens allow
194///             over the specified `period`
195/// * `period` - the time period over which `limit` is allowed.
196/// * `max_burst` - the maximum initial burst that will be permitted.
197///                 set this smaller than `limit` to prevent using
198///                 up the entire budget immediately and force it
199///                 to spread out across time.
200/// * `quantity` - how many tokens to add to the throttle. If omitted,
201///                1 token is added.
202/// * `force_local` - if true, always use the in-memory store on the local
203///                   machine even if the redis backend has been configured.
204pub async fn throttle(
205    key: &str,
206    limit: u64,
207    period: Duration,
208    max_burst: u64,
209    quantity: Option<u64>,
210    force_local: bool,
211) -> Result<ThrottleResult, Error> {
212    match (force_local, REDIS.get()) {
213        (false, Some(cx)) => match cx.has_redis_cell {
214            true => redis_cell_throttle(&cx, key, limit, period, max_burst, quantity).await,
215            false => redis_script_throttle(&cx, key, limit, period, max_burst, quantity).await,
216        },
217        _ => local_throttle(key, limit, period, max_burst, quantity),
218    }
219}
220
221#[cfg(test)]
222mod test {
223    use super::*;
224    use crate::redis::RedisContext;
225    use mod_redis::test::RedisServer;
226
227    trait Throttler {
228        async fn throttle(
229            &self,
230            key: &str,
231            limit: u64,
232            period: Duration,
233            max_burst: u64,
234            quantity: Option<u64>,
235        ) -> Result<ThrottleResult, Error>;
236    }
237
238    impl Throttler for Mutex<MemoryStore> {
239        async fn throttle(
240            &self,
241            key: &str,
242            limit: u64,
243            period: Duration,
244            max_burst: u64,
245            quantity: Option<u64>,
246        ) -> Result<ThrottleResult, Error> {
247            local_throttle(key, limit, period, max_burst, quantity)
248        }
249    }
250
251    struct RedisWithCell(RedisConnection);
252
253    impl Throttler for RedisWithCell {
254        async fn throttle(
255            &self,
256            key: &str,
257            limit: u64,
258            period: Duration,
259            max_burst: u64,
260            quantity: Option<u64>,
261        ) -> Result<ThrottleResult, Error> {
262            redis_cell_throttle(&self.0, key, limit, period, max_burst, quantity).await
263        }
264    }
265
266    struct VanillaRedis(RedisConnection);
267
268    impl Throttler for VanillaRedis {
269        async fn throttle(
270            &self,
271            key: &str,
272            limit: u64,
273            period: Duration,
274            max_burst: u64,
275            quantity: Option<u64>,
276        ) -> Result<ThrottleResult, Error> {
277            redis_script_throttle(&self.0, key, limit, period, max_burst, quantity).await
278        }
279    }
280
281    async fn test_big_limits(
282        limit: u64,
283        max_burst: Option<u64>,
284        permitted_tolerance: f64,
285        throttler: &impl Throttler,
286    ) {
287        let period = Duration::from_secs(60);
288        let max_burst = max_burst.unwrap_or(limit);
289        let key = format!("test_big_limits-{limit}-{max_burst}");
290
291        let mut throttled_iter = None;
292
293        for i in 0..limit * 2 {
294            let result = throttler
295                .throttle(&key, limit, period, max_burst, None)
296                .await
297                .unwrap();
298            if result.throttled {
299                println!("iter: {i} -> {result:?}");
300                throttled_iter.replace(i);
301                break;
302            }
303        }
304
305        let throttled_iter = throttled_iter.expect("to hit the throttle limit");
306        let diff = ((max_burst as f64) - (throttled_iter as f64)).abs();
307        let tolerance = (max_burst as f64) * permitted_tolerance;
308        println!(
309            "throttled after {throttled_iter} iterations for \
310                 limit {limit}. diff={diff}. tolerance {tolerance}"
311        );
312        let max_rate = Rate::per_period(limit as i64, time::Duration::try_from(period).unwrap());
313        println!("max_rate: {max_rate:?}");
314
315        assert!(
316            diff <= tolerance,
317            "throttled after {throttled_iter} iterations for \
318                limit {limit}. diff={diff} is not within tolerance {tolerance}"
319        );
320    }
321
322    #[tokio::test]
323    async fn basic_throttle_100() {
324        test_big_limits(100, None, 0.01, &*MEMORY).await;
325    }
326
327    #[tokio::test]
328    async fn basic_throttle_1_000() {
329        test_big_limits(1_000, Some(100), 0.02, &*MEMORY).await;
330    }
331
332    #[tokio::test]
333    async fn basic_throttle_6_000() {
334        test_big_limits(6_000, Some(100), 0.02, &*MEMORY).await;
335    }
336
337    #[tokio::test]
338    async fn basic_throttle_60_000() {
339        test_big_limits(60_000, Some(100), 0.1, &*MEMORY).await;
340    }
341
342    #[tokio::test]
343    async fn basic_throttle_60_000_burst_30k() {
344        // Note that the 5% tolerance here is the same as the basic_throttle_60_000
345        // test case because the variance is due to timing issues with very small
346        // time periods produced by the overally limit, rather than the burst.
347        test_big_limits(60_000, Some(100), 0.1, &*MEMORY).await;
348    }
349
350    #[tokio::test]
351    async fn redis_cell_throttle_1_000() {
352        if !RedisServer::is_available() {
353            return;
354        }
355
356        let redis = RedisServer::spawn("").await.unwrap();
357        let conn = redis.connection().await.unwrap();
358        let cx = RedisContext::try_from(conn).await.unwrap();
359        if !cx.has_redis_cell {
360            return;
361        }
362
363        test_big_limits(1_000, None, 0.02, &RedisWithCell(cx.connection)).await;
364    }
365
366    #[tokio::test]
367    async fn redis_script_throttle_1_000() {
368        if !RedisServer::is_available() {
369            return;
370        }
371
372        let redis = RedisServer::spawn("").await.unwrap();
373        let conn = redis.connection().await.unwrap();
374        let cx = RedisContext::try_from(conn).await.unwrap();
375        test_big_limits(1_000, None, 0.2, &VanillaRedis(cx.connection)).await;
376    }
377}