1use crate::{Error, ThrottleResult, REDIS};
2use anyhow::Context;
3use mod_redis::{Cmd, FromRedisValue, RedisConnection, Script};
4use redis_cell_impl::{time, MemoryStore, Rate, RateLimiter, RateQuota};
5use std::sync::{LazyLock, Mutex};
6use std::time::Duration;
7
8static MEMORY: LazyLock<Mutex<MemoryStore>> = LazyLock::new(|| Mutex::new(MemoryStore::new()));
9
10static GCRA_SCRIPT: LazyLock<Script> = LazyLock::new(|| {
12 Script::new(
13 r#"
14local key = KEYS[1]
15local limit = ARGV[1]
16local period = ARGV[2]
17local max_burst = ARGV[3]
18local quantity = ARGV[4]
19
20local interval = period / limit
21local increment = interval * quantity
22local burst_offset = interval * max_burst
23
24local now = tonumber(redis.call("TIME")[1])
25local tat = redis.call("GET", key)
26
27if not tat then
28 tat = now
29else
30 tat = tonumber(tat)
31end
32tat = math.max(tat, now)
33
34local new_tat = tat + increment
35local allow_at = new_tat - burst_offset
36local diff = now - allow_at
37
38local throttled
39local reset_after
40local retry_after
41
42local remaining = math.floor(diff / interval) -- poor man's round
43
44if remaining < 0 then
45 throttled = 1
46 -- calculate how many tokens there actually are, since
47 -- remaining is how many there would have been if we had been able to limit
48 -- and we did not limit
49 remaining = math.floor((now - (tat - burst_offset)) / interval)
50 reset_after = math.ceil(tat - now)
51 retry_after = math.ceil(diff * -1)
52elseif remaining == 0 and increment <= 0 then
53 -- request with cost of 0
54 -- cost of 0 with remaining 0 is still limited
55 throttled = 1
56 remaining = 0
57 reset_after = math.ceil(tat - now)
58 retry_after = 0 -- retry_after is meaningless when quantity is 0
59else
60 throttled = 0
61 reset_after = math.ceil(new_tat - now)
62 retry_after = 0
63 redis.call("SET", key, new_tat, "EX", reset_after)
64end
65
66return {throttled, remaining, reset_after, retry_after, tostring(diff), tostring(interval)}
67"#,
68 )
69});
70
71fn local_throttle(
72 key: &str,
73 limit: u64,
74 period: Duration,
75 max_burst: u64,
76 quantity: Option<u64>,
77) -> Result<ThrottleResult, Error> {
78 let mut store = MEMORY.lock().unwrap();
79 let max_rate = Rate::per_period(
80 limit as i64,
81 time::Duration::try_from(period).map_err(|err| Error::Generic(format!("{err:#}")))?,
82 );
83 let mut limiter = RateLimiter::new(
84 &mut *store,
85 &RateQuota {
86 max_burst: max_burst.min(limit - 1) as i64,
87 max_rate,
88 },
89 );
90 let quantity = quantity.unwrap_or(1) as i64;
91 let (throttled, rate_limit_result) = limiter
92 .rate_limit(key, quantity)
93 .map_err(|err| Error::Generic(format!("{err:#}")))?;
94
95 let mut retry_after = rate_limit_result.retry_after.whole_seconds();
99 if rate_limit_result.retry_after.subsec_milliseconds() > 0 {
100 retry_after += 1
101 }
102 let mut reset_after = rate_limit_result.reset_after.whole_seconds();
103 if rate_limit_result.reset_after.subsec_milliseconds() > 0 {
104 reset_after += 1
105 }
106
107 Ok(ThrottleResult {
108 throttled,
109 limit: rate_limit_result.limit as u64,
110 remaining: rate_limit_result.remaining as u64,
111 reset_after: Duration::from_secs(reset_after.max(0) as u64),
112 retry_after: if retry_after <= 0 {
113 None
114 } else {
115 Some(Duration::from_secs(retry_after.max(0) as u64))
116 },
117 })
118}
119
120async fn redis_cell_throttle(
121 conn: &RedisConnection,
122 key: &str,
123 limit: u64,
124 period: Duration,
125 max_burst: u64,
126 quantity: Option<u64>,
127) -> Result<ThrottleResult, Error> {
128 let mut cmd = Cmd::new();
129 cmd.arg("CL.THROTTLE")
130 .arg(key)
131 .arg(max_burst)
132 .arg(limit)
133 .arg(period.as_secs())
134 .arg(quantity.unwrap_or(1));
135 let result = conn.query(cmd).await?;
136 let result = <Vec<i64> as FromRedisValue>::from_redis_value(&result)?;
137
138 Ok(ThrottleResult {
139 throttled: result[0] != 0,
140 limit: result[1] as u64,
141 remaining: result[2] as u64,
142 retry_after: match result[3] {
143 n if n <= 0 => None,
144 n => Some(Duration::from_secs(n as u64)),
145 },
146 reset_after: Duration::from_secs(result[4].max(0) as u64),
147 })
148}
149
150async fn redis_script_throttle(
151 conn: &RedisConnection,
152 key: &str,
153 limit: u64,
154 period: Duration,
155 max_burst: u64,
156 quantity: Option<u64>,
157) -> Result<ThrottleResult, Error> {
158 let mut script = GCRA_SCRIPT.prepare_invoke();
159 script
160 .key(key)
161 .arg(limit)
162 .arg(period.as_secs())
163 .arg(max_burst)
164 .arg(quantity.unwrap_or(1));
165
166 let result = conn
167 .invoke_script(script)
168 .await
169 .context("error invoking redis GCRA script")?;
170 let result =
171 <(u64, u64, u64, u64, String, String) as FromRedisValue>::from_redis_value(&result)?;
172
173 Ok(ThrottleResult {
174 throttled: result.0 == 1,
175 limit: max_burst + 1,
176 remaining: result.1,
177 retry_after: match result.3 {
178 0 => None,
179 n => Some(Duration::from_secs(n)),
180 },
181 reset_after: Duration::from_secs(result.2),
182 })
183}
184
185pub async fn throttle(
205 key: &str,
206 limit: u64,
207 period: Duration,
208 max_burst: u64,
209 quantity: Option<u64>,
210 force_local: bool,
211) -> Result<ThrottleResult, Error> {
212 match (force_local, REDIS.get()) {
213 (false, Some(cx)) => {
214 if cx.has_redis_cell {
215 redis_cell_throttle(&cx, key, limit, period, max_burst, quantity).await
216 } else {
217 redis_script_throttle(&cx, key, limit, period, max_burst, quantity).await
218 }
219 }
220 _ => local_throttle(key, limit, period, max_burst, quantity),
221 }
222}
223
224#[cfg(test)]
225mod test {
226 use super::*;
227 use crate::redis::RedisContext;
228 use mod_redis::test::RedisServer;
229
230 trait Throttler {
231 async fn throttle(
232 &self,
233 key: &str,
234 limit: u64,
235 period: Duration,
236 max_burst: u64,
237 quantity: Option<u64>,
238 ) -> Result<ThrottleResult, Error>;
239 }
240
241 impl Throttler for Mutex<MemoryStore> {
242 async fn throttle(
243 &self,
244 key: &str,
245 limit: u64,
246 period: Duration,
247 max_burst: u64,
248 quantity: Option<u64>,
249 ) -> Result<ThrottleResult, Error> {
250 local_throttle(key, limit, period, max_burst, quantity)
251 }
252 }
253
254 struct RedisWithCell(RedisConnection);
255
256 impl Throttler for RedisWithCell {
257 async fn throttle(
258 &self,
259 key: &str,
260 limit: u64,
261 period: Duration,
262 max_burst: u64,
263 quantity: Option<u64>,
264 ) -> Result<ThrottleResult, Error> {
265 redis_cell_throttle(&self.0, key, limit, period, max_burst, quantity).await
266 }
267 }
268
269 struct VanillaRedis(RedisConnection);
270
271 impl Throttler for VanillaRedis {
272 async fn throttle(
273 &self,
274 key: &str,
275 limit: u64,
276 period: Duration,
277 max_burst: u64,
278 quantity: Option<u64>,
279 ) -> Result<ThrottleResult, Error> {
280 redis_script_throttle(&self.0, key, limit, period, max_burst, quantity).await
281 }
282 }
283
284 async fn test_big_limits(
285 limit: u64,
286 max_burst: Option<u64>,
287 permitted_tolerance: f64,
288 throttler: &impl Throttler,
289 ) {
290 let period = Duration::from_secs(60);
291 let max_burst = max_burst.unwrap_or(limit);
292 let key = format!("test_big_limits-{limit}-{max_burst}");
293
294 let mut throttled_iter = None;
295
296 for i in 0..limit * 2 {
297 let result = throttler
298 .throttle(&key, limit, period, max_burst, None)
299 .await
300 .unwrap();
301 if result.throttled {
302 println!("iter: {i} -> {result:?}");
303 throttled_iter.replace(i);
304 break;
305 }
306 }
307
308 let throttled_iter = throttled_iter.expect("to hit the throttle limit");
309 let diff = ((max_burst as f64) - (throttled_iter as f64)).abs();
310 let tolerance = (max_burst as f64) * permitted_tolerance;
311 println!(
312 "throttled after {throttled_iter} iterations for \
313 limit {limit}. diff={diff}. tolerance {tolerance}"
314 );
315 let max_rate = Rate::per_period(limit as i64, time::Duration::try_from(period).unwrap());
316 println!("max_rate: {max_rate:?}");
317
318 assert!(
319 diff <= tolerance,
320 "throttled after {throttled_iter} iterations for \
321 limit {limit}. diff={diff} is not within tolerance {tolerance}"
322 );
323 }
324
325 #[tokio::test]
326 async fn basic_throttle_100() {
327 test_big_limits(100, None, 0.01, &*MEMORY).await;
328 }
329
330 #[tokio::test]
331 async fn basic_throttle_1_000() {
332 test_big_limits(1_000, Some(100), 0.02, &*MEMORY).await;
333 }
334
335 #[tokio::test]
336 async fn basic_throttle_6_000() {
337 test_big_limits(6_000, Some(100), 0.02, &*MEMORY).await;
338 }
339
340 #[tokio::test]
341 async fn basic_throttle_60_000() {
342 test_big_limits(60_000, Some(100), 0.1, &*MEMORY).await;
343 }
344
345 #[tokio::test]
346 async fn basic_throttle_60_000_burst_30k() {
347 test_big_limits(60_000, Some(100), 0.1, &*MEMORY).await;
351 }
352
353 #[tokio::test]
354 async fn redis_cell_throttle_1_000() {
355 if !RedisServer::is_available() {
356 return;
357 }
358
359 let redis = RedisServer::spawn("").await.unwrap();
360 let conn = redis.connection().await.unwrap();
361 let cx = RedisContext::try_from(conn).await.unwrap();
362 if !cx.has_redis_cell {
363 return;
364 }
365
366 test_big_limits(1_000, None, 0.02, &RedisWithCell(cx.connection)).await;
367 }
368
369 #[tokio::test]
370 async fn redis_script_throttle_1_000() {
371 if !RedisServer::is_available() {
372 return;
373 }
374
375 let redis = RedisServer::spawn("").await.unwrap();
376 let conn = redis.connection().await.unwrap();
377 let cx = RedisContext::try_from(conn).await.unwrap();
378 test_big_limits(1_000, None, 0.2, &VanillaRedis(cx.connection)).await;
379 }
380}