rfc5321/
client.rs

1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4    AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5    ReversePath,
6};
7use hickory_proto::rr::rdata::TLSA;
8use memchr::memmem::Finder;
9use openssl::x509::{X509Ref, X509};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::str::FromStr;
14use std::sync::{Arc, LazyLock};
15use std::time::Duration;
16use thiserror::Error;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpStream, ToSocketAddrs};
19use tokio::time::timeout;
20use tokio_rustls::rustls::pki_types::ServerName;
21use tracing::Level;
22
23pub use kumo_tls_helper::TlsOptions;
24pub use {openssl, tokio_rustls};
25
26const MAX_LINE_LEN: usize = 4096;
27
28#[derive(Error, Debug, Clone)]
29pub enum ClientError {
30    #[error("response is not UTF8")]
31    Utf8(#[from] std::string::FromUtf8Error),
32    #[error("Malformed Response: {0}")]
33    MalformedResponseLine(String),
34    #[error("Response line is too long")]
35    ResponseTooLong,
36    #[error("Not connected")]
37    NotConnected,
38    #[error("Command rejected {0:?}")]
39    Rejected(Response),
40    #[error("Commands rejected {0:?}")]
41    RejectedBatch(Vec<Response>),
42    #[error("STARTTLS: {0} is not a valid DNS name")]
43    InvalidDnsName(String),
44    #[error("Invalid client certificate configured: {error:?}")]
45    FailedToBuildConnector { error: String },
46    #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
47    TimeOutResponse {
48        command: Option<Command>,
49        duration: Duration,
50    },
51    #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
52    TimeOutRequest {
53        commands: Vec<Command>,
54        duration: Duration,
55    },
56    #[error("Error {error} reading response to cmd={}", self.command_for_err())]
57    ReadError {
58        command: Option<Command>,
59        error: String,
60        partial: String,
61    },
62    #[error("Error {error} flushing send buffer")]
63    FlushError { error: String },
64    #[error("Error {error} writing {}", self.command_for_err())]
65    WriteError {
66        commands: Vec<Command>,
67        error: String,
68    },
69    #[error("Timed Out sending message payload data")]
70    TimeOutData,
71    #[error("SSL Error: {0}")]
72    SslErrorStack(#[from] openssl::error::ErrorStack),
73    #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
74    NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
75}
76
77impl ClientError {
78    /// Returns the command(s) string suitable for passing into a Response
79    pub fn command(&self) -> Option<String> {
80        match self {
81            Self::TimeOutResponse {
82                command: Some(command),
83                ..
84            }
85            | Self::ReadError {
86                command: Some(command),
87                ..
88            } => Some(command.encode()),
89            Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
90                if !commands.is_empty() =>
91            {
92                let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
93                Some(commands.join(""))
94            }
95            _ => None,
96        }
97    }
98
99    /// Returns the command(s) string suitable for embedding in the
100    /// overall error message
101    fn command_for_err(&self) -> String {
102        self.command()
103            .map(|cmd| cmd.replace("\r\n", ""))
104            .unwrap_or_else(|| "NONE".to_string())
105    }
106
107    /// If the error contents were likely caused by something
108    /// about the mostly recently attempted message, rather than
109    /// a transport issue, or a carry-over from a prior message
110    /// (eg: previous message was rejected and destination chose
111    /// to drop the connection, which we detect later in RSET
112    /// on the next message), then we return true.
113    /// The expectation is that the caller will transiently
114    /// fail the message for later retry.
115    /// If we return false then the caller might decide to
116    /// try that same message again more immediately on
117    /// a separate connection
118    pub fn was_due_to_message(&self) -> bool {
119        match self {
120            Self::Utf8(_)
121            | Self::MalformedResponseLine(_)
122            | Self::ResponseTooLong
123            | Self::NotConnected
124            | Self::InvalidDnsName(_)
125            | Self::FailedToBuildConnector { .. }
126            | Self::TimeOutResponse { .. }
127            | Self::TimeOutRequest { .. }
128            | Self::ReadError { .. }
129            | Self::FlushError { .. }
130            | Self::WriteError { .. }
131            | Self::TimeOutData
132            | Self::SslErrorStack(_)
133            | Self::NoUsableDaneTlsa { .. } => false,
134            Self::Rejected(response) => response.was_due_to_message(),
135            Self::RejectedBatch(responses) => match responses.len() {
136                1 => responses[0].was_due_to_message(),
137                _ => false,
138            },
139        }
140    }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub struct EsmtpCapability {
145    pub name: String,
146    pub param: Option<String>,
147}
148
149#[derive(Clone, Debug)]
150pub enum SmtpClientTraceEvent {
151    Closed,
152    Read(Vec<u8>),
153    Write(String),
154    Diagnostic {
155        level: tracing::Level,
156        message: String,
157    },
158}
159
160pub trait DeferredTracer {
161    fn trace(&self) -> SmtpClientTraceEvent;
162}
163
164pub trait SmtpClientTracer: std::fmt::Debug {
165    fn trace_event(&self, event: SmtpClientTraceEvent);
166    fn lazy_trace(&self, deferred: &dyn DeferredTracer);
167}
168
169// helper to avoid making a second copy of every write buffer
170struct WriteTracer<'a> {
171    data: &'a str,
172}
173impl DeferredTracer for WriteTracer<'_> {
174    fn trace(&self) -> SmtpClientTraceEvent {
175        SmtpClientTraceEvent::Write(self.data.to_string())
176    }
177}
178impl<'a> WriteTracer<'a> {
179    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
180        tracer.lazy_trace(&Self { data });
181    }
182}
183
184struct BinWriteTracer<'a> {
185    data: &'a [u8],
186}
187impl DeferredTracer for BinWriteTracer<'_> {
188    fn trace(&self) -> SmtpClientTraceEvent {
189        let data = String::from_utf8_lossy(self.data).to_string();
190        SmtpClientTraceEvent::Write(data)
191    }
192}
193impl<'a> BinWriteTracer<'a> {
194    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
195        tracer.lazy_trace(&Self { data });
196    }
197}
198
199// A little bit of gymnastics to avoid making a second
200// copy of every read buffer
201struct ReadTracer<'a> {
202    data: &'a [u8],
203}
204impl DeferredTracer for ReadTracer<'_> {
205    fn trace(&self) -> SmtpClientTraceEvent {
206        SmtpClientTraceEvent::Read(self.data.to_vec())
207    }
208}
209
210#[derive(Debug)]
211pub struct SmtpClient {
212    socket: Option<BoxedAsyncReadAndWrite>,
213    hostname: String,
214    capabilities: HashMap<String, EsmtpCapability>,
215    read_buffer: Vec<u8>,
216    timeouts: SmtpClientTimeouts,
217    tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
218    use_rset: bool,
219    enable_rset: bool,
220    enable_pipelining: bool,
221    ignore_8bit_checks: bool,
222}
223
224fn extract_hostname(hostname: &str) -> &str {
225    // Just the hostname, without any :port
226    let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
227    let hostname = if fields.len() == 2 {
228        fields[1]
229    } else {
230        hostname
231    };
232
233    let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
234        &hostname[1..hostname.len() - 1]
235    } else {
236        hostname
237    };
238
239    // Remove any trailing FQDN dot
240    hostname.strip_suffix(".").unwrap_or(hostname)
241}
242
243impl SmtpClient {
244    pub async fn new<A: ToSocketAddrs + ToString + Clone>(
245        addr: A,
246        timeouts: SmtpClientTimeouts,
247    ) -> std::io::Result<Self> {
248        let stream = TcpStream::connect(addr.clone()).await?;
249        // No need for Nagle with SMTP request/response
250        stream.set_nodelay(true)?;
251        Ok(Self::with_stream(stream, addr.to_string(), timeouts))
252    }
253
254    pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
255        stream: S,
256        peer_hostname: H,
257        timeouts: SmtpClientTimeouts,
258    ) -> Self {
259        let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
260
261        Self {
262            socket: Some(Box::new(stream)),
263            hostname,
264            capabilities: HashMap::new(),
265            read_buffer: Vec::with_capacity(1024),
266            timeouts,
267            tracer: None,
268            use_rset: false,
269            enable_rset: false,
270            enable_pipelining: false,
271            ignore_8bit_checks: false,
272        }
273    }
274
275    /// Setting this to true causes 8BITMIME and SMTPUTF8 checks to be ignored
276    pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
277        self.ignore_8bit_checks = enable;
278    }
279
280    pub fn is_connected(&self) -> bool {
281        self.socket.is_some()
282    }
283
284    pub fn set_enable_rset(&mut self, enable: bool) {
285        self.enable_rset = enable;
286    }
287
288    pub fn set_enable_pipelining(&mut self, enable: bool) {
289        self.enable_pipelining = enable;
290    }
291
292    pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
293        self.tracer.replace(tracer);
294    }
295
296    pub fn timeouts(&self) -> &SmtpClientTimeouts {
297        &self.timeouts
298    }
299
300    async fn read_line(
301        &mut self,
302        timeout_duration: Duration,
303        cmd: Option<&Command>,
304    ) -> Result<String, ClientError> {
305        let mut too_long = false;
306        loop {
307            let mut iter = self.read_buffer.iter().enumerate();
308            while let Some((i, &b)) = iter.next() {
309                if b != b'\r' {
310                    continue;
311                }
312                if let Some((_, b'\n')) = iter.next() {
313                    if too_long {
314                        self.read_buffer.drain(0..i + 2);
315
316                        if let Some(tracer) = &self.tracer {
317                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
318                                level: Level::ERROR,
319                                message: "Response too long".to_string(),
320                            });
321                        }
322
323                        return Err(ClientError::ResponseTooLong);
324                    }
325
326                    let line = String::from_utf8(self.read_buffer[0..i].to_vec());
327                    self.read_buffer.drain(0..i + 2);
328                    return Ok(line?);
329                }
330            }
331            if self.read_buffer.len() > MAX_LINE_LEN {
332                self.read_buffer.clear();
333                too_long = true;
334            }
335
336            // Didn't find a complete line, fill up the rest of the buffer
337            let mut data = [0u8; MAX_LINE_LEN];
338            let size = match self.socket.as_mut() {
339                Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
340                    Ok(Ok(size)) => size,
341                    Ok(Err(err)) => {
342                        self.socket.take();
343                        if let Some(tracer) = &self.tracer {
344                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
345                                level: Level::ERROR,
346                                message: format!("Error during read: {err:#}"),
347                            });
348                            tracer.trace_event(SmtpClientTraceEvent::Closed);
349                        }
350                        return Err(ClientError::ReadError {
351                            command: cmd.cloned(),
352                            error: format!("{err:#}"),
353                            partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
354                        });
355                    }
356                    Err(_) => {
357                        self.socket.take();
358                        if let Some(tracer) = &self.tracer {
359                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
360                                level: Level::ERROR,
361                                message: format!("Read Timeout after {timeout_duration:?}"),
362                            });
363                            tracer.trace_event(SmtpClientTraceEvent::Closed);
364                        }
365                        return Err(ClientError::TimeOutResponse {
366                            command: cmd.cloned(),
367                            duration: timeout_duration,
368                        });
369                    }
370                },
371                None => {
372                    return Err(ClientError::ReadError {
373                        command: cmd.cloned(),
374                        error: "the socket was closed in response to an earlier issue".to_string(),
375                        partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
376                    });
377                }
378            };
379            if size == 0 {
380                self.socket.take();
381                if let Some(tracer) = &self.tracer {
382                    tracer.trace_event(SmtpClientTraceEvent::Closed);
383                }
384                return Err(ClientError::ReadError {
385                    command: cmd.cloned(),
386                    error: "Connection closed by peer".to_string(),
387                    partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
388                });
389            }
390            if let Some(tracer) = &self.tracer {
391                tracer.lazy_trace(&ReadTracer {
392                    data: &data[0..size],
393                });
394            }
395            self.read_buffer.extend_from_slice(&data[0..size]);
396        }
397    }
398
399    pub async fn read_response(
400        &mut self,
401        command: Option<&Command>,
402        timeout_duration: Duration,
403    ) -> Result<Response, ClientError> {
404        if let Some(sock) = self.socket.as_mut() {
405            if let Err(err) = sock.flush().await {
406                self.socket.take();
407                if let Some(tracer) = &self.tracer {
408                    tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
409                        level: Level::ERROR,
410                        message: format!("Error during flush: {err:#}"),
411                    });
412                    tracer.trace_event(SmtpClientTraceEvent::Closed);
413                }
414                return Err(ClientError::FlushError {
415                    error: format!("{err:#}"),
416                });
417            }
418        }
419
420        let mut line = self.read_line(timeout_duration, command).await?;
421        tracing::trace!("recv<-{}: {line}", self.hostname);
422        let mut parsed = parse_response_line(&line)?;
423        let mut response_builder = ResponseBuilder::new(&parsed);
424
425        let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
426        while !parsed.is_final {
427            line = self
428                .read_line(subsequent_line_timeout_duration, command)
429                .await?;
430            parsed = parse_response_line(&line)?;
431            response_builder
432                .add_line(&parsed)
433                .map_err(ClientError::MalformedResponseLine)?;
434        }
435
436        let response = response_builder.build(command.map(|cmd| cmd.encode()));
437
438        tracing::trace!("{}: {response:?}", self.hostname);
439
440        Ok(response)
441    }
442
443    pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
444        self.write_command_request(command).await?;
445        self.read_response(Some(command), command.client_timeout(&self.timeouts))
446            .await
447    }
448
449    /// Wrapper around socket.write_all() that will emit trace diagnostics and synthesize
450    /// a Close event to the tracer if a timeout or IO error occurs.
451    /// If an error or timeout, occurs ensures that the socket will not be reused.
452    async fn write_all_with_timeout<F, G>(
453        &mut self,
454        timeout_duration: Duration,
455        bytes: &[u8],
456        make_timeout_err: F,
457        make_write_err: G,
458    ) -> Result<(), ClientError>
459    where
460        F: FnOnce() -> ClientError,
461        G: FnOnce(String) -> ClientError,
462    {
463        match self.socket.as_mut() {
464            Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
465                Ok(Ok(response)) => Ok(response),
466                Ok(Err(err)) => {
467                    self.socket.take();
468                    if let Some(tracer) = &self.tracer {
469                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
470                            level: Level::ERROR,
471                            message: format!("Error during write: {err:#}"),
472                        });
473                        tracer.trace_event(SmtpClientTraceEvent::Closed);
474                    }
475                    Err(make_write_err(format!("{err:#}")))
476                }
477                Err(_) => {
478                    self.socket.take();
479                    if let Some(tracer) = &self.tracer {
480                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
481                            level: Level::ERROR,
482                            message: format!("Write Timeout after {timeout_duration:?}"),
483                        });
484                        tracer.trace_event(SmtpClientTraceEvent::Closed);
485                    }
486                    Err(make_timeout_err())
487                }
488            },
489            None => Err(make_write_err(
490                "the socket was closed in response to an earlier issue".to_string(),
491            )),
492        }
493    }
494
495    async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
496        let total_timeout: Duration = commands
497            .iter()
498            .map(|cmd| cmd.client_timeout_request(&self.timeouts))
499            .sum();
500
501        let mut lines: Vec<String> = vec![];
502        let mut all = String::new();
503        for cmd in commands {
504            let line = cmd.encode();
505            all.push_str(&line);
506            lines.push(line);
507        }
508        tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
509        if self.socket.is_some() {
510            if let Some(tracer) = &self.tracer {
511                // Send the lines individually to the tracer, so that we
512                // don't break --terse mode
513                for line in lines {
514                    WriteTracer::trace(tracer, &line);
515                }
516            }
517        }
518        self.write_all_with_timeout(
519            total_timeout,
520            all.as_bytes(),
521            || ClientError::TimeOutRequest {
522                duration: total_timeout,
523                commands: commands.to_vec(),
524            },
525            |error| ClientError::WriteError {
526                error,
527                commands: commands.to_vec(),
528            },
529        )
530        .await
531    }
532
533    async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
534        let line = command.encode();
535        tracing::trace!("send->{}: {line}", self.hostname);
536        if self.socket.is_some() {
537            if let Some(tracer) = &self.tracer {
538                WriteTracer::trace(tracer, &line);
539            }
540        }
541
542        let timeout_duration = command.client_timeout_request(&self.timeouts);
543        self.write_all_with_timeout(
544            timeout_duration,
545            line.as_bytes(),
546            || ClientError::TimeOutRequest {
547                duration: timeout_duration,
548                commands: vec![command.clone()],
549            },
550            |error| ClientError::WriteError {
551                error,
552                commands: vec![command.clone()],
553            },
554        )
555        .await
556    }
557
558    async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
559        if self.socket.is_some() {
560            if let Some(tracer) = &self.tracer {
561                BinWriteTracer::trace(tracer, data);
562            }
563        }
564        let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
565        self.write_all_with_timeout(
566            timeout_duration,
567            data,
568            || ClientError::TimeOutData,
569            |error| ClientError::WriteError {
570                error,
571                commands: vec![],
572            },
573        )
574        .await
575    }
576
577    /// Issue a series of commands, and return the responses to
578    /// those commands.
579    ///
580    /// If the server advertised the RFC 2920 PIPELINING extension,
581    /// the commands are written one after the other before waiting
582    /// to read any data, resulting in lower overall latency due
583    /// to round-trip-times.
584    ///
585    /// If PIPELINING is not available, each command is written
586    /// and the response read before attempting to write the next
587    /// command.
588    ///
589    /// The number of returned responses may be smaller than the
590    /// number of requested commands if there is an issue with
591    /// the network connection.
592    pub async fn pipeline_commands(
593        &mut self,
594        commands: Vec<Command>,
595    ) -> Vec<Result<Response, ClientError>> {
596        let mut results: Vec<Result<Response, ClientError>> = vec![];
597
598        let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
599        if pipeline {
600            if let Err(err) = self.write_pipeline_request(&commands).await {
601                let err: ClientError = err;
602                results.push(Err(err.clone()));
603                while results.len() < commands.len() {
604                    // Synthesize failures for the remaining commands
605                    results.push(Err(err.clone()));
606                }
607                return results;
608            }
609
610            // Now read the responses effectively in a batch
611            for cmd in &commands {
612                results.push(
613                    self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
614                        .await,
615                );
616            }
617            return results;
618        }
619
620        for cmd in &commands {
621            if let Err(err) = self.write_command_request(cmd).await {
622                let err: ClientError = err;
623                results.push(Err(err.clone()));
624                while results.len() < commands.len() {
625                    // Synthesize failures for the remaining commands
626                    results.push(Err(err.clone()));
627                }
628                return results;
629            }
630            // Immediately request the response if the server
631            // doesn't support pipelining
632            results.push(
633                self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
634                    .await,
635            );
636        }
637        results
638    }
639
640    pub async fn ehlo_lhlo(
641        &mut self,
642        ehlo_name: &str,
643        use_lmtp: bool,
644    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
645        if use_lmtp {
646            self.lhlo(ehlo_name).await
647        } else {
648            self.ehlo(ehlo_name).await
649        }
650    }
651
652    pub async fn lhlo(
653        &mut self,
654        ehlo_name: &str,
655    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
656        let response = self
657            .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
658            .await?;
659        self.ehlo_common(response)
660    }
661
662    pub async fn ehlo(
663        &mut self,
664        ehlo_name: &str,
665    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
666        let response = self
667            .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
668            .await?;
669        self.ehlo_common(response)
670    }
671
672    fn ehlo_common(
673        &mut self,
674        response: Response,
675    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
676        if response.code != 250 {
677            return Err(ClientError::Rejected(response));
678        }
679
680        let mut capabilities = HashMap::new();
681
682        for line in response.content.lines().skip(1) {
683            let mut fields = line.splitn(2, ' ');
684            if let Some(name) = fields.next() {
685                let param = fields.next().map(|s| s.to_string());
686                let cap = EsmtpCapability {
687                    name: name.to_string(),
688                    param,
689                };
690                capabilities.insert(name.to_ascii_uppercase(), cap);
691            }
692        }
693
694        self.capabilities = capabilities;
695        Ok(&self.capabilities)
696    }
697
698    pub async fn auth_plain(
699        &mut self,
700        username: &str,
701        password: Option<&str>,
702    ) -> Result<(), ClientError> {
703        // RFC 4616 says that the format is:
704        // [authzid] NUL authcid NUL passwd
705        let password = password.unwrap_or("");
706        let payload = format!("\x00{username}\x00{password}");
707        let payload = data_encoding::BASE64.encode(payload.as_bytes());
708
709        let response = self
710            .send_command(&Command::Auth {
711                sasl_mech: "PLAIN".to_string(),
712                initial_response: Some(payload),
713            })
714            .await?;
715
716        if response.code != 235 {
717            return Err(ClientError::Rejected(response));
718        }
719
720        Ok(())
721    }
722
723    /// Attempt TLS handshake.
724    /// Returns Err for IO errors.
725    /// On completion, return an option that will be:
726    /// * Some(handshake_error) - if the handshake failed
727    /// * None - if the handshake succeeded
728    pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
729        let resp = self.send_command(&Command::StartTls).await?;
730        if resp.code != 220 {
731            return Err(ClientError::Rejected(resp));
732        }
733
734        let mut handshake_error = None;
735        let mut tls_info = TlsInformation::default();
736
737        let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
738            || !options.dane_tlsa.is_empty()
739        {
740            let connector = options
741                .build_openssl_connector(&self.hostname)
742                .map_err(|error| ClientError::FailedToBuildConnector {
743                    error: error.to_string(),
744                })?;
745            let ssl = connector.into_ssl(self.hostname.as_str())?;
746
747            let (stream, dup_stream) = match self.socket.take() {
748                Some(s) => {
749                    let d = s.try_dup();
750                    (s, d)
751                }
752                None => return Err(ClientError::NotConnected),
753            };
754
755            let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
756
757            if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
758                handshake_error.replace(format!("{err:#}"));
759            }
760
761            tls_info.provider_name = "openssl".to_string();
762            tls_info.cipher = match ssl_stream.ssl().current_cipher() {
763                Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
764                None => String::new(),
765            };
766            tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
767
768            if let Some(cert) = ssl_stream.ssl().peer_certificate() {
769                tls_info.subject_name = subject_name(&cert);
770            }
771            if let Ok(authority) = ssl_stream.ssl().dane_authority() {
772                if let Some(cert) = &authority.cert {
773                    tls_info.subject_name = subject_name(cert);
774                }
775            }
776
777            match (&handshake_error, dup_stream) {
778                (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
779                    // Try falling back to clear text on the duplicate stream.
780                    // This is imperfect: in a failed validation scenario we will
781                    // end up trying to read binary data as a string and get a UTF-8
782                    // error if the peer thinks the session is encrypted.
783                    drop(ssl_stream);
784                    Box::new(dup_stream)
785                }
786                _ => Box::new(ssl_stream),
787            }
788        } else {
789            tls_info.provider_name = "rustls".to_string();
790            let connector = options.build_tls_connector().await.map_err(|error| {
791                ClientError::FailedToBuildConnector {
792                    error: error.to_string(),
793                }
794            })?;
795            let server_name = match IpAddr::from_str(self.hostname.as_str()) {
796                Ok(ip) => ServerName::IpAddress(ip.into()),
797                Err(_) => ServerName::try_from(self.hostname.clone())
798                    .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
799            };
800
801            match connector
802                .connect(
803                    server_name,
804                    match self.socket.take() {
805                        Some(s) => s,
806                        None => return Err(ClientError::NotConnected),
807                    },
808                )
809                .into_fallible()
810                .await
811            {
812                Ok(stream) => {
813                    let (_, conn) = stream.get_ref();
814                    tls_info.cipher = match conn.negotiated_cipher_suite() {
815                        Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
816                        None => String::new(),
817                    };
818                    tls_info.protocol_version = match conn.protocol_version() {
819                        Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
820                        None => String::new(),
821                    };
822
823                    if let Some(certs) = conn.peer_certificates() {
824                        let peer_cert = &certs[0];
825                        if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
826                            tls_info.subject_name = subject_name(&cert);
827                        }
828                    }
829
830                    Box::new(stream)
831                }
832                Err((err, stream)) => {
833                    handshake_error.replace(format!("{err:#}"));
834                    stream
835                }
836            }
837        };
838
839        if let Some(tracer) = &self.tracer {
840            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
841                level: Level::INFO,
842                message: match &handshake_error {
843                    Some(error) => format!("STARTTLS handshake failed: {error:?}"),
844                    None => format!("STARTTLS handshake -> {tls_info:?}"),
845                },
846            });
847        }
848
849        self.socket.replace(stream);
850        Ok(match handshake_error {
851            Some(error) => TlsStatus::FailedHandshake(error),
852            None => TlsStatus::Info(tls_info),
853        })
854    }
855
856    pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
857        &mut self,
858        sender: SENDER,
859        recipient: RECIP,
860        data: B,
861    ) -> Result<Response, ClientError> {
862        let recipient: ForwardPath = recipient.into();
863        let status = self
864            .send_mail_multi_recip(sender, vec![recipient], data)
865            .await?;
866        Ok(status.response)
867    }
868
869    pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
870        &mut self,
871        sender: SENDER,
872        recipient_list: Vec<ForwardPath>,
873        data: B,
874    ) -> Result<BatchSendSuccess, ClientError> {
875        let sender = sender.into();
876
877        let data: &[u8] = data.as_ref();
878        let stuffed;
879
880        let data = match apply_dot_stuffing(data) {
881            Some(d) => {
882                stuffed = d;
883                &stuffed
884            }
885            None => data,
886        };
887
888        let data_is_8bit = data.iter().any(|&b| b >= 0x80);
889        let envelope_is_8bit =
890            !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
891
892        let mut mail_from_params = vec![];
893        if data_is_8bit {
894            if self.capabilities.contains_key("8BITMIME") {
895                mail_from_params.push(EsmtpParameter {
896                    name: "BODY".to_string(),
897                    value: Some("8BITMIME".to_string()),
898                });
899            } else if !self.ignore_8bit_checks {
900                return Err(ClientError::Rejected(Response {
901                    code: 554,
902                    command: None,
903                    enhanced_code: Some(EnhancedStatusCode {
904                        class: 5,
905                        subject: 6,
906                        detail: 3,
907                    }),
908                    content: "KumoMTA internal: DATA is 8bit, destination does \
909                        not support 8BITMIME. Conversion via msg:check_fix_conformance \
910                        during reception is required"
911                        .to_string(),
912                }));
913            }
914        }
915
916        if envelope_is_8bit {
917            if self.capabilities.contains_key("SMTPUTF8") {
918                mail_from_params.push(EsmtpParameter {
919                    name: "SMTPUTF8".to_string(),
920                    value: None,
921                });
922            } else if !self.ignore_8bit_checks {
923                return Err(ClientError::Rejected(Response {
924                    code: 554,
925                    command: None,
926                    enhanced_code: Some(EnhancedStatusCode {
927                        class: 5,
928                        subject: 6,
929                        detail: 7,
930                    }),
931                    content: "KumoMTA internal: envelope is 8bit, destination does \
932                        not support SMTPUTF8."
933                        .to_string(),
934                }));
935            }
936        }
937
938        let mut commands = vec![];
939
940        // We want to avoid using RSET for the first message we send on
941        // a given connection, because postfix can run in a mode where
942        // it will not tolerate RSET because it considers it to be a "junk"
943        // command, and rejecting junk commands will cut down on its load
944        // when it is under stress; it is used as a load shedding approach.
945        // If we always RSET then we will never deliver to a site that is
946        // configured that way. If we take care to RSET only for subsequent
947        // sends, then we should get one message per connection through
948        // without being unfairly penalized for defensively RSETing.
949        let used_rset = self.use_rset;
950        if self.use_rset {
951            commands.push(Command::Rset);
952        }
953        commands.push(Command::MailFrom {
954            address: sender,
955            parameters: mail_from_params,
956        });
957
958        for recipient in &recipient_list {
959            commands.push(Command::RcptTo {
960                address: recipient.clone(),
961                parameters: vec![],
962            });
963        }
964        commands.push(Command::Data);
965
966        // Assume that something might break below: if it does, we want
967        // to ensure that we RSET the connection on the next go around.
968        self.use_rset = true;
969
970        let mut responses = self.pipeline_commands(commands).await;
971
972        // This is a little awkward. We want to handle the RFC 2090 3.1 case
973        // below, which requires deferring checking the actual response codes
974        // until later, but we must also handle the case where we had a hard
975        // transport error partway through pipelining.
976        // So we set a flag for that case and will then "eagerly", wrt. the
977        // RFC 2090 3.1 logic, evaluate the SMTP response codes, so that we
978        // can propagate the correct error disposition up to the caller.
979        let is_err = responses.iter().any(|r| r.is_err());
980
981        if used_rset {
982            let rset_resp = responses.remove(0)?;
983            if rset_resp.code != 250 {
984                return Err(ClientError::Rejected(rset_resp));
985            }
986        }
987
988        let mail_resp = responses.remove(0)?;
989        if is_err && mail_resp.code != 250 {
990            return Err(ClientError::Rejected(mail_resp));
991        }
992
993        let mut rcpt_responses = vec![];
994        for _ in &recipient_list {
995            rcpt_responses.push(responses.remove(0)?);
996        }
997
998        if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
999            return Err(ClientError::RejectedBatch(rcpt_responses));
1000        }
1001
1002        let data_resp = responses.remove(0)?;
1003        if is_err && data_resp.code != 354 {
1004            return Err(ClientError::Rejected(data_resp));
1005        }
1006
1007        if data_resp.code == 354
1008            && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1009        {
1010            // RFC 2920 3.1:
1011            // the client cannot assume that the DATA command will be rejected
1012            // just because none of the RCPT TO commands worked.  If the DATA
1013            // command was properly rejected the client SMTP can just issue
1014            // RSET, but if the DATA command was accepted the client SMTP
1015            // should send a single dot.
1016
1017            // Send dummy data
1018            self.write_data_with_timeout(b".\r\n").await?;
1019            let data_dot = Command::DataDot;
1020            // wait for its response
1021            let _ = self
1022                .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1023                .await?;
1024
1025            // Continue below: we will match one of the failure cases and
1026            // return a ClientError::Rejected from one of the earlier
1027            // commands
1028        }
1029
1030        if mail_resp.code != 250 {
1031            return Err(ClientError::Rejected(mail_resp));
1032        }
1033        if rcpt_responses.iter().all(|resp| resp.code != 250) {
1034            if rcpt_responses.len() == 1 {
1035                return Err(ClientError::Rejected(
1036                    rcpt_responses.pop().expect("have at least one"),
1037                ));
1038            }
1039            return Err(ClientError::RejectedBatch(rcpt_responses));
1040        }
1041        if data_resp.code != 354 {
1042            return Err(ClientError::Rejected(data_resp));
1043        }
1044
1045        let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1046
1047        tracing::trace!("message data is {} bytes", data.len());
1048
1049        self.write_data_with_timeout(data).await?;
1050
1051        let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1052
1053        tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1054
1055        self.write_data_with_timeout(marker.as_bytes()).await?;
1056
1057        let data_dot = Command::DataDot;
1058        let resp = self
1059            .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1060            .await?;
1061        if resp.code != 250 {
1062            return Err(ClientError::Rejected(resp));
1063        }
1064
1065        // If everything went well, respect the user preference for speculatively
1066        // issuing an RSET next time around
1067        self.use_rset = self.enable_rset;
1068
1069        Ok(BatchSendSuccess {
1070            response: resp,
1071            rcpt_responses,
1072        })
1073    }
1074}
1075
1076#[derive(Debug)]
1077pub struct BatchSendSuccess {
1078    pub response: Response,
1079    pub rcpt_responses: Vec<Response>,
1080}
1081
1082#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1083pub enum TlsStatus {
1084    FailedHandshake(String),
1085    Info(TlsInformation),
1086}
1087
1088#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1089pub struct TlsInformation {
1090    pub cipher: String,
1091    pub protocol_version: String,
1092    pub subject_name: Vec<String>,
1093    pub provider_name: String,
1094}
1095
1096impl Drop for SmtpClient {
1097    fn drop(&mut self) {
1098        if let Some(tracer) = &self.tracer {
1099            if self.socket.is_some() {
1100                tracer.trace_event(SmtpClientTraceEvent::Closed);
1101            }
1102        }
1103    }
1104}
1105fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1106    if line.len() < 4 {
1107        return Err(ClientError::MalformedResponseLine(line.to_string()));
1108    }
1109
1110    match line.as_bytes()[3] {
1111        b' ' | b'-' => match line[0..3].parse::<u16>() {
1112            Ok(code) => Ok(ResponseLine {
1113                code,
1114                is_final: line.as_bytes()[3] == b' ',
1115                content: &line[4..],
1116            }),
1117            Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1118        },
1119        _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1120    }
1121}
1122
1123fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1124    static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1125
1126    if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1127        return None;
1128    }
1129
1130    let mut stuffed = vec![];
1131    if data.starts_with(b".") {
1132        stuffed.push(b'.');
1133    }
1134    let mut last_idx = 0;
1135    for i in LFDOT.find_iter(data) {
1136        stuffed.extend_from_slice(&data[last_idx..=i]);
1137        stuffed.push(b'.');
1138        last_idx = i + 1;
1139    }
1140    stuffed.extend_from_slice(&data[last_idx..]);
1141    Some(stuffed)
1142}
1143
1144/// Extracts the object=name pairs of the subject name from a cert.
1145/// eg:
1146/// ```no_run
1147/// ["C=US", "ST=CA", "L=SanFrancisco", "O=Fort-Funston", "OU=MyOrganizationalUnit",
1148/// "CN=do.havedane.net", "name=EasyRSA", "emailAddress=me@myhost.mydomain"]
1149/// ```
1150pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1151    let mut subject_name = vec![];
1152    for entry in cert.subject_name().entries() {
1153        if let Ok(obj) = entry.object().nid().short_name() {
1154            if let Ok(name) = entry.data().as_utf8() {
1155                subject_name.push(format!("{obj}={name}"));
1156            }
1157        }
1158    }
1159    subject_name
1160}
1161
1162#[cfg(test)]
1163mod test {
1164    use super::*;
1165    use crate::{MailPath, Mailbox};
1166
1167    #[test]
1168    fn test_stuffing() {
1169        assert_eq!(apply_dot_stuffing(b"foo"), None);
1170        assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1171        assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1172        assert_eq!(
1173            apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1174            b"foo\n..bar\n...baz\n"
1175        );
1176    }
1177
1178    /*
1179    #[tokio::test]
1180    async fn test_against_sink() {
1181        use tokio::net::TcpStream;
1182        let stream = TcpStream::connect("127.0.0.1:2025").await.unwrap();
1183        let mut client =
1184            SmtpClient::with_stream(stream, "localhost", SmtpClientTimeouts::default());
1185        dbg!(client.read_response(None).await).unwrap();
1186        dbg!(client.ehlo("localhost").await).unwrap();
1187        let insecure = true;
1188        dbg!(client.starttls(insecure).await).unwrap();
1189        let resp = client
1190            .send_mail(
1191                ReversePath::try_from("wez@mail.example.com").unwrap(),
1192                ForwardPath::try_from("wez@mail.example.com").unwrap(),
1193                "Subject: hello\r\n\r\nwoot\r\n",
1194            )
1195            .await
1196            .unwrap();
1197        panic!("{resp:#?}");
1198    }
1199    */
1200
1201    #[test]
1202    fn response_line_parsing() {
1203        assert_eq!(
1204            parse_response_line("220 woot").unwrap(),
1205            ResponseLine {
1206                code: 220,
1207                is_final: true,
1208                content: "woot"
1209            }
1210        );
1211        assert_eq!(
1212            parse_response_line("220-woot").unwrap(),
1213            ResponseLine {
1214                code: 220,
1215                is_final: false,
1216                content: "woot"
1217            }
1218        );
1219
1220        assert!(matches!(
1221            parse_response_line("220_woot"),
1222            Err(ClientError::MalformedResponseLine(_))
1223        ));
1224        assert!(matches!(
1225            parse_response_line("not really"),
1226            Err(ClientError::MalformedResponseLine(_))
1227        ));
1228    }
1229
1230    fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1231        let mut parsed = parse_response_line(lines[0])?;
1232        let mut b = ResponseBuilder::new(&parsed);
1233        for line in &lines[1..] {
1234            parsed = parse_response_line(line)?;
1235            b.add_line(&parsed)
1236                .map_err(ClientError::MalformedResponseLine)?;
1237        }
1238        assert!(parsed.is_final);
1239        Ok(b.build(None))
1240    }
1241
1242    #[test]
1243    fn multi_line_response() {
1244        assert_eq!(
1245            parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1246            Response {
1247                code: 220,
1248                enhanced_code: None,
1249                content: "woot\nmore\ndone".to_string(),
1250                command: None
1251            }
1252        );
1253
1254        let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1255        assert!(
1256            matches!(
1257                    res,
1258                ClientError::MalformedResponseLine(ref err) if err == "221-more"
1259            ),
1260            "got error {res:?}"
1261        );
1262
1263        let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1264        assert!(
1265            matches!(
1266                    res,
1267                ClientError::MalformedResponseLine(ref err) if err == "221 done"
1268            ),
1269            "got error {res:?}"
1270        );
1271
1272        assert_eq!(
1273            parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1274            Response {
1275                code: 220,
1276                enhanced_code: Some(EnhancedStatusCode {
1277                    class: 4,
1278                    subject: 1,
1279                    detail: 0
1280                }),
1281                content: "woot\nmore\ndone".to_string(),
1282                command: None
1283            }
1284        );
1285
1286        // Confirm that we strip the enhanced status code from each line
1287        assert_eq!(
1288            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1289            Response {
1290                code: 220,
1291                enhanced_code: Some(EnhancedStatusCode {
1292                    class: 4,
1293                    subject: 1,
1294                    detail: 0
1295                }),
1296                content: "woot\nmore\ndone".to_string(),
1297                command: None
1298            }
1299        );
1300
1301        // ... but only if the code matches that of the first line
1302        assert_eq!(
1303            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1304            Response {
1305                code: 220,
1306                enhanced_code: Some(EnhancedStatusCode {
1307                    class: 4,
1308                    subject: 1,
1309                    detail: 0
1310                }),
1311                content: "woot\nmore\n5.5.5 done".to_string(),
1312                command: None
1313            }
1314        );
1315    }
1316
1317    #[test]
1318    fn test_extract_hostname() {
1319        assert_eq!(extract_hostname("foo"), "foo");
1320        assert_eq!(extract_hostname("foo."), "foo");
1321        assert_eq!(extract_hostname("foo:25"), "foo");
1322        assert_eq!(extract_hostname("foo.:25"), "foo");
1323        assert_eq!(extract_hostname("[foo]:25"), "foo");
1324        assert_eq!(extract_hostname("[foo.]:25"), "foo");
1325        assert_eq!(extract_hostname("[::1]:25"), "::1");
1326        assert_eq!(extract_hostname("::1:25"), "::1");
1327    }
1328
1329    #[test]
1330    fn test_format_error_command() {
1331        assert_eq!(
1332            format!(
1333                "{:#}",
1334                ClientError::TimeOutRequest {
1335                    commands: vec![Command::DataDot],
1336                    duration: Duration::from_secs(10),
1337                }
1338            ),
1339            "Timed Out after 10s writing cmd=."
1340        );
1341        assert_eq!(
1342            format!(
1343                "{:#}",
1344                ClientError::TimeOutResponse {
1345                    command: Some(Command::MailFrom {
1346                        address: ReversePath::Path(MailPath {
1347                            at_domain_list: vec![],
1348                            mailbox: Mailbox {
1349                                local_part: "user".to_string(),
1350                                domain: Domain::Name("host".to_string())
1351                            }
1352                        }),
1353                        parameters: vec![],
1354                    }),
1355                    duration: Duration::from_secs(10),
1356                }
1357            ),
1358            r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1359        );
1360    }
1361}