1use anyhow::Context;
2use config::{any_err, from_lua_value, get_or_create_module};
3use mlua::{Lua, LuaSerdeExt, MultiValue, UserData, UserDataMethods, Value};
4use serde_json::{Map, Value as JsonValue};
5use sqlite::{Connection, ConnectionThreadSafe, ParameterIndex, State, Statement, Type};
6use std::sync::{Arc, Mutex};
7
8fn bind_param<I: ParameterIndex>(
9 stmt: &mut Statement,
10 index: I,
11 value: &JsonValue,
12) -> anyhow::Result<()> {
13 match value {
14 JsonValue::Null => stmt.bind((index, ()))?,
15 JsonValue::Number(n) => {
16 if let Some(i) = n.as_i64() {
17 stmt.bind((index, i))?
18 } else if let Some(f) = n.as_f64() {
19 stmt.bind((index, f))?
20 } else {
21 anyhow::bail!("numeric value {n} is out of range for sqlite");
22 }
23 }
24 JsonValue::String(s) => stmt.bind((index, s.as_str()))?,
25 _ => {
26 anyhow::bail!("only numbers, strings and nil can be passed as parameter values");
27 }
28 };
29 Ok(())
30}
31
32fn params_to_json(lua: &Lua, mut params: MultiValue) -> mlua::Result<JsonValue> {
33 match params.len() {
34 0 => Ok(JsonValue::Null),
35 1 => {
36 let param = params
37 .pop_front()
38 .expect("we checked and we have at least one");
39 let param: JsonValue = from_lua_value(lua, param)?;
40 Ok(param)
41 }
42 _ => {
43 let mut result = vec![];
44 for p in params {
45 let p: JsonValue = from_lua_value(lua, p)?;
46 result.push(p);
47 }
48 Ok(JsonValue::Array(result))
49 }
50 }
51}
52
53fn bind_params(stmt: &mut Statement, params: &JsonValue) -> anyhow::Result<()> {
54 match params {
55 JsonValue::Object(obj) => {
56 for (name, value) in obj.iter() {
57 bind_param(stmt, format!(":{name}").as_str(), value)
58 .with_context(|| format!("binding parameter :{name} with value {value:?}"))?;
59 }
60 Ok(())
61 }
62 JsonValue::Array(arr) => {
63 for (i, value) in arr.iter().enumerate() {
64 let i = i + 1;
66 bind_param(stmt, i, value)
67 .with_context(|| format!("binding parameter {i} with value {value:?}"))?;
68 }
69 Ok(())
70 }
71 JsonValue::Null => Ok(()),
72 p => bind_param(stmt, 1, p)
73 .with_context(|| format!("binding sole parameter with value {p:?}")),
74 }
75}
76
77fn get_column(stmt: &Statement, index: usize) -> anyhow::Result<JsonValue> {
78 match stmt.column_type(index)? {
79 Type::Binary | Type::String => {
80 let s: String = stmt.read(index).map_err(any_err)?;
81 Ok(s.into())
82 }
83 Type::Integer => {
84 let i: i64 = stmt.read(index)?;
85 Ok(i.into())
86 }
87 Type::Float => {
88 let f: f64 = stmt.read(index)?;
89 Ok(f.into())
90 }
91 Type::Null => Ok(JsonValue::Null),
92 }
93}
94
95#[derive(Clone)]
96struct Conn(Arc<Mutex<Option<Arc<ConnectionThreadSafe>>>>);
97
98impl Conn {
99 fn get_conn(&self) -> anyhow::Result<Arc<ConnectionThreadSafe>> {
100 self.0
101 .lock()
102 .unwrap()
103 .as_ref()
104 .map(Arc::clone)
105 .ok_or_else(|| anyhow::anyhow!("connection was closed"))
106 }
107
108 fn execute(&self, sql: String, params: JsonValue) -> anyhow::Result<JsonValue> {
112 let conn = self.get_conn()?;
113 let mut stmt = conn.prepare(&sql)?;
114 bind_params(&mut stmt, ¶ms)
115 .with_context(|| format!("bind parameters {params:?} in query `{sql}'"))?;
116
117 let state = stmt.next()?;
118 if state == State::Done && stmt.column_count() == 0 {
119 return Ok(conn.change_count().into());
122 }
123
124 let mut table = vec![];
125 let col_count = stmt.column_count();
128 if col_count == 1 {
129 loop {
130 let value = get_column(&mut stmt, 0)?;
131 table.push(value);
132
133 if stmt.next()? == State::Done {
134 break;
135 }
136 }
137 } else {
138 loop {
139 let mut obj = Map::new();
140 let col_names = stmt.column_names();
141 for i in 0..col_count {
142 let value = get_column(&stmt, i)?;
143 obj.insert(col_names[i].to_string(), value);
144 }
145 table.push(JsonValue::Object(obj));
146
147 if stmt.next()? == State::Done {
148 break;
149 }
150 }
151 }
152
153 Ok(JsonValue::Array(table))
154 }
155
156 async fn async_execute(self, sql: String, params: JsonValue) -> anyhow::Result<JsonValue> {
157 tokio::task::Builder::new()
158 .name(&format!("sqlite {sql}"))
159 .spawn_blocking(move || -> anyhow::Result<JsonValue> { self.execute(sql, params) })
160 .map_err(any_err)?
161 .await
162 .map_err(any_err)?
163 }
164}
165
166impl UserData for Conn {
167 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
168 methods.add_async_method(
169 "execute",
170 |lua, this, (sql, params): (String, MultiValue)| async move {
171 let json_params = params_to_json(&lua, params)?;
172 let result = this
173 .clone()
174 .async_execute(sql, json_params)
175 .await
176 .map_err(any_err)?;
177
178 let result: Value = lua
179 .to_value_with(&result, config::serialize_options())
180 .map_err(any_err)?;
181 Ok(result)
182 },
183 );
184
185 methods.add_method("close", |_lua, this, _: ()| {
186 this.0.lock().unwrap().take();
187 Ok(())
188 });
189 }
190}
191
192pub fn register(lua: &Lua) -> anyhow::Result<()> {
193 let sqlite_mod = get_or_create_module(lua, "sqlite")?;
194
195 sqlite_mod.set(
196 "open",
197 lua.create_function(move |_, (path, busy_timeout): (String, Option<usize>)| {
198 let mut db = Connection::open_thread_safe(path).map_err(any_err)?;
199 db.set_busy_timeout(busy_timeout.unwrap_or(500))
200 .map_err(any_err)?;
201 Ok(Conn(Arc::new(Mutex::new(Some(Arc::new(db))))))
202 })?,
203 )?;
204
205 Ok(())
206}