data_loader/
lib.rs

1#[cfg(feature = "impl")]
2use anyhow::{anyhow, Context};
3#[cfg(feature = "impl")]
4use config::{any_err, from_lua_value, get_or_create_sub_module};
5#[cfg(feature = "impl")]
6use mlua::Lua;
7use serde::{Deserialize, Serialize};
8#[cfg(feature = "impl")]
9use vaultrs::client::{VaultClient, VaultClientSettingsBuilder};
10
11#[derive(Deserialize, Serialize, Clone, Hash, PartialEq, Eq, Debug)]
12#[serde(untagged)]
13pub enum KeySource {
14    File(String),
15    Data {
16        #[serde(with = "serde_bytes")]
17        key_data: Vec<u8>,
18    },
19    Vault {
20        vault_address: Option<String>,
21        vault_token: Option<String>,
22        vault_mount: String,
23        vault_path: String,
24        #[serde(default = "default_vault_key")]
25        vault_key: String,
26    },
27    Event {
28        event_name: String,
29        event_args: Vec<serde_json::Value>,
30    },
31}
32
33fn default_vault_key() -> String {
34    "key".to_string()
35}
36
37#[cfg(feature = "impl")]
38impl KeySource {
39    pub async fn get(&self) -> anyhow::Result<Vec<u8>> {
40        match self {
41            Self::File(path) => Ok(tokio::fs::read(path)
42                .await
43                .with_context(|| format!("KeySource failed to load data from file `{}`", path))?),
44            Self::Data { key_data } => Ok(key_data.to_vec()),
45            Self::Vault {
46                vault_address,
47                vault_token,
48                vault_mount,
49                vault_path,
50                vault_key,
51            } => {
52                let address = match vault_address {
53                    Some(a) => a.to_string(),
54                    None => std::env::var("VAULT_ADDR").map_err(|err| {
55                        anyhow!(
56                            "vault_address was not specified and $VAULT_ADDR is not set/usable: {self:?} {err:#}"
57                        )
58                    })?,
59                };
60                let token = match vault_token {
61                    Some(a) => a.to_string(),
62                    None => std::env::var("VAULT_TOKEN").map_err(|err| {
63                        anyhow!(
64                            "vault_token was not specified and $VAULT_TOKEN is not set/usable: {self:?} {err:#}"
65                        )
66                    })?,
67                };
68
69                let client = VaultClient::new(
70                    VaultClientSettingsBuilder::default()
71                        .address(address)
72                        .token(token)
73                        .build()?,
74                )?;
75
76                let entry: serde_json::Value = vaultrs::kv2::read(&client, vault_mount, vault_path)
77                    .await
78                    .with_context(|| {
79                        format!(
80                            "kv2::read vault_mount={vault_mount}, vault_path={vault_path} {self:?}"
81                        )
82                    })?;
83
84                let value = entry
85                    .get(&vault_key)
86                    .and_then(|v| v.as_str())
87                    .ok_or_else(|| {
88                        anyhow!("vault secret at {vault_path} does not contain key '{vault_key}'")
89                    })?;
90
91                Ok(value.as_bytes().to_vec())
92            }
93            Self::Event {
94                event_name,
95                event_args,
96            } => {
97                let mut config = config::load_config().await?;
98                let sig = config::CallbackSignature::<mlua::MultiValue, mlua::String>::new(
99                    event_name.clone(),
100                );
101
102                let args = config.convert_args_to_multi(event_args)?;
103                let result = config.async_call_callback_non_default(&sig, args).await?;
104
105                Ok(result.as_bytes().to_vec())
106            }
107        }
108    }
109}
110
111#[cfg(feature = "impl")]
112pub fn register(lua: &Lua) -> anyhow::Result<()> {
113    let secrets_mod = get_or_create_sub_module(lua, "secrets")?;
114
115    secrets_mod.set(
116        "load",
117        lua.create_async_function(|lua, source: mlua::Value| async move {
118            let source: KeySource = from_lua_value(&lua, source)?;
119            lua.create_string(&source.get().await.map_err(any_err)?)
120        })?,
121    )?;
122
123    Ok(())
124}
125
126#[cfg(test)]
127#[cfg(feature = "impl")]
128mod test {
129    use super::*;
130    use anyhow::Context;
131    use std::process::Stdio;
132    use std::time::Duration;
133    use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
134    use tokio::process::{Child, Command};
135    use tokio::time::timeout;
136    use vaultrs::client::Client;
137
138    /// Ask the kernel to assign a free port.
139    /// We pass this to sshd and tell it to listen on that port.
140    /// This is racy, as releasing the socket technically makes
141    /// that port available to others using the same technique.
142    fn allocate_port() -> u16 {
143        let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("bind 127.0.0.1:0 failed");
144        listener.local_addr().unwrap().port()
145    }
146
147    const KEY: &str = "woot";
148
149    struct VaultServer {
150        port: u16,
151        _daemon: Child,
152    }
153
154    impl VaultServer {
155        pub async fn spawn() -> anyhow::Result<Self> {
156            let mut errors = vec![];
157
158            for _ in 0..2 {
159                let port = allocate_port();
160                match timeout(Duration::from_secs(5), Self::spawn_with_port(port)).await? {
161                    Ok(me) => return Ok(me),
162                    Err(err) => {
163                        errors.push(format!("{err:#}"));
164                    }
165                }
166            }
167            anyhow::bail!("failed to spawn vault: {}", errors.join(". "));
168        }
169
170        pub fn address(&self) -> String {
171            format!("http://127.0.0.1:{}", self.port)
172        }
173
174        async fn spawn_with_port(port: u16) -> anyhow::Result<Self> {
175            eprintln!("Trying to start vault on port {port}");
176
177            let mut daemon = Command::new("vault")
178                .args([
179                    "server",
180                    "-dev",
181                    &format!("-dev-listen-address=127.0.0.1:{port}"),
182                    &format!("-dev-root-token-id={KEY}"),
183                ])
184                .stdin(Stdio::piped())
185                .stderr(Stdio::piped())
186                .stdout(Stdio::piped())
187                .kill_on_drop(true)
188                .spawn()
189                .context("spawning vault")?;
190
191            let mut stderr = daemon.stderr.take().unwrap();
192            tokio::spawn(async move {
193                copy_stream_with_line_prefix("vault stderr", &mut stderr, &mut tokio::io::stderr())
194                    .await
195            });
196            let mut stdout = daemon.stdout.take().unwrap();
197            tokio::spawn(async move {
198                copy_stream_with_line_prefix("vault stdout", &mut stdout, &mut tokio::io::stderr())
199                    .await
200            });
201
202            let mut ok = false;
203            for _ in 0..25 {
204                let client = VaultClient::new(
205                    VaultClientSettingsBuilder::default()
206                        .address(format!("http://127.0.0.1:{port}"))
207                        .token(KEY)
208                        .build()?,
209                )?;
210                let status = client.status().await;
211                eprintln!("checking status: {status:?}");
212                if let Ok(vaultrs::sys::ServerStatus::OK) = status {
213                    ok = true;
214                    break;
215                }
216                tokio::time::sleep(Duration::from_millis(100)).await;
217            }
218
219            anyhow::ensure!(ok, "server didn't startup successfully");
220
221            if let Ok(Some(status)) = daemon.try_wait() {
222                anyhow::bail!("daemon exited already: {status:?}");
223            }
224
225            Ok(Self {
226                _daemon: daemon,
227                port,
228            })
229        }
230
231        pub async fn put_from_file(&self, vault_path: &str, path: &str) -> anyhow::Result<()> {
232            let output = Command::new("vault")
233                .args([
234                    "kv",
235                    "put",
236                    &format!("-address=http://127.0.0.1:{}", self.port),
237                    "-mount=secret",
238                    vault_path,
239                    &format!("key=@{path}"),
240                ])
241                .output()
242                .await?;
243            let stdout = String::from_utf8_lossy(&output.stdout);
244            if !stdout.is_empty() {
245                eprintln!("put_from_file: {stdout}");
246            }
247            let stderr = String::from_utf8_lossy(&output.stderr);
248            if !stderr.is_empty() {
249                eprintln!("put_from_file: {stderr}");
250            }
251            anyhow::ensure!(output.status.success(), "{:?}", output.status);
252            Ok(())
253        }
254
255        pub async fn put(&self, vault_path: &str, value: &str) -> anyhow::Result<()> {
256            let output = Command::new("vault")
257                .args([
258                    "kv",
259                    "put",
260                    &format!("-address=http://127.0.0.1:{}", self.port),
261                    "-mount=secret",
262                    vault_path,
263                    &format!("key={value}"),
264                ])
265                .output()
266                .await?;
267            let stdout = String::from_utf8_lossy(&output.stdout);
268            if !stdout.is_empty() {
269                eprintln!("put: {stdout}");
270            }
271            let stderr = String::from_utf8_lossy(&output.stderr);
272            if !stderr.is_empty() {
273                eprintln!("put: {stderr}");
274            }
275            anyhow::ensure!(output.status.success(), "{:?}", output.status);
276            Ok(())
277        }
278
279        pub fn make_source(&self, path: &str) -> KeySource {
280            KeySource::Vault {
281                vault_address: Some(format!("http://127.0.0.1:{}", self.port)),
282                vault_token: Some(KEY.to_string()),
283                vault_mount: "secret".to_string(),
284                vault_path: path.to_string(),
285                vault_key: "key".to_string(),
286            }
287        }
288    }
289
290    async fn copy_stream_with_line_prefix<SRC, DEST>(
291        prefix: &str,
292        src: SRC,
293        mut dest: DEST,
294    ) -> std::io::Result<()>
295    where
296        SRC: AsyncRead + Unpin,
297        DEST: AsyncWrite + Unpin,
298    {
299        let mut src = tokio::io::BufReader::new(src);
300        loop {
301            let mut line = String::new();
302            src.read_line(&mut line).await?;
303            if !line.is_empty() {
304                dest.write_all(format!("{prefix}: {line}").as_bytes())
305                    .await?;
306            }
307        }
308    }
309
310    #[tokio::test]
311    async fn test_vault() -> anyhow::Result<()> {
312        if which::which("vault").is_err() {
313            return Ok(());
314        }
315        let vault = VaultServer::spawn().await?;
316
317        vault
318            .put_from_file("example.com", "../../example-private-dkim-key.pem")
319            .await?;
320
321        let source = vault.make_source("example.com");
322        let data = source.get().await?;
323
324        assert_eq!(
325            data,
326            std::fs::read("../../example-private-dkim-key.pem").unwrap()
327        );
328
329        vault.put("foo", "bar").await?;
330
331        let source = vault.make_source("foo");
332        let data = source.get().await?;
333
334        assert_eq!(data, b"bar");
335
336        let lua = Lua::new();
337        register(&lua).unwrap();
338        lua.globals().set("ADDR", vault.address())?;
339        lua.globals().set("KEY", KEY)?;
340        let pw: String = lua
341            .load(
342                r#"
343            local kumo = require 'kumo';
344            print("ADDR", ADDR)
345            print("KEY", KEY)
346            return kumo.secrets.load {
347                vault_path = "foo",
348                vault_address = ADDR,
349                vault_mount = "secret",
350                vault_token = KEY,
351            }
352            "#,
353            )
354            .eval_async()
355            .await
356            .unwrap();
357
358        assert_eq!(pw, "bar");
359
360        // Test with a different key name
361        vault.put("custom_key", "custom_value").await?;
362
363        let source = KeySource::Vault {
364            vault_address: Some(vault.address()),
365            vault_token: Some(KEY.to_string()),
366            vault_mount: "secret".to_string(),
367            vault_path: "custom_key".to_string(),
368            vault_key: "custom_field".to_string(),
369        };
370
371        // This should fail because the vault secret has "key" but we're looking for "custom_field"
372        let result = source.get().await;
373        let err = result.unwrap_err();
374        assert!(err
375            .to_string()
376            .contains("vault secret at custom_key does not contain key 'custom_field'"));
377
378        // Test with the correct key name
379        let source = KeySource::Vault {
380            vault_address: Some(vault.address()),
381            vault_token: Some(KEY.to_string()),
382            vault_mount: "secret".to_string(),
383            vault_path: "custom_key".to_string(),
384            vault_key: "key".to_string(),
385        };
386
387        let data = source.get().await?;
388        assert_eq!(data, b"custom_value");
389
390        Ok(())
391    }
392}