kumo_api_types/
shaping.rs

1use crate::egress_path::EgressPathConfig;
2#[cfg(feature = "lua")]
3use anyhow::Context;
4#[cfg(feature = "lua")]
5use config::any_err;
6#[cfg(feature = "lua")]
7use config::serialize_options;
8#[cfg(feature = "lua")]
9use dns_resolver::{fully_qualify, MailExchanger};
10#[cfg(feature = "lua")]
11use kumo_log_types::JsonLogRecord;
12#[cfg(feature = "lua")]
13use mlua::prelude::LuaUserData;
14#[cfg(feature = "lua")]
15use mlua::{LuaSerdeExt, UserDataMethods};
16use ordermap::OrderMap;
17use serde::de::{SeqAccess, Visitor};
18use serde::{Deserialize, Deserializer, Serialize};
19use serde_with::formats::PreferOne;
20use serde_with::{serde_as, DeserializeAs, OneOrMany};
21#[cfg(feature = "lua")]
22use sha2::{Digest, Sha256};
23#[cfg(feature = "lua")]
24use std::collections::BTreeMap;
25use std::hash::{Hash, Hasher};
26use std::marker::PhantomData;
27#[cfg(feature = "lua")]
28use std::sync::Arc;
29use std::time::Duration;
30#[cfg(feature = "lua")]
31use throttle::LimitSpec;
32use throttle::ThrottleSpec;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(try_from = "String", into = "String")]
36pub struct Regex(fancy_regex::Regex);
37
38impl TryFrom<String> for Regex {
39    type Error = fancy_regex::Error;
40
41    fn try_from(s: String) -> fancy_regex::Result<Self> {
42        Ok(Self(fancy_regex::Regex::new(&s)?))
43    }
44}
45
46impl From<Regex> for String {
47    fn from(r: Regex) -> String {
48        r.0.as_str().to_string()
49    }
50}
51
52impl std::ops::Deref for Regex {
53    type Target = fancy_regex::Regex;
54    fn deref(&self) -> &Self::Target {
55        &self.0
56    }
57}
58
59impl std::hash::Hash for Regex {
60    fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
61        self.0.as_str().hash(hasher)
62    }
63}
64
65/// toml::Value is not Hash because it may contain floating
66/// point numbers, which are problematic from a Ord and Eq
67/// perspective. We're okay with skirting around that for
68/// our purposes here, so we implement our own hashable
69/// wrapper around the toml value.
70#[derive(Deserialize, Serialize, Debug, Clone)]
71#[serde(from = "toml::Value", into = "toml::Value")]
72pub struct HashableTomlValue {
73    value: toml::Value,
74}
75
76impl From<toml::Value> for HashableTomlValue {
77    fn from(value: toml::Value) -> Self {
78        Self { value }
79    }
80}
81
82impl From<HashableTomlValue> for toml::Value {
83    fn from(value: HashableTomlValue) -> toml::Value {
84        value.value
85    }
86}
87
88impl std::ops::Deref for HashableTomlValue {
89    type Target = toml::Value;
90    fn deref(&self) -> &toml::Value {
91        &self.value
92    }
93}
94
95fn hash_toml<H>(value: &toml::Value, h: &mut H)
96where
97    H: Hasher,
98{
99    match value {
100        toml::Value::Boolean(v) => v.hash(h),
101        toml::Value::Datetime(v) => {
102            if let Some(d) = &v.date {
103                d.year.hash(h);
104                d.month.hash(h);
105                d.day.hash(h);
106            }
107            if let Some(t) = &v.time {
108                t.hour.hash(h);
109                t.minute.hash(h);
110                t.second.hash(h);
111                t.nanosecond.hash(h);
112            }
113            if let Some(toml::value::Offset::Custom { minutes }) = &v.offset {
114                minutes.hash(h);
115            }
116        }
117        toml::Value::String(v) => v.hash(h),
118        toml::Value::Integer(v) => v.hash(h),
119        toml::Value::Float(v) => v.to_ne_bytes().hash(h),
120        toml::Value::Array(a) => {
121            for v in a.iter() {
122                hash_toml(v, h);
123            }
124        }
125        toml::Value::Table(m) => {
126            for (k, v) in m.iter() {
127                k.hash(h);
128                hash_toml(v, h);
129            }
130        }
131    }
132}
133
134impl Hash for HashableTomlValue {
135    fn hash<H>(&self, h: &mut H)
136    where
137        H: Hasher,
138    {
139        hash_toml(&self.value, h);
140    }
141}
142
143/// Represents an individual EgressPathConfig field name and value.
144/// It only allows deserializing from valid EgressPathConfig field + values.
145#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
146#[serde(
147    try_from = "EgressPathConfigValueUnchecked",
148    into = "EgressPathConfigValueUnchecked"
149)]
150pub struct EgressPathConfigValue {
151    pub name: String,
152    pub value: HashableTomlValue,
153}
154
155/// This is the type that we actually use to deserialize EgressPathConfigValue items.
156/// It doesn't care about validity; it is used solely to tell serde what shape of
157/// data to expect.
158/// The validation is performed by the TryFrom impl that is used to convert to the
159/// checked form below.
160#[derive(Deserialize, Serialize, Debug, Clone)]
161pub struct EgressPathConfigValueUnchecked {
162    pub name: String,
163    pub value: toml::Value,
164}
165
166impl TryFrom<EgressPathConfigValueUnchecked> for EgressPathConfigValue {
167    type Error = anyhow::Error;
168    fn try_from(config: EgressPathConfigValueUnchecked) -> anyhow::Result<EgressPathConfigValue> {
169        let mut map = toml::map::Map::new();
170        map.insert(config.name.clone(), config.value.clone());
171        let table = toml::Value::Table(map);
172
173        // Attempt to deserialize as EgressPathConfig.
174        // If it fails, then the field name/value are invalid
175        EgressPathConfig::deserialize(table)?;
176
177        // If we reach this point, we can pass along the name/value
178        Ok(EgressPathConfigValue {
179            name: config.name,
180            value: HashableTomlValue {
181                value: config.value,
182            },
183        })
184    }
185}
186
187impl From<EgressPathConfigValue> for EgressPathConfigValueUnchecked {
188    fn from(config: EgressPathConfigValue) -> EgressPathConfigValueUnchecked {
189        EgressPathConfigValueUnchecked {
190            name: config.name,
191            value: config.value.value,
192        }
193    }
194}
195
196#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
197pub enum Action {
198    Suspend,
199    SetConfig(EgressPathConfigValue),
200    SuspendTenant,
201    SuspendCampaign,
202    SetDomainConfig(EgressPathConfigValue),
203    Bounce,
204    BounceTenant,
205    BounceCampaign,
206}
207
208#[derive(Deserialize, Serialize, Debug, Clone, Hash, Default)]
209pub enum Trigger {
210    /// Trigger on the first match, immediately
211    #[default]
212    Immediate,
213    /// Trigger when a certain number of matches occur
214    /// over a certain time period.
215    Threshold(ThrottleSpec),
216}
217
218#[serde_as]
219#[derive(Deserialize, Serialize, Debug, Hash, Clone)]
220pub struct Rule {
221    #[serde(deserialize_with = "regex_string_or_array")]
222    pub regex: Vec<Regex>,
223
224    #[serde(deserialize_with = "one_or_many_action")]
225    pub action: Vec<Action>,
226
227    #[serde(default)]
228    pub trigger: Trigger,
229
230    #[serde(with = "duration_serde")]
231    pub duration: Duration,
232
233    #[serde(skip)]
234    pub was_rollup: bool,
235
236    /// if true, this rule can match kumomta internally generated
237    /// Response messages, otherwise, the rule will skip testing
238    /// against those.
239    #[serde(default)]
240    pub match_internal: bool,
241}
242
243impl Rule {
244    pub fn matches(&self, is_internal: bool, response: &str) -> bool {
245        if is_internal && !self.match_internal {
246            return false;
247        }
248        self.regex
249            .iter()
250            .any(|r| r.is_match(response).unwrap_or(false))
251    }
252
253    pub fn clone_and_set_rollup(&self) -> Self {
254        let mut result = self.clone();
255        result.was_rollup = true;
256        result
257    }
258}
259
260#[cfg(feature = "lua")]
261#[derive(Debug, Default)]
262struct ShapingInner {
263    by_site: OrderMap<String, PartialEntry>,
264    by_domain: OrderMap<String, PartialEntry>,
265    by_provider: OrderMap<String, ProviderEntry>,
266    warnings: Vec<String>,
267    errors: Vec<String>,
268    hash: String,
269}
270
271#[cfg(feature = "lua")]
272impl ShapingInner {
273    pub async fn get_egress_path_config(
274        &self,
275        domain: &str,
276        egress_source: &str,
277        site_name: &str,
278    ) -> PartialEntry {
279        let mut params = PartialEntry::default();
280
281        // Apply basic/default configuration
282        if let Some(default) = self.by_domain.get("default") {
283            params.merge_from(default.clone());
284        }
285
286        // Provider rules come next.
287        // They can only match valid domain names, so we'll
288        // skip processing them if we have something like our
289        // TSA http "domain name" here
290        let is_domain_name = dns_resolver::Name::from_str_relaxed(domain).is_ok();
291        if is_domain_name {
292            let mut prov_with_sources = vec![];
293
294            for prov in self.by_provider.values() {
295                if prov.domain_matches(domain).await {
296                    toml_table_merge_from(&mut params.params, &prov.params);
297                    prov.apply_provider_params_to(egress_source, &mut params.params);
298
299                    if !prov.sources.is_empty() {
300                        // Remember this matching provider, so that we
301                        // can apply any source rules after we've applied
302                        // any/all base provider rules for this domain
303                        prov_with_sources.push(prov);
304                    }
305                }
306            }
307
308            // Then Provider source rules
309            for prov in prov_with_sources {
310                if let Some(source) = prov.sources.get(egress_source) {
311                    toml_table_merge_from(&mut params.params, source);
312                    prov.apply_provider_params_to(egress_source, &mut params.params);
313                }
314            }
315        }
316
317        // Then site config
318        if let Some(by_site) = self.by_site.get(site_name) {
319            params.merge_from(by_site.clone());
320        }
321
322        // Then domain config
323        if let Some(by_domain) = self.by_domain.get(domain) {
324            params.merge_from(by_domain.clone());
325        }
326
327        // Then source config for the site
328        if let Some(by_site) = self.by_site.get(site_name) {
329            if let Some(source) = by_site.sources.get(egress_source) {
330                toml_table_merge_from(&mut params.params, source);
331            }
332        }
333
334        // Then source config for the domain
335        if let Some(by_domain) = self.by_domain.get(domain) {
336            if let Some(source) = by_domain.sources.get(egress_source) {
337                toml_table_merge_from(&mut params.params, source);
338            }
339        }
340
341        params
342    }
343
344    pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
345        use rfc5321::ForwardPath;
346        // Extract the domain from the recipient.
347        let recipient = ForwardPath::try_from(
348            record
349                .recipient
350                .first()
351                .ok_or_else(|| anyhow::anyhow!("no recipients!?"))?
352                .as_str(),
353        )
354        .map_err(|err| anyhow::anyhow!("parsing record.recipient: {err}"))?;
355
356        let recipient = match recipient {
357            ForwardPath::Postmaster => {
358                // It doesn't make sense to apply automation on the
359                // local postmaster address, so we ignore this.
360                return Ok(vec![]);
361            }
362            ForwardPath::Path(path) => path.mailbox,
363        };
364        let domain = recipient.domain.to_string();
365
366        // Track events/outcomes by site.
367        let source = record.egress_source.as_deref().unwrap_or("unspecified");
368        // record.site is poorly named; it is really an identifier for the
369        // egress path. For matching purposes, we want just the site_name
370        // in the form produced by our MX resolution process.
371        // In an earlier incarnation of this logic, we would resolve the
372        // site_name for ourselves based on other data in the record,
373        // but that could lead to over-resolution of some names and
374        // yield surprising results.
375        // What we do here is extract the egress path decoration from
376        // record.site to arrive at something that looks like the
377        // mx site_name.
378        // NOTE: this is coupled with the logic in
379        // ReadyQueueManager::compute_queue_name
380        let site_name = record
381            .site
382            .trim_start_matches(&format!("{source}->"))
383            .trim_end_matches("@smtp_client")
384            .to_string();
385
386        Ok(self.match_rules_impl(record, &domain, &site_name).await)
387    }
388
389    pub async fn match_rules_impl(
390        &self,
391        record: &JsonLogRecord,
392        domain: &str,
393        site_name: &str,
394    ) -> Vec<Rule> {
395        let mut result = vec![];
396        let response = record.response.to_single_line();
397        tracing::trace!("Consider rules for {response}");
398
399        let is_internal = record.response.content.starts_with("KumoMTA internal: ");
400
401        if let Some(default) = self.by_domain.get("default") {
402            for rule in &default.automation {
403                tracing::trace!("Consider \"default\" rule {rule:?} for {response}");
404                if rule.matches(is_internal, &response) {
405                    // For automation under `default`, we always
406                    // assume that mx_rollup should be true.
407                    // If you somehow have a domain where that isn't
408                    // true, you should avoid using `default` for
409                    // automation.  Honestly, it's best to avoid
410                    // using `default` for automation.
411                    result.push(rule.clone_and_set_rollup());
412                }
413            }
414        }
415
416        for prov in self.by_provider.values() {
417            if prov.domain_matches(domain).await {
418                for rule in &prov.automation {
419                    tracing::trace!(
420                        "Consider provider \"{}\" rule {rule:?} for {response}",
421                        prov.provider_name
422                    );
423                    if rule.matches(is_internal, &response) {
424                        result.push(rule.clone());
425                    }
426                }
427            }
428        }
429
430        // Then site config
431        if let Some(by_site) = self.by_site.get(site_name) {
432            for rule in &by_site.automation {
433                tracing::trace!("Consider \"{site_name}\" rule {rule:?} for {response}");
434                if rule.matches(is_internal, &response) {
435                    result.push(rule.clone_and_set_rollup());
436                }
437            }
438        }
439
440        // Then domain config
441        if let Some(by_domain) = self.by_domain.get(domain) {
442            for rule in &by_domain.automation {
443                tracing::trace!("Consider \"{domain}\" rule {rule:?} for {response}");
444                if rule.matches(is_internal, &response) {
445                    result.push(rule.clone());
446                }
447            }
448        }
449
450        result
451    }
452}
453
454#[cfg(feature = "lua")]
455#[derive(Debug, Default, Clone, mlua::FromLua)]
456pub struct Shaping {
457    inner: Arc<ShapingInner>,
458}
459
460#[cfg(feature = "lua")]
461fn from_json<'a, T: Deserialize<'a>>(json: &'a str) -> anyhow::Result<T> {
462    let d = &mut serde_json::Deserializer::from_str(json);
463    Ok(serde_path_to_error::deserialize(d)?)
464}
465
466#[cfg(feature = "lua")]
467fn from_toml<'a, T: Deserialize<'a>>(toml: &'a str) -> anyhow::Result<T> {
468    let d = toml::Deserializer::parse(toml)?;
469    Ok(serde_path_to_error::deserialize(d)?)
470}
471
472#[cfg(feature = "lua")]
473#[derive(Default, Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
474pub enum CheckLevel {
475    #[default]
476    Ignore,
477    Warn,
478    Error,
479}
480#[cfg(feature = "lua")]
481impl std::str::FromStr for CheckLevel {
482    type Err = String;
483
484    fn from_str(s: &str) -> Result<Self, String> {
485        if s.eq_ignore_ascii_case("ignore") {
486            Ok(Self::Ignore)
487        } else if s.eq_ignore_ascii_case("warn") {
488            Ok(Self::Warn)
489        } else if s.eq_ignore_ascii_case("error") {
490            Ok(Self::Error)
491        } else {
492            Err(format!(
493                "Expected one of `Ignore`, `Warn` or `Error`, got `{s}`"
494            ))
495        }
496    }
497}
498
499#[cfg(feature = "lua")]
500#[derive(Default)]
501struct Collector {
502    warnings: Vec<String>,
503    errors: Vec<String>,
504}
505
506#[cfg(feature = "lua")]
507impl Collector {
508    fn push<S: Into<String>>(&mut self, level: CheckLevel, msg: S) {
509        match level {
510            CheckLevel::Ignore => {}
511            CheckLevel::Warn => self.warnings.push(msg.into()),
512            CheckLevel::Error => self.errors.push(msg.into()),
513        }
514    }
515}
516
517#[cfg(feature = "lua")]
518#[derive(Debug, Clone, Deserialize)]
519#[serde_as]
520#[serde(deny_unknown_fields)]
521pub struct ShapingMergeOptions {
522    #[serde(default)]
523    pub provider_overlap: CheckLevel,
524    #[serde(default)]
525    pub dns_fail: CheckLevel,
526    #[serde(default)]
527    pub null_mx: CheckLevel,
528    #[serde(default)]
529    pub aliased_site: CheckLevel,
530    #[serde(default)]
531    pub skip_remote: bool,
532    #[serde(default)]
533    pub remote_load: CheckLevel,
534    #[serde(default)]
535    pub local_load: CheckLevel,
536    #[serde(default, with = "duration_serde")]
537    pub http_timeout: Option<Duration>,
538}
539
540#[cfg(feature = "lua")]
541impl Default for ShapingMergeOptions {
542    fn default() -> Self {
543        Self {
544            provider_overlap: CheckLevel::Ignore,
545            dns_fail: CheckLevel::Ignore,
546            null_mx: CheckLevel::Ignore,
547            aliased_site: CheckLevel::Ignore,
548            skip_remote: false,
549            remote_load: CheckLevel::Ignore,
550            local_load: CheckLevel::Error,
551            http_timeout: None,
552        }
553    }
554}
555
556#[cfg(feature = "lua")]
557impl Shaping {
558    async fn load_from_file(
559        path: &str,
560        options: &ShapingMergeOptions,
561        collector: &mut Collector,
562    ) -> anyhow::Result<ShapingFile> {
563        let (data, level): (String, CheckLevel) =
564            if path.starts_with("http://") || path.starts_with("https://") {
565                if options.skip_remote {
566                    collector.push(
567                        CheckLevel::Warn,
568                        format!("Ignoring {path} because skip_remote is set to true"),
569                    );
570                    return Ok(ShapingFile::default());
571                }
572
573                // To facilitate startup ordering races, and listing multiple subscription
574                // host replicas and allowing one or more of them to be temporarily down,
575                // we allow the http request to fail.
576                // We'll log the error message but consider it to be an empty map
577
578                async fn http_get(url: &str, timeout: Duration) -> anyhow::Result<String> {
579                    tokio::time::timeout(timeout, async {
580                        let response = reqwest::Client::builder()
581                            .timeout(timeout)
582                            .connect_timeout(timeout)
583                            .read_timeout(timeout)
584                            .build()?
585                            .get(url)
586                            .send()
587                            .await
588                            .with_context(|| format!("making HTTP request to {url}"))?;
589                        let status = response.status();
590                        if !status.is_success() {
591                            let err_text = match response.text().await {
592                                Ok(text) => text,
593                                Err(err) => {
594                                    format!("Additional error {err:#} while reading response body")
595                                }
596                            };
597                            anyhow::bail!(
598                                "HTTP request to {url} failed: status {status} {reason} {err_text}",
599                                status = status.as_u16(),
600                                reason = status.canonical_reason().unwrap_or("")
601                            );
602                        }
603                        response
604                            .text()
605                            .await
606                            .with_context(|| format!("reading text from {url}"))
607                    })
608                    .await
609                    .with_context(|| format!("timeout making HTTP request to {url}"))?
610                }
611
612                let timeout = options.http_timeout.unwrap_or(Duration::from_secs(5));
613
614                match http_get(path, timeout).await {
615                    Ok(s) => (s, options.remote_load),
616                    Err(err) => {
617                        tracing::error!("{err:#}. Ignoring this shaping source for now");
618                        collector.push(
619                            options.remote_load,
620                            format!("remote shaping source {path} error: {err:#}"),
621                        );
622                        return Ok(ShapingFile::default());
623                    }
624                }
625            } else {
626                match std::fs::read_to_string(path)
627                    .with_context(|| format!("loading data from file {path}"))
628                {
629                    Err(err) => {
630                        collector.push(
631                            options.local_load,
632                            format!("local shaping source {path} error: {err:#}"),
633                        );
634                        return Ok(ShapingFile::default());
635                    }
636                    Ok(s) => (s, options.local_load),
637                }
638            };
639
640        if path.ends_with(".toml") {
641            from_toml(&data).with_context(|| format!("parsing toml from file {path}"))
642        } else if path.ends_with(".json") {
643            from_json(&data).with_context(|| format!("parsing json from file {path}"))
644        } else {
645            // Try parsing both ways and see which wins
646            let mut errors = vec![];
647            match from_toml(&data) {
648                Ok(s) => return Ok(s),
649                Err(err) => errors.push(format!("as toml: {err:#}")),
650            }
651            match from_json(&data) {
652                Ok(s) => return Ok(s),
653                Err(err) => errors.push(format!("as json: {err:#}")),
654            }
655
656            collector.push(level, format!("parsing {path}: {}", errors.join(", ")));
657            Ok(ShapingFile::default())
658        }
659    }
660
661    pub async fn merge_files(
662        files: &[String],
663        options: &ShapingMergeOptions,
664    ) -> anyhow::Result<Self> {
665        use futures_util::stream::FuturesUnordered;
666        use futures_util::StreamExt;
667
668        let mut collector = Collector::default();
669        let mut loaded = vec![];
670        for p in files {
671            loaded.push(Self::load_from_file(p, options, &mut collector).await?);
672        }
673
674        let mut by_site: OrderMap<String, PartialEntry> = OrderMap::new();
675        let mut by_domain: OrderMap<String, PartialEntry> = OrderMap::new();
676        let mut by_provider: OrderMap<String, ProviderEntry> = OrderMap::new();
677        let mut site_aliases: OrderMap<String, Vec<String>> = OrderMap::new();
678
679        // Pre-resolve domains. We don't interleave the resolution with
680        // the work below, because we want to ensure that the ordering
681        // is preserved
682        let mut mx = std::collections::HashMap::new();
683        let mut lookups = FuturesUnordered::new();
684        for item in &loaded {
685            for (domain, partial) in &item.domains {
686                if partial.mx_rollup {
687                    let domain = domain.to_string();
688                    lookups.push(tokio::spawn(async move {
689                        let mx_result = MailExchanger::resolve(&domain).await;
690                        (domain, mx_result)
691                    }));
692                }
693            }
694        }
695
696        while let Some(Ok((domain, result))) = lookups.next().await {
697            mx.insert(domain, result);
698        }
699
700        for mut item in loaded {
701            if let Some(mut partial) = item.default.take() {
702                let domain = "default";
703                partial.domain_name.replace(domain.to_string());
704                match by_domain.get_mut(domain) {
705                    Some(existing) => {
706                        existing.merge_from(partial);
707                    }
708                    None => {
709                        by_domain.insert(domain.to_string(), partial);
710                    }
711                }
712            }
713
714            for (domain, mut partial) in item.domains {
715                partial.domain_name.replace(domain.clone());
716
717                if let Ok(name) = fully_qualify(&domain) {
718                    if name.num_labels() == 1 {
719                        collector.push(
720                            CheckLevel::Warn,
721                            format!(
722                                "Entry for domain '{domain}' consists of a \
723                                 single DNS label. Domain names in TOML sections \
724                                 need to be quoted like '[\"{domain}.com\"]` otherwise \
725                                 the '.' will create a nested table rather than being \
726                                 added to the domain name."
727                            ),
728                        );
729                    }
730                }
731
732                #[cfg(test)]
733                if partial._treat_domain_name_as_site_name {
734                    match by_site.get_mut(&domain) {
735                        Some(existing) => {
736                            existing.merge_from(partial);
737                        }
738                        None => {
739                            by_site.insert(domain.to_string(), partial);
740                        }
741                    }
742                    continue;
743                }
744                if partial.mx_rollup {
745                    let mx = match mx.get(&domain) {
746                        Some(Ok(mx)) => mx,
747                        Some(Err(err)) => {
748                            collector.push(
749                                options.dns_fail,
750                                format!(
751                                    "error resolving MX for {domain}: {err:#}. \
752                                 Ignoring the shaping config for that domain."
753                                ),
754                            );
755                            continue;
756                        }
757                        None => {
758                            collector.push(
759                                options.dns_fail,
760                                format!(
761                                "We didn't try to resolve the MX for {domain} for some reason!?. \
762                                 Ignoring the shaping config for that domain."
763                            ),
764                            );
765                            continue;
766                        }
767                    };
768
769                    if mx.site_name.is_empty() {
770                        collector.push(
771                            options.null_mx,
772                            format!(
773                            "domain {domain} has a NULL MX and cannot be used with mx_rollup=true. \
774                             Ignoring the shaping config for that domain."),
775                        );
776                        continue;
777                    }
778
779                    site_aliases
780                        .entry(mx.site_name.to_string())
781                        .or_default()
782                        .push(domain.to_string());
783
784                    match by_site.get_mut(&mx.site_name) {
785                        Some(existing) => {
786                            existing.merge_from(partial);
787                        }
788                        None => {
789                            by_site.insert(mx.site_name.clone(), partial);
790                        }
791                    }
792                } else {
793                    match by_domain.get_mut(&domain) {
794                        Some(existing) => {
795                            existing.merge_from(partial);
796                        }
797                        None => {
798                            by_domain.insert(domain, partial);
799                        }
800                    }
801                }
802            }
803
804            for (provider, mut prov) in item.provider {
805                prov.provider_name = provider.to_string();
806                match by_provider.get_mut(&provider) {
807                    Some(existing) => {
808                        existing.merge_from(prov);
809                    }
810                    None => {
811                        by_provider.insert(provider.to_string(), prov);
812                    }
813                }
814            }
815        }
816
817        for (site, partial) in &by_site {
818            partial
819                .clone()
820                .finish()
821                .with_context(|| format!("site: {site}"))?;
822        }
823
824        for (domain, partial) in &by_domain {
825            partial
826                .clone()
827                .finish()
828                .with_context(|| format!("domain: {domain}"))?;
829        }
830
831        for (provider, prov) in &by_provider {
832            prov.finish_params()
833                .with_context(|| format!("provider: {provider}"))?;
834        }
835
836        if options.aliased_site != CheckLevel::Ignore {
837            for (site, aliases) in site_aliases {
838                if aliases.len() > 1 {
839                    collector.push(
840                        options.aliased_site,
841                        format!(
842                            "multiple domain blocks alias to the same site: {site}: {}",
843                            aliases.join(", ")
844                        ),
845                    );
846                }
847            }
848        }
849
850        if options.provider_overlap != CheckLevel::Ignore {
851            for domain in mx.keys() {
852                let mut matching_providers = vec![];
853                for (prov_name, prov) in &by_provider {
854                    if prov.domain_matches(domain).await {
855                        matching_providers.push(prov_name.to_string());
856                    }
857                }
858                if !matching_providers.is_empty() {
859                    collector.push(
860                        options.provider_overlap,
861                        format!(
862                            "domain {domain} is also matched by provider(s): {}",
863                            matching_providers.join(", ")
864                        ),
865                    );
866                }
867            }
868        }
869
870        let mut ctx = Sha256::new();
871        ctx.update("by_site");
872        for (site, entry) in &by_site {
873            ctx.update(site);
874            entry.hash_into(&mut ctx);
875        }
876        ctx.update("by_domain");
877        for (domain, entry) in &by_domain {
878            ctx.update(domain);
879            entry.hash_into(&mut ctx);
880        }
881        ctx.update("by_provider");
882        for (provider, prov) in &by_provider {
883            ctx.update(provider);
884            prov.hash_into(&mut ctx);
885        }
886        ctx.update("warnings");
887        for warn in &collector.warnings {
888            ctx.update(warn);
889        }
890        ctx.update("errors");
891        for err in &collector.errors {
892            ctx.update(err);
893        }
894        let hash = ctx.finalize();
895        let hash = data_encoding::HEXLOWER.encode(&hash);
896
897        Ok(Self {
898            inner: Arc::new(ShapingInner {
899                by_site,
900                by_domain,
901                by_provider,
902                warnings: collector.warnings,
903                errors: collector.errors,
904                hash,
905            }),
906        })
907    }
908
909    async fn get_egress_path_config(
910        &self,
911        domain: &str,
912        egress_source: &str,
913        site_name: &str,
914    ) -> PartialEntry {
915        self.inner
916            .get_egress_path_config(domain, egress_source, site_name)
917            .await
918    }
919
920    pub async fn get_egress_path_config_value(
921        &self,
922        domain: &str,
923        egress_source: &str,
924        site_name: &str,
925    ) -> anyhow::Result<serde_json::Value> {
926        let partial = self
927            .get_egress_path_config(domain, egress_source, site_name)
928            .await;
929        Ok(serde_json::to_value(&partial)?)
930    }
931
932    pub fn get_errors(&self) -> &[String] {
933        &self.inner.errors
934    }
935
936    pub fn get_warnings(&self) -> &[String] {
937        &self.inner.warnings
938    }
939
940    pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
941        self.inner.match_rules(record).await
942    }
943
944    pub fn get_referenced_sources(&self) -> BTreeMap<String, Vec<String>> {
945        let mut result = BTreeMap::new();
946
947        for (site_name, site) in &self.inner.by_site {
948            for source_name in site.sources.keys() {
949                result
950                    .entry(source_name.to_string())
951                    .or_insert(vec![])
952                    .push(format!("site:{site_name}"));
953            }
954        }
955        for (domain_name, domain) in &self.inner.by_domain {
956            for source_name in domain.sources.keys() {
957                result
958                    .entry(source_name.to_string())
959                    .or_insert(vec![])
960                    .push(format!("domain:{domain_name}"));
961            }
962        }
963
964        result
965    }
966
967    pub fn hash(&self) -> String {
968        self.inner.hash.clone()
969    }
970}
971
972#[cfg(feature = "lua")]
973impl LuaUserData for Shaping {
974    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
975        mod_memoize::Memoized::impl_memoize(methods);
976        methods.add_async_method(
977            "get_egress_path_config",
978            |lua, this, (domain, egress_source, site_name): (String, String, String)| async move {
979                let params = this
980                    .get_egress_path_config(&domain, &egress_source, &site_name)
981                    .await;
982                lua.to_value_with(&params.params, serialize_options())
983            },
984        );
985
986        methods.add_method("get_errors", move |_lua, this, ()| {
987            let errors: Vec<String> = this.get_errors().iter().map(|s| s.to_string()).collect();
988            Ok(errors)
989        });
990
991        methods.add_method("get_warnings", move |_lua, this, ()| {
992            let warnings: Vec<String> = this.get_warnings().iter().map(|s| s.to_string()).collect();
993            Ok(warnings)
994        });
995
996        methods.add_method("get_referenced_sources", move |_lua, this, ()| {
997            Ok(this.get_referenced_sources())
998        });
999
1000        methods.add_async_method("match_rules", |lua, this, record: mlua::Value| async move {
1001            let record: JsonLogRecord = lua.from_value(record)?;
1002            let rules = this.match_rules(&record).await.map_err(any_err)?;
1003            let mut result = vec![];
1004            for rule in rules {
1005                result.push(lua.to_value(&rule)?);
1006            }
1007            Ok(result)
1008        });
1009
1010        methods.add_method("hash", move |_, this, ()| Ok(this.hash()));
1011    }
1012}
1013
1014#[derive(Default, Debug)]
1015pub struct MergedEntry {
1016    pub params: EgressPathConfig,
1017    pub sources: OrderMap<String, EgressPathConfig>,
1018    pub automation: Vec<Rule>,
1019}
1020
1021#[cfg(feature = "lua")]
1022#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1023struct ShapingFile {
1024    pub default: Option<PartialEntry>,
1025    #[serde(flatten, default)]
1026    pub domains: OrderMap<String, PartialEntry>,
1027    #[serde(default)]
1028    pub provider: OrderMap<String, ProviderEntry>,
1029}
1030
1031#[cfg(feature = "lua")]
1032#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1033struct PartialEntry {
1034    #[serde(skip)]
1035    pub domain_name: Option<String>,
1036
1037    #[serde(flatten)]
1038    pub params: toml::Table,
1039
1040    #[serde(default = "default_true")]
1041    pub mx_rollup: bool,
1042
1043    // This is present to facilitate unit testing without requiring
1044    // DNS to resolve the site_name. When set to true, the domain_name
1045    // is considered to be the site_name for this entry.
1046    #[cfg(test)]
1047    #[serde(default)]
1048    pub _treat_domain_name_as_site_name: bool,
1049
1050    #[serde(default)]
1051    pub replace_base: bool,
1052
1053    #[serde(default)]
1054    pub automation: Vec<Rule>,
1055
1056    #[serde(default)]
1057    pub sources: OrderMap<String, toml::Table>,
1058}
1059
1060#[cfg(feature = "lua")]
1061#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1062pub struct ProviderEntry {
1063    #[serde(skip, default)]
1064    pub provider_name: String,
1065
1066    #[serde(default)]
1067    pub provider_connection_limit: Option<LimitSpec>,
1068
1069    #[serde(default)]
1070    pub provider_max_message_rate: Option<ThrottleSpec>,
1071
1072    #[serde(default, rename = "match")]
1073    pub matches: Vec<ProviderMatch>,
1074
1075    #[serde(default)]
1076    pub replace_base: bool,
1077
1078    #[serde(flatten)]
1079    pub params: toml::Table,
1080
1081    #[serde(default)]
1082    pub automation: Vec<Rule>,
1083
1084    #[serde(default)]
1085    pub sources: OrderMap<String, toml::Table>,
1086}
1087
1088#[cfg(feature = "lua")]
1089fn suffix_matches(candidate: &str, suffix: &str) -> bool {
1090    // Remove trailing dot from candidate, as our resolver tends
1091    // to leave the canonical dot on the input host name
1092    let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1093    // We DON'T handle case mismatches in this code.
1094    // We assume that the dns-resolver crate normalized the
1095    // names to lowercase and that the input suffix is lowercase as well.
1096    candidate.ends_with(suffix)
1097}
1098
1099#[cfg(feature = "lua")]
1100#[cfg(test)]
1101#[test]
1102fn test_suffix_matches() {
1103    assert!(suffix_matches("a", "a"));
1104    assert!(suffix_matches("foo.com", "foo.com"));
1105    assert!(!suffix_matches("foo.com", ".foo.com"));
1106    assert!(!suffix_matches("foo.com", "longer.com"));
1107    assert!(!suffix_matches("réputation.net", ".mx.microsoft"));
1108
1109    // We DON'T handle case mismatches in this code.
1110    // We assume that the dns-resolver crate normalized the
1111    // names to lowercase and that the input suffix is lowercase as well.
1112    assert!(!suffix_matches("foo.Com", ".com"));
1113    assert!(!suffix_matches("foo.Cam", ".com"));
1114}
1115
1116#[cfg(feature = "lua")]
1117fn host_matches(candidate: &str, name: &str) -> bool {
1118    // Remove trailing dot from candidate, as our resolver tends
1119    // to leave the canonical dot on the input host name
1120    let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1121    candidate == name
1122}
1123
1124#[cfg(feature = "lua")]
1125#[cfg(test)]
1126#[test]
1127fn test_host_matches() {
1128    assert!(host_matches("foo.com", "foo.com"));
1129    assert!(host_matches("foo.com.", "foo.com"));
1130    assert!(!host_matches("foo.com", "notfoo.com"));
1131}
1132
1133#[cfg(feature = "lua")]
1134impl ProviderEntry {
1135    async fn domain_matches(&self, domain: &str) -> bool {
1136        // We'd like to avoid doing DNS if we can do a simple suffix match,
1137        // so we bias to looking at those first
1138        let mut need_mx = false;
1139
1140        tracing::trace!(
1141            "ProviderEntry::domain_matches({domain}) vs {:?}",
1142            self.matches
1143        );
1144
1145        for rule in &self.matches {
1146            match rule {
1147                ProviderMatch::DomainSuffix(suffix) => {
1148                    if suffix_matches(domain, suffix) {
1149                        tracing::trace!("{domain} suffix matches {suffix}");
1150                        return true;
1151                    }
1152                }
1153                ProviderMatch::HostName(_) | ProviderMatch::MXSuffix(_) => {
1154                    need_mx = true;
1155                }
1156            }
1157        }
1158
1159        if !need_mx {
1160            return false;
1161        }
1162
1163        // Now we can consider DNS
1164        match MailExchanger::resolve(domain).await {
1165            Err(err) => {
1166                // Didn't resolve; could be legit, for example, could
1167                // be an internal or fake name that is handled via
1168                // smart host or custom routing, so we log what
1169                // happened as a trace rather than polluting the
1170                // logs about it
1171                tracing::trace!(
1172                    "Error resolving MX for {domain}: {err:#}. \
1173                    Provider {} match rules will be ignored",
1174                    self.provider_name
1175                );
1176                false
1177            }
1178            Ok(mx) => {
1179                tracing::trace!("Consider MXSuffix rules");
1180                for host in &mx.hosts {
1181                    let mut matched = false;
1182
1183                    for rule in &self.matches {
1184                        match rule {
1185                            ProviderMatch::MXSuffix(suffix) => {
1186                                // For a given MX suffix rule, all hosts must match
1187                                // it for it to be valid. This is so that we don't
1188                                // falsely lump a vanity domain that blends providers
1189                                // together.
1190                                tracing::trace!("suffix={suffix} vs host {host}");
1191                                if suffix_matches(host, suffix) {
1192                                    matched = true;
1193                                    break;
1194                                }
1195                            }
1196                            ProviderMatch::HostName(name) => {
1197                                if host_matches(host, name) {
1198                                    matched = true;
1199                                    break;
1200                                }
1201                            }
1202                            ProviderMatch::DomainSuffix(_) => {}
1203                        }
1204                    }
1205
1206                    if !matched {
1207                        tracing::trace!("host didn't match any of these rules");
1208                        return false;
1209                    }
1210                }
1211
1212                true
1213            }
1214        }
1215    }
1216
1217    fn merge_from(&mut self, mut other: Self) {
1218        if other.replace_base {
1219            self.provider_connection_limit = other.provider_connection_limit;
1220            self.matches = other.matches;
1221            self.params = other.params;
1222            self.sources = other.sources;
1223            self.automation = other.automation;
1224        } else {
1225            if other.provider_connection_limit.is_some() {
1226                self.provider_connection_limit = other.provider_connection_limit;
1227            }
1228
1229            toml_table_merge_from(&mut self.params, &other.params);
1230
1231            for (source, tbl) in other.sources {
1232                match self.sources.get_mut(&source) {
1233                    Some(existing) => {
1234                        toml_table_merge_from(existing, &tbl);
1235                    }
1236                    None => {
1237                        self.sources.insert(source, tbl);
1238                    }
1239                }
1240            }
1241
1242            self.matches.append(&mut other.matches);
1243            self.automation.append(&mut other.automation);
1244        }
1245    }
1246
1247    fn apply_provider_params_to(&self, source: &str, target: &mut toml::Table) {
1248        let mut implied = toml::Table::new();
1249        implied.insert(
1250            "provider_name".to_string(),
1251            toml::Value::String(self.provider_name.to_string()),
1252        );
1253
1254        if let Some(limit) = &self.provider_connection_limit {
1255            let mut limits = toml::Table::new();
1256            limits.insert(
1257                format!("shaping-provider-{}-{source}-limit", self.provider_name),
1258                toml::Value::String(limit.to_string()),
1259            );
1260            implied.insert(
1261                "additional_connection_limits".to_string(),
1262                toml::Value::Table(limits),
1263            );
1264        }
1265        if let Some(rate) = &self.provider_max_message_rate {
1266            let rate = rate.as_string();
1267            let mut limits = toml::Table::new();
1268            limits.insert(
1269                format!("shaping-provider-{}-{source}-rate", self.provider_name),
1270                rate.into(),
1271            );
1272            implied.insert(
1273                "additional_message_rate_throttles".to_string(),
1274                toml::Value::Table(limits),
1275            );
1276        }
1277
1278        if let Some(rate) = target.remove("provider_source_selection_rate") {
1279            let mut limits = toml::Table::new();
1280            limits.insert(
1281                format!(
1282                    "shaping-provider-{}-{source}-selection-rate",
1283                    self.provider_name
1284                ),
1285                rate,
1286            );
1287            implied.insert(
1288                "additional_source_selection_rates".to_string(),
1289                toml::Value::Table(limits),
1290            );
1291        }
1292
1293        toml_table_merge_from(target, &implied);
1294    }
1295
1296    fn finish_params(&self) -> anyhow::Result<MergedEntry> {
1297        let provider_name = &self.provider_name;
1298
1299        let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1300            format!(
1301                "interpreting provider '{provider_name}' params {:#?} as EgressPathConfig",
1302                self.params
1303            )
1304        })?;
1305        let mut sources = OrderMap::new();
1306
1307        for (source, params) in &self.sources {
1308            let mut params = params.clone();
1309            // I don't really like this remove call. The issue is that we don't
1310            // have an alternative way to filter this out of the partial source
1311            // definition, and this provider_ option is not valid in the
1312            // EgressPathConfig struct itself; it is a shaping source-specific
1313            // addition to help with setting up shared throttles across providers
1314            params.remove("provider_source_selection_rate");
1315            sources.insert(
1316                source.clone(),
1317                EgressPathConfig::deserialize(params.clone()).with_context(|| {
1318                    format!("interpreting provider '{provider_name}' source '{source}' {params:#} as EgressPathConfig")
1319                })?,
1320            );
1321        }
1322
1323        Ok(MergedEntry {
1324            params,
1325            sources,
1326            automation: self.automation.clone(),
1327        })
1328    }
1329
1330    fn hash_into(&self, ctx: &mut Sha256) {
1331        ctx.update(&self.provider_name);
1332        ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1333    }
1334}
1335
1336#[cfg(feature = "lua")]
1337#[derive(Deserialize, Serialize, Debug, Clone)]
1338pub enum ProviderMatch {
1339    MXSuffix(String),
1340    DomainSuffix(String),
1341    HostName(String),
1342}
1343
1344#[cfg(feature = "lua")]
1345fn toml_table_merge_from(tbl: &mut toml::Table, source: &toml::Table) {
1346    // Limit merging to just the throttle related fields, as their purpose
1347    // is for creating broader scoped limits that cut across normal boundaries
1348    fn is_mergeable(name: &str) -> bool {
1349        match name {
1350            "additional_connection_limits"
1351            | "additional_message_rate_throttles"
1352            | "additional_source_selection_rates" => true,
1353            _ => false,
1354        }
1355    }
1356
1357    for (k, v) in source {
1358        match (tbl.get_mut(k), v.as_table()) {
1359            // Merge Table values together, rather than simply replacing them.
1360            (Some(toml::Value::Table(existing)), Some(v)) if is_mergeable(k) => {
1361                for (inner_k, inner_v) in v {
1362                    existing.insert(inner_k.clone(), inner_v.clone());
1363                }
1364            }
1365            _ => {
1366                tbl.insert(k.clone(), v.clone());
1367            }
1368        }
1369    }
1370}
1371
1372#[cfg(feature = "lua")]
1373impl PartialEntry {
1374    fn merge_from(&mut self, mut other: Self) {
1375        if other.replace_base {
1376            self.params = other.params;
1377            self.automation = other.automation;
1378            self.sources = other.sources;
1379        } else {
1380            toml_table_merge_from(&mut self.params, &other.params);
1381
1382            for (source, tbl) in other.sources {
1383                match self.sources.get_mut(&source) {
1384                    Some(existing) => {
1385                        toml_table_merge_from(existing, &tbl);
1386                    }
1387                    None => {
1388                        self.sources.insert(source, tbl);
1389                    }
1390                }
1391            }
1392
1393            self.automation.append(&mut other.automation);
1394        }
1395    }
1396
1397    fn finish(self) -> anyhow::Result<MergedEntry> {
1398        let domain = self.domain_name.unwrap_or_default();
1399
1400        let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1401            format!(
1402                "interpreting domain '{domain}' params {:#?} as EgressPathConfig",
1403                self.params
1404            )
1405        })?;
1406        let mut sources = OrderMap::new();
1407
1408        for (source, params) in self.sources {
1409            sources.insert(
1410                source.clone(),
1411                EgressPathConfig::deserialize(params.clone()).with_context(|| {
1412                    format!("interpreting domain '{domain}' source '{source}' {params:#} as EgressPathConfig")
1413                })?,
1414            );
1415        }
1416
1417        Ok(MergedEntry {
1418            params,
1419            sources,
1420            automation: self.automation,
1421        })
1422    }
1423
1424    fn hash_into(&self, ctx: &mut Sha256) {
1425        if let Some(name) = self.domain_name.as_ref() {
1426            ctx.update(name)
1427        }
1428        ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1429    }
1430}
1431
1432fn one_or_many<'de, T, D>(deserializer: D, expecting: &str) -> Result<Vec<T>, D::Error>
1433where
1434    T: Deserialize<'de>,
1435    D: Deserializer<'de>,
1436{
1437    let result: Result<Vec<T>, _> =
1438        OneOrMany::<serde_with::Same, PreferOne>::deserialize_as(deserializer);
1439    match result {
1440        Ok(r) => Ok(r),
1441        Err(err) => Err(serde::de::Error::custom(format!(
1442            "{expecting}.\nThe underlying error message is:\n{err:#}"
1443        ))),
1444    }
1445}
1446
1447fn one_or_many_action<'de, D>(deserializer: D) -> Result<Vec<Action>, D::Error>
1448where
1449    D: Deserializer<'de>,
1450{
1451    one_or_many(
1452        deserializer,
1453        "\"action\" field expected either a single Action or an array of Actions",
1454    )
1455}
1456
1457fn regex_string_or_array<'de, D>(deserializer: D) -> Result<Vec<Regex>, D::Error>
1458where
1459    D: Deserializer<'de>,
1460{
1461    string_or_array(
1462        deserializer,
1463        "regex string or array of regex strings for field regex",
1464    )
1465}
1466
1467fn string_or_array<'de, T, D>(deserializer: D, expecting: &'static str) -> Result<Vec<T>, D::Error>
1468where
1469    T: Deserialize<'de> + TryFrom<String>,
1470    <T as TryFrom<String>>::Error: std::fmt::Debug,
1471    D: Deserializer<'de>,
1472{
1473    // This is a Visitor that forwards string types to T's `TryFrom<String>` impl and
1474    // forwards map types to T's `Deserialize` impl. The `PhantomData` is to
1475    // keep the compiler from complaining about T being an unused generic type
1476    // parameter. We need T in order to know the Value type for the Visitor
1477    // impl.
1478    struct StringOrArray<T>(PhantomData<fn() -> T>, &'static str);
1479
1480    impl<'de, T> Visitor<'de> for StringOrArray<T>
1481    where
1482        T: Deserialize<'de> + TryFrom<String>,
1483        <T as TryFrom<String>>::Error: std::fmt::Debug,
1484    {
1485        type Value = Vec<T>;
1486
1487        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
1488            formatter.write_str(self.1)
1489        }
1490
1491        fn visit_str<E>(self, value: &str) -> Result<Vec<T>, E>
1492        where
1493            E: serde::de::Error,
1494        {
1495            Ok(vec![
1496                T::try_from(value.to_string()).map_err(|e| E::custom(format!("{e:?}")))?
1497            ])
1498        }
1499
1500        fn visit_seq<S>(self, seq: S) -> Result<Vec<T>, S::Error>
1501        where
1502            S: SeqAccess<'de>,
1503        {
1504            Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))
1505        }
1506    }
1507
1508    deserializer.deserialize_any(StringOrArray(PhantomData, expecting))
1509}
1510
1511#[cfg(feature = "lua")]
1512fn default_true() -> bool {
1513    true
1514}
1515
1516#[cfg(feature = "lua")]
1517pub fn register(lua: &mlua::Lua) -> anyhow::Result<()> {
1518    let shaping_mod = config::get_or_create_sub_module(lua, "shaping")?;
1519
1520    shaping_mod.set(
1521        "load",
1522        lua.create_async_function(
1523            move |lua, (paths, options): (Vec<String>, Option<mlua::Value>)| async move {
1524                let options = match options {
1525                    Some(v) => lua.from_value(v)?,
1526                    None => Default::default(),
1527                };
1528                let shaping = Shaping::merge_files(&paths, &options)
1529                    .await
1530                    .map_err(any_err)?;
1531                Ok(shaping)
1532            },
1533        )?,
1534    )?;
1535
1536    Ok(())
1537}
1538
1539#[cfg(test)]
1540mod test {
1541    use super::*;
1542    use kumo_log_types::RecordType;
1543    use rfc5321::Response;
1544    use std::io::Write;
1545    use tempfile::NamedTempFile;
1546    use uuid::Uuid;
1547
1548    async fn make_shaping_configs(inputs: &[&str]) -> Shaping {
1549        let mut files = vec![];
1550        let mut file_names = vec![];
1551
1552        for (i, content) in inputs.iter().enumerate() {
1553            let mut shaping_file = NamedTempFile::with_prefix(format!("file{i}")).unwrap();
1554            shaping_file.write_all(content.as_bytes()).unwrap();
1555            file_names.push(shaping_file.path().to_str().unwrap().to_string());
1556            files.push(shaping_file);
1557        }
1558
1559        Shaping::merge_files(&file_names, &ShapingMergeOptions::default())
1560            .await
1561            .unwrap()
1562    }
1563
1564    #[tokio::test]
1565    async fn test_merge_additional() {
1566        let shaping = make_shaping_configs(&[
1567            r#"
1568["example.com"]
1569mx_rollup = false
1570additional_connection_limits = {"first"=10}
1571        "#,
1572            r#"
1573["example.com"]
1574mx_rollup = false
1575additional_connection_limits = {"second"=32}
1576additional_message_rate_throttles = {"second"="100/hr"}
1577        "#,
1578        ])
1579        .await;
1580
1581        let resolved = shaping
1582            .get_egress_path_config("example.com", "invalid.source", "invalid.site")
1583            .await
1584            .finish()
1585            .unwrap();
1586
1587        k9::snapshot!(
1588            resolved.params.additional_connection_limits,
1589            r#"
1590{
1591    "first": 10,
1592    "second": 32,
1593}
1594"#
1595        );
1596        k9::snapshot!(
1597            resolved.params.additional_message_rate_throttles,
1598            r#"
1599{
1600    "second": 100/h,
1601}
1602"#
1603        );
1604    }
1605
1606    #[tokio::test]
1607    async fn test_provider_multi_hostname() {
1608        let shaping = make_shaping_configs(&[r#"
1609[provider."yahoo"]
1610match=[{HostName="mta5.am0.yahoodns.net"},{HostName="mta6.am0.yahoodns.net"},{HostName="mta7.am0.yahoodns.net"}]
1611enable_tls = "Required"
1612        "#])
1613        .await;
1614
1615        let resolved = shaping
1616            .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1617            .await
1618            .finish()
1619            .unwrap();
1620
1621        k9::assert_equal!(
1622            resolved.params.enable_tls,
1623            crate::egress_path::Tls::Required
1624        );
1625        k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1626    }
1627
1628    #[tokio::test]
1629    async fn test_provider_multi_suffix() {
1630        let shaping = make_shaping_configs(&[r#"
1631[provider."yahoo"]
1632match=[{MXSuffix="mta5.am0.yahoodns.net"},{MXSuffix="mta6.am0.yahoodns.net"},{MXSuffix="mta7.am0.yahoodns.net"}]
1633enable_tls = "Required"
1634        "#])
1635        .await;
1636
1637        let resolved = shaping
1638            .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1639            .await
1640            .finish()
1641            .unwrap();
1642
1643        k9::assert_equal!(
1644            resolved.params.enable_tls,
1645            crate::egress_path::Tls::Required
1646        );
1647        k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1648    }
1649
1650    #[tokio::test]
1651    async fn test_provider() {
1652        let shaping = make_shaping_configs(&[r#"
1653[provider."Office 365"]
1654match=[{MXSuffix=".olc.protection.outlook.com"},{DomainSuffix=".outlook.com"}]
1655enable_tls = "Required"
1656provider_connection_limit = 10
1657provider_max_message_rate = "120/s"
1658
1659[provider."Office 365".sources."new-source"]
1660provider_source_selection_rate = "500/d,max_burst=1"
1661        "#])
1662        .await;
1663
1664        let resolved = shaping
1665            .get_egress_path_config("outlook.com", "invalid.source", "invalid.site")
1666            .await
1667            .finish()
1668            .unwrap();
1669
1670        k9::assert_equal!(
1671            resolved.params.enable_tls,
1672            crate::egress_path::Tls::Required
1673        );
1674        k9::assert_equal!(resolved.params.provider_name.unwrap(), "Office 365");
1675
1676        k9::snapshot!(
1677            resolved.params.additional_connection_limits,
1678            r#"
1679{
1680    "shaping-provider-Office 365-invalid.source-limit": 10,
1681}
1682"#
1683        );
1684        k9::snapshot!(
1685            resolved.params.additional_message_rate_throttles,
1686            r#"
1687{
1688    "shaping-provider-Office 365-invalid.source-rate": 120/s,
1689}
1690"#
1691        );
1692        assert!(resolved.params.source_selection_rate.is_none());
1693        assert!(resolved.params.additional_source_selection_rates.is_empty());
1694
1695        let resolved = shaping
1696            .get_egress_path_config("outlook.com", "new-source", "invalid.site")
1697            .await
1698            .finish()
1699            .unwrap();
1700        assert!(resolved.params.source_selection_rate.is_none());
1701        k9::snapshot!(
1702            resolved.params.additional_source_selection_rates,
1703            r#"
1704{
1705    "shaping-provider-Office 365-new-source-selection-rate": 500/d,max_burst=1,
1706}
1707"#
1708        );
1709    }
1710
1711    #[tokio::test]
1712    async fn test_rule_matching() {
1713        let shaping = make_shaping_configs(&[r#"
1714[["default".automation]]
1715regex="default"
1716action = {SetConfig={name="connection_limit", value=1}}
1717duration = "1hr"
1718
1719["fake.site"]
1720_treat_domain_name_as_site_name = true
1721
1722[["fake.site".automation]]
1723regex="fake_rollup"
1724action = {SetConfig={name="connection_limit", value=2}}
1725duration = "1hr"
1726
1727["woot.provider"]
1728mx_rollup = false
1729
1730[["woot.provider".automation]]
1731regex="woot_domain"
1732action = {SetConfig={name="connection_limit", value=2}}
1733duration = "1hr"
1734
1735[provider."provider"]
1736match=[{DomainSuffix=".provider"}]
1737
1738[[provider."provider".automation]]
1739regex="provider"
1740action = {SetConfig={name="connection_limit", value=3}}
1741duration = "1hr"
1742match_internal = true
1743
1744"#])
1745        .await;
1746
1747        eprintln!("{:?}", shaping.inner.warnings);
1748
1749        fn make_record(content: &str, recipient: &str, site: &str) -> JsonLogRecord {
1750            JsonLogRecord {
1751                kind: RecordType::TransientFailure,
1752                id: String::new(),
1753                sender: String::new(),
1754                recipient: vec![recipient.to_string()],
1755                queue: String::new(),
1756                site: site.to_string(),
1757                size: 0,
1758                response: Response {
1759                    code: 400,
1760                    command: None,
1761                    enhanced_code: None,
1762                    content: content.to_string(),
1763                },
1764                peer_address: None,
1765                timestamp: Default::default(),
1766                created: Default::default(),
1767                num_attempts: 1,
1768                bounce_classification: Default::default(),
1769                egress_pool: None,
1770                egress_source: None,
1771                source_address: None,
1772                feedback_report: None,
1773                meta: Default::default(),
1774                headers: Default::default(),
1775                delivery_protocol: None,
1776                reception_protocol: None,
1777                nodeid: Uuid::default(),
1778                tls_cipher: None,
1779                tls_protocol_version: None,
1780                tls_peer_subject_name: None,
1781                provider_name: None,
1782                session_id: None,
1783            }
1784        }
1785
1786        let matches = shaping
1787            .match_rules(&make_record("default", "user@example.com", "dummy_site"))
1788            .await
1789            .unwrap();
1790        k9::assert_equal!(
1791            matches[0].regex[0].to_string(),
1792            "default",
1793            "matches against default automation rule"
1794        );
1795
1796        let matches = shaping
1797            .match_rules(&make_record(
1798                "KumoMTA internal: default",
1799                "user@example.com",
1800                "dummy_site",
1801            ))
1802            .await
1803            .unwrap();
1804        assert!(matches.is_empty(), "internal bounce should not match");
1805
1806        let matches = shaping
1807            .match_rules(&make_record(
1808                "woot_domain",
1809                "user@woot.provider",
1810                "dummy_site",
1811            ))
1812            .await
1813            .unwrap();
1814        k9::assert_equal!(
1815            matches[0].regex[0].to_string(),
1816            "woot_domain",
1817            "matches against domain rule with mx_rollup=false"
1818        );
1819
1820        let matches = shaping
1821            .match_rules(&make_record("fake_rollup", "user@fake.rollup", "fake.site"))
1822            .await
1823            .unwrap();
1824        k9::assert_equal!(
1825            matches[0].regex[0].to_string(),
1826            "fake_rollup",
1827            "matches against domain rule with mx_rollup=true"
1828        );
1829
1830        let matches = shaping
1831            .match_rules(&make_record("provider", "user@woot.provider", "dummy_site"))
1832            .await
1833            .unwrap();
1834        k9::assert_equal!(
1835            matches[0].regex[0].to_string(),
1836            "provider",
1837            "matches against provider rule"
1838        );
1839
1840        let matches = shaping
1841            .match_rules(&make_record(
1842                "KumoMTA internal: provider",
1843                "user@woot.provider",
1844                "dummy_site",
1845            ))
1846            .await
1847            .unwrap();
1848        k9::assert_equal!(
1849            matches[0].regex[0].to_string(),
1850            "provider",
1851            "internal response matches against provider rule"
1852        );
1853    }
1854
1855    #[tokio::test]
1856    async fn test_defaults() {
1857        let shaping = make_shaping_configs(&[
1858            r#"
1859["default"]
1860connection_limit = 10
1861max_connection_rate = "100/min"
1862max_deliveries_per_connection = 100
1863max_message_rate = "100/s"
1864idle_timeout = "60s"
1865data_timeout = "30s"
1866data_dot_timeout = "60s"
1867enable_tls = "Opportunistic"
1868consecutive_connection_failures_before_delay = 100
1869
1870[["default".automation]]
1871regex=[
1872        '/Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/',
1873        '/All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/',
1874        '/has been temporarily rate limited due to IP reputation/',
1875        '/Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/',
1876        '/Server busy\. Please try again later from/'
1877]
1878action = [
1879        {SetConfig={name="max_message_rate", value="1/minute"}},
1880        {SetConfig={name="connection_limit", value=1}}
1881]
1882duration = "90m"
1883
1884[["default".automation]]
1885regex="KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site"
1886action = {SetConfig={name="enable_tls", value="Disabled"}}
1887duration = "30 days"
1888
1889["gmail.com"]
1890max_deliveries_per_connection = 50
1891connection_limit = 5
1892enable_tls = "Required"
1893consecutive_connection_failures_before_delay = 5
1894
1895["yahoo.com"]
1896max_deliveries_per_connection = 20
1897
1898[["yahoo.com".automation]]
1899regex = "\\[TS04\\]"
1900action = "Suspend"
1901duration = "2 hours"
1902
1903["comcast.net"]
1904connection_limit = 25
1905max_deliveries_per_connection = 250
1906enable_tls = "Required"
1907idle_timeout = "30s"
1908consecutive_connection_failures_before_delay = 24
1909
1910["mail.com"]
1911max_deliveries_per_connection = 100
1912
1913["orange.fr"]
1914connection_limit = 3
1915
1916["smtp.mailgun.com"]
1917connection_limit = 7000
1918max_deliveries_per_connection = 3
1919
1920["example.com"]
1921mx_rollup = false
1922max_deliveries_per_connection = 100
1923connection_limit = 3
1924
1925["example.com".sources."my source name"]
1926connection_limit = 5
1927        "#,
1928        ])
1929        .await;
1930
1931        let default = shaping
1932            .get_egress_path_config("invalid.domain", "invalid.source", "invalid.site")
1933            .await
1934            .finish()
1935            .unwrap();
1936        k9::snapshot!(
1937            default,
1938            r#"
1939MergedEntry {
1940    params: EgressPathConfig {
1941        connection_limit: 10,
1942        additional_connection_limits: {},
1943        enable_tls: Opportunistic,
1944        enable_mta_sts: true,
1945        enable_dane: false,
1946        enable_pipelining: true,
1947        enable_rset: true,
1948        tls_prefer_openssl: false,
1949        tls_certificate: None,
1950        tls_private_key: None,
1951        openssl_cipher_list: None,
1952        openssl_cipher_suites: None,
1953        openssl_options: None,
1954        rustls_cipher_suites: [],
1955        client_timeouts: SmtpClientTimeouts {
1956            connect_timeout: 60s,
1957            banner_timeout: 60s,
1958            ehlo_timeout: 300s,
1959            mail_from_timeout: 300s,
1960            rcpt_to_timeout: 300s,
1961            data_timeout: 30s,
1962            data_dot_timeout: 60s,
1963            rset_timeout: 5s,
1964            idle_timeout: 60s,
1965            starttls_timeout: 5s,
1966            auth_timeout: 60s,
1967        },
1968        system_shutdown_timeout: None,
1969        max_ready: 1024,
1970        consecutive_connection_failures_before_delay: 100,
1971        smtp_port: 25,
1972        smtp_auth_plain_username: None,
1973        smtp_auth_plain_password: None,
1974        allow_smtp_auth_plain_without_tls: false,
1975        max_message_rate: Some(
1976            100/s,
1977        ),
1978        additional_message_rate_throttles: {},
1979        source_selection_rate: None,
1980        additional_source_selection_rates: {},
1981        max_connection_rate: Some(
1982            100/m,
1983        ),
1984        max_deliveries_per_connection: 100,
1985        max_recipients_per_batch: 100,
1986        prohibited_hosts: {
1987            "0.0.0.0",
1988            "127.0.0.0/8",
1989            "::/127",
1990        },
1991        skip_hosts: {},
1992        ehlo_domain: None,
1993        aggressive_connection_opening: false,
1994        refresh_interval: 60s,
1995        refresh_strategy: Ttl,
1996        dispatcher_wakeup_strategy: Aggressive,
1997        maintainer_wakeup_strategy: Aggressive,
1998        provider_name: None,
1999        remember_broken_tls: None,
2000        opportunistic_tls_reconnect_on_failed_handshake: false,
2001        use_lmtp: false,
2002        reconnect_strategy: ConnectNextHost,
2003        readyq_pool_name: None,
2004        low_memory_reduction_policy: ShrinkDataAndMeta,
2005        no_memory_reduction_policy: ShrinkDataAndMeta,
2006        try_next_host_on_transport_error: false,
2007        ignore_8bit_checks: false,
2008    },
2009    sources: {},
2010    automation: [
2011        Rule {
2012            regex: [
2013                Regex(
2014                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2015                ),
2016                Regex(
2017                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2018                ),
2019                Regex(
2020                    /has been temporarily rate limited due to IP reputation/,
2021                ),
2022                Regex(
2023                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2024                ),
2025                Regex(
2026                    /Server busy\. Please try again later from/,
2027                ),
2028            ],
2029            action: [
2030                SetConfig(
2031                    EgressPathConfigValue {
2032                        name: "max_message_rate",
2033                        value: HashableTomlValue {
2034                            value: String(
2035                                "1/minute",
2036                            ),
2037                        },
2038                    },
2039                ),
2040                SetConfig(
2041                    EgressPathConfigValue {
2042                        name: "connection_limit",
2043                        value: HashableTomlValue {
2044                            value: Integer(
2045                                1,
2046                            ),
2047                        },
2048                    },
2049                ),
2050            ],
2051            trigger: Immediate,
2052            duration: 5400s,
2053            was_rollup: false,
2054            match_internal: false,
2055        },
2056        Rule {
2057            regex: [
2058                Regex(
2059                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2060                ),
2061            ],
2062            action: [
2063                SetConfig(
2064                    EgressPathConfigValue {
2065                        name: "enable_tls",
2066                        value: HashableTomlValue {
2067                            value: String(
2068                                "Disabled",
2069                            ),
2070                        },
2071                    },
2072                ),
2073            ],
2074            trigger: Immediate,
2075            duration: 2592000s,
2076            was_rollup: false,
2077            match_internal: false,
2078        },
2079    ],
2080}
2081"#
2082        );
2083
2084        let example_com = shaping
2085            .get_egress_path_config("example.com", "invalid.source", "invalid.site")
2086            .await
2087            .finish()
2088            .unwrap();
2089        k9::snapshot!(
2090            example_com,
2091            r#"
2092MergedEntry {
2093    params: EgressPathConfig {
2094        connection_limit: 3,
2095        additional_connection_limits: {},
2096        enable_tls: Opportunistic,
2097        enable_mta_sts: true,
2098        enable_dane: false,
2099        enable_pipelining: true,
2100        enable_rset: true,
2101        tls_prefer_openssl: false,
2102        tls_certificate: None,
2103        tls_private_key: None,
2104        openssl_cipher_list: None,
2105        openssl_cipher_suites: None,
2106        openssl_options: None,
2107        rustls_cipher_suites: [],
2108        client_timeouts: SmtpClientTimeouts {
2109            connect_timeout: 60s,
2110            banner_timeout: 60s,
2111            ehlo_timeout: 300s,
2112            mail_from_timeout: 300s,
2113            rcpt_to_timeout: 300s,
2114            data_timeout: 30s,
2115            data_dot_timeout: 60s,
2116            rset_timeout: 5s,
2117            idle_timeout: 60s,
2118            starttls_timeout: 5s,
2119            auth_timeout: 60s,
2120        },
2121        system_shutdown_timeout: None,
2122        max_ready: 1024,
2123        consecutive_connection_failures_before_delay: 100,
2124        smtp_port: 25,
2125        smtp_auth_plain_username: None,
2126        smtp_auth_plain_password: None,
2127        allow_smtp_auth_plain_without_tls: false,
2128        max_message_rate: Some(
2129            100/s,
2130        ),
2131        additional_message_rate_throttles: {},
2132        source_selection_rate: None,
2133        additional_source_selection_rates: {},
2134        max_connection_rate: Some(
2135            100/m,
2136        ),
2137        max_deliveries_per_connection: 100,
2138        max_recipients_per_batch: 100,
2139        prohibited_hosts: {
2140            "0.0.0.0",
2141            "127.0.0.0/8",
2142            "::/127",
2143        },
2144        skip_hosts: {},
2145        ehlo_domain: None,
2146        aggressive_connection_opening: false,
2147        refresh_interval: 60s,
2148        refresh_strategy: Ttl,
2149        dispatcher_wakeup_strategy: Aggressive,
2150        maintainer_wakeup_strategy: Aggressive,
2151        provider_name: None,
2152        remember_broken_tls: None,
2153        opportunistic_tls_reconnect_on_failed_handshake: false,
2154        use_lmtp: false,
2155        reconnect_strategy: ConnectNextHost,
2156        readyq_pool_name: None,
2157        low_memory_reduction_policy: ShrinkDataAndMeta,
2158        no_memory_reduction_policy: ShrinkDataAndMeta,
2159        try_next_host_on_transport_error: false,
2160        ignore_8bit_checks: false,
2161    },
2162    sources: {
2163        "my source name": EgressPathConfig {
2164            connection_limit: 5,
2165            additional_connection_limits: {},
2166            enable_tls: Opportunistic,
2167            enable_mta_sts: true,
2168            enable_dane: false,
2169            enable_pipelining: true,
2170            enable_rset: true,
2171            tls_prefer_openssl: false,
2172            tls_certificate: None,
2173            tls_private_key: None,
2174            openssl_cipher_list: None,
2175            openssl_cipher_suites: None,
2176            openssl_options: None,
2177            rustls_cipher_suites: [],
2178            client_timeouts: SmtpClientTimeouts {
2179                connect_timeout: 60s,
2180                banner_timeout: 60s,
2181                ehlo_timeout: 300s,
2182                mail_from_timeout: 300s,
2183                rcpt_to_timeout: 300s,
2184                data_timeout: 300s,
2185                data_dot_timeout: 300s,
2186                rset_timeout: 5s,
2187                idle_timeout: 5s,
2188                starttls_timeout: 5s,
2189                auth_timeout: 60s,
2190            },
2191            system_shutdown_timeout: None,
2192            max_ready: 1024,
2193            consecutive_connection_failures_before_delay: 100,
2194            smtp_port: 25,
2195            smtp_auth_plain_username: None,
2196            smtp_auth_plain_password: None,
2197            allow_smtp_auth_plain_without_tls: false,
2198            max_message_rate: None,
2199            additional_message_rate_throttles: {},
2200            source_selection_rate: None,
2201            additional_source_selection_rates: {},
2202            max_connection_rate: None,
2203            max_deliveries_per_connection: 1024,
2204            max_recipients_per_batch: 100,
2205            prohibited_hosts: {
2206                "0.0.0.0",
2207                "127.0.0.0/8",
2208                "::/127",
2209            },
2210            skip_hosts: {},
2211            ehlo_domain: None,
2212            aggressive_connection_opening: false,
2213            refresh_interval: 60s,
2214            refresh_strategy: Ttl,
2215            dispatcher_wakeup_strategy: Aggressive,
2216            maintainer_wakeup_strategy: Aggressive,
2217            provider_name: None,
2218            remember_broken_tls: None,
2219            opportunistic_tls_reconnect_on_failed_handshake: false,
2220            use_lmtp: false,
2221            reconnect_strategy: ConnectNextHost,
2222            readyq_pool_name: None,
2223            low_memory_reduction_policy: ShrinkDataAndMeta,
2224            no_memory_reduction_policy: ShrinkDataAndMeta,
2225            try_next_host_on_transport_error: false,
2226            ignore_8bit_checks: false,
2227        },
2228    },
2229    automation: [
2230        Rule {
2231            regex: [
2232                Regex(
2233                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2234                ),
2235                Regex(
2236                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2237                ),
2238                Regex(
2239                    /has been temporarily rate limited due to IP reputation/,
2240                ),
2241                Regex(
2242                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2243                ),
2244                Regex(
2245                    /Server busy\. Please try again later from/,
2246                ),
2247            ],
2248            action: [
2249                SetConfig(
2250                    EgressPathConfigValue {
2251                        name: "max_message_rate",
2252                        value: HashableTomlValue {
2253                            value: String(
2254                                "1/minute",
2255                            ),
2256                        },
2257                    },
2258                ),
2259                SetConfig(
2260                    EgressPathConfigValue {
2261                        name: "connection_limit",
2262                        value: HashableTomlValue {
2263                            value: Integer(
2264                                1,
2265                            ),
2266                        },
2267                    },
2268                ),
2269            ],
2270            trigger: Immediate,
2271            duration: 5400s,
2272            was_rollup: false,
2273            match_internal: false,
2274        },
2275        Rule {
2276            regex: [
2277                Regex(
2278                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2279                ),
2280            ],
2281            action: [
2282                SetConfig(
2283                    EgressPathConfigValue {
2284                        name: "enable_tls",
2285                        value: HashableTomlValue {
2286                            value: String(
2287                                "Disabled",
2288                            ),
2289                        },
2290                    },
2291                ),
2292            ],
2293            trigger: Immediate,
2294            duration: 2592000s,
2295            was_rollup: false,
2296            match_internal: false,
2297        },
2298    ],
2299}
2300"#
2301        );
2302
2303        // The site name here will need to be updated if yahoo changes
2304        // their MX records
2305        let yahoo_com = shaping
2306            .get_egress_path_config(
2307                "yahoo.com",
2308                "invalid.source",
2309                "(mta5|mta6|mta7).am0.yahoodns.net",
2310            )
2311            .await
2312            .finish()
2313            .unwrap();
2314        k9::snapshot!(
2315            yahoo_com,
2316            r#"
2317MergedEntry {
2318    params: EgressPathConfig {
2319        connection_limit: 10,
2320        additional_connection_limits: {},
2321        enable_tls: Opportunistic,
2322        enable_mta_sts: true,
2323        enable_dane: false,
2324        enable_pipelining: true,
2325        enable_rset: true,
2326        tls_prefer_openssl: false,
2327        tls_certificate: None,
2328        tls_private_key: None,
2329        openssl_cipher_list: None,
2330        openssl_cipher_suites: None,
2331        openssl_options: None,
2332        rustls_cipher_suites: [],
2333        client_timeouts: SmtpClientTimeouts {
2334            connect_timeout: 60s,
2335            banner_timeout: 60s,
2336            ehlo_timeout: 300s,
2337            mail_from_timeout: 300s,
2338            rcpt_to_timeout: 300s,
2339            data_timeout: 30s,
2340            data_dot_timeout: 60s,
2341            rset_timeout: 5s,
2342            idle_timeout: 60s,
2343            starttls_timeout: 5s,
2344            auth_timeout: 60s,
2345        },
2346        system_shutdown_timeout: None,
2347        max_ready: 1024,
2348        consecutive_connection_failures_before_delay: 100,
2349        smtp_port: 25,
2350        smtp_auth_plain_username: None,
2351        smtp_auth_plain_password: None,
2352        allow_smtp_auth_plain_without_tls: false,
2353        max_message_rate: Some(
2354            100/s,
2355        ),
2356        additional_message_rate_throttles: {},
2357        source_selection_rate: None,
2358        additional_source_selection_rates: {},
2359        max_connection_rate: Some(
2360            100/m,
2361        ),
2362        max_deliveries_per_connection: 20,
2363        max_recipients_per_batch: 100,
2364        prohibited_hosts: {
2365            "0.0.0.0",
2366            "127.0.0.0/8",
2367            "::/127",
2368        },
2369        skip_hosts: {},
2370        ehlo_domain: None,
2371        aggressive_connection_opening: false,
2372        refresh_interval: 60s,
2373        refresh_strategy: Ttl,
2374        dispatcher_wakeup_strategy: Aggressive,
2375        maintainer_wakeup_strategy: Aggressive,
2376        provider_name: None,
2377        remember_broken_tls: None,
2378        opportunistic_tls_reconnect_on_failed_handshake: false,
2379        use_lmtp: false,
2380        reconnect_strategy: ConnectNextHost,
2381        readyq_pool_name: None,
2382        low_memory_reduction_policy: ShrinkDataAndMeta,
2383        no_memory_reduction_policy: ShrinkDataAndMeta,
2384        try_next_host_on_transport_error: false,
2385        ignore_8bit_checks: false,
2386    },
2387    sources: {},
2388    automation: [
2389        Rule {
2390            regex: [
2391                Regex(
2392                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2393                ),
2394                Regex(
2395                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2396                ),
2397                Regex(
2398                    /has been temporarily rate limited due to IP reputation/,
2399                ),
2400                Regex(
2401                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2402                ),
2403                Regex(
2404                    /Server busy\. Please try again later from/,
2405                ),
2406            ],
2407            action: [
2408                SetConfig(
2409                    EgressPathConfigValue {
2410                        name: "max_message_rate",
2411                        value: HashableTomlValue {
2412                            value: String(
2413                                "1/minute",
2414                            ),
2415                        },
2416                    },
2417                ),
2418                SetConfig(
2419                    EgressPathConfigValue {
2420                        name: "connection_limit",
2421                        value: HashableTomlValue {
2422                            value: Integer(
2423                                1,
2424                            ),
2425                        },
2426                    },
2427                ),
2428            ],
2429            trigger: Immediate,
2430            duration: 5400s,
2431            was_rollup: false,
2432            match_internal: false,
2433        },
2434        Rule {
2435            regex: [
2436                Regex(
2437                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2438                ),
2439            ],
2440            action: [
2441                SetConfig(
2442                    EgressPathConfigValue {
2443                        name: "enable_tls",
2444                        value: HashableTomlValue {
2445                            value: String(
2446                                "Disabled",
2447                            ),
2448                        },
2449                    },
2450                ),
2451            ],
2452            trigger: Immediate,
2453            duration: 2592000s,
2454            was_rollup: false,
2455            match_internal: false,
2456        },
2457        Rule {
2458            regex: [
2459                Regex(
2460                    \[TS04\],
2461                ),
2462            ],
2463            action: [
2464                Suspend,
2465            ],
2466            trigger: Immediate,
2467            duration: 7200s,
2468            was_rollup: false,
2469            match_internal: false,
2470        },
2471    ],
2472}
2473"#
2474        );
2475    }
2476
2477    #[tokio::test]
2478    async fn test_load_default_shaping_toml() {
2479        Shaping::merge_files(
2480            &["../../assets/policy-extras/shaping.toml".into()],
2481            &ShapingMergeOptions::default(),
2482        )
2483        .await
2484        .unwrap();
2485    }
2486}