rfc5321/
client.rs

1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4    AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5    ReversePath,
6};
7use hickory_proto::rr::rdata::tlsa::{CertUsage, Matching, Selector};
8use hickory_proto::rr::rdata::TLSA;
9use memchr::memmem::Finder;
10use openssl::ssl::{DaneMatchType, DaneSelector, DaneUsage, SslOptions};
11use openssl::x509::{X509Ref, X509};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::net::IpAddr;
15use std::str::FromStr;
16use std::sync::{Arc, LazyLock};
17use std::time::Duration;
18use thiserror::Error;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::net::{TcpStream, ToSocketAddrs};
21use tokio::time::timeout;
22use tokio_rustls::rustls::pki_types::ServerName;
23use tracing::Level;
24
25pub use crate::tls::TlsOptions;
26pub use {openssl, tokio_rustls};
27
28const MAX_LINE_LEN: usize = 4096;
29
30#[derive(Error, Debug, Clone)]
31pub enum ClientError {
32    #[error("response is not UTF8")]
33    Utf8(#[from] std::string::FromUtf8Error),
34    #[error("Malformed Response: {0}")]
35    MalformedResponseLine(String),
36    #[error("Response line is too long")]
37    ResponseTooLong,
38    #[error("Not connected")]
39    NotConnected,
40    #[error("Command rejected {0:?}")]
41    Rejected(Response),
42    #[error("STARTTLS: {0} is not a valid DNS name")]
43    InvalidDnsName(String),
44    #[error("Timed Out waiting {duration:?} for response to {command:?}")]
45    TimeOutResponse {
46        command: Option<Command>,
47        duration: Duration,
48    },
49    #[error("Timed Out writing {duration:?} {commands:?}")]
50    TimeOutRequest {
51        commands: Vec<Command>,
52        duration: Duration,
53    },
54    #[error("Error {error} reading response to {command:?}")]
55    ReadError {
56        command: Option<Command>,
57        error: String,
58        partial: String,
59    },
60    #[error("Error {error} flushing send buffer")]
61    FlushError { error: String },
62    #[error("Error {error} writing {commands:?}")]
63    WriteError {
64        commands: Vec<Command>,
65        error: String,
66    },
67    #[error("Timed Out sending message payload data")]
68    TimeOutData,
69    #[error("SSL Error: {0}")]
70    SslErrorStack(#[from] openssl::error::ErrorStack),
71    #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
72    NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
76pub struct EsmtpCapability {
77    pub name: String,
78    pub param: Option<String>,
79}
80
81#[derive(Clone, Debug)]
82pub enum SmtpClientTraceEvent {
83    Closed,
84    Read(Vec<u8>),
85    Write(String),
86    Diagnostic {
87        level: tracing::Level,
88        message: String,
89    },
90}
91
92pub trait DeferredTracer {
93    fn trace(&self) -> SmtpClientTraceEvent;
94}
95
96pub trait SmtpClientTracer: std::fmt::Debug {
97    fn trace_event(&self, event: SmtpClientTraceEvent);
98    fn lazy_trace(&self, deferred: &dyn DeferredTracer);
99}
100
101// helper to avoid making a second copy of every write buffer
102struct WriteTracer<'a> {
103    data: &'a str,
104}
105impl DeferredTracer for WriteTracer<'_> {
106    fn trace(&self) -> SmtpClientTraceEvent {
107        SmtpClientTraceEvent::Write(self.data.to_string())
108    }
109}
110impl<'a> WriteTracer<'a> {
111    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
112        tracer.lazy_trace(&Self { data });
113    }
114}
115
116struct BinWriteTracer<'a> {
117    data: &'a [u8],
118}
119impl DeferredTracer for BinWriteTracer<'_> {
120    fn trace(&self) -> SmtpClientTraceEvent {
121        let data = String::from_utf8_lossy(self.data).to_string();
122        SmtpClientTraceEvent::Write(data)
123    }
124}
125impl<'a> BinWriteTracer<'a> {
126    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
127        tracer.lazy_trace(&Self { data });
128    }
129}
130
131// A little bit of gymnastics to avoid making a second
132// copy of every read buffer
133struct ReadTracer<'a> {
134    data: &'a [u8],
135}
136impl DeferredTracer for ReadTracer<'_> {
137    fn trace(&self) -> SmtpClientTraceEvent {
138        SmtpClientTraceEvent::Read(self.data.to_vec())
139    }
140}
141
142#[derive(Debug)]
143pub struct SmtpClient {
144    socket: Option<BoxedAsyncReadAndWrite>,
145    hostname: String,
146    capabilities: HashMap<String, EsmtpCapability>,
147    read_buffer: Vec<u8>,
148    timeouts: SmtpClientTimeouts,
149    tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
150    use_rset: bool,
151    enable_rset: bool,
152    enable_pipelining: bool,
153}
154
155fn extract_hostname(hostname: &str) -> &str {
156    // Just the hostname, without any :port
157    let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
158    let hostname = if fields.len() == 2 {
159        fields[1]
160    } else {
161        hostname
162    };
163
164    let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
165        &hostname[1..hostname.len() - 1]
166    } else {
167        hostname
168    };
169
170    // Remove any trailing FQDN dot
171    hostname.strip_suffix(".").unwrap_or(hostname)
172}
173
174impl SmtpClient {
175    pub async fn new<A: ToSocketAddrs + ToString + Clone>(
176        addr: A,
177        timeouts: SmtpClientTimeouts,
178    ) -> std::io::Result<Self> {
179        let stream = TcpStream::connect(addr.clone()).await?;
180        // No need for Nagle with SMTP request/response
181        stream.set_nodelay(true)?;
182        Ok(Self::with_stream(stream, addr.to_string(), timeouts))
183    }
184
185    pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
186        stream: S,
187        peer_hostname: H,
188        timeouts: SmtpClientTimeouts,
189    ) -> Self {
190        let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
191
192        Self {
193            socket: Some(Box::new(stream)),
194            hostname,
195            capabilities: HashMap::new(),
196            read_buffer: Vec::with_capacity(1024),
197            timeouts,
198            tracer: None,
199            use_rset: false,
200            enable_rset: false,
201            enable_pipelining: false,
202        }
203    }
204
205    pub fn is_connected(&self) -> bool {
206        self.socket.is_some()
207    }
208
209    pub fn set_enable_rset(&mut self, enable: bool) {
210        self.enable_rset = enable;
211    }
212
213    pub fn set_enable_pipelining(&mut self, enable: bool) {
214        self.enable_pipelining = enable;
215    }
216
217    pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
218        self.tracer.replace(tracer);
219    }
220
221    pub fn timeouts(&self) -> &SmtpClientTimeouts {
222        &self.timeouts
223    }
224
225    async fn read_line(
226        &mut self,
227        timeout_duration: Duration,
228        cmd: Option<&Command>,
229    ) -> Result<String, ClientError> {
230        let mut too_long = false;
231        loop {
232            let mut iter = self.read_buffer.iter().enumerate();
233            while let Some((i, &b)) = iter.next() {
234                if b != b'\r' {
235                    continue;
236                }
237                if let Some((_, b'\n')) = iter.next() {
238                    if too_long {
239                        self.read_buffer.drain(0..i + 2);
240
241                        if let Some(tracer) = &self.tracer {
242                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
243                                level: Level::ERROR,
244                                message: "Response too long".to_string(),
245                            });
246                        }
247
248                        return Err(ClientError::ResponseTooLong);
249                    }
250
251                    let line = String::from_utf8(self.read_buffer[0..i].to_vec());
252                    self.read_buffer.drain(0..i + 2);
253                    return Ok(line?);
254                }
255            }
256            if self.read_buffer.len() > MAX_LINE_LEN {
257                self.read_buffer.clear();
258                too_long = true;
259            }
260
261            // Didn't find a complete line, fill up the rest of the buffer
262            let mut data = [0u8; MAX_LINE_LEN];
263            let size = match self.socket.as_mut() {
264                Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
265                    Ok(Ok(size)) => size,
266                    Ok(Err(err)) => {
267                        self.socket.take();
268                        if let Some(tracer) = &self.tracer {
269                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
270                                level: Level::ERROR,
271                                message: format!("Error during read: {err:#}"),
272                            });
273                            tracer.trace_event(SmtpClientTraceEvent::Closed);
274                        }
275                        return Err(ClientError::ReadError {
276                            command: cmd.cloned(),
277                            error: format!("{err:#}"),
278                            partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
279                        });
280                    }
281                    Err(_) => {
282                        self.socket.take();
283                        if let Some(tracer) = &self.tracer {
284                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
285                                level: Level::ERROR,
286                                message: format!("Read Timeout after {timeout_duration:?}"),
287                            });
288                            tracer.trace_event(SmtpClientTraceEvent::Closed);
289                        }
290                        return Err(ClientError::TimeOutResponse {
291                            command: cmd.cloned(),
292                            duration: timeout_duration,
293                        });
294                    }
295                },
296                None => {
297                    return Err(ClientError::ReadError {
298                        command: cmd.cloned(),
299                        error: "the socket was closed in response to an earlier issue".to_string(),
300                        partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
301                    });
302                }
303            };
304            if size == 0 {
305                self.socket.take();
306                if let Some(tracer) = &self.tracer {
307                    tracer.trace_event(SmtpClientTraceEvent::Closed);
308                }
309                return Err(ClientError::ReadError {
310                    command: cmd.cloned(),
311                    error: "Connection closed by peer".to_string(),
312                    partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
313                });
314            }
315            if let Some(tracer) = &self.tracer {
316                tracer.lazy_trace(&ReadTracer {
317                    data: &data[0..size],
318                });
319            }
320            self.read_buffer.extend_from_slice(&data[0..size]);
321        }
322    }
323
324    pub async fn read_response(
325        &mut self,
326        command: Option<&Command>,
327        timeout_duration: Duration,
328    ) -> Result<Response, ClientError> {
329        if let Some(sock) = self.socket.as_mut() {
330            if let Err(err) = sock.flush().await {
331                self.socket.take();
332                if let Some(tracer) = &self.tracer {
333                    tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
334                        level: Level::ERROR,
335                        message: format!("Error during flush: {err:#}"),
336                    });
337                    tracer.trace_event(SmtpClientTraceEvent::Closed);
338                }
339                return Err(ClientError::FlushError {
340                    error: format!("{err:#}"),
341                });
342            }
343        }
344
345        let mut line = self.read_line(timeout_duration, command).await?;
346        tracing::trace!("recv<-{}: {line}", self.hostname);
347        let mut parsed = parse_response_line(&line)?;
348        let mut response_builder = ResponseBuilder::new(&parsed);
349
350        let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
351        while !parsed.is_final {
352            line = self
353                .read_line(subsequent_line_timeout_duration, command)
354                .await?;
355            parsed = parse_response_line(&line)?;
356            response_builder
357                .add_line(&parsed)
358                .map_err(ClientError::MalformedResponseLine)?;
359        }
360
361        let response = response_builder.build(command.map(|cmd| cmd.encode()));
362
363        tracing::trace!("{}: {response:?}", self.hostname);
364
365        Ok(response)
366    }
367
368    pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
369        self.write_command_request(command).await?;
370        self.read_response(Some(command), command.client_timeout(&self.timeouts))
371            .await
372    }
373
374    /// Wrapper around socket.write_all() that will emit trace diagnostics and synthesize
375    /// a Close event to the tracer if a timeout or IO error occurs.
376    /// If an error or timeout, occurs ensures that the socket will not be reused.
377    async fn write_all_with_timeout<F, G>(
378        &mut self,
379        timeout_duration: Duration,
380        bytes: &[u8],
381        make_timeout_err: F,
382        make_write_err: G,
383    ) -> Result<(), ClientError>
384    where
385        F: FnOnce() -> ClientError,
386        G: FnOnce(String) -> ClientError,
387    {
388        match self.socket.as_mut() {
389            Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
390                Ok(Ok(response)) => Ok(response),
391                Ok(Err(err)) => {
392                    self.socket.take();
393                    if let Some(tracer) = &self.tracer {
394                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
395                            level: Level::ERROR,
396                            message: format!("Error during write: {err:#}"),
397                        });
398                        tracer.trace_event(SmtpClientTraceEvent::Closed);
399                    }
400                    Err(make_write_err(format!("{err:#}")))
401                }
402                Err(_) => {
403                    self.socket.take();
404                    if let Some(tracer) = &self.tracer {
405                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
406                            level: Level::ERROR,
407                            message: format!("Write Timeout after {timeout_duration:?}"),
408                        });
409                        tracer.trace_event(SmtpClientTraceEvent::Closed);
410                    }
411                    Err(make_timeout_err())
412                }
413            },
414            None => Err(make_write_err(
415                "the socket was closed in response to an earlier issue".to_string(),
416            )),
417        }
418    }
419
420    async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
421        let total_timeout: Duration = commands
422            .iter()
423            .map(|cmd| cmd.client_timeout_request(&self.timeouts))
424            .sum();
425
426        let mut lines: Vec<String> = vec![];
427        let mut all = String::new();
428        for cmd in commands {
429            let line = cmd.encode();
430            all.push_str(&line);
431            lines.push(line);
432        }
433        tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
434        if self.socket.is_some() {
435            if let Some(tracer) = &self.tracer {
436                // Send the lines individually to the tracer, so that we
437                // don't break --terse mode
438                for line in lines {
439                    WriteTracer::trace(tracer, &line);
440                }
441            }
442        }
443        self.write_all_with_timeout(
444            total_timeout,
445            all.as_bytes(),
446            || ClientError::TimeOutRequest {
447                duration: total_timeout,
448                commands: commands.to_vec(),
449            },
450            |error| ClientError::WriteError {
451                error,
452                commands: commands.to_vec(),
453            },
454        )
455        .await
456    }
457
458    async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
459        let line = command.encode();
460        tracing::trace!("send->{}: {line}", self.hostname);
461        if self.socket.is_some() {
462            if let Some(tracer) = &self.tracer {
463                WriteTracer::trace(tracer, &line);
464            }
465        }
466
467        let timeout_duration = command.client_timeout_request(&self.timeouts);
468        self.write_all_with_timeout(
469            timeout_duration,
470            line.as_bytes(),
471            || ClientError::TimeOutRequest {
472                duration: timeout_duration,
473                commands: vec![command.clone()],
474            },
475            |error| ClientError::WriteError {
476                error,
477                commands: vec![command.clone()],
478            },
479        )
480        .await
481    }
482
483    async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
484        if self.socket.is_some() {
485            if let Some(tracer) = &self.tracer {
486                BinWriteTracer::trace(tracer, data);
487            }
488        }
489        let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
490        self.write_all_with_timeout(
491            timeout_duration,
492            data,
493            || ClientError::TimeOutData,
494            |error| ClientError::WriteError {
495                error,
496                commands: vec![],
497            },
498        )
499        .await
500    }
501
502    /// Issue a series of commands, and return the responses to
503    /// those commands.
504    ///
505    /// If the server advertised the RFC 2920 PIPELINING extension,
506    /// the commands are written one after the other before waiting
507    /// to read any data, resulting in lower overall latency due
508    /// to round-trip-times.
509    ///
510    /// If PIPELINING is not available, each command is written
511    /// and the response read before attempting to write the next
512    /// command.
513    ///
514    /// The number of returned responses may be smaller than the
515    /// number of requested commands if there is an issue with
516    /// the network connection.
517    pub async fn pipeline_commands(
518        &mut self,
519        commands: Vec<Command>,
520    ) -> Vec<Result<Response, ClientError>> {
521        let mut results: Vec<Result<Response, ClientError>> = vec![];
522
523        let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
524        if pipeline {
525            if let Err(err) = self.write_pipeline_request(&commands).await {
526                let err: ClientError = err;
527                results.push(Err(err.clone()));
528                while results.len() < commands.len() {
529                    // Synthesize failures for the remaining commands
530                    results.push(Err(err.clone()));
531                }
532                return results;
533            }
534
535            // Now read the responses effectively in a batch
536            for cmd in &commands {
537                results.push(
538                    self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
539                        .await,
540                );
541            }
542            return results;
543        }
544
545        for cmd in &commands {
546            if let Err(err) = self.write_command_request(cmd).await {
547                let err: ClientError = err;
548                results.push(Err(err.clone()));
549                while results.len() < commands.len() {
550                    // Synthesize failures for the remaining commands
551                    results.push(Err(err.clone()));
552                }
553                return results;
554            }
555            // Immediately request the response if the server
556            // doesn't support pipelining
557            results.push(
558                self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
559                    .await,
560            );
561        }
562        results
563    }
564
565    pub async fn ehlo_lhlo(
566        &mut self,
567        ehlo_name: &str,
568        use_lmtp: bool,
569    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
570        if use_lmtp {
571            self.lhlo(ehlo_name).await
572        } else {
573            self.ehlo(ehlo_name).await
574        }
575    }
576
577    pub async fn lhlo(
578        &mut self,
579        ehlo_name: &str,
580    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
581        let response = self
582            .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
583            .await?;
584        self.ehlo_common(response)
585    }
586
587    pub async fn ehlo(
588        &mut self,
589        ehlo_name: &str,
590    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
591        let response = self
592            .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
593            .await?;
594        self.ehlo_common(response)
595    }
596
597    fn ehlo_common(
598        &mut self,
599        response: Response,
600    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
601        if response.code != 250 {
602            return Err(ClientError::Rejected(response));
603        }
604
605        let mut capabilities = HashMap::new();
606
607        for line in response.content.lines().skip(1) {
608            let mut fields = line.splitn(2, ' ');
609            if let Some(name) = fields.next() {
610                let param = fields.next().map(|s| s.to_string());
611                let cap = EsmtpCapability {
612                    name: name.to_string(),
613                    param,
614                };
615                capabilities.insert(name.to_ascii_uppercase(), cap);
616            }
617        }
618
619        self.capabilities = capabilities;
620        Ok(&self.capabilities)
621    }
622
623    pub async fn auth_plain(
624        &mut self,
625        username: &str,
626        password: Option<&str>,
627    ) -> Result<(), ClientError> {
628        // RFC 4616 says that the format is:
629        // [authzid] NUL authcid NUL passwd
630        let password = password.unwrap_or("");
631        let payload = format!("\x00{username}\x00{password}");
632        let payload = data_encoding::BASE64.encode(payload.as_bytes());
633
634        let response = self
635            .send_command(&Command::Auth {
636                sasl_mech: "PLAIN".to_string(),
637                initial_response: Some(payload),
638            })
639            .await?;
640
641        if response.code != 235 {
642            return Err(ClientError::Rejected(response));
643        }
644
645        Ok(())
646    }
647
648    /// Attempt TLS handshake.
649    /// Returns Err for IO errors.
650    /// On completion, return an option that will be:
651    /// * Some(handshake_error) - if the handshake failed
652    /// * None - if the handshake succeeded
653    pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
654        let resp = self.send_command(&Command::StartTls).await?;
655        if resp.code != 220 {
656            return Err(ClientError::Rejected(resp));
657        }
658
659        let mut handshake_error = None;
660        let mut tls_info = TlsInformation::default();
661
662        let stream: BoxedAsyncReadAndWrite =
663            if options.prefer_openssl || !options.dane_tlsa.is_empty() {
664                let connector = options.build_openssl_connector(&self.hostname)?;
665                let ssl = connector.into_ssl(self.hostname.as_str())?;
666
667                let (stream, dup_stream) = match self.socket.take() {
668                    Some(s) => {
669                        let d = s.try_dup();
670                        (s, d)
671                    }
672                    None => return Err(ClientError::NotConnected),
673                };
674
675                let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
676
677                if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
678                    handshake_error.replace(format!("{err:#}"));
679                }
680
681                tls_info.provider_name = "openssl".to_string();
682                tls_info.cipher = match ssl_stream.ssl().current_cipher() {
683                    Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
684                    None => String::new(),
685                };
686                tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
687
688                if let Some(cert) = ssl_stream.ssl().peer_certificate() {
689                    tls_info.subject_name = subject_name(&cert);
690                }
691                if let Ok(authority) = ssl_stream.ssl().dane_authority() {
692                    if let Some(cert) = &authority.cert {
693                        tls_info.subject_name = subject_name(cert);
694                    }
695                }
696
697                match (&handshake_error, dup_stream) {
698                    (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
699                        // Try falling back to clear text on the duplicate stream.
700                        // This is imperfect: in a failed validation scenario we will
701                        // end up trying to read binary data as a string and get a UTF-8
702                        // error if the peer thinks the session is encrypted.
703                        drop(ssl_stream);
704                        Box::new(dup_stream)
705                    }
706                    _ => Box::new(ssl_stream),
707                }
708            } else {
709                tls_info.provider_name = "rustls".to_string();
710                let connector = options.build_tls_connector().await;
711                let server_name = match IpAddr::from_str(self.hostname.as_str()) {
712                    Ok(ip) => ServerName::IpAddress(ip.into()),
713                    Err(_) => ServerName::try_from(self.hostname.clone())
714                        .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
715                };
716
717                match connector
718                    .connect(
719                        server_name,
720                        match self.socket.take() {
721                            Some(s) => s,
722                            None => return Err(ClientError::NotConnected),
723                        },
724                    )
725                    .into_fallible()
726                    .await
727                {
728                    Ok(stream) => {
729                        let (_, conn) = stream.get_ref();
730                        tls_info.cipher = match conn.negotiated_cipher_suite() {
731                            Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
732                            None => String::new(),
733                        };
734                        tls_info.protocol_version = match conn.protocol_version() {
735                            Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
736                            None => String::new(),
737                        };
738
739                        if let Some(certs) = conn.peer_certificates() {
740                            let peer_cert = &certs[0];
741                            if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
742                                tls_info.subject_name = subject_name(&cert);
743                            }
744                        }
745
746                        Box::new(stream)
747                    }
748                    Err((err, stream)) => {
749                        handshake_error.replace(format!("{err:#}"));
750                        stream
751                    }
752                }
753            };
754
755        if let Some(tracer) = &self.tracer {
756            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
757                level: Level::INFO,
758                message: match &handshake_error {
759                    Some(error) => format!("STARTTLS handshake failed: {error:?}"),
760                    None => format!("STARTTLS handshake -> {tls_info:?}"),
761                },
762            });
763        }
764
765        self.socket.replace(stream);
766        Ok(match handshake_error {
767            Some(error) => TlsStatus::FailedHandshake(error),
768            None => TlsStatus::Info(tls_info),
769        })
770    }
771
772    pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
773        &mut self,
774        sender: SENDER,
775        recipient: RECIP,
776        data: B,
777    ) -> Result<Response, ClientError> {
778        let sender = sender.into();
779        let recipient = recipient.into();
780
781        let data: &[u8] = data.as_ref();
782        let stuffed;
783
784        let data = match apply_dot_stuffing(data) {
785            Some(d) => {
786                stuffed = d;
787                &stuffed
788            }
789            None => data,
790        };
791
792        let data_is_8bit = data.iter().any(|&b| b >= 0x80);
793        let envelope_is_8bit = !sender.is_ascii() || !recipient.is_ascii();
794
795        let mut mail_from_params = vec![];
796        if data_is_8bit && self.capabilities.contains_key("8BITMIME") {
797            mail_from_params.push(EsmtpParameter {
798                name: "BODY".to_string(),
799                value: Some("8BITMIME".to_string()),
800            });
801        }
802
803        if envelope_is_8bit && self.capabilities.contains_key("SMTPUTF8") {
804            mail_from_params.push(EsmtpParameter {
805                name: "SMTPUTF8".to_string(),
806                value: None,
807            });
808        }
809
810        let mut commands = vec![];
811
812        // We want to avoid using RSET for the first message we send on
813        // a given connection, because postfix can run in a mode where
814        // it will not tolerate RSET because it considers it to be a "junk"
815        // command, and rejecting junk commands will cut down on its load
816        // when it is under stress; it is used as a load shedding approach.
817        // If we always RSET then we will never deliver to a site that is
818        // configured that way. If we take care to RSET only for subsequent
819        // sends, then we should get one message per connection through
820        // without being unfairly penalized for defensively RSETing.
821        let used_rset = self.use_rset;
822        if self.use_rset {
823            commands.push(Command::Rset);
824        }
825        commands.push(Command::MailFrom {
826            address: sender,
827            parameters: mail_from_params,
828        });
829        commands.push(Command::RcptTo {
830            address: recipient,
831            parameters: vec![],
832        });
833        commands.push(Command::Data);
834
835        // Assume that something might break below: if it does, we want
836        // to ensure that we RSET the connection on the next go around.
837        self.use_rset = true;
838
839        let mut responses = self.pipeline_commands(commands).await;
840
841        // This is a little awkward. We want to handle the RFC 2090 3.1 case
842        // below, which requires deferring checking the actual response codes
843        // until later, but we must also handle the case where we had a hard
844        // transport error partway through pipelining.
845        // So we set a flag for that case and will then "eagerly", wrt. the
846        // RFC 2090 3.1 logic, evaluate the SMTP response codes, so that we
847        // can propagate the correct error disposition up to the caller.
848        let is_err = responses.iter().any(|r| r.is_err());
849
850        if used_rset {
851            let rset_resp = responses.remove(0)?;
852            if rset_resp.code != 250 {
853                return Err(ClientError::Rejected(rset_resp));
854            }
855        }
856
857        let mail_resp = responses.remove(0)?;
858        if is_err && mail_resp.code != 250 {
859            return Err(ClientError::Rejected(mail_resp));
860        }
861
862        let rcpt_resp = responses.remove(0)?;
863        if is_err && rcpt_resp.code != 250 {
864            return Err(ClientError::Rejected(rcpt_resp));
865        }
866
867        let data_resp = responses.remove(0)?;
868        if is_err && data_resp.code != 354 {
869            return Err(ClientError::Rejected(data_resp));
870        }
871
872        if data_resp.code == 354 && (mail_resp.code != 250 || rcpt_resp.code != 250) {
873            // RFC 2920 3.1:
874            // the client cannot assume that the DATA command will be rejected
875            // just because none of the RCPT TO commands worked.  If the DATA
876            // command was properly rejected the client SMTP can just issue
877            // RSET, but if the DATA command was accepted the client SMTP
878            // should send a single dot.
879
880            // Send dummy data
881            self.write_data_with_timeout(b".\r\n").await?;
882            let data_dot = Command::DataDot;
883            // wait for its response
884            let _ = self
885                .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
886                .await?;
887
888            // Continue below: we will match one of the failure cases and
889            // return a ClientError::Rejected from one of the earlier
890            // commands
891        }
892
893        if mail_resp.code != 250 {
894            return Err(ClientError::Rejected(mail_resp));
895        }
896        if rcpt_resp.code != 250 {
897            return Err(ClientError::Rejected(rcpt_resp));
898        }
899        if data_resp.code != 354 {
900            return Err(ClientError::Rejected(data_resp));
901        }
902
903        let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
904
905        tracing::trace!("message data is {} bytes", data.len());
906
907        self.write_data_with_timeout(data).await?;
908
909        let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
910
911        tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
912
913        self.write_data_with_timeout(marker.as_bytes()).await?;
914
915        let data_dot = Command::DataDot;
916        let resp = self
917            .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
918            .await?;
919        if resp.code != 250 {
920            return Err(ClientError::Rejected(resp));
921        }
922
923        // If everything went well, respect the user preference for speculatively
924        // issuing an RSET next time around
925        self.use_rset = self.enable_rset;
926
927        Ok(resp)
928    }
929}
930
931#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
932pub enum TlsStatus {
933    FailedHandshake(String),
934    Info(TlsInformation),
935}
936
937#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
938pub struct TlsInformation {
939    pub cipher: String,
940    pub protocol_version: String,
941    pub subject_name: Vec<String>,
942    pub provider_name: String,
943}
944
945impl Drop for SmtpClient {
946    fn drop(&mut self) {
947        if let Some(tracer) = &self.tracer {
948            if self.socket.is_some() {
949                tracer.trace_event(SmtpClientTraceEvent::Closed);
950            }
951        }
952    }
953}
954fn parse_response_line(line: &str) -> Result<ResponseLine, ClientError> {
955    if line.len() < 4 {
956        return Err(ClientError::MalformedResponseLine(line.to_string()));
957    }
958
959    match line.as_bytes()[3] {
960        b' ' | b'-' => match line[0..3].parse::<u16>() {
961            Ok(code) => Ok(ResponseLine {
962                code,
963                is_final: line.as_bytes()[3] == b' ',
964                content: &line[4..],
965            }),
966            Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
967        },
968        _ => Err(ClientError::MalformedResponseLine(line.to_string())),
969    }
970}
971
972impl TlsOptions {
973    pub fn build_openssl_connector(
974        &self,
975        hostname: &str,
976    ) -> Result<openssl::ssl::ConnectConfiguration, ClientError> {
977        tracing::trace!("build_openssl_connector for {hostname}");
978        let mut builder =
979            openssl::ssl::SslConnector::builder(openssl::ssl::SslMethod::tls_client())?;
980
981        if let Some(list) = &self.openssl_cipher_list {
982            builder.set_cipher_list(list)?;
983        }
984
985        if let Some(suites) = &self.openssl_cipher_suites {
986            builder.set_ciphersuites(suites)?;
987        }
988
989        if let Some(options) = &self.openssl_options {
990            builder.clear_options(SslOptions::all());
991            builder.set_options(*options);
992        }
993
994        if self.insecure {
995            builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
996        }
997
998        if !self.dane_tlsa.is_empty() {
999            builder.dane_enable()?;
1000            builder.set_no_dane_ee_namechecks();
1001        }
1002
1003        let connector = builder.build();
1004
1005        let mut config = connector.configure()?;
1006
1007        if !self.dane_tlsa.is_empty() {
1008            config.dane_enable(hostname)?;
1009            let mut any_usable = false;
1010            for tlsa in &self.dane_tlsa {
1011                let usable = config.dane_tlsa_add(
1012                    match tlsa.cert_usage() {
1013                        CertUsage::PkixTa => DaneUsage::PKIX_TA,
1014                        CertUsage::PkixEe => DaneUsage::PKIX_EE,
1015                        CertUsage::DaneTa => DaneUsage::DANE_TA,
1016                        CertUsage::DaneEe => DaneUsage::DANE_EE,
1017                        CertUsage::Unassigned(n) => DaneUsage::from_raw(n),
1018                        CertUsage::Private => DaneUsage::PRIV_CERT,
1019                    },
1020                    match tlsa.selector() {
1021                        Selector::Full => DaneSelector::CERT,
1022                        Selector::Spki => DaneSelector::SPKI,
1023                        Selector::Unassigned(n) => DaneSelector::from_raw(n),
1024                        Selector::Private => DaneSelector::PRIV_SEL,
1025                    },
1026                    match tlsa.matching() {
1027                        Matching::Raw => DaneMatchType::FULL,
1028                        Matching::Sha256 => DaneMatchType::SHA2_256,
1029                        Matching::Sha512 => DaneMatchType::SHA2_512,
1030                        Matching::Unassigned(n) => DaneMatchType::from_raw(n),
1031                        Matching::Private => DaneMatchType::PRIV_MATCH,
1032                    },
1033                    tlsa.cert_data(),
1034                )?;
1035
1036                tracing::trace!("build_dane_connector usable={usable} {tlsa:?}");
1037                if usable {
1038                    any_usable = true;
1039                }
1040            }
1041
1042            if !any_usable {
1043                return Err(ClientError::NoUsableDaneTlsa {
1044                    hostname: hostname.to_string(),
1045                    tlsa: self.dane_tlsa.clone(),
1046                });
1047            }
1048        }
1049
1050        Ok(config)
1051    }
1052}
1053
1054fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1055    static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1056
1057    if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1058        return None;
1059    }
1060
1061    let mut stuffed = vec![];
1062    if data.starts_with(b".") {
1063        stuffed.push(b'.');
1064    }
1065    let mut last_idx = 0;
1066    for i in LFDOT.find_iter(data) {
1067        stuffed.extend_from_slice(&data[last_idx..=i]);
1068        stuffed.push(b'.');
1069        last_idx = i + 1;
1070    }
1071    stuffed.extend_from_slice(&data[last_idx..]);
1072    Some(stuffed)
1073}
1074
1075/// Extracts the object=name pairs of the subject name from a cert.
1076/// eg:
1077/// ```norun
1078/// ["C=US", "ST=CA", "L=SanFrancisco", "O=Fort-Funston", "OU=MyOrganizationalUnit",
1079/// "CN=do.havedane.net", "name=EasyRSA", "emailAddress=me@myhost.mydomain"]
1080/// ```
1081fn subject_name(cert: &X509Ref) -> Vec<String> {
1082    let mut subject_name = vec![];
1083    for entry in cert.subject_name().entries() {
1084        if let Ok(obj) = entry.object().nid().short_name() {
1085            if let Ok(name) = entry.data().as_utf8() {
1086                subject_name.push(format!("{obj}={name}"));
1087            }
1088        }
1089    }
1090    subject_name
1091}
1092
1093#[cfg(test)]
1094mod test {
1095    use super::*;
1096
1097    #[test]
1098    fn test_stuffing() {
1099        assert_eq!(apply_dot_stuffing(b"foo"), None);
1100        assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1101        assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1102        assert_eq!(
1103            apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1104            b"foo\n..bar\n...baz\n"
1105        );
1106    }
1107
1108    /*
1109    #[tokio::test]
1110    async fn test_against_sink() {
1111        use tokio::net::TcpStream;
1112        let stream = TcpStream::connect("127.0.0.1:2025").await.unwrap();
1113        let mut client =
1114            SmtpClient::with_stream(stream, "localhost", SmtpClientTimeouts::default());
1115        dbg!(client.read_response(None).await).unwrap();
1116        dbg!(client.ehlo("localhost").await).unwrap();
1117        let insecure = true;
1118        dbg!(client.starttls(insecure).await).unwrap();
1119        let resp = client
1120            .send_mail(
1121                ReversePath::try_from("wez@mail.example.com").unwrap(),
1122                ForwardPath::try_from("wez@mail.example.com").unwrap(),
1123                "Subject: hello\r\n\r\nwoot\r\n",
1124            )
1125            .await
1126            .unwrap();
1127        panic!("{resp:#?}");
1128    }
1129    */
1130
1131    #[test]
1132    fn response_line_parsing() {
1133        assert_eq!(
1134            parse_response_line("220 woot").unwrap(),
1135            ResponseLine {
1136                code: 220,
1137                is_final: true,
1138                content: "woot"
1139            }
1140        );
1141        assert_eq!(
1142            parse_response_line("220-woot").unwrap(),
1143            ResponseLine {
1144                code: 220,
1145                is_final: false,
1146                content: "woot"
1147            }
1148        );
1149
1150        assert!(matches!(
1151            parse_response_line("220_woot"),
1152            Err(ClientError::MalformedResponseLine(_))
1153        ));
1154        assert!(matches!(
1155            parse_response_line("not really"),
1156            Err(ClientError::MalformedResponseLine(_))
1157        ));
1158    }
1159
1160    fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1161        let mut parsed = parse_response_line(lines[0])?;
1162        let mut b = ResponseBuilder::new(&parsed);
1163        for line in &lines[1..] {
1164            parsed = parse_response_line(line)?;
1165            b.add_line(&parsed)
1166                .map_err(ClientError::MalformedResponseLine)?;
1167        }
1168        assert!(parsed.is_final);
1169        Ok(b.build(None))
1170    }
1171
1172    #[test]
1173    fn multi_line_response() {
1174        assert_eq!(
1175            parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1176            Response {
1177                code: 220,
1178                enhanced_code: None,
1179                content: "woot\nmore\ndone".to_string(),
1180                command: None
1181            }
1182        );
1183
1184        let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1185        assert!(
1186            matches!(
1187                    res,
1188                ClientError::MalformedResponseLine(ref err) if err == "221-more"
1189            ),
1190            "got error {res:?}"
1191        );
1192
1193        let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1194        assert!(
1195            matches!(
1196                    res,
1197                ClientError::MalformedResponseLine(ref err) if err == "221 done"
1198            ),
1199            "got error {res:?}"
1200        );
1201
1202        assert_eq!(
1203            parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1204            Response {
1205                code: 220,
1206                enhanced_code: Some(EnhancedStatusCode {
1207                    class: 4,
1208                    subject: 1,
1209                    detail: 0
1210                }),
1211                content: "woot\nmore\ndone".to_string(),
1212                command: None
1213            }
1214        );
1215
1216        // Confirm that we strip the enhanced status code from each line
1217        assert_eq!(
1218            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1219            Response {
1220                code: 220,
1221                enhanced_code: Some(EnhancedStatusCode {
1222                    class: 4,
1223                    subject: 1,
1224                    detail: 0
1225                }),
1226                content: "woot\nmore\ndone".to_string(),
1227                command: None
1228            }
1229        );
1230
1231        // ... but only if the code matches that of the first line
1232        assert_eq!(
1233            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1234            Response {
1235                code: 220,
1236                enhanced_code: Some(EnhancedStatusCode {
1237                    class: 4,
1238                    subject: 1,
1239                    detail: 0
1240                }),
1241                content: "woot\nmore\n5.5.5 done".to_string(),
1242                command: None
1243            }
1244        );
1245    }
1246
1247    #[test]
1248    fn test_extract_hostname() {
1249        assert_eq!(extract_hostname("foo"), "foo");
1250        assert_eq!(extract_hostname("foo."), "foo");
1251        assert_eq!(extract_hostname("foo:25"), "foo");
1252        assert_eq!(extract_hostname("foo.:25"), "foo");
1253        assert_eq!(extract_hostname("[foo]:25"), "foo");
1254        assert_eq!(extract_hostname("[foo.]:25"), "foo");
1255        assert_eq!(extract_hostname("[::1]:25"), "::1");
1256        assert_eq!(extract_hostname("::1:25"), "::1");
1257    }
1258}