mod_sqlite/
lib.rs

1use anyhow::Context;
2use config::{any_err, from_lua_value, get_or_create_module};
3use mlua::{Lua, LuaSerdeExt, MultiValue, UserData, UserDataMethods, Value};
4use serde_json::{Map, Value as JsonValue};
5use sqlite::{Connection, ConnectionThreadSafe, ParameterIndex, State, Statement, Type};
6use std::sync::{Arc, Mutex};
7
8fn bind_param<I: ParameterIndex>(
9    stmt: &mut Statement,
10    index: I,
11    value: &JsonValue,
12) -> anyhow::Result<()> {
13    match value {
14        JsonValue::Null => stmt.bind((index, ()))?,
15        JsonValue::Number(n) => {
16            if let Some(i) = n.as_i64() {
17                stmt.bind((index, i))?
18            } else if let Some(f) = n.as_f64() {
19                stmt.bind((index, f))?
20            } else {
21                anyhow::bail!("numeric value {n} is out of range for sqlite");
22            }
23        }
24        JsonValue::String(s) => stmt.bind((index, s.as_str()))?,
25        _ => {
26            anyhow::bail!("only numbers, strings and nil can be passed as parameter values");
27        }
28    };
29    Ok(())
30}
31
32fn params_to_json(lua: &Lua, mut params: MultiValue) -> mlua::Result<JsonValue> {
33    match params.len() {
34        0 => Ok(JsonValue::Null),
35        1 => {
36            let param = params
37                .pop_front()
38                .expect("we checked and we have at least one");
39            let param: JsonValue = from_lua_value(lua, param)?;
40            Ok(param)
41        }
42        _ => {
43            let mut result = vec![];
44            for p in params {
45                let p: JsonValue = from_lua_value(lua, p)?;
46                result.push(p);
47            }
48            Ok(JsonValue::Array(result))
49        }
50    }
51}
52
53fn bind_params(stmt: &mut Statement, params: &JsonValue) -> anyhow::Result<()> {
54    match params {
55        JsonValue::Object(obj) => {
56            for (name, value) in obj.iter() {
57                bind_param(stmt, format!(":{name}").as_str(), value)
58                    .with_context(|| format!("binding parameter :{name} with value {value:?}"))?;
59            }
60            Ok(())
61        }
62        JsonValue::Array(arr) => {
63            for (i, value) in arr.iter().enumerate() {
64                // Parameter indices are 1-based
65                let i = i + 1;
66                bind_param(stmt, i, value)
67                    .with_context(|| format!("binding parameter {i} with value {value:?}"))?;
68            }
69            Ok(())
70        }
71        JsonValue::Null => Ok(()),
72        p => bind_param(stmt, 1, p)
73            .with_context(|| format!("binding sole parameter with value {p:?}")),
74    }
75}
76
77fn get_column(stmt: &Statement, index: usize) -> anyhow::Result<JsonValue> {
78    match stmt.column_type(index)? {
79        Type::Binary | Type::String => {
80            let s: String = stmt.read(index).map_err(any_err)?;
81            Ok(s.into())
82        }
83        Type::Integer => {
84            let i: i64 = stmt.read(index)?;
85            Ok(i.into())
86        }
87        Type::Float => {
88            let f: f64 = stmt.read(index)?;
89            Ok(f.into())
90        }
91        Type::Null => Ok(JsonValue::Null),
92    }
93}
94
95#[derive(Clone)]
96struct Conn(Arc<Mutex<Option<Arc<ConnectionThreadSafe>>>>);
97
98impl Conn {
99    fn get_conn(&self) -> anyhow::Result<Arc<ConnectionThreadSafe>> {
100        self.0
101            .lock()
102            .unwrap()
103            .as_ref()
104            .map(Arc::clone)
105            .ok_or_else(|| anyhow::anyhow!("connection was closed"))
106    }
107
108    // Sqlite queries are blocking and we cannot safely block an async
109    // function, so we push the work over to this blocking function
110    // via spawn_blocking.
111    fn execute(&self, sql: String, params: JsonValue) -> anyhow::Result<JsonValue> {
112        let conn = self.get_conn()?;
113        let mut stmt = conn.prepare(&sql)?;
114        bind_params(&mut stmt, &params)
115            .with_context(|| format!("bind parameters {params:?} in query `{sql}'"))?;
116
117        let state = stmt.next()?;
118        if state == State::Done && stmt.column_count() == 0 {
119            // Query cannot return any rows, so we'll return
120            // the affected row count
121            return Ok(conn.change_count().into());
122        }
123
124        let mut table = vec![];
125        // Query has rows. Decide whether we are returning a simple
126        // array of single column results, or an array of objects
127        let col_count = stmt.column_count();
128        if col_count == 1 {
129            loop {
130                let value = get_column(&mut stmt, 0)?;
131                table.push(value);
132
133                if stmt.next()? == State::Done {
134                    break;
135                }
136            }
137        } else {
138            loop {
139                let mut obj = Map::new();
140                let col_names = stmt.column_names();
141                for i in 0..col_count {
142                    let value = get_column(&stmt, i)?;
143                    obj.insert(col_names[i].to_string(), value);
144                }
145                table.push(JsonValue::Object(obj));
146
147                if stmt.next()? == State::Done {
148                    break;
149                }
150            }
151        }
152
153        Ok(JsonValue::Array(table))
154    }
155
156    async fn async_execute(self, sql: String, params: JsonValue) -> anyhow::Result<JsonValue> {
157        tokio::task::Builder::new()
158            .name(&format!("sqlite {sql}"))
159            .spawn_blocking(move || -> anyhow::Result<JsonValue> { self.execute(sql, params) })
160            .map_err(any_err)?
161            .await
162            .map_err(any_err)?
163    }
164}
165
166impl UserData for Conn {
167    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
168        methods.add_async_method(
169            "execute",
170            |lua, this, (sql, params): (String, MultiValue)| async move {
171                let json_params = params_to_json(&lua, params)?;
172                let result = this
173                    .clone()
174                    .async_execute(sql, json_params)
175                    .await
176                    .map_err(any_err)?;
177
178                let result: Value = lua
179                    .to_value_with(&result, config::serialize_options())
180                    .map_err(any_err)?;
181                Ok(result)
182            },
183        );
184
185        methods.add_method("close", |_lua, this, _: ()| {
186            this.0.lock().unwrap().take();
187            Ok(())
188        });
189    }
190}
191
192pub fn register(lua: &Lua) -> anyhow::Result<()> {
193    let sqlite_mod = get_or_create_module(lua, "sqlite")?;
194
195    sqlite_mod.set(
196        "open",
197        lua.create_function(move |_, (path, busy_timeout): (String, Option<usize>)| {
198            let mut db = Connection::open_thread_safe(path).map_err(any_err)?;
199            db.set_busy_timeout(busy_timeout.unwrap_or(500))
200                .map_err(any_err)?;
201            Ok(Conn(Arc::new(Mutex::new(Some(Arc::new(db))))))
202        })?,
203    )?;
204
205    Ok(())
206}