1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4 AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5 ReversePath,
6};
7use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
8use hickory_proto::rr::rdata::TLSA;
9use memchr::memmem::Finder;
10use openssl::pkey::PKey;
11use openssl::ssl::{DaneMatchType, DaneSelector, DaneUsage, SslOptions};
12use openssl::x509::{X509Ref, X509};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::net::IpAddr;
16use std::str::FromStr;
17use std::sync::{Arc, LazyLock};
18use std::time::Duration;
19use thiserror::Error;
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21use tokio::net::{TcpStream, ToSocketAddrs};
22use tokio::time::timeout;
23use tokio_rustls::rustls::pki_types::ServerName;
24use tracing::Level;
25
26pub use crate::tls::TlsOptions;
27pub use {openssl, tokio_rustls};
28
29const MAX_LINE_LEN: usize = 4096;
30
31#[derive(Error, Debug, Clone)]
32pub enum ClientError {
33 #[error("response is not UTF8")]
34 Utf8(#[from] std::string::FromUtf8Error),
35 #[error("Malformed Response: {0}")]
36 MalformedResponseLine(String),
37 #[error("Response line is too long")]
38 ResponseTooLong,
39 #[error("Not connected")]
40 NotConnected,
41 #[error("Command rejected {0:?}")]
42 Rejected(Response),
43 #[error("Commands rejected {0:?}")]
44 RejectedBatch(Vec<Response>),
45 #[error("STARTTLS: {0} is not a valid DNS name")]
46 InvalidDnsName(String),
47 #[error("Invalid client certificate configured: {error:?}")]
48 FailedToBuildConnector { error: String },
49 #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
50 TimeOutResponse {
51 command: Option<Command>,
52 duration: Duration,
53 },
54 #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
55 TimeOutRequest {
56 commands: Vec<Command>,
57 duration: Duration,
58 },
59 #[error("Error {error} reading response to cmd={}", self.command_for_err())]
60 ReadError {
61 command: Option<Command>,
62 error: String,
63 partial: String,
64 },
65 #[error("Error {error} flushing send buffer")]
66 FlushError { error: String },
67 #[error("Error {error} writing {}", self.command_for_err())]
68 WriteError {
69 commands: Vec<Command>,
70 error: String,
71 },
72 #[error("Timed Out sending message payload data")]
73 TimeOutData,
74 #[error("SSL Error: {0}")]
75 SslErrorStack(#[from] openssl::error::ErrorStack),
76 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
77 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
78}
79
80impl ClientError {
81 pub fn command(&self) -> Option<String> {
83 match self {
84 Self::TimeOutResponse {
85 command: Some(command),
86 ..
87 }
88 | Self::ReadError {
89 command: Some(command),
90 ..
91 } => Some(command.encode()),
92 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
93 if !commands.is_empty() =>
94 {
95 let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
96 Some(commands.join(""))
97 }
98 _ => None,
99 }
100 }
101
102 fn command_for_err(&self) -> String {
105 self.command()
106 .map(|cmd| cmd.replace("\r\n", ""))
107 .unwrap_or_else(|| "NONE".to_string())
108 }
109
110 pub fn was_due_to_message(&self) -> bool {
122 match self {
123 Self::Utf8(_)
124 | Self::MalformedResponseLine(_)
125 | Self::ResponseTooLong
126 | Self::NotConnected
127 | Self::InvalidDnsName(_)
128 | Self::FailedToBuildConnector { .. }
129 | Self::TimeOutResponse { .. }
130 | Self::TimeOutRequest { .. }
131 | Self::ReadError { .. }
132 | Self::FlushError { .. }
133 | Self::WriteError { .. }
134 | Self::TimeOutData
135 | Self::SslErrorStack(_)
136 | Self::NoUsableDaneTlsa { .. } => false,
137 Self::Rejected(response) => response.was_due_to_message(),
138 Self::RejectedBatch(responses) => match responses.len() {
139 1 => responses[0].was_due_to_message(),
140 _ => false,
141 },
142 }
143 }
144}
145
146#[derive(Debug, Clone, PartialEq, Eq)]
147pub struct EsmtpCapability {
148 pub name: String,
149 pub param: Option<String>,
150}
151
152#[derive(Clone, Debug)]
153pub enum SmtpClientTraceEvent {
154 Closed,
155 Read(Vec<u8>),
156 Write(String),
157 Diagnostic {
158 level: tracing::Level,
159 message: String,
160 },
161}
162
163pub trait DeferredTracer {
164 fn trace(&self) -> SmtpClientTraceEvent;
165}
166
167pub trait SmtpClientTracer: std::fmt::Debug {
168 fn trace_event(&self, event: SmtpClientTraceEvent);
169 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
170}
171
172struct WriteTracer<'a> {
174 data: &'a str,
175}
176impl DeferredTracer for WriteTracer<'_> {
177 fn trace(&self) -> SmtpClientTraceEvent {
178 SmtpClientTraceEvent::Write(self.data.to_string())
179 }
180}
181impl<'a> WriteTracer<'a> {
182 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
183 tracer.lazy_trace(&Self { data });
184 }
185}
186
187struct BinWriteTracer<'a> {
188 data: &'a [u8],
189}
190impl DeferredTracer for BinWriteTracer<'_> {
191 fn trace(&self) -> SmtpClientTraceEvent {
192 let data = String::from_utf8_lossy(self.data).to_string();
193 SmtpClientTraceEvent::Write(data)
194 }
195}
196impl<'a> BinWriteTracer<'a> {
197 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
198 tracer.lazy_trace(&Self { data });
199 }
200}
201
202struct ReadTracer<'a> {
205 data: &'a [u8],
206}
207impl DeferredTracer for ReadTracer<'_> {
208 fn trace(&self) -> SmtpClientTraceEvent {
209 SmtpClientTraceEvent::Read(self.data.to_vec())
210 }
211}
212
213#[derive(Debug)]
214pub struct SmtpClient {
215 socket: Option<BoxedAsyncReadAndWrite>,
216 hostname: String,
217 capabilities: HashMap<String, EsmtpCapability>,
218 read_buffer: Vec<u8>,
219 timeouts: SmtpClientTimeouts,
220 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
221 use_rset: bool,
222 enable_rset: bool,
223 enable_pipelining: bool,
224 ignore_8bit_checks: bool,
225}
226
227fn extract_hostname(hostname: &str) -> &str {
228 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
230 let hostname = if fields.len() == 2 {
231 fields[1]
232 } else {
233 hostname
234 };
235
236 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
237 &hostname[1..hostname.len() - 1]
238 } else {
239 hostname
240 };
241
242 hostname.strip_suffix(".").unwrap_or(hostname)
244}
245
246impl SmtpClient {
247 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
248 addr: A,
249 timeouts: SmtpClientTimeouts,
250 ) -> std::io::Result<Self> {
251 let stream = TcpStream::connect(addr.clone()).await?;
252 stream.set_nodelay(true)?;
254 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
255 }
256
257 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
258 stream: S,
259 peer_hostname: H,
260 timeouts: SmtpClientTimeouts,
261 ) -> Self {
262 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
263
264 Self {
265 socket: Some(Box::new(stream)),
266 hostname,
267 capabilities: HashMap::new(),
268 read_buffer: Vec::with_capacity(1024),
269 timeouts,
270 tracer: None,
271 use_rset: false,
272 enable_rset: false,
273 enable_pipelining: false,
274 ignore_8bit_checks: false,
275 }
276 }
277
278 pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
280 self.ignore_8bit_checks = enable;
281 }
282
283 pub fn is_connected(&self) -> bool {
284 self.socket.is_some()
285 }
286
287 pub fn set_enable_rset(&mut self, enable: bool) {
288 self.enable_rset = enable;
289 }
290
291 pub fn set_enable_pipelining(&mut self, enable: bool) {
292 self.enable_pipelining = enable;
293 }
294
295 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
296 self.tracer.replace(tracer);
297 }
298
299 pub fn timeouts(&self) -> &SmtpClientTimeouts {
300 &self.timeouts
301 }
302
303 async fn read_line(
304 &mut self,
305 timeout_duration: Duration,
306 cmd: Option<&Command>,
307 ) -> Result<String, ClientError> {
308 let mut too_long = false;
309 loop {
310 let mut iter = self.read_buffer.iter().enumerate();
311 while let Some((i, &b)) = iter.next() {
312 if b != b'\r' {
313 continue;
314 }
315 if let Some((_, b'\n')) = iter.next() {
316 if too_long {
317 self.read_buffer.drain(0..i + 2);
318
319 if let Some(tracer) = &self.tracer {
320 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
321 level: Level::ERROR,
322 message: "Response too long".to_string(),
323 });
324 }
325
326 return Err(ClientError::ResponseTooLong);
327 }
328
329 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
330 self.read_buffer.drain(0..i + 2);
331 return Ok(line?);
332 }
333 }
334 if self.read_buffer.len() > MAX_LINE_LEN {
335 self.read_buffer.clear();
336 too_long = true;
337 }
338
339 let mut data = [0u8; MAX_LINE_LEN];
341 let size = match self.socket.as_mut() {
342 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
343 Ok(Ok(size)) => size,
344 Ok(Err(err)) => {
345 self.socket.take();
346 if let Some(tracer) = &self.tracer {
347 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
348 level: Level::ERROR,
349 message: format!("Error during read: {err:#}"),
350 });
351 tracer.trace_event(SmtpClientTraceEvent::Closed);
352 }
353 return Err(ClientError::ReadError {
354 command: cmd.cloned(),
355 error: format!("{err:#}"),
356 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
357 });
358 }
359 Err(_) => {
360 self.socket.take();
361 if let Some(tracer) = &self.tracer {
362 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
363 level: Level::ERROR,
364 message: format!("Read Timeout after {timeout_duration:?}"),
365 });
366 tracer.trace_event(SmtpClientTraceEvent::Closed);
367 }
368 return Err(ClientError::TimeOutResponse {
369 command: cmd.cloned(),
370 duration: timeout_duration,
371 });
372 }
373 },
374 None => {
375 return Err(ClientError::ReadError {
376 command: cmd.cloned(),
377 error: "the socket was closed in response to an earlier issue".to_string(),
378 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
379 });
380 }
381 };
382 if size == 0 {
383 self.socket.take();
384 if let Some(tracer) = &self.tracer {
385 tracer.trace_event(SmtpClientTraceEvent::Closed);
386 }
387 return Err(ClientError::ReadError {
388 command: cmd.cloned(),
389 error: "Connection closed by peer".to_string(),
390 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
391 });
392 }
393 if let Some(tracer) = &self.tracer {
394 tracer.lazy_trace(&ReadTracer {
395 data: &data[0..size],
396 });
397 }
398 self.read_buffer.extend_from_slice(&data[0..size]);
399 }
400 }
401
402 pub async fn read_response(
403 &mut self,
404 command: Option<&Command>,
405 timeout_duration: Duration,
406 ) -> Result<Response, ClientError> {
407 if let Some(sock) = self.socket.as_mut() {
408 if let Err(err) = sock.flush().await {
409 self.socket.take();
410 if let Some(tracer) = &self.tracer {
411 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
412 level: Level::ERROR,
413 message: format!("Error during flush: {err:#}"),
414 });
415 tracer.trace_event(SmtpClientTraceEvent::Closed);
416 }
417 return Err(ClientError::FlushError {
418 error: format!("{err:#}"),
419 });
420 }
421 }
422
423 let mut line = self.read_line(timeout_duration, command).await?;
424 tracing::trace!("recv<-{}: {line}", self.hostname);
425 let mut parsed = parse_response_line(&line)?;
426 let mut response_builder = ResponseBuilder::new(&parsed);
427
428 let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
429 while !parsed.is_final {
430 line = self
431 .read_line(subsequent_line_timeout_duration, command)
432 .await?;
433 parsed = parse_response_line(&line)?;
434 response_builder
435 .add_line(&parsed)
436 .map_err(ClientError::MalformedResponseLine)?;
437 }
438
439 let response = response_builder.build(command.map(|cmd| cmd.encode()));
440
441 tracing::trace!("{}: {response:?}", self.hostname);
442
443 Ok(response)
444 }
445
446 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
447 self.write_command_request(command).await?;
448 self.read_response(Some(command), command.client_timeout(&self.timeouts))
449 .await
450 }
451
452 async fn write_all_with_timeout<F, G>(
456 &mut self,
457 timeout_duration: Duration,
458 bytes: &[u8],
459 make_timeout_err: F,
460 make_write_err: G,
461 ) -> Result<(), ClientError>
462 where
463 F: FnOnce() -> ClientError,
464 G: FnOnce(String) -> ClientError,
465 {
466 match self.socket.as_mut() {
467 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
468 Ok(Ok(response)) => Ok(response),
469 Ok(Err(err)) => {
470 self.socket.take();
471 if let Some(tracer) = &self.tracer {
472 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
473 level: Level::ERROR,
474 message: format!("Error during write: {err:#}"),
475 });
476 tracer.trace_event(SmtpClientTraceEvent::Closed);
477 }
478 Err(make_write_err(format!("{err:#}")))
479 }
480 Err(_) => {
481 self.socket.take();
482 if let Some(tracer) = &self.tracer {
483 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
484 level: Level::ERROR,
485 message: format!("Write Timeout after {timeout_duration:?}"),
486 });
487 tracer.trace_event(SmtpClientTraceEvent::Closed);
488 }
489 Err(make_timeout_err())
490 }
491 },
492 None => Err(make_write_err(
493 "the socket was closed in response to an earlier issue".to_string(),
494 )),
495 }
496 }
497
498 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
499 let total_timeout: Duration = commands
500 .iter()
501 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
502 .sum();
503
504 let mut lines: Vec<String> = vec![];
505 let mut all = String::new();
506 for cmd in commands {
507 let line = cmd.encode();
508 all.push_str(&line);
509 lines.push(line);
510 }
511 tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
512 if self.socket.is_some() {
513 if let Some(tracer) = &self.tracer {
514 for line in lines {
517 WriteTracer::trace(tracer, &line);
518 }
519 }
520 }
521 self.write_all_with_timeout(
522 total_timeout,
523 all.as_bytes(),
524 || ClientError::TimeOutRequest {
525 duration: total_timeout,
526 commands: commands.to_vec(),
527 },
528 |error| ClientError::WriteError {
529 error,
530 commands: commands.to_vec(),
531 },
532 )
533 .await
534 }
535
536 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
537 let line = command.encode();
538 tracing::trace!("send->{}: {line}", self.hostname);
539 if self.socket.is_some() {
540 if let Some(tracer) = &self.tracer {
541 WriteTracer::trace(tracer, &line);
542 }
543 }
544
545 let timeout_duration = command.client_timeout_request(&self.timeouts);
546 self.write_all_with_timeout(
547 timeout_duration,
548 line.as_bytes(),
549 || ClientError::TimeOutRequest {
550 duration: timeout_duration,
551 commands: vec![command.clone()],
552 },
553 |error| ClientError::WriteError {
554 error,
555 commands: vec![command.clone()],
556 },
557 )
558 .await
559 }
560
561 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
562 if self.socket.is_some() {
563 if let Some(tracer) = &self.tracer {
564 BinWriteTracer::trace(tracer, data);
565 }
566 }
567 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
568 self.write_all_with_timeout(
569 timeout_duration,
570 data,
571 || ClientError::TimeOutData,
572 |error| ClientError::WriteError {
573 error,
574 commands: vec![],
575 },
576 )
577 .await
578 }
579
580 pub async fn pipeline_commands(
596 &mut self,
597 commands: Vec<Command>,
598 ) -> Vec<Result<Response, ClientError>> {
599 let mut results: Vec<Result<Response, ClientError>> = vec![];
600
601 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
602 if pipeline {
603 if let Err(err) = self.write_pipeline_request(&commands).await {
604 let err: ClientError = err;
605 results.push(Err(err.clone()));
606 while results.len() < commands.len() {
607 results.push(Err(err.clone()));
609 }
610 return results;
611 }
612
613 for cmd in &commands {
615 results.push(
616 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
617 .await,
618 );
619 }
620 return results;
621 }
622
623 for cmd in &commands {
624 if let Err(err) = self.write_command_request(cmd).await {
625 let err: ClientError = err;
626 results.push(Err(err.clone()));
627 while results.len() < commands.len() {
628 results.push(Err(err.clone()));
630 }
631 return results;
632 }
633 results.push(
636 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
637 .await,
638 );
639 }
640 results
641 }
642
643 pub async fn ehlo_lhlo(
644 &mut self,
645 ehlo_name: &str,
646 use_lmtp: bool,
647 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
648 if use_lmtp {
649 self.lhlo(ehlo_name).await
650 } else {
651 self.ehlo(ehlo_name).await
652 }
653 }
654
655 pub async fn lhlo(
656 &mut self,
657 ehlo_name: &str,
658 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
659 let response = self
660 .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
661 .await?;
662 self.ehlo_common(response)
663 }
664
665 pub async fn ehlo(
666 &mut self,
667 ehlo_name: &str,
668 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
669 let response = self
670 .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
671 .await?;
672 self.ehlo_common(response)
673 }
674
675 fn ehlo_common(
676 &mut self,
677 response: Response,
678 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
679 if response.code != 250 {
680 return Err(ClientError::Rejected(response));
681 }
682
683 let mut capabilities = HashMap::new();
684
685 for line in response.content.lines().skip(1) {
686 let mut fields = line.splitn(2, ' ');
687 if let Some(name) = fields.next() {
688 let param = fields.next().map(|s| s.to_string());
689 let cap = EsmtpCapability {
690 name: name.to_string(),
691 param,
692 };
693 capabilities.insert(name.to_ascii_uppercase(), cap);
694 }
695 }
696
697 self.capabilities = capabilities;
698 Ok(&self.capabilities)
699 }
700
701 pub async fn auth_plain(
702 &mut self,
703 username: &str,
704 password: Option<&str>,
705 ) -> Result<(), ClientError> {
706 let password = password.unwrap_or("");
709 let payload = format!("\x00{username}\x00{password}");
710 let payload = data_encoding::BASE64.encode(payload.as_bytes());
711
712 let response = self
713 .send_command(&Command::Auth {
714 sasl_mech: "PLAIN".to_string(),
715 initial_response: Some(payload),
716 })
717 .await?;
718
719 if response.code != 235 {
720 return Err(ClientError::Rejected(response));
721 }
722
723 Ok(())
724 }
725
726 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
732 let resp = self.send_command(&Command::StartTls).await?;
733 if resp.code != 220 {
734 return Err(ClientError::Rejected(resp));
735 }
736
737 let mut handshake_error = None;
738 let mut tls_info = TlsInformation::default();
739
740 let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
741 || !options.dane_tlsa.is_empty()
742 {
743 let connector = options
744 .build_openssl_connector(&self.hostname)
745 .map_err(|error| ClientError::FailedToBuildConnector {
746 error: error.to_string(),
747 })?;
748 let ssl = connector.into_ssl(self.hostname.as_str())?;
749
750 let (stream, dup_stream) = match self.socket.take() {
751 Some(s) => {
752 let d = s.try_dup();
753 (s, d)
754 }
755 None => return Err(ClientError::NotConnected),
756 };
757
758 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
759
760 if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
761 handshake_error.replace(format!("{err:#}"));
762 }
763
764 tls_info.provider_name = "openssl".to_string();
765 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
766 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
767 None => String::new(),
768 };
769 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
770
771 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
772 tls_info.subject_name = subject_name(&cert);
773 }
774 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
775 if let Some(cert) = &authority.cert {
776 tls_info.subject_name = subject_name(cert);
777 }
778 }
779
780 match (&handshake_error, dup_stream) {
781 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
782 drop(ssl_stream);
787 Box::new(dup_stream)
788 }
789 _ => Box::new(ssl_stream),
790 }
791 } else {
792 tls_info.provider_name = "rustls".to_string();
793 let connector = options.build_tls_connector().await.map_err(|error| {
794 ClientError::FailedToBuildConnector {
795 error: error.to_string(),
796 }
797 })?;
798 let server_name = match IpAddr::from_str(self.hostname.as_str()) {
799 Ok(ip) => ServerName::IpAddress(ip.into()),
800 Err(_) => ServerName::try_from(self.hostname.clone())
801 .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
802 };
803
804 match connector
805 .connect(
806 server_name,
807 match self.socket.take() {
808 Some(s) => s,
809 None => return Err(ClientError::NotConnected),
810 },
811 )
812 .into_fallible()
813 .await
814 {
815 Ok(stream) => {
816 let (_, conn) = stream.get_ref();
817 tls_info.cipher = match conn.negotiated_cipher_suite() {
818 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
819 None => String::new(),
820 };
821 tls_info.protocol_version = match conn.protocol_version() {
822 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
823 None => String::new(),
824 };
825
826 if let Some(certs) = conn.peer_certificates() {
827 let peer_cert = &certs[0];
828 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
829 tls_info.subject_name = subject_name(&cert);
830 }
831 }
832
833 Box::new(stream)
834 }
835 Err((err, stream)) => {
836 handshake_error.replace(format!("{err:#}"));
837 stream
838 }
839 }
840 };
841
842 if let Some(tracer) = &self.tracer {
843 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
844 level: Level::INFO,
845 message: match &handshake_error {
846 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
847 None => format!("STARTTLS handshake -> {tls_info:?}"),
848 },
849 });
850 }
851
852 self.socket.replace(stream);
853 Ok(match handshake_error {
854 Some(error) => TlsStatus::FailedHandshake(error),
855 None => TlsStatus::Info(tls_info),
856 })
857 }
858
859 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
860 &mut self,
861 sender: SENDER,
862 recipient: RECIP,
863 data: B,
864 ) -> Result<Response, ClientError> {
865 let recipient: ForwardPath = recipient.into();
866 let status = self
867 .send_mail_multi_recip(sender, vec![recipient], data)
868 .await?;
869 Ok(status.response)
870 }
871
872 pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
873 &mut self,
874 sender: SENDER,
875 recipient_list: Vec<ForwardPath>,
876 data: B,
877 ) -> Result<BatchSendSuccess, ClientError> {
878 let sender = sender.into();
879
880 let data: &[u8] = data.as_ref();
881 let stuffed;
882
883 let data = match apply_dot_stuffing(data) {
884 Some(d) => {
885 stuffed = d;
886 &stuffed
887 }
888 None => data,
889 };
890
891 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
892 let envelope_is_8bit =
893 !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
894
895 let mut mail_from_params = vec![];
896 if data_is_8bit {
897 if self.capabilities.contains_key("8BITMIME") {
898 mail_from_params.push(EsmtpParameter {
899 name: "BODY".to_string(),
900 value: Some("8BITMIME".to_string()),
901 });
902 } else if !self.ignore_8bit_checks {
903 return Err(ClientError::Rejected(Response {
904 code: 554,
905 command: None,
906 enhanced_code: Some(EnhancedStatusCode {
907 class: 5,
908 subject: 6,
909 detail: 3,
910 }),
911 content: "KumoMTA internal: DATA is 8bit, destination does \
912 not support 8BITMIME. Conversion via msg:check_fix_conformance \
913 during reception is required"
914 .to_string(),
915 }));
916 }
917 }
918
919 if envelope_is_8bit {
920 if self.capabilities.contains_key("SMTPUTF8") {
921 mail_from_params.push(EsmtpParameter {
922 name: "SMTPUTF8".to_string(),
923 value: None,
924 });
925 } else if !self.ignore_8bit_checks {
926 return Err(ClientError::Rejected(Response {
927 code: 554,
928 command: None,
929 enhanced_code: Some(EnhancedStatusCode {
930 class: 5,
931 subject: 6,
932 detail: 7,
933 }),
934 content: "KumoMTA internal: envelope is 8bit, destination does \
935 not support SMTPUTF8."
936 .to_string(),
937 }));
938 }
939 }
940
941 let mut commands = vec![];
942
943 let used_rset = self.use_rset;
953 if self.use_rset {
954 commands.push(Command::Rset);
955 }
956 commands.push(Command::MailFrom {
957 address: sender,
958 parameters: mail_from_params,
959 });
960
961 for recipient in &recipient_list {
962 commands.push(Command::RcptTo {
963 address: recipient.clone(),
964 parameters: vec![],
965 });
966 }
967 commands.push(Command::Data);
968
969 self.use_rset = true;
972
973 let mut responses = self.pipeline_commands(commands).await;
974
975 let is_err = responses.iter().any(|r| r.is_err());
983
984 if used_rset {
985 let rset_resp = responses.remove(0)?;
986 if rset_resp.code != 250 {
987 return Err(ClientError::Rejected(rset_resp));
988 }
989 }
990
991 let mail_resp = responses.remove(0)?;
992 if is_err && mail_resp.code != 250 {
993 return Err(ClientError::Rejected(mail_resp));
994 }
995
996 let mut rcpt_responses = vec![];
997 for _ in &recipient_list {
998 rcpt_responses.push(responses.remove(0)?);
999 }
1000
1001 if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1002 return Err(ClientError::RejectedBatch(rcpt_responses));
1003 }
1004
1005 let data_resp = responses.remove(0)?;
1006 if is_err && data_resp.code != 354 {
1007 return Err(ClientError::Rejected(data_resp));
1008 }
1009
1010 if data_resp.code == 354
1011 && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1012 {
1013 self.write_data_with_timeout(b".\r\n").await?;
1022 let data_dot = Command::DataDot;
1023 let _ = self
1025 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1026 .await?;
1027
1028 }
1032
1033 if mail_resp.code != 250 {
1034 return Err(ClientError::Rejected(mail_resp));
1035 }
1036 if rcpt_responses.iter().all(|resp| resp.code != 250) {
1037 if rcpt_responses.len() == 1 {
1038 return Err(ClientError::Rejected(
1039 rcpt_responses.pop().expect("have at least one"),
1040 ));
1041 }
1042 return Err(ClientError::RejectedBatch(rcpt_responses));
1043 }
1044 if data_resp.code != 354 {
1045 return Err(ClientError::Rejected(data_resp));
1046 }
1047
1048 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1049
1050 tracing::trace!("message data is {} bytes", data.len());
1051
1052 self.write_data_with_timeout(data).await?;
1053
1054 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1055
1056 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1057
1058 self.write_data_with_timeout(marker.as_bytes()).await?;
1059
1060 let data_dot = Command::DataDot;
1061 let resp = self
1062 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1063 .await?;
1064 if resp.code != 250 {
1065 return Err(ClientError::Rejected(resp));
1066 }
1067
1068 self.use_rset = self.enable_rset;
1071
1072 Ok(BatchSendSuccess {
1073 response: resp,
1074 rcpt_responses,
1075 })
1076 }
1077}
1078
1079#[derive(Debug)]
1080pub struct BatchSendSuccess {
1081 pub response: Response,
1082 pub rcpt_responses: Vec<Response>,
1083}
1084
1085#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1086pub enum TlsStatus {
1087 FailedHandshake(String),
1088 Info(TlsInformation),
1089}
1090
1091#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1092pub struct TlsInformation {
1093 pub cipher: String,
1094 pub protocol_version: String,
1095 pub subject_name: Vec<String>,
1096 pub provider_name: String,
1097}
1098
1099impl Drop for SmtpClient {
1100 fn drop(&mut self) {
1101 if let Some(tracer) = &self.tracer {
1102 if self.socket.is_some() {
1103 tracer.trace_event(SmtpClientTraceEvent::Closed);
1104 }
1105 }
1106 }
1107}
1108fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1109 if line.len() < 4 {
1110 return Err(ClientError::MalformedResponseLine(line.to_string()));
1111 }
1112
1113 match line.as_bytes()[3] {
1114 b' ' | b'-' => match line[0..3].parse::<u16>() {
1115 Ok(code) => Ok(ResponseLine {
1116 code,
1117 is_final: line.as_bytes()[3] == b' ',
1118 content: &line[4..],
1119 }),
1120 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1121 },
1122 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1123 }
1124}
1125
1126impl TlsOptions {
1127 pub fn build_openssl_connector(
1128 &self,
1129 hostname: &str,
1130 ) -> Result<openssl::ssl::ConnectConfiguration, ClientError> {
1131 tracing::trace!("build_openssl_connector for {hostname}");
1132 let mut builder =
1133 openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
1134
1135 if let (Some(cert_data), Some(key_data)) =
1136 (&self.certificate_from_pem, &self.private_key_from_pem)
1137 {
1138 let cert = X509::from_pem(cert_data)?;
1139 builder.set_certificate(&cert)?;
1140
1141 let key = PKey::private_key_from_pem(key_data)?;
1142 builder.set_private_key(&key)?;
1143
1144 builder.check_private_key()?;
1145 }
1146
1147 if let Some(list) = &self.openssl_cipher_list {
1148 builder.set_cipher_list(list)?;
1149 }
1150
1151 if let Some(suites) = &self.openssl_cipher_suites {
1152 builder.set_ciphersuites(suites)?;
1153 }
1154
1155 if let Some(options) = &self.openssl_options {
1156 builder.clear_options(SslOptions::all());
1157 builder.set_options(*options);
1158 }
1159
1160 if self.insecure {
1161 builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
1162 }
1163
1164 if !self.dane_tlsa.is_empty() {
1165 builder.dane_enable()?;
1166 builder.set_no_dane_ee_namechecks();
1167 }
1168
1169 let connector = builder.build();
1170
1171 let mut config = connector.configure()?;
1172
1173 if !self.dane_tlsa.is_empty() {
1174 config.dane_enable(hostname)?;
1175 let mut any_usable = false;
1176 for tlsa in &self.dane_tlsa {
1177 let usable = config.dane_tlsa_add(
1178 match tlsa.cert_usage() {
1179 CertUsage::PkixTa => DaneUsage::PKIX_TA,
1180 CertUsage::PkixEe => DaneUsage::PKIX_EE,
1181 CertUsage::DaneTa => DaneUsage::DANE_TA,
1182 CertUsage::DaneEe => DaneUsage::DANE_EE,
1183 CertUsage::Unassigned(n) => DaneUsage::from_raw(n),
1184 CertUsage::Private => DaneUsage::PRIV_CERT,
1185 },
1186 match tlsa.selector() {
1187 Selector::Full => DaneSelector::CERT,
1188 Selector::Spki => DaneSelector::SPKI,
1189 Selector::Unassigned(n) => DaneSelector::from_raw(n),
1190 Selector::Private => DaneSelector::PRIV_SEL,
1191 },
1192 match tlsa.matching() {
1193 Matching::Raw => DaneMatchType::FULL,
1194 Matching::Sha256 => DaneMatchType::SHA2_256,
1195 Matching::Sha512 => DaneMatchType::SHA2_512,
1196 Matching::Unassigned(n) => DaneMatchType::from_raw(n),
1197 Matching::Private => DaneMatchType::PRIV_MATCH,
1198 },
1199 tlsa.cert_data(),
1200 )?;
1201
1202 tracing::trace!("build_dane_connector usable={usable} {tlsa:?}");
1203 if usable {
1204 any_usable = true;
1205 }
1206 }
1207
1208 if !any_usable {
1209 return Err(ClientError::NoUsableDaneTlsa {
1210 hostname: hostname.to_string(),
1211 tlsa: self.dane_tlsa.clone(),
1212 });
1213 }
1214 }
1215
1216 Ok(config)
1217 }
1218}
1219
1220fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1221 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1222
1223 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1224 return None;
1225 }
1226
1227 let mut stuffed = vec![];
1228 if data.starts_with(b".") {
1229 stuffed.push(b'.');
1230 }
1231 let mut last_idx = 0;
1232 for i in LFDOT.find_iter(data) {
1233 stuffed.extend_from_slice(&data[last_idx..=i]);
1234 stuffed.push(b'.');
1235 last_idx = i + 1;
1236 }
1237 stuffed.extend_from_slice(&data[last_idx..]);
1238 Some(stuffed)
1239}
1240
1241pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1248 let mut subject_name = vec![];
1249 for entry in cert.subject_name().entries() {
1250 if let Ok(obj) = entry.object().nid().short_name() {
1251 if let Ok(name) = entry.data().as_utf8() {
1252 subject_name.push(format!("{obj}={name}"));
1253 }
1254 }
1255 }
1256 subject_name
1257}
1258
1259#[cfg(test)]
1260mod test {
1261 use super::*;
1262 use crate::{MailPath, Mailbox};
1263
1264 #[test]
1265 fn test_stuffing() {
1266 assert_eq!(apply_dot_stuffing(b"foo"), None);
1267 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1268 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1269 assert_eq!(
1270 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1271 b"foo\n..bar\n...baz\n"
1272 );
1273 }
1274
1275 #[test]
1299 fn response_line_parsing() {
1300 assert_eq!(
1301 parse_response_line("220 woot").unwrap(),
1302 ResponseLine {
1303 code: 220,
1304 is_final: true,
1305 content: "woot"
1306 }
1307 );
1308 assert_eq!(
1309 parse_response_line("220-woot").unwrap(),
1310 ResponseLine {
1311 code: 220,
1312 is_final: false,
1313 content: "woot"
1314 }
1315 );
1316
1317 assert!(matches!(
1318 parse_response_line("220_woot"),
1319 Err(ClientError::MalformedResponseLine(_))
1320 ));
1321 assert!(matches!(
1322 parse_response_line("not really"),
1323 Err(ClientError::MalformedResponseLine(_))
1324 ));
1325 }
1326
1327 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1328 let mut parsed = parse_response_line(lines[0])?;
1329 let mut b = ResponseBuilder::new(&parsed);
1330 for line in &lines[1..] {
1331 parsed = parse_response_line(line)?;
1332 b.add_line(&parsed)
1333 .map_err(ClientError::MalformedResponseLine)?;
1334 }
1335 assert!(parsed.is_final);
1336 Ok(b.build(None))
1337 }
1338
1339 #[test]
1340 fn multi_line_response() {
1341 assert_eq!(
1342 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1343 Response {
1344 code: 220,
1345 enhanced_code: None,
1346 content: "woot\nmore\ndone".to_string(),
1347 command: None
1348 }
1349 );
1350
1351 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1352 assert!(
1353 matches!(
1354 res,
1355 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1356 ),
1357 "got error {res:?}"
1358 );
1359
1360 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1361 assert!(
1362 matches!(
1363 res,
1364 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1365 ),
1366 "got error {res:?}"
1367 );
1368
1369 assert_eq!(
1370 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1371 Response {
1372 code: 220,
1373 enhanced_code: Some(EnhancedStatusCode {
1374 class: 4,
1375 subject: 1,
1376 detail: 0
1377 }),
1378 content: "woot\nmore\ndone".to_string(),
1379 command: None
1380 }
1381 );
1382
1383 assert_eq!(
1385 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1386 Response {
1387 code: 220,
1388 enhanced_code: Some(EnhancedStatusCode {
1389 class: 4,
1390 subject: 1,
1391 detail: 0
1392 }),
1393 content: "woot\nmore\ndone".to_string(),
1394 command: None
1395 }
1396 );
1397
1398 assert_eq!(
1400 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1401 Response {
1402 code: 220,
1403 enhanced_code: Some(EnhancedStatusCode {
1404 class: 4,
1405 subject: 1,
1406 detail: 0
1407 }),
1408 content: "woot\nmore\n5.5.5 done".to_string(),
1409 command: None
1410 }
1411 );
1412 }
1413
1414 #[test]
1415 fn test_extract_hostname() {
1416 assert_eq!(extract_hostname("foo"), "foo");
1417 assert_eq!(extract_hostname("foo."), "foo");
1418 assert_eq!(extract_hostname("foo:25"), "foo");
1419 assert_eq!(extract_hostname("foo.:25"), "foo");
1420 assert_eq!(extract_hostname("[foo]:25"), "foo");
1421 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1422 assert_eq!(extract_hostname("[::1]:25"), "::1");
1423 assert_eq!(extract_hostname("::1:25"), "::1");
1424 }
1425
1426 #[test]
1427 fn test_format_error_command() {
1428 assert_eq!(
1429 format!(
1430 "{:#}",
1431 ClientError::TimeOutRequest {
1432 commands: vec![Command::DataDot],
1433 duration: Duration::from_secs(10),
1434 }
1435 ),
1436 "Timed Out after 10s writing cmd=."
1437 );
1438 assert_eq!(
1439 format!(
1440 "{:#}",
1441 ClientError::TimeOutResponse {
1442 command: Some(Command::MailFrom {
1443 address: ReversePath::Path(MailPath {
1444 at_domain_list: vec![],
1445 mailbox: Mailbox {
1446 local_part: "user".to_string(),
1447 domain: Domain::Name("host".to_string())
1448 }
1449 }),
1450 parameters: vec![],
1451 }),
1452 duration: Duration::from_secs(10),
1453 }
1454 ),
1455 r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1456 );
1457 }
1458}