1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4 AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5 ReversePath,
6};
7use hickory_proto::rr::rdata::TLSA;
8use memchr::memmem::Finder;
9use openssl::x509::{X509Ref, X509};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::str::FromStr;
14use std::sync::{Arc, LazyLock};
15use std::time::Duration;
16use thiserror::Error;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpStream, ToSocketAddrs};
19use tokio::time::timeout;
20use tokio_rustls::rustls::pki_types::ServerName;
21use tracing::Level;
22
23pub use kumo_tls_helper::TlsOptions;
24pub use openssl;
25pub use tokio_rustls;
26
27const MAX_LINE_LEN: usize = 4096;
28
29#[derive(Error, Debug, Clone)]
30pub enum ClientError {
31 #[error("response is not UTF8")]
32 Utf8(#[from] std::string::FromUtf8Error),
33 #[error("Malformed Response: {0}")]
34 MalformedResponseLine(String),
35 #[error("Response line is too long")]
36 ResponseTooLong,
37 #[error("Not connected")]
38 NotConnected,
39 #[error("Command rejected {0:?}")]
40 Rejected(Response),
41 #[error("Commands rejected {0:?}")]
42 RejectedBatch(Vec<Response>),
43 #[error("STARTTLS: {0} is not a valid DNS name")]
44 InvalidDnsName(String),
45 #[error("Invalid client certificate configured: {error:?}")]
46 FailedToBuildConnector { error: String },
47 #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
48 TimeOutResponse {
49 command: Option<Command>,
50 duration: Duration,
51 },
52 #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
53 TimeOutRequest {
54 commands: Vec<Command>,
55 duration: Duration,
56 },
57 #[error("Error {error} reading response to cmd={}", self.command_for_err())]
58 ReadError {
59 command: Option<Command>,
60 error: String,
61 partial: String,
62 },
63 #[error("Error {error} flushing send buffer")]
64 FlushError { error: String },
65 #[error("Error {error} writing {}", self.command_for_err())]
66 WriteError {
67 commands: Vec<Command>,
68 error: String,
69 },
70 #[error("Timed Out sending message payload data")]
71 TimeOutData,
72 #[error("SSL Error: {0}")]
73 SslErrorStack(#[from] openssl::error::ErrorStack),
74 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
75 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
76}
77
78impl ClientError {
79 pub fn command(&self) -> Option<String> {
81 match self {
82 Self::TimeOutResponse {
83 command: Some(command),
84 ..
85 }
86 | Self::ReadError {
87 command: Some(command),
88 ..
89 } => Some(command.encode()),
90 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
91 if !commands.is_empty() =>
92 {
93 let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
94 Some(commands.join(""))
95 }
96 _ => None,
97 }
98 }
99
100 fn command_for_err(&self) -> String {
103 self.command()
104 .map(|cmd| cmd.replace("\r\n", ""))
105 .unwrap_or_else(|| "NONE".to_string())
106 }
107
108 pub fn was_due_to_message(&self) -> bool {
120 match self {
121 Self::Utf8(_)
122 | Self::MalformedResponseLine(_)
123 | Self::ResponseTooLong
124 | Self::NotConnected
125 | Self::InvalidDnsName(_)
126 | Self::FailedToBuildConnector { .. }
127 | Self::TimeOutResponse { .. }
128 | Self::TimeOutRequest { .. }
129 | Self::ReadError { .. }
130 | Self::FlushError { .. }
131 | Self::WriteError { .. }
132 | Self::TimeOutData
133 | Self::SslErrorStack(_)
134 | Self::NoUsableDaneTlsa { .. } => false,
135 Self::Rejected(response) => response.was_due_to_message(),
136 Self::RejectedBatch(responses) => match responses.len() {
137 1 => responses[0].was_due_to_message(),
138 _ => false,
139 },
140 }
141 }
142}
143
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub struct EsmtpCapability {
146 pub name: String,
147 pub param: Option<String>,
148}
149
150#[derive(Clone, Debug)]
151pub enum SmtpClientTraceEvent {
152 Closed,
153 Read(Vec<u8>),
154 Write(String),
155 Diagnostic {
156 level: tracing::Level,
157 message: String,
158 },
159}
160
161pub trait DeferredTracer {
162 fn trace(&self) -> SmtpClientTraceEvent;
163}
164
165pub trait SmtpClientTracer: std::fmt::Debug {
166 fn trace_event(&self, event: SmtpClientTraceEvent);
167 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
168}
169
170struct WriteTracer<'a> {
172 data: &'a str,
173}
174impl DeferredTracer for WriteTracer<'_> {
175 fn trace(&self) -> SmtpClientTraceEvent {
176 SmtpClientTraceEvent::Write(self.data.to_string())
177 }
178}
179impl<'a> WriteTracer<'a> {
180 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
181 tracer.lazy_trace(&Self { data });
182 }
183}
184
185struct BinWriteTracer<'a> {
186 data: &'a [u8],
187}
188impl DeferredTracer for BinWriteTracer<'_> {
189 fn trace(&self) -> SmtpClientTraceEvent {
190 let data = String::from_utf8_lossy(self.data).to_string();
191 SmtpClientTraceEvent::Write(data)
192 }
193}
194impl<'a> BinWriteTracer<'a> {
195 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
196 tracer.lazy_trace(&Self { data });
197 }
198}
199
200struct ReadTracer<'a> {
203 data: &'a [u8],
204}
205impl DeferredTracer for ReadTracer<'_> {
206 fn trace(&self) -> SmtpClientTraceEvent {
207 SmtpClientTraceEvent::Read(self.data.to_vec())
208 }
209}
210
211#[derive(Debug)]
212pub struct SmtpClient {
213 socket: Option<BoxedAsyncReadAndWrite>,
214 hostname: String,
215 capabilities: HashMap<String, EsmtpCapability>,
216 read_buffer: Vec<u8>,
217 timeouts: SmtpClientTimeouts,
218 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
219 use_rset: bool,
220 enable_rset: bool,
221 enable_pipelining: bool,
222 ignore_8bit_checks: bool,
223}
224
225fn extract_hostname(hostname: &str) -> &str {
226 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
228 let hostname = if fields.len() == 2 {
229 fields[1]
230 } else {
231 hostname
232 };
233
234 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
235 &hostname[1..hostname.len() - 1]
236 } else {
237 hostname
238 };
239
240 hostname.strip_suffix(".").unwrap_or(hostname)
242}
243
244impl SmtpClient {
245 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
246 addr: A,
247 timeouts: SmtpClientTimeouts,
248 ) -> std::io::Result<Self> {
249 let stream = TcpStream::connect(addr.clone()).await?;
250 stream.set_nodelay(true)?;
252 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
253 }
254
255 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
256 stream: S,
257 peer_hostname: H,
258 timeouts: SmtpClientTimeouts,
259 ) -> Self {
260 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
261
262 Self {
263 socket: Some(Box::new(stream)),
264 hostname,
265 capabilities: HashMap::new(),
266 read_buffer: Vec::with_capacity(1024),
267 timeouts,
268 tracer: None,
269 use_rset: false,
270 enable_rset: false,
271 enable_pipelining: false,
272 ignore_8bit_checks: false,
273 }
274 }
275
276 pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
278 self.ignore_8bit_checks = enable;
279 }
280
281 pub fn is_connected(&self) -> bool {
282 self.socket.is_some()
283 }
284
285 pub fn set_enable_rset(&mut self, enable: bool) {
286 self.enable_rset = enable;
287 }
288
289 pub fn set_enable_pipelining(&mut self, enable: bool) {
290 self.enable_pipelining = enable;
291 }
292
293 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
294 self.tracer.replace(tracer);
295 }
296
297 pub fn timeouts(&self) -> &SmtpClientTimeouts {
298 &self.timeouts
299 }
300
301 async fn read_line(
302 &mut self,
303 timeout_duration: Duration,
304 cmd: Option<&Command>,
305 ) -> Result<String, ClientError> {
306 let mut too_long = false;
307 loop {
308 let mut iter = self.read_buffer.iter().enumerate();
309 while let Some((i, &b)) = iter.next() {
310 if b != b'\r' {
311 continue;
312 }
313 if let Some((_, b'\n')) = iter.next() {
314 if too_long {
315 self.read_buffer.drain(0..i + 2);
316
317 if let Some(tracer) = &self.tracer {
318 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
319 level: Level::ERROR,
320 message: "Response too long".to_string(),
321 });
322 }
323
324 return Err(ClientError::ResponseTooLong);
325 }
326
327 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
328 self.read_buffer.drain(0..i + 2);
329 return Ok(line?);
330 }
331 }
332 if self.read_buffer.len() > MAX_LINE_LEN {
333 self.read_buffer.clear();
334 too_long = true;
335 }
336
337 let mut data = [0u8; MAX_LINE_LEN];
339 let size = match self.socket.as_mut() {
340 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
341 Ok(Ok(size)) => size,
342 Ok(Err(err)) => {
343 self.socket.take();
344 if let Some(tracer) = &self.tracer {
345 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
346 level: Level::ERROR,
347 message: format!("Error during read: {err:#}"),
348 });
349 tracer.trace_event(SmtpClientTraceEvent::Closed);
350 }
351 return Err(ClientError::ReadError {
352 command: cmd.cloned(),
353 error: format!("{err:#}"),
354 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
355 });
356 }
357 Err(_) => {
358 if timeout_duration != Duration::ZERO {
363 self.socket.take();
364 }
365 if let Some(tracer) = &self.tracer {
366 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
367 level: Level::ERROR,
368 message: format!("Read Timeout after {timeout_duration:?}"),
369 });
370 if self.socket.is_none() {
371 tracer.trace_event(SmtpClientTraceEvent::Closed);
372 }
373 }
374 return Err(ClientError::TimeOutResponse {
375 command: cmd.cloned(),
376 duration: timeout_duration,
377 });
378 }
379 },
380 None => {
381 return Err(ClientError::ReadError {
382 command: cmd.cloned(),
383 error: "the socket was closed in response to an earlier issue".to_string(),
384 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
385 });
386 }
387 };
388 if size == 0 {
389 self.socket.take();
390 if let Some(tracer) = &self.tracer {
391 tracer.trace_event(SmtpClientTraceEvent::Closed);
392 }
393 return Err(ClientError::ReadError {
394 command: cmd.cloned(),
395 error: "Connection closed by peer".to_string(),
396 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
397 });
398 }
399 if let Some(tracer) = &self.tracer {
400 tracer.lazy_trace(&ReadTracer {
401 data: &data[0..size],
402 });
403 }
404 self.read_buffer.extend_from_slice(&data[0..size]);
405 }
406 }
407
408 pub async fn check_unilateral_response(&mut self) -> Result<Option<Response>, ClientError> {
416 match self.read_response(None, Duration::ZERO).await {
417 Ok(response) => Ok(Some(response)),
419 Err(ClientError::TimeOutResponse { .. }) => Ok(None),
421 Err(err) => Err(err),
423 }
424 }
425
426 pub async fn read_response(
427 &mut self,
428 command: Option<&Command>,
429 timeout_duration: Duration,
430 ) -> Result<Response, ClientError> {
431 if let Some(sock) = self.socket.as_mut() {
432 if let Err(err) = sock.flush().await {
433 self.socket.take();
434 if let Some(tracer) = &self.tracer {
435 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
436 level: Level::ERROR,
437 message: format!("Error during flush: {err:#}"),
438 });
439 tracer.trace_event(SmtpClientTraceEvent::Closed);
440 }
441 return Err(ClientError::FlushError {
442 error: format!("{err:#}"),
443 });
444 }
445 }
446
447 let mut line = self.read_line(timeout_duration, command).await?;
448 tracing::trace!("recv<-{}: {line}", self.hostname);
449 let mut parsed = parse_response_line(&line)?;
450 let mut response_builder = ResponseBuilder::new(&parsed);
451
452 let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
453 while !parsed.is_final {
454 line = self
455 .read_line(subsequent_line_timeout_duration, command)
456 .await?;
457 parsed = parse_response_line(&line)?;
458 response_builder
459 .add_line(&parsed)
460 .map_err(ClientError::MalformedResponseLine)?;
461 }
462
463 let response = response_builder.build(command.map(|cmd| cmd.encode()));
464
465 tracing::trace!("{}: {response:?}", self.hostname);
466
467 Ok(response)
468 }
469
470 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
471 self.write_command_request(command).await?;
472 self.read_response(Some(command), command.client_timeout(&self.timeouts))
473 .await
474 }
475
476 async fn write_all_with_timeout<F, G>(
480 &mut self,
481 timeout_duration: Duration,
482 bytes: &[u8],
483 make_timeout_err: F,
484 make_write_err: G,
485 ) -> Result<(), ClientError>
486 where
487 F: FnOnce() -> ClientError,
488 G: FnOnce(String) -> ClientError,
489 {
490 match self.socket.as_mut() {
491 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
492 Ok(Ok(response)) => Ok(response),
493 Ok(Err(err)) => {
494 self.socket.take();
495 if let Some(tracer) = &self.tracer {
496 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
497 level: Level::ERROR,
498 message: format!("Error during write: {err:#}"),
499 });
500 tracer.trace_event(SmtpClientTraceEvent::Closed);
501 }
502 Err(make_write_err(format!("{err:#}")))
503 }
504 Err(_) => {
505 self.socket.take();
506 if let Some(tracer) = &self.tracer {
507 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
508 level: Level::ERROR,
509 message: format!("Write Timeout after {timeout_duration:?}"),
510 });
511 tracer.trace_event(SmtpClientTraceEvent::Closed);
512 }
513 Err(make_timeout_err())
514 }
515 },
516 None => Err(make_write_err(
517 "the socket was closed in response to an earlier issue".to_string(),
518 )),
519 }
520 }
521
522 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
523 let total_timeout: Duration = commands
524 .iter()
525 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
526 .sum();
527
528 let mut lines: Vec<String> = vec![];
529 let mut all = String::new();
530 for cmd in commands {
531 let line = cmd.encode();
532 all.push_str(&line);
533 lines.push(line);
534 }
535 tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
536 if self.socket.is_some() {
537 if let Some(tracer) = &self.tracer {
538 for line in lines {
541 WriteTracer::trace(tracer, &line);
542 }
543 }
544 }
545 self.write_all_with_timeout(
546 total_timeout,
547 all.as_bytes(),
548 || ClientError::TimeOutRequest {
549 duration: total_timeout,
550 commands: commands.to_vec(),
551 },
552 |error| ClientError::WriteError {
553 error,
554 commands: commands.to_vec(),
555 },
556 )
557 .await
558 }
559
560 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
561 let line = command.encode();
562 tracing::trace!("send->{}: {line}", self.hostname);
563 if self.socket.is_some() {
564 if let Some(tracer) = &self.tracer {
565 WriteTracer::trace(tracer, &line);
566 }
567 }
568
569 let timeout_duration = command.client_timeout_request(&self.timeouts);
570 self.write_all_with_timeout(
571 timeout_duration,
572 line.as_bytes(),
573 || ClientError::TimeOutRequest {
574 duration: timeout_duration,
575 commands: vec![command.clone()],
576 },
577 |error| ClientError::WriteError {
578 error,
579 commands: vec![command.clone()],
580 },
581 )
582 .await
583 }
584
585 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
586 if self.socket.is_some() {
587 if let Some(tracer) = &self.tracer {
588 BinWriteTracer::trace(tracer, data);
589 }
590 }
591 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
592 self.write_all_with_timeout(
593 timeout_duration,
594 data,
595 || ClientError::TimeOutData,
596 |error| ClientError::WriteError {
597 error,
598 commands: vec![],
599 },
600 )
601 .await
602 }
603
604 pub async fn pipeline_commands(
620 &mut self,
621 commands: Vec<Command>,
622 ) -> Vec<Result<Response, ClientError>> {
623 let mut results: Vec<Result<Response, ClientError>> = vec![];
624
625 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
626 if pipeline {
627 if let Err(err) = self.write_pipeline_request(&commands).await {
628 let err: ClientError = err;
629 results.push(Err(err.clone()));
630 while results.len() < commands.len() {
631 results.push(Err(err.clone()));
633 }
634 return results;
635 }
636
637 for cmd in &commands {
639 results.push(
640 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
641 .await,
642 );
643 }
644 return results;
645 }
646
647 for cmd in &commands {
648 if let Err(err) = self.write_command_request(cmd).await {
649 let err: ClientError = err;
650 results.push(Err(err.clone()));
651 while results.len() < commands.len() {
652 results.push(Err(err.clone()));
654 }
655 return results;
656 }
657 results.push(
660 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
661 .await,
662 );
663 }
664 results
665 }
666
667 pub async fn ehlo_lhlo(
668 &mut self,
669 ehlo_name: &str,
670 use_lmtp: bool,
671 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
672 if use_lmtp {
673 self.lhlo(ehlo_name).await
674 } else {
675 self.ehlo(ehlo_name).await
676 }
677 }
678
679 pub async fn lhlo(
680 &mut self,
681 ehlo_name: &str,
682 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
683 let response = self
684 .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
685 .await?;
686 self.ehlo_common(response)
687 }
688
689 pub async fn ehlo(
690 &mut self,
691 ehlo_name: &str,
692 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
693 let response = self
694 .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
695 .await?;
696 self.ehlo_common(response)
697 }
698
699 fn ehlo_common(
700 &mut self,
701 response: Response,
702 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
703 if response.code != 250 {
704 return Err(ClientError::Rejected(response));
705 }
706
707 let mut capabilities = HashMap::new();
708
709 for line in response.content.lines().skip(1) {
710 let mut fields = line.splitn(2, ' ');
711 if let Some(name) = fields.next() {
712 let param = fields.next().map(|s| s.to_string());
713 let cap = EsmtpCapability {
714 name: name.to_string(),
715 param,
716 };
717 capabilities.insert(name.to_ascii_uppercase(), cap);
718 }
719 }
720
721 self.capabilities = capabilities;
722 Ok(&self.capabilities)
723 }
724
725 pub async fn auth_plain(
726 &mut self,
727 username: &str,
728 password: Option<&str>,
729 ) -> Result<(), ClientError> {
730 let password = password.unwrap_or("");
733 let payload = format!("\x00{username}\x00{password}");
734 let payload = data_encoding::BASE64.encode(payload.as_bytes());
735
736 let response = self
737 .send_command(&Command::Auth {
738 sasl_mech: "PLAIN".to_string(),
739 initial_response: Some(payload),
740 })
741 .await?;
742
743 if response.code != 235 {
744 return Err(ClientError::Rejected(response));
745 }
746
747 Ok(())
748 }
749
750 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
756 let resp = self.send_command(&Command::StartTls).await?;
757 if resp.code != 220 {
758 return Err(ClientError::Rejected(resp));
759 }
760
761 let mut handshake_error = None;
762 let mut tls_info = TlsInformation::default();
763
764 let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
765 || !options.dane_tlsa.is_empty()
766 {
767 let connector = options
768 .build_openssl_connector(&self.hostname)
769 .map_err(|error| ClientError::FailedToBuildConnector {
770 error: error.to_string(),
771 })?;
772 let ssl = connector.into_ssl(self.hostname.as_str())?;
773
774 let (stream, dup_stream) = match self.socket.take() {
775 Some(s) => {
776 let d = s.try_dup();
777 (s, d)
778 }
779 None => return Err(ClientError::NotConnected),
780 };
781
782 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
783
784 if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
785 handshake_error.replace(format!("{err:#}"));
786 }
787
788 tls_info.provider_name = "openssl".to_string();
789 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
790 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
791 None => String::new(),
792 };
793 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
794
795 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
796 tls_info.subject_name = subject_name(&cert);
797 }
798 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
799 if let Some(cert) = &authority.cert {
800 tls_info.subject_name = subject_name(cert);
801 }
802 }
803
804 match (&handshake_error, dup_stream) {
805 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
806 drop(ssl_stream);
811 Box::new(dup_stream)
812 }
813 _ => Box::new(ssl_stream),
814 }
815 } else {
816 tls_info.provider_name = "rustls".to_string();
817 let connector = options.build_tls_connector().await.map_err(|error| {
818 ClientError::FailedToBuildConnector {
819 error: error.to_string(),
820 }
821 })?;
822 let server_name = match IpAddr::from_str(self.hostname.as_str()) {
823 Ok(ip) => ServerName::IpAddress(ip.into()),
824 Err(_) => ServerName::try_from(self.hostname.clone())
825 .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
826 };
827
828 match connector
829 .connect(
830 server_name,
831 match self.socket.take() {
832 Some(s) => s,
833 None => return Err(ClientError::NotConnected),
834 },
835 )
836 .into_fallible()
837 .await
838 {
839 Ok(stream) => {
840 let (_, conn) = stream.get_ref();
841 tls_info.cipher = match conn.negotiated_cipher_suite() {
842 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
843 None => String::new(),
844 };
845 tls_info.protocol_version = match conn.protocol_version() {
846 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
847 None => String::new(),
848 };
849
850 if let Some(certs) = conn.peer_certificates() {
851 let peer_cert = &certs[0];
852 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
853 tls_info.subject_name = subject_name(&cert);
854 }
855 }
856
857 Box::new(stream)
858 }
859 Err((err, stream)) => {
860 handshake_error.replace(format!("{err:#}"));
861 stream
862 }
863 }
864 };
865
866 if let Some(tracer) = &self.tracer {
867 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
868 level: Level::INFO,
869 message: match &handshake_error {
870 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
871 None => format!("STARTTLS handshake -> {tls_info:?}"),
872 },
873 });
874 }
875
876 self.socket.replace(stream);
877 Ok(match handshake_error {
878 Some(error) => TlsStatus::FailedHandshake(error),
879 None => TlsStatus::Info(tls_info),
880 })
881 }
882
883 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
884 &mut self,
885 sender: SENDER,
886 recipient: RECIP,
887 data: B,
888 ) -> Result<Response, ClientError> {
889 let recipient: ForwardPath = recipient.into();
890 let status = self
891 .send_mail_multi_recip(sender, vec![recipient], data)
892 .await?;
893 Ok(status.response)
894 }
895
896 pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
897 &mut self,
898 sender: SENDER,
899 recipient_list: Vec<ForwardPath>,
900 data: B,
901 ) -> Result<BatchSendSuccess, ClientError> {
902 let sender = sender.into();
903
904 let data: &[u8] = data.as_ref();
905 let stuffed;
906
907 let data = match apply_dot_stuffing(data) {
908 Some(d) => {
909 stuffed = d;
910 &stuffed
911 }
912 None => data,
913 };
914
915 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
916 let envelope_is_8bit =
917 !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
918
919 let mut mail_from_params = vec![];
920 if data_is_8bit {
921 if self.capabilities.contains_key("8BITMIME") {
922 mail_from_params.push(EsmtpParameter {
923 name: "BODY".to_string(),
924 value: Some("8BITMIME".to_string()),
925 });
926 } else if !self.ignore_8bit_checks {
927 return Err(ClientError::Rejected(Response {
928 code: 554,
929 command: None,
930 enhanced_code: Some(EnhancedStatusCode {
931 class: 5,
932 subject: 6,
933 detail: 3,
934 }),
935 content: "KumoMTA internal: DATA is 8bit, destination does \
936 not support 8BITMIME. Conversion via msg:check_fix_conformance \
937 during reception is required"
938 .to_string(),
939 }));
940 }
941 }
942
943 if envelope_is_8bit {
944 if self.capabilities.contains_key("SMTPUTF8") {
945 mail_from_params.push(EsmtpParameter {
946 name: "SMTPUTF8".to_string(),
947 value: None,
948 });
949 } else if !self.ignore_8bit_checks {
950 return Err(ClientError::Rejected(Response {
951 code: 554,
952 command: None,
953 enhanced_code: Some(EnhancedStatusCode {
954 class: 5,
955 subject: 6,
956 detail: 7,
957 }),
958 content: "KumoMTA internal: envelope is 8bit, destination does \
959 not support SMTPUTF8."
960 .to_string(),
961 }));
962 }
963 }
964
965 let mut commands = vec![];
966
967 let used_rset = self.use_rset;
977 if self.use_rset {
978 commands.push(Command::Rset);
979 }
980 commands.push(Command::MailFrom {
981 address: sender,
982 parameters: mail_from_params,
983 });
984
985 for recipient in &recipient_list {
986 commands.push(Command::RcptTo {
987 address: recipient.clone(),
988 parameters: vec![],
989 });
990 }
991 commands.push(Command::Data);
992
993 self.use_rset = true;
996
997 let mut responses = self.pipeline_commands(commands).await;
998
999 let is_err = responses.iter().any(|r| r.is_err());
1007
1008 if used_rset {
1009 let rset_resp = responses.remove(0)?;
1010 if rset_resp.code != 250 {
1011 return Err(ClientError::Rejected(rset_resp));
1012 }
1013 }
1014
1015 let mail_resp = responses.remove(0)?;
1016 if is_err && mail_resp.code != 250 {
1017 return Err(ClientError::Rejected(mail_resp));
1018 }
1019
1020 let mut rcpt_responses = vec![];
1021 for _ in &recipient_list {
1022 rcpt_responses.push(responses.remove(0)?);
1023 }
1024
1025 if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1026 return Err(ClientError::RejectedBatch(rcpt_responses));
1027 }
1028
1029 let data_resp = responses.remove(0)?;
1030 if is_err && data_resp.code != 354 {
1031 return Err(ClientError::Rejected(data_resp));
1032 }
1033
1034 if data_resp.code == 354
1035 && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1036 {
1037 self.write_data_with_timeout(b".\r\n").await?;
1046 let data_dot = Command::DataDot;
1047 let _ = self
1049 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1050 .await?;
1051
1052 }
1056
1057 if mail_resp.code != 250 {
1058 return Err(ClientError::Rejected(mail_resp));
1059 }
1060 if rcpt_responses.iter().all(|resp| resp.code != 250) {
1061 if rcpt_responses.len() == 1 {
1062 return Err(ClientError::Rejected(
1063 rcpt_responses.pop().expect("have at least one"),
1064 ));
1065 }
1066 return Err(ClientError::RejectedBatch(rcpt_responses));
1067 }
1068 if data_resp.code != 354 {
1069 return Err(ClientError::Rejected(data_resp));
1070 }
1071
1072 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1073
1074 tracing::trace!("message data is {} bytes", data.len());
1075
1076 self.write_data_with_timeout(data).await?;
1077
1078 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1079
1080 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1081
1082 self.write_data_with_timeout(marker.as_bytes()).await?;
1083
1084 let data_dot = Command::DataDot;
1085 let resp = self
1086 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1087 .await?;
1088 if resp.code != 250 {
1089 return Err(ClientError::Rejected(resp));
1090 }
1091
1092 self.use_rset = self.enable_rset;
1095
1096 Ok(BatchSendSuccess {
1097 response: resp,
1098 rcpt_responses,
1099 })
1100 }
1101}
1102
1103#[derive(Debug)]
1104pub struct BatchSendSuccess {
1105 pub response: Response,
1106 pub rcpt_responses: Vec<Response>,
1107}
1108
1109#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1110pub enum TlsStatus {
1111 FailedHandshake(String),
1112 Info(TlsInformation),
1113}
1114
1115#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1116pub struct TlsInformation {
1117 pub cipher: String,
1118 pub protocol_version: String,
1119 pub subject_name: Vec<String>,
1120 pub provider_name: String,
1121}
1122
1123impl Drop for SmtpClient {
1124 fn drop(&mut self) {
1125 if let Some(tracer) = &self.tracer {
1126 if self.socket.is_some() {
1127 tracer.trace_event(SmtpClientTraceEvent::Closed);
1128 }
1129 }
1130 }
1131}
1132fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1133 if line.len() < 4 {
1134 return Err(ClientError::MalformedResponseLine(line.to_string()));
1135 }
1136
1137 match line.as_bytes()[3] {
1138 b' ' | b'-' => match line[0..3].parse::<u16>() {
1139 Ok(code) => Ok(ResponseLine {
1140 code,
1141 is_final: line.as_bytes()[3] == b' ',
1142 content: &line[4..],
1143 }),
1144 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1145 },
1146 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1147 }
1148}
1149
1150fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1151 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1152
1153 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1154 return None;
1155 }
1156
1157 let mut stuffed = vec![];
1158 if data.starts_with(b".") {
1159 stuffed.push(b'.');
1160 }
1161 let mut last_idx = 0;
1162 for i in LFDOT.find_iter(data) {
1163 stuffed.extend_from_slice(&data[last_idx..=i]);
1164 stuffed.push(b'.');
1165 last_idx = i + 1;
1166 }
1167 stuffed.extend_from_slice(&data[last_idx..]);
1168 Some(stuffed)
1169}
1170
1171pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1178 let mut subject_name = vec![];
1179 for entry in cert.subject_name().entries() {
1180 if let Ok(obj) = entry.object().nid().short_name() {
1181 if let Ok(name) = entry.data().as_utf8() {
1182 subject_name.push(format!("{obj}={name}"));
1183 }
1184 }
1185 }
1186 subject_name
1187}
1188
1189#[cfg(test)]
1190mod test {
1191 use super::*;
1192 use crate::{MailPath, Mailbox};
1193
1194 #[test]
1195 fn test_stuffing() {
1196 assert_eq!(apply_dot_stuffing(b"foo"), None);
1197 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1198 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1199 assert_eq!(
1200 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1201 b"foo\n..bar\n...baz\n"
1202 );
1203 }
1204
1205 #[test]
1229 fn response_line_parsing() {
1230 assert_eq!(
1231 parse_response_line("220 woot").unwrap(),
1232 ResponseLine {
1233 code: 220,
1234 is_final: true,
1235 content: "woot"
1236 }
1237 );
1238 assert_eq!(
1239 parse_response_line("220-woot").unwrap(),
1240 ResponseLine {
1241 code: 220,
1242 is_final: false,
1243 content: "woot"
1244 }
1245 );
1246
1247 assert!(matches!(
1248 parse_response_line("220_woot"),
1249 Err(ClientError::MalformedResponseLine(_))
1250 ));
1251 assert!(matches!(
1252 parse_response_line("not really"),
1253 Err(ClientError::MalformedResponseLine(_))
1254 ));
1255 }
1256
1257 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1258 let mut parsed = parse_response_line(lines[0])?;
1259 let mut b = ResponseBuilder::new(&parsed);
1260 for line in &lines[1..] {
1261 parsed = parse_response_line(line)?;
1262 b.add_line(&parsed)
1263 .map_err(ClientError::MalformedResponseLine)?;
1264 }
1265 assert!(parsed.is_final);
1266 Ok(b.build(None))
1267 }
1268
1269 #[test]
1270 fn multi_line_response() {
1271 assert_eq!(
1272 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1273 Response {
1274 code: 220,
1275 enhanced_code: None,
1276 content: "woot\nmore\ndone".to_string(),
1277 command: None
1278 }
1279 );
1280
1281 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1282 assert!(
1283 matches!(
1284 res,
1285 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1286 ),
1287 "got error {res:?}"
1288 );
1289
1290 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1291 assert!(
1292 matches!(
1293 res,
1294 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1295 ),
1296 "got error {res:?}"
1297 );
1298
1299 assert_eq!(
1300 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1301 Response {
1302 code: 220,
1303 enhanced_code: Some(EnhancedStatusCode {
1304 class: 4,
1305 subject: 1,
1306 detail: 0
1307 }),
1308 content: "woot\nmore\ndone".to_string(),
1309 command: None
1310 }
1311 );
1312
1313 assert_eq!(
1315 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1316 Response {
1317 code: 220,
1318 enhanced_code: Some(EnhancedStatusCode {
1319 class: 4,
1320 subject: 1,
1321 detail: 0
1322 }),
1323 content: "woot\nmore\ndone".to_string(),
1324 command: None
1325 }
1326 );
1327
1328 assert_eq!(
1330 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1331 Response {
1332 code: 220,
1333 enhanced_code: Some(EnhancedStatusCode {
1334 class: 4,
1335 subject: 1,
1336 detail: 0
1337 }),
1338 content: "woot\nmore\n5.5.5 done".to_string(),
1339 command: None
1340 }
1341 );
1342 }
1343
1344 #[test]
1345 fn test_extract_hostname() {
1346 assert_eq!(extract_hostname("foo"), "foo");
1347 assert_eq!(extract_hostname("foo."), "foo");
1348 assert_eq!(extract_hostname("foo:25"), "foo");
1349 assert_eq!(extract_hostname("foo.:25"), "foo");
1350 assert_eq!(extract_hostname("[foo]:25"), "foo");
1351 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1352 assert_eq!(extract_hostname("[::1]:25"), "::1");
1353 assert_eq!(extract_hostname("::1:25"), "::1");
1354 }
1355
1356 #[test]
1357 fn test_format_error_command() {
1358 assert_eq!(
1359 format!(
1360 "{:#}",
1361 ClientError::TimeOutRequest {
1362 commands: vec![Command::DataDot],
1363 duration: Duration::from_secs(10),
1364 }
1365 ),
1366 "Timed Out after 10s writing cmd=."
1367 );
1368 assert_eq!(
1369 format!(
1370 "{:#}",
1371 ClientError::TimeOutResponse {
1372 command: Some(Command::MailFrom {
1373 address: ReversePath::Path(MailPath {
1374 at_domain_list: vec![],
1375 mailbox: Mailbox {
1376 local_part: "user".to_string(),
1377 domain: Domain::Name("host".to_string())
1378 }
1379 }),
1380 parameters: vec![],
1381 }),
1382 duration: Duration::from_secs(10),
1383 }
1384 ),
1385 r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1386 );
1387 }
1388}