1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4 AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5 ReversePath,
6};
7use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
8use hickory_proto::rr::rdata::TLSA;
9use memchr::memmem::Finder;
10use openssl::pkey::PKey;
11use openssl::ssl::{DaneMatchType, DaneSelector, DaneUsage, SslOptions};
12use openssl::x509::{X509Ref, X509};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::net::IpAddr;
16use std::str::FromStr;
17use std::sync::{Arc, LazyLock};
18use std::time::Duration;
19use thiserror::Error;
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21use tokio::net::{TcpStream, ToSocketAddrs};
22use tokio::time::timeout;
23use tokio_rustls::rustls::pki_types::ServerName;
24use tracing::Level;
25
26pub use crate::tls::TlsOptions;
27pub use {openssl, tokio_rustls};
28
29const MAX_LINE_LEN: usize = 4096;
30
31#[derive(Error, Debug, Clone)]
32pub enum ClientError {
33 #[error("response is not UTF8")]
34 Utf8(#[from] std::string::FromUtf8Error),
35 #[error("Malformed Response: {0}")]
36 MalformedResponseLine(String),
37 #[error("Response line is too long")]
38 ResponseTooLong,
39 #[error("Not connected")]
40 NotConnected,
41 #[error("Command rejected {0:?}")]
42 Rejected(Response),
43 #[error("STARTTLS: {0} is not a valid DNS name")]
44 InvalidDnsName(String),
45 #[error("Invalid client certificate configured: {error:?}")]
46 FailedToBuildConnector { error: String },
47 #[error("Timed Out waiting {duration:?} for response to {command:?}")]
48 TimeOutResponse {
49 command: Option<Command>,
50 duration: Duration,
51 },
52 #[error("Timed Out writing {duration:?} {commands:?}")]
53 TimeOutRequest {
54 commands: Vec<Command>,
55 duration: Duration,
56 },
57 #[error("Error {error} reading response to {command:?}")]
58 ReadError {
59 command: Option<Command>,
60 error: String,
61 partial: String,
62 },
63 #[error("Error {error} flushing send buffer")]
64 FlushError { error: String },
65 #[error("Error {error} writing {commands:?}")]
66 WriteError {
67 commands: Vec<Command>,
68 error: String,
69 },
70 #[error("Timed Out sending message payload data")]
71 TimeOutData,
72 #[error("SSL Error: {0}")]
73 SslErrorStack(#[from] openssl::error::ErrorStack),
74 #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
75 NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
76}
77
78impl ClientError {
79 pub fn command(&self) -> Option<String> {
81 match self {
82 Self::TimeOutResponse {
83 command: Some(command),
84 ..
85 }
86 | Self::ReadError {
87 command: Some(command),
88 ..
89 } => Some(command.encode()),
90 Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
91 if !commands.is_empty() =>
92 {
93 let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
94 Some(commands.join(""))
95 }
96 _ => None,
97 }
98 }
99
100 pub fn was_due_to_message(&self) -> bool {
112 match self {
113 Self::Utf8(_)
114 | Self::MalformedResponseLine(_)
115 | Self::ResponseTooLong
116 | Self::NotConnected
117 | Self::InvalidDnsName(_)
118 | Self::FailedToBuildConnector { .. }
119 | Self::TimeOutResponse { .. }
120 | Self::TimeOutRequest { .. }
121 | Self::ReadError { .. }
122 | Self::FlushError { .. }
123 | Self::WriteError { .. }
124 | Self::TimeOutData
125 | Self::SslErrorStack(_)
126 | Self::NoUsableDaneTlsa { .. } => false,
127 Self::Rejected(response) => response.was_due_to_message(),
128 }
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct EsmtpCapability {
134 pub name: String,
135 pub param: Option<String>,
136}
137
138#[derive(Clone, Debug)]
139pub enum SmtpClientTraceEvent {
140 Closed,
141 Read(Vec<u8>),
142 Write(String),
143 Diagnostic {
144 level: tracing::Level,
145 message: String,
146 },
147}
148
149pub trait DeferredTracer {
150 fn trace(&self) -> SmtpClientTraceEvent;
151}
152
153pub trait SmtpClientTracer: std::fmt::Debug {
154 fn trace_event(&self, event: SmtpClientTraceEvent);
155 fn lazy_trace(&self, deferred: &dyn DeferredTracer);
156}
157
158struct WriteTracer<'a> {
160 data: &'a str,
161}
162impl DeferredTracer for WriteTracer<'_> {
163 fn trace(&self) -> SmtpClientTraceEvent {
164 SmtpClientTraceEvent::Write(self.data.to_string())
165 }
166}
167impl<'a> WriteTracer<'a> {
168 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
169 tracer.lazy_trace(&Self { data });
170 }
171}
172
173struct BinWriteTracer<'a> {
174 data: &'a [u8],
175}
176impl DeferredTracer for BinWriteTracer<'_> {
177 fn trace(&self) -> SmtpClientTraceEvent {
178 let data = String::from_utf8_lossy(self.data).to_string();
179 SmtpClientTraceEvent::Write(data)
180 }
181}
182impl<'a> BinWriteTracer<'a> {
183 fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
184 tracer.lazy_trace(&Self { data });
185 }
186}
187
188struct ReadTracer<'a> {
191 data: &'a [u8],
192}
193impl DeferredTracer for ReadTracer<'_> {
194 fn trace(&self) -> SmtpClientTraceEvent {
195 SmtpClientTraceEvent::Read(self.data.to_vec())
196 }
197}
198
199#[derive(Debug)]
200pub struct SmtpClient {
201 socket: Option<BoxedAsyncReadAndWrite>,
202 hostname: String,
203 capabilities: HashMap<String, EsmtpCapability>,
204 read_buffer: Vec<u8>,
205 timeouts: SmtpClientTimeouts,
206 tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
207 use_rset: bool,
208 enable_rset: bool,
209 enable_pipelining: bool,
210}
211
212fn extract_hostname(hostname: &str) -> &str {
213 let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
215 let hostname = if fields.len() == 2 {
216 fields[1]
217 } else {
218 hostname
219 };
220
221 let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
222 &hostname[1..hostname.len() - 1]
223 } else {
224 hostname
225 };
226
227 hostname.strip_suffix(".").unwrap_or(hostname)
229}
230
231impl SmtpClient {
232 pub async fn new<A: ToSocketAddrs + ToString + Clone>(
233 addr: A,
234 timeouts: SmtpClientTimeouts,
235 ) -> std::io::Result<Self> {
236 let stream = TcpStream::connect(addr.clone()).await?;
237 stream.set_nodelay(true)?;
239 Ok(Self::with_stream(stream, addr.to_string(), timeouts))
240 }
241
242 pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
243 stream: S,
244 peer_hostname: H,
245 timeouts: SmtpClientTimeouts,
246 ) -> Self {
247 let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
248
249 Self {
250 socket: Some(Box::new(stream)),
251 hostname,
252 capabilities: HashMap::new(),
253 read_buffer: Vec::with_capacity(1024),
254 timeouts,
255 tracer: None,
256 use_rset: false,
257 enable_rset: false,
258 enable_pipelining: false,
259 }
260 }
261
262 pub fn is_connected(&self) -> bool {
263 self.socket.is_some()
264 }
265
266 pub fn set_enable_rset(&mut self, enable: bool) {
267 self.enable_rset = enable;
268 }
269
270 pub fn set_enable_pipelining(&mut self, enable: bool) {
271 self.enable_pipelining = enable;
272 }
273
274 pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
275 self.tracer.replace(tracer);
276 }
277
278 pub fn timeouts(&self) -> &SmtpClientTimeouts {
279 &self.timeouts
280 }
281
282 async fn read_line(
283 &mut self,
284 timeout_duration: Duration,
285 cmd: Option<&Command>,
286 ) -> Result<String, ClientError> {
287 let mut too_long = false;
288 loop {
289 let mut iter = self.read_buffer.iter().enumerate();
290 while let Some((i, &b)) = iter.next() {
291 if b != b'\r' {
292 continue;
293 }
294 if let Some((_, b'\n')) = iter.next() {
295 if too_long {
296 self.read_buffer.drain(0..i + 2);
297
298 if let Some(tracer) = &self.tracer {
299 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
300 level: Level::ERROR,
301 message: "Response too long".to_string(),
302 });
303 }
304
305 return Err(ClientError::ResponseTooLong);
306 }
307
308 let line = String::from_utf8(self.read_buffer[0..i].to_vec());
309 self.read_buffer.drain(0..i + 2);
310 return Ok(line?);
311 }
312 }
313 if self.read_buffer.len() > MAX_LINE_LEN {
314 self.read_buffer.clear();
315 too_long = true;
316 }
317
318 let mut data = [0u8; MAX_LINE_LEN];
320 let size = match self.socket.as_mut() {
321 Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
322 Ok(Ok(size)) => size,
323 Ok(Err(err)) => {
324 self.socket.take();
325 if let Some(tracer) = &self.tracer {
326 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
327 level: Level::ERROR,
328 message: format!("Error during read: {err:#}"),
329 });
330 tracer.trace_event(SmtpClientTraceEvent::Closed);
331 }
332 return Err(ClientError::ReadError {
333 command: cmd.cloned(),
334 error: format!("{err:#}"),
335 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
336 });
337 }
338 Err(_) => {
339 self.socket.take();
340 if let Some(tracer) = &self.tracer {
341 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
342 level: Level::ERROR,
343 message: format!("Read Timeout after {timeout_duration:?}"),
344 });
345 tracer.trace_event(SmtpClientTraceEvent::Closed);
346 }
347 return Err(ClientError::TimeOutResponse {
348 command: cmd.cloned(),
349 duration: timeout_duration,
350 });
351 }
352 },
353 None => {
354 return Err(ClientError::ReadError {
355 command: cmd.cloned(),
356 error: "the socket was closed in response to an earlier issue".to_string(),
357 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
358 });
359 }
360 };
361 if size == 0 {
362 self.socket.take();
363 if let Some(tracer) = &self.tracer {
364 tracer.trace_event(SmtpClientTraceEvent::Closed);
365 }
366 return Err(ClientError::ReadError {
367 command: cmd.cloned(),
368 error: "Connection closed by peer".to_string(),
369 partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
370 });
371 }
372 if let Some(tracer) = &self.tracer {
373 tracer.lazy_trace(&ReadTracer {
374 data: &data[0..size],
375 });
376 }
377 self.read_buffer.extend_from_slice(&data[0..size]);
378 }
379 }
380
381 pub async fn read_response(
382 &mut self,
383 command: Option<&Command>,
384 timeout_duration: Duration,
385 ) -> Result<Response, ClientError> {
386 if let Some(sock) = self.socket.as_mut() {
387 if let Err(err) = sock.flush().await {
388 self.socket.take();
389 if let Some(tracer) = &self.tracer {
390 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
391 level: Level::ERROR,
392 message: format!("Error during flush: {err:#}"),
393 });
394 tracer.trace_event(SmtpClientTraceEvent::Closed);
395 }
396 return Err(ClientError::FlushError {
397 error: format!("{err:#}"),
398 });
399 }
400 }
401
402 let mut line = self.read_line(timeout_duration, command).await?;
403 tracing::trace!("recv<-{}: {line}", self.hostname);
404 let mut parsed = parse_response_line(&line)?;
405 let mut response_builder = ResponseBuilder::new(&parsed);
406
407 let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
408 while !parsed.is_final {
409 line = self
410 .read_line(subsequent_line_timeout_duration, command)
411 .await?;
412 parsed = parse_response_line(&line)?;
413 response_builder
414 .add_line(&parsed)
415 .map_err(ClientError::MalformedResponseLine)?;
416 }
417
418 let response = response_builder.build(command.map(|cmd| cmd.encode()));
419
420 tracing::trace!("{}: {response:?}", self.hostname);
421
422 Ok(response)
423 }
424
425 pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
426 self.write_command_request(command).await?;
427 self.read_response(Some(command), command.client_timeout(&self.timeouts))
428 .await
429 }
430
431 async fn write_all_with_timeout<F, G>(
435 &mut self,
436 timeout_duration: Duration,
437 bytes: &[u8],
438 make_timeout_err: F,
439 make_write_err: G,
440 ) -> Result<(), ClientError>
441 where
442 F: FnOnce() -> ClientError,
443 G: FnOnce(String) -> ClientError,
444 {
445 match self.socket.as_mut() {
446 Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
447 Ok(Ok(response)) => Ok(response),
448 Ok(Err(err)) => {
449 self.socket.take();
450 if let Some(tracer) = &self.tracer {
451 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
452 level: Level::ERROR,
453 message: format!("Error during write: {err:#}"),
454 });
455 tracer.trace_event(SmtpClientTraceEvent::Closed);
456 }
457 Err(make_write_err(format!("{err:#}")))
458 }
459 Err(_) => {
460 self.socket.take();
461 if let Some(tracer) = &self.tracer {
462 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
463 level: Level::ERROR,
464 message: format!("Write Timeout after {timeout_duration:?}"),
465 });
466 tracer.trace_event(SmtpClientTraceEvent::Closed);
467 }
468 Err(make_timeout_err())
469 }
470 },
471 None => Err(make_write_err(
472 "the socket was closed in response to an earlier issue".to_string(),
473 )),
474 }
475 }
476
477 async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
478 let total_timeout: Duration = commands
479 .iter()
480 .map(|cmd| cmd.client_timeout_request(&self.timeouts))
481 .sum();
482
483 let mut lines: Vec<String> = vec![];
484 let mut all = String::new();
485 for cmd in commands {
486 let line = cmd.encode();
487 all.push_str(&line);
488 lines.push(line);
489 }
490 tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
491 if self.socket.is_some() {
492 if let Some(tracer) = &self.tracer {
493 for line in lines {
496 WriteTracer::trace(tracer, &line);
497 }
498 }
499 }
500 self.write_all_with_timeout(
501 total_timeout,
502 all.as_bytes(),
503 || ClientError::TimeOutRequest {
504 duration: total_timeout,
505 commands: commands.to_vec(),
506 },
507 |error| ClientError::WriteError {
508 error,
509 commands: commands.to_vec(),
510 },
511 )
512 .await
513 }
514
515 async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
516 let line = command.encode();
517 tracing::trace!("send->{}: {line}", self.hostname);
518 if self.socket.is_some() {
519 if let Some(tracer) = &self.tracer {
520 WriteTracer::trace(tracer, &line);
521 }
522 }
523
524 let timeout_duration = command.client_timeout_request(&self.timeouts);
525 self.write_all_with_timeout(
526 timeout_duration,
527 line.as_bytes(),
528 || ClientError::TimeOutRequest {
529 duration: timeout_duration,
530 commands: vec![command.clone()],
531 },
532 |error| ClientError::WriteError {
533 error,
534 commands: vec![command.clone()],
535 },
536 )
537 .await
538 }
539
540 async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
541 if self.socket.is_some() {
542 if let Some(tracer) = &self.tracer {
543 BinWriteTracer::trace(tracer, data);
544 }
545 }
546 let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
547 self.write_all_with_timeout(
548 timeout_duration,
549 data,
550 || ClientError::TimeOutData,
551 |error| ClientError::WriteError {
552 error,
553 commands: vec![],
554 },
555 )
556 .await
557 }
558
559 pub async fn pipeline_commands(
575 &mut self,
576 commands: Vec<Command>,
577 ) -> Vec<Result<Response, ClientError>> {
578 let mut results: Vec<Result<Response, ClientError>> = vec![];
579
580 let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
581 if pipeline {
582 if let Err(err) = self.write_pipeline_request(&commands).await {
583 let err: ClientError = err;
584 results.push(Err(err.clone()));
585 while results.len() < commands.len() {
586 results.push(Err(err.clone()));
588 }
589 return results;
590 }
591
592 for cmd in &commands {
594 results.push(
595 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
596 .await,
597 );
598 }
599 return results;
600 }
601
602 for cmd in &commands {
603 if let Err(err) = self.write_command_request(cmd).await {
604 let err: ClientError = err;
605 results.push(Err(err.clone()));
606 while results.len() < commands.len() {
607 results.push(Err(err.clone()));
609 }
610 return results;
611 }
612 results.push(
615 self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
616 .await,
617 );
618 }
619 results
620 }
621
622 pub async fn ehlo_lhlo(
623 &mut self,
624 ehlo_name: &str,
625 use_lmtp: bool,
626 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
627 if use_lmtp {
628 self.lhlo(ehlo_name).await
629 } else {
630 self.ehlo(ehlo_name).await
631 }
632 }
633
634 pub async fn lhlo(
635 &mut self,
636 ehlo_name: &str,
637 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
638 let response = self
639 .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
640 .await?;
641 self.ehlo_common(response)
642 }
643
644 pub async fn ehlo(
645 &mut self,
646 ehlo_name: &str,
647 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
648 let response = self
649 .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
650 .await?;
651 self.ehlo_common(response)
652 }
653
654 fn ehlo_common(
655 &mut self,
656 response: Response,
657 ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
658 if response.code != 250 {
659 return Err(ClientError::Rejected(response));
660 }
661
662 let mut capabilities = HashMap::new();
663
664 for line in response.content.lines().skip(1) {
665 let mut fields = line.splitn(2, ' ');
666 if let Some(name) = fields.next() {
667 let param = fields.next().map(|s| s.to_string());
668 let cap = EsmtpCapability {
669 name: name.to_string(),
670 param,
671 };
672 capabilities.insert(name.to_ascii_uppercase(), cap);
673 }
674 }
675
676 self.capabilities = capabilities;
677 Ok(&self.capabilities)
678 }
679
680 pub async fn auth_plain(
681 &mut self,
682 username: &str,
683 password: Option<&str>,
684 ) -> Result<(), ClientError> {
685 let password = password.unwrap_or("");
688 let payload = format!("\x00{username}\x00{password}");
689 let payload = data_encoding::BASE64.encode(payload.as_bytes());
690
691 let response = self
692 .send_command(&Command::Auth {
693 sasl_mech: "PLAIN".to_string(),
694 initial_response: Some(payload),
695 })
696 .await?;
697
698 if response.code != 235 {
699 return Err(ClientError::Rejected(response));
700 }
701
702 Ok(())
703 }
704
705 pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
711 let resp = self.send_command(&Command::StartTls).await?;
712 if resp.code != 220 {
713 return Err(ClientError::Rejected(resp));
714 }
715
716 let mut handshake_error = None;
717 let mut tls_info = TlsInformation::default();
718
719 let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
720 || !options.dane_tlsa.is_empty()
721 {
722 let connector = options
723 .build_openssl_connector(&self.hostname)
724 .map_err(|error| ClientError::FailedToBuildConnector {
725 error: error.to_string(),
726 })?;
727 let ssl = connector.into_ssl(self.hostname.as_str())?;
728
729 let (stream, dup_stream) = match self.socket.take() {
730 Some(s) => {
731 let d = s.try_dup();
732 (s, d)
733 }
734 None => return Err(ClientError::NotConnected),
735 };
736
737 let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
738
739 if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
740 handshake_error.replace(format!("{err:#}"));
741 }
742
743 tls_info.provider_name = "openssl".to_string();
744 tls_info.cipher = match ssl_stream.ssl().current_cipher() {
745 Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
746 None => String::new(),
747 };
748 tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
749
750 if let Some(cert) = ssl_stream.ssl().peer_certificate() {
751 tls_info.subject_name = subject_name(&cert);
752 }
753 if let Ok(authority) = ssl_stream.ssl().dane_authority() {
754 if let Some(cert) = &authority.cert {
755 tls_info.subject_name = subject_name(cert);
756 }
757 }
758
759 match (&handshake_error, dup_stream) {
760 (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
761 drop(ssl_stream);
766 Box::new(dup_stream)
767 }
768 _ => Box::new(ssl_stream),
769 }
770 } else {
771 tls_info.provider_name = "rustls".to_string();
772 let connector = options.build_tls_connector().await.map_err(|error| {
773 ClientError::FailedToBuildConnector {
774 error: error.to_string(),
775 }
776 })?;
777 let server_name = match IpAddr::from_str(self.hostname.as_str()) {
778 Ok(ip) => ServerName::IpAddress(ip.into()),
779 Err(_) => ServerName::try_from(self.hostname.clone())
780 .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
781 };
782
783 match connector
784 .connect(
785 server_name,
786 match self.socket.take() {
787 Some(s) => s,
788 None => return Err(ClientError::NotConnected),
789 },
790 )
791 .into_fallible()
792 .await
793 {
794 Ok(stream) => {
795 let (_, conn) = stream.get_ref();
796 tls_info.cipher = match conn.negotiated_cipher_suite() {
797 Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
798 None => String::new(),
799 };
800 tls_info.protocol_version = match conn.protocol_version() {
801 Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
802 None => String::new(),
803 };
804
805 if let Some(certs) = conn.peer_certificates() {
806 let peer_cert = &certs[0];
807 if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
808 tls_info.subject_name = subject_name(&cert);
809 }
810 }
811
812 Box::new(stream)
813 }
814 Err((err, stream)) => {
815 handshake_error.replace(format!("{err:#}"));
816 stream
817 }
818 }
819 };
820
821 if let Some(tracer) = &self.tracer {
822 tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
823 level: Level::INFO,
824 message: match &handshake_error {
825 Some(error) => format!("STARTTLS handshake failed: {error:?}"),
826 None => format!("STARTTLS handshake -> {tls_info:?}"),
827 },
828 });
829 }
830
831 self.socket.replace(stream);
832 Ok(match handshake_error {
833 Some(error) => TlsStatus::FailedHandshake(error),
834 None => TlsStatus::Info(tls_info),
835 })
836 }
837
838 pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
839 &mut self,
840 sender: SENDER,
841 recipient: RECIP,
842 data: B,
843 ) -> Result<Response, ClientError> {
844 let sender = sender.into();
845 let recipient = recipient.into();
846
847 let data: &[u8] = data.as_ref();
848 let stuffed;
849
850 let data = match apply_dot_stuffing(data) {
851 Some(d) => {
852 stuffed = d;
853 &stuffed
854 }
855 None => data,
856 };
857
858 let data_is_8bit = data.iter().any(|&b| b >= 0x80);
859 let envelope_is_8bit = !sender.is_ascii() || !recipient.is_ascii();
860
861 let mut mail_from_params = vec![];
862 if data_is_8bit && self.capabilities.contains_key("8BITMIME") {
863 mail_from_params.push(EsmtpParameter {
864 name: "BODY".to_string(),
865 value: Some("8BITMIME".to_string()),
866 });
867 }
868
869 if envelope_is_8bit && self.capabilities.contains_key("SMTPUTF8") {
870 mail_from_params.push(EsmtpParameter {
871 name: "SMTPUTF8".to_string(),
872 value: None,
873 });
874 }
875
876 let mut commands = vec![];
877
878 let used_rset = self.use_rset;
888 if self.use_rset {
889 commands.push(Command::Rset);
890 }
891 commands.push(Command::MailFrom {
892 address: sender,
893 parameters: mail_from_params,
894 });
895 commands.push(Command::RcptTo {
896 address: recipient,
897 parameters: vec![],
898 });
899 commands.push(Command::Data);
900
901 self.use_rset = true;
904
905 let mut responses = self.pipeline_commands(commands).await;
906
907 let is_err = responses.iter().any(|r| r.is_err());
915
916 if used_rset {
917 let rset_resp = responses.remove(0)?;
918 if rset_resp.code != 250 {
919 return Err(ClientError::Rejected(rset_resp));
920 }
921 }
922
923 let mail_resp = responses.remove(0)?;
924 if is_err && mail_resp.code != 250 {
925 return Err(ClientError::Rejected(mail_resp));
926 }
927
928 let rcpt_resp = responses.remove(0)?;
929 if is_err && rcpt_resp.code != 250 {
930 return Err(ClientError::Rejected(rcpt_resp));
931 }
932
933 let data_resp = responses.remove(0)?;
934 if is_err && data_resp.code != 354 {
935 return Err(ClientError::Rejected(data_resp));
936 }
937
938 if data_resp.code == 354 && (mail_resp.code != 250 || rcpt_resp.code != 250) {
939 self.write_data_with_timeout(b".\r\n").await?;
948 let data_dot = Command::DataDot;
949 let _ = self
951 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
952 .await?;
953
954 }
958
959 if mail_resp.code != 250 {
960 return Err(ClientError::Rejected(mail_resp));
961 }
962 if rcpt_resp.code != 250 {
963 return Err(ClientError::Rejected(rcpt_resp));
964 }
965 if data_resp.code != 354 {
966 return Err(ClientError::Rejected(data_resp));
967 }
968
969 let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
970
971 tracing::trace!("message data is {} bytes", data.len());
972
973 self.write_data_with_timeout(data).await?;
974
975 let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
976
977 tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
978
979 self.write_data_with_timeout(marker.as_bytes()).await?;
980
981 let data_dot = Command::DataDot;
982 let resp = self
983 .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
984 .await?;
985 if resp.code != 250 {
986 return Err(ClientError::Rejected(resp));
987 }
988
989 self.use_rset = self.enable_rset;
992
993 Ok(resp)
994 }
995}
996
997#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
998pub enum TlsStatus {
999 FailedHandshake(String),
1000 Info(TlsInformation),
1001}
1002
1003#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1004pub struct TlsInformation {
1005 pub cipher: String,
1006 pub protocol_version: String,
1007 pub subject_name: Vec<String>,
1008 pub provider_name: String,
1009}
1010
1011impl Drop for SmtpClient {
1012 fn drop(&mut self) {
1013 if let Some(tracer) = &self.tracer {
1014 if self.socket.is_some() {
1015 tracer.trace_event(SmtpClientTraceEvent::Closed);
1016 }
1017 }
1018 }
1019}
1020fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1021 if line.len() < 4 {
1022 return Err(ClientError::MalformedResponseLine(line.to_string()));
1023 }
1024
1025 match line.as_bytes()[3] {
1026 b' ' | b'-' => match line[0..3].parse::<u16>() {
1027 Ok(code) => Ok(ResponseLine {
1028 code,
1029 is_final: line.as_bytes()[3] == b' ',
1030 content: &line[4..],
1031 }),
1032 Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1033 },
1034 _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1035 }
1036}
1037
1038impl TlsOptions {
1039 pub fn build_openssl_connector(
1040 &self,
1041 hostname: &str,
1042 ) -> Result<openssl::ssl::ConnectConfiguration, ClientError> {
1043 tracing::trace!("build_openssl_connector for {hostname}");
1044 let mut builder =
1045 openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
1046
1047 if let (Some(cert_data), Some(key_data)) =
1048 (&self.certificate_from_pem, &self.private_key_from_pem)
1049 {
1050 let cert = X509::from_pem(cert_data)?;
1051 builder.set_certificate(&cert)?;
1052
1053 let key = PKey::private_key_from_pem(key_data)?;
1054 builder.set_private_key(&key)?;
1055
1056 builder.check_private_key()?;
1057 }
1058
1059 if let Some(list) = &self.openssl_cipher_list {
1060 builder.set_cipher_list(list)?;
1061 }
1062
1063 if let Some(suites) = &self.openssl_cipher_suites {
1064 builder.set_ciphersuites(suites)?;
1065 }
1066
1067 if let Some(options) = &self.openssl_options {
1068 builder.clear_options(SslOptions::all());
1069 builder.set_options(*options);
1070 }
1071
1072 if self.insecure {
1073 builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
1074 }
1075
1076 if !self.dane_tlsa.is_empty() {
1077 builder.dane_enable()?;
1078 builder.set_no_dane_ee_namechecks();
1079 }
1080
1081 let connector = builder.build();
1082
1083 let mut config = connector.configure()?;
1084
1085 if !self.dane_tlsa.is_empty() {
1086 config.dane_enable(hostname)?;
1087 let mut any_usable = false;
1088 for tlsa in &self.dane_tlsa {
1089 let usable = config.dane_tlsa_add(
1090 match tlsa.cert_usage() {
1091 CertUsage::PkixTa => DaneUsage::PKIX_TA,
1092 CertUsage::PkixEe => DaneUsage::PKIX_EE,
1093 CertUsage::DaneTa => DaneUsage::DANE_TA,
1094 CertUsage::DaneEe => DaneUsage::DANE_EE,
1095 CertUsage::Unassigned(n) => DaneUsage::from_raw(n),
1096 CertUsage::Private => DaneUsage::PRIV_CERT,
1097 },
1098 match tlsa.selector() {
1099 Selector::Full => DaneSelector::CERT,
1100 Selector::Spki => DaneSelector::SPKI,
1101 Selector::Unassigned(n) => DaneSelector::from_raw(n),
1102 Selector::Private => DaneSelector::PRIV_SEL,
1103 },
1104 match tlsa.matching() {
1105 Matching::Raw => DaneMatchType::FULL,
1106 Matching::Sha256 => DaneMatchType::SHA2_256,
1107 Matching::Sha512 => DaneMatchType::SHA2_512,
1108 Matching::Unassigned(n) => DaneMatchType::from_raw(n),
1109 Matching::Private => DaneMatchType::PRIV_MATCH,
1110 },
1111 tlsa.cert_data(),
1112 )?;
1113
1114 tracing::trace!("build_dane_connector usable={usable} {tlsa:?}");
1115 if usable {
1116 any_usable = true;
1117 }
1118 }
1119
1120 if !any_usable {
1121 return Err(ClientError::NoUsableDaneTlsa {
1122 hostname: hostname.to_string(),
1123 tlsa: self.dane_tlsa.clone(),
1124 });
1125 }
1126 }
1127
1128 Ok(config)
1129 }
1130}
1131
1132fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1133 static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1134
1135 if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1136 return None;
1137 }
1138
1139 let mut stuffed = vec![];
1140 if data.starts_with(b".") {
1141 stuffed.push(b'.');
1142 }
1143 let mut last_idx = 0;
1144 for i in LFDOT.find_iter(data) {
1145 stuffed.extend_from_slice(&data[last_idx..=i]);
1146 stuffed.push(b'.');
1147 last_idx = i + 1;
1148 }
1149 stuffed.extend_from_slice(&data[last_idx..]);
1150 Some(stuffed)
1151}
1152
1153pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1160 let mut subject_name = vec![];
1161 for entry in cert.subject_name().entries() {
1162 if let Ok(obj) = entry.object().nid().short_name() {
1163 if let Ok(name) = entry.data().as_utf8() {
1164 subject_name.push(format!("{obj}={name}"));
1165 }
1166 }
1167 }
1168 subject_name
1169}
1170
1171#[cfg(test)]
1172mod test {
1173 use super::*;
1174
1175 #[test]
1176 fn test_stuffing() {
1177 assert_eq!(apply_dot_stuffing(b"foo"), None);
1178 assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1179 assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1180 assert_eq!(
1181 apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1182 b"foo\n..bar\n...baz\n"
1183 );
1184 }
1185
1186 #[test]
1210 fn response_line_parsing() {
1211 assert_eq!(
1212 parse_response_line("220 woot").unwrap(),
1213 ResponseLine {
1214 code: 220,
1215 is_final: true,
1216 content: "woot"
1217 }
1218 );
1219 assert_eq!(
1220 parse_response_line("220-woot").unwrap(),
1221 ResponseLine {
1222 code: 220,
1223 is_final: false,
1224 content: "woot"
1225 }
1226 );
1227
1228 assert!(matches!(
1229 parse_response_line("220_woot"),
1230 Err(ClientError::MalformedResponseLine(_))
1231 ));
1232 assert!(matches!(
1233 parse_response_line("not really"),
1234 Err(ClientError::MalformedResponseLine(_))
1235 ));
1236 }
1237
1238 fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1239 let mut parsed = parse_response_line(lines[0])?;
1240 let mut b = ResponseBuilder::new(&parsed);
1241 for line in &lines[1..] {
1242 parsed = parse_response_line(line)?;
1243 b.add_line(&parsed)
1244 .map_err(ClientError::MalformedResponseLine)?;
1245 }
1246 assert!(parsed.is_final);
1247 Ok(b.build(None))
1248 }
1249
1250 #[test]
1251 fn multi_line_response() {
1252 assert_eq!(
1253 parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1254 Response {
1255 code: 220,
1256 enhanced_code: None,
1257 content: "woot\nmore\ndone".to_string(),
1258 command: None
1259 }
1260 );
1261
1262 let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1263 assert!(
1264 matches!(
1265 res,
1266 ClientError::MalformedResponseLine(ref err) if err == "221-more"
1267 ),
1268 "got error {res:?}"
1269 );
1270
1271 let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1272 assert!(
1273 matches!(
1274 res,
1275 ClientError::MalformedResponseLine(ref err) if err == "221 done"
1276 ),
1277 "got error {res:?}"
1278 );
1279
1280 assert_eq!(
1281 parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1282 Response {
1283 code: 220,
1284 enhanced_code: Some(EnhancedStatusCode {
1285 class: 4,
1286 subject: 1,
1287 detail: 0
1288 }),
1289 content: "woot\nmore\ndone".to_string(),
1290 command: None
1291 }
1292 );
1293
1294 assert_eq!(
1296 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1297 Response {
1298 code: 220,
1299 enhanced_code: Some(EnhancedStatusCode {
1300 class: 4,
1301 subject: 1,
1302 detail: 0
1303 }),
1304 content: "woot\nmore\ndone".to_string(),
1305 command: None
1306 }
1307 );
1308
1309 assert_eq!(
1311 parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1312 Response {
1313 code: 220,
1314 enhanced_code: Some(EnhancedStatusCode {
1315 class: 4,
1316 subject: 1,
1317 detail: 0
1318 }),
1319 content: "woot\nmore\n5.5.5 done".to_string(),
1320 command: None
1321 }
1322 );
1323 }
1324
1325 #[test]
1326 fn test_extract_hostname() {
1327 assert_eq!(extract_hostname("foo"), "foo");
1328 assert_eq!(extract_hostname("foo."), "foo");
1329 assert_eq!(extract_hostname("foo:25"), "foo");
1330 assert_eq!(extract_hostname("foo.:25"), "foo");
1331 assert_eq!(extract_hostname("[foo]:25"), "foo");
1332 assert_eq!(extract_hostname("[foo.]:25"), "foo");
1333 assert_eq!(extract_hostname("[::1]:25"), "::1");
1334 assert_eq!(extract_hostname("::1:25"), "::1");
1335 }
1336}