mod_redis/
lib.rs

1use anyhow::Context;
2use config::{any_err, from_lua_value, get_or_create_module};
3use deadpool::managed::{Manager, Metrics, Pool, RecycleError, RecycleResult};
4use kumo_prometheus::declare_metric;
5use mlua::{Lua, MultiValue, UserData, UserDataMethods, Value};
6use redis::aio::{ConnectionLike, ConnectionManager, ConnectionManagerConfig};
7use redis::cluster::ClusterClient;
8use redis::cluster_async::ClusterConnection;
9pub use redis::{
10    cmd, Cmd, FromRedisValue, RedisError, Script, ScriptInvocation, Value as RedisValue,
11};
12use redis::{
13    Client, ConnectionInfo, IntoConnectionInfo, Pipeline, RedisFuture, RedisWrite, ToRedisArgs,
14};
15use serde::Deserialize;
16use serde_json::Value as JsonValue;
17use std::collections::HashMap;
18use std::future::Future;
19use std::sync::{Arc, LazyLock, Mutex};
20use std::time::{Duration, Instant};
21
22pub mod test;
23
24static POOLS: LazyLock<Mutex<HashMap<RedisConnKey, Pool<ClientManager>>>> =
25    LazyLock::new(Mutex::default);
26
27pub struct ClientManager(ClientWrapper);
28
29impl Manager for ClientManager {
30    type Type = ConnectionWrapper;
31    type Error = anyhow::Error;
32
33    async fn create(&self) -> Result<Self::Type, Self::Error> {
34        let c = self.0.connect().await?;
35        Ok(c)
36    }
37
38    async fn recycle(
39        &self,
40        conn: &mut Self::Type,
41        _metrics: &Metrics,
42    ) -> RecycleResult<anyhow::Error> {
43        conn.ping()
44            .await
45            .map_err(|err| RecycleError::message(format!("{err:#}")))
46    }
47}
48
49declare_metric! {
50/// The latency of an operation talking to Redis.
51///
52/// {{since('dev')}}
53///
54/// The `service` key represents the redis server/service. It is not
55/// a direct match to a server name as it is really a hash of the
56/// overall redis configuration information used in the client.
57/// It might look something like:
58/// `redis://127.0.0.1:24419,redis://127.0.0.1:7779,redis://127.0.0.1:29469-2ce79dd1`
59/// for a cluster configuration, or `redis://127.0.0.1:16267-f4da6e64`
60/// for a single node cluster configuration.
61/// You should anticipate that the `-HEX` suffix can and will change
62/// in an unspecified way as you vary the redis connection parameters.
63///
64/// The `operation` key indicates the operation, which can be a `ping`,
65/// a `query` or a `script`.
66///
67/// `status` will be either `ok` or `error` to indicate whether this
68/// is tracking a successful or failed operation.
69///
70/// Since histograms track a count of operations, you can track the
71/// rate of `redis_operation_latency_count` where `status=error`
72/// to have an indication of the failure rate of redis operations.
73static REDIS_LATENCY:  HistogramVec("redis_operation_latency",
74    &["service", "operation", "status"]);
75}
76
77#[derive(Debug)]
78struct KeyAndLabel {
79    key: RedisConnKey,
80    label: String,
81}
82
83#[derive(Clone, Debug)]
84pub struct RedisConnection(Arc<KeyAndLabel>);
85
86impl RedisConnection {
87    async fn sample_latency<T, E>(
88        &self,
89        operation: &str,
90        fut: impl Future<Output = Result<T, E>>,
91    ) -> Result<T, E> {
92        let now = Instant::now();
93        let result = (fut).await;
94        let elapsed = now.elapsed().as_secs_f64();
95        let status = if result.is_ok() { "ok" } else { "error" };
96
97        if let Ok(hist) =
98            REDIS_LATENCY.get_metric_with_label_values(&[self.0.label.as_str(), operation, status])
99        {
100            hist.observe(elapsed);
101        }
102
103        result
104    }
105
106    pub async fn ping(&self) -> anyhow::Result<()> {
107        self.sample_latency("ping", async {
108            let pool = self.0.key.get_pool()?;
109            let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
110            conn.ping().await
111        })
112        .await
113    }
114
115    pub async fn query(&self, cmd: Cmd) -> anyhow::Result<RedisValue> {
116        self.sample_latency("query", async {
117            let pool = self.0.key.get_pool()?;
118            let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
119            Ok(cmd.query_async(&mut *conn).await?)
120        })
121        .await
122    }
123
124    pub async fn invoke_script(
125        &self,
126        script: ScriptInvocation<'static>,
127    ) -> anyhow::Result<RedisValue> {
128        self.sample_latency("script", async {
129            let pool = self.0.key.get_pool()?;
130            let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
131            Ok(script.invoke_async(&mut *conn).await?)
132        })
133        .await
134    }
135}
136
137fn redis_value_to_lua(lua: &Lua, value: RedisValue) -> mlua::Result<Value> {
138    Ok(match value {
139        RedisValue::Nil => Value::Nil,
140        RedisValue::Int(i) => Value::Integer(i),
141        RedisValue::Boolean(i) => Value::Boolean(i),
142        RedisValue::BigNumber(i) => Value::String(lua.create_string(i.to_string())?),
143        RedisValue::Double(i) => Value::Number(i),
144        RedisValue::BulkString(bytes) => Value::String(lua.create_string(&bytes)?),
145        RedisValue::SimpleString(s) => Value::String(lua.create_string(&s)?),
146        RedisValue::Map(pairs) => {
147            let map = lua.create_table()?;
148            for (k, v) in pairs {
149                let k = redis_value_to_lua(lua, k)?;
150                let v = redis_value_to_lua(lua, v)?;
151                map.set(k, v)?;
152            }
153            Value::Table(map)
154        }
155        RedisValue::Array(values) => {
156            let array = lua.create_table()?;
157            for v in values {
158                array.push(redis_value_to_lua(lua, v)?)?;
159            }
160            Value::Table(array)
161        }
162        RedisValue::Set(values) => {
163            let array = lua.create_table()?;
164            for v in values {
165                array.push(redis_value_to_lua(lua, v)?)?;
166            }
167            Value::Table(array)
168        }
169        RedisValue::Attribute { data, attributes } => {
170            let map = lua.create_table()?;
171            for (k, v) in attributes {
172                let k = redis_value_to_lua(lua, k)?;
173                let v = redis_value_to_lua(lua, v)?;
174                map.set(k, v)?;
175            }
176
177            let attribute = lua.create_table()?;
178            attribute.set("data", redis_value_to_lua(lua, *data)?)?;
179            attribute.set("attributes", map)?;
180
181            Value::Table(attribute)
182        }
183        RedisValue::VerbatimString { format, text } => {
184            let vstr = lua.create_table()?;
185            vstr.set("format", format.to_string())?;
186            vstr.set("text", text)?;
187            Value::Table(vstr)
188        }
189        RedisValue::ServerError(_) => {
190            return Err(value
191                .extract_error()
192                .map_err(mlua::Error::external)
193                .unwrap_err());
194        }
195        RedisValue::Okay => Value::Boolean(true),
196        RedisValue::Push { kind, data } => {
197            let array = lua.create_table()?;
198            for v in data {
199                let v = redis_value_to_lua(lua, v)?;
200                array.push(v)?;
201            }
202
203            let push = lua.create_table()?;
204            push.set("data", array)?;
205            push.set("kind", kind.to_string())?;
206
207            Value::Table(push)
208        }
209    })
210}
211
212impl UserData for RedisConnection {
213    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
214        methods.add_async_method("query", |lua, this, params: MultiValue| async move {
215            let mut args = vec![];
216            for p in params {
217                args.push(from_lua_value(&lua, p)?);
218            }
219            let cmd = build_cmd(args).map_err(any_err)?;
220            let result = this.query(cmd).await.map_err(any_err)?;
221            redis_value_to_lua(&lua, result)
222        });
223    }
224}
225
226struct RedisJsonValue<'a>(&'a JsonValue);
227
228impl ToRedisArgs for RedisJsonValue<'_> {
229    fn write_redis_args<W>(&self, write: &mut W)
230    where
231        W: ?Sized + RedisWrite,
232    {
233        match self.0 {
234            JsonValue::Null => {}
235            JsonValue::Bool(b) => {
236                b.write_redis_args(write);
237            }
238            JsonValue::Number(n) => n.to_string().write_redis_args(write),
239            JsonValue::String(s) => s.write_redis_args(write),
240            JsonValue::Array(array) => {
241                for item in array {
242                    RedisJsonValue(item).write_redis_args(write);
243                }
244            }
245            JsonValue::Object(map) => {
246                for (k, v) in map {
247                    k.write_redis_args(write);
248                    RedisJsonValue(v).write_redis_args(write);
249                }
250            }
251        }
252    }
253
254    fn num_of_args(&self) -> usize {
255        match self.0 {
256            JsonValue::Array(array) => array.len(),
257            JsonValue::Null => 1,
258            JsonValue::Object(map) => map.len(),
259            JsonValue::Number(_) | JsonValue::Bool(_) | JsonValue::String(_) => 1,
260        }
261    }
262}
263
264pub fn build_cmd(args: Vec<JsonValue>) -> anyhow::Result<Cmd> {
265    let mut cmd = Cmd::new();
266    for a in args {
267        cmd.arg(RedisJsonValue(&a));
268    }
269    Ok(cmd)
270}
271
272#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
273#[serde(untagged)]
274pub enum NodeSpec {
275    /// A single, non-clustered redis node
276    Single(String),
277    /// List of redis URLs for hosts in the cluster
278    Cluster(Vec<String>),
279}
280
281#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
282pub struct RedisConnKey {
283    pub node: NodeSpec,
284    /// Enables reading from replicas for all new connections
285    #[serde(default)]
286    pub read_from_replicas: bool,
287    #[serde(default)]
288    pub username: Option<String>,
289    #[serde(default)]
290    pub password: Option<String>,
291    #[serde(default)]
292    pub cluster: Option<bool>,
293    /// Maximum number of connections managed by the pool.
294    /// Default is 10
295    #[serde(default)]
296    pub pool_size: Option<usize>,
297    #[serde(default, with = "duration_serde")]
298    pub connect_timeout: Option<Duration>,
299    #[serde(default, with = "duration_serde")]
300    pub recycle_timeout: Option<Duration>,
301    #[serde(default, with = "duration_serde")]
302    pub wait_timeout: Option<Duration>,
303    #[serde(default, with = "duration_serde")]
304    pub response_timeout: Option<Duration>,
305}
306
307pub enum ClientWrapper {
308    Single(Client, ConnectionManagerConfig),
309    Cluster(ClusterClient),
310}
311
312impl ClientWrapper {
313    pub async fn connect(&self) -> anyhow::Result<ConnectionWrapper> {
314        match self {
315            Self::Single(client, config) => Ok(ConnectionWrapper::Single(
316                ConnectionManager::new_with_config(client.clone(), config.clone()).await?,
317            )),
318            Self::Cluster(c) => Ok(ConnectionWrapper::Cluster(c.get_async_connection().await?)),
319        }
320    }
321}
322
323pub enum ConnectionWrapper {
324    Single(ConnectionManager),
325    Cluster(ClusterConnection),
326}
327
328impl ConnectionWrapper {
329    pub async fn ping(&mut self) -> anyhow::Result<()> {
330        Ok(redis::cmd("PING").query_async(self).await?)
331    }
332}
333
334impl ConnectionLike for ConnectionWrapper {
335    // Required methods
336    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, RedisValue> {
337        match self {
338            Self::Single(c) => c.req_packed_command(cmd),
339            Self::Cluster(c) => c.req_packed_command(cmd),
340        }
341    }
342
343    fn req_packed_commands<'a>(
344        &'a mut self,
345        cmd: &'a crate::Pipeline,
346        offset: usize,
347        count: usize,
348    ) -> RedisFuture<'a, Vec<RedisValue>> {
349        match self {
350            Self::Single(c) => c.req_packed_commands(cmd, offset, count),
351            Self::Cluster(c) => c.req_packed_commands(cmd, offset, count),
352        }
353    }
354
355    fn get_db(&self) -> i64 {
356        match self {
357            Self::Single(c) => c.get_db(),
358            Self::Cluster(c) => c.get_db(),
359        }
360    }
361}
362
363impl RedisConnKey {
364    pub fn build_client(&self) -> anyhow::Result<ClientWrapper> {
365        let cluster = self
366            .cluster
367            .unwrap_or(matches!(&self.node, NodeSpec::Cluster(_)));
368        let nodes = match &self.node {
369            NodeSpec::Single(node) => vec![node.to_string()],
370            NodeSpec::Cluster(nodes) => nodes.clone(),
371        };
372
373        if cluster {
374            let mut builder = ClusterClient::builder(nodes);
375            if self.read_from_replicas {
376                builder = builder.read_from_replicas();
377            }
378            if let Some(user) = &self.username {
379                builder = builder.username(user.to_string());
380            }
381            if let Some(pass) = &self.password {
382                builder = builder.password(pass.to_string());
383            }
384            if let Some(duration) = self.connect_timeout {
385                builder = builder.connection_timeout(duration);
386            }
387            if let Some(duration) = self.response_timeout {
388                builder = builder.response_timeout(duration);
389            }
390
391            Ok(ClientWrapper::Cluster(builder.build().with_context(
392                || format!("building redis client {self:?}"),
393            )?))
394        } else {
395            let mut config = ConnectionManagerConfig::new();
396            if let Some(duration) = self.connect_timeout {
397                config = config.set_connection_timeout(duration);
398            }
399            if let Some(duration) = self.response_timeout {
400                config = config.set_response_timeout(duration);
401            }
402
403            let mut info: ConnectionInfo = nodes[0]
404                .as_str()
405                .into_connection_info()
406                .with_context(|| format!("building redis client {self:?}"))?;
407            if let Some(user) = &self.username {
408                info.redis.username.replace(user.to_string());
409            }
410            if let Some(pass) = &self.password {
411                info.redis.password.replace(pass.to_string());
412            }
413
414            Ok(ClientWrapper::Single(
415                Client::open(info).with_context(|| format!("building redis client {self:?}"))?,
416                config,
417            ))
418        }
419    }
420
421    pub fn get_pool(&self) -> anyhow::Result<Pool<ClientManager>> {
422        let mut pools = POOLS.lock().unwrap();
423        if let Some(pool) = pools.get(self) {
424            return Ok(pool.clone());
425        }
426
427        let client = self.build_client()?;
428        let mut builder = Pool::builder(ClientManager(client))
429            .runtime(deadpool::Runtime::Tokio1)
430            .create_timeout(self.connect_timeout)
431            .recycle_timeout(self.recycle_timeout)
432            .wait_timeout(self.wait_timeout);
433
434        if let Some(limit) = self.pool_size {
435            builder = builder.max_size(limit);
436        }
437
438        let pool = builder.build()?;
439
440        pools.insert(self.clone(), pool.clone());
441
442        Ok(pool)
443    }
444
445    pub fn open(&self) -> anyhow::Result<RedisConnection> {
446        self.build_client()?;
447        Ok(RedisConnection(Arc::new(KeyAndLabel {
448            key: self.clone(),
449            label: self.hash_label(),
450        })))
451    }
452
453    /// Produces a human readable label string that is representitive
454    /// of this RedisConnKey.  We pull out the node and username to
455    /// include in the label.
456    /// Now, since the entire RedisConnKey is the actual key, that
457    /// readable subset is not sufficient to uniquely identify the
458    /// entry in the pool, although in reality it is probably OK,
459    /// there exists the possibility that eg: on a config update,
460    /// multiple entries have the same list of nodes but different
461    /// auth or other parameters.
462    /// To smooth over such a transition, we'll include a basic
463    /// crc32 of the entire RedisConnKey in the label that is
464    /// returned.  This will probably be sufficient to avoid
465    /// an obvious collision between such names, but it will not
466    /// guarantee it.
467    /// This is a best effort really; I doubt that this will cause
468    /// any meaningful issues in practice, as the intended use case
469    /// for this label string is to sample metrics rather than to
470    /// guarantee isolation.
471    pub fn hash_label(&self) -> String {
472        use crc32fast::Hasher;
473        use std::hash::Hash;
474        let mut hasher = Hasher::new();
475        self.hash(&mut hasher);
476        let crc = hasher.finalize();
477
478        let mut label = String::new();
479        if let Some(user) = &self.username {
480            label.push_str(user);
481            label.push('@');
482        }
483        match &self.node {
484            NodeSpec::Single(node) => {
485                label.push_str(node);
486            }
487            NodeSpec::Cluster(nodes) => {
488                for (idx, node) in nodes.iter().enumerate() {
489                    if idx > 0 {
490                        label.push(',');
491                    }
492                    label.push_str(node);
493                }
494            }
495        }
496        label.push_str(&format!("-{crc:08x}"));
497        label
498    }
499}
500
501pub fn register(lua: &Lua) -> anyhow::Result<()> {
502    let redis_mod = get_or_create_module(lua, "redis")?;
503
504    redis_mod.set(
505        "open",
506        lua.create_function(move |lua, key: Value| {
507            let key: RedisConnKey = from_lua_value(lua, key)?;
508            key.open().map_err(any_err)
509        })?,
510    )?;
511
512    Ok(())
513}