1use dns_resolver::Resolver;
2use futures::future::BoxFuture;
3use hickory_resolver::Name;
4use policy::MtaStsPolicy;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8lruttl::declare_cache! {
9static CACHE: LruCacheWithTtl<Name, CachedPolicy>::new("mta_sts_policy", 64 * 1024);
10}
11
12pub mod dns;
13pub mod policy;
14
15#[derive(Clone, Debug)]
16struct CachedPolicy {
17 pub id: String,
18 pub policy: Arc<MtaStsPolicy>,
19}
20
21struct Getter {}
22
23impl policy::Get for Getter {
24 fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>> {
25 Box::pin(async move {
26 let response = reqwest::Client::builder()
27 .redirect(reqwest::redirect::Policy::none())
30 .timeout(std::time::Duration::from_secs(20))
31 .build()?
32 .request(reqwest::Method::GET, url)
33 .send()
34 .await?;
35
36 let status = response.status();
40 if status != reqwest::StatusCode::OK {
41 anyhow::bail!("failed to GET {url}: {status}");
42 }
43
44 let content_type = response
54 .headers()
55 .get(reqwest::header::CONTENT_TYPE)
56 .ok_or_else(|| anyhow::anyhow!("missing required Content-Type header"))?;
57
58 let content_type = content_type.to_str()?;
59
60 let ct = if let Some((ct, _)) = content_type.split_once(';') {
61 ct.trim()
62 } else {
63 content_type.trim()
64 };
65 if ct != "text/plain" {
66 anyhow::bail!("Content-Type must be text/plain, got {content_type}");
67 }
68
69 Ok(response.text().await?)
70 })
71 }
72}
73
74pub async fn get_policy_for_domain(policy_domain: &str) -> anyhow::Result<Arc<MtaStsPolicy>> {
75 let resolver = dns_resolver::get_resolver();
76 get_policy_for_domain_impl(policy_domain, &**resolver, &Getter {}).await
77}
78
79fn cache_lookup(name: &Name) -> Option<CachedPolicy> {
80 CACHE.get(name)
81}
82
83async fn get_policy_for_domain_impl(
84 policy_domain: &str,
85 resolver: &dyn Resolver,
86 getter: &dyn policy::Get,
87) -> anyhow::Result<Arc<MtaStsPolicy>> {
88 let name = Name::from_str_relaxed(policy_domain)?.to_lowercase();
89
90 if let Some(cached) = cache_lookup(&name) {
91 let still_valid = dns::resolve_dns_record(policy_domain, resolver)
94 .await
95 .map(|r| cached.id == r.id)
96 .unwrap_or(true);
97
98 if still_valid {
99 return Ok(Arc::clone(&cached.policy));
100 }
101 }
102
103 let record = dns::resolve_dns_record(policy_domain, resolver).await?;
104
105 let policy = Arc::new(policy::load_policy_for_domain(policy_domain, getter).await?);
106
107 let expires = Instant::now() + Duration::from_secs(policy.max_age);
108
109 CACHE
110 .insert(
111 name,
112 CachedPolicy {
113 id: record.id,
114 policy: Arc::clone(&policy),
115 },
116 expires.into(),
117 )
118 .await;
119
120 Ok(policy)
121}
122
123