mod_crypto/
lib.rs

1use anyhow::anyhow;
2use aws_lc_rs::cipher::{
3    AES_128, AES_256, AES_CBC_IV_LEN, DecryptionContext, PaddedBlockDecryptingKey,
4    PaddedBlockEncryptingKey, UnboundCipherKey,
5};
6use config::{any_err, from_lua_value, get_or_create_sub_module};
7use data_loader::KeySource;
8use mlua::{Lua, Value};
9use mod_digest::BinaryResult;
10use serde::Deserialize;
11
12#[derive(Clone, Debug)]
13pub struct AesParams {
14    pub algorithm: AesAlgo,
15    pub key: Vec<u8>,
16}
17
18#[derive(Deserialize, Clone, Copy, Debug)]
19pub enum AesAlgo {
20    Ecb,
21    Cbc,
22}
23
24#[derive(Deserialize, Clone, Debug)]
25pub struct KeyConfig {
26    pub key: KeySource,
27}
28
29fn make_cipher_key(bytes: &[u8]) -> anyhow::Result<UnboundCipherKey> {
30    match bytes.len() {
31        16 => Ok(UnboundCipherKey::new(&AES_128, bytes)?),
32        32 => Ok(UnboundCipherKey::new(&AES_256, bytes)?),
33        _ => anyhow::bail!("Key length must be 16 or 32 bytes"),
34    }
35}
36
37fn make_enc_key(bytes: &[u8], algorithm: AesAlgo) -> anyhow::Result<PaddedBlockEncryptingKey> {
38    let key = make_cipher_key(bytes)?;
39    match algorithm {
40        AesAlgo::Ecb => Ok(PaddedBlockEncryptingKey::ecb_pkcs7(key)?),
41        AesAlgo::Cbc => Ok(PaddedBlockEncryptingKey::cbc_pkcs7(key)?),
42    }
43}
44
45fn make_dec_key(bytes: &[u8], algorithm: AesAlgo) -> anyhow::Result<PaddedBlockDecryptingKey> {
46    let key = make_cipher_key(bytes)?;
47    match algorithm {
48        AesAlgo::Ecb => Ok(PaddedBlockDecryptingKey::ecb_pkcs7(key)?),
49        AesAlgo::Cbc => Ok(PaddedBlockDecryptingKey::cbc_pkcs7(key)?),
50    }
51}
52
53fn aes_encrypt_block(plaintext: &[u8], params: AesParams) -> anyhow::Result<Vec<u8>> {
54    let mut buf_ciphertext = plaintext.to_vec();
55    let enc_key = make_enc_key(&params.key, params.algorithm)?;
56
57    match params.algorithm {
58        AesAlgo::Ecb => {
59            // don't return IV vector for ecb
60            enc_key.encrypt(&mut buf_ciphertext)?;
61            Ok(buf_ciphertext)
62        }
63        AesAlgo::Cbc => {
64            // context contain IV vector
65            let context = enc_key.encrypt(&mut buf_ciphertext)?;
66
67            match context {
68                DecryptionContext::Iv128(iv) => {
69                    let mut result = iv.as_ref().to_vec();
70                    // Append the actual ciphertext after the IV
71                    result.extend_from_slice(&buf_ciphertext);
72                    Ok(result)
73                }
74                unsupported => anyhow::bail!(
75                    "Unexpected IV context {unsupported:?} for encrypting in CBC mode"
76                ),
77            }
78        }
79    }
80}
81
82fn aes_decrypt_block(ciphertext_buf: &[u8], params: AesParams) -> anyhow::Result<BinaryResult> {
83    let mut in_out_buffer = ciphertext_buf.to_vec();
84
85    let dec_key = make_dec_key(&params.key, params.algorithm)?;
86    match params.algorithm {
87        AesAlgo::Ecb => match dec_key.decrypt(&mut in_out_buffer, DecryptionContext::None) {
88            Ok(plaintext) => Ok(BinaryResult(plaintext.to_vec())),
89            Err(e) => Err(anyhow!("Decryption failed with AES ECB mode: {}", e)),
90        },
91
92        AesAlgo::Cbc => {
93            // CBC expects the IV to be prepended to the ciphertext.
94            // Split into IV and actual ciphertext
95            let Some((iv_bytes, actual_ciphertext)) =
96                ciphertext_buf.split_at_checked(AES_CBC_IV_LEN)
97            else {
98                anyhow::bail!(
99                    "ciphertext must be prefixed by iv with len at least {AES_CBC_IV_LEN}"
100                );
101            };
102
103            // Create a mutable buffer for decryption
104            let mut decrypt_buffer = actual_ciphertext.to_vec();
105            let plaintext_slice = dec_key.decrypt(
106                &mut decrypt_buffer,
107                DecryptionContext::Iv128(iv_bytes.try_into()?),
108            )?;
109            Ok(BinaryResult(plaintext_slice.to_vec()))
110        }
111    }
112}
113pub fn register(lua: &Lua) -> anyhow::Result<()> {
114    let crypto = get_or_create_sub_module(lua, "crypto")?;
115    crypto.set(
116        "aes_encrypt_block",
117        lua.create_async_function(
118            |lua, (algorithm, data, config): (Value, mlua::String, Value)| async move {
119                let algorithm: AesAlgo = from_lua_value(&lua, algorithm)?;
120
121                let config: KeyConfig = from_lua_value(&lua, config)?;
122                let key = config.key.get().await.map_err(any_err)?;
123                let p = AesParams { key, algorithm };
124
125                let plaintext_bytes = data.as_bytes();
126                let result = aes_encrypt_block(&plaintext_bytes, p).map_err(any_err)?;
127                lua.create_string(&result)
128            },
129        )?,
130    )?;
131
132    crypto.set(
133        "aes_decrypt_block",
134        lua.create_async_function(
135            |lua, (algorithm, data, config): (Value, mlua::String, Value)| async move {
136                let algorithm: AesAlgo = from_lua_value(&lua, algorithm)?;
137
138                let config: KeyConfig = from_lua_value(&lua, config)?;
139                let key = config.key.get().await.map_err(any_err)?;
140                let p = AesParams { key, algorithm };
141
142                let ciphertext_bytes = data.as_bytes();
143                aes_decrypt_block(&ciphertext_bytes, p).map_err(any_err)
144            },
145        )?,
146    )?;
147    Ok(())
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use anyhow::Result;
154    use hex;
155
156    mod ecb_tests {
157        use super::*;
158        #[test]
159        fn encrypt_decrypt_aes128_ecb_hex() -> Result<()> {
160            // Example plaintext
161            let plaintext = "helloword-from-the-sun";
162            // Hex for a 128-bit AES key
163            let key_hex = "2b7e151628aed2a6abf7158809cf4f3c";
164
165            let key = hex::decode(key_hex)?;
166            let params = AesParams {
167                key,
168                algorithm: AesAlgo::Ecb,
169            };
170
171            let ciphertext = aes_encrypt_block(plaintext.as_bytes(), params.clone())?;
172            let decrypted_text = aes_decrypt_block(ciphertext.as_slice(), params.clone())?;
173            let decrypted_string = String::from_utf8(decrypted_text.0)?;
174            assert_eq!(decrypted_string, plaintext);
175
176            Ok(())
177        }
178    }
179
180    mod cbc_tests {
181        use super::*;
182        #[test]
183        fn encrypt_decrypt_aes128_cbc_hex() -> Result<()> {
184            let plaintext = "This is a secret message to be encrypted using AES-128 CBC mode.";
185            // AES-128 key
186            let key_hex = "2b7e151628aed2a6abf7158809cf4f3c";
187            let key = hex::decode(key_hex)?;
188
189            let params = AesParams {
190                key,
191                algorithm: AesAlgo::Cbc,
192            };
193
194            let ciphertext_with_iv = aes_encrypt_block(plaintext.as_bytes(), params.clone())?;
195
196            assert!(
197                ciphertext_with_iv.len() > AES_CBC_IV_LEN,
198                "Ciphertext should contain IV + data"
199            );
200
201            let decrypted_text = aes_decrypt_block(ciphertext_with_iv.as_slice(), params.clone())?;
202            let decrypted_string = String::from_utf8(decrypted_text.0)?;
203            assert_eq!(decrypted_string, plaintext);
204
205            Ok(())
206        }
207
208        #[test]
209        fn encrypt_decrypt_aes256_cbc_hex() -> Result<()> {
210            let plaintext = "Another secret message, but this one is for AES-256 CBC.";
211            // AES-256 key
212            let key_hex = "603deb1015ca71be2b73aef0857d7781a5b6b8e5b62c65e9f1f63b7ee7ec6f2f";
213            let key = hex::decode(key_hex)?;
214
215            let params = AesParams {
216                key,
217                algorithm: AesAlgo::Cbc,
218            };
219
220            let ciphertext_with_iv = aes_encrypt_block(plaintext.as_bytes(), params.clone())?;
221
222            assert!(
223                ciphertext_with_iv.len() > AES_CBC_IV_LEN,
224                "Ciphertext should contain IV + data"
225            );
226
227            let decrypted_text = aes_decrypt_block(ciphertext_with_iv.as_slice(), params.clone())?;
228            let decrypted_string = String::from_utf8(decrypted_text.0)?;
229            assert_eq!(decrypted_string, plaintext);
230
231            Ok(())
232        }
233    }
234}