kumo_api_types/
shaping.rs

1use crate::egress_path::EgressPathConfig;
2#[cfg(feature = "lua")]
3use anyhow::Context;
4#[cfg(feature = "lua")]
5use config::any_err;
6#[cfg(feature = "lua")]
7use config::serialize_options;
8#[cfg(feature = "lua")]
9use dns_resolver::{fully_qualify, MailExchanger};
10#[cfg(feature = "lua")]
11use kumo_log_types::JsonLogRecord;
12#[cfg(feature = "lua")]
13use mlua::prelude::LuaUserData;
14#[cfg(feature = "lua")]
15use mlua::{LuaSerdeExt, UserDataMethods};
16use ordermap::OrderMap;
17use serde::de::{SeqAccess, Visitor};
18use serde::{Deserialize, Deserializer, Serialize};
19use serde_with::formats::PreferOne;
20use serde_with::{serde_as, DeserializeAs, OneOrMany};
21#[cfg(feature = "lua")]
22use sha2::{Digest, Sha256};
23#[cfg(feature = "lua")]
24use std::collections::BTreeMap;
25use std::hash::{Hash, Hasher};
26use std::marker::PhantomData;
27#[cfg(feature = "lua")]
28use std::sync::Arc;
29use std::time::Duration;
30#[cfg(feature = "lua")]
31use throttle::LimitSpec;
32use throttle::ThrottleSpec;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(try_from = "String", into = "String")]
36pub struct Regex(fancy_regex::Regex);
37
38impl TryFrom<String> for Regex {
39    type Error = fancy_regex::Error;
40
41    fn try_from(s: String) -> fancy_regex::Result<Self> {
42        Ok(Self(fancy_regex::Regex::new(&s)?))
43    }
44}
45
46impl From<Regex> for String {
47    fn from(r: Regex) -> String {
48        r.0.as_str().to_string()
49    }
50}
51
52impl std::ops::Deref for Regex {
53    type Target = fancy_regex::Regex;
54    fn deref(&self) -> &Self::Target {
55        &self.0
56    }
57}
58
59impl std::hash::Hash for Regex {
60    fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
61        self.0.as_str().hash(hasher)
62    }
63}
64
65/// toml::Value is not Hash because it may contain floating
66/// point numbers, which are problematic from a Ord and Eq
67/// perspective. We're okay with skirting around that for
68/// our purposes here, so we implement our own hashable
69/// wrapper around the toml value.
70#[derive(Deserialize, Serialize, Debug, Clone)]
71#[serde(from = "toml::Value", into = "toml::Value")]
72pub struct HashableTomlValue {
73    value: toml::Value,
74}
75
76impl From<toml::Value> for HashableTomlValue {
77    fn from(value: toml::Value) -> Self {
78        Self { value }
79    }
80}
81
82impl From<HashableTomlValue> for toml::Value {
83    fn from(value: HashableTomlValue) -> toml::Value {
84        value.value
85    }
86}
87
88impl std::ops::Deref for HashableTomlValue {
89    type Target = toml::Value;
90    fn deref(&self) -> &toml::Value {
91        &self.value
92    }
93}
94
95fn hash_toml<H>(value: &toml::Value, h: &mut H)
96where
97    H: Hasher,
98{
99    match value {
100        toml::Value::Boolean(v) => v.hash(h),
101        toml::Value::Datetime(v) => {
102            if let Some(d) = &v.date {
103                d.year.hash(h);
104                d.month.hash(h);
105                d.day.hash(h);
106            }
107            if let Some(t) = &v.time {
108                t.hour.hash(h);
109                t.minute.hash(h);
110                t.second.hash(h);
111                t.nanosecond.hash(h);
112            }
113            if let Some(toml::value::Offset::Custom { minutes }) = &v.offset {
114                minutes.hash(h);
115            }
116        }
117        toml::Value::String(v) => v.hash(h),
118        toml::Value::Integer(v) => v.hash(h),
119        toml::Value::Float(v) => v.to_ne_bytes().hash(h),
120        toml::Value::Array(a) => {
121            for v in a.iter() {
122                hash_toml(v, h);
123            }
124        }
125        toml::Value::Table(m) => {
126            for (k, v) in m.iter() {
127                k.hash(h);
128                hash_toml(v, h);
129            }
130        }
131    }
132}
133
134impl Hash for HashableTomlValue {
135    fn hash<H>(&self, h: &mut H)
136    where
137        H: Hasher,
138    {
139        hash_toml(&self.value, h);
140    }
141}
142
143/// Represents an individual EgressPathConfig field name and value.
144/// It only allows deserializing from valid EgressPathConfig field + values.
145#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
146#[serde(
147    try_from = "EgressPathConfigValueUnchecked",
148    into = "EgressPathConfigValueUnchecked"
149)]
150pub struct EgressPathConfigValue {
151    pub name: String,
152    pub value: HashableTomlValue,
153}
154
155/// This is the type that we actually use to deserialize EgressPathConfigValue items.
156/// It doesn't care about validity; it is used solely to tell serde what shape of
157/// data to expect.
158/// The validation is performed by the TryFrom impl that is used to convert to the
159/// checked form below.
160#[derive(Deserialize, Serialize, Debug, Clone)]
161pub struct EgressPathConfigValueUnchecked {
162    pub name: String,
163    pub value: toml::Value,
164}
165
166impl TryFrom<EgressPathConfigValueUnchecked> for EgressPathConfigValue {
167    type Error = anyhow::Error;
168    fn try_from(config: EgressPathConfigValueUnchecked) -> anyhow::Result<EgressPathConfigValue> {
169        let mut map = toml::map::Map::new();
170        map.insert(config.name.clone(), config.value.clone());
171        let table = toml::Value::Table(map);
172
173        // Attempt to deserialize as EgressPathConfig.
174        // If it fails, then the field name/value are invalid
175        EgressPathConfig::deserialize(table)?;
176
177        // If we reach this point, we can pass along the name/value
178        Ok(EgressPathConfigValue {
179            name: config.name,
180            value: HashableTomlValue {
181                value: config.value,
182            },
183        })
184    }
185}
186
187impl From<EgressPathConfigValue> for EgressPathConfigValueUnchecked {
188    fn from(config: EgressPathConfigValue) -> EgressPathConfigValueUnchecked {
189        EgressPathConfigValueUnchecked {
190            name: config.name,
191            value: config.value.value,
192        }
193    }
194}
195
196#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
197pub enum Action {
198    Suspend,
199    SetConfig(EgressPathConfigValue),
200    SuspendTenant,
201    SuspendCampaign,
202    SetDomainConfig(EgressPathConfigValue),
203    Bounce,
204    BounceTenant,
205    BounceCampaign,
206}
207
208#[derive(Deserialize, Serialize, Debug, Clone, Hash, Default)]
209pub enum Trigger {
210    /// Trigger on the first match, immediately
211    #[default]
212    Immediate,
213    /// Trigger when a certain number of matches occur
214    /// over a certain time period.
215    Threshold(ThrottleSpec),
216}
217
218#[serde_as]
219#[derive(Deserialize, Serialize, Debug, Hash, Clone)]
220pub struct Rule {
221    #[serde(deserialize_with = "regex_string_or_array")]
222    pub regex: Vec<Regex>,
223
224    #[serde(deserialize_with = "one_or_many_action")]
225    pub action: Vec<Action>,
226
227    #[serde(default)]
228    pub trigger: Trigger,
229
230    #[serde(with = "duration_serde")]
231    pub duration: Duration,
232
233    #[serde(skip)]
234    pub was_rollup: bool,
235
236    /// if true, this rule can match kumomta internally generated
237    /// Response messages, otherwise, the rule will skip testing
238    /// against those.
239    #[serde(default)]
240    pub match_internal: bool,
241}
242
243impl Rule {
244    pub fn matches(&self, is_internal: bool, response: &str) -> bool {
245        if is_internal && !self.match_internal {
246            return false;
247        }
248        self.regex
249            .iter()
250            .any(|r| r.is_match(response).unwrap_or(false))
251    }
252
253    pub fn clone_and_set_rollup(&self) -> Self {
254        let mut result = self.clone();
255        result.was_rollup = true;
256        result
257    }
258}
259
260#[cfg(feature = "lua")]
261#[derive(Debug, Default)]
262struct ShapingInner {
263    by_site: OrderMap<String, PartialEntry>,
264    by_domain: OrderMap<String, PartialEntry>,
265    by_provider: OrderMap<String, ProviderEntry>,
266    warnings: Vec<String>,
267    errors: Vec<String>,
268    hash: String,
269}
270
271#[cfg(feature = "lua")]
272impl ShapingInner {
273    pub async fn get_egress_path_config(
274        &self,
275        domain: &str,
276        egress_source: &str,
277        site_name: &str,
278    ) -> PartialEntry {
279        let mut params = PartialEntry::default();
280
281        // Apply basic/default configuration
282        if let Some(default) = self.by_domain.get("default") {
283            params.merge_from(default.clone());
284        }
285
286        // Provider rules come next.
287        // They can only match valid domain names, so we'll
288        // skip processing them if we have something like our
289        // TSA http "domain name" here
290        let is_domain_name = dns_resolver::Name::from_str_relaxed(domain).is_ok();
291        if is_domain_name {
292            let mut prov_with_sources = vec![];
293
294            for prov in self.by_provider.values() {
295                if prov.domain_matches(domain).await {
296                    toml_table_merge_from(&mut params.params, &prov.params);
297                    prov.apply_provider_params_to(egress_source, &mut params.params);
298
299                    if !prov.sources.is_empty() {
300                        // Remember this matching provider, so that we
301                        // can apply any source rules after we've applied
302                        // any/all base provider rules for this domain
303                        prov_with_sources.push(prov);
304                    }
305                }
306            }
307
308            // Then Provider source rules
309            for prov in prov_with_sources {
310                if let Some(source) = prov.sources.get(egress_source) {
311                    toml_table_merge_from(&mut params.params, source);
312                    prov.apply_provider_params_to(egress_source, &mut params.params);
313                }
314            }
315        }
316
317        // Then site config
318        if let Some(by_site) = self.by_site.get(site_name) {
319            params.merge_from(by_site.clone());
320        }
321
322        // Then domain config
323        if let Some(by_domain) = self.by_domain.get(domain) {
324            params.merge_from(by_domain.clone());
325        }
326
327        // Then source config for the site
328        if let Some(by_site) = self.by_site.get(site_name) {
329            if let Some(source) = by_site.sources.get(egress_source) {
330                toml_table_merge_from(&mut params.params, source);
331            }
332        }
333
334        // Then source config for the domain
335        if let Some(by_domain) = self.by_domain.get(domain) {
336            if let Some(source) = by_domain.sources.get(egress_source) {
337                toml_table_merge_from(&mut params.params, source);
338            }
339        }
340
341        params
342    }
343
344    pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
345        use rfc5321::ForwardPath;
346        // Extract the domain from the recipient.
347        let recipient = ForwardPath::try_from(record.recipient.as_str())
348            .map_err(|err| anyhow::anyhow!("parsing record.recipient: {err}"))?;
349
350        let recipient = match recipient {
351            ForwardPath::Postmaster => {
352                // It doesn't make sense to apply automation on the
353                // local postmaster address, so we ignore this.
354                return Ok(vec![]);
355            }
356            ForwardPath::Path(path) => path.mailbox,
357        };
358        let domain = recipient.domain.to_string();
359
360        // Track events/outcomes by site.
361        let source = record.egress_source.as_deref().unwrap_or("unspecified");
362        // record.site is poorly named; it is really an identifier for the
363        // egress path. For matching purposes, we want just the site_name
364        // in the form produced by our MX resolution process.
365        // In an earlier incarnation of this logic, we would resolve the
366        // site_name for ourselves based on other data in the record,
367        // but that could lead to over-resolution of some names and
368        // yield surprising results.
369        // What we do here is extract the egress path decoration from
370        // record.site to arrive at something that looks like the
371        // mx site_name.
372        // NOTE: this is coupled with the logic in
373        // ReadyQueueManager::compute_queue_name
374        let site_name = record
375            .site
376            .trim_start_matches(&format!("{source}->"))
377            .trim_end_matches("@smtp_client")
378            .to_string();
379
380        Ok(self.match_rules_impl(record, &domain, &site_name).await)
381    }
382
383    pub async fn match_rules_impl(
384        &self,
385        record: &JsonLogRecord,
386        domain: &str,
387        site_name: &str,
388    ) -> Vec<Rule> {
389        let mut result = vec![];
390        let response = record.response.to_single_line();
391        tracing::trace!("Consider rules for {response}");
392
393        let is_internal = record.response.content.starts_with("KumoMTA internal: ");
394
395        if let Some(default) = self.by_domain.get("default") {
396            for rule in &default.automation {
397                tracing::trace!("Consider \"default\" rule {rule:?} for {response}");
398                if rule.matches(is_internal, &response) {
399                    // For automation under `default`, we always
400                    // assume that mx_rollup should be true.
401                    // If you somehow have a domain where that isn't
402                    // true, you should avoid using `default` for
403                    // automation.  Honestly, it's best to avoid
404                    // using `default` for automation.
405                    result.push(rule.clone_and_set_rollup());
406                }
407            }
408        }
409
410        for prov in self.by_provider.values() {
411            if prov.domain_matches(domain).await {
412                for rule in &prov.automation {
413                    tracing::trace!(
414                        "Consider provider \"{}\" rule {rule:?} for {response}",
415                        prov.provider_name
416                    );
417                    if rule.matches(is_internal, &response) {
418                        result.push(rule.clone());
419                    }
420                }
421            }
422        }
423
424        // Then site config
425        if let Some(by_site) = self.by_site.get(site_name) {
426            for rule in &by_site.automation {
427                tracing::trace!("Consider \"{site_name}\" rule {rule:?} for {response}");
428                if rule.matches(is_internal, &response) {
429                    result.push(rule.clone_and_set_rollup());
430                }
431            }
432        }
433
434        // Then domain config
435        if let Some(by_domain) = self.by_domain.get(domain) {
436            for rule in &by_domain.automation {
437                tracing::trace!("Consider \"{domain}\" rule {rule:?} for {response}");
438                if rule.matches(is_internal, &response) {
439                    result.push(rule.clone());
440                }
441            }
442        }
443
444        result
445    }
446}
447
448#[cfg(feature = "lua")]
449#[derive(Debug, Default, Clone, mlua::FromLua)]
450pub struct Shaping {
451    inner: Arc<ShapingInner>,
452}
453
454#[cfg(feature = "lua")]
455fn from_json<'a, T: Deserialize<'a>>(json: &'a str) -> anyhow::Result<T> {
456    let d = &mut serde_json::Deserializer::from_str(json);
457    Ok(serde_path_to_error::deserialize(d)?)
458}
459
460#[cfg(feature = "lua")]
461fn from_toml<'a, T: Deserialize<'a>>(toml: &'a str) -> anyhow::Result<T> {
462    let d = toml::Deserializer::new(toml);
463    Ok(serde_path_to_error::deserialize(d)?)
464}
465
466#[cfg(feature = "lua")]
467#[derive(Default, Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
468pub enum CheckLevel {
469    #[default]
470    Ignore,
471    Warn,
472    Error,
473}
474#[cfg(feature = "lua")]
475impl std::str::FromStr for CheckLevel {
476    type Err = String;
477
478    fn from_str(s: &str) -> Result<Self, String> {
479        if s.eq_ignore_ascii_case("ignore") {
480            Ok(Self::Ignore)
481        } else if s.eq_ignore_ascii_case("warn") {
482            Ok(Self::Warn)
483        } else if s.eq_ignore_ascii_case("error") {
484            Ok(Self::Error)
485        } else {
486            Err(format!(
487                "Expected one of `Ignore`, `Warn` or `Error`, got `{s}`"
488            ))
489        }
490    }
491}
492
493#[cfg(feature = "lua")]
494#[derive(Default)]
495struct Collector {
496    warnings: Vec<String>,
497    errors: Vec<String>,
498}
499
500#[cfg(feature = "lua")]
501impl Collector {
502    fn push<S: Into<String>>(&mut self, level: CheckLevel, msg: S) {
503        match level {
504            CheckLevel::Ignore => {}
505            CheckLevel::Warn => self.warnings.push(msg.into()),
506            CheckLevel::Error => self.errors.push(msg.into()),
507        }
508    }
509}
510
511#[cfg(feature = "lua")]
512#[derive(Debug, Clone, Deserialize)]
513#[serde_as]
514#[serde(deny_unknown_fields)]
515pub struct ShapingMergeOptions {
516    #[serde(default)]
517    pub provider_overlap: CheckLevel,
518    #[serde(default)]
519    pub dns_fail: CheckLevel,
520    #[serde(default)]
521    pub null_mx: CheckLevel,
522    #[serde(default)]
523    pub aliased_site: CheckLevel,
524    #[serde(default)]
525    pub skip_remote: bool,
526    #[serde(default)]
527    pub remote_load: CheckLevel,
528    #[serde(default)]
529    pub local_load: CheckLevel,
530    #[serde(default, with = "duration_serde")]
531    pub http_timeout: Option<Duration>,
532}
533
534#[cfg(feature = "lua")]
535impl Default for ShapingMergeOptions {
536    fn default() -> Self {
537        Self {
538            provider_overlap: CheckLevel::Ignore,
539            dns_fail: CheckLevel::Ignore,
540            null_mx: CheckLevel::Ignore,
541            aliased_site: CheckLevel::Ignore,
542            skip_remote: false,
543            remote_load: CheckLevel::Ignore,
544            local_load: CheckLevel::Error,
545            http_timeout: None,
546        }
547    }
548}
549
550#[cfg(feature = "lua")]
551impl Shaping {
552    async fn load_from_file(
553        path: &str,
554        options: &ShapingMergeOptions,
555        collector: &mut Collector,
556    ) -> anyhow::Result<ShapingFile> {
557        let (data, level): (String, CheckLevel) =
558            if path.starts_with("http://") || path.starts_with("https://") {
559                if options.skip_remote {
560                    collector.push(
561                        CheckLevel::Warn,
562                        format!("Ignoring {path} because skip_remote is set to true"),
563                    );
564                    return Ok(ShapingFile::default());
565                }
566
567                // To facilitate startup ordering races, and listing multiple subscription
568                // host replicas and allowing one or more of them to be temporarily down,
569                // we allow the http request to fail.
570                // We'll log the error message but consider it to be an empty map
571
572                async fn http_get(url: &str, timeout: Duration) -> anyhow::Result<String> {
573                    tokio::time::timeout(timeout, async {
574                        reqwest::Client::builder()
575                            .timeout(timeout)
576                            .connect_timeout(timeout)
577                            .read_timeout(timeout)
578                            .build()?
579                            .get(url)
580                            .send()
581                            .await
582                            .with_context(|| format!("making HTTP request to {url}"))?
583                            .text()
584                            .await
585                            .with_context(|| format!("reading text from {url}"))
586                    })
587                    .await
588                    .with_context(|| format!("timeout making HTTP request to {url}"))?
589                }
590
591                let timeout = options.http_timeout.unwrap_or(Duration::from_secs(5));
592
593                match http_get(path, timeout).await {
594                    Ok(s) => (s, options.remote_load),
595                    Err(err) => {
596                        tracing::error!("{err:#}. Ignoring this shaping source for now");
597                        collector.push(
598                            options.remote_load,
599                            format!("remote shaping source {path} error: {err:#}"),
600                        );
601                        return Ok(ShapingFile::default());
602                    }
603                }
604            } else {
605                match std::fs::read_to_string(path)
606                    .with_context(|| format!("loading data from file {path}"))
607                {
608                    Err(err) => {
609                        collector.push(
610                            options.local_load,
611                            format!("local shaping source {path} error: {err:#}"),
612                        );
613                        return Ok(ShapingFile::default());
614                    }
615                    Ok(s) => (s, options.local_load),
616                }
617            };
618
619        if path.ends_with(".toml") {
620            from_toml(&data).with_context(|| format!("parsing toml from file {path}"))
621        } else if path.ends_with(".json") {
622            from_json(&data).with_context(|| format!("parsing json from file {path}"))
623        } else {
624            // Try parsing both ways and see which wins
625            let mut errors = vec![];
626            match from_toml(&data) {
627                Ok(s) => return Ok(s),
628                Err(err) => errors.push(format!("as toml: {err:#}")),
629            }
630            match from_json(&data) {
631                Ok(s) => return Ok(s),
632                Err(err) => errors.push(format!("as json: {err:#}")),
633            }
634
635            collector.push(level, format!("parsing {path}: {}", errors.join(", ")));
636            Ok(ShapingFile::default())
637        }
638    }
639
640    pub async fn merge_files(
641        files: &[String],
642        options: &ShapingMergeOptions,
643    ) -> anyhow::Result<Self> {
644        use futures_util::stream::FuturesUnordered;
645        use futures_util::StreamExt;
646
647        let mut collector = Collector::default();
648        let mut loaded = vec![];
649        for p in files {
650            loaded.push(Self::load_from_file(p, options, &mut collector).await?);
651        }
652
653        let mut by_site: OrderMap<String, PartialEntry> = OrderMap::new();
654        let mut by_domain: OrderMap<String, PartialEntry> = OrderMap::new();
655        let mut by_provider: OrderMap<String, ProviderEntry> = OrderMap::new();
656        let mut site_aliases: OrderMap<String, Vec<String>> = OrderMap::new();
657
658        // Pre-resolve domains. We don't interleave the resolution with
659        // the work below, because we want to ensure that the ordering
660        // is preserved
661        let mut mx = std::collections::HashMap::new();
662        let mut lookups = FuturesUnordered::new();
663        for item in &loaded {
664            for (domain, partial) in &item.domains {
665                if partial.mx_rollup {
666                    let domain = domain.to_string();
667                    lookups.push(tokio::spawn(async move {
668                        let mx_result = MailExchanger::resolve(&domain).await;
669                        (domain, mx_result)
670                    }));
671                }
672            }
673        }
674
675        while let Some(Ok((domain, result))) = lookups.next().await {
676            mx.insert(domain, result);
677        }
678
679        for mut item in loaded {
680            if let Some(mut partial) = item.default.take() {
681                let domain = "default";
682                partial.domain_name.replace(domain.to_string());
683                match by_domain.get_mut(domain) {
684                    Some(existing) => {
685                        existing.merge_from(partial);
686                    }
687                    None => {
688                        by_domain.insert(domain.to_string(), partial);
689                    }
690                }
691            }
692
693            for (domain, mut partial) in item.domains {
694                partial.domain_name.replace(domain.clone());
695
696                if let Ok(name) = fully_qualify(&domain) {
697                    if name.num_labels() == 1 {
698                        collector.push(
699                            CheckLevel::Warn,
700                            format!(
701                                "Entry for domain '{domain}' consists of a \
702                                 single DNS label. Domain names in TOML sections \
703                                 need to be quoted like '[\"{domain}.com\"]` otherwise \
704                                 the '.' will create a nested table rather than being \
705                                 added to the domain name."
706                            ),
707                        );
708                    }
709                }
710
711                #[cfg(test)]
712                if partial._treat_domain_name_as_site_name {
713                    match by_site.get_mut(&domain) {
714                        Some(existing) => {
715                            existing.merge_from(partial);
716                        }
717                        None => {
718                            by_site.insert(domain.to_string(), partial);
719                        }
720                    }
721                    continue;
722                }
723                if partial.mx_rollup {
724                    let mx = match mx.get(&domain) {
725                        Some(Ok(mx)) => mx,
726                        Some(Err(err)) => {
727                            collector.push(
728                                options.dns_fail,
729                                format!(
730                                    "error resolving MX for {domain}: {err:#}. \
731                                 Ignoring the shaping config for that domain."
732                                ),
733                            );
734                            continue;
735                        }
736                        None => {
737                            collector.push(
738                                options.dns_fail,
739                                format!(
740                                "We didn't try to resolve the MX for {domain} for some reason!?. \
741                                 Ignoring the shaping config for that domain."
742                            ),
743                            );
744                            continue;
745                        }
746                    };
747
748                    if mx.site_name.is_empty() {
749                        collector.push(
750                            options.null_mx,
751                            format!(
752                            "domain {domain} has a NULL MX and cannot be used with mx_rollup=true. \
753                             Ignoring the shaping config for that domain."),
754                        );
755                        continue;
756                    }
757
758                    site_aliases
759                        .entry(mx.site_name.to_string())
760                        .or_default()
761                        .push(domain.to_string());
762
763                    match by_site.get_mut(&mx.site_name) {
764                        Some(existing) => {
765                            existing.merge_from(partial);
766                        }
767                        None => {
768                            by_site.insert(mx.site_name.clone(), partial);
769                        }
770                    }
771                } else {
772                    match by_domain.get_mut(&domain) {
773                        Some(existing) => {
774                            existing.merge_from(partial);
775                        }
776                        None => {
777                            by_domain.insert(domain, partial);
778                        }
779                    }
780                }
781            }
782
783            for (provider, mut prov) in item.provider {
784                prov.provider_name = provider.to_string();
785                match by_provider.get_mut(&provider) {
786                    Some(existing) => {
787                        existing.merge_from(prov);
788                    }
789                    None => {
790                        by_provider.insert(provider.to_string(), prov);
791                    }
792                }
793            }
794        }
795
796        for (site, partial) in &by_site {
797            partial
798                .clone()
799                .finish()
800                .with_context(|| format!("site: {site}"))?;
801        }
802
803        for (domain, partial) in &by_domain {
804            partial
805                .clone()
806                .finish()
807                .with_context(|| format!("domain: {domain}"))?;
808        }
809
810        for (provider, prov) in &by_provider {
811            prov.finish_params()
812                .with_context(|| format!("provider: {provider}"))?;
813        }
814
815        if options.aliased_site != CheckLevel::Ignore {
816            for (site, aliases) in site_aliases {
817                if aliases.len() > 1 {
818                    collector.push(
819                        options.aliased_site,
820                        format!(
821                            "multiple domain blocks alias to the same site: {site}: {}",
822                            aliases.join(", ")
823                        ),
824                    );
825                }
826            }
827        }
828
829        if options.provider_overlap != CheckLevel::Ignore {
830            for domain in mx.keys() {
831                let mut matching_providers = vec![];
832                for (prov_name, prov) in &by_provider {
833                    if prov.domain_matches(domain).await {
834                        matching_providers.push(prov_name.to_string());
835                    }
836                }
837                if !matching_providers.is_empty() {
838                    collector.push(
839                        options.provider_overlap,
840                        format!(
841                            "domain {domain} is also matched by provider(s): {}",
842                            matching_providers.join(", ")
843                        ),
844                    );
845                }
846            }
847        }
848
849        let mut ctx = Sha256::new();
850        ctx.update("by_site");
851        for (site, entry) in &by_site {
852            ctx.update(site);
853            entry.hash_into(&mut ctx);
854        }
855        ctx.update("by_domain");
856        for (domain, entry) in &by_domain {
857            ctx.update(domain);
858            entry.hash_into(&mut ctx);
859        }
860        ctx.update("by_provider");
861        for (provider, prov) in &by_provider {
862            ctx.update(provider);
863            prov.hash_into(&mut ctx);
864        }
865        ctx.update("warnings");
866        for warn in &collector.warnings {
867            ctx.update(warn);
868        }
869        ctx.update("errors");
870        for err in &collector.errors {
871            ctx.update(err);
872        }
873        let hash = ctx.finalize();
874        let hash = data_encoding::HEXLOWER.encode(&hash);
875
876        Ok(Self {
877            inner: Arc::new(ShapingInner {
878                by_site,
879                by_domain,
880                by_provider,
881                warnings: collector.warnings,
882                errors: collector.errors,
883                hash,
884            }),
885        })
886    }
887
888    async fn get_egress_path_config(
889        &self,
890        domain: &str,
891        egress_source: &str,
892        site_name: &str,
893    ) -> PartialEntry {
894        self.inner
895            .get_egress_path_config(domain, egress_source, site_name)
896            .await
897    }
898
899    pub async fn get_egress_path_config_value(
900        &self,
901        domain: &str,
902        egress_source: &str,
903        site_name: &str,
904    ) -> anyhow::Result<serde_json::Value> {
905        let partial = self
906            .get_egress_path_config(domain, egress_source, site_name)
907            .await;
908        Ok(serde_json::to_value(&partial)?)
909    }
910
911    pub fn get_errors(&self) -> &[String] {
912        &self.inner.errors
913    }
914
915    pub fn get_warnings(&self) -> &[String] {
916        &self.inner.warnings
917    }
918
919    pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
920        self.inner.match_rules(record).await
921    }
922
923    pub fn get_referenced_sources(&self) -> BTreeMap<String, Vec<String>> {
924        let mut result = BTreeMap::new();
925
926        for (site_name, site) in &self.inner.by_site {
927            for source_name in site.sources.keys() {
928                result
929                    .entry(source_name.to_string())
930                    .or_insert(vec![])
931                    .push(format!("site:{site_name}"));
932            }
933        }
934        for (domain_name, domain) in &self.inner.by_domain {
935            for source_name in domain.sources.keys() {
936                result
937                    .entry(source_name.to_string())
938                    .or_insert(vec![])
939                    .push(format!("domain:{domain_name}"));
940            }
941        }
942
943        result
944    }
945
946    pub fn hash(&self) -> String {
947        self.inner.hash.clone()
948    }
949}
950
951#[cfg(feature = "lua")]
952impl LuaUserData for Shaping {
953    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
954        mod_memoize::Memoized::impl_memoize(methods);
955        methods.add_async_method(
956            "get_egress_path_config",
957            |lua, this, (domain, egress_source, site_name): (String, String, String)| async move {
958                let params = this
959                    .get_egress_path_config(&domain, &egress_source, &site_name)
960                    .await;
961                lua.to_value_with(&params.params, serialize_options())
962            },
963        );
964
965        methods.add_method("get_errors", move |_lua, this, ()| {
966            let errors: Vec<String> = this.get_errors().iter().map(|s| s.to_string()).collect();
967            Ok(errors)
968        });
969
970        methods.add_method("get_warnings", move |_lua, this, ()| {
971            let warnings: Vec<String> = this.get_warnings().iter().map(|s| s.to_string()).collect();
972            Ok(warnings)
973        });
974
975        methods.add_method("get_referenced_sources", move |_lua, this, ()| {
976            Ok(this.get_referenced_sources())
977        });
978
979        methods.add_async_method("match_rules", |lua, this, record: mlua::Value| async move {
980            let record: JsonLogRecord = lua.from_value(record)?;
981            let rules = this.match_rules(&record).await.map_err(any_err)?;
982            let mut result = vec![];
983            for rule in rules {
984                result.push(lua.to_value(&rule)?);
985            }
986            Ok(result)
987        });
988
989        methods.add_method("hash", move |_, this, ()| Ok(this.hash()));
990    }
991}
992
993#[derive(Default, Debug)]
994pub struct MergedEntry {
995    pub params: EgressPathConfig,
996    pub sources: OrderMap<String, EgressPathConfig>,
997    pub automation: Vec<Rule>,
998}
999
1000#[cfg(feature = "lua")]
1001#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1002struct ShapingFile {
1003    pub default: Option<PartialEntry>,
1004    #[serde(flatten, default)]
1005    pub domains: OrderMap<String, PartialEntry>,
1006    #[serde(default)]
1007    pub provider: OrderMap<String, ProviderEntry>,
1008}
1009
1010#[cfg(feature = "lua")]
1011#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1012struct PartialEntry {
1013    #[serde(skip)]
1014    pub domain_name: Option<String>,
1015
1016    #[serde(flatten)]
1017    pub params: toml::Table,
1018
1019    #[serde(default = "default_true")]
1020    pub mx_rollup: bool,
1021
1022    // This is present to facilitate unit testing without requiring
1023    // DNS to resolve the site_name. When set to true, the domain_name
1024    // is considered to be the site_name for this entry.
1025    #[cfg(test)]
1026    #[serde(default)]
1027    pub _treat_domain_name_as_site_name: bool,
1028
1029    #[serde(default)]
1030    pub replace_base: bool,
1031
1032    #[serde(default)]
1033    pub automation: Vec<Rule>,
1034
1035    #[serde(default)]
1036    pub sources: OrderMap<String, toml::Table>,
1037}
1038
1039#[cfg(feature = "lua")]
1040#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1041pub struct ProviderEntry {
1042    #[serde(skip, default)]
1043    pub provider_name: String,
1044
1045    #[serde(default)]
1046    pub provider_connection_limit: Option<LimitSpec>,
1047
1048    #[serde(default)]
1049    pub provider_max_message_rate: Option<ThrottleSpec>,
1050
1051    #[serde(default, rename = "match")]
1052    pub matches: Vec<ProviderMatch>,
1053
1054    #[serde(default)]
1055    pub replace_base: bool,
1056
1057    #[serde(flatten)]
1058    pub params: toml::Table,
1059
1060    #[serde(default)]
1061    pub automation: Vec<Rule>,
1062
1063    #[serde(default)]
1064    pub sources: OrderMap<String, toml::Table>,
1065}
1066
1067#[cfg(feature = "lua")]
1068fn suffix_matches(candidate: &str, suffix: &str) -> bool {
1069    // Remove trailing dot from candidate, as our resolver tends
1070    // to leave the canonical dot on the input host name
1071    let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1072    // We DON'T handle case mismatches in this code.
1073    // We assume that the dns-resolver crate normalized the
1074    // names to lowercase and that the input suffix is lowercase as well.
1075    candidate.ends_with(suffix)
1076}
1077
1078#[cfg(feature = "lua")]
1079#[cfg(test)]
1080#[test]
1081fn test_suffix_matches() {
1082    assert!(suffix_matches("a", "a"));
1083    assert!(suffix_matches("foo.com", "foo.com"));
1084    assert!(!suffix_matches("foo.com", ".foo.com"));
1085    assert!(!suffix_matches("foo.com", "longer.com"));
1086    assert!(!suffix_matches("réputation.net", ".mx.microsoft"));
1087
1088    // We DON'T handle case mismatches in this code.
1089    // We assume that the dns-resolver crate normalized the
1090    // names to lowercase and that the input suffix is lowercase as well.
1091    assert!(!suffix_matches("foo.Com", ".com"));
1092    assert!(!suffix_matches("foo.Cam", ".com"));
1093}
1094
1095#[cfg(feature = "lua")]
1096fn host_matches(candidate: &str, name: &str) -> bool {
1097    // Remove trailing dot from candidate, as our resolver tends
1098    // to leave the canonical dot on the input host name
1099    let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1100    candidate == name
1101}
1102
1103#[cfg(feature = "lua")]
1104#[cfg(test)]
1105#[test]
1106fn test_host_matches() {
1107    assert!(host_matches("foo.com", "foo.com"));
1108    assert!(host_matches("foo.com.", "foo.com"));
1109    assert!(!host_matches("foo.com", "notfoo.com"));
1110}
1111
1112#[cfg(feature = "lua")]
1113impl ProviderEntry {
1114    async fn domain_matches(&self, domain: &str) -> bool {
1115        // We'd like to avoid doing DNS if we can do a simple suffix match,
1116        // so we bias to looking at those first
1117        let mut need_mx = false;
1118
1119        tracing::trace!(
1120            "ProviderEntry::domain_matches({domain}) vs {:?}",
1121            self.matches
1122        );
1123
1124        for rule in &self.matches {
1125            match rule {
1126                ProviderMatch::DomainSuffix(suffix) => {
1127                    if suffix_matches(domain, suffix) {
1128                        tracing::trace!("{domain} suffix matches {suffix}");
1129                        return true;
1130                    }
1131                }
1132                ProviderMatch::HostName(_) | ProviderMatch::MXSuffix(_) => {
1133                    need_mx = true;
1134                }
1135            }
1136        }
1137
1138        if !need_mx {
1139            return false;
1140        }
1141
1142        // Now we can consider DNS
1143        match MailExchanger::resolve(domain).await {
1144            Err(err) => {
1145                // Didn't resolve; could be legit, for example, could
1146                // be an internal or fake name that is handled via
1147                // smart host or custom routing, so we log what
1148                // happened as a trace rather than polluting the
1149                // logs about it
1150                tracing::trace!(
1151                    "Error resolving MX for {domain}: {err:#}. \
1152                    Provider {} match rules will be ignored",
1153                    self.provider_name
1154                );
1155                false
1156            }
1157            Ok(mx) => {
1158                tracing::trace!("Consider MXSuffix rules");
1159                for host in &mx.hosts {
1160                    let mut matched = false;
1161
1162                    for rule in &self.matches {
1163                        match rule {
1164                            ProviderMatch::MXSuffix(suffix) => {
1165                                // For a given MX suffix rule, all hosts must match
1166                                // it for it to be valid. This is so that we don't
1167                                // falsely lump a vanity domain that blends providers
1168                                // together.
1169                                tracing::trace!("suffix={suffix} vs host {host}");
1170                                if suffix_matches(host, suffix) {
1171                                    matched = true;
1172                                    break;
1173                                }
1174                            }
1175                            ProviderMatch::HostName(name) => {
1176                                if host_matches(host, name) {
1177                                    matched = true;
1178                                    break;
1179                                }
1180                            }
1181                            ProviderMatch::DomainSuffix(_) => {}
1182                        }
1183                    }
1184
1185                    if !matched {
1186                        tracing::trace!("host didn't match any of these rules");
1187                        return false;
1188                    }
1189                }
1190
1191                true
1192            }
1193        }
1194    }
1195
1196    fn merge_from(&mut self, mut other: Self) {
1197        if other.replace_base {
1198            self.provider_connection_limit = other.provider_connection_limit;
1199            self.matches = other.matches;
1200            self.params = other.params;
1201            self.sources = other.sources;
1202            self.automation = other.automation;
1203        } else {
1204            if other.provider_connection_limit.is_some() {
1205                self.provider_connection_limit = other.provider_connection_limit;
1206            }
1207
1208            toml_table_merge_from(&mut self.params, &other.params);
1209
1210            for (source, tbl) in other.sources {
1211                match self.sources.get_mut(&source) {
1212                    Some(existing) => {
1213                        toml_table_merge_from(existing, &tbl);
1214                    }
1215                    None => {
1216                        self.sources.insert(source, tbl);
1217                    }
1218                }
1219            }
1220
1221            self.matches.append(&mut other.matches);
1222            self.automation.append(&mut other.automation);
1223        }
1224    }
1225
1226    fn apply_provider_params_to(&self, source: &str, target: &mut toml::Table) {
1227        let mut implied = toml::Table::new();
1228        implied.insert(
1229            "provider_name".to_string(),
1230            toml::Value::String(self.provider_name.to_string()),
1231        );
1232
1233        if let Some(limit) = &self.provider_connection_limit {
1234            let mut limits = toml::Table::new();
1235            limits.insert(
1236                format!("shaping-provider-{}-{source}-limit", self.provider_name),
1237                toml::Value::String(limit.to_string()),
1238            );
1239            implied.insert(
1240                "additional_connection_limits".to_string(),
1241                toml::Value::Table(limits),
1242            );
1243        }
1244        if let Some(rate) = &self.provider_max_message_rate {
1245            let rate = rate.as_string();
1246            let mut limits = toml::Table::new();
1247            limits.insert(
1248                format!("shaping-provider-{}-{source}-rate", self.provider_name),
1249                rate.into(),
1250            );
1251            implied.insert(
1252                "additional_message_rate_throttles".to_string(),
1253                toml::Value::Table(limits),
1254            );
1255        }
1256
1257        if let Some(rate) = target.remove("provider_source_selection_rate") {
1258            let mut limits = toml::Table::new();
1259            limits.insert(
1260                format!(
1261                    "shaping-provider-{}-{source}-selection-rate",
1262                    self.provider_name
1263                ),
1264                rate,
1265            );
1266            implied.insert(
1267                "additional_source_selection_rates".to_string(),
1268                toml::Value::Table(limits),
1269            );
1270        }
1271
1272        toml_table_merge_from(target, &implied);
1273    }
1274
1275    fn finish_params(&self) -> anyhow::Result<MergedEntry> {
1276        let provider_name = &self.provider_name;
1277
1278        let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1279            format!(
1280                "interpreting provider '{provider_name}' params {:#?} as EgressPathConfig",
1281                self.params
1282            )
1283        })?;
1284        let mut sources = OrderMap::new();
1285
1286        for (source, params) in &self.sources {
1287            let mut params = params.clone();
1288            // I don't really like this remove call. The issue is that we don't
1289            // have an alternative way to filter this out of the partial source
1290            // definition, and this provider_ option is not valid in the
1291            // EgressPathConfig struct itself; it is a shaping source-specific
1292            // addition to help with setting up shared throttles across providers
1293            params.remove("provider_source_selection_rate");
1294            sources.insert(
1295                source.clone(),
1296                EgressPathConfig::deserialize(params.clone()).with_context(|| {
1297                    format!("interpreting provider '{provider_name}' source '{source}' {params:#} as EgressPathConfig")
1298                })?,
1299            );
1300        }
1301
1302        Ok(MergedEntry {
1303            params,
1304            sources,
1305            automation: self.automation.clone(),
1306        })
1307    }
1308
1309    fn hash_into(&self, ctx: &mut Sha256) {
1310        ctx.update(&self.provider_name);
1311        ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1312    }
1313}
1314
1315#[cfg(feature = "lua")]
1316#[derive(Deserialize, Serialize, Debug, Clone)]
1317pub enum ProviderMatch {
1318    MXSuffix(String),
1319    DomainSuffix(String),
1320    HostName(String),
1321}
1322
1323#[cfg(feature = "lua")]
1324fn toml_table_merge_from(tbl: &mut toml::Table, source: &toml::Table) {
1325    // Limit merging to just the throttle related fields, as their purpose
1326    // is for creating broader scoped limits that cut across normal boundaries
1327    fn is_mergeable(name: &str) -> bool {
1328        match name {
1329            "additional_connection_limits"
1330            | "additional_message_rate_throttles"
1331            | "additional_source_selection_rates" => true,
1332            _ => false,
1333        }
1334    }
1335
1336    for (k, v) in source {
1337        match (tbl.get_mut(k), v.as_table()) {
1338            // Merge Table values together, rather than simply replacing them.
1339            (Some(toml::Value::Table(existing)), Some(v)) if is_mergeable(k) => {
1340                for (inner_k, inner_v) in v {
1341                    existing.insert(inner_k.clone(), inner_v.clone());
1342                }
1343            }
1344            _ => {
1345                tbl.insert(k.clone(), v.clone());
1346            }
1347        }
1348    }
1349}
1350
1351#[cfg(feature = "lua")]
1352impl PartialEntry {
1353    fn merge_from(&mut self, mut other: Self) {
1354        if other.replace_base {
1355            self.params = other.params;
1356            self.automation = other.automation;
1357            self.sources = other.sources;
1358        } else {
1359            toml_table_merge_from(&mut self.params, &other.params);
1360
1361            for (source, tbl) in other.sources {
1362                match self.sources.get_mut(&source) {
1363                    Some(existing) => {
1364                        toml_table_merge_from(existing, &tbl);
1365                    }
1366                    None => {
1367                        self.sources.insert(source, tbl);
1368                    }
1369                }
1370            }
1371
1372            self.automation.append(&mut other.automation);
1373        }
1374    }
1375
1376    fn finish(self) -> anyhow::Result<MergedEntry> {
1377        let domain = self.domain_name.unwrap_or_default();
1378
1379        let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1380            format!(
1381                "interpreting domain '{domain}' params {:#?} as EgressPathConfig",
1382                self.params
1383            )
1384        })?;
1385        let mut sources = OrderMap::new();
1386
1387        for (source, params) in self.sources {
1388            sources.insert(
1389                source.clone(),
1390                EgressPathConfig::deserialize(params.clone()).with_context(|| {
1391                    format!("interpreting domain '{domain}' source '{source}' {params:#} as EgressPathConfig")
1392                })?,
1393            );
1394        }
1395
1396        Ok(MergedEntry {
1397            params,
1398            sources,
1399            automation: self.automation,
1400        })
1401    }
1402
1403    fn hash_into(&self, ctx: &mut Sha256) {
1404        if let Some(name) = self.domain_name.as_ref() {
1405            ctx.update(name)
1406        }
1407        ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1408    }
1409}
1410
1411fn one_or_many<'de, T, D>(deserializer: D, expecting: &str) -> Result<Vec<T>, D::Error>
1412where
1413    T: Deserialize<'de>,
1414    D: Deserializer<'de>,
1415{
1416    let result: Result<Vec<T>, _> =
1417        OneOrMany::<serde_with::Same, PreferOne>::deserialize_as(deserializer);
1418    match result {
1419        Ok(r) => Ok(r),
1420        Err(err) => Err(serde::de::Error::custom(format!(
1421            "{expecting}.\nThe underlying error message is:\n{err:#}"
1422        ))),
1423    }
1424}
1425
1426fn one_or_many_action<'de, D>(deserializer: D) -> Result<Vec<Action>, D::Error>
1427where
1428    D: Deserializer<'de>,
1429{
1430    one_or_many(
1431        deserializer,
1432        "\"action\" field expected either a single Action or an array of Actions",
1433    )
1434}
1435
1436fn regex_string_or_array<'de, D>(deserializer: D) -> Result<Vec<Regex>, D::Error>
1437where
1438    D: Deserializer<'de>,
1439{
1440    string_or_array(
1441        deserializer,
1442        "regex string or array of regex strings for field regex",
1443    )
1444}
1445
1446fn string_or_array<'de, T, D>(deserializer: D, expecting: &'static str) -> Result<Vec<T>, D::Error>
1447where
1448    T: Deserialize<'de> + TryFrom<String>,
1449    <T as TryFrom<String>>::Error: std::fmt::Debug,
1450    D: Deserializer<'de>,
1451{
1452    // This is a Visitor that forwards string types to T's `TryFrom<String>` impl and
1453    // forwards map types to T's `Deserialize` impl. The `PhantomData` is to
1454    // keep the compiler from complaining about T being an unused generic type
1455    // parameter. We need T in order to know the Value type for the Visitor
1456    // impl.
1457    struct StringOrArray<T>(PhantomData<fn() -> T>, &'static str);
1458
1459    impl<'de, T> Visitor<'de> for StringOrArray<T>
1460    where
1461        T: Deserialize<'de> + TryFrom<String>,
1462        <T as TryFrom<String>>::Error: std::fmt::Debug,
1463    {
1464        type Value = Vec<T>;
1465
1466        fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
1467            formatter.write_str(self.1)
1468        }
1469
1470        fn visit_str<E>(self, value: &str) -> Result<Vec<T>, E>
1471        where
1472            E: serde::de::Error,
1473        {
1474            Ok(vec![
1475                T::try_from(value.to_string()).map_err(|e| E::custom(format!("{e:?}")))?
1476            ])
1477        }
1478
1479        fn visit_seq<S>(self, seq: S) -> Result<Vec<T>, S::Error>
1480        where
1481            S: SeqAccess<'de>,
1482        {
1483            Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))
1484        }
1485    }
1486
1487    deserializer.deserialize_any(StringOrArray(PhantomData, expecting))
1488}
1489
1490#[cfg(feature = "lua")]
1491fn default_true() -> bool {
1492    true
1493}
1494
1495#[cfg(feature = "lua")]
1496pub fn register(lua: &mlua::Lua) -> anyhow::Result<()> {
1497    let shaping_mod = config::get_or_create_sub_module(lua, "shaping")?;
1498
1499    shaping_mod.set(
1500        "load",
1501        lua.create_async_function(
1502            move |lua, (paths, options): (Vec<String>, Option<mlua::Value>)| async move {
1503                let options = match options {
1504                    Some(v) => lua.from_value(v)?,
1505                    None => Default::default(),
1506                };
1507                let shaping = Shaping::merge_files(&paths, &options)
1508                    .await
1509                    .map_err(any_err)?;
1510                Ok(shaping)
1511            },
1512        )?,
1513    )?;
1514
1515    Ok(())
1516}
1517
1518#[cfg(test)]
1519mod test {
1520    use super::*;
1521    use kumo_log_types::RecordType;
1522    use rfc5321::Response;
1523    use std::io::Write;
1524    use tempfile::NamedTempFile;
1525    use uuid::Uuid;
1526
1527    async fn make_shaping_configs(inputs: &[&str]) -> Shaping {
1528        let mut files = vec![];
1529        let mut file_names = vec![];
1530
1531        for (i, content) in inputs.iter().enumerate() {
1532            let mut shaping_file = NamedTempFile::with_prefix(format!("file{i}")).unwrap();
1533            shaping_file.write_all(content.as_bytes()).unwrap();
1534            file_names.push(shaping_file.path().to_str().unwrap().to_string());
1535            files.push(shaping_file);
1536        }
1537
1538        Shaping::merge_files(&file_names, &ShapingMergeOptions::default())
1539            .await
1540            .unwrap()
1541    }
1542
1543    #[tokio::test]
1544    async fn test_merge_additional() {
1545        let shaping = make_shaping_configs(&[
1546            r#"
1547["example.com"]
1548mx_rollup = false
1549additional_connection_limits = {"first"=10}
1550        "#,
1551            r#"
1552["example.com"]
1553mx_rollup = false
1554additional_connection_limits = {"second"=32}
1555additional_message_rate_throttles = {"second"="100/hr"}
1556        "#,
1557        ])
1558        .await;
1559
1560        let resolved = shaping
1561            .get_egress_path_config("example.com", "invalid.source", "invalid.site")
1562            .await
1563            .finish()
1564            .unwrap();
1565
1566        k9::snapshot!(
1567            resolved.params.additional_connection_limits,
1568            r#"
1569{
1570    "first": 10,
1571    "second": 32,
1572}
1573"#
1574        );
1575        k9::snapshot!(
1576            resolved.params.additional_message_rate_throttles,
1577            r#"
1578{
1579    "second": 100/h,
1580}
1581"#
1582        );
1583    }
1584
1585    #[tokio::test]
1586    async fn test_provider_multi_hostname() {
1587        let shaping = make_shaping_configs(&[r#"
1588[provider."yahoo"]
1589match=[{HostName="mta5.am0.yahoodns.net"},{HostName="mta6.am0.yahoodns.net"},{HostName="mta7.am0.yahoodns.net"}]
1590enable_tls = "Required"
1591        "#])
1592        .await;
1593
1594        let resolved = shaping
1595            .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1596            .await
1597            .finish()
1598            .unwrap();
1599
1600        k9::assert_equal!(
1601            resolved.params.enable_tls,
1602            crate::egress_path::Tls::Required
1603        );
1604        k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1605    }
1606
1607    #[tokio::test]
1608    async fn test_provider_multi_suffix() {
1609        let shaping = make_shaping_configs(&[r#"
1610[provider."yahoo"]
1611match=[{MXSuffix="mta5.am0.yahoodns.net"},{MXSuffix="mta6.am0.yahoodns.net"},{MXSuffix="mta7.am0.yahoodns.net"}]
1612enable_tls = "Required"
1613        "#])
1614        .await;
1615
1616        let resolved = shaping
1617            .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1618            .await
1619            .finish()
1620            .unwrap();
1621
1622        k9::assert_equal!(
1623            resolved.params.enable_tls,
1624            crate::egress_path::Tls::Required
1625        );
1626        k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1627    }
1628
1629    #[tokio::test]
1630    async fn test_provider() {
1631        let shaping = make_shaping_configs(&[r#"
1632[provider."Office 365"]
1633match=[{MXSuffix=".olc.protection.outlook.com"},{DomainSuffix=".outlook.com"}]
1634enable_tls = "Required"
1635provider_connection_limit = 10
1636provider_max_message_rate = "120/s"
1637
1638[provider."Office 365".sources."new-source"]
1639provider_source_selection_rate = "500/d,max_burst=1"
1640        "#])
1641        .await;
1642
1643        let resolved = shaping
1644            .get_egress_path_config("outlook.com", "invalid.source", "invalid.site")
1645            .await
1646            .finish()
1647            .unwrap();
1648
1649        k9::assert_equal!(
1650            resolved.params.enable_tls,
1651            crate::egress_path::Tls::Required
1652        );
1653        k9::assert_equal!(resolved.params.provider_name.unwrap(), "Office 365");
1654
1655        k9::snapshot!(
1656            resolved.params.additional_connection_limits,
1657            r#"
1658{
1659    "shaping-provider-Office 365-invalid.source-limit": 10,
1660}
1661"#
1662        );
1663        k9::snapshot!(
1664            resolved.params.additional_message_rate_throttles,
1665            r#"
1666{
1667    "shaping-provider-Office 365-invalid.source-rate": 120/s,
1668}
1669"#
1670        );
1671        assert!(resolved.params.source_selection_rate.is_none());
1672        assert!(resolved.params.additional_source_selection_rates.is_empty());
1673
1674        let resolved = shaping
1675            .get_egress_path_config("outlook.com", "new-source", "invalid.site")
1676            .await
1677            .finish()
1678            .unwrap();
1679        assert!(resolved.params.source_selection_rate.is_none());
1680        k9::snapshot!(
1681            resolved.params.additional_source_selection_rates,
1682            r#"
1683{
1684    "shaping-provider-Office 365-new-source-selection-rate": 500/d,max_burst=1,
1685}
1686"#
1687        );
1688    }
1689
1690    #[tokio::test]
1691    async fn test_rule_matching() {
1692        let shaping = make_shaping_configs(&[r#"
1693[["default".automation]]
1694regex="default"
1695action = {SetConfig={name="connection_limit", value=1}}
1696duration = "1hr"
1697
1698["fake.site"]
1699_treat_domain_name_as_site_name = true
1700
1701[["fake.site".automation]]
1702regex="fake_rollup"
1703action = {SetConfig={name="connection_limit", value=2}}
1704duration = "1hr"
1705
1706["woot.provider"]
1707mx_rollup = false
1708
1709[["woot.provider".automation]]
1710regex="woot_domain"
1711action = {SetConfig={name="connection_limit", value=2}}
1712duration = "1hr"
1713
1714[provider."provider"]
1715match=[{DomainSuffix=".provider"}]
1716
1717[[provider."provider".automation]]
1718regex="provider"
1719action = {SetConfig={name="connection_limit", value=3}}
1720duration = "1hr"
1721match_internal = true
1722
1723"#])
1724        .await;
1725
1726        eprintln!("{:?}", shaping.inner.warnings);
1727
1728        fn make_record(content: &str, recipient: &str, site: &str) -> JsonLogRecord {
1729            JsonLogRecord {
1730                kind: RecordType::TransientFailure,
1731                id: String::new(),
1732                sender: String::new(),
1733                recipient: recipient.to_string(),
1734                queue: String::new(),
1735                site: site.to_string(),
1736                size: 0,
1737                response: Response {
1738                    code: 400,
1739                    command: None,
1740                    enhanced_code: None,
1741                    content: content.to_string(),
1742                },
1743                peer_address: None,
1744                timestamp: Default::default(),
1745                created: Default::default(),
1746                num_attempts: 1,
1747                bounce_classification: Default::default(),
1748                egress_pool: None,
1749                egress_source: None,
1750                source_address: None,
1751                feedback_report: None,
1752                meta: Default::default(),
1753                headers: Default::default(),
1754                delivery_protocol: None,
1755                reception_protocol: None,
1756                nodeid: Uuid::default(),
1757                tls_cipher: None,
1758                tls_protocol_version: None,
1759                tls_peer_subject_name: None,
1760                provider_name: None,
1761                session_id: None,
1762            }
1763        }
1764
1765        let matches = shaping
1766            .match_rules(&make_record("default", "user@example.com", "dummy_site"))
1767            .await
1768            .unwrap();
1769        k9::assert_equal!(
1770            matches[0].regex[0].to_string(),
1771            "default",
1772            "matches against default automation rule"
1773        );
1774
1775        let matches = shaping
1776            .match_rules(&make_record(
1777                "KumoMTA internal: default",
1778                "user@example.com",
1779                "dummy_site",
1780            ))
1781            .await
1782            .unwrap();
1783        assert!(matches.is_empty(), "internal bounce should not match");
1784
1785        let matches = shaping
1786            .match_rules(&make_record(
1787                "woot_domain",
1788                "user@woot.provider",
1789                "dummy_site",
1790            ))
1791            .await
1792            .unwrap();
1793        k9::assert_equal!(
1794            matches[0].regex[0].to_string(),
1795            "woot_domain",
1796            "matches against domain rule with mx_rollup=false"
1797        );
1798
1799        let matches = shaping
1800            .match_rules(&make_record("fake_rollup", "user@fake.rollup", "fake.site"))
1801            .await
1802            .unwrap();
1803        k9::assert_equal!(
1804            matches[0].regex[0].to_string(),
1805            "fake_rollup",
1806            "matches against domain rule with mx_rollup=true"
1807        );
1808
1809        let matches = shaping
1810            .match_rules(&make_record("provider", "user@woot.provider", "dummy_site"))
1811            .await
1812            .unwrap();
1813        k9::assert_equal!(
1814            matches[0].regex[0].to_string(),
1815            "provider",
1816            "matches against provider rule"
1817        );
1818
1819        let matches = shaping
1820            .match_rules(&make_record(
1821                "KumoMTA internal: provider",
1822                "user@woot.provider",
1823                "dummy_site",
1824            ))
1825            .await
1826            .unwrap();
1827        k9::assert_equal!(
1828            matches[0].regex[0].to_string(),
1829            "provider",
1830            "internal response matches against provider rule"
1831        );
1832    }
1833
1834    #[tokio::test]
1835    async fn test_defaults() {
1836        let shaping = make_shaping_configs(&[
1837            r#"
1838["default"]
1839connection_limit = 10
1840max_connection_rate = "100/min"
1841max_deliveries_per_connection = 100
1842max_message_rate = "100/s"
1843idle_timeout = "60s"
1844data_timeout = "30s"
1845data_dot_timeout = "60s"
1846enable_tls = "Opportunistic"
1847consecutive_connection_failures_before_delay = 100
1848
1849[["default".automation]]
1850regex=[
1851        '/Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/',
1852        '/All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/',
1853        '/has been temporarily rate limited due to IP reputation/',
1854        '/Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/',
1855        '/Server busy\. Please try again later from/'
1856]
1857action = [
1858        {SetConfig={name="max_message_rate", value="1/minute"}},
1859        {SetConfig={name="connection_limit", value=1}}
1860]
1861duration = "90m"
1862
1863[["default".automation]]
1864regex="KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site"
1865action = {SetConfig={name="enable_tls", value="Disabled"}}
1866duration = "30 days"
1867
1868["gmail.com"]
1869max_deliveries_per_connection = 50
1870connection_limit = 5
1871enable_tls = "Required"
1872consecutive_connection_failures_before_delay = 5
1873
1874["yahoo.com"]
1875max_deliveries_per_connection = 20
1876
1877[["yahoo.com".automation]]
1878regex = "\\[TS04\\]"
1879action = "Suspend"
1880duration = "2 hours"
1881
1882["comcast.net"]
1883connection_limit = 25
1884max_deliveries_per_connection = 250
1885enable_tls = "Required"
1886idle_timeout = "30s"
1887consecutive_connection_failures_before_delay = 24
1888
1889["mail.com"]
1890max_deliveries_per_connection = 100
1891
1892["orange.fr"]
1893connection_limit = 3
1894
1895["smtp.mailgun.com"]
1896connection_limit = 7000
1897max_deliveries_per_connection = 3
1898
1899["example.com"]
1900mx_rollup = false
1901max_deliveries_per_connection = 100
1902connection_limit = 3
1903
1904["example.com".sources."my source name"]
1905connection_limit = 5
1906        "#,
1907        ])
1908        .await;
1909
1910        let default = shaping
1911            .get_egress_path_config("invalid.domain", "invalid.source", "invalid.site")
1912            .await
1913            .finish()
1914            .unwrap();
1915        k9::snapshot!(
1916            default,
1917            r#"
1918MergedEntry {
1919    params: EgressPathConfig {
1920        connection_limit: 10,
1921        additional_connection_limits: {},
1922        enable_tls: Opportunistic,
1923        enable_mta_sts: true,
1924        enable_dane: false,
1925        enable_pipelining: true,
1926        enable_rset: true,
1927        tls_prefer_openssl: false,
1928        openssl_cipher_list: None,
1929        openssl_cipher_suites: None,
1930        openssl_options: None,
1931        rustls_cipher_suites: [],
1932        client_timeouts: SmtpClientTimeouts {
1933            connect_timeout: 60s,
1934            banner_timeout: 60s,
1935            ehlo_timeout: 300s,
1936            mail_from_timeout: 300s,
1937            rcpt_to_timeout: 300s,
1938            data_timeout: 30s,
1939            data_dot_timeout: 60s,
1940            rset_timeout: 5s,
1941            idle_timeout: 60s,
1942            starttls_timeout: 5s,
1943            auth_timeout: 60s,
1944        },
1945        system_shutdown_timeout: None,
1946        max_ready: 1024,
1947        consecutive_connection_failures_before_delay: 100,
1948        smtp_port: 25,
1949        smtp_auth_plain_username: None,
1950        smtp_auth_plain_password: None,
1951        allow_smtp_auth_plain_without_tls: false,
1952        max_message_rate: Some(
1953            100/s,
1954        ),
1955        additional_message_rate_throttles: {},
1956        source_selection_rate: None,
1957        additional_source_selection_rates: {},
1958        max_connection_rate: Some(
1959            100/m,
1960        ),
1961        max_deliveries_per_connection: 100,
1962        prohibited_hosts: {
1963            "127.0.0.0/8",
1964            "::1",
1965        },
1966        skip_hosts: {},
1967        ehlo_domain: None,
1968        aggressive_connection_opening: false,
1969        refresh_interval: 60s,
1970        refresh_strategy: Ttl,
1971        dispatcher_wakeup_strategy: Aggressive,
1972        maintainer_wakeup_strategy: Aggressive,
1973        provider_name: None,
1974        remember_broken_tls: None,
1975        opportunistic_tls_reconnect_on_failed_handshake: false,
1976        use_lmtp: false,
1977        reconnect_strategy: ConnectNextHost,
1978        readyq_pool_name: None,
1979        low_memory_reduction_policy: ShrinkDataAndMeta,
1980        no_memory_reduction_policy: ShrinkDataAndMeta,
1981    },
1982    sources: {},
1983    automation: [
1984        Rule {
1985            regex: [
1986                Regex(
1987                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
1988                ),
1989                Regex(
1990                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
1991                ),
1992                Regex(
1993                    /has been temporarily rate limited due to IP reputation/,
1994                ),
1995                Regex(
1996                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
1997                ),
1998                Regex(
1999                    /Server busy\. Please try again later from/,
2000                ),
2001            ],
2002            action: [
2003                SetConfig(
2004                    EgressPathConfigValue {
2005                        name: "max_message_rate",
2006                        value: HashableTomlValue {
2007                            value: String(
2008                                "1/minute",
2009                            ),
2010                        },
2011                    },
2012                ),
2013                SetConfig(
2014                    EgressPathConfigValue {
2015                        name: "connection_limit",
2016                        value: HashableTomlValue {
2017                            value: Integer(
2018                                1,
2019                            ),
2020                        },
2021                    },
2022                ),
2023            ],
2024            trigger: Immediate,
2025            duration: 5400s,
2026            was_rollup: false,
2027            match_internal: false,
2028        },
2029        Rule {
2030            regex: [
2031                Regex(
2032                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2033                ),
2034            ],
2035            action: [
2036                SetConfig(
2037                    EgressPathConfigValue {
2038                        name: "enable_tls",
2039                        value: HashableTomlValue {
2040                            value: String(
2041                                "Disabled",
2042                            ),
2043                        },
2044                    },
2045                ),
2046            ],
2047            trigger: Immediate,
2048            duration: 2592000s,
2049            was_rollup: false,
2050            match_internal: false,
2051        },
2052    ],
2053}
2054"#
2055        );
2056
2057        let example_com = shaping
2058            .get_egress_path_config("example.com", "invalid.source", "invalid.site")
2059            .await
2060            .finish()
2061            .unwrap();
2062        k9::snapshot!(
2063            example_com,
2064            r#"
2065MergedEntry {
2066    params: EgressPathConfig {
2067        connection_limit: 3,
2068        additional_connection_limits: {},
2069        enable_tls: Opportunistic,
2070        enable_mta_sts: true,
2071        enable_dane: false,
2072        enable_pipelining: true,
2073        enable_rset: true,
2074        tls_prefer_openssl: false,
2075        openssl_cipher_list: None,
2076        openssl_cipher_suites: None,
2077        openssl_options: None,
2078        rustls_cipher_suites: [],
2079        client_timeouts: SmtpClientTimeouts {
2080            connect_timeout: 60s,
2081            banner_timeout: 60s,
2082            ehlo_timeout: 300s,
2083            mail_from_timeout: 300s,
2084            rcpt_to_timeout: 300s,
2085            data_timeout: 30s,
2086            data_dot_timeout: 60s,
2087            rset_timeout: 5s,
2088            idle_timeout: 60s,
2089            starttls_timeout: 5s,
2090            auth_timeout: 60s,
2091        },
2092        system_shutdown_timeout: None,
2093        max_ready: 1024,
2094        consecutive_connection_failures_before_delay: 100,
2095        smtp_port: 25,
2096        smtp_auth_plain_username: None,
2097        smtp_auth_plain_password: None,
2098        allow_smtp_auth_plain_without_tls: false,
2099        max_message_rate: Some(
2100            100/s,
2101        ),
2102        additional_message_rate_throttles: {},
2103        source_selection_rate: None,
2104        additional_source_selection_rates: {},
2105        max_connection_rate: Some(
2106            100/m,
2107        ),
2108        max_deliveries_per_connection: 100,
2109        prohibited_hosts: {
2110            "127.0.0.0/8",
2111            "::1",
2112        },
2113        skip_hosts: {},
2114        ehlo_domain: None,
2115        aggressive_connection_opening: false,
2116        refresh_interval: 60s,
2117        refresh_strategy: Ttl,
2118        dispatcher_wakeup_strategy: Aggressive,
2119        maintainer_wakeup_strategy: Aggressive,
2120        provider_name: None,
2121        remember_broken_tls: None,
2122        opportunistic_tls_reconnect_on_failed_handshake: false,
2123        use_lmtp: false,
2124        reconnect_strategy: ConnectNextHost,
2125        readyq_pool_name: None,
2126        low_memory_reduction_policy: ShrinkDataAndMeta,
2127        no_memory_reduction_policy: ShrinkDataAndMeta,
2128    },
2129    sources: {
2130        "my source name": EgressPathConfig {
2131            connection_limit: 5,
2132            additional_connection_limits: {},
2133            enable_tls: Opportunistic,
2134            enable_mta_sts: true,
2135            enable_dane: false,
2136            enable_pipelining: true,
2137            enable_rset: true,
2138            tls_prefer_openssl: false,
2139            openssl_cipher_list: None,
2140            openssl_cipher_suites: None,
2141            openssl_options: None,
2142            rustls_cipher_suites: [],
2143            client_timeouts: SmtpClientTimeouts {
2144                connect_timeout: 60s,
2145                banner_timeout: 60s,
2146                ehlo_timeout: 300s,
2147                mail_from_timeout: 300s,
2148                rcpt_to_timeout: 300s,
2149                data_timeout: 300s,
2150                data_dot_timeout: 300s,
2151                rset_timeout: 5s,
2152                idle_timeout: 5s,
2153                starttls_timeout: 5s,
2154                auth_timeout: 60s,
2155            },
2156            system_shutdown_timeout: None,
2157            max_ready: 1024,
2158            consecutive_connection_failures_before_delay: 100,
2159            smtp_port: 25,
2160            smtp_auth_plain_username: None,
2161            smtp_auth_plain_password: None,
2162            allow_smtp_auth_plain_without_tls: false,
2163            max_message_rate: None,
2164            additional_message_rate_throttles: {},
2165            source_selection_rate: None,
2166            additional_source_selection_rates: {},
2167            max_connection_rate: None,
2168            max_deliveries_per_connection: 1024,
2169            prohibited_hosts: {
2170                "127.0.0.0/8",
2171                "::1",
2172            },
2173            skip_hosts: {},
2174            ehlo_domain: None,
2175            aggressive_connection_opening: false,
2176            refresh_interval: 60s,
2177            refresh_strategy: Ttl,
2178            dispatcher_wakeup_strategy: Aggressive,
2179            maintainer_wakeup_strategy: Aggressive,
2180            provider_name: None,
2181            remember_broken_tls: None,
2182            opportunistic_tls_reconnect_on_failed_handshake: false,
2183            use_lmtp: false,
2184            reconnect_strategy: ConnectNextHost,
2185            readyq_pool_name: None,
2186            low_memory_reduction_policy: ShrinkDataAndMeta,
2187            no_memory_reduction_policy: ShrinkDataAndMeta,
2188        },
2189    },
2190    automation: [
2191        Rule {
2192            regex: [
2193                Regex(
2194                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2195                ),
2196                Regex(
2197                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2198                ),
2199                Regex(
2200                    /has been temporarily rate limited due to IP reputation/,
2201                ),
2202                Regex(
2203                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2204                ),
2205                Regex(
2206                    /Server busy\. Please try again later from/,
2207                ),
2208            ],
2209            action: [
2210                SetConfig(
2211                    EgressPathConfigValue {
2212                        name: "max_message_rate",
2213                        value: HashableTomlValue {
2214                            value: String(
2215                                "1/minute",
2216                            ),
2217                        },
2218                    },
2219                ),
2220                SetConfig(
2221                    EgressPathConfigValue {
2222                        name: "connection_limit",
2223                        value: HashableTomlValue {
2224                            value: Integer(
2225                                1,
2226                            ),
2227                        },
2228                    },
2229                ),
2230            ],
2231            trigger: Immediate,
2232            duration: 5400s,
2233            was_rollup: false,
2234            match_internal: false,
2235        },
2236        Rule {
2237            regex: [
2238                Regex(
2239                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2240                ),
2241            ],
2242            action: [
2243                SetConfig(
2244                    EgressPathConfigValue {
2245                        name: "enable_tls",
2246                        value: HashableTomlValue {
2247                            value: String(
2248                                "Disabled",
2249                            ),
2250                        },
2251                    },
2252                ),
2253            ],
2254            trigger: Immediate,
2255            duration: 2592000s,
2256            was_rollup: false,
2257            match_internal: false,
2258        },
2259    ],
2260}
2261"#
2262        );
2263
2264        // The site name here will need to be updated if yahoo changes
2265        // their MX records
2266        let yahoo_com = shaping
2267            .get_egress_path_config(
2268                "yahoo.com",
2269                "invalid.source",
2270                "(mta5|mta6|mta7).am0.yahoodns.net",
2271            )
2272            .await
2273            .finish()
2274            .unwrap();
2275        k9::snapshot!(
2276            yahoo_com,
2277            r#"
2278MergedEntry {
2279    params: EgressPathConfig {
2280        connection_limit: 10,
2281        additional_connection_limits: {},
2282        enable_tls: Opportunistic,
2283        enable_mta_sts: true,
2284        enable_dane: false,
2285        enable_pipelining: true,
2286        enable_rset: true,
2287        tls_prefer_openssl: false,
2288        openssl_cipher_list: None,
2289        openssl_cipher_suites: None,
2290        openssl_options: None,
2291        rustls_cipher_suites: [],
2292        client_timeouts: SmtpClientTimeouts {
2293            connect_timeout: 60s,
2294            banner_timeout: 60s,
2295            ehlo_timeout: 300s,
2296            mail_from_timeout: 300s,
2297            rcpt_to_timeout: 300s,
2298            data_timeout: 30s,
2299            data_dot_timeout: 60s,
2300            rset_timeout: 5s,
2301            idle_timeout: 60s,
2302            starttls_timeout: 5s,
2303            auth_timeout: 60s,
2304        },
2305        system_shutdown_timeout: None,
2306        max_ready: 1024,
2307        consecutive_connection_failures_before_delay: 100,
2308        smtp_port: 25,
2309        smtp_auth_plain_username: None,
2310        smtp_auth_plain_password: None,
2311        allow_smtp_auth_plain_without_tls: false,
2312        max_message_rate: Some(
2313            100/s,
2314        ),
2315        additional_message_rate_throttles: {},
2316        source_selection_rate: None,
2317        additional_source_selection_rates: {},
2318        max_connection_rate: Some(
2319            100/m,
2320        ),
2321        max_deliveries_per_connection: 20,
2322        prohibited_hosts: {
2323            "127.0.0.0/8",
2324            "::1",
2325        },
2326        skip_hosts: {},
2327        ehlo_domain: None,
2328        aggressive_connection_opening: false,
2329        refresh_interval: 60s,
2330        refresh_strategy: Ttl,
2331        dispatcher_wakeup_strategy: Aggressive,
2332        maintainer_wakeup_strategy: Aggressive,
2333        provider_name: None,
2334        remember_broken_tls: None,
2335        opportunistic_tls_reconnect_on_failed_handshake: false,
2336        use_lmtp: false,
2337        reconnect_strategy: ConnectNextHost,
2338        readyq_pool_name: None,
2339        low_memory_reduction_policy: ShrinkDataAndMeta,
2340        no_memory_reduction_policy: ShrinkDataAndMeta,
2341    },
2342    sources: {},
2343    automation: [
2344        Rule {
2345            regex: [
2346                Regex(
2347                    /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2348                ),
2349                Regex(
2350                    /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2351                ),
2352                Regex(
2353                    /has been temporarily rate limited due to IP reputation/,
2354                ),
2355                Regex(
2356                    /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2357                ),
2358                Regex(
2359                    /Server busy\. Please try again later from/,
2360                ),
2361            ],
2362            action: [
2363                SetConfig(
2364                    EgressPathConfigValue {
2365                        name: "max_message_rate",
2366                        value: HashableTomlValue {
2367                            value: String(
2368                                "1/minute",
2369                            ),
2370                        },
2371                    },
2372                ),
2373                SetConfig(
2374                    EgressPathConfigValue {
2375                        name: "connection_limit",
2376                        value: HashableTomlValue {
2377                            value: Integer(
2378                                1,
2379                            ),
2380                        },
2381                    },
2382                ),
2383            ],
2384            trigger: Immediate,
2385            duration: 5400s,
2386            was_rollup: false,
2387            match_internal: false,
2388        },
2389        Rule {
2390            regex: [
2391                Regex(
2392                    KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2393                ),
2394            ],
2395            action: [
2396                SetConfig(
2397                    EgressPathConfigValue {
2398                        name: "enable_tls",
2399                        value: HashableTomlValue {
2400                            value: String(
2401                                "Disabled",
2402                            ),
2403                        },
2404                    },
2405                ),
2406            ],
2407            trigger: Immediate,
2408            duration: 2592000s,
2409            was_rollup: false,
2410            match_internal: false,
2411        },
2412        Rule {
2413            regex: [
2414                Regex(
2415                    \[TS04\],
2416                ),
2417            ],
2418            action: [
2419                Suspend,
2420            ],
2421            trigger: Immediate,
2422            duration: 7200s,
2423            was_rollup: false,
2424            match_internal: false,
2425        },
2426    ],
2427}
2428"#
2429        );
2430    }
2431
2432    #[tokio::test]
2433    async fn test_load_default_shaping_toml() {
2434        Shaping::merge_files(
2435            &["../../assets/policy-extras/shaping.toml".into()],
2436            &ShapingMergeOptions::default(),
2437        )
2438        .await
2439        .unwrap();
2440    }
2441}