1use anyhow::Context;
2use config::{any_err, from_lua_value, get_or_create_module};
3use deadpool::managed::{Manager, Metrics, Pool, RecycleError, RecycleResult};
4use kumo_prometheus::declare_metric;
5use mlua::{Lua, MultiValue, UserData, UserDataMethods, Value};
6use redis::aio::{ConnectionLike, ConnectionManager, ConnectionManagerConfig};
7use redis::cluster::ClusterClient;
8use redis::cluster_async::ClusterConnection;
9pub use redis::{
10 cmd, Cmd, FromRedisValue, RedisError, Script, ScriptInvocation, Value as RedisValue,
11};
12use redis::{
13 Client, ConnectionInfo, IntoConnectionInfo, Pipeline, RedisFuture, RedisWrite, ToRedisArgs,
14};
15use serde::Deserialize;
16use serde_json::Value as JsonValue;
17use std::collections::HashMap;
18use std::future::Future;
19use std::sync::{Arc, LazyLock, Mutex};
20use std::time::{Duration, Instant};
21
22pub mod test;
23
24static POOLS: LazyLock<Mutex<HashMap<RedisConnKey, Pool<ClientManager>>>> =
25 LazyLock::new(Mutex::default);
26
27pub struct ClientManager(ClientWrapper);
28
29impl Manager for ClientManager {
30 type Type = ConnectionWrapper;
31 type Error = anyhow::Error;
32
33 async fn create(&self) -> Result<Self::Type, Self::Error> {
34 let c = self.0.connect().await?;
35 Ok(c)
36 }
37
38 async fn recycle(
39 &self,
40 conn: &mut Self::Type,
41 _metrics: &Metrics,
42 ) -> RecycleResult<anyhow::Error> {
43 conn.ping()
44 .await
45 .map_err(|err| RecycleError::message(format!("{err:#}")))
46 }
47}
48
49declare_metric! {
50static REDIS_LATENCY: HistogramVec("redis_operation_latency",
74 &["service", "operation", "status"]);
75}
76
77#[derive(Debug)]
78struct KeyAndLabel {
79 key: RedisConnKey,
80 label: String,
81}
82
83#[derive(Clone, Debug)]
84pub struct RedisConnection(Arc<KeyAndLabel>);
85
86impl RedisConnection {
87 async fn sample_latency<T, E>(
88 &self,
89 operation: &str,
90 fut: impl Future<Output = Result<T, E>>,
91 ) -> Result<T, E> {
92 let now = Instant::now();
93 let result = (fut).await;
94 let elapsed = now.elapsed().as_secs_f64();
95 let status = if result.is_ok() { "ok" } else { "error" };
96
97 if let Ok(hist) =
98 REDIS_LATENCY.get_metric_with_label_values(&[self.0.label.as_str(), operation, status])
99 {
100 hist.observe(elapsed);
101 }
102
103 result
104 }
105
106 pub async fn ping(&self) -> anyhow::Result<()> {
107 self.sample_latency("ping", async {
108 let pool = self.0.key.get_pool()?;
109 let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
110 conn.ping().await
111 })
112 .await
113 }
114
115 pub async fn query(&self, cmd: Cmd) -> anyhow::Result<RedisValue> {
116 self.sample_latency("query", async {
117 let pool = self.0.key.get_pool()?;
118 let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
119 Ok(cmd.query_async(&mut *conn).await?)
120 })
121 .await
122 }
123
124 pub async fn invoke_script(
125 &self,
126 script: ScriptInvocation<'static>,
127 ) -> anyhow::Result<RedisValue> {
128 self.sample_latency("script", async {
129 let pool = self.0.key.get_pool()?;
130 let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
131 Ok(script.invoke_async(&mut *conn).await?)
132 })
133 .await
134 }
135}
136
137fn redis_value_to_lua(lua: &Lua, value: RedisValue) -> mlua::Result<Value> {
138 Ok(match value {
139 RedisValue::Nil => Value::Nil,
140 RedisValue::Int(i) => Value::Integer(i),
141 RedisValue::Boolean(i) => Value::Boolean(i),
142 RedisValue::BigNumber(i) => Value::String(lua.create_string(i.to_string())?),
143 RedisValue::Double(i) => Value::Number(i),
144 RedisValue::BulkString(bytes) => Value::String(lua.create_string(&bytes)?),
145 RedisValue::SimpleString(s) => Value::String(lua.create_string(&s)?),
146 RedisValue::Map(pairs) => {
147 let map = lua.create_table()?;
148 for (k, v) in pairs {
149 let k = redis_value_to_lua(lua, k)?;
150 let v = redis_value_to_lua(lua, v)?;
151 map.set(k, v)?;
152 }
153 Value::Table(map)
154 }
155 RedisValue::Array(values) => {
156 let array = lua.create_table()?;
157 for v in values {
158 array.push(redis_value_to_lua(lua, v)?)?;
159 }
160 Value::Table(array)
161 }
162 RedisValue::Set(values) => {
163 let array = lua.create_table()?;
164 for v in values {
165 array.push(redis_value_to_lua(lua, v)?)?;
166 }
167 Value::Table(array)
168 }
169 RedisValue::Attribute { data, attributes } => {
170 let map = lua.create_table()?;
171 for (k, v) in attributes {
172 let k = redis_value_to_lua(lua, k)?;
173 let v = redis_value_to_lua(lua, v)?;
174 map.set(k, v)?;
175 }
176
177 let attribute = lua.create_table()?;
178 attribute.set("data", redis_value_to_lua(lua, *data)?)?;
179 attribute.set("attributes", map)?;
180
181 Value::Table(attribute)
182 }
183 RedisValue::VerbatimString { format, text } => {
184 let vstr = lua.create_table()?;
185 vstr.set("format", format.to_string())?;
186 vstr.set("text", text)?;
187 Value::Table(vstr)
188 }
189 RedisValue::ServerError(_) => {
190 return Err(value
191 .extract_error()
192 .map_err(mlua::Error::external)
193 .unwrap_err());
194 }
195 RedisValue::Okay => Value::Boolean(true),
196 RedisValue::Push { kind, data } => {
197 let array = lua.create_table()?;
198 for v in data {
199 let v = redis_value_to_lua(lua, v)?;
200 array.push(v)?;
201 }
202
203 let push = lua.create_table()?;
204 push.set("data", array)?;
205 push.set("kind", kind.to_string())?;
206
207 Value::Table(push)
208 }
209 })
210}
211
212impl UserData for RedisConnection {
213 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
214 methods.add_async_method("query", |lua, this, params: MultiValue| async move {
215 let mut args = vec![];
216 for p in params {
217 args.push(from_lua_value(&lua, p)?);
218 }
219 let cmd = build_cmd(args).map_err(any_err)?;
220 let result = this.query(cmd).await.map_err(any_err)?;
221 redis_value_to_lua(&lua, result)
222 });
223 }
224}
225
226struct RedisJsonValue<'a>(&'a JsonValue);
227
228impl ToRedisArgs for RedisJsonValue<'_> {
229 fn write_redis_args<W>(&self, write: &mut W)
230 where
231 W: ?Sized + RedisWrite,
232 {
233 match self.0 {
234 JsonValue::Null => {}
235 JsonValue::Bool(b) => {
236 b.write_redis_args(write);
237 }
238 JsonValue::Number(n) => n.to_string().write_redis_args(write),
239 JsonValue::String(s) => s.write_redis_args(write),
240 JsonValue::Array(array) => {
241 for item in array {
242 RedisJsonValue(item).write_redis_args(write);
243 }
244 }
245 JsonValue::Object(map) => {
246 for (k, v) in map {
247 k.write_redis_args(write);
248 RedisJsonValue(v).write_redis_args(write);
249 }
250 }
251 }
252 }
253
254 fn num_of_args(&self) -> usize {
255 match self.0 {
256 JsonValue::Array(array) => array.len(),
257 JsonValue::Null => 1,
258 JsonValue::Object(map) => map.len(),
259 JsonValue::Number(_) | JsonValue::Bool(_) | JsonValue::String(_) => 1,
260 }
261 }
262}
263
264pub fn build_cmd(args: Vec<JsonValue>) -> anyhow::Result<Cmd> {
265 let mut cmd = Cmd::new();
266 for a in args {
267 cmd.arg(RedisJsonValue(&a));
268 }
269 Ok(cmd)
270}
271
272#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
273#[serde(untagged)]
274pub enum NodeSpec {
275 Single(String),
277 Cluster(Vec<String>),
279}
280
281#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
282pub struct RedisConnKey {
283 pub node: NodeSpec,
284 #[serde(default)]
286 pub read_from_replicas: bool,
287 #[serde(default)]
288 pub username: Option<String>,
289 #[serde(default)]
290 pub password: Option<String>,
291 #[serde(default)]
292 pub cluster: Option<bool>,
293 #[serde(default)]
296 pub pool_size: Option<usize>,
297 #[serde(default, with = "duration_serde")]
298 pub connect_timeout: Option<Duration>,
299 #[serde(default, with = "duration_serde")]
300 pub recycle_timeout: Option<Duration>,
301 #[serde(default, with = "duration_serde")]
302 pub wait_timeout: Option<Duration>,
303 #[serde(default, with = "duration_serde")]
304 pub response_timeout: Option<Duration>,
305}
306
307pub enum ClientWrapper {
308 Single(Client, ConnectionManagerConfig),
309 Cluster(ClusterClient),
310}
311
312impl ClientWrapper {
313 pub async fn connect(&self) -> anyhow::Result<ConnectionWrapper> {
314 match self {
315 Self::Single(client, config) => Ok(ConnectionWrapper::Single(
316 ConnectionManager::new_with_config(client.clone(), config.clone()).await?,
317 )),
318 Self::Cluster(c) => Ok(ConnectionWrapper::Cluster(c.get_async_connection().await?)),
319 }
320 }
321}
322
323pub enum ConnectionWrapper {
324 Single(ConnectionManager),
325 Cluster(ClusterConnection),
326}
327
328impl ConnectionWrapper {
329 pub async fn ping(&mut self) -> anyhow::Result<()> {
330 Ok(redis::cmd("PING").query_async(self).await?)
331 }
332}
333
334impl ConnectionLike for ConnectionWrapper {
335 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, RedisValue> {
337 match self {
338 Self::Single(c) => c.req_packed_command(cmd),
339 Self::Cluster(c) => c.req_packed_command(cmd),
340 }
341 }
342
343 fn req_packed_commands<'a>(
344 &'a mut self,
345 cmd: &'a crate::Pipeline,
346 offset: usize,
347 count: usize,
348 ) -> RedisFuture<'a, Vec<RedisValue>> {
349 match self {
350 Self::Single(c) => c.req_packed_commands(cmd, offset, count),
351 Self::Cluster(c) => c.req_packed_commands(cmd, offset, count),
352 }
353 }
354
355 fn get_db(&self) -> i64 {
356 match self {
357 Self::Single(c) => c.get_db(),
358 Self::Cluster(c) => c.get_db(),
359 }
360 }
361}
362
363impl RedisConnKey {
364 pub fn build_client(&self) -> anyhow::Result<ClientWrapper> {
365 let cluster = self
366 .cluster
367 .unwrap_or(matches!(&self.node, NodeSpec::Cluster(_)));
368 let nodes = match &self.node {
369 NodeSpec::Single(node) => vec![node.to_string()],
370 NodeSpec::Cluster(nodes) => nodes.clone(),
371 };
372
373 if cluster {
374 let mut builder = ClusterClient::builder(nodes);
375 if self.read_from_replicas {
376 builder = builder.read_from_replicas();
377 }
378 if let Some(user) = &self.username {
379 builder = builder.username(user.to_string());
380 }
381 if let Some(pass) = &self.password {
382 builder = builder.password(pass.to_string());
383 }
384 if let Some(duration) = self.connect_timeout {
385 builder = builder.connection_timeout(duration);
386 }
387 if let Some(duration) = self.response_timeout {
388 builder = builder.response_timeout(duration);
389 }
390
391 Ok(ClientWrapper::Cluster(builder.build().with_context(
392 || format!("building redis client {self:?}"),
393 )?))
394 } else {
395 let mut config = ConnectionManagerConfig::new();
396 if let Some(duration) = self.connect_timeout {
397 config = config.set_connection_timeout(duration);
398 }
399 if let Some(duration) = self.response_timeout {
400 config = config.set_response_timeout(duration);
401 }
402
403 let mut info: ConnectionInfo = nodes[0]
404 .as_str()
405 .into_connection_info()
406 .with_context(|| format!("building redis client {self:?}"))?;
407 if let Some(user) = &self.username {
408 info.redis.username.replace(user.to_string());
409 }
410 if let Some(pass) = &self.password {
411 info.redis.password.replace(pass.to_string());
412 }
413
414 Ok(ClientWrapper::Single(
415 Client::open(info).with_context(|| format!("building redis client {self:?}"))?,
416 config,
417 ))
418 }
419 }
420
421 pub fn get_pool(&self) -> anyhow::Result<Pool<ClientManager>> {
422 let mut pools = POOLS.lock().unwrap();
423 if let Some(pool) = pools.get(self) {
424 return Ok(pool.clone());
425 }
426
427 let client = self.build_client()?;
428 let mut builder = Pool::builder(ClientManager(client))
429 .runtime(deadpool::Runtime::Tokio1)
430 .create_timeout(self.connect_timeout)
431 .recycle_timeout(self.recycle_timeout)
432 .wait_timeout(self.wait_timeout);
433
434 if let Some(limit) = self.pool_size {
435 builder = builder.max_size(limit);
436 }
437
438 let pool = builder.build()?;
439
440 pools.insert(self.clone(), pool.clone());
441
442 Ok(pool)
443 }
444
445 pub fn open(&self) -> anyhow::Result<RedisConnection> {
446 self.build_client()?;
447 Ok(RedisConnection(Arc::new(KeyAndLabel {
448 key: self.clone(),
449 label: self.hash_label(),
450 })))
451 }
452
453 pub fn hash_label(&self) -> String {
472 use crc32fast::Hasher;
473 use std::hash::Hash;
474 let mut hasher = Hasher::new();
475 self.hash(&mut hasher);
476 let crc = hasher.finalize();
477
478 let mut label = String::new();
479 if let Some(user) = &self.username {
480 label.push_str(user);
481 label.push('@');
482 }
483 match &self.node {
484 NodeSpec::Single(node) => {
485 label.push_str(node);
486 }
487 NodeSpec::Cluster(nodes) => {
488 for (idx, node) in nodes.iter().enumerate() {
489 if idx > 0 {
490 label.push(',');
491 }
492 label.push_str(node);
493 }
494 }
495 }
496 label.push_str(&format!("-{crc:08x}"));
497 label
498 }
499}
500
501pub fn register(lua: &Lua) -> anyhow::Result<()> {
502 let redis_mod = get_or_create_module(lua, "redis")?;
503
504 redis_mod.set(
505 "open",
506 lua.create_function(move |lua, key: Value| {
507 let key: RedisConnKey = from_lua_value(lua, key)?;
508 key.open().map_err(any_err)
509 })?,
510 )?;
511
512 Ok(())
513}