1use anyhow::Context;
2use aws_lc_rs::hmac::Key;
3use chrono::{DateTime, Utc};
4use config::{any_err, from_lua_value, get_or_create_sub_module};
5use data_encoding::HEXLOWER;
6use data_loader::KeySource;
7use mlua::{Lua, LuaSerdeExt, Value};
8use percent_encoding::{percent_encode, AsciiSet, CONTROLS};
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11
12const URI_ENCODE_SET: &AsciiSet = &CONTROLS
15 .add(b' ')
16 .add(b'!')
17 .add(b'"')
18 .add(b'#')
19 .add(b'$')
20 .add(b'%')
21 .add(b'&')
22 .add(b'\'')
23 .add(b'(')
24 .add(b')')
25 .add(b'*')
26 .add(b'+')
27 .add(b',')
28 .add(b'/')
29 .add(b':')
30 .add(b';')
31 .add(b'=')
32 .add(b'?')
33 .add(b'@')
34 .add(b'[')
35 .add(b']');
36
37#[derive(Deserialize, Debug)]
38struct SigV4Request {
39 access_key: KeySource,
41 secret_key: KeySource,
43 region: String,
45 service: String,
47 method: String,
49 uri: String,
51 #[serde(default)]
53 query_params: BTreeMap<String, String>,
54 #[serde(default)]
56 headers: BTreeMap<String, String>,
57 #[serde(default)]
59 payload: String,
60 timestamp: Option<DateTime<Utc>>,
62 session_token: Option<String>,
64}
65
66#[derive(Deserialize, Serialize, Debug)]
67struct SigV4Response {
68 authorization: String,
70 timestamp: String,
72 canonical_request: String,
74 string_to_sign: String,
76 signature: String,
78}
79
80fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
81 let key = Key::new(aws_lc_rs::hmac::HMAC_SHA256, key);
82 let tag = aws_lc_rs::hmac::sign(&key, data);
83 tag.as_ref().to_vec()
84}
85
86fn sha256_hex(data: &[u8]) -> String {
87 use aws_lc_rs::digest;
88 let hash = digest::digest(&digest::SHA256, data);
89 HEXLOWER.encode(hash.as_ref())
90}
91
92fn uri_encode(input: &str) -> String {
93 percent_encode(input.as_bytes(), URI_ENCODE_SET).to_string()
94}
95
96fn create_canonical_uri(path: &str) -> String {
97 if path.is_empty() {
98 "/".to_string()
99 } else {
100 path.split('/')
102 .map(uri_encode)
103 .collect::<Vec<_>>()
104 .join("/")
105 }
106}
107
108fn create_canonical_query_string(params: &BTreeMap<String, String>) -> String {
109 if params.is_empty() {
110 return String::new();
111 }
112
113 let mut encoded_params: Vec<(String, String)> = params
119 .iter()
120 .map(|(k, v)| (uri_encode(k), uri_encode(v)))
121 .collect();
122 encoded_params.sort();
123
124 encoded_params
125 .iter()
126 .map(|(k, v)| format!("{}={}", k, v))
127 .collect::<Vec<_>>()
128 .join("&")
129}
130
131fn create_canonical_headers(headers: &BTreeMap<String, String>) -> (String, String) {
132 let canonical_headers: BTreeMap<String, String> = headers
134 .iter()
135 .map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
136 .collect();
137
138 let header_string = canonical_headers
140 .iter()
141 .map(|(k, v)| format!("{}:{}", k, v))
142 .collect::<Vec<_>>()
143 .join("\n");
144
145 let signed_headers = canonical_headers
147 .keys()
148 .cloned()
149 .collect::<Vec<_>>()
150 .join(";");
151
152 (header_string, signed_headers)
153}
154
155fn create_signing_key(secret_key: &str, date_stamp: &str, region: &str, service: &str) -> Vec<u8> {
156 let k_date = hmac_sha256(
157 format!("AWS4{secret_key}").as_bytes(),
158 date_stamp.as_bytes(),
159 );
160 let k_region = hmac_sha256(&k_date, region.as_bytes());
161 let k_service = hmac_sha256(&k_region, service.as_bytes());
162 hmac_sha256(&k_service, b"aws4_request")
163}
164
165async fn sign_request(req: SigV4Request) -> anyhow::Result<SigV4Response> {
166 let access_key_bytes = req.access_key.get().await?;
168 let access_key = std::str::from_utf8(&access_key_bytes)
169 .context("access_key must be valid UTF-8")?
170 .to_string();
171
172 let secret_key_bytes = req.secret_key.get().await?;
174 let secret_key = std::str::from_utf8(&secret_key_bytes)
175 .context("secret_key must be valid UTF-8")?
176 .to_string();
177
178 let timestamp = req.timestamp.unwrap_or_else(Utc::now);
180 let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
181 let date_stamp = timestamp.format("%Y%m%d").to_string();
182
183 let payload_hash = sha256_hex(req.payload.as_bytes());
185
186 let mut headers = req.headers.clone();
188 headers.insert("host".to_string(), "".to_string()); headers.insert("x-amz-date".to_string(), amz_date.clone());
190
191 if let Some(token) = &req.session_token {
192 headers.insert("x-amz-security-token".to_string(), token.clone());
193 }
194
195 if req.service != "s3" {
197 headers.insert("x-amz-content-sha256".to_string(), payload_hash.clone());
198 }
199
200 let canonical_uri = create_canonical_uri(&req.uri);
202 let canonical_query_string = create_canonical_query_string(&req.query_params);
203 let (canonical_headers, signed_headers) = create_canonical_headers(&headers);
204
205 let canonical_request = format!(
209 "{method}\n{canonical_uri}\n{canonical_query_string}\n{canonical_headers}\n\n{signed_headers}\n{payload_hash}",
210 method = req.method,
211 );
212
213 let algorithm = "AWS4-HMAC-SHA256";
215 let credential_scope = format!(
216 "{date_stamp}/{region}/{service}/aws4_request",
217 region = req.region,
218 service = req.service
219 );
220 let canonical_request_hash = sha256_hex(canonical_request.as_bytes());
221
222 let string_to_sign =
223 format!("{algorithm}\n{amz_date}\n{credential_scope}\n{canonical_request_hash}");
224
225 let signing_key = create_signing_key(&secret_key, &date_stamp, &req.region, &req.service);
227 let signature_bytes = hmac_sha256(&signing_key, string_to_sign.as_bytes());
228 let signature = HEXLOWER.encode(&signature_bytes);
229
230 let authorization = format!(
232 "{algorithm} Credential={access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
233 access_key = access_key
234 );
235
236 Ok(SigV4Response {
237 authorization,
238 timestamp: amz_date,
239 canonical_request,
240 string_to_sign,
241 signature,
242 })
243}
244
245pub fn register(lua: &Lua) -> anyhow::Result<()> {
246 let aws_mod = get_or_create_sub_module(lua, "crypto")?;
249
250 aws_mod.set(
251 "aws_sign_v4",
252 lua.create_async_function(|lua, request: Value| async move {
253 let req: SigV4Request = from_lua_value(&lua, request)?;
254 let response = sign_request(req).await.map_err(any_err)?;
255
256 lua.to_value(&response)
257 })?,
258 )?;
259
260 Ok(())
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_uri_encode() {
269 assert_eq!(uri_encode("test"), "test");
270 assert_eq!(uri_encode("test value"), "test%20value");
271 assert_eq!(uri_encode("test/path"), "test%2Fpath");
272 assert_eq!(uri_encode("test-value_123.txt~"), "test-value_123.txt~");
273 }
274
275 #[test]
276 fn test_canonical_uri() {
277 assert_eq!(create_canonical_uri(""), "/");
278 assert_eq!(create_canonical_uri("/"), "/");
279 assert_eq!(create_canonical_uri("/path"), "/path");
280 assert_eq!(create_canonical_uri("/path/to/file"), "/path/to/file");
281 assert_eq!(
282 create_canonical_uri("/path with spaces"),
283 "/path%20with%20spaces"
284 );
285 }
286
287 #[test]
288 fn test_canonical_query_string() {
289 let mut params = BTreeMap::new();
290 assert_eq!(create_canonical_query_string(¶ms), "");
291
292 params.insert("key".to_string(), "value".to_string());
293 assert_eq!(create_canonical_query_string(¶ms), "key=value");
294
295 params.insert("another".to_string(), "test".to_string());
296 assert_eq!(
297 create_canonical_query_string(¶ms),
298 "another=test&key=value"
299 );
300 }
301
302 #[test]
303 fn test_sha256_hex() {
304 let result = sha256_hex(b"test");
305 assert_eq!(
306 result,
307 "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"
308 );
309 }
310
311 #[test]
312 fn test_hmac_sha256() {
313 let result = hmac_sha256(b"key", b"message");
314 let hex = HEXLOWER.encode(&result);
315 assert_eq!(
316 hex,
317 "6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a"
318 );
319 }
320}