mod_redis/
lib.rs

1use anyhow::Context;
2use config::{any_err, from_lua_value, get_or_create_module};
3use deadpool::managed::{Manager, Metrics, Pool, RecycleError, RecycleResult};
4use mlua::{Lua, MultiValue, UserData, UserDataMethods, Value};
5use redis::aio::{ConnectionLike, ConnectionManager, ConnectionManagerConfig};
6use redis::cluster::ClusterClient;
7use redis::cluster_async::ClusterConnection;
8pub use redis::{
9    cmd, Cmd, FromRedisValue, RedisError, Script, ScriptInvocation, Value as RedisValue,
10};
11use redis::{
12    Client, ConnectionInfo, IntoConnectionInfo, Pipeline, RedisFuture, RedisWrite, ToRedisArgs,
13};
14use serde::Deserialize;
15use serde_json::Value as JsonValue;
16use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, Mutex};
18use std::time::Duration;
19
20pub mod test;
21
22static POOLS: LazyLock<Mutex<HashMap<RedisConnKey, Pool<ClientManager>>>> =
23    LazyLock::new(Mutex::default);
24
25pub struct ClientManager(ClientWrapper);
26
27impl Manager for ClientManager {
28    type Type = ConnectionWrapper;
29    type Error = anyhow::Error;
30
31    async fn create(&self) -> Result<Self::Type, Self::Error> {
32        let c = self.0.connect().await?;
33        Ok(c)
34    }
35
36    async fn recycle(
37        &self,
38        conn: &mut Self::Type,
39        _metrics: &Metrics,
40    ) -> RecycleResult<anyhow::Error> {
41        conn.ping()
42            .await
43            .map_err(|err| RecycleError::message(format!("{err:#}")))
44    }
45}
46
47#[derive(Clone, Debug)]
48pub struct RedisConnection(Arc<RedisConnKey>);
49
50impl RedisConnection {
51    pub async fn ping(&self) -> anyhow::Result<()> {
52        let pool = self.0.get_pool()?;
53        let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
54        conn.ping().await
55    }
56
57    pub async fn query(&self, cmd: Cmd) -> anyhow::Result<RedisValue> {
58        let pool = self.0.get_pool()?;
59        let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
60        Ok(cmd.query_async(&mut *conn).await?)
61    }
62
63    pub async fn invoke_script(
64        &self,
65        script: ScriptInvocation<'static>,
66    ) -> anyhow::Result<RedisValue> {
67        let pool = self.0.get_pool()?;
68        let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
69        Ok(script.invoke_async(&mut *conn).await?)
70    }
71}
72
73fn redis_value_to_lua(lua: &Lua, value: RedisValue) -> mlua::Result<Value> {
74    Ok(match value {
75        RedisValue::Nil => Value::Nil,
76        RedisValue::Int(i) => Value::Integer(i),
77        RedisValue::Boolean(i) => Value::Boolean(i),
78        RedisValue::BigNumber(i) => Value::String(lua.create_string(i.to_string())?),
79        RedisValue::Double(i) => Value::Number(i),
80        RedisValue::BulkString(bytes) => Value::String(lua.create_string(&bytes)?),
81        RedisValue::SimpleString(s) => Value::String(lua.create_string(&s)?),
82        RedisValue::Map(pairs) => {
83            let map = lua.create_table()?;
84            for (k, v) in pairs {
85                let k = redis_value_to_lua(lua, k)?;
86                let v = redis_value_to_lua(lua, v)?;
87                map.set(k, v)?;
88            }
89            Value::Table(map)
90        }
91        RedisValue::Array(values) => {
92            let array = lua.create_table()?;
93            for v in values {
94                array.push(redis_value_to_lua(lua, v)?)?;
95            }
96            Value::Table(array)
97        }
98        RedisValue::Set(values) => {
99            let array = lua.create_table()?;
100            for v in values {
101                array.push(redis_value_to_lua(lua, v)?)?;
102            }
103            Value::Table(array)
104        }
105        RedisValue::Attribute { data, attributes } => {
106            let map = lua.create_table()?;
107            for (k, v) in attributes {
108                let k = redis_value_to_lua(lua, k)?;
109                let v = redis_value_to_lua(lua, v)?;
110                map.set(k, v)?;
111            }
112
113            let attribute = lua.create_table()?;
114            attribute.set("data", redis_value_to_lua(lua, *data)?)?;
115            attribute.set("attributes", map)?;
116
117            Value::Table(attribute)
118        }
119        RedisValue::VerbatimString { format, text } => {
120            let vstr = lua.create_table()?;
121            vstr.set("format", format.to_string())?;
122            vstr.set("text", text)?;
123            Value::Table(vstr)
124        }
125        RedisValue::ServerError(_) => {
126            return Err(value
127                .extract_error()
128                .map_err(mlua::Error::external)
129                .unwrap_err());
130        }
131        RedisValue::Okay => Value::Boolean(true),
132        RedisValue::Push { kind, data } => {
133            let array = lua.create_table()?;
134            for v in data {
135                let v = redis_value_to_lua(lua, v)?;
136                array.push(v)?;
137            }
138
139            let push = lua.create_table()?;
140            push.set("data", array)?;
141            push.set("kind", kind.to_string())?;
142
143            Value::Table(push)
144        }
145    })
146}
147
148impl UserData for RedisConnection {
149    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
150        methods.add_async_method("query", |lua, this, params: MultiValue| async move {
151            let mut args = vec![];
152            for p in params {
153                args.push(from_lua_value(&lua, p)?);
154            }
155            let cmd = build_cmd(args).map_err(any_err)?;
156            let result = this.query(cmd).await.map_err(any_err)?;
157            redis_value_to_lua(&lua, result)
158        });
159    }
160}
161
162struct RedisJsonValue<'a>(&'a JsonValue);
163
164impl ToRedisArgs for RedisJsonValue<'_> {
165    fn write_redis_args<W>(&self, write: &mut W)
166    where
167        W: ?Sized + RedisWrite,
168    {
169        match self.0 {
170            JsonValue::Null => {}
171            JsonValue::Bool(b) => {
172                b.write_redis_args(write);
173            }
174            JsonValue::Number(n) => n.to_string().write_redis_args(write),
175            JsonValue::String(s) => s.write_redis_args(write),
176            JsonValue::Array(array) => {
177                for item in array {
178                    RedisJsonValue(item).write_redis_args(write);
179                }
180            }
181            JsonValue::Object(map) => {
182                for (k, v) in map {
183                    k.write_redis_args(write);
184                    RedisJsonValue(v).write_redis_args(write);
185                }
186            }
187        }
188    }
189
190    fn num_of_args(&self) -> usize {
191        match self.0 {
192            JsonValue::Array(array) => array.len(),
193            JsonValue::Null => 1,
194            JsonValue::Object(map) => map.len(),
195            JsonValue::Number(_) | JsonValue::Bool(_) | JsonValue::String(_) => 1,
196        }
197    }
198}
199
200pub fn build_cmd(args: Vec<JsonValue>) -> anyhow::Result<Cmd> {
201    let mut cmd = Cmd::new();
202    for a in args {
203        cmd.arg(RedisJsonValue(&a));
204    }
205    Ok(cmd)
206}
207
208#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
209#[serde(untagged)]
210pub enum NodeSpec {
211    /// A single, non-clustered redis node
212    Single(String),
213    /// List of redis URLs for hosts in the cluster
214    Cluster(Vec<String>),
215}
216
217#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
218pub struct RedisConnKey {
219    pub node: NodeSpec,
220    /// Enables reading from replicas for all new connections
221    #[serde(default)]
222    pub read_from_replicas: bool,
223    #[serde(default)]
224    pub username: Option<String>,
225    #[serde(default)]
226    pub password: Option<String>,
227    #[serde(default)]
228    pub cluster: Option<bool>,
229    /// Maximum number of connections managed by the pool.
230    /// Default is 10
231    #[serde(default)]
232    pub pool_size: Option<usize>,
233    #[serde(default, with = "duration_serde")]
234    pub connect_timeout: Option<Duration>,
235    #[serde(default, with = "duration_serde")]
236    pub recycle_timeout: Option<Duration>,
237    #[serde(default, with = "duration_serde")]
238    pub wait_timeout: Option<Duration>,
239    #[serde(default, with = "duration_serde")]
240    pub response_timeout: Option<Duration>,
241}
242
243pub enum ClientWrapper {
244    Single(Client, ConnectionManagerConfig),
245    Cluster(ClusterClient),
246}
247
248impl ClientWrapper {
249    pub async fn connect(&self) -> anyhow::Result<ConnectionWrapper> {
250        match self {
251            Self::Single(client, config) => Ok(ConnectionWrapper::Single(
252                ConnectionManager::new_with_config(client.clone(), config.clone()).await?,
253            )),
254            Self::Cluster(c) => Ok(ConnectionWrapper::Cluster(c.get_async_connection().await?)),
255        }
256    }
257}
258
259pub enum ConnectionWrapper {
260    Single(ConnectionManager),
261    Cluster(ClusterConnection),
262}
263
264impl ConnectionWrapper {
265    pub async fn ping(&mut self) -> anyhow::Result<()> {
266        Ok(redis::cmd("PING").query_async(self).await?)
267    }
268}
269
270impl ConnectionLike for ConnectionWrapper {
271    // Required methods
272    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, RedisValue> {
273        match self {
274            Self::Single(c) => c.req_packed_command(cmd),
275            Self::Cluster(c) => c.req_packed_command(cmd),
276        }
277    }
278
279    fn req_packed_commands<'a>(
280        &'a mut self,
281        cmd: &'a crate::Pipeline,
282        offset: usize,
283        count: usize,
284    ) -> RedisFuture<'a, Vec<RedisValue>> {
285        match self {
286            Self::Single(c) => c.req_packed_commands(cmd, offset, count),
287            Self::Cluster(c) => c.req_packed_commands(cmd, offset, count),
288        }
289    }
290
291    fn get_db(&self) -> i64 {
292        match self {
293            Self::Single(c) => c.get_db(),
294            Self::Cluster(c) => c.get_db(),
295        }
296    }
297}
298
299impl RedisConnKey {
300    pub fn build_client(&self) -> anyhow::Result<ClientWrapper> {
301        let cluster = self
302            .cluster
303            .unwrap_or(matches!(&self.node, NodeSpec::Cluster(_)));
304        let nodes = match &self.node {
305            NodeSpec::Single(node) => vec![node.to_string()],
306            NodeSpec::Cluster(nodes) => nodes.clone(),
307        };
308
309        if cluster {
310            let mut builder = ClusterClient::builder(nodes);
311            if self.read_from_replicas {
312                builder = builder.read_from_replicas();
313            }
314            if let Some(user) = &self.username {
315                builder = builder.username(user.to_string());
316            }
317            if let Some(pass) = &self.password {
318                builder = builder.password(pass.to_string());
319            }
320            if let Some(duration) = self.connect_timeout {
321                builder = builder.connection_timeout(duration);
322            }
323            if let Some(duration) = self.response_timeout {
324                builder = builder.response_timeout(duration);
325            }
326
327            Ok(ClientWrapper::Cluster(builder.build().with_context(
328                || format!("building redis client {self:?}"),
329            )?))
330        } else {
331            let mut config = ConnectionManagerConfig::new();
332            if let Some(duration) = self.connect_timeout {
333                config = config.set_connection_timeout(duration);
334            }
335            if let Some(duration) = self.response_timeout {
336                config = config.set_response_timeout(duration);
337            }
338
339            let mut info: ConnectionInfo = nodes[0]
340                .as_str()
341                .into_connection_info()
342                .with_context(|| format!("building redis client {self:?}"))?;
343            if let Some(user) = &self.username {
344                info.redis.username.replace(user.to_string());
345            }
346            if let Some(pass) = &self.password {
347                info.redis.password.replace(pass.to_string());
348            }
349
350            Ok(ClientWrapper::Single(
351                Client::open(info).with_context(|| format!("building redis client {self:?}"))?,
352                config,
353            ))
354        }
355    }
356
357    pub fn get_pool(&self) -> anyhow::Result<Pool<ClientManager>> {
358        let mut pools = POOLS.lock().unwrap();
359        if let Some(pool) = pools.get(self) {
360            return Ok(pool.clone());
361        }
362
363        let client = self.build_client()?;
364        let mut builder = Pool::builder(ClientManager(client))
365            .runtime(deadpool::Runtime::Tokio1)
366            .create_timeout(self.connect_timeout)
367            .recycle_timeout(self.recycle_timeout)
368            .wait_timeout(self.wait_timeout);
369
370        if let Some(limit) = self.pool_size {
371            builder = builder.max_size(limit);
372        }
373
374        let pool = builder.build()?;
375
376        pools.insert(self.clone(), pool.clone());
377
378        Ok(pool)
379    }
380
381    pub fn open(&self) -> anyhow::Result<RedisConnection> {
382        self.build_client()?;
383        Ok(RedisConnection(Arc::new(self.clone())))
384    }
385}
386
387pub fn register(lua: &Lua) -> anyhow::Result<()> {
388    let redis_mod = get_or_create_module(lua, "redis")?;
389
390    redis_mod.set(
391        "open",
392        lua.create_function(move |lua, key: Value| {
393            let key: RedisConnKey = from_lua_value(lua, key)?;
394            key.open().map_err(any_err)
395        })?,
396    )?;
397
398    Ok(())
399}