1use dns_resolver::Resolver;
2use futures::future::BoxFuture;
3use hickory_resolver::Name;
4use policy::MtaStsPolicy;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8lruttl::declare_cache! {
9static CACHE: LruCacheWithTtl<Name, CachedPolicy>::new("mta_sts_policy", 64 * 1024);
11}
12
13pub mod dns;
14pub mod policy;
15
16#[derive(Clone, Debug)]
17struct CachedPolicy {
18 pub id: String,
19 pub policy: Arc<MtaStsPolicy>,
20}
21
22struct Getter {}
23
24impl policy::Get for Getter {
25 fn http_get<'a>(&'a self, url: &'a str) -> BoxFuture<'a, anyhow::Result<String>> {
26 Box::pin(async move {
27 let response = reqwest::Client::builder()
28 .redirect(reqwest::redirect::Policy::none())
31 .timeout(std::time::Duration::from_secs(20))
32 .build()?
33 .request(reqwest::Method::GET, url)
34 .send()
35 .await?;
36
37 let status = response.status();
41 if status != reqwest::StatusCode::OK {
42 anyhow::bail!("failed to GET {url}: {status}");
43 }
44
45 let content_type = response
55 .headers()
56 .get(reqwest::header::CONTENT_TYPE)
57 .ok_or_else(|| anyhow::anyhow!("missing required Content-Type header"))?;
58
59 let content_type = content_type.to_str()?;
60
61 let ct = if let Some((ct, _)) = content_type.split_once(';') {
62 ct.trim()
63 } else {
64 content_type.trim()
65 };
66 if ct != "text/plain" {
67 anyhow::bail!("Content-Type must be text/plain, got {content_type}");
68 }
69
70 Ok(response.text().await?)
71 })
72 }
73}
74
75pub async fn get_policy_for_domain(policy_domain: &str) -> anyhow::Result<Arc<MtaStsPolicy>> {
76 let resolver = dns_resolver::get_resolver();
77 get_policy_for_domain_impl(policy_domain, &**resolver, &Getter {}).await
78}
79
80fn cache_lookup(name: &Name) -> Option<CachedPolicy> {
81 CACHE.get(name)
82}
83
84async fn get_policy_for_domain_impl(
85 policy_domain: &str,
86 resolver: &dyn Resolver,
87 getter: &dyn policy::Get,
88) -> anyhow::Result<Arc<MtaStsPolicy>> {
89 let name = Name::from_str_relaxed(policy_domain)?.to_lowercase();
90
91 if let Some(cached) = cache_lookup(&name) {
92 let still_valid = dns::resolve_dns_record(policy_domain, resolver)
95 .await
96 .map(|r| cached.id == r.id)
97 .unwrap_or(true);
98
99 if still_valid {
100 return Ok(Arc::clone(&cached.policy));
101 }
102 }
103
104 let record = dns::resolve_dns_record(policy_domain, resolver).await?;
105
106 let policy = Arc::new(policy::load_policy_for_domain(policy_domain, getter).await?);
107
108 let expires = Instant::now() + Duration::from_secs(policy.max_age);
109
110 CACHE
111 .insert(
112 name,
113 CachedPolicy {
114 id: record.id,
115 policy: Arc::clone(&policy),
116 },
117 expires.into(),
118 )
119 .await;
120
121 Ok(policy)
122}
123
124