data_loader/
lib.rs

1#[cfg(feature = "impl")]
2use anyhow::{anyhow, Context};
3#[cfg(feature = "impl")]
4use config::{any_err, from_lua_value, get_or_create_sub_module};
5#[cfg(feature = "impl")]
6use mlua::Lua;
7use serde::{Deserialize, Serialize};
8#[cfg(feature = "impl")]
9use vaultrs::client::{VaultClient, VaultClientSettingsBuilder};
10
11#[derive(Deserialize, Serialize, Clone, Hash, PartialEq, Eq, Debug)]
12#[serde(untagged)]
13pub enum KeySource {
14    File(String),
15    Data {
16        key_data: String,
17    },
18    Vault {
19        vault_address: Option<String>,
20        vault_token: Option<String>,
21        vault_mount: String,
22        vault_path: String,
23    },
24}
25
26#[cfg(feature = "impl")]
27impl KeySource {
28    pub async fn get(&self) -> anyhow::Result<Vec<u8>> {
29        match self {
30            Self::File(path) => Ok(tokio::fs::read(path).await?),
31            Self::Data { key_data } => Ok(key_data.as_bytes().to_vec()),
32            Self::Vault {
33                vault_address,
34                vault_token,
35                vault_mount,
36                vault_path,
37            } => {
38                let address = match vault_address {
39                    Some(a) => a.to_string(),
40                    None => std::env::var("VAULT_ADDR").map_err(|err| {
41                        anyhow!(
42                            "vault_address was not specified and $VAULT_ADDR is not set/usable: {self:?} {err:#}"
43                        )
44                    })?,
45                };
46                let token = match vault_token {
47                    Some(a) => a.to_string(),
48                    None => std::env::var("VAULT_TOKEN").map_err(|err| {
49                        anyhow!(
50                            "vault_token was not specified and $VAULT_TOKEN is not set/usable: {self:?} {err:#}"
51                        )
52                    })?,
53                };
54
55                let client = VaultClient::new(
56                    VaultClientSettingsBuilder::default()
57                        .address(address)
58                        .token(token)
59                        .build()?,
60                )?;
61
62                #[derive(Deserialize, Debug)]
63                struct Entry {
64                    key: String,
65                }
66
67                let entry: Entry = vaultrs::kv2::read(&client, vault_mount, vault_path)
68                    .await
69                    .with_context(|| {
70                        format!(
71                            "kv2::read vault_mount={vault_mount}, vault_path={vault_path} {self:?}"
72                        )
73                    })?;
74
75                Ok(entry.key.into())
76            }
77        }
78    }
79}
80
81#[cfg(feature = "impl")]
82pub fn register(lua: &Lua) -> anyhow::Result<()> {
83    let secrets_mod = get_or_create_sub_module(lua, "secrets")?;
84
85    secrets_mod.set(
86        "load",
87        lua.create_async_function(|lua, source: mlua::Value| async move {
88            let source: KeySource = from_lua_value(&lua, source)?;
89            lua.create_string(&source.get().await.map_err(any_err)?)
90        })?,
91    )?;
92
93    Ok(())
94}
95
96#[cfg(test)]
97#[cfg(feature = "impl")]
98mod test {
99    use super::*;
100    use anyhow::Context;
101    use std::process::Stdio;
102    use std::time::Duration;
103    use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
104    use tokio::process::{Child, Command};
105    use tokio::time::timeout;
106    use vaultrs::client::Client;
107
108    /// Ask the kernel to assign a free port.
109    /// We pass this to sshd and tell it to listen on that port.
110    /// This is racy, as releasing the socket technically makes
111    /// that port available to others using the same technique.
112    fn allocate_port() -> u16 {
113        let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("bind 127.0.0.1:0 failed");
114        listener.local_addr().unwrap().port()
115    }
116
117    const KEY: &str = "woot";
118
119    struct VaultServer {
120        port: u16,
121        _daemon: Child,
122    }
123
124    impl VaultServer {
125        pub async fn spawn() -> anyhow::Result<Self> {
126            let mut errors = vec![];
127
128            for _ in 0..2 {
129                let port = allocate_port();
130                match timeout(Duration::from_secs(5), Self::spawn_with_port(port)).await? {
131                    Ok(me) => return Ok(me),
132                    Err(err) => {
133                        errors.push(format!("{err:#}"));
134                    }
135                }
136            }
137            anyhow::bail!("failed to spawn vault: {}", errors.join(". "));
138        }
139
140        pub fn address(&self) -> String {
141            format!("http://127.0.0.1:{}", self.port)
142        }
143
144        async fn spawn_with_port(port: u16) -> anyhow::Result<Self> {
145            eprintln!("Trying to start vault on port {port}");
146
147            let mut daemon = Command::new("vault")
148                .args([
149                    "server",
150                    "-dev",
151                    &format!("-dev-listen-address=127.0.0.1:{port}"),
152                    &format!("-dev-root-token-id={KEY}"),
153                ])
154                .stdin(Stdio::piped())
155                .stderr(Stdio::piped())
156                .stdout(Stdio::piped())
157                .kill_on_drop(true)
158                .spawn()
159                .context("spawning vault")?;
160
161            let mut stderr = daemon.stderr.take().unwrap();
162            tokio::spawn(async move {
163                copy_stream_with_line_prefix("vault stderr", &mut stderr, &mut tokio::io::stderr())
164                    .await
165            });
166            let mut stdout = daemon.stdout.take().unwrap();
167            tokio::spawn(async move {
168                copy_stream_with_line_prefix("vault stdout", &mut stdout, &mut tokio::io::stderr())
169                    .await
170            });
171
172            let mut ok = false;
173            for _ in 0..25 {
174                let client = VaultClient::new(
175                    VaultClientSettingsBuilder::default()
176                        .address(format!("http://127.0.0.1:{port}"))
177                        .token(KEY)
178                        .build()?,
179                )?;
180                let status = client.status().await;
181                eprintln!("checking status: {status:?}");
182                if let Ok(vaultrs::sys::ServerStatus::OK) = status {
183                    ok = true;
184                    break;
185                }
186                tokio::time::sleep(Duration::from_millis(100)).await;
187            }
188
189            anyhow::ensure!(ok, "server didn't startup successfully");
190
191            if let Ok(Some(status)) = daemon.try_wait() {
192                anyhow::bail!("daemon exited already: {status:?}");
193            }
194
195            Ok(Self {
196                _daemon: daemon,
197                port,
198            })
199        }
200
201        pub async fn put_from_file(&self, vault_path: &str, path: &str) -> anyhow::Result<()> {
202            let output = Command::new("vault")
203                .args([
204                    "kv",
205                    "put",
206                    &format!("-address=http://127.0.0.1:{}", self.port),
207                    "-mount=secret",
208                    vault_path,
209                    &format!("key=@{path}"),
210                ])
211                .output()
212                .await?;
213            let stdout = String::from_utf8_lossy(&output.stdout);
214            if !stdout.is_empty() {
215                eprintln!("put_from_file: {stdout}");
216            }
217            let stderr = String::from_utf8_lossy(&output.stderr);
218            if !stderr.is_empty() {
219                eprintln!("put_from_file: {stderr}");
220            }
221            anyhow::ensure!(output.status.success(), "{:?}", output.status);
222            Ok(())
223        }
224
225        pub async fn put(&self, vault_path: &str, value: &str) -> anyhow::Result<()> {
226            let output = Command::new("vault")
227                .args([
228                    "kv",
229                    "put",
230                    &format!("-address=http://127.0.0.1:{}", self.port),
231                    "-mount=secret",
232                    vault_path,
233                    &format!("key={value}"),
234                ])
235                .output()
236                .await?;
237            let stdout = String::from_utf8_lossy(&output.stdout);
238            if !stdout.is_empty() {
239                eprintln!("put: {stdout}");
240            }
241            let stderr = String::from_utf8_lossy(&output.stderr);
242            if !stderr.is_empty() {
243                eprintln!("put: {stderr}");
244            }
245            anyhow::ensure!(output.status.success(), "{:?}", output.status);
246            Ok(())
247        }
248
249        pub fn make_source(&self, path: &str) -> KeySource {
250            KeySource::Vault {
251                vault_address: Some(format!("http://127.0.0.1:{}", self.port)),
252                vault_token: Some(KEY.to_string()),
253                vault_mount: "secret".to_string(),
254                vault_path: path.to_string(),
255            }
256        }
257    }
258
259    async fn copy_stream_with_line_prefix<SRC, DEST>(
260        prefix: &str,
261        src: SRC,
262        mut dest: DEST,
263    ) -> std::io::Result<()>
264    where
265        SRC: AsyncRead + Unpin,
266        DEST: AsyncWrite + Unpin,
267    {
268        let mut src = tokio::io::BufReader::new(src);
269        loop {
270            let mut line = String::new();
271            src.read_line(&mut line).await?;
272            if !line.is_empty() {
273                dest.write_all(format!("{prefix}: {line}").as_bytes())
274                    .await?;
275            }
276        }
277    }
278
279    #[tokio::test]
280    async fn test_vault() -> anyhow::Result<()> {
281        if which::which("vault").is_err() {
282            return Ok(());
283        }
284        let vault = VaultServer::spawn().await?;
285
286        vault
287            .put_from_file("example.com", "../../example-private-dkim-key.pem")
288            .await?;
289
290        let source = vault.make_source("example.com");
291        let data = source.get().await?;
292
293        assert_eq!(
294            data,
295            std::fs::read("../../example-private-dkim-key.pem").unwrap()
296        );
297
298        vault.put("foo", "bar").await?;
299
300        let source = vault.make_source("foo");
301        let data = source.get().await?;
302
303        assert_eq!(data, b"bar");
304
305        let lua = Lua::new();
306        register(&lua).unwrap();
307        lua.globals().set("ADDR", vault.address())?;
308        lua.globals().set("KEY", KEY)?;
309        let pw: String = lua
310            .load(
311                r#"
312            local kumo = require 'kumo';
313            print("ADDR", ADDR)
314            print("KEY", KEY)
315            return kumo.secrets.load {
316                vault_path = "foo",
317                vault_address = ADDR,
318                vault_mount = "secret",
319                vault_token = KEY,
320            }
321            "#,
322            )
323            .eval_async()
324            .await
325            .unwrap();
326
327        assert_eq!(pw, "bar");
328
329        Ok(())
330    }
331}