1use anyhow::Context;
2use config::{any_err, from_lua_value, get_or_create_module};
3use deadpool::managed::{Manager, Metrics, Pool, RecycleError, RecycleResult};
4use mlua::{Lua, MultiValue, UserData, UserDataMethods, Value};
5use redis::aio::{ConnectionLike, ConnectionManager, ConnectionManagerConfig};
6use redis::cluster::ClusterClient;
7use redis::cluster_async::ClusterConnection;
8pub use redis::{
9 cmd, Cmd, FromRedisValue, RedisError, Script, ScriptInvocation, Value as RedisValue,
10};
11use redis::{
12 Client, ConnectionInfo, IntoConnectionInfo, Pipeline, RedisFuture, RedisWrite, ToRedisArgs,
13};
14use serde::Deserialize;
15use serde_json::Value as JsonValue;
16use std::collections::HashMap;
17use std::sync::{Arc, LazyLock, Mutex};
18use std::time::Duration;
19
20pub mod test;
21
22static POOLS: LazyLock<Mutex<HashMap<RedisConnKey, Pool<ClientManager>>>> =
23 LazyLock::new(Mutex::default);
24
25pub struct ClientManager(ClientWrapper);
26
27impl Manager for ClientManager {
28 type Type = ConnectionWrapper;
29 type Error = anyhow::Error;
30
31 async fn create(&self) -> Result<Self::Type, Self::Error> {
32 let c = self.0.connect().await?;
33 Ok(c)
34 }
35
36 async fn recycle(
37 &self,
38 conn: &mut Self::Type,
39 _metrics: &Metrics,
40 ) -> RecycleResult<anyhow::Error> {
41 conn.ping()
42 .await
43 .map_err(|err| RecycleError::message(format!("{err:#}")))
44 }
45}
46
47#[derive(Clone, Debug)]
48pub struct RedisConnection(Arc<RedisConnKey>);
49
50impl RedisConnection {
51 pub async fn ping(&self) -> anyhow::Result<()> {
52 let pool = self.0.get_pool()?;
53 let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
54 conn.ping().await
55 }
56
57 pub async fn query(&self, cmd: Cmd) -> anyhow::Result<RedisValue> {
58 let pool = self.0.get_pool()?;
59 let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
60 Ok(cmd.query_async(&mut *conn).await?)
61 }
62
63 pub async fn invoke_script(
64 &self,
65 script: ScriptInvocation<'static>,
66 ) -> anyhow::Result<RedisValue> {
67 let pool = self.0.get_pool()?;
68 let mut conn = pool.get().await.map_err(|err| anyhow::anyhow!("{err:#}"))?;
69 Ok(script.invoke_async(&mut *conn).await?)
70 }
71}
72
73fn redis_value_to_lua(lua: &Lua, value: RedisValue) -> mlua::Result<Value> {
74 Ok(match value {
75 RedisValue::Nil => Value::Nil,
76 RedisValue::Int(i) => Value::Integer(i),
77 RedisValue::Boolean(i) => Value::Boolean(i),
78 RedisValue::BigNumber(i) => Value::String(lua.create_string(i.to_string())?),
79 RedisValue::Double(i) => Value::Number(i),
80 RedisValue::BulkString(bytes) => Value::String(lua.create_string(&bytes)?),
81 RedisValue::SimpleString(s) => Value::String(lua.create_string(&s)?),
82 RedisValue::Map(pairs) => {
83 let map = lua.create_table()?;
84 for (k, v) in pairs {
85 let k = redis_value_to_lua(lua, k)?;
86 let v = redis_value_to_lua(lua, v)?;
87 map.set(k, v)?;
88 }
89 Value::Table(map)
90 }
91 RedisValue::Array(values) => {
92 let array = lua.create_table()?;
93 for v in values {
94 array.push(redis_value_to_lua(lua, v)?)?;
95 }
96 Value::Table(array)
97 }
98 RedisValue::Set(values) => {
99 let array = lua.create_table()?;
100 for v in values {
101 array.push(redis_value_to_lua(lua, v)?)?;
102 }
103 Value::Table(array)
104 }
105 RedisValue::Attribute { data, attributes } => {
106 let map = lua.create_table()?;
107 for (k, v) in attributes {
108 let k = redis_value_to_lua(lua, k)?;
109 let v = redis_value_to_lua(lua, v)?;
110 map.set(k, v)?;
111 }
112
113 let attribute = lua.create_table()?;
114 attribute.set("data", redis_value_to_lua(lua, *data)?)?;
115 attribute.set("attributes", map)?;
116
117 Value::Table(attribute)
118 }
119 RedisValue::VerbatimString { format, text } => {
120 let vstr = lua.create_table()?;
121 vstr.set("format", format.to_string())?;
122 vstr.set("text", text)?;
123 Value::Table(vstr)
124 }
125 RedisValue::ServerError(_) => {
126 return Err(value
127 .extract_error()
128 .map_err(mlua::Error::external)
129 .unwrap_err());
130 }
131 RedisValue::Okay => Value::Boolean(true),
132 RedisValue::Push { kind, data } => {
133 let array = lua.create_table()?;
134 for v in data {
135 let v = redis_value_to_lua(lua, v)?;
136 array.push(v)?;
137 }
138
139 let push = lua.create_table()?;
140 push.set("data", array)?;
141 push.set("kind", kind.to_string())?;
142
143 Value::Table(push)
144 }
145 })
146}
147
148impl UserData for RedisConnection {
149 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
150 methods.add_async_method("query", |lua, this, params: MultiValue| async move {
151 let mut args = vec![];
152 for p in params {
153 args.push(from_lua_value(&lua, p)?);
154 }
155 let cmd = build_cmd(args).map_err(any_err)?;
156 let result = this.query(cmd).await.map_err(any_err)?;
157 redis_value_to_lua(&lua, result)
158 });
159 }
160}
161
162struct RedisJsonValue<'a>(&'a JsonValue);
163
164impl ToRedisArgs for RedisJsonValue<'_> {
165 fn write_redis_args<W>(&self, write: &mut W)
166 where
167 W: ?Sized + RedisWrite,
168 {
169 match self.0 {
170 JsonValue::Null => {}
171 JsonValue::Bool(b) => {
172 b.write_redis_args(write);
173 }
174 JsonValue::Number(n) => n.to_string().write_redis_args(write),
175 JsonValue::String(s) => s.write_redis_args(write),
176 JsonValue::Array(array) => {
177 for item in array {
178 RedisJsonValue(item).write_redis_args(write);
179 }
180 }
181 JsonValue::Object(map) => {
182 for (k, v) in map {
183 k.write_redis_args(write);
184 RedisJsonValue(v).write_redis_args(write);
185 }
186 }
187 }
188 }
189
190 fn num_of_args(&self) -> usize {
191 match self.0 {
192 JsonValue::Array(array) => array.len(),
193 JsonValue::Null => 1,
194 JsonValue::Object(map) => map.len(),
195 JsonValue::Number(_) | JsonValue::Bool(_) | JsonValue::String(_) => 1,
196 }
197 }
198}
199
200pub fn build_cmd(args: Vec<JsonValue>) -> anyhow::Result<Cmd> {
201 let mut cmd = Cmd::new();
202 for a in args {
203 cmd.arg(RedisJsonValue(&a));
204 }
205 Ok(cmd)
206}
207
208#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
209#[serde(untagged)]
210pub enum NodeSpec {
211 Single(String),
213 Cluster(Vec<String>),
215}
216
217#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize)]
218pub struct RedisConnKey {
219 pub node: NodeSpec,
220 #[serde(default)]
222 pub read_from_replicas: bool,
223 #[serde(default)]
224 pub username: Option<String>,
225 #[serde(default)]
226 pub password: Option<String>,
227 #[serde(default)]
228 pub cluster: Option<bool>,
229 #[serde(default)]
232 pub pool_size: Option<usize>,
233 #[serde(default, with = "duration_serde")]
234 pub connect_timeout: Option<Duration>,
235 #[serde(default, with = "duration_serde")]
236 pub recycle_timeout: Option<Duration>,
237 #[serde(default, with = "duration_serde")]
238 pub wait_timeout: Option<Duration>,
239 #[serde(default, with = "duration_serde")]
240 pub response_timeout: Option<Duration>,
241}
242
243pub enum ClientWrapper {
244 Single(Client, ConnectionManagerConfig),
245 Cluster(ClusterClient),
246}
247
248impl ClientWrapper {
249 pub async fn connect(&self) -> anyhow::Result<ConnectionWrapper> {
250 match self {
251 Self::Single(client, config) => Ok(ConnectionWrapper::Single(
252 ConnectionManager::new_with_config(client.clone(), config.clone()).await?,
253 )),
254 Self::Cluster(c) => Ok(ConnectionWrapper::Cluster(c.get_async_connection().await?)),
255 }
256 }
257}
258
259pub enum ConnectionWrapper {
260 Single(ConnectionManager),
261 Cluster(ClusterConnection),
262}
263
264impl ConnectionWrapper {
265 pub async fn ping(&mut self) -> anyhow::Result<()> {
266 Ok(redis::cmd("PING").query_async(self).await?)
267 }
268}
269
270impl ConnectionLike for ConnectionWrapper {
271 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, RedisValue> {
273 match self {
274 Self::Single(c) => c.req_packed_command(cmd),
275 Self::Cluster(c) => c.req_packed_command(cmd),
276 }
277 }
278
279 fn req_packed_commands<'a>(
280 &'a mut self,
281 cmd: &'a crate::Pipeline,
282 offset: usize,
283 count: usize,
284 ) -> RedisFuture<'a, Vec<RedisValue>> {
285 match self {
286 Self::Single(c) => c.req_packed_commands(cmd, offset, count),
287 Self::Cluster(c) => c.req_packed_commands(cmd, offset, count),
288 }
289 }
290
291 fn get_db(&self) -> i64 {
292 match self {
293 Self::Single(c) => c.get_db(),
294 Self::Cluster(c) => c.get_db(),
295 }
296 }
297}
298
299impl RedisConnKey {
300 pub fn build_client(&self) -> anyhow::Result<ClientWrapper> {
301 let cluster = self
302 .cluster
303 .unwrap_or(matches!(&self.node, NodeSpec::Cluster(_)));
304 let nodes = match &self.node {
305 NodeSpec::Single(node) => vec![node.to_string()],
306 NodeSpec::Cluster(nodes) => nodes.clone(),
307 };
308
309 if cluster {
310 let mut builder = ClusterClient::builder(nodes);
311 if self.read_from_replicas {
312 builder = builder.read_from_replicas();
313 }
314 if let Some(user) = &self.username {
315 builder = builder.username(user.to_string());
316 }
317 if let Some(pass) = &self.password {
318 builder = builder.password(pass.to_string());
319 }
320 if let Some(duration) = self.connect_timeout {
321 builder = builder.connection_timeout(duration);
322 }
323 if let Some(duration) = self.response_timeout {
324 builder = builder.response_timeout(duration);
325 }
326
327 Ok(ClientWrapper::Cluster(builder.build().with_context(
328 || format!("building redis client {self:?}"),
329 )?))
330 } else {
331 let mut config = ConnectionManagerConfig::new();
332 if let Some(duration) = self.connect_timeout {
333 config = config.set_connection_timeout(duration);
334 }
335 if let Some(duration) = self.response_timeout {
336 config = config.set_response_timeout(duration);
337 }
338
339 let mut info: ConnectionInfo = nodes[0]
340 .as_str()
341 .into_connection_info()
342 .with_context(|| format!("building redis client {self:?}"))?;
343 if let Some(user) = &self.username {
344 info.redis.username.replace(user.to_string());
345 }
346 if let Some(pass) = &self.password {
347 info.redis.password.replace(pass.to_string());
348 }
349
350 Ok(ClientWrapper::Single(
351 Client::open(info).with_context(|| format!("building redis client {self:?}"))?,
352 config,
353 ))
354 }
355 }
356
357 pub fn get_pool(&self) -> anyhow::Result<Pool<ClientManager>> {
358 let mut pools = POOLS.lock().unwrap();
359 if let Some(pool) = pools.get(self) {
360 return Ok(pool.clone());
361 }
362
363 let client = self.build_client()?;
364 let mut builder = Pool::builder(ClientManager(client))
365 .runtime(deadpool::Runtime::Tokio1)
366 .create_timeout(self.connect_timeout)
367 .recycle_timeout(self.recycle_timeout)
368 .wait_timeout(self.wait_timeout);
369
370 if let Some(limit) = self.pool_size {
371 builder = builder.max_size(limit);
372 }
373
374 let pool = builder.build()?;
375
376 pools.insert(self.clone(), pool.clone());
377
378 Ok(pool)
379 }
380
381 pub fn open(&self) -> anyhow::Result<RedisConnection> {
382 self.build_client()?;
383 Ok(RedisConnection(Arc::new(self.clone())))
384 }
385}
386
387pub fn register(lua: &Lua) -> anyhow::Result<()> {
388 let redis_mod = get_or_create_module(lua, "redis")?;
389
390 redis_mod.set(
391 "open",
392 lua.create_function(move |lua, key: Value| {
393 let key: RedisConnKey = from_lua_value(lua, key)?;
394 key.open().map_err(any_err)
395 })?,
396 )?;
397
398 Ok(())
399}