data_loader/
lib.rs

1#[cfg(feature = "impl")]
2use anyhow::{anyhow, Context};
3#[cfg(feature = "impl")]
4use config::{any_err, from_lua_value, get_or_create_sub_module};
5#[cfg(feature = "impl")]
6use mlua::Lua;
7use serde::{Deserialize, Serialize};
8#[cfg(feature = "impl")]
9use vaultrs::client::{VaultClient, VaultClientSettingsBuilder};
10
11#[derive(Deserialize, Serialize, Clone, Hash, PartialEq, Eq, Debug)]
12#[serde(untagged)]
13pub enum KeySource {
14    File(String),
15    Data {
16        key_data: String,
17    },
18    Vault {
19        vault_address: Option<String>,
20        vault_token: Option<String>,
21        vault_mount: String,
22        vault_path: String,
23        #[serde(default = "default_vault_key")]
24        vault_key: String,
25    },
26}
27
28fn default_vault_key() -> String {
29    "key".to_string()
30}
31
32#[cfg(feature = "impl")]
33impl KeySource {
34    pub async fn get(&self) -> anyhow::Result<Vec<u8>> {
35        match self {
36            Self::File(path) => Ok(tokio::fs::read(path).await?),
37            Self::Data { key_data } => Ok(key_data.as_bytes().to_vec()),
38            Self::Vault {
39                vault_address,
40                vault_token,
41                vault_mount,
42                vault_path,
43                vault_key,
44            } => {
45                let address = match vault_address {
46                    Some(a) => a.to_string(),
47                    None => std::env::var("VAULT_ADDR").map_err(|err| {
48                        anyhow!(
49                            "vault_address was not specified and $VAULT_ADDR is not set/usable: {self:?} {err:#}"
50                        )
51                    })?,
52                };
53                let token = match vault_token {
54                    Some(a) => a.to_string(),
55                    None => std::env::var("VAULT_TOKEN").map_err(|err| {
56                        anyhow!(
57                            "vault_token was not specified and $VAULT_TOKEN is not set/usable: {self:?} {err:#}"
58                        )
59                    })?,
60                };
61
62                let client = VaultClient::new(
63                    VaultClientSettingsBuilder::default()
64                        .address(address)
65                        .token(token)
66                        .build()?,
67                )?;
68
69                let entry: serde_json::Value = vaultrs::kv2::read(&client, vault_mount, vault_path)
70                    .await
71                    .with_context(|| {
72                        format!(
73                            "kv2::read vault_mount={vault_mount}, vault_path={vault_path} {self:?}"
74                        )
75                    })?;
76
77                let value = entry
78                    .get(&vault_key)
79                    .and_then(|v| v.as_str())
80                    .ok_or_else(|| {
81                        anyhow!("vault secret at {vault_path} does not contain key '{vault_key}'")
82                    })?;
83
84                Ok(value.as_bytes().to_vec())
85            }
86        }
87    }
88}
89
90#[cfg(feature = "impl")]
91pub fn register(lua: &Lua) -> anyhow::Result<()> {
92    let secrets_mod = get_or_create_sub_module(lua, "secrets")?;
93
94    secrets_mod.set(
95        "load",
96        lua.create_async_function(|lua, source: mlua::Value| async move {
97            let source: KeySource = from_lua_value(&lua, source)?;
98            lua.create_string(&source.get().await.map_err(any_err)?)
99        })?,
100    )?;
101
102    Ok(())
103}
104
105#[cfg(test)]
106#[cfg(feature = "impl")]
107mod test {
108    use super::*;
109    use anyhow::Context;
110    use std::process::Stdio;
111    use std::time::Duration;
112    use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
113    use tokio::process::{Child, Command};
114    use tokio::time::timeout;
115    use vaultrs::client::Client;
116
117    /// Ask the kernel to assign a free port.
118    /// We pass this to sshd and tell it to listen on that port.
119    /// This is racy, as releasing the socket technically makes
120    /// that port available to others using the same technique.
121    fn allocate_port() -> u16 {
122        let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("bind 127.0.0.1:0 failed");
123        listener.local_addr().unwrap().port()
124    }
125
126    const KEY: &str = "woot";
127
128    struct VaultServer {
129        port: u16,
130        _daemon: Child,
131    }
132
133    impl VaultServer {
134        pub async fn spawn() -> anyhow::Result<Self> {
135            let mut errors = vec![];
136
137            for _ in 0..2 {
138                let port = allocate_port();
139                match timeout(Duration::from_secs(5), Self::spawn_with_port(port)).await? {
140                    Ok(me) => return Ok(me),
141                    Err(err) => {
142                        errors.push(format!("{err:#}"));
143                    }
144                }
145            }
146            anyhow::bail!("failed to spawn vault: {}", errors.join(". "));
147        }
148
149        pub fn address(&self) -> String {
150            format!("http://127.0.0.1:{}", self.port)
151        }
152
153        async fn spawn_with_port(port: u16) -> anyhow::Result<Self> {
154            eprintln!("Trying to start vault on port {port}");
155
156            let mut daemon = Command::new("vault")
157                .args([
158                    "server",
159                    "-dev",
160                    &format!("-dev-listen-address=127.0.0.1:{port}"),
161                    &format!("-dev-root-token-id={KEY}"),
162                ])
163                .stdin(Stdio::piped())
164                .stderr(Stdio::piped())
165                .stdout(Stdio::piped())
166                .kill_on_drop(true)
167                .spawn()
168                .context("spawning vault")?;
169
170            let mut stderr = daemon.stderr.take().unwrap();
171            tokio::spawn(async move {
172                copy_stream_with_line_prefix("vault stderr", &mut stderr, &mut tokio::io::stderr())
173                    .await
174            });
175            let mut stdout = daemon.stdout.take().unwrap();
176            tokio::spawn(async move {
177                copy_stream_with_line_prefix("vault stdout", &mut stdout, &mut tokio::io::stderr())
178                    .await
179            });
180
181            let mut ok = false;
182            for _ in 0..25 {
183                let client = VaultClient::new(
184                    VaultClientSettingsBuilder::default()
185                        .address(format!("http://127.0.0.1:{port}"))
186                        .token(KEY)
187                        .build()?,
188                )?;
189                let status = client.status().await;
190                eprintln!("checking status: {status:?}");
191                if let Ok(vaultrs::sys::ServerStatus::OK) = status {
192                    ok = true;
193                    break;
194                }
195                tokio::time::sleep(Duration::from_millis(100)).await;
196            }
197
198            anyhow::ensure!(ok, "server didn't startup successfully");
199
200            if let Ok(Some(status)) = daemon.try_wait() {
201                anyhow::bail!("daemon exited already: {status:?}");
202            }
203
204            Ok(Self {
205                _daemon: daemon,
206                port,
207            })
208        }
209
210        pub async fn put_from_file(&self, vault_path: &str, path: &str) -> anyhow::Result<()> {
211            let output = Command::new("vault")
212                .args([
213                    "kv",
214                    "put",
215                    &format!("-address=http://127.0.0.1:{}", self.port),
216                    "-mount=secret",
217                    vault_path,
218                    &format!("key=@{path}"),
219                ])
220                .output()
221                .await?;
222            let stdout = String::from_utf8_lossy(&output.stdout);
223            if !stdout.is_empty() {
224                eprintln!("put_from_file: {stdout}");
225            }
226            let stderr = String::from_utf8_lossy(&output.stderr);
227            if !stderr.is_empty() {
228                eprintln!("put_from_file: {stderr}");
229            }
230            anyhow::ensure!(output.status.success(), "{:?}", output.status);
231            Ok(())
232        }
233
234        pub async fn put(&self, vault_path: &str, value: &str) -> anyhow::Result<()> {
235            let output = Command::new("vault")
236                .args([
237                    "kv",
238                    "put",
239                    &format!("-address=http://127.0.0.1:{}", self.port),
240                    "-mount=secret",
241                    vault_path,
242                    &format!("key={value}"),
243                ])
244                .output()
245                .await?;
246            let stdout = String::from_utf8_lossy(&output.stdout);
247            if !stdout.is_empty() {
248                eprintln!("put: {stdout}");
249            }
250            let stderr = String::from_utf8_lossy(&output.stderr);
251            if !stderr.is_empty() {
252                eprintln!("put: {stderr}");
253            }
254            anyhow::ensure!(output.status.success(), "{:?}", output.status);
255            Ok(())
256        }
257
258        pub fn make_source(&self, path: &str) -> KeySource {
259            KeySource::Vault {
260                vault_address: Some(format!("http://127.0.0.1:{}", self.port)),
261                vault_token: Some(KEY.to_string()),
262                vault_mount: "secret".to_string(),
263                vault_path: path.to_string(),
264                vault_key: "key".to_string(),
265            }
266        }
267    }
268
269    async fn copy_stream_with_line_prefix<SRC, DEST>(
270        prefix: &str,
271        src: SRC,
272        mut dest: DEST,
273    ) -> std::io::Result<()>
274    where
275        SRC: AsyncRead + Unpin,
276        DEST: AsyncWrite + Unpin,
277    {
278        let mut src = tokio::io::BufReader::new(src);
279        loop {
280            let mut line = String::new();
281            src.read_line(&mut line).await?;
282            if !line.is_empty() {
283                dest.write_all(format!("{prefix}: {line}").as_bytes())
284                    .await?;
285            }
286        }
287    }
288
289    #[tokio::test]
290    async fn test_vault() -> anyhow::Result<()> {
291        if which::which("vault").is_err() {
292            return Ok(());
293        }
294        let vault = VaultServer::spawn().await?;
295
296        vault
297            .put_from_file("example.com", "../../example-private-dkim-key.pem")
298            .await?;
299
300        let source = vault.make_source("example.com");
301        let data = source.get().await?;
302
303        assert_eq!(
304            data,
305            std::fs::read("../../example-private-dkim-key.pem").unwrap()
306        );
307
308        vault.put("foo", "bar").await?;
309
310        let source = vault.make_source("foo");
311        let data = source.get().await?;
312
313        assert_eq!(data, b"bar");
314
315        let lua = Lua::new();
316        register(&lua).unwrap();
317        lua.globals().set("ADDR", vault.address())?;
318        lua.globals().set("KEY", KEY)?;
319        let pw: String = lua
320            .load(
321                r#"
322            local kumo = require 'kumo';
323            print("ADDR", ADDR)
324            print("KEY", KEY)
325            return kumo.secrets.load {
326                vault_path = "foo",
327                vault_address = ADDR,
328                vault_mount = "secret",
329                vault_token = KEY,
330            }
331            "#,
332            )
333            .eval_async()
334            .await
335            .unwrap();
336
337        assert_eq!(pw, "bar");
338
339        // Test with a different key name
340        vault.put("custom_key", "custom_value").await?;
341
342        let source = KeySource::Vault {
343            vault_address: Some(vault.address()),
344            vault_token: Some(KEY.to_string()),
345            vault_mount: "secret".to_string(),
346            vault_path: "custom_key".to_string(),
347            vault_key: "custom_field".to_string(),
348        };
349
350        // This should fail because the vault secret has "key" but we're looking for "custom_field"
351        let result = source.get().await;
352        let err = result.unwrap_err();
353        assert!(err
354            .to_string()
355            .contains("vault secret at custom_key does not contain key 'custom_field'"));
356
357        // Test with the correct key name
358        let source = KeySource::Vault {
359            vault_address: Some(vault.address()),
360            vault_token: Some(KEY.to_string()),
361            vault_mount: "secret".to_string(),
362            vault_path: "custom_key".to_string(),
363            vault_key: "key".to_string(),
364        };
365
366        let data = source.get().await?;
367        assert_eq!(data, b"custom_value");
368
369        Ok(())
370    }
371}