mod_aws_sigv4/
lib.rs

1use anyhow::Context;
2use aws_lc_rs::hmac::Key;
3use chrono::{DateTime, Utc};
4use config::{any_err, from_lua_value, get_or_create_sub_module};
5use data_encoding::HEXLOWER;
6use data_loader::KeySource;
7use mlua::{Lua, LuaSerdeExt, Value};
8use percent_encoding::{percent_encode, AsciiSet, CONTROLS};
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11
12/// AWS SigV4 URI encoding set
13/// Encodes everything except: A-Z a-z 0-9 - _ . ~
14const URI_ENCODE_SET: &AsciiSet = &CONTROLS
15    .add(b' ')
16    .add(b'!')
17    .add(b'"')
18    .add(b'#')
19    .add(b'$')
20    .add(b'%')
21    .add(b'&')
22    .add(b'\'')
23    .add(b'(')
24    .add(b')')
25    .add(b'*')
26    .add(b'+')
27    .add(b',')
28    .add(b'/')
29    .add(b':')
30    .add(b';')
31    .add(b'=')
32    .add(b'?')
33    .add(b'@')
34    .add(b'[')
35    .add(b']');
36
37#[derive(Deserialize, Debug)]
38struct SigV4Request {
39    /// AWS access key ID (can be a KeySource)
40    access_key: KeySource,
41    /// AWS secret access key (can be a KeySource)
42    secret_key: KeySource,
43    /// AWS region (e.g., "us-east-1")
44    region: String,
45    /// AWS service name (e.g., "s3", "sns", "sqs")
46    service: String,
47    /// HTTP method (e.g., "GET", "POST")
48    method: String,
49    /// URI path (e.g., "/")
50    uri: String,
51    /// Optional query string parameters
52    #[serde(default)]
53    query_params: BTreeMap<String, String>,
54    /// HTTP headers to sign
55    #[serde(default)]
56    headers: BTreeMap<String, String>,
57    /// Request payload (body)
58    #[serde(default)]
59    payload: String,
60    /// Optional timestamp (defaults to current time)
61    timestamp: Option<DateTime<Utc>>,
62    /// Optional session token for temporary credentials
63    session_token: Option<String>,
64}
65
66#[derive(Deserialize, Serialize, Debug)]
67struct SigV4Response {
68    /// The authorization header value
69    authorization: String,
70    /// The timestamp used in ISO8601 format (YYYYMMDD'T'HHMMSS'Z')
71    timestamp: String,
72    /// The canonical request (for debugging)
73    canonical_request: String,
74    /// The string to sign (for debugging)
75    string_to_sign: String,
76    /// The signature
77    signature: String,
78}
79
80fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
81    let key = Key::new(aws_lc_rs::hmac::HMAC_SHA256, key);
82    let tag = aws_lc_rs::hmac::sign(&key, data);
83    tag.as_ref().to_vec()
84}
85
86fn sha256_hex(data: &[u8]) -> String {
87    use aws_lc_rs::digest;
88    let hash = digest::digest(&digest::SHA256, data);
89    HEXLOWER.encode(hash.as_ref())
90}
91
92fn uri_encode(input: &str) -> String {
93    percent_encode(input.as_bytes(), URI_ENCODE_SET).to_string()
94}
95
96fn create_canonical_uri(path: &str) -> String {
97    if path.is_empty() {
98        "/".to_string()
99    } else {
100        // Split path and encode each segment
101        path.split('/')
102            .map(uri_encode)
103            .collect::<Vec<_>>()
104            .join("/")
105    }
106}
107
108fn create_canonical_query_string(params: &BTreeMap<String, String>) -> String {
109    if params.is_empty() {
110        return String::new();
111    }
112
113    // Sort parameters and URI encode them.
114    //
115    // We collect into a Vec and sort on the *encoded* keys to ensure
116    // the ordering is correct even when encoding changes the byte
117    // ordering of the original key/value strings.
118    let mut encoded_params: Vec<(String, String)> = params
119        .iter()
120        .map(|(k, v)| (uri_encode(k), uri_encode(v)))
121        .collect();
122    encoded_params.sort();
123
124    encoded_params
125        .iter()
126        .map(|(k, v)| format!("{}={}", k, v))
127        .collect::<Vec<_>>()
128        .join("&")
129}
130
131fn create_canonical_headers(headers: &BTreeMap<String, String>) -> (String, String) {
132    // Convert headers to lowercase and trim values
133    let canonical_headers: BTreeMap<String, String> = headers
134        .iter()
135        .map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
136        .collect();
137
138    // Sort headers
139    let header_string = canonical_headers
140        .iter()
141        .map(|(k, v)| format!("{}:{}", k, v))
142        .collect::<Vec<_>>()
143        .join("\n");
144
145    // Create signed headers list
146    let signed_headers = canonical_headers
147        .keys()
148        .cloned()
149        .collect::<Vec<_>>()
150        .join(";");
151
152    (header_string, signed_headers)
153}
154
155fn create_signing_key(secret_key: &str, date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
156    let k_date = hmac_sha256(
157        format!("AWS4{secret_key}").as_bytes(),
158        date_stamp.as_bytes(),
159    );
160    let k_region = hmac_sha256(&k_date, region.as_bytes());
161    let k_service = hmac_sha256(&k_region, service.as_bytes());
162    hmac_sha256(&k_service, b"aws4_request")
163}
164
165async fn sign_request(req: SigV4Request) -> anyhow::Result<SigV4Response> {
166    // Get the access key id and secret key from their KeySource values
167    let access_key_bytes = req.access_key.get().await?;
168    let access_key = std::str::from_utf8(&access_key_bytes)
169        .context("access_key must be valid UTF-8")?
170        .to_string();
171
172    // Get the secret key
173    let secret_key_bytes = req.secret_key.get().await?;
174    let secret_key = std::str::from_utf8(&secret_key_bytes)
175        .context("secret_key must be valid UTF-8")?
176        .to_string();
177
178    // Use provided timestamp or current time
179    let timestamp = req.timestamp.unwrap_or_else(Utc::now);
180    let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
181    let date_stamp = timestamp.format("%Y%m%d").to_string();
182
183    // Create payload hash
184    let payload_hash = sha256_hex(req.payload.as_bytes());
185
186    // Prepare headers - add required AWS headers
187    let mut headers = req.headers.clone();
188    headers.insert("host".to_string(), "".to_string()); // Will be set by caller
189    headers.insert("x-amz-date".to_string(), amz_date.clone());
190
191    if let Some(token) = &req.session_token {
192        headers.insert("x-amz-security-token".to_string(), token.clone());
193    }
194
195    // Add content hash header for some services
196    if req.service != "s3" {
197        headers.insert("x-amz-content-sha256".to_string(), payload_hash.clone());
198    }
199
200    // Create canonical request
201    let canonical_uri = create_canonical_uri(&req.uri);
202    let canonical_query_string = create_canonical_query_string(&req.query_params);
203    let (canonical_headers, signed_headers) = create_canonical_headers(&headers);
204
205    // See https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
206    // for the canonical request structure. The blank line between the
207    // canonical headers and the signed headers is required by the spec.
208    let canonical_request = format!(
209        "{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n\n{signed_headers}\n{payload_hash}",
210        method = req.method,
211    );
212
213    // Create string to sign
214    let algorithm = "AWS4-HMAC-SHA256";
215    let credential_scope = format!(
216        "{date_stamp}/{region}/{service}/aws4_request",
217        region = req.region,
218        service = req.service
219    );
220    let canonical_request_hash = sha256_hex(canonical_request.as_bytes());
221
222    let string_to_sign =
223        format!("{algorithm}\n{amz_date}\n{credential_scope}\n{canonical_request_hash}");
224
225    // Calculate signature
226    let signing_key = create_signing_key(&secret_key, &date_stamp, &req.region, &req.service);
227    let signature_bytes = hmac_sha256(&signing_key, string_to_sign.as_bytes());
228    let signature = HEXLOWER.encode(&signature_bytes);
229
230    // Create authorization header
231    let authorization = format!(
232        "{algorithm} Credential={access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
233        access_key = access_key
234    );
235
236    Ok(SigV4Response {
237        authorization,
238        timestamp: amz_date,
239        canonical_request,
240        string_to_sign,
241        signature,
242    })
243}
244
245pub fn register(lua: &Lua) -> anyhow::Result<()> {
246    // Register under kumo.crypto as aws_sign_v4 so that the function
247    // shows up alongside the other crypto helpers in the reference docs.
248    let aws_mod = get_or_create_sub_module(lua, "crypto")?;
249
250    aws_mod.set(
251        "aws_sign_v4",
252        lua.create_async_function(|lua, request: Value| async move {
253            let req: SigV4Request = from_lua_value(&lua, request)?;
254            let response = sign_request(req).await.map_err(any_err)?;
255
256            lua.to_value(&response)
257        })?,
258    )?;
259
260    Ok(())
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_uri_encode() {
269        assert_eq!(uri_encode("test"), "test");
270        assert_eq!(uri_encode("test value"), "test%20value");
271        assert_eq!(uri_encode("test/path"), "test%2Fpath");
272        assert_eq!(uri_encode("test-value_123.txt~"), "test-value_123.txt~");
273    }
274
275    #[test]
276    fn test_canonical_uri() {
277        assert_eq!(create_canonical_uri(""), "/");
278        assert_eq!(create_canonical_uri("/"), "/");
279        assert_eq!(create_canonical_uri("/path"), "/path");
280        assert_eq!(create_canonical_uri("/path/to/file"), "/path/to/file");
281        assert_eq!(
282            create_canonical_uri("/path with spaces"),
283            "/path%20with%20spaces"
284        );
285    }
286
287    #[test]
288    fn test_canonical_query_string() {
289        let mut params = BTreeMap::new();
290        assert_eq!(create_canonical_query_string(&params), "");
291
292        params.insert("key".to_string(), "value".to_string());
293        assert_eq!(create_canonical_query_string(&params), "key=value");
294
295        params.insert("another".to_string(), "test".to_string());
296        assert_eq!(
297            create_canonical_query_string(&params),
298            "another=test&key=value"
299        );
300    }
301
302    #[test]
303    fn test_sha256_hex() {
304        let result = sha256_hex(b"test");
305        assert_eq!(
306            result,
307            "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"
308        );
309    }
310
311    #[test]
312    fn test_hmac_sha256() {
313        let result = hmac_sha256(b"key", b"message");
314        let hex = HEXLOWER.encode(&result);
315        assert_eq!(
316            hex,
317            "6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a"
318        );
319    }
320}