mod_encode/
lib.rs

1use config::{any_err, get_or_create_sub_module};
2use data_encoding::{
3    Encoding, BASE32, BASE32HEX, BASE32HEX_NOPAD, BASE32_NOPAD, BASE64, BASE64URL, BASE64URL_NOPAD,
4    BASE64_NOPAD, HEXLOWER,
5};
6use mlua::{Lua, Value};
7
8/// data_encoding is very strict when considering padding, making it
9/// incompatible with a number of base64 encoders that apply excess
10/// padding in certain situations.
11/// This decode wrapper considers whether the encoder allows padding,
12/// and if so, speculatively removes any trailing padding bytes from
13/// the string and instead uses the no-padding variant of the encoder
14/// specification in order to avoid triggering any length/padding
15/// checks inside the crate.
16fn decode(enc: &Encoding, data: &[u8]) -> mlua::Result<Vec<u8>> {
17    let mut spec = enc.specification();
18    if let Some(pad_char) = spec.padding {
19        let padding_bytes = [pad_char as u8];
20        let mut stripped = data;
21        while let Some(s) = stripped.strip_suffix(&padding_bytes) {
22            stripped = s;
23        }
24        spec.padding.take();
25        return spec
26            .encoding()
27            .map_err(any_err)?
28            .decode(stripped)
29            .map_err(any_err);
30    }
31    enc.decode(data).map_err(any_err)
32}
33
34pub fn register(lua: &Lua) -> anyhow::Result<()> {
35    let digest_mod = get_or_create_sub_module(lua, "encode")?;
36
37    for (name, enc) in [
38        ("base32", BASE32),
39        ("base32hex", BASE32HEX),
40        ("base32hex_nopad", BASE32HEX_NOPAD),
41        ("base32_nopad", BASE32_NOPAD),
42        ("base64", BASE64),
43        ("base64url", BASE64URL),
44        ("base64url_nopad", BASE64URL_NOPAD),
45        ("base64_nopad", BASE64_NOPAD),
46        ("hex", HEXLOWER),
47    ] {
48        let encoder = enc.clone();
49        digest_mod.set(
50            format!("{name}_encode"),
51            lua.create_function(move |_, data: mlua::Value| match data {
52                Value::String(s) => Ok(encoder.encode(&s.as_bytes())),
53                _ => Err(mlua::Error::external(
54                    "parameter must be a string".to_string(),
55                )),
56            })?,
57        )?;
58        digest_mod.set(
59            format!("{name}_decode"),
60            lua.create_function(move |lua, data: String| {
61                let bytes = decode(&enc, data.as_bytes())?;
62                lua.create_string(&bytes)
63            })?,
64        )?;
65    }
66    Ok(())
67}
68
69#[cfg(test)]
70#[test]
71fn test_decode_padding() {
72    assert_eq!(
73        std::str::from_utf8(&decode(&BASE64, b"MmVtYWlsLmxvZwAuY3N2").unwrap()).unwrap(),
74        "2email.log\0.csv"
75    );
76    assert_eq!(
77        std::str::from_utf8(&decode(&BASE64, b"MmVtYWlsLmxvZwAuY3N2=").unwrap()).unwrap(),
78        "2email.log\0.csv"
79    );
80    assert_eq!(
81        std::str::from_utf8(&decode(&BASE64, b"MmVtYWlsLmxvZwAuY3N2==").unwrap()).unwrap(),
82        "2email.log\0.csv"
83    );
84    assert_eq!(
85        std::str::from_utf8(&decode(&BASE64, b"MmVtYWlsLmxvZwAuY3N2===").unwrap()).unwrap(),
86        "2email.log\0.csv"
87    );
88}