1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4 AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5 ReversePath,
6};
7use hickory_proto::rr::rdata::TLSA;
8use memchr::memmem::Finder;
9use openssl::x509::{X509Ref, X509};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::str::FromStr;
14use std::sync::{Arc, LazyLock};
15use std::time::Duration;
16use thiserror::Error;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpStream, ToSocketAddrs};
19use tokio::time::timeout;
20use tokio_rustls::rustls::pki_types::ServerName;
21use tracing::Level;
22
23pub use kumo_tls_helper::TlsOptions;
24pub use {openssl, tokio_rustls};
25
26const MAX_LINE_LEN: usize = 4096;
27
28#[derive(Error, Debug, Clone)]
29pub enum ClientError {
30 #[error("response is not UTF8")]
31 Utf8(#[from] std::string::FromUtf8Error),
32 #[error("Malformed Response: {0}")]
33 MalformedResponseLine(String),
34 #[error("Response line is too long")]
35 ResponseTooLong,
36 #[error("Not connected")]
37 NotConnected,
38 #[error("Command rejected {0:?}")]
39 Rejected(Response),
40 #[error("Commands rejected {0:?}")]
41 RejectedBatch(Vec<Response>),
42 #[error("STARTTLS: {0} is not a valid DNS name")]
43 InvalidDnsName(String),
44 #[error("Invalid client certificate configured: {error:?}")]
45 FailedToBuildConnector { error: String },
46 #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
47 TimeOutResponse {
48 command: Option<Command>,
49 duration: Duration,
50 },
51 #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
52 TimeOutRequest {
53 commands: Vec<Command>,
54 duration: Duration,
55 },
56 #[error("Error {error} reading response to cmd={}", self.command_for_err())]
57 ReadError {
58 command: Option<Command>,
59 error: String,
60 partial: String,
61 },
62 #[error("Error {error} flushing send buffer")]
63 FlushError { error: String },
64 #[error("Error {error} writing {}", self.command_for_err())]
65 WriteError {
66 commands: Vec<Command>,
67 error: String,
68 },
69 #[error("Timed Out sending message payload data")]
70 TimeOutData,
71 #[error("SSL Error: {0}")]
72 SslErrorStack(#[from] openssl::error::ErrorStack),
73 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
74 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
75}
76
77impl ClientError {
78 pub fn command(&self) -> Option<String> {
80 match self {
81 Self::TimeOutResponse {
82 command: Some(command),
83 ..
84 }
85 | Self::ReadError {
86 command: Some(command),
87 ..
88 } => Some(command.encode()),
89 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
90 if !commands.is_empty() =>
91 {
92 let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
93 Some(commands.join(""))
94 }
95 _ => None,
96 }
97 }
98
99 fn command_for_err(&self) -> String {
102 self.command()
103 .map(|cmd| cmd.replace("\r\n", ""))
104 .unwrap_or_else(|| "NONE".to_string())
105 }
106
107 pub fn was_due_to_message(&self) -> bool {
119 match self {
120 Self::Utf8(_)
121 | Self::MalformedResponseLine(_)
122 | Self::ResponseTooLong
123 | Self::NotConnected
124 | Self::InvalidDnsName(_)
125 | Self::FailedToBuildConnector { .. }
126 | Self::TimeOutResponse { .. }
127 | Self::TimeOutRequest { .. }
128 | Self::ReadError { .. }
129 | Self::FlushError { .. }
130 | Self::WriteError { .. }
131 | Self::TimeOutData
132 | Self::SslErrorStack(_)
133 | Self::NoUsableDaneTlsa { .. } => false,
134 Self::Rejected(response) => response.was_due_to_message(),
135 Self::RejectedBatch(responses) => match responses.len() {
136 1 => responses[0].was_due_to_message(),
137 _ => false,
138 },
139 }
140 }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub struct EsmtpCapability {
145 pub name: String,
146 pub param: Option<String>,
147}
148
149#[derive(Clone, Debug)]
150pub enum SmtpClientTraceEvent {
151 Closed,
152 Read(Vec<u8>),
153 Write(String),
154 Diagnostic {
155 level: tracing::Level,
156 message: String,
157 },
158}
159
160pub trait DeferredTracer {
161 fn trace(&self) -> SmtpClientTraceEvent;
162}
163
164pub trait SmtpClientTracer: std::fmt::Debug {
165 fn trace_event(&self, event: SmtpClientTraceEvent);
166 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
167}
168
169struct WriteTracer<'a> {
171 data: &'a str,
172}
173impl DeferredTracer for WriteTracer<'_> {
174 fn trace(&self) -> SmtpClientTraceEvent {
175 SmtpClientTraceEvent::Write(self.data.to_string())
176 }
177}
178impl<'a> WriteTracer<'a> {
179 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
180 tracer.lazy_trace(&Self { data });
181 }
182}
183
184struct BinWriteTracer<'a> {
185 data: &'a [u8],
186}
187impl DeferredTracer for BinWriteTracer<'_> {
188 fn trace(&self) -> SmtpClientTraceEvent {
189 let data = String::from_utf8_lossy(self.data).to_string();
190 SmtpClientTraceEvent::Write(data)
191 }
192}
193impl<'a> BinWriteTracer<'a> {
194 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
195 tracer.lazy_trace(&Self { data });
196 }
197}
198
199struct ReadTracer<'a> {
202 data: &'a [u8],
203}
204impl DeferredTracer for ReadTracer<'_> {
205 fn trace(&self) -> SmtpClientTraceEvent {
206 SmtpClientTraceEvent::Read(self.data.to_vec())
207 }
208}
209
210#[derive(Debug)]
211pub struct SmtpClient {
212 socket: Option<BoxedAsyncReadAndWrite>,
213 hostname: String,
214 capabilities: HashMap<String, EsmtpCapability>,
215 read_buffer: Vec<u8>,
216 timeouts: SmtpClientTimeouts,
217 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
218 use_rset: bool,
219 enable_rset: bool,
220 enable_pipelining: bool,
221 ignore_8bit_checks: bool,
222}
223
224fn extract_hostname(hostname: &str) -> &str {
225 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
227 let hostname = if fields.len() == 2 {
228 fields[1]
229 } else {
230 hostname
231 };
232
233 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
234 &hostname[1..hostname.len() - 1]
235 } else {
236 hostname
237 };
238
239 hostname.strip_suffix(".").unwrap_or(hostname)
241}
242
243impl SmtpClient {
244 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
245 addr: A,
246 timeouts: SmtpClientTimeouts,
247 ) -> std::io::Result<Self> {
248 let stream = TcpStream::connect(addr.clone()).await?;
249 stream.set_nodelay(true)?;
251 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
252 }
253
254 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
255 stream: S,
256 peer_hostname: H,
257 timeouts: SmtpClientTimeouts,
258 ) -> Self {
259 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
260
261 Self {
262 socket: Some(Box::new(stream)),
263 hostname,
264 capabilities: HashMap::new(),
265 read_buffer: Vec::with_capacity(1024),
266 timeouts,
267 tracer: None,
268 use_rset: false,
269 enable_rset: false,
270 enable_pipelining: false,
271 ignore_8bit_checks: false,
272 }
273 }
274
275 pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
277 self.ignore_8bit_checks = enable;
278 }
279
280 pub fn is_connected(&self) -> bool {
281 self.socket.is_some()
282 }
283
284 pub fn set_enable_rset(&mut self, enable: bool) {
285 self.enable_rset = enable;
286 }
287
288 pub fn set_enable_pipelining(&mut self, enable: bool) {
289 self.enable_pipelining = enable;
290 }
291
292 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
293 self.tracer.replace(tracer);
294 }
295
296 pub fn timeouts(&self) -> &SmtpClientTimeouts {
297 &self.timeouts
298 }
299
300 async fn read_line(
301 &mut self,
302 timeout_duration: Duration,
303 cmd: Option<&Command>,
304 ) -> Result<String, ClientError> {
305 let mut too_long = false;
306 loop {
307 let mut iter = self.read_buffer.iter().enumerate();
308 while let Some((i, &b)) = iter.next() {
309 if b != b'\r' {
310 continue;
311 }
312 if let Some((_, b'\n')) = iter.next() {
313 if too_long {
314 self.read_buffer.drain(0..i + 2);
315
316 if let Some(tracer) = &self.tracer {
317 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
318 level: Level::ERROR,
319 message: "Response too long".to_string(),
320 });
321 }
322
323 return Err(ClientError::ResponseTooLong);
324 }
325
326 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
327 self.read_buffer.drain(0..i + 2);
328 return Ok(line?);
329 }
330 }
331 if self.read_buffer.len() > MAX_LINE_LEN {
332 self.read_buffer.clear();
333 too_long = true;
334 }
335
336 let mut data = [0u8; MAX_LINE_LEN];
338 let size = match self.socket.as_mut() {
339 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
340 Ok(Ok(size)) => size,
341 Ok(Err(err)) => {
342 self.socket.take();
343 if let Some(tracer) = &self.tracer {
344 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
345 level: Level::ERROR,
346 message: format!("Error during read: {err:#}"),
347 });
348 tracer.trace_event(SmtpClientTraceEvent::Closed);
349 }
350 return Err(ClientError::ReadError {
351 command: cmd.cloned(),
352 error: format!("{err:#}"),
353 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
354 });
355 }
356 Err(_) => {
357 if timeout_duration != Duration::ZERO {
362 self.socket.take();
363 }
364 if let Some(tracer) = &self.tracer {
365 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
366 level: Level::ERROR,
367 message: format!("Read Timeout after {timeout_duration:?}"),
368 });
369 if self.socket.is_none() {
370 tracer.trace_event(SmtpClientTraceEvent::Closed);
371 }
372 }
373 return Err(ClientError::TimeOutResponse {
374 command: cmd.cloned(),
375 duration: timeout_duration,
376 });
377 }
378 },
379 None => {
380 return Err(ClientError::ReadError {
381 command: cmd.cloned(),
382 error: "the socket was closed in response to an earlier issue".to_string(),
383 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
384 });
385 }
386 };
387 if size == 0 {
388 self.socket.take();
389 if let Some(tracer) = &self.tracer {
390 tracer.trace_event(SmtpClientTraceEvent::Closed);
391 }
392 return Err(ClientError::ReadError {
393 command: cmd.cloned(),
394 error: "Connection closed by peer".to_string(),
395 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
396 });
397 }
398 if let Some(tracer) = &self.tracer {
399 tracer.lazy_trace(&ReadTracer {
400 data: &data[0..size],
401 });
402 }
403 self.read_buffer.extend_from_slice(&data[0..size]);
404 }
405 }
406
407 pub async fn check_unilateral_response(&mut self) -> Result<Option<Response>, ClientError> {
415 match self.read_response(None, Duration::ZERO).await {
416 Ok(response) => Ok(Some(response)),
418 Err(ClientError::TimeOutResponse { .. }) => Ok(None),
420 Err(err) => Err(err),
422 }
423 }
424
425 pub async fn read_response(
426 &mut self,
427 command: Option<&Command>,
428 timeout_duration: Duration,
429 ) -> Result<Response, ClientError> {
430 if let Some(sock) = self.socket.as_mut() {
431 if let Err(err) = sock.flush().await {
432 self.socket.take();
433 if let Some(tracer) = &self.tracer {
434 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
435 level: Level::ERROR,
436 message: format!("Error during flush: {err:#}"),
437 });
438 tracer.trace_event(SmtpClientTraceEvent::Closed);
439 }
440 return Err(ClientError::FlushError {
441 error: format!("{err:#}"),
442 });
443 }
444 }
445
446 let mut line = self.read_line(timeout_duration, command).await?;
447 tracing::trace!("recv<-{}: {line}", self.hostname);
448 let mut parsed = parse_response_line(&line)?;
449 let mut response_builder = ResponseBuilder::new(&parsed);
450
451 let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
452 while !parsed.is_final {
453 line = self
454 .read_line(subsequent_line_timeout_duration, command)
455 .await?;
456 parsed = parse_response_line(&line)?;
457 response_builder
458 .add_line(&parsed)
459 .map_err(ClientError::MalformedResponseLine)?;
460 }
461
462 let response = response_builder.build(command.map(|cmd| cmd.encode()));
463
464 tracing::trace!("{}: {response:?}", self.hostname);
465
466 Ok(response)
467 }
468
469 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
470 self.write_command_request(command).await?;
471 self.read_response(Some(command), command.client_timeout(&self.timeouts))
472 .await
473 }
474
475 async fn write_all_with_timeout<F, G>(
479 &mut self,
480 timeout_duration: Duration,
481 bytes: &[u8],
482 make_timeout_err: F,
483 make_write_err: G,
484 ) -> Result<(), ClientError>
485 where
486 F: FnOnce() -> ClientError,
487 G: FnOnce(String) -> ClientError,
488 {
489 match self.socket.as_mut() {
490 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
491 Ok(Ok(response)) => Ok(response),
492 Ok(Err(err)) => {
493 self.socket.take();
494 if let Some(tracer) = &self.tracer {
495 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
496 level: Level::ERROR,
497 message: format!("Error during write: {err:#}"),
498 });
499 tracer.trace_event(SmtpClientTraceEvent::Closed);
500 }
501 Err(make_write_err(format!("{err:#}")))
502 }
503 Err(_) => {
504 self.socket.take();
505 if let Some(tracer) = &self.tracer {
506 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
507 level: Level::ERROR,
508 message: format!("Write Timeout after {timeout_duration:?}"),
509 });
510 tracer.trace_event(SmtpClientTraceEvent::Closed);
511 }
512 Err(make_timeout_err())
513 }
514 },
515 None => Err(make_write_err(
516 "the socket was closed in response to an earlier issue".to_string(),
517 )),
518 }
519 }
520
521 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
522 let total_timeout: Duration = commands
523 .iter()
524 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
525 .sum();
526
527 let mut lines: Vec<String> = vec![];
528 let mut all = String::new();
529 for cmd in commands {
530 let line = cmd.encode();
531 all.push_str(&line);
532 lines.push(line);
533 }
534 tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
535 if self.socket.is_some() {
536 if let Some(tracer) = &self.tracer {
537 for line in lines {
540 WriteTracer::trace(tracer, &line);
541 }
542 }
543 }
544 self.write_all_with_timeout(
545 total_timeout,
546 all.as_bytes(),
547 || ClientError::TimeOutRequest {
548 duration: total_timeout,
549 commands: commands.to_vec(),
550 },
551 |error| ClientError::WriteError {
552 error,
553 commands: commands.to_vec(),
554 },
555 )
556 .await
557 }
558
559 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
560 let line = command.encode();
561 tracing::trace!("send->{}: {line}", self.hostname);
562 if self.socket.is_some() {
563 if let Some(tracer) = &self.tracer {
564 WriteTracer::trace(tracer, &line);
565 }
566 }
567
568 let timeout_duration = command.client_timeout_request(&self.timeouts);
569 self.write_all_with_timeout(
570 timeout_duration,
571 line.as_bytes(),
572 || ClientError::TimeOutRequest {
573 duration: timeout_duration,
574 commands: vec![command.clone()],
575 },
576 |error| ClientError::WriteError {
577 error,
578 commands: vec![command.clone()],
579 },
580 )
581 .await
582 }
583
584 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
585 if self.socket.is_some() {
586 if let Some(tracer) = &self.tracer {
587 BinWriteTracer::trace(tracer, data);
588 }
589 }
590 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
591 self.write_all_with_timeout(
592 timeout_duration,
593 data,
594 || ClientError::TimeOutData,
595 |error| ClientError::WriteError {
596 error,
597 commands: vec![],
598 },
599 )
600 .await
601 }
602
603 pub async fn pipeline_commands(
619 &mut self,
620 commands: Vec<Command>,
621 ) -> Vec<Result<Response, ClientError>> {
622 let mut results: Vec<Result<Response, ClientError>> = vec![];
623
624 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
625 if pipeline {
626 if let Err(err) = self.write_pipeline_request(&commands).await {
627 let err: ClientError = err;
628 results.push(Err(err.clone()));
629 while results.len() < commands.len() {
630 results.push(Err(err.clone()));
632 }
633 return results;
634 }
635
636 for cmd in &commands {
638 results.push(
639 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
640 .await,
641 );
642 }
643 return results;
644 }
645
646 for cmd in &commands {
647 if let Err(err) = self.write_command_request(cmd).await {
648 let err: ClientError = err;
649 results.push(Err(err.clone()));
650 while results.len() < commands.len() {
651 results.push(Err(err.clone()));
653 }
654 return results;
655 }
656 results.push(
659 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
660 .await,
661 );
662 }
663 results
664 }
665
666 pub async fn ehlo_lhlo(
667 &mut self,
668 ehlo_name: &str,
669 use_lmtp: bool,
670 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
671 if use_lmtp {
672 self.lhlo(ehlo_name).await
673 } else {
674 self.ehlo(ehlo_name).await
675 }
676 }
677
678 pub async fn lhlo(
679 &mut self,
680 ehlo_name: &str,
681 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
682 let response = self
683 .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
684 .await?;
685 self.ehlo_common(response)
686 }
687
688 pub async fn ehlo(
689 &mut self,
690 ehlo_name: &str,
691 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
692 let response = self
693 .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
694 .await?;
695 self.ehlo_common(response)
696 }
697
698 fn ehlo_common(
699 &mut self,
700 response: Response,
701 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
702 if response.code != 250 {
703 return Err(ClientError::Rejected(response));
704 }
705
706 let mut capabilities = HashMap::new();
707
708 for line in response.content.lines().skip(1) {
709 let mut fields = line.splitn(2, ' ');
710 if let Some(name) = fields.next() {
711 let param = fields.next().map(|s| s.to_string());
712 let cap = EsmtpCapability {
713 name: name.to_string(),
714 param,
715 };
716 capabilities.insert(name.to_ascii_uppercase(), cap);
717 }
718 }
719
720 self.capabilities = capabilities;
721 Ok(&self.capabilities)
722 }
723
724 pub async fn auth_plain(
725 &mut self,
726 username: &str,
727 password: Option<&str>,
728 ) -> Result<(), ClientError> {
729 let password = password.unwrap_or("");
732 let payload = format!("\x00{username}\x00{password}");
733 let payload = data_encoding::BASE64.encode(payload.as_bytes());
734
735 let response = self
736 .send_command(&Command::Auth {
737 sasl_mech: "PLAIN".to_string(),
738 initial_response: Some(payload),
739 })
740 .await?;
741
742 if response.code != 235 {
743 return Err(ClientError::Rejected(response));
744 }
745
746 Ok(())
747 }
748
749 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
755 let resp = self.send_command(&Command::StartTls).await?;
756 if resp.code != 220 {
757 return Err(ClientError::Rejected(resp));
758 }
759
760 let mut handshake_error = None;
761 let mut tls_info = TlsInformation::default();
762
763 let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
764 || !options.dane_tlsa.is_empty()
765 {
766 let connector = options
767 .build_openssl_connector(&self.hostname)
768 .map_err(|error| ClientError::FailedToBuildConnector {
769 error: error.to_string(),
770 })?;
771 let ssl = connector.into_ssl(self.hostname.as_str())?;
772
773 let (stream, dup_stream) = match self.socket.take() {
774 Some(s) => {
775 let d = s.try_dup();
776 (s, d)
777 }
778 None => return Err(ClientError::NotConnected),
779 };
780
781 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
782
783 if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
784 handshake_error.replace(format!("{err:#}"));
785 }
786
787 tls_info.provider_name = "openssl".to_string();
788 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
789 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
790 None => String::new(),
791 };
792 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
793
794 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
795 tls_info.subject_name = subject_name(&cert);
796 }
797 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
798 if let Some(cert) = &authority.cert {
799 tls_info.subject_name = subject_name(cert);
800 }
801 }
802
803 match (&handshake_error, dup_stream) {
804 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
805 drop(ssl_stream);
810 Box::new(dup_stream)
811 }
812 _ => Box::new(ssl_stream),
813 }
814 } else {
815 tls_info.provider_name = "rustls".to_string();
816 let connector = options.build_tls_connector().await.map_err(|error| {
817 ClientError::FailedToBuildConnector {
818 error: error.to_string(),
819 }
820 })?;
821 let server_name = match IpAddr::from_str(self.hostname.as_str()) {
822 Ok(ip) => ServerName::IpAddress(ip.into()),
823 Err(_) => ServerName::try_from(self.hostname.clone())
824 .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
825 };
826
827 match connector
828 .connect(
829 server_name,
830 match self.socket.take() {
831 Some(s) => s,
832 None => return Err(ClientError::NotConnected),
833 },
834 )
835 .into_fallible()
836 .await
837 {
838 Ok(stream) => {
839 let (_, conn) = stream.get_ref();
840 tls_info.cipher = match conn.negotiated_cipher_suite() {
841 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
842 None => String::new(),
843 };
844 tls_info.protocol_version = match conn.protocol_version() {
845 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
846 None => String::new(),
847 };
848
849 if let Some(certs) = conn.peer_certificates() {
850 let peer_cert = &certs[0];
851 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
852 tls_info.subject_name = subject_name(&cert);
853 }
854 }
855
856 Box::new(stream)
857 }
858 Err((err, stream)) => {
859 handshake_error.replace(format!("{err:#}"));
860 stream
861 }
862 }
863 };
864
865 if let Some(tracer) = &self.tracer {
866 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
867 level: Level::INFO,
868 message: match &handshake_error {
869 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
870 None => format!("STARTTLS handshake -> {tls_info:?}"),
871 },
872 });
873 }
874
875 self.socket.replace(stream);
876 Ok(match handshake_error {
877 Some(error) => TlsStatus::FailedHandshake(error),
878 None => TlsStatus::Info(tls_info),
879 })
880 }
881
882 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
883 &mut self,
884 sender: SENDER,
885 recipient: RECIP,
886 data: B,
887 ) -> Result<Response, ClientError> {
888 let recipient: ForwardPath = recipient.into();
889 let status = self
890 .send_mail_multi_recip(sender, vec![recipient], data)
891 .await?;
892 Ok(status.response)
893 }
894
895 pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
896 &mut self,
897 sender: SENDER,
898 recipient_list: Vec<ForwardPath>,
899 data: B,
900 ) -> Result<BatchSendSuccess, ClientError> {
901 let sender = sender.into();
902
903 let data: &[u8] = data.as_ref();
904 let stuffed;
905
906 let data = match apply_dot_stuffing(data) {
907 Some(d) => {
908 stuffed = d;
909 &stuffed
910 }
911 None => data,
912 };
913
914 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
915 let envelope_is_8bit =
916 !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
917
918 let mut mail_from_params = vec![];
919 if data_is_8bit {
920 if self.capabilities.contains_key("8BITMIME") {
921 mail_from_params.push(EsmtpParameter {
922 name: "BODY".to_string(),
923 value: Some("8BITMIME".to_string()),
924 });
925 } else if !self.ignore_8bit_checks {
926 return Err(ClientError::Rejected(Response {
927 code: 554,
928 command: None,
929 enhanced_code: Some(EnhancedStatusCode {
930 class: 5,
931 subject: 6,
932 detail: 3,
933 }),
934 content: "KumoMTA internal: DATA is 8bit, destination does \
935 not support 8BITMIME. Conversion via msg:check_fix_conformance \
936 during reception is required"
937 .to_string(),
938 }));
939 }
940 }
941
942 if envelope_is_8bit {
943 if self.capabilities.contains_key("SMTPUTF8") {
944 mail_from_params.push(EsmtpParameter {
945 name: "SMTPUTF8".to_string(),
946 value: None,
947 });
948 } else if !self.ignore_8bit_checks {
949 return Err(ClientError::Rejected(Response {
950 code: 554,
951 command: None,
952 enhanced_code: Some(EnhancedStatusCode {
953 class: 5,
954 subject: 6,
955 detail: 7,
956 }),
957 content: "KumoMTA internal: envelope is 8bit, destination does \
958 not support SMTPUTF8."
959 .to_string(),
960 }));
961 }
962 }
963
964 let mut commands = vec![];
965
966 let used_rset = self.use_rset;
976 if self.use_rset {
977 commands.push(Command::Rset);
978 }
979 commands.push(Command::MailFrom {
980 address: sender,
981 parameters: mail_from_params,
982 });
983
984 for recipient in &recipient_list {
985 commands.push(Command::RcptTo {
986 address: recipient.clone(),
987 parameters: vec![],
988 });
989 }
990 commands.push(Command::Data);
991
992 self.use_rset = true;
995
996 let mut responses = self.pipeline_commands(commands).await;
997
998 let is_err = responses.iter().any(|r| r.is_err());
1006
1007 if used_rset {
1008 let rset_resp = responses.remove(0)?;
1009 if rset_resp.code != 250 {
1010 return Err(ClientError::Rejected(rset_resp));
1011 }
1012 }
1013
1014 let mail_resp = responses.remove(0)?;
1015 if is_err && mail_resp.code != 250 {
1016 return Err(ClientError::Rejected(mail_resp));
1017 }
1018
1019 let mut rcpt_responses = vec![];
1020 for _ in &recipient_list {
1021 rcpt_responses.push(responses.remove(0)?);
1022 }
1023
1024 if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1025 return Err(ClientError::RejectedBatch(rcpt_responses));
1026 }
1027
1028 let data_resp = responses.remove(0)?;
1029 if is_err && data_resp.code != 354 {
1030 return Err(ClientError::Rejected(data_resp));
1031 }
1032
1033 if data_resp.code == 354
1034 && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1035 {
1036 self.write_data_with_timeout(b".\r\n").await?;
1045 let data_dot = Command::DataDot;
1046 let _ = self
1048 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1049 .await?;
1050
1051 }
1055
1056 if mail_resp.code != 250 {
1057 return Err(ClientError::Rejected(mail_resp));
1058 }
1059 if rcpt_responses.iter().all(|resp| resp.code != 250) {
1060 if rcpt_responses.len() == 1 {
1061 return Err(ClientError::Rejected(
1062 rcpt_responses.pop().expect("have at least one"),
1063 ));
1064 }
1065 return Err(ClientError::RejectedBatch(rcpt_responses));
1066 }
1067 if data_resp.code != 354 {
1068 return Err(ClientError::Rejected(data_resp));
1069 }
1070
1071 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1072
1073 tracing::trace!("message data is {} bytes", data.len());
1074
1075 self.write_data_with_timeout(data).await?;
1076
1077 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1078
1079 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1080
1081 self.write_data_with_timeout(marker.as_bytes()).await?;
1082
1083 let data_dot = Command::DataDot;
1084 let resp = self
1085 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1086 .await?;
1087 if resp.code != 250 {
1088 return Err(ClientError::Rejected(resp));
1089 }
1090
1091 self.use_rset = self.enable_rset;
1094
1095 Ok(BatchSendSuccess {
1096 response: resp,
1097 rcpt_responses,
1098 })
1099 }
1100}
1101
1102#[derive(Debug)]
1103pub struct BatchSendSuccess {
1104 pub response: Response,
1105 pub rcpt_responses: Vec<Response>,
1106}
1107
1108#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1109pub enum TlsStatus {
1110 FailedHandshake(String),
1111 Info(TlsInformation),
1112}
1113
1114#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1115pub struct TlsInformation {
1116 pub cipher: String,
1117 pub protocol_version: String,
1118 pub subject_name: Vec<String>,
1119 pub provider_name: String,
1120}
1121
1122impl Drop for SmtpClient {
1123 fn drop(&mut self) {
1124 if let Some(tracer) = &self.tracer {
1125 if self.socket.is_some() {
1126 tracer.trace_event(SmtpClientTraceEvent::Closed);
1127 }
1128 }
1129 }
1130}
1131fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1132 if line.len() < 4 {
1133 return Err(ClientError::MalformedResponseLine(line.to_string()));
1134 }
1135
1136 match line.as_bytes()[3] {
1137 b' ' | b'-' => match line[0..3].parse::<u16>() {
1138 Ok(code) => Ok(ResponseLine {
1139 code,
1140 is_final: line.as_bytes()[3] == b' ',
1141 content: &line[4..],
1142 }),
1143 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1144 },
1145 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1146 }
1147}
1148
1149fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1150 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1151
1152 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1153 return None;
1154 }
1155
1156 let mut stuffed = vec![];
1157 if data.starts_with(b".") {
1158 stuffed.push(b'.');
1159 }
1160 let mut last_idx = 0;
1161 for i in LFDOT.find_iter(data) {
1162 stuffed.extend_from_slice(&data[last_idx..=i]);
1163 stuffed.push(b'.');
1164 last_idx = i + 1;
1165 }
1166 stuffed.extend_from_slice(&data[last_idx..]);
1167 Some(stuffed)
1168}
1169
1170pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1177 let mut subject_name = vec![];
1178 for entry in cert.subject_name().entries() {
1179 if let Ok(obj) = entry.object().nid().short_name() {
1180 if let Ok(name) = entry.data().as_utf8() {
1181 subject_name.push(format!("{obj}={name}"));
1182 }
1183 }
1184 }
1185 subject_name
1186}
1187
1188#[cfg(test)]
1189mod test {
1190 use super::*;
1191 use crate::{MailPath, Mailbox};
1192
1193 #[test]
1194 fn test_stuffing() {
1195 assert_eq!(apply_dot_stuffing(b"foo"), None);
1196 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1197 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1198 assert_eq!(
1199 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1200 b"foo\n..bar\n...baz\n"
1201 );
1202 }
1203
1204 #[test]
1228 fn response_line_parsing() {
1229 assert_eq!(
1230 parse_response_line("220 woot").unwrap(),
1231 ResponseLine {
1232 code: 220,
1233 is_final: true,
1234 content: "woot"
1235 }
1236 );
1237 assert_eq!(
1238 parse_response_line("220-woot").unwrap(),
1239 ResponseLine {
1240 code: 220,
1241 is_final: false,
1242 content: "woot"
1243 }
1244 );
1245
1246 assert!(matches!(
1247 parse_response_line("220_woot"),
1248 Err(ClientError::MalformedResponseLine(_))
1249 ));
1250 assert!(matches!(
1251 parse_response_line("not really"),
1252 Err(ClientError::MalformedResponseLine(_))
1253 ));
1254 }
1255
1256 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1257 let mut parsed = parse_response_line(lines[0])?;
1258 let mut b = ResponseBuilder::new(&parsed);
1259 for line in &lines[1..] {
1260 parsed = parse_response_line(line)?;
1261 b.add_line(&parsed)
1262 .map_err(ClientError::MalformedResponseLine)?;
1263 }
1264 assert!(parsed.is_final);
1265 Ok(b.build(None))
1266 }
1267
1268 #[test]
1269 fn multi_line_response() {
1270 assert_eq!(
1271 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1272 Response {
1273 code: 220,
1274 enhanced_code: None,
1275 content: "woot\nmore\ndone".to_string(),
1276 command: None
1277 }
1278 );
1279
1280 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1281 assert!(
1282 matches!(
1283 res,
1284 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1285 ),
1286 "got error {res:?}"
1287 );
1288
1289 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1290 assert!(
1291 matches!(
1292 res,
1293 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1294 ),
1295 "got error {res:?}"
1296 );
1297
1298 assert_eq!(
1299 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1300 Response {
1301 code: 220,
1302 enhanced_code: Some(EnhancedStatusCode {
1303 class: 4,
1304 subject: 1,
1305 detail: 0
1306 }),
1307 content: "woot\nmore\ndone".to_string(),
1308 command: None
1309 }
1310 );
1311
1312 assert_eq!(
1314 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1315 Response {
1316 code: 220,
1317 enhanced_code: Some(EnhancedStatusCode {
1318 class: 4,
1319 subject: 1,
1320 detail: 0
1321 }),
1322 content: "woot\nmore\ndone".to_string(),
1323 command: None
1324 }
1325 );
1326
1327 assert_eq!(
1329 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1330 Response {
1331 code: 220,
1332 enhanced_code: Some(EnhancedStatusCode {
1333 class: 4,
1334 subject: 1,
1335 detail: 0
1336 }),
1337 content: "woot\nmore\n5.5.5 done".to_string(),
1338 command: None
1339 }
1340 );
1341 }
1342
1343 #[test]
1344 fn test_extract_hostname() {
1345 assert_eq!(extract_hostname("foo"), "foo");
1346 assert_eq!(extract_hostname("foo."), "foo");
1347 assert_eq!(extract_hostname("foo:25"), "foo");
1348 assert_eq!(extract_hostname("foo.:25"), "foo");
1349 assert_eq!(extract_hostname("[foo]:25"), "foo");
1350 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1351 assert_eq!(extract_hostname("[::1]:25"), "::1");
1352 assert_eq!(extract_hostname("::1:25"), "::1");
1353 }
1354
1355 #[test]
1356 fn test_format_error_command() {
1357 assert_eq!(
1358 format!(
1359 "{:#}",
1360 ClientError::TimeOutRequest {
1361 commands: vec![Command::DataDot],
1362 duration: Duration::from_secs(10),
1363 }
1364 ),
1365 "Timed Out after 10s writing cmd=."
1366 );
1367 assert_eq!(
1368 format!(
1369 "{:#}",
1370 ClientError::TimeOutResponse {
1371 command: Some(Command::MailFrom {
1372 address: ReversePath::Path(MailPath {
1373 at_domain_list: vec![],
1374 mailbox: Mailbox {
1375 local_part: "user".to_string(),
1376 domain: Domain::Name("host".to_string())
1377 }
1378 }),
1379 parameters: vec![],
1380 }),
1381 duration: Duration::from_secs(10),
1382 }
1383 ),
1384 r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1385 );
1386 }
1387}