1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::parser::{Command, Domain, EsmtpParameter, ForwardPath, ReversePath};
4use crate::{AsyncReadAndWrite, BoxedAsyncReadAndWrite};
5use bstr::ByteSlice;
6use hickory_proto::rr::rdata::TLSA;
7use hickory_proto::rr::Name;
8use memchr::memmem::Finder;
9use nom_utils::DomainString;
10use openssl::x509::{X509Ref, X509};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::IpAddr;
14use std::str::FromStr;
15use std::sync::{Arc, LazyLock};
16use std::time::Duration;
17use thiserror::Error;
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tokio::net::{TcpStream, ToSocketAddrs};
20use tokio::time::{timeout, timeout_at, Instant};
21use tokio_rustls::rustls::pki_types::ServerName;
22use tracing::Level;
23
24pub use kumo_tls_helper::TlsOptions;
25pub use openssl;
26pub use tokio_rustls;
27
28const MAX_LINE_LEN: usize = 4096;
29
30#[derive(Error, Debug, Clone)]
31pub enum ClientError {
32 #[error("response is not UTF8")]
33 Utf8(#[from] std::string::FromUtf8Error),
34 #[error("Malformed Response: {0}")]
35 MalformedResponseLine(String),
36 #[error("Response line is too long")]
37 ResponseTooLong,
38 #[error("Not connected")]
39 NotConnected,
40 #[error("Command rejected {0:?}")]
41 Rejected(Response),
42 #[error("Commands rejected {0:?}")]
43 RejectedBatch(Vec<Response>),
44 #[error("STARTTLS: {0} is not a valid DNS name")]
45 InvalidDnsName(String),
46 #[error("Invalid client certificate configured: {error:?}")]
47 FailedToBuildConnector { error: String },
48 #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
49 TimeOutResponse {
50 command: Option<Command>,
51 duration: Duration,
52 },
53 #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
54 TimeOutRequest {
55 commands: Vec<Command>,
56 duration: Duration,
57 },
58 #[error("Error {error} reading response to cmd={}", self.command_for_err())]
59 ReadError {
60 command: Option<Command>,
61 error: String,
62 partial: String,
63 },
64 #[error("Error {error} flushing send buffer")]
65 FlushError { error: String },
66 #[error("Error {error} writing {}", self.command_for_err())]
67 WriteError {
68 commands: Vec<Command>,
69 error: String,
70 },
71 #[error("Timed Out sending message payload data")]
72 TimeOutData,
73 #[error("SSL Error: {0}")]
74 SslErrorStack(#[from] openssl::error::ErrorStack),
75 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
76 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
77}
78
79impl ClientError {
80 pub fn command(&self) -> Option<String> {
82 match self {
83 Self::TimeOutResponse {
84 command: Some(command),
85 ..
86 }
87 | Self::ReadError {
88 command: Some(command),
89 ..
90 } => Some(command.encode().to_string()),
91 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
92 if !commands.is_empty() =>
93 {
94 let s: String = commands
95 .iter()
96 .map(|cmd| cmd.encode().to_string())
97 .collect();
98 Some(s)
99 }
100 _ => None,
101 }
102 }
103
104 fn command_for_err(&self) -> String {
107 self.command()
108 .map(|cmd| cmd.replace("\r\n", ""))
109 .unwrap_or_else(|| "NONE".to_string())
110 }
111
112 pub fn was_due_to_message(&self) -> bool {
124 match self {
125 Self::Utf8(_)
126 | Self::MalformedResponseLine(_)
127 | Self::ResponseTooLong
128 | Self::NotConnected
129 | Self::InvalidDnsName(_)
130 | Self::FailedToBuildConnector { .. }
131 | Self::TimeOutResponse { .. }
132 | Self::TimeOutRequest { .. }
133 | Self::ReadError { .. }
134 | Self::FlushError { .. }
135 | Self::WriteError { .. }
136 | Self::TimeOutData
137 | Self::SslErrorStack(_)
138 | Self::NoUsableDaneTlsa { .. } => false,
139 Self::Rejected(response) => response.was_due_to_message(),
140 Self::RejectedBatch(responses) => match responses.len() {
141 1 => responses[0].was_due_to_message(),
142 _ => false,
143 },
144 }
145 }
146}
147
148#[derive(Debug, Clone, PartialEq, Eq)]
149pub struct EsmtpCapability {
150 pub name: String,
151 pub param: Option<String>,
152}
153
154#[derive(Clone, Debug)]
155pub enum SmtpClientTraceEvent {
156 Closed,
157 Read(Vec<u8>),
158 Write(String),
159 Diagnostic {
160 level: tracing::Level,
161 message: String,
162 },
163}
164
165pub trait DeferredTracer {
166 fn trace(&self) -> SmtpClientTraceEvent;
167}
168
169pub trait SmtpClientTracer: std::fmt::Debug {
170 fn trace_event(&self, event: SmtpClientTraceEvent);
171 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
172}
173
174struct WriteTracer<'a> {
176 data: &'a str,
177}
178impl DeferredTracer for WriteTracer<'_> {
179 fn trace(&self) -> SmtpClientTraceEvent {
180 SmtpClientTraceEvent::Write(self.data.to_string())
181 }
182}
183impl<'a> WriteTracer<'a> {
184 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
185 tracer.lazy_trace(&Self { data });
186 }
187}
188
189struct BinWriteTracer<'a> {
190 data: &'a [u8],
191}
192impl DeferredTracer for BinWriteTracer<'_> {
193 fn trace(&self) -> SmtpClientTraceEvent {
194 let data = String::from_utf8_lossy(self.data).to_string();
195 SmtpClientTraceEvent::Write(data)
196 }
197}
198impl<'a> BinWriteTracer<'a> {
199 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
200 tracer.lazy_trace(&Self { data });
201 }
202}
203
204struct ReadTracer<'a> {
207 data: &'a [u8],
208}
209impl DeferredTracer for ReadTracer<'_> {
210 fn trace(&self) -> SmtpClientTraceEvent {
211 SmtpClientTraceEvent::Read(self.data.to_vec())
212 }
213}
214
215#[derive(Debug)]
216pub struct SmtpClient {
217 socket: Option<BoxedAsyncReadAndWrite>,
218 hostname: String,
219 capabilities: HashMap<String, EsmtpCapability>,
220 read_buffer: Vec<u8>,
221 timeouts: SmtpClientTimeouts,
222 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
223 use_rset: bool,
224 enable_rset: bool,
225 enable_pipelining: bool,
226 ignore_8bit_checks: bool,
227}
228
229fn extract_hostname(hostname: &str) -> &str {
230 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
232 let hostname = if fields.len() == 2 {
233 fields[1]
234 } else {
235 hostname
236 };
237
238 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
239 &hostname[1..hostname.len() - 1]
240 } else {
241 hostname
242 };
243
244 hostname.strip_suffix(".").unwrap_or(hostname)
246}
247
248impl SmtpClient {
249 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
250 addr: A,
251 timeouts: SmtpClientTimeouts,
252 ) -> std::io::Result<Self> {
253 let stream = TcpStream::connect(addr.clone()).await?;
254 stream.set_nodelay(true)?;
256 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
257 }
258
259 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
260 stream: S,
261 peer_hostname: H,
262 timeouts: SmtpClientTimeouts,
263 ) -> Self {
264 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
265
266 Self {
267 socket: Some(Box::new(stream)),
268 hostname,
269 capabilities: HashMap::new(),
270 read_buffer: Vec::with_capacity(1024),
271 timeouts,
272 tracer: None,
273 use_rset: false,
274 enable_rset: false,
275 enable_pipelining: false,
276 ignore_8bit_checks: false,
277 }
278 }
279
280 pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
282 self.ignore_8bit_checks = enable;
283 }
284
285 pub fn is_connected(&self) -> bool {
286 self.socket.is_some()
287 }
288
289 pub fn set_enable_rset(&mut self, enable: bool) {
290 self.enable_rset = enable;
291 }
292
293 pub fn set_enable_pipelining(&mut self, enable: bool) {
294 self.enable_pipelining = enable;
295 }
296
297 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
298 self.tracer.replace(tracer);
299 }
300
301 pub fn timeouts(&self) -> &SmtpClientTimeouts {
302 &self.timeouts
303 }
304
305 async fn read_line(
306 &mut self,
307 timeout_duration: Duration,
308 cmd: Option<&Command>,
309 ) -> Result<String, ClientError> {
310 let mut too_long = false;
311 loop {
312 let mut iter = self.read_buffer.iter().enumerate();
313 while let Some((i, &b)) = iter.next() {
314 if b != b'\r' {
315 continue;
316 }
317 if let Some((_, b'\n')) = iter.next() {
318 if too_long {
319 self.read_buffer.drain(0..i + 2);
320
321 if let Some(tracer) = &self.tracer {
322 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
323 level: Level::ERROR,
324 message: "Response too long".to_string(),
325 });
326 }
327
328 return Err(ClientError::ResponseTooLong);
329 }
330
331 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
332 self.read_buffer.drain(0..i + 2);
333 return Ok(line?);
334 }
335 }
336 if self.read_buffer.len() > MAX_LINE_LEN {
337 self.read_buffer.clear();
338 too_long = true;
339 }
340
341 let mut data = [0u8; MAX_LINE_LEN];
343 let size = match self.socket.as_mut() {
344 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
345 Ok(Ok(size)) => size,
346 Ok(Err(err)) => {
347 self.socket.take();
348 if let Some(tracer) = &self.tracer {
349 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
350 level: Level::ERROR,
351 message: format!("Error during read: {err:#}"),
352 });
353 tracer.trace_event(SmtpClientTraceEvent::Closed);
354 }
355 return Err(ClientError::ReadError {
356 command: cmd.cloned(),
357 error: format!("{err:#}"),
358 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
359 });
360 }
361 Err(_) => {
362 if timeout_duration != Duration::ZERO {
367 self.socket.take();
368 }
369 if let Some(tracer) = &self.tracer {
370 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
371 level: Level::ERROR,
372 message: format!("Read Timeout after {timeout_duration:?}"),
373 });
374 if self.socket.is_none() {
375 tracer.trace_event(SmtpClientTraceEvent::Closed);
376 }
377 }
378 return Err(ClientError::TimeOutResponse {
379 command: cmd.cloned(),
380 duration: timeout_duration,
381 });
382 }
383 },
384 None => {
385 return Err(ClientError::ReadError {
386 command: cmd.cloned(),
387 error: "the socket was closed in response to an earlier issue".to_string(),
388 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
389 });
390 }
391 };
392 if size == 0 {
393 self.socket.take();
394 if let Some(tracer) = &self.tracer {
395 tracer.trace_event(SmtpClientTraceEvent::Closed);
396 }
397 return Err(ClientError::ReadError {
398 command: cmd.cloned(),
399 error: "Connection closed by peer".to_string(),
400 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
401 });
402 }
403 if let Some(tracer) = &self.tracer {
404 tracer.lazy_trace(&ReadTracer {
405 data: &data[0..size],
406 });
407 }
408 self.read_buffer.extend_from_slice(&data[0..size]);
409 }
410 }
411
412 pub async fn check_unilateral_response(&mut self) -> Result<Option<Response>, ClientError> {
420 match self.read_response(None, Duration::ZERO).await {
421 Ok(response) => Ok(Some(response)),
423 Err(ClientError::TimeOutResponse { .. }) => Ok(None),
425 Err(err) => Err(err),
427 }
428 }
429
430 pub async fn read_response(
431 &mut self,
432 command: Option<&Command>,
433 timeout_duration: Duration,
434 ) -> Result<Response, ClientError> {
435 let deadline = Instant::now() + timeout_duration;
436
437 if let Some(sock) = self.socket.as_mut() {
438 match timeout_at(deadline, sock.flush()).await {
439 Ok(Ok(())) => {}
440 Ok(Err(err)) => {
441 self.socket.take();
442 if let Some(tracer) = &self.tracer {
443 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
444 level: Level::ERROR,
445 message: format!("Error during flush: {err:#}"),
446 });
447 tracer.trace_event(SmtpClientTraceEvent::Closed);
448 }
449 return Err(ClientError::FlushError {
450 error: format!("{err:#}"),
451 });
452 }
453 Err(_elapsed) => {
454 self.socket.take();
455 if let Some(tracer) = &self.tracer {
456 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
457 level: Level::ERROR,
458 message: format!("Flush timed out after {timeout_duration:?}"),
459 });
460 tracer.trace_event(SmtpClientTraceEvent::Closed);
461 }
462 return Err(ClientError::TimeOutResponse {
463 command: command.cloned(),
464 duration: timeout_duration,
465 });
466 }
467 }
468 }
469
470 let mut line = self
471 .read_line(deadline.saturating_duration_since(Instant::now()), command)
472 .await?;
473 tracing::trace!("recv<-{}: {line}", self.hostname);
474 let mut parsed = parse_response_line(&line)?;
475 let mut response_builder = ResponseBuilder::new(&parsed);
476
477 while !parsed.is_final {
478 let per_line =
479 Duration::from_secs(60).min(deadline.saturating_duration_since(Instant::now()));
480 line = self.read_line(per_line, command).await?;
481 parsed = parse_response_line(&line)?;
482 response_builder
483 .add_line(&parsed)
484 .map_err(ClientError::MalformedResponseLine)?;
485 }
486
487 let response = response_builder.build(command.map(|cmd| cmd.encode().to_string()));
488
489 tracing::trace!("{}: {response:?}", self.hostname);
490
491 Ok(response)
492 }
493
494 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
495 self.write_command_request(command).await?;
496 self.read_response(Some(command), command.client_timeout(&self.timeouts))
497 .await
498 }
499
500 async fn write_all_with_timeout<F, G>(
504 &mut self,
505 timeout_duration: Duration,
506 bytes: &[u8],
507 make_timeout_err: F,
508 make_write_err: G,
509 ) -> Result<(), ClientError>
510 where
511 F: FnOnce() -> ClientError,
512 G: FnOnce(String) -> ClientError,
513 {
514 match self.socket.as_mut() {
515 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
516 Ok(Ok(response)) => Ok(response),
517 Ok(Err(err)) => {
518 self.socket.take();
519 if let Some(tracer) = &self.tracer {
520 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
521 level: Level::ERROR,
522 message: format!("Error during write: {err:#}"),
523 });
524 tracer.trace_event(SmtpClientTraceEvent::Closed);
525 }
526 Err(make_write_err(format!("{err:#}")))
527 }
528 Err(_) => {
529 self.socket.take();
530 if let Some(tracer) = &self.tracer {
531 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
532 level: Level::ERROR,
533 message: format!("Write Timeout after {timeout_duration:?}"),
534 });
535 tracer.trace_event(SmtpClientTraceEvent::Closed);
536 }
537 Err(make_timeout_err())
538 }
539 },
540 None => Err(make_write_err(
541 "the socket was closed in response to an earlier issue".to_string(),
542 )),
543 }
544 }
545
546 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
547 let total_timeout: Duration = commands
548 .iter()
549 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
550 .sum();
551
552 let mut lines: Vec<String> = vec![];
553 let mut all: Vec<u8> = vec![];
554 for cmd in commands {
555 let line = cmd.encode();
556 all.extend_from_slice(&line);
557 lines.push(line.to_string());
558 }
559 tracing::trace!(
560 "send->{}: (PIPELINE) {}",
561 self.hostname,
562 all.as_bstr().escape_bytes()
563 );
564 if self.socket.is_some() {
565 if let Some(tracer) = &self.tracer {
566 for line in &lines {
569 WriteTracer::trace(tracer, line);
570 }
571 }
572 }
573 self.write_all_with_timeout(
574 total_timeout,
575 &all,
576 || ClientError::TimeOutRequest {
577 duration: total_timeout,
578 commands: commands.to_vec(),
579 },
580 |error| ClientError::WriteError {
581 error,
582 commands: commands.to_vec(),
583 },
584 )
585 .await
586 }
587
588 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
589 let line = command.encode();
590 tracing::trace!("send->{}: {line}", self.hostname);
591 if self.socket.is_some() {
592 if let Some(tracer) = &self.tracer {
593 WriteTracer::trace(tracer, &line.to_string());
594 }
595 }
596
597 let timeout_duration = command.client_timeout_request(&self.timeouts);
598 self.write_all_with_timeout(
599 timeout_duration,
600 &line,
601 || ClientError::TimeOutRequest {
602 duration: timeout_duration,
603 commands: vec![command.clone()],
604 },
605 |error| ClientError::WriteError {
606 error,
607 commands: vec![command.clone()],
608 },
609 )
610 .await
611 }
612
613 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
614 if self.socket.is_some() {
615 if let Some(tracer) = &self.tracer {
616 BinWriteTracer::trace(tracer, data);
617 }
618 }
619 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
620 self.write_all_with_timeout(
621 timeout_duration,
622 data,
623 || ClientError::TimeOutData,
624 |error| ClientError::WriteError {
625 error,
626 commands: vec![],
627 },
628 )
629 .await
630 }
631
632 pub async fn pipeline_commands(
648 &mut self,
649 commands: Vec<Command>,
650 ) -> Vec<Result<Response, ClientError>> {
651 let mut results: Vec<Result<Response, ClientError>> = vec![];
652
653 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
654 if pipeline {
655 if let Err(err) = self.write_pipeline_request(&commands).await {
656 let err: ClientError = err;
657 results.push(Err(err.clone()));
658 while results.len() < commands.len() {
659 results.push(Err(err.clone()));
661 }
662 return results;
663 }
664
665 for cmd in &commands {
667 results.push(
668 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
669 .await,
670 );
671 }
672 return results;
673 }
674
675 for cmd in &commands {
676 if let Err(err) = self.write_command_request(cmd).await {
677 let err: ClientError = err;
678 results.push(Err(err.clone()));
679 while results.len() < commands.len() {
680 results.push(Err(err.clone()));
682 }
683 return results;
684 }
685 results.push(
688 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
689 .await,
690 );
691 }
692 results
693 }
694
695 pub async fn ehlo_lhlo(
696 &mut self,
697 ehlo_name: &str,
698 use_lmtp: bool,
699 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
700 if use_lmtp {
701 self.lhlo(ehlo_name).await
702 } else {
703 self.ehlo(ehlo_name).await
704 }
705 }
706
707 pub async fn lhlo(
708 &mut self,
709 ehlo_name: &str,
710 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
711 let response = self
712 .send_command(&Command::Lhlo(Domain::DomainName(
713 ehlo_name
714 .parse::<DomainString>()
715 .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
716 )))
717 .await?;
718 self.ehlo_common(response)
719 }
720
721 pub async fn ehlo(
722 &mut self,
723 ehlo_name: &str,
724 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
725 let response = self
726 .send_command(&Command::Ehlo(Domain::DomainName(
727 ehlo_name
728 .parse::<DomainString>()
729 .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
730 )))
731 .await?;
732 self.ehlo_common(response)
733 }
734
735 fn ehlo_common(
736 &mut self,
737 response: Response,
738 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
739 if response.code != 250 {
740 return Err(ClientError::Rejected(response));
741 }
742
743 let mut capabilities = HashMap::new();
744
745 for line in response.content.lines().skip(1) {
746 let mut fields = line.splitn(2, ' ');
747 if let Some(name) = fields.next() {
748 let param = fields.next().map(|s| s.to_string());
749 let cap = EsmtpCapability {
750 name: name.to_string(),
751 param,
752 };
753 capabilities.insert(name.to_ascii_uppercase(), cap);
754 }
755 }
756
757 self.capabilities = capabilities;
758 Ok(&self.capabilities)
759 }
760
761 pub async fn auth_plain(
762 &mut self,
763 username: &str,
764 password: Option<&str>,
765 ) -> Result<(), ClientError> {
766 let password = password.unwrap_or("");
769 let payload = format!("\x00{username}\x00{password}");
770 let payload = data_encoding::BASE64.encode(payload.as_bytes());
771
772 let response = self
773 .send_command(&Command::Auth {
774 sasl_mech: "PLAIN".to_string(),
775 initial_response: Some(payload),
776 })
777 .await?;
778
779 if response.code != 235 {
780 return Err(ClientError::Rejected(response));
781 }
782
783 Ok(())
784 }
785
786 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
792 let resp = self.send_command(&Command::StartTls).await?;
793 if resp.code != 220 {
794 return Err(ClientError::Rejected(resp));
795 }
796
797 let mut handshake_error = None;
798 let mut tls_info = TlsInformation::default();
799
800 let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
801 || !options.dane_tlsa.is_empty()
802 {
803 let connector = options
804 .build_openssl_connector(&self.hostname)
805 .map_err(|error| ClientError::FailedToBuildConnector {
806 error: error.to_string(),
807 })?;
808 let ssl = connector.into_ssl(self.hostname.as_str())?;
809
810 let (stream, dup_stream) = match self.socket.take() {
811 Some(s) => {
812 let d = s.try_dup();
813 (s, d)
814 }
815 None => return Err(ClientError::NotConnected),
816 };
817
818 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
819
820 match timeout(
821 self.timeouts.starttls_timeout,
822 std::pin::Pin::new(&mut ssl_stream).connect(),
823 )
824 .await
825 {
826 Ok(Ok(())) => {}
827 Ok(Err(err)) => {
828 handshake_error.replace(format!("{err:#}"));
829 }
830 Err(_elapsed) => {
831 if let Some(tracer) = &self.tracer {
835 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
836 level: Level::ERROR,
837 message: format!(
838 "openssl handshake timed out after {:?}",
839 self.timeouts.starttls_timeout
840 ),
841 });
842 tracer.trace_event(SmtpClientTraceEvent::Closed);
843 }
844 return Err(ClientError::TimeOutResponse {
845 command: Some(Command::StartTls),
846 duration: self.timeouts.starttls_timeout,
847 });
848 }
849 }
850
851 tls_info.provider_name = "openssl".to_string();
852 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
853 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
854 None => String::new(),
855 };
856 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
857
858 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
859 tls_info.subject_name = subject_name(&cert);
860 }
861 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
862 if let Some(cert) = &authority.cert {
863 tls_info.subject_name = subject_name(cert);
864 }
865 }
866
867 match (&handshake_error, dup_stream) {
868 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
869 drop(ssl_stream);
874 Box::new(dup_stream)
875 }
876 _ => Box::new(ssl_stream),
877 }
878 } else {
879 tls_info.provider_name = "rustls".to_string();
880 let connector = options.build_tls_connector().await.map_err(|error| {
881 ClientError::FailedToBuildConnector {
882 error: error.to_string(),
883 }
884 })?;
885 let server_name = parse_server_name(self.hostname.as_str())?;
886
887 let socket = match self.socket.take() {
888 Some(s) => s,
889 None => return Err(ClientError::NotConnected),
890 };
891
892 let connect_future = connector.connect(server_name, socket).into_fallible();
893 match timeout(self.timeouts.starttls_timeout, connect_future).await {
894 Ok(Ok(stream)) => {
895 let (_, conn) = stream.get_ref();
896 tls_info.cipher = match conn.negotiated_cipher_suite() {
897 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
898 None => String::new(),
899 };
900 tls_info.protocol_version = match conn.protocol_version() {
901 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
902 None => String::new(),
903 };
904
905 if let Some(certs) = conn.peer_certificates() {
906 let peer_cert = &certs[0];
907 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
908 tls_info.subject_name = subject_name(&cert);
909 }
910 }
911
912 Box::new(stream)
913 }
914 Ok(Err((err, stream))) => {
915 handshake_error.replace(format!("{err:#}"));
916 stream
917 }
918 Err(_elapsed) => {
919 if let Some(tracer) = &self.tracer {
920 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
921 level: Level::ERROR,
922 message: format!(
923 "rustls handshake timed out after {:?}",
924 self.timeouts.starttls_timeout
925 ),
926 });
927 tracer.trace_event(SmtpClientTraceEvent::Closed);
928 }
929 return Err(ClientError::TimeOutResponse {
930 command: Some(Command::StartTls),
931 duration: self.timeouts.starttls_timeout,
932 });
933 }
934 }
935 };
936
937 if let Some(tracer) = &self.tracer {
938 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
939 level: Level::INFO,
940 message: match &handshake_error {
941 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
942 None => format!("STARTTLS handshake -> {tls_info:?}"),
943 },
944 });
945 }
946
947 self.socket.replace(stream);
948 Ok(match handshake_error {
949 Some(error) => TlsStatus::FailedHandshake(error),
950 None => TlsStatus::Info(tls_info),
951 })
952 }
953
954 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
955 &mut self,
956 sender: SENDER,
957 recipient: RECIP,
958 data: B,
959 ) -> Result<Response, ClientError> {
960 let recipient: ForwardPath = recipient.into();
961 let status = self
962 .send_mail_multi_recip(sender, vec![recipient], data)
963 .await?;
964 Ok(status.response)
965 }
966
967 pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
968 &mut self,
969 sender: SENDER,
970 recipient_list: Vec<ForwardPath>,
971 data: B,
972 ) -> Result<BatchSendSuccess, ClientError> {
973 let sender = sender.into();
974
975 let data: &[u8] = data.as_ref();
976 let stuffed;
977
978 let data = match apply_dot_stuffing(data) {
979 Some(d) => {
980 stuffed = d;
981 &stuffed
982 }
983 None => data,
984 };
985
986 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
987 let envelope_is_8bit =
988 !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
989
990 let mut mail_from_params = vec![];
991 if data_is_8bit {
992 if self.capabilities.contains_key("8BITMIME") {
993 mail_from_params.push(EsmtpParameter {
994 name: "BODY".to_string(),
995 value: Some("8BITMIME".to_string()),
996 });
997 } else if !self.ignore_8bit_checks {
998 return Err(ClientError::Rejected(Response {
999 code: 554,
1000 command: None,
1001 enhanced_code: Some(EnhancedStatusCode {
1002 class: 5,
1003 subject: 6,
1004 detail: 3,
1005 }),
1006 content: "KumoMTA internal: DATA is 8bit, destination does \
1007 not support 8BITMIME. Conversion via msg:check_fix_conformance \
1008 during reception is required"
1009 .to_string(),
1010 }));
1011 }
1012 }
1013
1014 if envelope_is_8bit {
1015 if self.capabilities.contains_key("SMTPUTF8") {
1016 mail_from_params.push(EsmtpParameter {
1017 name: "SMTPUTF8".to_string(),
1018 value: None,
1019 });
1020 } else if !self.ignore_8bit_checks {
1021 return Err(ClientError::Rejected(Response {
1022 code: 554,
1023 command: None,
1024 enhanced_code: Some(EnhancedStatusCode {
1025 class: 5,
1026 subject: 6,
1027 detail: 7,
1028 }),
1029 content: "KumoMTA internal: envelope is 8bit, destination does \
1030 not support SMTPUTF8."
1031 .to_string(),
1032 }));
1033 }
1034 }
1035
1036 let mut commands = vec![];
1037
1038 let used_rset = self.use_rset;
1048 if self.use_rset {
1049 commands.push(Command::Rset);
1050 }
1051 commands.push(Command::MailFrom {
1052 address: sender,
1053 parameters: mail_from_params,
1054 });
1055
1056 for recipient in &recipient_list {
1057 commands.push(Command::RcptTo {
1058 address: recipient.clone(),
1059 parameters: vec![],
1060 });
1061 }
1062 commands.push(Command::Data);
1063
1064 self.use_rset = true;
1067
1068 let mut responses = self.pipeline_commands(commands).await;
1069
1070 let is_err = responses.iter().any(|r| r.is_err());
1078
1079 if used_rset {
1080 let rset_resp = responses.remove(0)?;
1081 if rset_resp.code != 250 {
1082 return Err(ClientError::Rejected(rset_resp));
1083 }
1084 }
1085
1086 let mail_resp = responses.remove(0)?;
1087 if is_err && mail_resp.code != 250 {
1088 return Err(ClientError::Rejected(mail_resp));
1089 }
1090
1091 let mut rcpt_responses = vec![];
1092 for _ in &recipient_list {
1093 rcpt_responses.push(responses.remove(0)?);
1094 }
1095
1096 if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1097 return Err(ClientError::RejectedBatch(rcpt_responses));
1098 }
1099
1100 let data_resp = responses.remove(0)?;
1101 if is_err && data_resp.code != 354 {
1102 return Err(ClientError::Rejected(data_resp));
1103 }
1104
1105 if data_resp.code == 354
1106 && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1107 {
1108 self.write_data_with_timeout(b".\r\n").await?;
1117 let data_dot = Command::DataDot;
1118 let _ = self
1120 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1121 .await?;
1122
1123 }
1127
1128 if mail_resp.code != 250 {
1129 return Err(ClientError::Rejected(mail_resp));
1130 }
1131 if rcpt_responses.iter().all(|resp| resp.code != 250) {
1132 if rcpt_responses.len() == 1 {
1133 return Err(ClientError::Rejected(
1134 rcpt_responses.pop().expect("have at least one"),
1135 ));
1136 }
1137 return Err(ClientError::RejectedBatch(rcpt_responses));
1138 }
1139 if data_resp.code != 354 {
1140 return Err(ClientError::Rejected(data_resp));
1141 }
1142
1143 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1144
1145 tracing::trace!("message data is {} bytes", data.len());
1146
1147 self.write_data_with_timeout(data).await?;
1148
1149 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1150
1151 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1152
1153 self.write_data_with_timeout(marker.as_bytes()).await?;
1154
1155 let data_dot = Command::DataDot;
1156 let resp = self
1157 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1158 .await?;
1159 if resp.code != 250 {
1160 return Err(ClientError::Rejected(resp));
1161 }
1162
1163 self.use_rset = self.enable_rset;
1166
1167 Ok(BatchSendSuccess {
1168 response: resp,
1169 rcpt_responses,
1170 })
1171 }
1172}
1173
1174#[derive(Debug)]
1175pub struct BatchSendSuccess {
1176 pub response: Response,
1177 pub rcpt_responses: Vec<Response>,
1178}
1179
1180#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1181pub enum TlsStatus {
1182 FailedHandshake(String),
1183 Info(TlsInformation),
1184}
1185
1186#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1187pub struct TlsInformation {
1188 pub cipher: String,
1189 pub protocol_version: String,
1190 pub subject_name: Vec<String>,
1191 pub provider_name: String,
1192}
1193
1194impl Drop for SmtpClient {
1195 fn drop(&mut self) {
1196 if let Some(tracer) = &self.tracer {
1197 if self.socket.is_some() {
1198 tracer.trace_event(SmtpClientTraceEvent::Closed);
1199 }
1200 }
1201 }
1202}
1203fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1204 if line.len() < 4 {
1205 return Err(ClientError::MalformedResponseLine(line.to_string()));
1206 }
1207
1208 match line.as_bytes()[3] {
1209 b' ' | b'-' => match line[0..3].parse::<u16>() {
1210 Ok(code) => Ok(ResponseLine {
1211 code,
1212 is_final: line.as_bytes()[3] == b' ',
1213 content: &line[4..],
1214 }),
1215 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1216 },
1217 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1218 }
1219}
1220
1221fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1222 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1223
1224 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1225 return None;
1226 }
1227
1228 let mut stuffed = vec![];
1229 if data.starts_with(b".") {
1230 stuffed.push(b'.');
1231 }
1232 let mut last_idx = 0;
1233 for i in LFDOT.find_iter(data) {
1234 stuffed.extend_from_slice(&data[last_idx..=i]);
1235 stuffed.push(b'.');
1236 last_idx = i + 1;
1237 }
1238 stuffed.extend_from_slice(&data[last_idx..]);
1239 Some(stuffed)
1240}
1241
1242pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1249 let mut subject_name = vec![];
1250 for entry in cert.subject_name().entries() {
1251 if let Ok(obj) = entry.object().nid().short_name() {
1252 if let Ok(name) = entry.data().as_utf8() {
1253 subject_name.push(format!("{obj}={name}"));
1254 }
1255 }
1256 }
1257 subject_name
1258}
1259
1260fn parse_server_name(input: &str) -> Result<ServerName<'static>, ClientError> {
1261 match IpAddr::from_str(input) {
1262 Ok(ip) => Ok(ServerName::IpAddress(ip.into())),
1263 Err(_) => {
1264 let name = Name::from_str_relaxed(input)
1265 .map_err(|_| ClientError::InvalidDnsName(input.to_string()))?;
1266 ServerName::try_from(name.to_ascii())
1267 .map_err(|_| ClientError::InvalidDnsName(name.to_ascii()))
1268 }
1269 }
1270}
1271
1272#[cfg(test)]
1273mod test {
1274 use super::*;
1275 use crate::parser::{EnvelopeAddress, ReversePath};
1276
1277 #[test]
1278 fn test_stuffing() {
1279 assert_eq!(apply_dot_stuffing(b"foo"), None);
1280 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1281 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1282 assert_eq!(
1283 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1284 b"foo\n..bar\n...baz\n"
1285 );
1286 }
1287
1288 #[test]
1312 fn response_line_parsing() {
1313 assert_eq!(
1314 parse_response_line("220 woot").unwrap(),
1315 ResponseLine {
1316 code: 220,
1317 is_final: true,
1318 content: "woot"
1319 }
1320 );
1321 assert_eq!(
1322 parse_response_line("220-woot").unwrap(),
1323 ResponseLine {
1324 code: 220,
1325 is_final: false,
1326 content: "woot"
1327 }
1328 );
1329
1330 assert!(matches!(
1331 parse_response_line("220_woot"),
1332 Err(ClientError::MalformedResponseLine(_))
1333 ));
1334 assert!(matches!(
1335 parse_response_line("not really"),
1336 Err(ClientError::MalformedResponseLine(_))
1337 ));
1338 }
1339
1340 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1341 let mut parsed = parse_response_line(lines[0])?;
1342 let mut b = ResponseBuilder::new(&parsed);
1343 for line in &lines[1..] {
1344 parsed = parse_response_line(line)?;
1345 b.add_line(&parsed)
1346 .map_err(ClientError::MalformedResponseLine)?;
1347 }
1348 assert!(parsed.is_final);
1349 Ok(b.build(None))
1350 }
1351
1352 #[test]
1353 fn multi_line_response() {
1354 assert_eq!(
1355 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1356 Response {
1357 code: 220,
1358 enhanced_code: None,
1359 content: "woot\nmore\ndone".to_string(),
1360 command: None
1361 }
1362 );
1363
1364 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1365 assert!(
1366 matches!(
1367 res,
1368 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1369 ),
1370 "got error {res:?}"
1371 );
1372
1373 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1374 assert!(
1375 matches!(
1376 res,
1377 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1378 ),
1379 "got error {res:?}"
1380 );
1381
1382 assert_eq!(
1383 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1384 Response {
1385 code: 220,
1386 enhanced_code: Some(EnhancedStatusCode {
1387 class: 4,
1388 subject: 1,
1389 detail: 0
1390 }),
1391 content: "woot\nmore\ndone".to_string(),
1392 command: None
1393 }
1394 );
1395
1396 assert_eq!(
1398 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1399 Response {
1400 code: 220,
1401 enhanced_code: Some(EnhancedStatusCode {
1402 class: 4,
1403 subject: 1,
1404 detail: 0
1405 }),
1406 content: "woot\nmore\ndone".to_string(),
1407 command: None
1408 }
1409 );
1410
1411 assert_eq!(
1413 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1414 Response {
1415 code: 220,
1416 enhanced_code: Some(EnhancedStatusCode {
1417 class: 4,
1418 subject: 1,
1419 detail: 0
1420 }),
1421 content: "woot\nmore\n5.5.5 done".to_string(),
1422 command: None
1423 }
1424 );
1425 }
1426
1427 #[test]
1428 fn test_extract_hostname() {
1429 assert_eq!(extract_hostname("foo"), "foo");
1430 assert_eq!(extract_hostname("foo."), "foo");
1431 assert_eq!(extract_hostname("foo:25"), "foo");
1432 assert_eq!(extract_hostname("foo.:25"), "foo");
1433 assert_eq!(extract_hostname("[foo]:25"), "foo");
1434 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1435 assert_eq!(extract_hostname("[::1]:25"), "::1");
1436 assert_eq!(extract_hostname("::1:25"), "::1");
1437 }
1438
1439 #[test]
1440 fn test_format_error_command() {
1441 assert_eq!(
1442 format!(
1443 "{:#}",
1444 ClientError::TimeOutRequest {
1445 commands: vec![Command::DataDot],
1446 duration: Duration::from_secs(10),
1447 }
1448 ),
1449 "Timed Out after 10s writing cmd=."
1450 );
1451 assert_eq!(
1452 format!(
1453 "{:#}",
1454 ClientError::TimeOutResponse {
1455 command: Some(Command::MailFrom {
1456 address: {
1457 let EnvelopeAddress::Path(p) =
1458 EnvelopeAddress::parse("user@host").unwrap()
1459 else {
1460 panic!("expected Path")
1461 };
1462 ReversePath::Path(p)
1463 },
1464 parameters: vec![],
1465 }),
1466 duration: Duration::from_secs(10),
1467 }
1468 ),
1469 r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1470 );
1471 }
1472
1473 #[test]
1474 fn test_issue_533() {
1475 let _name = parse_server_name("münchen.de").unwrap();
1476 }
1477}