throttle/
throttle.rs

1use crate::{Error, ThrottleResult, REDIS};
2use anyhow::Context;
3use mod_redis::{Cmd, FromRedisValue, RedisConnection, Script};
4use redis_cell_impl::{time, MemoryStore, Rate, RateLimiter, RateQuota};
5use std::sync::{LazyLock, Mutex};
6use std::time::Duration;
7
8static MEMORY: LazyLock<Mutex<MemoryStore>> = LazyLock::new(|| Mutex::new(MemoryStore::new()));
9
10// Adapted from https://github.com/Losant/redis-gcra/blob/master/lib/gcra.lua
11static GCRA_SCRIPT: LazyLock<Script> = LazyLock::new(|| {
12    Script::new(
13        r#"
14local key = KEYS[1]
15local limit = ARGV[1]
16local period = ARGV[2]
17local max_burst = ARGV[3]
18local quantity = ARGV[4]
19
20local interval = period / limit
21local increment = interval * quantity
22local burst_offset = interval * max_burst
23
24local now = tonumber(redis.call("TIME")[1])
25local tat = redis.call("GET", key)
26
27if not tat then
28  tat = now
29else
30  tat = tonumber(tat)
31end
32tat = math.max(tat, now)
33
34local new_tat = tat + increment
35local allow_at = new_tat - burst_offset
36local diff = now - allow_at
37
38local throttled
39local reset_after
40local retry_after
41
42local remaining = math.floor(diff / interval) -- poor man's round
43
44if remaining < 0 then
45  throttled = 1
46  -- calculate how many tokens there actually are, since
47  -- remaining is how many there would have been if we had been able to limit
48  -- and we did not limit
49  remaining = math.floor((now - (tat - burst_offset)) / interval)
50  reset_after = math.ceil(tat - now)
51  retry_after = math.ceil(diff * -1)
52elseif remaining == 0 and increment <= 0 then
53  -- request with cost of 0
54  -- cost of 0 with remaining 0 is still limited
55  throttled = 1
56  remaining = 0
57  reset_after = math.ceil(tat - now)
58  retry_after = 0 -- retry_after is meaningless when quantity is 0
59else
60  throttled = 0
61  reset_after = math.ceil(new_tat - now)
62  retry_after = 0
63  redis.call("SET", key, new_tat, "EX", reset_after)
64end
65
66return {throttled, remaining, reset_after, retry_after, tostring(diff), tostring(interval)}
67"#,
68    )
69});
70
71fn local_throttle(
72    key: &str,
73    limit: u64,
74    period: Duration,
75    max_burst: u64,
76    quantity: Option<u64>,
77) -> Result<ThrottleResult, Error> {
78    let mut store = MEMORY.lock().unwrap();
79    let max_rate = Rate::per_period(
80        limit as i64,
81        time::Duration::try_from(period).map_err(|err| Error::Generic(format!("{err:#}")))?,
82    );
83    let mut limiter = RateLimiter::new(
84        &mut *store,
85        &RateQuota {
86            max_burst: max_burst.min(limit - 1) as i64,
87            max_rate,
88        },
89    );
90    let quantity = quantity.unwrap_or(1) as i64;
91    let (throttled, rate_limit_result) = limiter
92        .rate_limit(key, quantity)
93        .map_err(|err| Error::Generic(format!("{err:#}")))?;
94
95    // If either time had a partial component, bump it up to the next full
96    // second because otherwise a fast-paced caller could try again too
97    // early.
98    let mut retry_after = rate_limit_result.retry_after.whole_seconds();
99    if rate_limit_result.retry_after.subsec_milliseconds() > 0 {
100        retry_after += 1
101    }
102    let mut reset_after = rate_limit_result.reset_after.whole_seconds();
103    if rate_limit_result.reset_after.subsec_milliseconds() > 0 {
104        reset_after += 1
105    }
106
107    Ok(ThrottleResult {
108        throttled,
109        limit: rate_limit_result.limit as u64,
110        remaining: rate_limit_result.remaining as u64,
111        reset_after: Duration::from_secs(reset_after.max(0) as u64),
112        retry_after: if retry_after <= 0 {
113            None
114        } else {
115            Some(Duration::from_secs(retry_after.max(0) as u64))
116        },
117    })
118}
119
120async fn redis_cell_throttle(
121    conn: &RedisConnection,
122    key: &str,
123    limit: u64,
124    period: Duration,
125    max_burst: u64,
126    quantity: Option<u64>,
127) -> Result<ThrottleResult, Error> {
128    let mut cmd = Cmd::new();
129    cmd.arg("CL.THROTTLE")
130        .arg(key)
131        .arg(max_burst)
132        .arg(limit)
133        .arg(period.as_secs())
134        .arg(quantity.unwrap_or(1));
135    let result = conn.query(cmd).await?;
136    let result = <Vec<i64> as FromRedisValue>::from_redis_value(&result)?;
137
138    Ok(ThrottleResult {
139        throttled: result[0] != 0,
140        limit: result[1] as u64,
141        remaining: result[2] as u64,
142        retry_after: match result[3] {
143            n if n <= 0 => None,
144            n => Some(Duration::from_secs(n as u64)),
145        },
146        reset_after: Duration::from_secs(result[4].max(0) as u64),
147    })
148}
149
150async fn redis_script_throttle(
151    conn: &RedisConnection,
152    key: &str,
153    limit: u64,
154    period: Duration,
155    max_burst: u64,
156    quantity: Option<u64>,
157) -> Result<ThrottleResult, Error> {
158    let mut script = GCRA_SCRIPT.prepare_invoke();
159    script
160        .key(key)
161        .arg(limit)
162        .arg(period.as_secs())
163        .arg(max_burst)
164        .arg(quantity.unwrap_or(1));
165
166    let result = conn
167        .invoke_script(script)
168        .await
169        .context("error invoking redis GCRA script")?;
170    let result =
171        <(u64, u64, u64, u64, String, String) as FromRedisValue>::from_redis_value(&result)?;
172
173    Ok(ThrottleResult {
174        throttled: result.0 == 1,
175        limit: max_burst + 1,
176        remaining: result.1,
177        retry_after: match result.3 {
178            0 => None,
179            n => Some(Duration::from_secs(n)),
180        },
181        reset_after: Duration::from_secs(result.2),
182    })
183}
184
185/// It is very important for `key` to be used with the same `limit`,
186/// `period` and `max_burst` values in order to produce meaningful
187/// results.
188///
189/// This interface cannot detect or report that kind of misuse.
190/// It is recommended that those parameters be encoded into the
191/// key to make it impossible to misuse.
192///
193/// * `limit` - specifies the maximum number of tokens allow
194///             over the specified `period`
195/// * `period` - the time period over which `limit` is allowed.
196/// * `max_burst` - the maximum initial burst that will be permitted.
197///                 set this smaller than `limit` to prevent using
198///                 up the entire budget immediately and force it
199///                 to spread out across time.
200/// * `quantity` - how many tokens to add to the throttle. If omitted,
201///                1 token is added.
202/// * `force_local` - if true, always use the in-memory store on the local
203///                   machine even if the redis backend has been configured.
204pub async fn throttle(
205    key: &str,
206    limit: u64,
207    period: Duration,
208    max_burst: u64,
209    quantity: Option<u64>,
210    force_local: bool,
211) -> Result<ThrottleResult, Error> {
212    match (force_local, REDIS.get()) {
213        (false, Some(cx)) => {
214            if cx.has_redis_cell {
215                redis_cell_throttle(&cx, key, limit, period, max_burst, quantity).await
216            } else {
217                redis_script_throttle(&cx, key, limit, period, max_burst, quantity).await
218            }
219        }
220        _ => local_throttle(key, limit, period, max_burst, quantity),
221    }
222}
223
224#[cfg(test)]
225mod test {
226    use super::*;
227    use crate::redis::RedisContext;
228    use mod_redis::test::RedisServer;
229
230    trait Throttler {
231        async fn throttle(
232            &self,
233            key: &str,
234            limit: u64,
235            period: Duration,
236            max_burst: u64,
237            quantity: Option<u64>,
238        ) -> Result<ThrottleResult, Error>;
239    }
240
241    impl Throttler for Mutex<MemoryStore> {
242        async fn throttle(
243            &self,
244            key: &str,
245            limit: u64,
246            period: Duration,
247            max_burst: u64,
248            quantity: Option<u64>,
249        ) -> Result<ThrottleResult, Error> {
250            local_throttle(key, limit, period, max_burst, quantity)
251        }
252    }
253
254    struct RedisWithCell(RedisConnection);
255
256    impl Throttler for RedisWithCell {
257        async fn throttle(
258            &self,
259            key: &str,
260            limit: u64,
261            period: Duration,
262            max_burst: u64,
263            quantity: Option<u64>,
264        ) -> Result<ThrottleResult, Error> {
265            redis_cell_throttle(&self.0, key, limit, period, max_burst, quantity).await
266        }
267    }
268
269    struct VanillaRedis(RedisConnection);
270
271    impl Throttler for VanillaRedis {
272        async fn throttle(
273            &self,
274            key: &str,
275            limit: u64,
276            period: Duration,
277            max_burst: u64,
278            quantity: Option<u64>,
279        ) -> Result<ThrottleResult, Error> {
280            redis_script_throttle(&self.0, key, limit, period, max_burst, quantity).await
281        }
282    }
283
284    async fn test_big_limits(
285        limit: u64,
286        max_burst: Option<u64>,
287        permitted_tolerance: f64,
288        throttler: &impl Throttler,
289    ) {
290        let period = Duration::from_secs(60);
291        let max_burst = max_burst.unwrap_or(limit);
292        let key = format!("test_big_limits-{limit}-{max_burst}");
293
294        let mut throttled_iter = None;
295
296        for i in 0..limit * 2 {
297            let result = throttler
298                .throttle(&key, limit, period, max_burst, None)
299                .await
300                .unwrap();
301            if result.throttled {
302                println!("iter: {i} -> {result:?}");
303                throttled_iter.replace(i);
304                break;
305            }
306        }
307
308        let throttled_iter = throttled_iter.expect("to hit the throttle limit");
309        let diff = ((max_burst as f64) - (throttled_iter as f64)).abs();
310        let tolerance = (max_burst as f64) * permitted_tolerance;
311        println!(
312            "throttled after {throttled_iter} iterations for \
313                 limit {limit}. diff={diff}. tolerance {tolerance}"
314        );
315        let max_rate = Rate::per_period(limit as i64, time::Duration::try_from(period).unwrap());
316        println!("max_rate: {max_rate:?}");
317
318        assert!(
319            diff <= tolerance,
320            "throttled after {throttled_iter} iterations for \
321                limit {limit}. diff={diff} is not within tolerance {tolerance}"
322        );
323    }
324
325    #[tokio::test]
326    async fn basic_throttle_100() {
327        test_big_limits(100, None, 0.01, &*MEMORY).await;
328    }
329
330    #[tokio::test]
331    async fn basic_throttle_1_000() {
332        test_big_limits(1_000, Some(100), 0.02, &*MEMORY).await;
333    }
334
335    #[tokio::test]
336    async fn basic_throttle_6_000() {
337        test_big_limits(6_000, Some(100), 0.02, &*MEMORY).await;
338    }
339
340    #[tokio::test]
341    async fn basic_throttle_60_000() {
342        test_big_limits(60_000, Some(100), 0.1, &*MEMORY).await;
343    }
344
345    #[tokio::test]
346    async fn basic_throttle_60_000_burst_30k() {
347        // Note that the 5% tolerance here is the same as the basic_throttle_60_000
348        // test case because the variance is due to timing issues with very small
349        // time periods produced by the overally limit, rather than the burst.
350        test_big_limits(60_000, Some(100), 0.1, &*MEMORY).await;
351    }
352
353    #[tokio::test]
354    async fn redis_cell_throttle_1_000() {
355        if !RedisServer::is_available() {
356            return;
357        }
358
359        let redis = RedisServer::spawn("").await.unwrap();
360        let conn = redis.connection().await.unwrap();
361        let cx = RedisContext::try_from(conn).await.unwrap();
362        if !cx.has_redis_cell {
363            return;
364        }
365
366        test_big_limits(1_000, None, 0.02, &RedisWithCell(cx.connection)).await;
367    }
368
369    #[tokio::test]
370    async fn redis_script_throttle_1_000() {
371        if !RedisServer::is_available() {
372            return;
373        }
374
375        let redis = RedisServer::spawn("").await.unwrap();
376        let conn = redis.connection().await.unwrap();
377        let cx = RedisContext::try_from(conn).await.unwrap();
378        test_big_limits(1_000, None, 0.2, &VanillaRedis(cx.connection)).await;
379    }
380}