rfc5321/
client.rs

1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4    AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5    ReversePath,
6};
7use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
8use hickory_proto::rr::rdata::TLSA;
9use memchr::memmem::Finder;
10use openssl::pkey::PKey;
11use openssl::ssl::{DaneMatchType, DaneSelector, DaneUsage, SslOptions};
12use openssl::x509::{X509Ref, X509};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::net::IpAddr;
16use std::str::FromStr;
17use std::sync::{Arc, LazyLock};
18use std::time::Duration;
19use thiserror::Error;
20use tokio::io::{AsyncReadExt, AsyncWriteExt};
21use tokio::net::{TcpStream, ToSocketAddrs};
22use tokio::time::timeout;
23use tokio_rustls::rustls::pki_types::ServerName;
24use tracing::Level;
25
26pub use crate::tls::TlsOptions;
27pub use {openssl, tokio_rustls};
28
29const MAX_LINE_LEN: usize = 4096;
30
31#[derive(Error, Debug, Clone)]
32pub enum ClientError {
33    #[error("response is not UTF8")]
34    Utf8(#[from] std::string::FromUtf8Error),
35    #[error("Malformed Response: {0}")]
36    MalformedResponseLine(String),
37    #[error("Response line is too long")]
38    ResponseTooLong,
39    #[error("Not connected")]
40    NotConnected,
41    #[error("Command rejected {0:?}")]
42    Rejected(Response),
43    #[error("Commands rejected {0:?}")]
44    RejectedBatch(Vec<Response>),
45    #[error("STARTTLS: {0} is not a valid DNS name")]
46    InvalidDnsName(String),
47    #[error("Invalid client certificate configured: {error:?}")]
48    FailedToBuildConnector { error: String },
49    #[error("Timed Out waiting {duration:?} for response to {command:?}")]
50    TimeOutResponse {
51        command: Option<Command>,
52        duration: Duration,
53    },
54    #[error("Timed Out writing {duration:?} {commands:?}")]
55    TimeOutRequest {
56        commands: Vec<Command>,
57        duration: Duration,
58    },
59    #[error("Error {error} reading response to {command:?}")]
60    ReadError {
61        command: Option<Command>,
62        error: String,
63        partial: String,
64    },
65    #[error("Error {error} flushing send buffer")]
66    FlushError { error: String },
67    #[error("Error {error} writing {commands:?}")]
68    WriteError {
69        commands: Vec<Command>,
70        error: String,
71    },
72    #[error("Timed Out sending message payload data")]
73    TimeOutData,
74    #[error("SSL Error: {0}")]
75    SslErrorStack(#[from] openssl::error::ErrorStack),
76    #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
77    NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
78}
79
80impl ClientError {
81    /// Returns the command(s) string suitable for passing into a Response
82    pub fn command(&self) -> Option<String> {
83        match self {
84            Self::TimeOutResponse {
85                command: Some(command),
86                ..
87            }
88            | Self::ReadError {
89                command: Some(command),
90                ..
91            } => Some(command.encode()),
92            Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
93                if !commands.is_empty() =>
94            {
95                let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
96                Some(commands.join(""))
97            }
98            _ => None,
99        }
100    }
101
102    /// If the error contents were likely caused by something
103    /// about the mostly recently attempted message, rather than
104    /// a transport issue, or a carry-over from a prior message
105    /// (eg: previous message was rejected and destination chose
106    /// to drop the connection, which we detect later in RSET
107    /// on the next message), then we return true.
108    /// The expectation is that the caller will transiently
109    /// fail the message for later retry.
110    /// If we return false then the caller might decide to
111    /// try that same message again more immediately on
112    /// a separate connection
113    pub fn was_due_to_message(&self) -> bool {
114        match self {
115            Self::Utf8(_)
116            | Self::MalformedResponseLine(_)
117            | Self::ResponseTooLong
118            | Self::NotConnected
119            | Self::InvalidDnsName(_)
120            | Self::FailedToBuildConnector { .. }
121            | Self::TimeOutResponse { .. }
122            | Self::TimeOutRequest { .. }
123            | Self::ReadError { .. }
124            | Self::FlushError { .. }
125            | Self::WriteError { .. }
126            | Self::TimeOutData
127            | Self::SslErrorStack(_)
128            | Self::NoUsableDaneTlsa { .. } => false,
129            Self::Rejected(response) => response.was_due_to_message(),
130            Self::RejectedBatch(responses) => match responses.len() {
131                1 => responses[0].was_due_to_message(),
132                _ => false,
133            },
134        }
135    }
136}
137
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct EsmtpCapability {
140    pub name: String,
141    pub param: Option<String>,
142}
143
144#[derive(Clone, Debug)]
145pub enum SmtpClientTraceEvent {
146    Closed,
147    Read(Vec<u8>),
148    Write(String),
149    Diagnostic {
150        level: tracing::Level,
151        message: String,
152    },
153}
154
155pub trait DeferredTracer {
156    fn trace(&self) -> SmtpClientTraceEvent;
157}
158
159pub trait SmtpClientTracer: std::fmt::Debug {
160    fn trace_event(&self, event: SmtpClientTraceEvent);
161    fn lazy_trace(&self, deferred: &dyn DeferredTracer);
162}
163
164// helper to avoid making a second copy of every write buffer
165struct WriteTracer<'a> {
166    data: &'a str,
167}
168impl DeferredTracer for WriteTracer<'_> {
169    fn trace(&self) -> SmtpClientTraceEvent {
170        SmtpClientTraceEvent::Write(self.data.to_string())
171    }
172}
173impl<'a> WriteTracer<'a> {
174    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
175        tracer.lazy_trace(&Self { data });
176    }
177}
178
179struct BinWriteTracer<'a> {
180    data: &'a [u8],
181}
182impl DeferredTracer for BinWriteTracer<'_> {
183    fn trace(&self) -> SmtpClientTraceEvent {
184        let data = String::from_utf8_lossy(self.data).to_string();
185        SmtpClientTraceEvent::Write(data)
186    }
187}
188impl<'a> BinWriteTracer<'a> {
189    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
190        tracer.lazy_trace(&Self { data });
191    }
192}
193
194// A little bit of gymnastics to avoid making a second
195// copy of every read buffer
196struct ReadTracer<'a> {
197    data: &'a [u8],
198}
199impl DeferredTracer for ReadTracer<'_> {
200    fn trace(&self) -> SmtpClientTraceEvent {
201        SmtpClientTraceEvent::Read(self.data.to_vec())
202    }
203}
204
205#[derive(Debug)]
206pub struct SmtpClient {
207    socket: Option<BoxedAsyncReadAndWrite>,
208    hostname: String,
209    capabilities: HashMap<String, EsmtpCapability>,
210    read_buffer: Vec<u8>,
211    timeouts: SmtpClientTimeouts,
212    tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
213    use_rset: bool,
214    enable_rset: bool,
215    enable_pipelining: bool,
216    ignore_8bit_checks: bool,
217}
218
219fn extract_hostname(hostname: &str) -> &str {
220    // Just the hostname, without any :port
221    let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
222    let hostname = if fields.len() == 2 {
223        fields[1]
224    } else {
225        hostname
226    };
227
228    let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
229        &hostname[1..hostname.len() - 1]
230    } else {
231        hostname
232    };
233
234    // Remove any trailing FQDN dot
235    hostname.strip_suffix(".").unwrap_or(hostname)
236}
237
238impl SmtpClient {
239    pub async fn new<A: ToSocketAddrs + ToString + Clone>(
240        addr: A,
241        timeouts: SmtpClientTimeouts,
242    ) -> std::io::Result<Self> {
243        let stream = TcpStream::connect(addr.clone()).await?;
244        // No need for Nagle with SMTP request/response
245        stream.set_nodelay(true)?;
246        Ok(Self::with_stream(stream, addr.to_string(), timeouts))
247    }
248
249    pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
250        stream: S,
251        peer_hostname: H,
252        timeouts: SmtpClientTimeouts,
253    ) -> Self {
254        let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
255
256        Self {
257            socket: Some(Box::new(stream)),
258            hostname,
259            capabilities: HashMap::new(),
260            read_buffer: Vec::with_capacity(1024),
261            timeouts,
262            tracer: None,
263            use_rset: false,
264            enable_rset: false,
265            enable_pipelining: false,
266            ignore_8bit_checks: false,
267        }
268    }
269
270    /// Setting this to true causes 8BITMIME and SMTPUTF8 checks to be ignored
271    pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
272        self.ignore_8bit_checks = enable;
273    }
274
275    pub fn is_connected(&self) -> bool {
276        self.socket.is_some()
277    }
278
279    pub fn set_enable_rset(&mut self, enable: bool) {
280        self.enable_rset = enable;
281    }
282
283    pub fn set_enable_pipelining(&mut self, enable: bool) {
284        self.enable_pipelining = enable;
285    }
286
287    pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
288        self.tracer.replace(tracer);
289    }
290
291    pub fn timeouts(&self) -> &SmtpClientTimeouts {
292        &self.timeouts
293    }
294
295    async fn read_line(
296        &mut self,
297        timeout_duration: Duration,
298        cmd: Option<&Command>,
299    ) -> Result<String, ClientError> {
300        let mut too_long = false;
301        loop {
302            let mut iter = self.read_buffer.iter().enumerate();
303            while let Some((i, &b)) = iter.next() {
304                if b != b'\r' {
305                    continue;
306                }
307                if let Some((_, b'\n')) = iter.next() {
308                    if too_long {
309                        self.read_buffer.drain(0..i + 2);
310
311                        if let Some(tracer) = &self.tracer {
312                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
313                                level: Level::ERROR,
314                                message: "Response too long".to_string(),
315                            });
316                        }
317
318                        return Err(ClientError::ResponseTooLong);
319                    }
320
321                    let line = String::from_utf8(self.read_buffer[0..i].to_vec());
322                    self.read_buffer.drain(0..i + 2);
323                    return Ok(line?);
324                }
325            }
326            if self.read_buffer.len() > MAX_LINE_LEN {
327                self.read_buffer.clear();
328                too_long = true;
329            }
330
331            // Didn't find a complete line, fill up the rest of the buffer
332            let mut data = [0u8; MAX_LINE_LEN];
333            let size = match self.socket.as_mut() {
334                Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
335                    Ok(Ok(size)) => size,
336                    Ok(Err(err)) => {
337                        self.socket.take();
338                        if let Some(tracer) = &self.tracer {
339                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
340                                level: Level::ERROR,
341                                message: format!("Error during read: {err:#}"),
342                            });
343                            tracer.trace_event(SmtpClientTraceEvent::Closed);
344                        }
345                        return Err(ClientError::ReadError {
346                            command: cmd.cloned(),
347                            error: format!("{err:#}"),
348                            partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
349                        });
350                    }
351                    Err(_) => {
352                        self.socket.take();
353                        if let Some(tracer) = &self.tracer {
354                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
355                                level: Level::ERROR,
356                                message: format!("Read Timeout after {timeout_duration:?}"),
357                            });
358                            tracer.trace_event(SmtpClientTraceEvent::Closed);
359                        }
360                        return Err(ClientError::TimeOutResponse {
361                            command: cmd.cloned(),
362                            duration: timeout_duration,
363                        });
364                    }
365                },
366                None => {
367                    return Err(ClientError::ReadError {
368                        command: cmd.cloned(),
369                        error: "the socket was closed in response to an earlier issue".to_string(),
370                        partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
371                    });
372                }
373            };
374            if size == 0 {
375                self.socket.take();
376                if let Some(tracer) = &self.tracer {
377                    tracer.trace_event(SmtpClientTraceEvent::Closed);
378                }
379                return Err(ClientError::ReadError {
380                    command: cmd.cloned(),
381                    error: "Connection closed by peer".to_string(),
382                    partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
383                });
384            }
385            if let Some(tracer) = &self.tracer {
386                tracer.lazy_trace(&ReadTracer {
387                    data: &data[0..size],
388                });
389            }
390            self.read_buffer.extend_from_slice(&data[0..size]);
391        }
392    }
393
394    pub async fn read_response(
395        &mut self,
396        command: Option<&Command>,
397        timeout_duration: Duration,
398    ) -> Result<Response, ClientError> {
399        if let Some(sock) = self.socket.as_mut() {
400            if let Err(err) = sock.flush().await {
401                self.socket.take();
402                if let Some(tracer) = &self.tracer {
403                    tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
404                        level: Level::ERROR,
405                        message: format!("Error during flush: {err:#}"),
406                    });
407                    tracer.trace_event(SmtpClientTraceEvent::Closed);
408                }
409                return Err(ClientError::FlushError {
410                    error: format!("{err:#}"),
411                });
412            }
413        }
414
415        let mut line = self.read_line(timeout_duration, command).await?;
416        tracing::trace!("recv<-{}: {line}", self.hostname);
417        let mut parsed = parse_response_line(&line)?;
418        let mut response_builder = ResponseBuilder::new(&parsed);
419
420        let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
421        while !parsed.is_final {
422            line = self
423                .read_line(subsequent_line_timeout_duration, command)
424                .await?;
425            parsed = parse_response_line(&line)?;
426            response_builder
427                .add_line(&parsed)
428                .map_err(ClientError::MalformedResponseLine)?;
429        }
430
431        let response = response_builder.build(command.map(|cmd| cmd.encode()));
432
433        tracing::trace!("{}: {response:?}", self.hostname);
434
435        Ok(response)
436    }
437
438    pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
439        self.write_command_request(command).await?;
440        self.read_response(Some(command), command.client_timeout(&self.timeouts))
441            .await
442    }
443
444    /// Wrapper around socket.write_all() that will emit trace diagnostics and synthesize
445    /// a Close event to the tracer if a timeout or IO error occurs.
446    /// If an error or timeout, occurs ensures that the socket will not be reused.
447    async fn write_all_with_timeout<F, G>(
448        &mut self,
449        timeout_duration: Duration,
450        bytes: &[u8],
451        make_timeout_err: F,
452        make_write_err: G,
453    ) -> Result<(), ClientError>
454    where
455        F: FnOnce() -> ClientError,
456        G: FnOnce(String) -> ClientError,
457    {
458        match self.socket.as_mut() {
459            Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
460                Ok(Ok(response)) => Ok(response),
461                Ok(Err(err)) => {
462                    self.socket.take();
463                    if let Some(tracer) = &self.tracer {
464                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
465                            level: Level::ERROR,
466                            message: format!("Error during write: {err:#}"),
467                        });
468                        tracer.trace_event(SmtpClientTraceEvent::Closed);
469                    }
470                    Err(make_write_err(format!("{err:#}")))
471                }
472                Err(_) => {
473                    self.socket.take();
474                    if let Some(tracer) = &self.tracer {
475                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
476                            level: Level::ERROR,
477                            message: format!("Write Timeout after {timeout_duration:?}"),
478                        });
479                        tracer.trace_event(SmtpClientTraceEvent::Closed);
480                    }
481                    Err(make_timeout_err())
482                }
483            },
484            None => Err(make_write_err(
485                "the socket was closed in response to an earlier issue".to_string(),
486            )),
487        }
488    }
489
490    async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
491        let total_timeout: Duration = commands
492            .iter()
493            .map(|cmd| cmd.client_timeout_request(&self.timeouts))
494            .sum();
495
496        let mut lines: Vec<String> = vec![];
497        let mut all = String::new();
498        for cmd in commands {
499            let line = cmd.encode();
500            all.push_str(&line);
501            lines.push(line);
502        }
503        tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
504        if self.socket.is_some() {
505            if let Some(tracer) = &self.tracer {
506                // Send the lines individually to the tracer, so that we
507                // don't break --terse mode
508                for line in lines {
509                    WriteTracer::trace(tracer, &line);
510                }
511            }
512        }
513        self.write_all_with_timeout(
514            total_timeout,
515            all.as_bytes(),
516            || ClientError::TimeOutRequest {
517                duration: total_timeout,
518                commands: commands.to_vec(),
519            },
520            |error| ClientError::WriteError {
521                error,
522                commands: commands.to_vec(),
523            },
524        )
525        .await
526    }
527
528    async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
529        let line = command.encode();
530        tracing::trace!("send->{}: {line}", self.hostname);
531        if self.socket.is_some() {
532            if let Some(tracer) = &self.tracer {
533                WriteTracer::trace(tracer, &line);
534            }
535        }
536
537        let timeout_duration = command.client_timeout_request(&self.timeouts);
538        self.write_all_with_timeout(
539            timeout_duration,
540            line.as_bytes(),
541            || ClientError::TimeOutRequest {
542                duration: timeout_duration,
543                commands: vec![command.clone()],
544            },
545            |error| ClientError::WriteError {
546                error,
547                commands: vec![command.clone()],
548            },
549        )
550        .await
551    }
552
553    async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
554        if self.socket.is_some() {
555            if let Some(tracer) = &self.tracer {
556                BinWriteTracer::trace(tracer, data);
557            }
558        }
559        let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
560        self.write_all_with_timeout(
561            timeout_duration,
562            data,
563            || ClientError::TimeOutData,
564            |error| ClientError::WriteError {
565                error,
566                commands: vec![],
567            },
568        )
569        .await
570    }
571
572    /// Issue a series of commands, and return the responses to
573    /// those commands.
574    ///
575    /// If the server advertised the RFC 2920 PIPELINING extension,
576    /// the commands are written one after the other before waiting
577    /// to read any data, resulting in lower overall latency due
578    /// to round-trip-times.
579    ///
580    /// If PIPELINING is not available, each command is written
581    /// and the response read before attempting to write the next
582    /// command.
583    ///
584    /// The number of returned responses may be smaller than the
585    /// number of requested commands if there is an issue with
586    /// the network connection.
587    pub async fn pipeline_commands(
588        &mut self,
589        commands: Vec<Command>,
590    ) -> Vec<Result<Response, ClientError>> {
591        let mut results: Vec<Result<Response, ClientError>> = vec![];
592
593        let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
594        if pipeline {
595            if let Err(err) = self.write_pipeline_request(&commands).await {
596                let err: ClientError = err;
597                results.push(Err(err.clone()));
598                while results.len() < commands.len() {
599                    // Synthesize failures for the remaining commands
600                    results.push(Err(err.clone()));
601                }
602                return results;
603            }
604
605            // Now read the responses effectively in a batch
606            for cmd in &commands {
607                results.push(
608                    self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
609                        .await,
610                );
611            }
612            return results;
613        }
614
615        for cmd in &commands {
616            if let Err(err) = self.write_command_request(cmd).await {
617                let err: ClientError = err;
618                results.push(Err(err.clone()));
619                while results.len() < commands.len() {
620                    // Synthesize failures for the remaining commands
621                    results.push(Err(err.clone()));
622                }
623                return results;
624            }
625            // Immediately request the response if the server
626            // doesn't support pipelining
627            results.push(
628                self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
629                    .await,
630            );
631        }
632        results
633    }
634
635    pub async fn ehlo_lhlo(
636        &mut self,
637        ehlo_name: &str,
638        use_lmtp: bool,
639    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
640        if use_lmtp {
641            self.lhlo(ehlo_name).await
642        } else {
643            self.ehlo(ehlo_name).await
644        }
645    }
646
647    pub async fn lhlo(
648        &mut self,
649        ehlo_name: &str,
650    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
651        let response = self
652            .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
653            .await?;
654        self.ehlo_common(response)
655    }
656
657    pub async fn ehlo(
658        &mut self,
659        ehlo_name: &str,
660    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
661        let response = self
662            .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
663            .await?;
664        self.ehlo_common(response)
665    }
666
667    fn ehlo_common(
668        &mut self,
669        response: Response,
670    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
671        if response.code != 250 {
672            return Err(ClientError::Rejected(response));
673        }
674
675        let mut capabilities = HashMap::new();
676
677        for line in response.content.lines().skip(1) {
678            let mut fields = line.splitn(2, ' ');
679            if let Some(name) = fields.next() {
680                let param = fields.next().map(|s| s.to_string());
681                let cap = EsmtpCapability {
682                    name: name.to_string(),
683                    param,
684                };
685                capabilities.insert(name.to_ascii_uppercase(), cap);
686            }
687        }
688
689        self.capabilities = capabilities;
690        Ok(&self.capabilities)
691    }
692
693    pub async fn auth_plain(
694        &mut self,
695        username: &str,
696        password: Option<&str>,
697    ) -> Result<(), ClientError> {
698        // RFC 4616 says that the format is:
699        // [authzid] NUL authcid NUL passwd
700        let password = password.unwrap_or("");
701        let payload = format!("\x00{username}\x00{password}");
702        let payload = data_encoding::BASE64.encode(payload.as_bytes());
703
704        let response = self
705            .send_command(&Command::Auth {
706                sasl_mech: "PLAIN".to_string(),
707                initial_response: Some(payload),
708            })
709            .await?;
710
711        if response.code != 235 {
712            return Err(ClientError::Rejected(response));
713        }
714
715        Ok(())
716    }
717
718    /// Attempt TLS handshake.
719    /// Returns Err for IO errors.
720    /// On completion, return an option that will be:
721    /// * Some(handshake_error) - if the handshake failed
722    /// * None - if the handshake succeeded
723    pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
724        let resp = self.send_command(&Command::StartTls).await?;
725        if resp.code != 220 {
726            return Err(ClientError::Rejected(resp));
727        }
728
729        let mut handshake_error = None;
730        let mut tls_info = TlsInformation::default();
731
732        let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
733            || !options.dane_tlsa.is_empty()
734        {
735            let connector = options
736                .build_openssl_connector(&self.hostname)
737                .map_err(|error| ClientError::FailedToBuildConnector {
738                    error: error.to_string(),
739                })?;
740            let ssl = connector.into_ssl(self.hostname.as_str())?;
741
742            let (stream, dup_stream) = match self.socket.take() {
743                Some(s) => {
744                    let d = s.try_dup();
745                    (s, d)
746                }
747                None => return Err(ClientError::NotConnected),
748            };
749
750            let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
751
752            if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
753                handshake_error.replace(format!("{err:#}"));
754            }
755
756            tls_info.provider_name = "openssl".to_string();
757            tls_info.cipher = match ssl_stream.ssl().current_cipher() {
758                Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
759                None => String::new(),
760            };
761            tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
762
763            if let Some(cert) = ssl_stream.ssl().peer_certificate() {
764                tls_info.subject_name = subject_name(&cert);
765            }
766            if let Ok(authority) = ssl_stream.ssl().dane_authority() {
767                if let Some(cert) = &authority.cert {
768                    tls_info.subject_name = subject_name(cert);
769                }
770            }
771
772            match (&handshake_error, dup_stream) {
773                (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
774                    // Try falling back to clear text on the duplicate stream.
775                    // This is imperfect: in a failed validation scenario we will
776                    // end up trying to read binary data as a string and get a UTF-8
777                    // error if the peer thinks the session is encrypted.
778                    drop(ssl_stream);
779                    Box::new(dup_stream)
780                }
781                _ => Box::new(ssl_stream),
782            }
783        } else {
784            tls_info.provider_name = "rustls".to_string();
785            let connector = options.build_tls_connector().await.map_err(|error| {
786                ClientError::FailedToBuildConnector {
787                    error: error.to_string(),
788                }
789            })?;
790            let server_name = match IpAddr::from_str(self.hostname.as_str()) {
791                Ok(ip) => ServerName::IpAddress(ip.into()),
792                Err(_) => ServerName::try_from(self.hostname.clone())
793                    .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
794            };
795
796            match connector
797                .connect(
798                    server_name,
799                    match self.socket.take() {
800                        Some(s) => s,
801                        None => return Err(ClientError::NotConnected),
802                    },
803                )
804                .into_fallible()
805                .await
806            {
807                Ok(stream) => {
808                    let (_, conn) = stream.get_ref();
809                    tls_info.cipher = match conn.negotiated_cipher_suite() {
810                        Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
811                        None => String::new(),
812                    };
813                    tls_info.protocol_version = match conn.protocol_version() {
814                        Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
815                        None => String::new(),
816                    };
817
818                    if let Some(certs) = conn.peer_certificates() {
819                        let peer_cert = &certs[0];
820                        if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
821                            tls_info.subject_name = subject_name(&cert);
822                        }
823                    }
824
825                    Box::new(stream)
826                }
827                Err((err, stream)) => {
828                    handshake_error.replace(format!("{err:#}"));
829                    stream
830                }
831            }
832        };
833
834        if let Some(tracer) = &self.tracer {
835            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
836                level: Level::INFO,
837                message: match &handshake_error {
838                    Some(error) => format!("STARTTLS handshake failed: {error:?}"),
839                    None => format!("STARTTLS handshake -> {tls_info:?}"),
840                },
841            });
842        }
843
844        self.socket.replace(stream);
845        Ok(match handshake_error {
846            Some(error) => TlsStatus::FailedHandshake(error),
847            None => TlsStatus::Info(tls_info),
848        })
849    }
850
851    pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
852        &mut self,
853        sender: SENDER,
854        recipient: RECIP,
855        data: B,
856    ) -> Result<Response, ClientError> {
857        let recipient: ForwardPath = recipient.into();
858        let status = self
859            .send_mail_multi_recip(sender, vec![recipient], data)
860            .await?;
861        Ok(status.response)
862    }
863
864    pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
865        &mut self,
866        sender: SENDER,
867        recipient_list: Vec<ForwardPath>,
868        data: B,
869    ) -> Result<BatchSendSuccess, ClientError> {
870        let sender = sender.into();
871
872        let data: &[u8] = data.as_ref();
873        let stuffed;
874
875        let data = match apply_dot_stuffing(data) {
876            Some(d) => {
877                stuffed = d;
878                &stuffed
879            }
880            None => data,
881        };
882
883        let data_is_8bit = data.iter().any(|&b| b >= 0x80);
884        let envelope_is_8bit =
885            !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
886
887        let mut mail_from_params = vec![];
888        if data_is_8bit {
889            if self.capabilities.contains_key("8BITMIME") {
890                mail_from_params.push(EsmtpParameter {
891                    name: "BODY".to_string(),
892                    value: Some("8BITMIME".to_string()),
893                });
894            } else if !self.ignore_8bit_checks {
895                return Err(ClientError::Rejected(Response {
896                    code: 554,
897                    command: None,
898                    enhanced_code: Some(EnhancedStatusCode {
899                        class: 5,
900                        subject: 6,
901                        detail: 3,
902                    }),
903                    content: "KumoMTA internal: DATA is 8bit, destination does \
904                        not support 8BITMIME. Conversion via msg:check_fix_conformance \
905                        during reception is required"
906                        .to_string(),
907                }));
908            }
909        }
910
911        if envelope_is_8bit {
912            if self.capabilities.contains_key("SMTPUTF8") {
913                mail_from_params.push(EsmtpParameter {
914                    name: "SMTPUTF8".to_string(),
915                    value: None,
916                });
917            } else if !self.ignore_8bit_checks {
918                return Err(ClientError::Rejected(Response {
919                    code: 554,
920                    command: None,
921                    enhanced_code: Some(EnhancedStatusCode {
922                        class: 5,
923                        subject: 6,
924                        detail: 7,
925                    }),
926                    content: "KumoMTA internal: envelope is 8bit, destination does \
927                        not support SMTPUTF8."
928                        .to_string(),
929                }));
930            }
931        }
932
933        let mut commands = vec![];
934
935        // We want to avoid using RSET for the first message we send on
936        // a given connection, because postfix can run in a mode where
937        // it will not tolerate RSET because it considers it to be a "junk"
938        // command, and rejecting junk commands will cut down on its load
939        // when it is under stress; it is used as a load shedding approach.
940        // If we always RSET then we will never deliver to a site that is
941        // configured that way. If we take care to RSET only for subsequent
942        // sends, then we should get one message per connection through
943        // without being unfairly penalized for defensively RSETing.
944        let used_rset = self.use_rset;
945        if self.use_rset {
946            commands.push(Command::Rset);
947        }
948        commands.push(Command::MailFrom {
949            address: sender,
950            parameters: mail_from_params,
951        });
952
953        for recipient in &recipient_list {
954            commands.push(Command::RcptTo {
955                address: recipient.clone(),
956                parameters: vec![],
957            });
958        }
959        commands.push(Command::Data);
960
961        // Assume that something might break below: if it does, we want
962        // to ensure that we RSET the connection on the next go around.
963        self.use_rset = true;
964
965        let mut responses = self.pipeline_commands(commands).await;
966
967        // This is a little awkward. We want to handle the RFC 2090 3.1 case
968        // below, which requires deferring checking the actual response codes
969        // until later, but we must also handle the case where we had a hard
970        // transport error partway through pipelining.
971        // So we set a flag for that case and will then "eagerly", wrt. the
972        // RFC 2090 3.1 logic, evaluate the SMTP response codes, so that we
973        // can propagate the correct error disposition up to the caller.
974        let is_err = responses.iter().any(|r| r.is_err());
975
976        if used_rset {
977            let rset_resp = responses.remove(0)?;
978            if rset_resp.code != 250 {
979                return Err(ClientError::Rejected(rset_resp));
980            }
981        }
982
983        let mail_resp = responses.remove(0)?;
984        if is_err && mail_resp.code != 250 {
985            return Err(ClientError::Rejected(mail_resp));
986        }
987
988        let mut rcpt_responses = vec![];
989        for _ in &recipient_list {
990            rcpt_responses.push(responses.remove(0)?);
991        }
992
993        if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
994            return Err(ClientError::RejectedBatch(rcpt_responses));
995        }
996
997        let data_resp = responses.remove(0)?;
998        if is_err && data_resp.code != 354 {
999            return Err(ClientError::Rejected(data_resp));
1000        }
1001
1002        if data_resp.code == 354
1003            && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1004        {
1005            // RFC 2920 3.1:
1006            // the client cannot assume that the DATA command will be rejected
1007            // just because none of the RCPT TO commands worked.  If the DATA
1008            // command was properly rejected the client SMTP can just issue
1009            // RSET, but if the DATA command was accepted the client SMTP
1010            // should send a single dot.
1011
1012            // Send dummy data
1013            self.write_data_with_timeout(b".\r\n").await?;
1014            let data_dot = Command::DataDot;
1015            // wait for its response
1016            let _ = self
1017                .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1018                .await?;
1019
1020            // Continue below: we will match one of the failure cases and
1021            // return a ClientError::Rejected from one of the earlier
1022            // commands
1023        }
1024
1025        if mail_resp.code != 250 {
1026            return Err(ClientError::Rejected(mail_resp));
1027        }
1028        if rcpt_responses.iter().all(|resp| resp.code != 250) {
1029            if rcpt_responses.len() == 1 {
1030                return Err(ClientError::Rejected(
1031                    rcpt_responses.pop().expect("have at least one"),
1032                ));
1033            }
1034            return Err(ClientError::RejectedBatch(rcpt_responses));
1035        }
1036        if data_resp.code != 354 {
1037            return Err(ClientError::Rejected(data_resp));
1038        }
1039
1040        let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1041
1042        tracing::trace!("message data is {} bytes", data.len());
1043
1044        self.write_data_with_timeout(data).await?;
1045
1046        let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1047
1048        tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1049
1050        self.write_data_with_timeout(marker.as_bytes()).await?;
1051
1052        let data_dot = Command::DataDot;
1053        let resp = self
1054            .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1055            .await?;
1056        if resp.code != 250 {
1057            return Err(ClientError::Rejected(resp));
1058        }
1059
1060        // If everything went well, respect the user preference for speculatively
1061        // issuing an RSET next time around
1062        self.use_rset = self.enable_rset;
1063
1064        Ok(BatchSendSuccess {
1065            response: resp,
1066            rcpt_responses,
1067        })
1068    }
1069}
1070
1071#[derive(Debug)]
1072pub struct BatchSendSuccess {
1073    pub response: Response,
1074    pub rcpt_responses: Vec<Response>,
1075}
1076
1077#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1078pub enum TlsStatus {
1079    FailedHandshake(String),
1080    Info(TlsInformation),
1081}
1082
1083#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1084pub struct TlsInformation {
1085    pub cipher: String,
1086    pub protocol_version: String,
1087    pub subject_name: Vec<String>,
1088    pub provider_name: String,
1089}
1090
1091impl Drop for SmtpClient {
1092    fn drop(&mut self) {
1093        if let Some(tracer) = &self.tracer {
1094            if self.socket.is_some() {
1095                tracer.trace_event(SmtpClientTraceEvent::Closed);
1096            }
1097        }
1098    }
1099}
1100fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1101    if line.len() < 4 {
1102        return Err(ClientError::MalformedResponseLine(line.to_string()));
1103    }
1104
1105    match line.as_bytes()[3] {
1106        b' ' | b'-' => match line[0..3].parse::<u16>() {
1107            Ok(code) => Ok(ResponseLine {
1108                code,
1109                is_final: line.as_bytes()[3] == b' ',
1110                content: &line[4..],
1111            }),
1112            Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1113        },
1114        _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1115    }
1116}
1117
1118impl TlsOptions {
1119    pub fn build_openssl_connector(
1120        &self,
1121        hostname: &str,
1122    ) -> Result<openssl::ssl::ConnectConfiguration, ClientError> {
1123        tracing::trace!("build_openssl_connector for {hostname}");
1124        let mut builder =
1125            openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
1126
1127        if let (Some(cert_data), Some(key_data)) =
1128            (&self.certificate_from_pem, &self.private_key_from_pem)
1129        {
1130            let cert = X509::from_pem(cert_data)?;
1131            builder.set_certificate(&cert)?;
1132
1133            let key = PKey::private_key_from_pem(key_data)?;
1134            builder.set_private_key(&key)?;
1135
1136            builder.check_private_key()?;
1137        }
1138
1139        if let Some(list) = &self.openssl_cipher_list {
1140            builder.set_cipher_list(list)?;
1141        }
1142
1143        if let Some(suites) = &self.openssl_cipher_suites {
1144            builder.set_ciphersuites(suites)?;
1145        }
1146
1147        if let Some(options) = &self.openssl_options {
1148            builder.clear_options(SslOptions::all());
1149            builder.set_options(*options);
1150        }
1151
1152        if self.insecure {
1153            builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
1154        }
1155
1156        if !self.dane_tlsa.is_empty() {
1157            builder.dane_enable()?;
1158            builder.set_no_dane_ee_namechecks();
1159        }
1160
1161        let connector = builder.build();
1162
1163        let mut config = connector.configure()?;
1164
1165        if !self.dane_tlsa.is_empty() {
1166            config.dane_enable(hostname)?;
1167            let mut any_usable = false;
1168            for tlsa in &self.dane_tlsa {
1169                let usable = config.dane_tlsa_add(
1170                    match tlsa.cert_usage() {
1171                        CertUsage::PkixTa => DaneUsage::PKIX_TA,
1172                        CertUsage::PkixEe => DaneUsage::PKIX_EE,
1173                        CertUsage::DaneTa => DaneUsage::DANE_TA,
1174                        CertUsage::DaneEe => DaneUsage::DANE_EE,
1175                        CertUsage::Unassigned(n) => DaneUsage::from_raw(n),
1176                        CertUsage::Private => DaneUsage::PRIV_CERT,
1177                    },
1178                    match tlsa.selector() {
1179                        Selector::Full => DaneSelector::CERT,
1180                        Selector::Spki => DaneSelector::SPKI,
1181                        Selector::Unassigned(n) => DaneSelector::from_raw(n),
1182                        Selector::Private => DaneSelector::PRIV_SEL,
1183                    },
1184                    match tlsa.matching() {
1185                        Matching::Raw => DaneMatchType::FULL,
1186                        Matching::Sha256 => DaneMatchType::SHA2_256,
1187                        Matching::Sha512 => DaneMatchType::SHA2_512,
1188                        Matching::Unassigned(n) => DaneMatchType::from_raw(n),
1189                        Matching::Private => DaneMatchType::PRIV_MATCH,
1190                    },
1191                    tlsa.cert_data(),
1192                )?;
1193
1194                tracing::trace!("build_dane_connector usable={usable} {tlsa:?}");
1195                if usable {
1196                    any_usable = true;
1197                }
1198            }
1199
1200            if !any_usable {
1201                return Err(ClientError::NoUsableDaneTlsa {
1202                    hostname: hostname.to_string(),
1203                    tlsa: self.dane_tlsa.clone(),
1204                });
1205            }
1206        }
1207
1208        Ok(config)
1209    }
1210}
1211
1212fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1213    static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1214
1215    if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1216        return None;
1217    }
1218
1219    let mut stuffed = vec![];
1220    if data.starts_with(b".") {
1221        stuffed.push(b'.');
1222    }
1223    let mut last_idx = 0;
1224    for i in LFDOT.find_iter(data) {
1225        stuffed.extend_from_slice(&data[last_idx..=i]);
1226        stuffed.push(b'.');
1227        last_idx = i + 1;
1228    }
1229    stuffed.extend_from_slice(&data[last_idx..]);
1230    Some(stuffed)
1231}
1232
1233/// Extracts the object=name pairs of the subject name from a cert.
1234/// eg:
1235/// ```no_run
1236/// ["C=US", "ST=CA", "L=SanFrancisco", "O=Fort-Funston", "OU=MyOrganizationalUnit",
1237/// "CN=do.havedane.net", "name=EasyRSA", "emailAddress=me@myhost.mydomain"]
1238/// ```
1239pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1240    let mut subject_name = vec![];
1241    for entry in cert.subject_name().entries() {
1242        if let Ok(obj) = entry.object().nid().short_name() {
1243            if let Ok(name) = entry.data().as_utf8() {
1244                subject_name.push(format!("{obj}={name}"));
1245            }
1246        }
1247    }
1248    subject_name
1249}
1250
1251#[cfg(test)]
1252mod test {
1253    use super::*;
1254
1255    #[test]
1256    fn test_stuffing() {
1257        assert_eq!(apply_dot_stuffing(b"foo"), None);
1258        assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1259        assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1260        assert_eq!(
1261            apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1262            b"foo\n..bar\n...baz\n"
1263        );
1264    }
1265
1266    /*
1267    #[tokio::test]
1268    async fn test_against_sink() {
1269        use tokio::net::TcpStream;
1270        let stream = TcpStream::connect("127.0.0.1:2025").await.unwrap();
1271        let mut client =
1272            SmtpClient::with_stream(stream, "localhost", SmtpClientTimeouts::default());
1273        dbg!(client.read_response(None).await).unwrap();
1274        dbg!(client.ehlo("localhost").await).unwrap();
1275        let insecure = true;
1276        dbg!(client.starttls(insecure).await).unwrap();
1277        let resp = client
1278            .send_mail(
1279                ReversePath::try_from("wez@mail.example.com").unwrap(),
1280                ForwardPath::try_from("wez@mail.example.com").unwrap(),
1281                "Subject: hello\r\n\r\nwoot\r\n",
1282            )
1283            .await
1284            .unwrap();
1285        panic!("{resp:#?}");
1286    }
1287    */
1288
1289    #[test]
1290    fn response_line_parsing() {
1291        assert_eq!(
1292            parse_response_line("220 woot").unwrap(),
1293            ResponseLine {
1294                code: 220,
1295                is_final: true,
1296                content: "woot"
1297            }
1298        );
1299        assert_eq!(
1300            parse_response_line("220-woot").unwrap(),
1301            ResponseLine {
1302                code: 220,
1303                is_final: false,
1304                content: "woot"
1305            }
1306        );
1307
1308        assert!(matches!(
1309            parse_response_line("220_woot"),
1310            Err(ClientError::MalformedResponseLine(_))
1311        ));
1312        assert!(matches!(
1313            parse_response_line("not really"),
1314            Err(ClientError::MalformedResponseLine(_))
1315        ));
1316    }
1317
1318    fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1319        let mut parsed = parse_response_line(lines[0])?;
1320        let mut b = ResponseBuilder::new(&parsed);
1321        for line in &lines[1..] {
1322            parsed = parse_response_line(line)?;
1323            b.add_line(&parsed)
1324                .map_err(ClientError::MalformedResponseLine)?;
1325        }
1326        assert!(parsed.is_final);
1327        Ok(b.build(None))
1328    }
1329
1330    #[test]
1331    fn multi_line_response() {
1332        assert_eq!(
1333            parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1334            Response {
1335                code: 220,
1336                enhanced_code: None,
1337                content: "woot\nmore\ndone".to_string(),
1338                command: None
1339            }
1340        );
1341
1342        let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1343        assert!(
1344            matches!(
1345                    res,
1346                ClientError::MalformedResponseLine(ref err) if err == "221-more"
1347            ),
1348            "got error {res:?}"
1349        );
1350
1351        let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1352        assert!(
1353            matches!(
1354                    res,
1355                ClientError::MalformedResponseLine(ref err) if err == "221 done"
1356            ),
1357            "got error {res:?}"
1358        );
1359
1360        assert_eq!(
1361            parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1362            Response {
1363                code: 220,
1364                enhanced_code: Some(EnhancedStatusCode {
1365                    class: 4,
1366                    subject: 1,
1367                    detail: 0
1368                }),
1369                content: "woot\nmore\ndone".to_string(),
1370                command: None
1371            }
1372        );
1373
1374        // Confirm that we strip the enhanced status code from each line
1375        assert_eq!(
1376            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1377            Response {
1378                code: 220,
1379                enhanced_code: Some(EnhancedStatusCode {
1380                    class: 4,
1381                    subject: 1,
1382                    detail: 0
1383                }),
1384                content: "woot\nmore\ndone".to_string(),
1385                command: None
1386            }
1387        );
1388
1389        // ... but only if the code matches that of the first line
1390        assert_eq!(
1391            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1392            Response {
1393                code: 220,
1394                enhanced_code: Some(EnhancedStatusCode {
1395                    class: 4,
1396                    subject: 1,
1397                    detail: 0
1398                }),
1399                content: "woot\nmore\n5.5.5 done".to_string(),
1400                command: None
1401            }
1402        );
1403    }
1404
1405    #[test]
1406    fn test_extract_hostname() {
1407        assert_eq!(extract_hostname("foo"), "foo");
1408        assert_eq!(extract_hostname("foo."), "foo");
1409        assert_eq!(extract_hostname("foo:25"), "foo");
1410        assert_eq!(extract_hostname("foo.:25"), "foo");
1411        assert_eq!(extract_hostname("[foo]:25"), "foo");
1412        assert_eq!(extract_hostname("[foo.]:25"), "foo");
1413        assert_eq!(extract_hostname("[::1]:25"), "::1");
1414        assert_eq!(extract_hostname("::1:25"), "::1");
1415    }
1416}