1use crate::egress_path::EgressPathConfig;
2#[cfg(feature = "lua")]
3use anyhow::Context;
4#[cfg(feature = "lua")]
5use config::any_err;
6#[cfg(feature = "lua")]
7use config::serialize_options;
8#[cfg(feature = "lua")]
9use dns_resolver::{fully_qualify, MailExchanger};
10#[cfg(feature = "lua")]
11use kumo_log_types::JsonLogRecord;
12#[cfg(feature = "lua")]
13use mlua::prelude::LuaUserData;
14#[cfg(feature = "lua")]
15use mlua::{LuaSerdeExt, UserDataMethods};
16use ordermap::OrderMap;
17use serde::de::{SeqAccess, Visitor};
18use serde::{Deserialize, Deserializer, Serialize};
19use serde_with::formats::PreferOne;
20use serde_with::{serde_as, DeserializeAs, OneOrMany};
21#[cfg(feature = "lua")]
22use sha2::{Digest, Sha256};
23#[cfg(feature = "lua")]
24use std::collections::BTreeMap;
25use std::hash::{Hash, Hasher};
26use std::marker::PhantomData;
27#[cfg(feature = "lua")]
28use std::sync::Arc;
29use std::time::Duration;
30#[cfg(feature = "lua")]
31use throttle::LimitSpec;
32use throttle::ThrottleSpec;
33
34#[derive(Deserialize, Serialize, Debug, Clone)]
35#[serde(try_from = "String", into = "String")]
36pub struct Regex(fancy_regex::Regex);
37
38impl TryFrom<String> for Regex {
39 type Error = fancy_regex::Error;
40
41 fn try_from(s: String) -> fancy_regex::Result<Self> {
42 Ok(Self(fancy_regex::Regex::new(&s)?))
43 }
44}
45
46impl From<Regex> for String {
47 fn from(r: Regex) -> String {
48 r.0.as_str().to_string()
49 }
50}
51
52impl std::ops::Deref for Regex {
53 type Target = fancy_regex::Regex;
54 fn deref(&self) -> &Self::Target {
55 &self.0
56 }
57}
58
59impl std::hash::Hash for Regex {
60 fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
61 self.0.as_str().hash(hasher)
62 }
63}
64
65#[derive(Deserialize, Serialize, Debug, Clone)]
71#[serde(from = "toml::Value", into = "toml::Value")]
72pub struct HashableTomlValue {
73 value: toml::Value,
74}
75
76impl From<toml::Value> for HashableTomlValue {
77 fn from(value: toml::Value) -> Self {
78 Self { value }
79 }
80}
81
82impl From<HashableTomlValue> for toml::Value {
83 fn from(value: HashableTomlValue) -> toml::Value {
84 value.value
85 }
86}
87
88impl std::ops::Deref for HashableTomlValue {
89 type Target = toml::Value;
90 fn deref(&self) -> &toml::Value {
91 &self.value
92 }
93}
94
95fn hash_toml<H>(value: &toml::Value, h: &mut H)
96where
97 H: Hasher,
98{
99 match value {
100 toml::Value::Boolean(v) => v.hash(h),
101 toml::Value::Datetime(v) => {
102 if let Some(d) = &v.date {
103 d.year.hash(h);
104 d.month.hash(h);
105 d.day.hash(h);
106 }
107 if let Some(t) = &v.time {
108 t.hour.hash(h);
109 t.minute.hash(h);
110 t.second.hash(h);
111 t.nanosecond.hash(h);
112 }
113 if let Some(toml::value::Offset::Custom { minutes }) = &v.offset {
114 minutes.hash(h);
115 }
116 }
117 toml::Value::String(v) => v.hash(h),
118 toml::Value::Integer(v) => v.hash(h),
119 toml::Value::Float(v) => v.to_ne_bytes().hash(h),
120 toml::Value::Array(a) => {
121 for v in a.iter() {
122 hash_toml(v, h);
123 }
124 }
125 toml::Value::Table(m) => {
126 for (k, v) in m.iter() {
127 k.hash(h);
128 hash_toml(v, h);
129 }
130 }
131 }
132}
133
134impl Hash for HashableTomlValue {
135 fn hash<H>(&self, h: &mut H)
136 where
137 H: Hasher,
138 {
139 hash_toml(&self.value, h);
140 }
141}
142
143#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
146#[serde(
147 try_from = "EgressPathConfigValueUnchecked",
148 into = "EgressPathConfigValueUnchecked"
149)]
150pub struct EgressPathConfigValue {
151 pub name: String,
152 pub value: HashableTomlValue,
153}
154
155#[derive(Deserialize, Serialize, Debug, Clone)]
161pub struct EgressPathConfigValueUnchecked {
162 pub name: String,
163 pub value: toml::Value,
164}
165
166impl TryFrom<EgressPathConfigValueUnchecked> for EgressPathConfigValue {
167 type Error = anyhow::Error;
168 fn try_from(config: EgressPathConfigValueUnchecked) -> anyhow::Result<EgressPathConfigValue> {
169 let mut map = toml::map::Map::new();
170 map.insert(config.name.clone(), config.value.clone());
171 let table = toml::Value::Table(map);
172
173 EgressPathConfig::deserialize(table)?;
176
177 Ok(EgressPathConfigValue {
179 name: config.name,
180 value: HashableTomlValue {
181 value: config.value,
182 },
183 })
184 }
185}
186
187impl From<EgressPathConfigValue> for EgressPathConfigValueUnchecked {
188 fn from(config: EgressPathConfigValue) -> EgressPathConfigValueUnchecked {
189 EgressPathConfigValueUnchecked {
190 name: config.name,
191 value: config.value.value,
192 }
193 }
194}
195
196#[derive(Deserialize, Serialize, Debug, Clone, Hash)]
197pub enum Action {
198 Suspend,
199 SetConfig(EgressPathConfigValue),
200 SuspendTenant,
201 SuspendCampaign,
202 SetDomainConfig(EgressPathConfigValue),
203 Bounce,
204 BounceTenant,
205 BounceCampaign,
206}
207
208#[derive(Deserialize, Serialize, Debug, Clone, Hash, Default)]
209pub enum Trigger {
210 #[default]
212 Immediate,
213 Threshold(ThrottleSpec),
216}
217
218#[serde_as]
219#[derive(Deserialize, Serialize, Debug, Hash, Clone)]
220pub struct Rule {
221 #[serde(deserialize_with = "regex_string_or_array")]
222 pub regex: Vec<Regex>,
223
224 #[serde(deserialize_with = "one_or_many_action")]
225 pub action: Vec<Action>,
226
227 #[serde(default)]
228 pub trigger: Trigger,
229
230 #[serde(with = "duration_serde")]
231 pub duration: Duration,
232
233 #[serde(skip)]
234 pub was_rollup: bool,
235
236 #[serde(default)]
240 pub match_internal: bool,
241}
242
243impl Rule {
244 pub fn matches(&self, is_internal: bool, response: &str) -> bool {
245 if is_internal && !self.match_internal {
246 return false;
247 }
248 self.regex
249 .iter()
250 .any(|r| r.is_match(response).unwrap_or(false))
251 }
252
253 pub fn clone_and_set_rollup(&self) -> Self {
254 let mut result = self.clone();
255 result.was_rollup = true;
256 result
257 }
258}
259
260#[cfg(feature = "lua")]
261#[derive(Debug, Default)]
262struct ShapingInner {
263 by_site: OrderMap<String, PartialEntry>,
264 by_domain: OrderMap<String, PartialEntry>,
265 by_provider: OrderMap<String, ProviderEntry>,
266 warnings: Vec<String>,
267 errors: Vec<String>,
268 hash: String,
269}
270
271#[cfg(feature = "lua")]
272impl ShapingInner {
273 pub async fn get_egress_path_config(
274 &self,
275 domain: &str,
276 egress_source: &str,
277 site_name: &str,
278 ) -> PartialEntry {
279 let mut params = PartialEntry::default();
280
281 if let Some(default) = self.by_domain.get("default") {
283 params.merge_from(default.clone());
284 }
285
286 let is_domain_name = dns_resolver::Name::from_str_relaxed(domain).is_ok();
291 if is_domain_name {
292 let mut prov_with_sources = vec![];
293
294 for prov in self.by_provider.values() {
295 if prov.domain_matches(domain).await {
296 toml_table_merge_from(&mut params.params, &prov.params);
297 prov.apply_provider_params_to(egress_source, &mut params.params);
298
299 if !prov.sources.is_empty() {
300 prov_with_sources.push(prov);
304 }
305 }
306 }
307
308 for prov in prov_with_sources {
310 if let Some(source) = prov.sources.get(egress_source) {
311 toml_table_merge_from(&mut params.params, source);
312 prov.apply_provider_params_to(egress_source, &mut params.params);
313 }
314 }
315 }
316
317 if let Some(by_site) = self.by_site.get(site_name) {
319 params.merge_from(by_site.clone());
320 }
321
322 if let Some(by_domain) = self.by_domain.get(domain) {
324 params.merge_from(by_domain.clone());
325 }
326
327 if let Some(by_site) = self.by_site.get(site_name) {
329 if let Some(source) = by_site.sources.get(egress_source) {
330 toml_table_merge_from(&mut params.params, source);
331 }
332 }
333
334 if let Some(by_domain) = self.by_domain.get(domain) {
336 if let Some(source) = by_domain.sources.get(egress_source) {
337 toml_table_merge_from(&mut params.params, source);
338 }
339 }
340
341 params
342 }
343
344 pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
345 use rfc5321::ForwardPath;
346 let recipient = ForwardPath::try_from(record.recipient.as_str())
348 .map_err(|err| anyhow::anyhow!("parsing record.recipient: {err}"))?;
349
350 let recipient = match recipient {
351 ForwardPath::Postmaster => {
352 return Ok(vec![]);
355 }
356 ForwardPath::Path(path) => path.mailbox,
357 };
358 let domain = recipient.domain.to_string();
359
360 let source = record.egress_source.as_deref().unwrap_or("unspecified");
362 let site_name = record
375 .site
376 .trim_start_matches(&format!("{source}->"))
377 .trim_end_matches("@smtp_client")
378 .to_string();
379
380 Ok(self.match_rules_impl(record, &domain, &site_name).await)
381 }
382
383 pub async fn match_rules_impl(
384 &self,
385 record: &JsonLogRecord,
386 domain: &str,
387 site_name: &str,
388 ) -> Vec<Rule> {
389 let mut result = vec![];
390 let response = record.response.to_single_line();
391 tracing::trace!("Consider rules for {response}");
392
393 let is_internal = record.response.content.starts_with("KumoMTA internal: ");
394
395 if let Some(default) = self.by_domain.get("default") {
396 for rule in &default.automation {
397 tracing::trace!("Consider \"default\" rule {rule:?} for {response}");
398 if rule.matches(is_internal, &response) {
399 result.push(rule.clone_and_set_rollup());
406 }
407 }
408 }
409
410 for prov in self.by_provider.values() {
411 if prov.domain_matches(domain).await {
412 for rule in &prov.automation {
413 tracing::trace!(
414 "Consider provider \"{}\" rule {rule:?} for {response}",
415 prov.provider_name
416 );
417 if rule.matches(is_internal, &response) {
418 result.push(rule.clone());
419 }
420 }
421 }
422 }
423
424 if let Some(by_site) = self.by_site.get(site_name) {
426 for rule in &by_site.automation {
427 tracing::trace!("Consider \"{site_name}\" rule {rule:?} for {response}");
428 if rule.matches(is_internal, &response) {
429 result.push(rule.clone_and_set_rollup());
430 }
431 }
432 }
433
434 if let Some(by_domain) = self.by_domain.get(domain) {
436 for rule in &by_domain.automation {
437 tracing::trace!("Consider \"{domain}\" rule {rule:?} for {response}");
438 if rule.matches(is_internal, &response) {
439 result.push(rule.clone());
440 }
441 }
442 }
443
444 result
445 }
446}
447
448#[cfg(feature = "lua")]
449#[derive(Debug, Default, Clone, mlua::FromLua)]
450pub struct Shaping {
451 inner: Arc<ShapingInner>,
452}
453
454#[cfg(feature = "lua")]
455fn from_json<'a, T: Deserialize<'a>>(json: &'a str) -> anyhow::Result<T> {
456 let d = &mut serde_json::Deserializer::from_str(json);
457 Ok(serde_path_to_error::deserialize(d)?)
458}
459
460#[cfg(feature = "lua")]
461fn from_toml<'a, T: Deserialize<'a>>(toml: &'a str) -> anyhow::Result<T> {
462 let d = toml::Deserializer::new(toml);
463 Ok(serde_path_to_error::deserialize(d)?)
464}
465
466#[cfg(feature = "lua")]
467#[derive(Default, Debug, Clone, Copy, Deserialize, PartialEq, Eq)]
468pub enum CheckLevel {
469 #[default]
470 Ignore,
471 Warn,
472 Error,
473}
474#[cfg(feature = "lua")]
475impl std::str::FromStr for CheckLevel {
476 type Err = String;
477
478 fn from_str(s: &str) -> Result<Self, String> {
479 if s.eq_ignore_ascii_case("ignore") {
480 Ok(Self::Ignore)
481 } else if s.eq_ignore_ascii_case("warn") {
482 Ok(Self::Warn)
483 } else if s.eq_ignore_ascii_case("error") {
484 Ok(Self::Error)
485 } else {
486 Err(format!(
487 "Expected one of `Ignore`, `Warn` or `Error`, got `{s}`"
488 ))
489 }
490 }
491}
492
493#[cfg(feature = "lua")]
494#[derive(Default)]
495struct Collector {
496 warnings: Vec<String>,
497 errors: Vec<String>,
498}
499
500#[cfg(feature = "lua")]
501impl Collector {
502 fn push<S: Into<String>>(&mut self, level: CheckLevel, msg: S) {
503 match level {
504 CheckLevel::Ignore => {}
505 CheckLevel::Warn => self.warnings.push(msg.into()),
506 CheckLevel::Error => self.errors.push(msg.into()),
507 }
508 }
509}
510
511#[cfg(feature = "lua")]
512#[derive(Debug, Clone, Deserialize)]
513#[serde_as]
514#[serde(deny_unknown_fields)]
515pub struct ShapingMergeOptions {
516 #[serde(default)]
517 pub provider_overlap: CheckLevel,
518 #[serde(default)]
519 pub dns_fail: CheckLevel,
520 #[serde(default)]
521 pub null_mx: CheckLevel,
522 #[serde(default)]
523 pub aliased_site: CheckLevel,
524 #[serde(default)]
525 pub skip_remote: bool,
526 #[serde(default)]
527 pub remote_load: CheckLevel,
528 #[serde(default)]
529 pub local_load: CheckLevel,
530 #[serde(default, with = "duration_serde")]
531 pub http_timeout: Option<Duration>,
532}
533
534#[cfg(feature = "lua")]
535impl Default for ShapingMergeOptions {
536 fn default() -> Self {
537 Self {
538 provider_overlap: CheckLevel::Ignore,
539 dns_fail: CheckLevel::Ignore,
540 null_mx: CheckLevel::Ignore,
541 aliased_site: CheckLevel::Ignore,
542 skip_remote: false,
543 remote_load: CheckLevel::Ignore,
544 local_load: CheckLevel::Error,
545 http_timeout: None,
546 }
547 }
548}
549
550#[cfg(feature = "lua")]
551impl Shaping {
552 async fn load_from_file(
553 path: &str,
554 options: &ShapingMergeOptions,
555 collector: &mut Collector,
556 ) -> anyhow::Result<ShapingFile> {
557 let (data, level): (String, CheckLevel) =
558 if path.starts_with("http://") || path.starts_with("https://") {
559 if options.skip_remote {
560 collector.push(
561 CheckLevel::Warn,
562 format!("Ignoring {path} because skip_remote is set to true"),
563 );
564 return Ok(ShapingFile::default());
565 }
566
567 async fn http_get(url: &str, timeout: Duration) -> anyhow::Result<String> {
573 tokio::time::timeout(timeout, async {
574 reqwest::Client::builder()
575 .timeout(timeout)
576 .connect_timeout(timeout)
577 .read_timeout(timeout)
578 .build()?
579 .get(url)
580 .send()
581 .await
582 .with_context(|| format!("making HTTP request to {url}"))?
583 .text()
584 .await
585 .with_context(|| format!("reading text from {url}"))
586 })
587 .await
588 .with_context(|| format!("timeout making HTTP request to {url}"))?
589 }
590
591 let timeout = options.http_timeout.unwrap_or(Duration::from_secs(5));
592
593 match http_get(path, timeout).await {
594 Ok(s) => (s, options.remote_load),
595 Err(err) => {
596 tracing::error!("{err:#}. Ignoring this shaping source for now");
597 collector.push(
598 options.remote_load,
599 format!("remote shaping source {path} error: {err:#}"),
600 );
601 return Ok(ShapingFile::default());
602 }
603 }
604 } else {
605 match std::fs::read_to_string(path)
606 .with_context(|| format!("loading data from file {path}"))
607 {
608 Err(err) => {
609 collector.push(
610 options.local_load,
611 format!("local shaping source {path} error: {err:#}"),
612 );
613 return Ok(ShapingFile::default());
614 }
615 Ok(s) => (s, options.local_load),
616 }
617 };
618
619 if path.ends_with(".toml") {
620 from_toml(&data).with_context(|| format!("parsing toml from file {path}"))
621 } else if path.ends_with(".json") {
622 from_json(&data).with_context(|| format!("parsing json from file {path}"))
623 } else {
624 let mut errors = vec![];
626 match from_toml(&data) {
627 Ok(s) => return Ok(s),
628 Err(err) => errors.push(format!("as toml: {err:#}")),
629 }
630 match from_json(&data) {
631 Ok(s) => return Ok(s),
632 Err(err) => errors.push(format!("as json: {err:#}")),
633 }
634
635 collector.push(level, format!("parsing {path}: {}", errors.join(", ")));
636 Ok(ShapingFile::default())
637 }
638 }
639
640 pub async fn merge_files(
641 files: &[String],
642 options: &ShapingMergeOptions,
643 ) -> anyhow::Result<Self> {
644 use futures_util::stream::FuturesUnordered;
645 use futures_util::StreamExt;
646
647 let mut collector = Collector::default();
648 let mut loaded = vec![];
649 for p in files {
650 loaded.push(Self::load_from_file(p, options, &mut collector).await?);
651 }
652
653 let mut by_site: OrderMap<String, PartialEntry> = OrderMap::new();
654 let mut by_domain: OrderMap<String, PartialEntry> = OrderMap::new();
655 let mut by_provider: OrderMap<String, ProviderEntry> = OrderMap::new();
656 let mut site_aliases: OrderMap<String, Vec<String>> = OrderMap::new();
657
658 let mut mx = std::collections::HashMap::new();
662 let mut lookups = FuturesUnordered::new();
663 for item in &loaded {
664 for (domain, partial) in &item.domains {
665 if partial.mx_rollup {
666 let domain = domain.to_string();
667 lookups.push(tokio::spawn(async move {
668 let mx_result = MailExchanger::resolve(&domain).await;
669 (domain, mx_result)
670 }));
671 }
672 }
673 }
674
675 while let Some(Ok((domain, result))) = lookups.next().await {
676 mx.insert(domain, result);
677 }
678
679 for mut item in loaded {
680 if let Some(mut partial) = item.default.take() {
681 let domain = "default";
682 partial.domain_name.replace(domain.to_string());
683 match by_domain.get_mut(domain) {
684 Some(existing) => {
685 existing.merge_from(partial);
686 }
687 None => {
688 by_domain.insert(domain.to_string(), partial);
689 }
690 }
691 }
692
693 for (domain, mut partial) in item.domains {
694 partial.domain_name.replace(domain.clone());
695
696 if let Ok(name) = fully_qualify(&domain) {
697 if name.num_labels() == 1 {
698 collector.push(
699 CheckLevel::Warn,
700 format!(
701 "Entry for domain '{domain}' consists of a \
702 single DNS label. Domain names in TOML sections \
703 need to be quoted like '[\"{domain}.com\"]` otherwise \
704 the '.' will create a nested table rather than being \
705 added to the domain name."
706 ),
707 );
708 }
709 }
710
711 #[cfg(test)]
712 if partial._treat_domain_name_as_site_name {
713 match by_site.get_mut(&domain) {
714 Some(existing) => {
715 existing.merge_from(partial);
716 }
717 None => {
718 by_site.insert(domain.to_string(), partial);
719 }
720 }
721 continue;
722 }
723 if partial.mx_rollup {
724 let mx = match mx.get(&domain) {
725 Some(Ok(mx)) => mx,
726 Some(Err(err)) => {
727 collector.push(
728 options.dns_fail,
729 format!(
730 "error resolving MX for {domain}: {err:#}. \
731 Ignoring the shaping config for that domain."
732 ),
733 );
734 continue;
735 }
736 None => {
737 collector.push(
738 options.dns_fail,
739 format!(
740 "We didn't try to resolve the MX for {domain} for some reason!?. \
741 Ignoring the shaping config for that domain."
742 ),
743 );
744 continue;
745 }
746 };
747
748 if mx.site_name.is_empty() {
749 collector.push(
750 options.null_mx,
751 format!(
752 "domain {domain} has a NULL MX and cannot be used with mx_rollup=true. \
753 Ignoring the shaping config for that domain."),
754 );
755 continue;
756 }
757
758 site_aliases
759 .entry(mx.site_name.to_string())
760 .or_default()
761 .push(domain.to_string());
762
763 match by_site.get_mut(&mx.site_name) {
764 Some(existing) => {
765 existing.merge_from(partial);
766 }
767 None => {
768 by_site.insert(mx.site_name.clone(), partial);
769 }
770 }
771 } else {
772 match by_domain.get_mut(&domain) {
773 Some(existing) => {
774 existing.merge_from(partial);
775 }
776 None => {
777 by_domain.insert(domain, partial);
778 }
779 }
780 }
781 }
782
783 for (provider, mut prov) in item.provider {
784 prov.provider_name = provider.to_string();
785 match by_provider.get_mut(&provider) {
786 Some(existing) => {
787 existing.merge_from(prov);
788 }
789 None => {
790 by_provider.insert(provider.to_string(), prov);
791 }
792 }
793 }
794 }
795
796 for (site, partial) in &by_site {
797 partial
798 .clone()
799 .finish()
800 .with_context(|| format!("site: {site}"))?;
801 }
802
803 for (domain, partial) in &by_domain {
804 partial
805 .clone()
806 .finish()
807 .with_context(|| format!("domain: {domain}"))?;
808 }
809
810 for (provider, prov) in &by_provider {
811 prov.finish_params()
812 .with_context(|| format!("provider: {provider}"))?;
813 }
814
815 if options.aliased_site != CheckLevel::Ignore {
816 for (site, aliases) in site_aliases {
817 if aliases.len() > 1 {
818 collector.push(
819 options.aliased_site,
820 format!(
821 "multiple domain blocks alias to the same site: {site}: {}",
822 aliases.join(", ")
823 ),
824 );
825 }
826 }
827 }
828
829 if options.provider_overlap != CheckLevel::Ignore {
830 for domain in mx.keys() {
831 let mut matching_providers = vec![];
832 for (prov_name, prov) in &by_provider {
833 if prov.domain_matches(domain).await {
834 matching_providers.push(prov_name.to_string());
835 }
836 }
837 if !matching_providers.is_empty() {
838 collector.push(
839 options.provider_overlap,
840 format!(
841 "domain {domain} is also matched by provider(s): {}",
842 matching_providers.join(", ")
843 ),
844 );
845 }
846 }
847 }
848
849 let mut ctx = Sha256::new();
850 ctx.update("by_site");
851 for (site, entry) in &by_site {
852 ctx.update(site);
853 entry.hash_into(&mut ctx);
854 }
855 ctx.update("by_domain");
856 for (domain, entry) in &by_domain {
857 ctx.update(domain);
858 entry.hash_into(&mut ctx);
859 }
860 ctx.update("by_provider");
861 for (provider, prov) in &by_provider {
862 ctx.update(provider);
863 prov.hash_into(&mut ctx);
864 }
865 ctx.update("warnings");
866 for warn in &collector.warnings {
867 ctx.update(warn);
868 }
869 ctx.update("errors");
870 for err in &collector.errors {
871 ctx.update(err);
872 }
873 let hash = ctx.finalize();
874 let hash = data_encoding::HEXLOWER.encode(&hash);
875
876 Ok(Self {
877 inner: Arc::new(ShapingInner {
878 by_site,
879 by_domain,
880 by_provider,
881 warnings: collector.warnings,
882 errors: collector.errors,
883 hash,
884 }),
885 })
886 }
887
888 async fn get_egress_path_config(
889 &self,
890 domain: &str,
891 egress_source: &str,
892 site_name: &str,
893 ) -> PartialEntry {
894 self.inner
895 .get_egress_path_config(domain, egress_source, site_name)
896 .await
897 }
898
899 pub async fn get_egress_path_config_value(
900 &self,
901 domain: &str,
902 egress_source: &str,
903 site_name: &str,
904 ) -> anyhow::Result<serde_json::Value> {
905 let partial = self
906 .get_egress_path_config(domain, egress_source, site_name)
907 .await;
908 Ok(serde_json::to_value(&partial)?)
909 }
910
911 pub fn get_errors(&self) -> &[String] {
912 &self.inner.errors
913 }
914
915 pub fn get_warnings(&self) -> &[String] {
916 &self.inner.warnings
917 }
918
919 pub async fn match_rules(&self, record: &JsonLogRecord) -> anyhow::Result<Vec<Rule>> {
920 self.inner.match_rules(record).await
921 }
922
923 pub fn get_referenced_sources(&self) -> BTreeMap<String, Vec<String>> {
924 let mut result = BTreeMap::new();
925
926 for (site_name, site) in &self.inner.by_site {
927 for source_name in site.sources.keys() {
928 result
929 .entry(source_name.to_string())
930 .or_insert(vec![])
931 .push(format!("site:{site_name}"));
932 }
933 }
934 for (domain_name, domain) in &self.inner.by_domain {
935 for source_name in domain.sources.keys() {
936 result
937 .entry(source_name.to_string())
938 .or_insert(vec![])
939 .push(format!("domain:{domain_name}"));
940 }
941 }
942
943 result
944 }
945
946 pub fn hash(&self) -> String {
947 self.inner.hash.clone()
948 }
949}
950
951#[cfg(feature = "lua")]
952impl LuaUserData for Shaping {
953 fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
954 mod_memoize::Memoized::impl_memoize(methods);
955 methods.add_async_method(
956 "get_egress_path_config",
957 |lua, this, (domain, egress_source, site_name): (String, String, String)| async move {
958 let params = this
959 .get_egress_path_config(&domain, &egress_source, &site_name)
960 .await;
961 lua.to_value_with(¶ms.params, serialize_options())
962 },
963 );
964
965 methods.add_method("get_errors", move |_lua, this, ()| {
966 let errors: Vec<String> = this.get_errors().iter().map(|s| s.to_string()).collect();
967 Ok(errors)
968 });
969
970 methods.add_method("get_warnings", move |_lua, this, ()| {
971 let warnings: Vec<String> = this.get_warnings().iter().map(|s| s.to_string()).collect();
972 Ok(warnings)
973 });
974
975 methods.add_method("get_referenced_sources", move |_lua, this, ()| {
976 Ok(this.get_referenced_sources())
977 });
978
979 methods.add_async_method("match_rules", |lua, this, record: mlua::Value| async move {
980 let record: JsonLogRecord = lua.from_value(record)?;
981 let rules = this.match_rules(&record).await.map_err(any_err)?;
982 let mut result = vec![];
983 for rule in rules {
984 result.push(lua.to_value(&rule)?);
985 }
986 Ok(result)
987 });
988
989 methods.add_method("hash", move |_, this, ()| Ok(this.hash()));
990 }
991}
992
993#[derive(Default, Debug)]
994pub struct MergedEntry {
995 pub params: EgressPathConfig,
996 pub sources: OrderMap<String, EgressPathConfig>,
997 pub automation: Vec<Rule>,
998}
999
1000#[cfg(feature = "lua")]
1001#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1002struct ShapingFile {
1003 pub default: Option<PartialEntry>,
1004 #[serde(flatten, default)]
1005 pub domains: OrderMap<String, PartialEntry>,
1006 #[serde(default)]
1007 pub provider: OrderMap<String, ProviderEntry>,
1008}
1009
1010#[cfg(feature = "lua")]
1011#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1012struct PartialEntry {
1013 #[serde(skip)]
1014 pub domain_name: Option<String>,
1015
1016 #[serde(flatten)]
1017 pub params: toml::Table,
1018
1019 #[serde(default = "default_true")]
1020 pub mx_rollup: bool,
1021
1022 #[cfg(test)]
1026 #[serde(default)]
1027 pub _treat_domain_name_as_site_name: bool,
1028
1029 #[serde(default)]
1030 pub replace_base: bool,
1031
1032 #[serde(default)]
1033 pub automation: Vec<Rule>,
1034
1035 #[serde(default)]
1036 pub sources: OrderMap<String, toml::Table>,
1037}
1038
1039#[cfg(feature = "lua")]
1040#[derive(Deserialize, Serialize, Debug, Clone, Default)]
1041pub struct ProviderEntry {
1042 #[serde(skip, default)]
1043 pub provider_name: String,
1044
1045 #[serde(default)]
1046 pub provider_connection_limit: Option<LimitSpec>,
1047
1048 #[serde(default)]
1049 pub provider_max_message_rate: Option<ThrottleSpec>,
1050
1051 #[serde(default, rename = "match")]
1052 pub matches: Vec<ProviderMatch>,
1053
1054 #[serde(default)]
1055 pub replace_base: bool,
1056
1057 #[serde(flatten)]
1058 pub params: toml::Table,
1059
1060 #[serde(default)]
1061 pub automation: Vec<Rule>,
1062
1063 #[serde(default)]
1064 pub sources: OrderMap<String, toml::Table>,
1065}
1066
1067#[cfg(feature = "lua")]
1068fn suffix_matches(candidate: &str, suffix: &str) -> bool {
1069 let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1072 candidate.ends_with(suffix)
1076}
1077
1078#[cfg(feature = "lua")]
1079#[cfg(test)]
1080#[test]
1081fn test_suffix_matches() {
1082 assert!(suffix_matches("a", "a"));
1083 assert!(suffix_matches("foo.com", "foo.com"));
1084 assert!(!suffix_matches("foo.com", ".foo.com"));
1085 assert!(!suffix_matches("foo.com", "longer.com"));
1086 assert!(!suffix_matches("réputation.net", ".mx.microsoft"));
1087
1088 assert!(!suffix_matches("foo.Com", ".com"));
1092 assert!(!suffix_matches("foo.Cam", ".com"));
1093}
1094
1095#[cfg(feature = "lua")]
1096fn host_matches(candidate: &str, name: &str) -> bool {
1097 let candidate = candidate.strip_suffix(".").unwrap_or(candidate);
1100 candidate == name
1101}
1102
1103#[cfg(feature = "lua")]
1104#[cfg(test)]
1105#[test]
1106fn test_host_matches() {
1107 assert!(host_matches("foo.com", "foo.com"));
1108 assert!(host_matches("foo.com.", "foo.com"));
1109 assert!(!host_matches("foo.com", "notfoo.com"));
1110}
1111
1112#[cfg(feature = "lua")]
1113impl ProviderEntry {
1114 async fn domain_matches(&self, domain: &str) -> bool {
1115 let mut need_mx = false;
1118
1119 tracing::trace!(
1120 "ProviderEntry::domain_matches({domain}) vs {:?}",
1121 self.matches
1122 );
1123
1124 for rule in &self.matches {
1125 match rule {
1126 ProviderMatch::DomainSuffix(suffix) => {
1127 if suffix_matches(domain, suffix) {
1128 tracing::trace!("{domain} suffix matches {suffix}");
1129 return true;
1130 }
1131 }
1132 ProviderMatch::HostName(_) | ProviderMatch::MXSuffix(_) => {
1133 need_mx = true;
1134 }
1135 }
1136 }
1137
1138 if !need_mx {
1139 return false;
1140 }
1141
1142 match MailExchanger::resolve(domain).await {
1144 Err(err) => {
1145 tracing::trace!(
1151 "Error resolving MX for {domain}: {err:#}. \
1152 Provider {} match rules will be ignored",
1153 self.provider_name
1154 );
1155 false
1156 }
1157 Ok(mx) => {
1158 tracing::trace!("Consider MXSuffix rules");
1159 for host in &mx.hosts {
1160 let mut matched = false;
1161
1162 for rule in &self.matches {
1163 match rule {
1164 ProviderMatch::MXSuffix(suffix) => {
1165 tracing::trace!("suffix={suffix} vs host {host}");
1170 if suffix_matches(host, suffix) {
1171 matched = true;
1172 break;
1173 }
1174 }
1175 ProviderMatch::HostName(name) => {
1176 if host_matches(host, name) {
1177 matched = true;
1178 break;
1179 }
1180 }
1181 ProviderMatch::DomainSuffix(_) => {}
1182 }
1183 }
1184
1185 if !matched {
1186 tracing::trace!("host didn't match any of these rules");
1187 return false;
1188 }
1189 }
1190
1191 true
1192 }
1193 }
1194 }
1195
1196 fn merge_from(&mut self, mut other: Self) {
1197 if other.replace_base {
1198 self.provider_connection_limit = other.provider_connection_limit;
1199 self.matches = other.matches;
1200 self.params = other.params;
1201 self.sources = other.sources;
1202 self.automation = other.automation;
1203 } else {
1204 if other.provider_connection_limit.is_some() {
1205 self.provider_connection_limit = other.provider_connection_limit;
1206 }
1207
1208 toml_table_merge_from(&mut self.params, &other.params);
1209
1210 for (source, tbl) in other.sources {
1211 match self.sources.get_mut(&source) {
1212 Some(existing) => {
1213 toml_table_merge_from(existing, &tbl);
1214 }
1215 None => {
1216 self.sources.insert(source, tbl);
1217 }
1218 }
1219 }
1220
1221 self.matches.append(&mut other.matches);
1222 self.automation.append(&mut other.automation);
1223 }
1224 }
1225
1226 fn apply_provider_params_to(&self, source: &str, target: &mut toml::Table) {
1227 let mut implied = toml::Table::new();
1228 implied.insert(
1229 "provider_name".to_string(),
1230 toml::Value::String(self.provider_name.to_string()),
1231 );
1232
1233 if let Some(limit) = &self.provider_connection_limit {
1234 let mut limits = toml::Table::new();
1235 limits.insert(
1236 format!("shaping-provider-{}-{source}-limit", self.provider_name),
1237 toml::Value::String(limit.to_string()),
1238 );
1239 implied.insert(
1240 "additional_connection_limits".to_string(),
1241 toml::Value::Table(limits),
1242 );
1243 }
1244 if let Some(rate) = &self.provider_max_message_rate {
1245 let rate = rate.as_string();
1246 let mut limits = toml::Table::new();
1247 limits.insert(
1248 format!("shaping-provider-{}-{source}-rate", self.provider_name),
1249 rate.into(),
1250 );
1251 implied.insert(
1252 "additional_message_rate_throttles".to_string(),
1253 toml::Value::Table(limits),
1254 );
1255 }
1256
1257 if let Some(rate) = target.remove("provider_source_selection_rate") {
1258 let mut limits = toml::Table::new();
1259 limits.insert(
1260 format!(
1261 "shaping-provider-{}-{source}-selection-rate",
1262 self.provider_name
1263 ),
1264 rate,
1265 );
1266 implied.insert(
1267 "additional_source_selection_rates".to_string(),
1268 toml::Value::Table(limits),
1269 );
1270 }
1271
1272 toml_table_merge_from(target, &implied);
1273 }
1274
1275 fn finish_params(&self) -> anyhow::Result<MergedEntry> {
1276 let provider_name = &self.provider_name;
1277
1278 let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1279 format!(
1280 "interpreting provider '{provider_name}' params {:#?} as EgressPathConfig",
1281 self.params
1282 )
1283 })?;
1284 let mut sources = OrderMap::new();
1285
1286 for (source, params) in &self.sources {
1287 let mut params = params.clone();
1288 params.remove("provider_source_selection_rate");
1294 sources.insert(
1295 source.clone(),
1296 EgressPathConfig::deserialize(params.clone()).with_context(|| {
1297 format!("interpreting provider '{provider_name}' source '{source}' {params:#} as EgressPathConfig")
1298 })?,
1299 );
1300 }
1301
1302 Ok(MergedEntry {
1303 params,
1304 sources,
1305 automation: self.automation.clone(),
1306 })
1307 }
1308
1309 fn hash_into(&self, ctx: &mut Sha256) {
1310 ctx.update(&self.provider_name);
1311 ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1312 }
1313}
1314
1315#[cfg(feature = "lua")]
1316#[derive(Deserialize, Serialize, Debug, Clone)]
1317pub enum ProviderMatch {
1318 MXSuffix(String),
1319 DomainSuffix(String),
1320 HostName(String),
1321}
1322
1323#[cfg(feature = "lua")]
1324fn toml_table_merge_from(tbl: &mut toml::Table, source: &toml::Table) {
1325 fn is_mergeable(name: &str) -> bool {
1328 match name {
1329 "additional_connection_limits"
1330 | "additional_message_rate_throttles"
1331 | "additional_source_selection_rates" => true,
1332 _ => false,
1333 }
1334 }
1335
1336 for (k, v) in source {
1337 match (tbl.get_mut(k), v.as_table()) {
1338 (Some(toml::Value::Table(existing)), Some(v)) if is_mergeable(k) => {
1340 for (inner_k, inner_v) in v {
1341 existing.insert(inner_k.clone(), inner_v.clone());
1342 }
1343 }
1344 _ => {
1345 tbl.insert(k.clone(), v.clone());
1346 }
1347 }
1348 }
1349}
1350
1351#[cfg(feature = "lua")]
1352impl PartialEntry {
1353 fn merge_from(&mut self, mut other: Self) {
1354 if other.replace_base {
1355 self.params = other.params;
1356 self.automation = other.automation;
1357 self.sources = other.sources;
1358 } else {
1359 toml_table_merge_from(&mut self.params, &other.params);
1360
1361 for (source, tbl) in other.sources {
1362 match self.sources.get_mut(&source) {
1363 Some(existing) => {
1364 toml_table_merge_from(existing, &tbl);
1365 }
1366 None => {
1367 self.sources.insert(source, tbl);
1368 }
1369 }
1370 }
1371
1372 self.automation.append(&mut other.automation);
1373 }
1374 }
1375
1376 fn finish(self) -> anyhow::Result<MergedEntry> {
1377 let domain = self.domain_name.unwrap_or_default();
1378
1379 let params = EgressPathConfig::deserialize(self.params.clone()).with_context(|| {
1380 format!(
1381 "interpreting domain '{domain}' params {:#?} as EgressPathConfig",
1382 self.params
1383 )
1384 })?;
1385 let mut sources = OrderMap::new();
1386
1387 for (source, params) in self.sources {
1388 sources.insert(
1389 source.clone(),
1390 EgressPathConfig::deserialize(params.clone()).with_context(|| {
1391 format!("interpreting domain '{domain}' source '{source}' {params:#} as EgressPathConfig")
1392 })?,
1393 );
1394 }
1395
1396 Ok(MergedEntry {
1397 params,
1398 sources,
1399 automation: self.automation,
1400 })
1401 }
1402
1403 fn hash_into(&self, ctx: &mut Sha256) {
1404 if let Some(name) = self.domain_name.as_ref() {
1405 ctx.update(name)
1406 }
1407 ctx.update(serde_json::to_string(self).unwrap_or_else(|_| String::new()));
1408 }
1409}
1410
1411fn one_or_many<'de, T, D>(deserializer: D, expecting: &str) -> Result<Vec<T>, D::Error>
1412where
1413 T: Deserialize<'de>,
1414 D: Deserializer<'de>,
1415{
1416 let result: Result<Vec<T>, _> =
1417 OneOrMany::<serde_with::Same, PreferOne>::deserialize_as(deserializer);
1418 match result {
1419 Ok(r) => Ok(r),
1420 Err(err) => Err(serde::de::Error::custom(format!(
1421 "{expecting}.\nThe underlying error message is:\n{err:#}"
1422 ))),
1423 }
1424}
1425
1426fn one_or_many_action<'de, D>(deserializer: D) -> Result<Vec<Action>, D::Error>
1427where
1428 D: Deserializer<'de>,
1429{
1430 one_or_many(
1431 deserializer,
1432 "\"action\" field expected either a single Action or an array of Actions",
1433 )
1434}
1435
1436fn regex_string_or_array<'de, D>(deserializer: D) -> Result<Vec<Regex>, D::Error>
1437where
1438 D: Deserializer<'de>,
1439{
1440 string_or_array(
1441 deserializer,
1442 "regex string or array of regex strings for field regex",
1443 )
1444}
1445
1446fn string_or_array<'de, T, D>(deserializer: D, expecting: &'static str) -> Result<Vec<T>, D::Error>
1447where
1448 T: Deserialize<'de> + TryFrom<String>,
1449 <T as TryFrom<String>>::Error: std::fmt::Debug,
1450 D: Deserializer<'de>,
1451{
1452 struct StringOrArray<T>(PhantomData<fn() -> T>, &'static str);
1458
1459 impl<'de, T> Visitor<'de> for StringOrArray<T>
1460 where
1461 T: Deserialize<'de> + TryFrom<String>,
1462 <T as TryFrom<String>>::Error: std::fmt::Debug,
1463 {
1464 type Value = Vec<T>;
1465
1466 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
1467 formatter.write_str(self.1)
1468 }
1469
1470 fn visit_str<E>(self, value: &str) -> Result<Vec<T>, E>
1471 where
1472 E: serde::de::Error,
1473 {
1474 Ok(vec![
1475 T::try_from(value.to_string()).map_err(|e| E::custom(format!("{e:?}")))?
1476 ])
1477 }
1478
1479 fn visit_seq<S>(self, seq: S) -> Result<Vec<T>, S::Error>
1480 where
1481 S: SeqAccess<'de>,
1482 {
1483 Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))
1484 }
1485 }
1486
1487 deserializer.deserialize_any(StringOrArray(PhantomData, expecting))
1488}
1489
1490#[cfg(feature = "lua")]
1491fn default_true() -> bool {
1492 true
1493}
1494
1495#[cfg(feature = "lua")]
1496pub fn register(lua: &mlua::Lua) -> anyhow::Result<()> {
1497 let shaping_mod = config::get_or_create_sub_module(lua, "shaping")?;
1498
1499 shaping_mod.set(
1500 "load",
1501 lua.create_async_function(
1502 move |lua, (paths, options): (Vec<String>, Option<mlua::Value>)| async move {
1503 let options = match options {
1504 Some(v) => lua.from_value(v)?,
1505 None => Default::default(),
1506 };
1507 let shaping = Shaping::merge_files(&paths, &options)
1508 .await
1509 .map_err(any_err)?;
1510 Ok(shaping)
1511 },
1512 )?,
1513 )?;
1514
1515 Ok(())
1516}
1517
1518#[cfg(test)]
1519mod test {
1520 use super::*;
1521 use kumo_log_types::RecordType;
1522 use rfc5321::Response;
1523 use std::io::Write;
1524 use tempfile::NamedTempFile;
1525 use uuid::Uuid;
1526
1527 async fn make_shaping_configs(inputs: &[&str]) -> Shaping {
1528 let mut files = vec![];
1529 let mut file_names = vec![];
1530
1531 for (i, content) in inputs.iter().enumerate() {
1532 let mut shaping_file = NamedTempFile::with_prefix(format!("file{i}")).unwrap();
1533 shaping_file.write_all(content.as_bytes()).unwrap();
1534 file_names.push(shaping_file.path().to_str().unwrap().to_string());
1535 files.push(shaping_file);
1536 }
1537
1538 Shaping::merge_files(&file_names, &ShapingMergeOptions::default())
1539 .await
1540 .unwrap()
1541 }
1542
1543 #[tokio::test]
1544 async fn test_merge_additional() {
1545 let shaping = make_shaping_configs(&[
1546 r#"
1547["example.com"]
1548mx_rollup = false
1549additional_connection_limits = {"first"=10}
1550 "#,
1551 r#"
1552["example.com"]
1553mx_rollup = false
1554additional_connection_limits = {"second"=32}
1555additional_message_rate_throttles = {"second"="100/hr"}
1556 "#,
1557 ])
1558 .await;
1559
1560 let resolved = shaping
1561 .get_egress_path_config("example.com", "invalid.source", "invalid.site")
1562 .await
1563 .finish()
1564 .unwrap();
1565
1566 k9::snapshot!(
1567 resolved.params.additional_connection_limits,
1568 r#"
1569{
1570 "first": 10,
1571 "second": 32,
1572}
1573"#
1574 );
1575 k9::snapshot!(
1576 resolved.params.additional_message_rate_throttles,
1577 r#"
1578{
1579 "second": 100/h,
1580}
1581"#
1582 );
1583 }
1584
1585 #[tokio::test]
1586 async fn test_provider_multi_hostname() {
1587 let shaping = make_shaping_configs(&[r#"
1588[provider."yahoo"]
1589match=[{HostName="mta5.am0.yahoodns.net"},{HostName="mta6.am0.yahoodns.net"},{HostName="mta7.am0.yahoodns.net"}]
1590enable_tls = "Required"
1591 "#])
1592 .await;
1593
1594 let resolved = shaping
1595 .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1596 .await
1597 .finish()
1598 .unwrap();
1599
1600 k9::assert_equal!(
1601 resolved.params.enable_tls,
1602 crate::egress_path::Tls::Required
1603 );
1604 k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1605 }
1606
1607 #[tokio::test]
1608 async fn test_provider_multi_suffix() {
1609 let shaping = make_shaping_configs(&[r#"
1610[provider."yahoo"]
1611match=[{MXSuffix="mta5.am0.yahoodns.net"},{MXSuffix="mta6.am0.yahoodns.net"},{MXSuffix="mta7.am0.yahoodns.net"}]
1612enable_tls = "Required"
1613 "#])
1614 .await;
1615
1616 let resolved = shaping
1617 .get_egress_path_config("yahoo.com", "invalid.source", "invalid.site")
1618 .await
1619 .finish()
1620 .unwrap();
1621
1622 k9::assert_equal!(
1623 resolved.params.enable_tls,
1624 crate::egress_path::Tls::Required
1625 );
1626 k9::assert_equal!(resolved.params.provider_name.unwrap(), "yahoo");
1627 }
1628
1629 #[tokio::test]
1630 async fn test_provider() {
1631 let shaping = make_shaping_configs(&[r#"
1632[provider."Office 365"]
1633match=[{MXSuffix=".olc.protection.outlook.com"},{DomainSuffix=".outlook.com"}]
1634enable_tls = "Required"
1635provider_connection_limit = 10
1636provider_max_message_rate = "120/s"
1637
1638[provider."Office 365".sources."new-source"]
1639provider_source_selection_rate = "500/d,max_burst=1"
1640 "#])
1641 .await;
1642
1643 let resolved = shaping
1644 .get_egress_path_config("outlook.com", "invalid.source", "invalid.site")
1645 .await
1646 .finish()
1647 .unwrap();
1648
1649 k9::assert_equal!(
1650 resolved.params.enable_tls,
1651 crate::egress_path::Tls::Required
1652 );
1653 k9::assert_equal!(resolved.params.provider_name.unwrap(), "Office 365");
1654
1655 k9::snapshot!(
1656 resolved.params.additional_connection_limits,
1657 r#"
1658{
1659 "shaping-provider-Office 365-invalid.source-limit": 10,
1660}
1661"#
1662 );
1663 k9::snapshot!(
1664 resolved.params.additional_message_rate_throttles,
1665 r#"
1666{
1667 "shaping-provider-Office 365-invalid.source-rate": 120/s,
1668}
1669"#
1670 );
1671 assert!(resolved.params.source_selection_rate.is_none());
1672 assert!(resolved.params.additional_source_selection_rates.is_empty());
1673
1674 let resolved = shaping
1675 .get_egress_path_config("outlook.com", "new-source", "invalid.site")
1676 .await
1677 .finish()
1678 .unwrap();
1679 assert!(resolved.params.source_selection_rate.is_none());
1680 k9::snapshot!(
1681 resolved.params.additional_source_selection_rates,
1682 r#"
1683{
1684 "shaping-provider-Office 365-new-source-selection-rate": 500/d,max_burst=1,
1685}
1686"#
1687 );
1688 }
1689
1690 #[tokio::test]
1691 async fn test_rule_matching() {
1692 let shaping = make_shaping_configs(&[r#"
1693[["default".automation]]
1694regex="default"
1695action = {SetConfig={name="connection_limit", value=1}}
1696duration = "1hr"
1697
1698["fake.site"]
1699_treat_domain_name_as_site_name = true
1700
1701[["fake.site".automation]]
1702regex="fake_rollup"
1703action = {SetConfig={name="connection_limit", value=2}}
1704duration = "1hr"
1705
1706["woot.provider"]
1707mx_rollup = false
1708
1709[["woot.provider".automation]]
1710regex="woot_domain"
1711action = {SetConfig={name="connection_limit", value=2}}
1712duration = "1hr"
1713
1714[provider."provider"]
1715match=[{DomainSuffix=".provider"}]
1716
1717[[provider."provider".automation]]
1718regex="provider"
1719action = {SetConfig={name="connection_limit", value=3}}
1720duration = "1hr"
1721match_internal = true
1722
1723"#])
1724 .await;
1725
1726 eprintln!("{:?}", shaping.inner.warnings);
1727
1728 fn make_record(content: &str, recipient: &str, site: &str) -> JsonLogRecord {
1729 JsonLogRecord {
1730 kind: RecordType::TransientFailure,
1731 id: String::new(),
1732 sender: String::new(),
1733 recipient: recipient.to_string(),
1734 queue: String::new(),
1735 site: site.to_string(),
1736 size: 0,
1737 response: Response {
1738 code: 400,
1739 command: None,
1740 enhanced_code: None,
1741 content: content.to_string(),
1742 },
1743 peer_address: None,
1744 timestamp: Default::default(),
1745 created: Default::default(),
1746 num_attempts: 1,
1747 bounce_classification: Default::default(),
1748 egress_pool: None,
1749 egress_source: None,
1750 source_address: None,
1751 feedback_report: None,
1752 meta: Default::default(),
1753 headers: Default::default(),
1754 delivery_protocol: None,
1755 reception_protocol: None,
1756 nodeid: Uuid::default(),
1757 tls_cipher: None,
1758 tls_protocol_version: None,
1759 tls_peer_subject_name: None,
1760 provider_name: None,
1761 session_id: None,
1762 }
1763 }
1764
1765 let matches = shaping
1766 .match_rules(&make_record("default", "user@example.com", "dummy_site"))
1767 .await
1768 .unwrap();
1769 k9::assert_equal!(
1770 matches[0].regex[0].to_string(),
1771 "default",
1772 "matches against default automation rule"
1773 );
1774
1775 let matches = shaping
1776 .match_rules(&make_record(
1777 "KumoMTA internal: default",
1778 "user@example.com",
1779 "dummy_site",
1780 ))
1781 .await
1782 .unwrap();
1783 assert!(matches.is_empty(), "internal bounce should not match");
1784
1785 let matches = shaping
1786 .match_rules(&make_record(
1787 "woot_domain",
1788 "user@woot.provider",
1789 "dummy_site",
1790 ))
1791 .await
1792 .unwrap();
1793 k9::assert_equal!(
1794 matches[0].regex[0].to_string(),
1795 "woot_domain",
1796 "matches against domain rule with mx_rollup=false"
1797 );
1798
1799 let matches = shaping
1800 .match_rules(&make_record("fake_rollup", "user@fake.rollup", "fake.site"))
1801 .await
1802 .unwrap();
1803 k9::assert_equal!(
1804 matches[0].regex[0].to_string(),
1805 "fake_rollup",
1806 "matches against domain rule with mx_rollup=true"
1807 );
1808
1809 let matches = shaping
1810 .match_rules(&make_record("provider", "user@woot.provider", "dummy_site"))
1811 .await
1812 .unwrap();
1813 k9::assert_equal!(
1814 matches[0].regex[0].to_string(),
1815 "provider",
1816 "matches against provider rule"
1817 );
1818
1819 let matches = shaping
1820 .match_rules(&make_record(
1821 "KumoMTA internal: provider",
1822 "user@woot.provider",
1823 "dummy_site",
1824 ))
1825 .await
1826 .unwrap();
1827 k9::assert_equal!(
1828 matches[0].regex[0].to_string(),
1829 "provider",
1830 "internal response matches against provider rule"
1831 );
1832 }
1833
1834 #[tokio::test]
1835 async fn test_defaults() {
1836 let shaping = make_shaping_configs(&[
1837 r#"
1838["default"]
1839connection_limit = 10
1840max_connection_rate = "100/min"
1841max_deliveries_per_connection = 100
1842max_message_rate = "100/s"
1843idle_timeout = "60s"
1844data_timeout = "30s"
1845data_dot_timeout = "60s"
1846enable_tls = "Opportunistic"
1847consecutive_connection_failures_before_delay = 100
1848
1849[["default".automation]]
1850regex=[
1851 '/Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/',
1852 '/All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/',
1853 '/has been temporarily rate limited due to IP reputation/',
1854 '/Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/',
1855 '/Server busy\. Please try again later from/'
1856]
1857action = [
1858 {SetConfig={name="max_message_rate", value="1/minute"}},
1859 {SetConfig={name="connection_limit", value=1}}
1860]
1861duration = "90m"
1862
1863[["default".automation]]
1864regex="KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site"
1865action = {SetConfig={name="enable_tls", value="Disabled"}}
1866duration = "30 days"
1867
1868["gmail.com"]
1869max_deliveries_per_connection = 50
1870connection_limit = 5
1871enable_tls = "Required"
1872consecutive_connection_failures_before_delay = 5
1873
1874["yahoo.com"]
1875max_deliveries_per_connection = 20
1876
1877[["yahoo.com".automation]]
1878regex = "\\[TS04\\]"
1879action = "Suspend"
1880duration = "2 hours"
1881
1882["comcast.net"]
1883connection_limit = 25
1884max_deliveries_per_connection = 250
1885enable_tls = "Required"
1886idle_timeout = "30s"
1887consecutive_connection_failures_before_delay = 24
1888
1889["mail.com"]
1890max_deliveries_per_connection = 100
1891
1892["orange.fr"]
1893connection_limit = 3
1894
1895["smtp.mailgun.com"]
1896connection_limit = 7000
1897max_deliveries_per_connection = 3
1898
1899["example.com"]
1900mx_rollup = false
1901max_deliveries_per_connection = 100
1902connection_limit = 3
1903
1904["example.com".sources."my source name"]
1905connection_limit = 5
1906 "#,
1907 ])
1908 .await;
1909
1910 let default = shaping
1911 .get_egress_path_config("invalid.domain", "invalid.source", "invalid.site")
1912 .await
1913 .finish()
1914 .unwrap();
1915 k9::snapshot!(
1916 default,
1917 r#"
1918MergedEntry {
1919 params: EgressPathConfig {
1920 connection_limit: 10,
1921 additional_connection_limits: {},
1922 enable_tls: Opportunistic,
1923 enable_mta_sts: true,
1924 enable_dane: false,
1925 enable_pipelining: true,
1926 enable_rset: true,
1927 tls_prefer_openssl: false,
1928 openssl_cipher_list: None,
1929 openssl_cipher_suites: None,
1930 openssl_options: None,
1931 rustls_cipher_suites: [],
1932 client_timeouts: SmtpClientTimeouts {
1933 connect_timeout: 60s,
1934 banner_timeout: 60s,
1935 ehlo_timeout: 300s,
1936 mail_from_timeout: 300s,
1937 rcpt_to_timeout: 300s,
1938 data_timeout: 30s,
1939 data_dot_timeout: 60s,
1940 rset_timeout: 5s,
1941 idle_timeout: 60s,
1942 starttls_timeout: 5s,
1943 auth_timeout: 60s,
1944 },
1945 system_shutdown_timeout: None,
1946 max_ready: 1024,
1947 consecutive_connection_failures_before_delay: 100,
1948 smtp_port: 25,
1949 smtp_auth_plain_username: None,
1950 smtp_auth_plain_password: None,
1951 allow_smtp_auth_plain_without_tls: false,
1952 max_message_rate: Some(
1953 100/s,
1954 ),
1955 additional_message_rate_throttles: {},
1956 source_selection_rate: None,
1957 additional_source_selection_rates: {},
1958 max_connection_rate: Some(
1959 100/m,
1960 ),
1961 max_deliveries_per_connection: 100,
1962 prohibited_hosts: {
1963 "127.0.0.0/8",
1964 "::1",
1965 },
1966 skip_hosts: {},
1967 ehlo_domain: None,
1968 aggressive_connection_opening: false,
1969 refresh_interval: 60s,
1970 refresh_strategy: Ttl,
1971 dispatcher_wakeup_strategy: Aggressive,
1972 maintainer_wakeup_strategy: Aggressive,
1973 provider_name: None,
1974 remember_broken_tls: None,
1975 opportunistic_tls_reconnect_on_failed_handshake: false,
1976 use_lmtp: false,
1977 reconnect_strategy: ConnectNextHost,
1978 readyq_pool_name: None,
1979 low_memory_reduction_policy: ShrinkDataAndMeta,
1980 no_memory_reduction_policy: ShrinkDataAndMeta,
1981 },
1982 sources: {},
1983 automation: [
1984 Rule {
1985 regex: [
1986 Regex(
1987 /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
1988 ),
1989 Regex(
1990 /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
1991 ),
1992 Regex(
1993 /has been temporarily rate limited due to IP reputation/,
1994 ),
1995 Regex(
1996 /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
1997 ),
1998 Regex(
1999 /Server busy\. Please try again later from/,
2000 ),
2001 ],
2002 action: [
2003 SetConfig(
2004 EgressPathConfigValue {
2005 name: "max_message_rate",
2006 value: HashableTomlValue {
2007 value: String(
2008 "1/minute",
2009 ),
2010 },
2011 },
2012 ),
2013 SetConfig(
2014 EgressPathConfigValue {
2015 name: "connection_limit",
2016 value: HashableTomlValue {
2017 value: Integer(
2018 1,
2019 ),
2020 },
2021 },
2022 ),
2023 ],
2024 trigger: Immediate,
2025 duration: 5400s,
2026 was_rollup: false,
2027 match_internal: false,
2028 },
2029 Rule {
2030 regex: [
2031 Regex(
2032 KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2033 ),
2034 ],
2035 action: [
2036 SetConfig(
2037 EgressPathConfigValue {
2038 name: "enable_tls",
2039 value: HashableTomlValue {
2040 value: String(
2041 "Disabled",
2042 ),
2043 },
2044 },
2045 ),
2046 ],
2047 trigger: Immediate,
2048 duration: 2592000s,
2049 was_rollup: false,
2050 match_internal: false,
2051 },
2052 ],
2053}
2054"#
2055 );
2056
2057 let example_com = shaping
2058 .get_egress_path_config("example.com", "invalid.source", "invalid.site")
2059 .await
2060 .finish()
2061 .unwrap();
2062 k9::snapshot!(
2063 example_com,
2064 r#"
2065MergedEntry {
2066 params: EgressPathConfig {
2067 connection_limit: 3,
2068 additional_connection_limits: {},
2069 enable_tls: Opportunistic,
2070 enable_mta_sts: true,
2071 enable_dane: false,
2072 enable_pipelining: true,
2073 enable_rset: true,
2074 tls_prefer_openssl: false,
2075 openssl_cipher_list: None,
2076 openssl_cipher_suites: None,
2077 openssl_options: None,
2078 rustls_cipher_suites: [],
2079 client_timeouts: SmtpClientTimeouts {
2080 connect_timeout: 60s,
2081 banner_timeout: 60s,
2082 ehlo_timeout: 300s,
2083 mail_from_timeout: 300s,
2084 rcpt_to_timeout: 300s,
2085 data_timeout: 30s,
2086 data_dot_timeout: 60s,
2087 rset_timeout: 5s,
2088 idle_timeout: 60s,
2089 starttls_timeout: 5s,
2090 auth_timeout: 60s,
2091 },
2092 system_shutdown_timeout: None,
2093 max_ready: 1024,
2094 consecutive_connection_failures_before_delay: 100,
2095 smtp_port: 25,
2096 smtp_auth_plain_username: None,
2097 smtp_auth_plain_password: None,
2098 allow_smtp_auth_plain_without_tls: false,
2099 max_message_rate: Some(
2100 100/s,
2101 ),
2102 additional_message_rate_throttles: {},
2103 source_selection_rate: None,
2104 additional_source_selection_rates: {},
2105 max_connection_rate: Some(
2106 100/m,
2107 ),
2108 max_deliveries_per_connection: 100,
2109 prohibited_hosts: {
2110 "127.0.0.0/8",
2111 "::1",
2112 },
2113 skip_hosts: {},
2114 ehlo_domain: None,
2115 aggressive_connection_opening: false,
2116 refresh_interval: 60s,
2117 refresh_strategy: Ttl,
2118 dispatcher_wakeup_strategy: Aggressive,
2119 maintainer_wakeup_strategy: Aggressive,
2120 provider_name: None,
2121 remember_broken_tls: None,
2122 opportunistic_tls_reconnect_on_failed_handshake: false,
2123 use_lmtp: false,
2124 reconnect_strategy: ConnectNextHost,
2125 readyq_pool_name: None,
2126 low_memory_reduction_policy: ShrinkDataAndMeta,
2127 no_memory_reduction_policy: ShrinkDataAndMeta,
2128 },
2129 sources: {
2130 "my source name": EgressPathConfig {
2131 connection_limit: 5,
2132 additional_connection_limits: {},
2133 enable_tls: Opportunistic,
2134 enable_mta_sts: true,
2135 enable_dane: false,
2136 enable_pipelining: true,
2137 enable_rset: true,
2138 tls_prefer_openssl: false,
2139 openssl_cipher_list: None,
2140 openssl_cipher_suites: None,
2141 openssl_options: None,
2142 rustls_cipher_suites: [],
2143 client_timeouts: SmtpClientTimeouts {
2144 connect_timeout: 60s,
2145 banner_timeout: 60s,
2146 ehlo_timeout: 300s,
2147 mail_from_timeout: 300s,
2148 rcpt_to_timeout: 300s,
2149 data_timeout: 300s,
2150 data_dot_timeout: 300s,
2151 rset_timeout: 5s,
2152 idle_timeout: 5s,
2153 starttls_timeout: 5s,
2154 auth_timeout: 60s,
2155 },
2156 system_shutdown_timeout: None,
2157 max_ready: 1024,
2158 consecutive_connection_failures_before_delay: 100,
2159 smtp_port: 25,
2160 smtp_auth_plain_username: None,
2161 smtp_auth_plain_password: None,
2162 allow_smtp_auth_plain_without_tls: false,
2163 max_message_rate: None,
2164 additional_message_rate_throttles: {},
2165 source_selection_rate: None,
2166 additional_source_selection_rates: {},
2167 max_connection_rate: None,
2168 max_deliveries_per_connection: 1024,
2169 prohibited_hosts: {
2170 "127.0.0.0/8",
2171 "::1",
2172 },
2173 skip_hosts: {},
2174 ehlo_domain: None,
2175 aggressive_connection_opening: false,
2176 refresh_interval: 60s,
2177 refresh_strategy: Ttl,
2178 dispatcher_wakeup_strategy: Aggressive,
2179 maintainer_wakeup_strategy: Aggressive,
2180 provider_name: None,
2181 remember_broken_tls: None,
2182 opportunistic_tls_reconnect_on_failed_handshake: false,
2183 use_lmtp: false,
2184 reconnect_strategy: ConnectNextHost,
2185 readyq_pool_name: None,
2186 low_memory_reduction_policy: ShrinkDataAndMeta,
2187 no_memory_reduction_policy: ShrinkDataAndMeta,
2188 },
2189 },
2190 automation: [
2191 Rule {
2192 regex: [
2193 Regex(
2194 /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2195 ),
2196 Regex(
2197 /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2198 ),
2199 Regex(
2200 /has been temporarily rate limited due to IP reputation/,
2201 ),
2202 Regex(
2203 /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2204 ),
2205 Regex(
2206 /Server busy\. Please try again later from/,
2207 ),
2208 ],
2209 action: [
2210 SetConfig(
2211 EgressPathConfigValue {
2212 name: "max_message_rate",
2213 value: HashableTomlValue {
2214 value: String(
2215 "1/minute",
2216 ),
2217 },
2218 },
2219 ),
2220 SetConfig(
2221 EgressPathConfigValue {
2222 name: "connection_limit",
2223 value: HashableTomlValue {
2224 value: Integer(
2225 1,
2226 ),
2227 },
2228 },
2229 ),
2230 ],
2231 trigger: Immediate,
2232 duration: 5400s,
2233 was_rollup: false,
2234 match_internal: false,
2235 },
2236 Rule {
2237 regex: [
2238 Regex(
2239 KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2240 ),
2241 ],
2242 action: [
2243 SetConfig(
2244 EgressPathConfigValue {
2245 name: "enable_tls",
2246 value: HashableTomlValue {
2247 value: String(
2248 "Disabled",
2249 ),
2250 },
2251 },
2252 ),
2253 ],
2254 trigger: Immediate,
2255 duration: 2592000s,
2256 was_rollup: false,
2257 match_internal: false,
2258 },
2259 ],
2260}
2261"#
2262 );
2263
2264 let yahoo_com = shaping
2267 .get_egress_path_config(
2268 "yahoo.com",
2269 "invalid.source",
2270 "(mta5|mta6|mta7).am0.yahoodns.net",
2271 )
2272 .await
2273 .finish()
2274 .unwrap();
2275 k9::snapshot!(
2276 yahoo_com,
2277 r#"
2278MergedEntry {
2279 params: EgressPathConfig {
2280 connection_limit: 10,
2281 additional_connection_limits: {},
2282 enable_tls: Opportunistic,
2283 enable_mta_sts: true,
2284 enable_dane: false,
2285 enable_pipelining: true,
2286 enable_rset: true,
2287 tls_prefer_openssl: false,
2288 openssl_cipher_list: None,
2289 openssl_cipher_suites: None,
2290 openssl_options: None,
2291 rustls_cipher_suites: [],
2292 client_timeouts: SmtpClientTimeouts {
2293 connect_timeout: 60s,
2294 banner_timeout: 60s,
2295 ehlo_timeout: 300s,
2296 mail_from_timeout: 300s,
2297 rcpt_to_timeout: 300s,
2298 data_timeout: 30s,
2299 data_dot_timeout: 60s,
2300 rset_timeout: 5s,
2301 idle_timeout: 60s,
2302 starttls_timeout: 5s,
2303 auth_timeout: 60s,
2304 },
2305 system_shutdown_timeout: None,
2306 max_ready: 1024,
2307 consecutive_connection_failures_before_delay: 100,
2308 smtp_port: 25,
2309 smtp_auth_plain_username: None,
2310 smtp_auth_plain_password: None,
2311 allow_smtp_auth_plain_without_tls: false,
2312 max_message_rate: Some(
2313 100/s,
2314 ),
2315 additional_message_rate_throttles: {},
2316 source_selection_rate: None,
2317 additional_source_selection_rates: {},
2318 max_connection_rate: Some(
2319 100/m,
2320 ),
2321 max_deliveries_per_connection: 20,
2322 prohibited_hosts: {
2323 "127.0.0.0/8",
2324 "::1",
2325 },
2326 skip_hosts: {},
2327 ehlo_domain: None,
2328 aggressive_connection_opening: false,
2329 refresh_interval: 60s,
2330 refresh_strategy: Ttl,
2331 dispatcher_wakeup_strategy: Aggressive,
2332 maintainer_wakeup_strategy: Aggressive,
2333 provider_name: None,
2334 remember_broken_tls: None,
2335 opportunistic_tls_reconnect_on_failed_handshake: false,
2336 use_lmtp: false,
2337 reconnect_strategy: ConnectNextHost,
2338 readyq_pool_name: None,
2339 low_memory_reduction_policy: ShrinkDataAndMeta,
2340 no_memory_reduction_policy: ShrinkDataAndMeta,
2341 },
2342 sources: {},
2343 automation: [
2344 Rule {
2345 regex: [
2346 Regex(
2347 /Messages from \d+\.\d+\.\d+\.\d+ temporarily deferred/,
2348 ),
2349 Regex(
2350 /All messages from \d+\.\d+\.\d+\.\d+ will be permanently deferred/,
2351 ),
2352 Regex(
2353 /has been temporarily rate limited due to IP reputation/,
2354 ),
2355 Regex(
2356 /Unfortunately, messages from \d+\.\d+\.\d+\.\d+ weren.t sent/,
2357 ),
2358 Regex(
2359 /Server busy\. Please try again later from/,
2360 ),
2361 ],
2362 action: [
2363 SetConfig(
2364 EgressPathConfigValue {
2365 name: "max_message_rate",
2366 value: HashableTomlValue {
2367 value: String(
2368 "1/minute",
2369 ),
2370 },
2371 },
2372 ),
2373 SetConfig(
2374 EgressPathConfigValue {
2375 name: "connection_limit",
2376 value: HashableTomlValue {
2377 value: Integer(
2378 1,
2379 ),
2380 },
2381 },
2382 ),
2383 ],
2384 trigger: Immediate,
2385 duration: 5400s,
2386 was_rollup: false,
2387 match_internal: false,
2388 },
2389 Rule {
2390 regex: [
2391 Regex(
2392 KumoMTA internal: failed to connect to any candidate hosts: All failures are related to OpportunisticInsecure STARTTLS. Consider setting enable_tls=Disabled for this site,
2393 ),
2394 ],
2395 action: [
2396 SetConfig(
2397 EgressPathConfigValue {
2398 name: "enable_tls",
2399 value: HashableTomlValue {
2400 value: String(
2401 "Disabled",
2402 ),
2403 },
2404 },
2405 ),
2406 ],
2407 trigger: Immediate,
2408 duration: 2592000s,
2409 was_rollup: false,
2410 match_internal: false,
2411 },
2412 Rule {
2413 regex: [
2414 Regex(
2415 \[TS04\],
2416 ),
2417 ],
2418 action: [
2419 Suspend,
2420 ],
2421 trigger: Immediate,
2422 duration: 7200s,
2423 was_rollup: false,
2424 match_internal: false,
2425 },
2426 ],
2427}
2428"#
2429 );
2430 }
2431
2432 #[tokio::test]
2433 async fn test_load_default_shaping_toml() {
2434 Shaping::merge_files(
2435 &["../../assets/policy-extras/shaping.toml".into()],
2436 &ShapingMergeOptions::default(),
2437 )
2438 .await
2439 .unwrap();
2440 }
2441}