1use crate::{Error, ThrottleResult, REDIS};
2use anyhow::Context;
3use mod_redis::{Cmd, FromRedisValue, RedisConnection, Script};
4use redis_cell_impl::{time, MemoryStore, Rate, RateLimiter, RateQuota};
5use std::sync::{LazyLock, Mutex};
6use std::time::Duration;
7
8static MEMORY: LazyLock<Mutex<MemoryStore>> = LazyLock::new(|| Mutex::new(MemoryStore::new()));
9
10static GCRA_SCRIPT: LazyLock<Script> = LazyLock::new(|| {
12 Script::new(
13 r#"
14local key = KEYS[1]
15local limit = ARGV[1]
16local period = ARGV[2]
17local max_burst = ARGV[3]
18local quantity = ARGV[4]
19
20local interval = period / limit
21local increment = interval * quantity
22local burst_offset = interval * max_burst
23
24local now = tonumber(redis.call("TIME")[1])
25local tat = redis.call("GET", key)
26
27if not tat then
28 tat = now
29else
30 tat = tonumber(tat)
31end
32tat = math.max(tat, now)
33
34local new_tat = tat + increment
35local allow_at = new_tat - burst_offset
36local diff = now - allow_at
37
38local throttled
39local reset_after
40local retry_after
41
42local remaining = math.floor(diff / interval) -- poor man's round
43
44if remaining < 0 then
45 throttled = 1
46 -- calculate how many tokens there actually are, since
47 -- remaining is how many there would have been if we had been able to limit
48 -- and we did not limit
49 remaining = math.floor((now - (tat - burst_offset)) / interval)
50 reset_after = math.ceil(tat - now)
51 retry_after = math.ceil(diff * -1)
52elseif remaining == 0 and increment <= 0 then
53 -- request with cost of 0
54 -- cost of 0 with remaining 0 is still limited
55 throttled = 1
56 remaining = 0
57 reset_after = math.ceil(tat - now)
58 retry_after = 0 -- retry_after is meaningless when quantity is 0
59else
60 throttled = 0
61 reset_after = math.ceil(new_tat - now)
62 retry_after = 0
63 redis.call("SET", key, new_tat, "EX", reset_after)
64end
65
66return {throttled, remaining, reset_after, retry_after, tostring(diff), tostring(interval)}
67"#,
68 )
69});
70
71fn local_throttle(
72 key: &str,
73 limit: u64,
74 period: Duration,
75 max_burst: u64,
76 quantity: Option<u64>,
77) -> Result<ThrottleResult, Error> {
78 let mut store = MEMORY.lock().unwrap();
79 let max_rate = Rate::per_period(
80 limit as i64,
81 time::Duration::try_from(period).map_err(|err| Error::Generic(format!("{err:#}")))?,
82 );
83 let mut limiter = RateLimiter::new(
84 &mut *store,
85 &RateQuota {
86 max_burst: max_burst.min(limit - 1) as i64,
87 max_rate,
88 },
89 );
90 let quantity = quantity.unwrap_or(1) as i64;
91 let (throttled, rate_limit_result) = limiter
92 .rate_limit(key, quantity)
93 .map_err(|err| Error::Generic(format!("{err:#}")))?;
94
95 let mut retry_after = rate_limit_result.retry_after.whole_seconds();
99 if rate_limit_result.retry_after.subsec_milliseconds() > 0 {
100 retry_after += 1
101 }
102 let mut reset_after = rate_limit_result.reset_after.whole_seconds();
103 if rate_limit_result.reset_after.subsec_milliseconds() > 0 {
104 reset_after += 1
105 }
106
107 Ok(ThrottleResult {
108 throttled,
109 limit: rate_limit_result.limit as u64,
110 remaining: rate_limit_result.remaining as u64,
111 reset_after: Duration::from_secs(reset_after.max(0) as u64),
112 retry_after: if retry_after <= 0 {
113 None
114 } else {
115 Some(Duration::from_secs(retry_after.max(0) as u64))
116 },
117 })
118}
119
120async fn redis_cell_throttle(
121 conn: &RedisConnection,
122 key: &str,
123 limit: u64,
124 period: Duration,
125 max_burst: u64,
126 quantity: Option<u64>,
127) -> Result<ThrottleResult, Error> {
128 let mut cmd = Cmd::new();
129 cmd.arg("CL.THROTTLE")
130 .arg(key)
131 .arg(max_burst)
132 .arg(limit)
133 .arg(period.as_secs())
134 .arg(quantity.unwrap_or(1));
135 let result = conn.query(cmd).await?;
136 let result = <Vec<i64> as FromRedisValue>::from_redis_value(&result)?;
137
138 Ok(ThrottleResult {
139 throttled: result[0] != 0,
140 limit: result[1] as u64,
141 remaining: result[2] as u64,
142 retry_after: match result[3] {
143 n if n <= 0 => None,
144 n => Some(Duration::from_secs(n as u64)),
145 },
146 reset_after: Duration::from_secs(result[4].max(0) as u64),
147 })
148}
149
150async fn redis_script_throttle(
151 conn: &RedisConnection,
152 key: &str,
153 limit: u64,
154 period: Duration,
155 max_burst: u64,
156 quantity: Option<u64>,
157) -> Result<ThrottleResult, Error> {
158 let mut script = GCRA_SCRIPT.prepare_invoke();
159 script
160 .key(key)
161 .arg(limit)
162 .arg(period.as_secs())
163 .arg(max_burst)
164 .arg(quantity.unwrap_or(1));
165
166 let result = conn
167 .invoke_script(script)
168 .await
169 .context("error invoking redis GCRA script")?;
170 let result =
171 <(u64, u64, u64, u64, String, String) as FromRedisValue>::from_redis_value(&result)?;
172
173 Ok(ThrottleResult {
174 throttled: result.0 == 1,
175 limit: max_burst + 1,
176 remaining: result.1,
177 retry_after: match result.3 {
178 0 => None,
179 n => Some(Duration::from_secs(n)),
180 },
181 reset_after: Duration::from_secs(result.2),
182 })
183}
184
185pub async fn throttle(
205 key: &str,
206 limit: u64,
207 period: Duration,
208 max_burst: u64,
209 quantity: Option<u64>,
210 force_local: bool,
211) -> Result<ThrottleResult, Error> {
212 match (force_local, REDIS.get()) {
213 (false, Some(cx)) => match cx.has_redis_cell {
214 true => redis_cell_throttle(&cx, key, limit, period, max_burst, quantity).await,
215 false => redis_script_throttle(&cx, key, limit, period, max_burst, quantity).await,
216 },
217 _ => local_throttle(key, limit, period, max_burst, quantity),
218 }
219}
220
221#[cfg(test)]
222mod test {
223 use super::*;
224 use crate::redis::RedisContext;
225 use mod_redis::test::RedisServer;
226
227 trait Throttler {
228 async fn throttle(
229 &self,
230 key: &str,
231 limit: u64,
232 period: Duration,
233 max_burst: u64,
234 quantity: Option<u64>,
235 ) -> Result<ThrottleResult, Error>;
236 }
237
238 impl Throttler for Mutex<MemoryStore> {
239 async fn throttle(
240 &self,
241 key: &str,
242 limit: u64,
243 period: Duration,
244 max_burst: u64,
245 quantity: Option<u64>,
246 ) -> Result<ThrottleResult, Error> {
247 local_throttle(key, limit, period, max_burst, quantity)
248 }
249 }
250
251 struct RedisWithCell(RedisConnection);
252
253 impl Throttler for RedisWithCell {
254 async fn throttle(
255 &self,
256 key: &str,
257 limit: u64,
258 period: Duration,
259 max_burst: u64,
260 quantity: Option<u64>,
261 ) -> Result<ThrottleResult, Error> {
262 redis_cell_throttle(&self.0, key, limit, period, max_burst, quantity).await
263 }
264 }
265
266 struct VanillaRedis(RedisConnection);
267
268 impl Throttler for VanillaRedis {
269 async fn throttle(
270 &self,
271 key: &str,
272 limit: u64,
273 period: Duration,
274 max_burst: u64,
275 quantity: Option<u64>,
276 ) -> Result<ThrottleResult, Error> {
277 redis_script_throttle(&self.0, key, limit, period, max_burst, quantity).await
278 }
279 }
280
281 async fn test_big_limits(
282 limit: u64,
283 max_burst: Option<u64>,
284 permitted_tolerance: f64,
285 throttler: &impl Throttler,
286 ) {
287 let period = Duration::from_secs(60);
288 let max_burst = max_burst.unwrap_or(limit);
289 let key = format!("test_big_limits-{limit}-{max_burst}");
290
291 let mut throttled_iter = None;
292
293 for i in 0..limit * 2 {
294 let result = throttler
295 .throttle(&key, limit, period, max_burst, None)
296 .await
297 .unwrap();
298 if result.throttled {
299 println!("iter: {i} -> {result:?}");
300 throttled_iter.replace(i);
301 break;
302 }
303 }
304
305 let throttled_iter = throttled_iter.expect("to hit the throttle limit");
306 let diff = ((max_burst as f64) - (throttled_iter as f64)).abs();
307 let tolerance = (max_burst as f64) * permitted_tolerance;
308 println!(
309 "throttled after {throttled_iter} iterations for \
310 limit {limit}. diff={diff}. tolerance {tolerance}"
311 );
312 let max_rate = Rate::per_period(limit as i64, time::Duration::try_from(period).unwrap());
313 println!("max_rate: {max_rate:?}");
314
315 assert!(
316 diff <= tolerance,
317 "throttled after {throttled_iter} iterations for \
318 limit {limit}. diff={diff} is not within tolerance {tolerance}"
319 );
320 }
321
322 #[tokio::test]
323 async fn basic_throttle_100() {
324 test_big_limits(100, None, 0.01, &*MEMORY).await;
325 }
326
327 #[tokio::test]
328 async fn basic_throttle_1_000() {
329 test_big_limits(1_000, Some(100), 0.02, &*MEMORY).await;
330 }
331
332 #[tokio::test]
333 async fn basic_throttle_6_000() {
334 test_big_limits(6_000, Some(100), 0.02, &*MEMORY).await;
335 }
336
337 #[tokio::test]
338 async fn basic_throttle_60_000() {
339 test_big_limits(60_000, Some(100), 0.1, &*MEMORY).await;
340 }
341
342 #[tokio::test]
343 async fn basic_throttle_60_000_burst_30k() {
344 test_big_limits(60_000, Some(100), 0.1, &*MEMORY).await;
348 }
349
350 #[tokio::test]
351 async fn redis_cell_throttle_1_000() {
352 if !RedisServer::is_available() {
353 return;
354 }
355
356 let redis = RedisServer::spawn("").await.unwrap();
357 let conn = redis.connection().await.unwrap();
358 let cx = RedisContext::try_from(conn).await.unwrap();
359 if !cx.has_redis_cell {
360 return;
361 }
362
363 test_big_limits(1_000, None, 0.02, &RedisWithCell(cx.connection)).await;
364 }
365
366 #[tokio::test]
367 async fn redis_script_throttle_1_000() {
368 if !RedisServer::is_available() {
369 return;
370 }
371
372 let redis = RedisServer::spawn("").await.unwrap();
373 let conn = redis.connection().await.unwrap();
374 let cx = RedisContext::try_from(conn).await.unwrap();
375 test_big_limits(1_000, None, 0.2, &VanillaRedis(cx.connection)).await;
376 }
377}