rfc5321/
client.rs

1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::parser::{Command, Domain, EsmtpParameter, ForwardPath, ReversePath};
4use crate::{AsyncReadAndWrite, BoxedAsyncReadAndWrite};
5use bstr::ByteSlice;
6use hickory_proto::rr::rdata::TLSA;
7use hickory_proto::rr::Name;
8use memchr::memmem::Finder;
9use nom_utils::DomainString;
10use openssl::x509::{X509Ref, X509};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::net::IpAddr;
14use std::str::FromStr;
15use std::sync::{Arc, LazyLock};
16use std::time::Duration;
17use thiserror::Error;
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tokio::net::{TcpStream, ToSocketAddrs};
20use tokio::time::timeout;
21use tokio_rustls::rustls::pki_types::ServerName;
22use tracing::Level;
23
24pub use kumo_tls_helper::TlsOptions;
25pub use openssl;
26pub use tokio_rustls;
27
28const MAX_LINE_LEN: usize = 4096;
29
30#[derive(Error, Debug, Clone)]
31pub enum ClientError {
32    #[error("response is not UTF8")]
33    Utf8(#[from] std::string::FromUtf8Error),
34    #[error("Malformed Response: {0}")]
35    MalformedResponseLine(String),
36    #[error("Response line is too long")]
37    ResponseTooLong,
38    #[error("Not connected")]
39    NotConnected,
40    #[error("Command rejected {0:?}")]
41    Rejected(Response),
42    #[error("Commands rejected {0:?}")]
43    RejectedBatch(Vec<Response>),
44    #[error("STARTTLS: {0} is not a valid DNS name")]
45    InvalidDnsName(String),
46    #[error("Invalid client certificate configured: {error:?}")]
47    FailedToBuildConnector { error: String },
48    #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
49    TimeOutResponse {
50        command: Option<Command>,
51        duration: Duration,
52    },
53    #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
54    TimeOutRequest {
55        commands: Vec<Command>,
56        duration: Duration,
57    },
58    #[error("Error {error} reading response to cmd={}", self.command_for_err())]
59    ReadError {
60        command: Option<Command>,
61        error: String,
62        partial: String,
63    },
64    #[error("Error {error} flushing send buffer")]
65    FlushError { error: String },
66    #[error("Error {error} writing {}", self.command_for_err())]
67    WriteError {
68        commands: Vec<Command>,
69        error: String,
70    },
71    #[error("Timed Out sending message payload data")]
72    TimeOutData,
73    #[error("SSL Error: {0}")]
74    SslErrorStack(#[from] openssl::error::ErrorStack),
75    #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
76    NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
77}
78
79impl ClientError {
80    /// Returns the command(s) string suitable for passing into a Response
81    pub fn command(&self) -> Option<String> {
82        match self {
83            Self::TimeOutResponse {
84                command: Some(command),
85                ..
86            }
87            | Self::ReadError {
88                command: Some(command),
89                ..
90            } => Some(command.encode().to_string()),
91            Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
92                if !commands.is_empty() =>
93            {
94                let s: String = commands
95                    .iter()
96                    .map(|cmd| cmd.encode().to_string())
97                    .collect();
98                Some(s)
99            }
100            _ => None,
101        }
102    }
103
104    /// Returns the command(s) string suitable for embedding in the
105    /// overall error message
106    fn command_for_err(&self) -> String {
107        self.command()
108            .map(|cmd| cmd.replace("\r\n", ""))
109            .unwrap_or_else(|| "NONE".to_string())
110    }
111
112    /// If the error contents were likely caused by something
113    /// about the mostly recently attempted message, rather than
114    /// a transport issue, or a carry-over from a prior message
115    /// (eg: previous message was rejected and destination chose
116    /// to drop the connection, which we detect later in RSET
117    /// on the next message), then we return true.
118    /// The expectation is that the caller will transiently
119    /// fail the message for later retry.
120    /// If we return false then the caller might decide to
121    /// try that same message again more immediately on
122    /// a separate connection
123    pub fn was_due_to_message(&self) -> bool {
124        match self {
125            Self::Utf8(_)
126            | Self::MalformedResponseLine(_)
127            | Self::ResponseTooLong
128            | Self::NotConnected
129            | Self::InvalidDnsName(_)
130            | Self::FailedToBuildConnector { .. }
131            | Self::TimeOutResponse { .. }
132            | Self::TimeOutRequest { .. }
133            | Self::ReadError { .. }
134            | Self::FlushError { .. }
135            | Self::WriteError { .. }
136            | Self::TimeOutData
137            | Self::SslErrorStack(_)
138            | Self::NoUsableDaneTlsa { .. } => false,
139            Self::Rejected(response) => response.was_due_to_message(),
140            Self::RejectedBatch(responses) => match responses.len() {
141                1 => responses[0].was_due_to_message(),
142                _ => false,
143            },
144        }
145    }
146}
147
148#[derive(Debug, Clone, PartialEq, Eq)]
149pub struct EsmtpCapability {
150    pub name: String,
151    pub param: Option<String>,
152}
153
154#[derive(Clone, Debug)]
155pub enum SmtpClientTraceEvent {
156    Closed,
157    Read(Vec<u8>),
158    Write(String),
159    Diagnostic {
160        level: tracing::Level,
161        message: String,
162    },
163}
164
165pub trait DeferredTracer {
166    fn trace(&self) -> SmtpClientTraceEvent;
167}
168
169pub trait SmtpClientTracer: std::fmt::Debug {
170    fn trace_event(&self, event: SmtpClientTraceEvent);
171    fn lazy_trace(&self, deferred: &dyn DeferredTracer);
172}
173
174// helper to avoid making a second copy of every write buffer
175struct WriteTracer<'a> {
176    data: &'a str,
177}
178impl DeferredTracer for WriteTracer<'_> {
179    fn trace(&self) -> SmtpClientTraceEvent {
180        SmtpClientTraceEvent::Write(self.data.to_string())
181    }
182}
183impl<'a> WriteTracer<'a> {
184    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
185        tracer.lazy_trace(&Self { data });
186    }
187}
188
189struct BinWriteTracer<'a> {
190    data: &'a [u8],
191}
192impl DeferredTracer for BinWriteTracer<'_> {
193    fn trace(&self) -> SmtpClientTraceEvent {
194        let data = String::from_utf8_lossy(self.data).to_string();
195        SmtpClientTraceEvent::Write(data)
196    }
197}
198impl<'a> BinWriteTracer<'a> {
199    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
200        tracer.lazy_trace(&Self { data });
201    }
202}
203
204// A little bit of gymnastics to avoid making a second
205// copy of every read buffer
206struct ReadTracer<'a> {
207    data: &'a [u8],
208}
209impl DeferredTracer for ReadTracer<'_> {
210    fn trace(&self) -> SmtpClientTraceEvent {
211        SmtpClientTraceEvent::Read(self.data.to_vec())
212    }
213}
214
215#[derive(Debug)]
216pub struct SmtpClient {
217    socket: Option<BoxedAsyncReadAndWrite>,
218    hostname: String,
219    capabilities: HashMap<String, EsmtpCapability>,
220    read_buffer: Vec<u8>,
221    timeouts: SmtpClientTimeouts,
222    tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
223    use_rset: bool,
224    enable_rset: bool,
225    enable_pipelining: bool,
226    ignore_8bit_checks: bool,
227}
228
229fn extract_hostname(hostname: &str) -> &str {
230    // Just the hostname, without any :port
231    let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
232    let hostname = if fields.len() == 2 {
233        fields[1]
234    } else {
235        hostname
236    };
237
238    let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
239        &hostname[1..hostname.len() - 1]
240    } else {
241        hostname
242    };
243
244    // Remove any trailing FQDN dot
245    hostname.strip_suffix(".").unwrap_or(hostname)
246}
247
248impl SmtpClient {
249    pub async fn new<A: ToSocketAddrs + ToString + Clone>(
250        addr: A,
251        timeouts: SmtpClientTimeouts,
252    ) -> std::io::Result<Self> {
253        let stream = TcpStream::connect(addr.clone()).await?;
254        // No need for Nagle with SMTP request/response
255        stream.set_nodelay(true)?;
256        Ok(Self::with_stream(stream, addr.to_string(), timeouts))
257    }
258
259    pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
260        stream: S,
261        peer_hostname: H,
262        timeouts: SmtpClientTimeouts,
263    ) -> Self {
264        let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
265
266        Self {
267            socket: Some(Box::new(stream)),
268            hostname,
269            capabilities: HashMap::new(),
270            read_buffer: Vec::with_capacity(1024),
271            timeouts,
272            tracer: None,
273            use_rset: false,
274            enable_rset: false,
275            enable_pipelining: false,
276            ignore_8bit_checks: false,
277        }
278    }
279
280    /// Setting this to true causes 8BITMIME and SMTPUTF8 checks to be ignored
281    pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
282        self.ignore_8bit_checks = enable;
283    }
284
285    pub fn is_connected(&self) -> bool {
286        self.socket.is_some()
287    }
288
289    pub fn set_enable_rset(&mut self, enable: bool) {
290        self.enable_rset = enable;
291    }
292
293    pub fn set_enable_pipelining(&mut self, enable: bool) {
294        self.enable_pipelining = enable;
295    }
296
297    pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
298        self.tracer.replace(tracer);
299    }
300
301    pub fn timeouts(&self) -> &SmtpClientTimeouts {
302        &self.timeouts
303    }
304
305    async fn read_line(
306        &mut self,
307        timeout_duration: Duration,
308        cmd: Option<&Command>,
309    ) -> Result<String, ClientError> {
310        let mut too_long = false;
311        loop {
312            let mut iter = self.read_buffer.iter().enumerate();
313            while let Some((i, &b)) = iter.next() {
314                if b != b'\r' {
315                    continue;
316                }
317                if let Some((_, b'\n')) = iter.next() {
318                    if too_long {
319                        self.read_buffer.drain(0..i + 2);
320
321                        if let Some(tracer) = &self.tracer {
322                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
323                                level: Level::ERROR,
324                                message: "Response too long".to_string(),
325                            });
326                        }
327
328                        return Err(ClientError::ResponseTooLong);
329                    }
330
331                    let line = String::from_utf8(self.read_buffer[0..i].to_vec());
332                    self.read_buffer.drain(0..i + 2);
333                    return Ok(line?);
334                }
335            }
336            if self.read_buffer.len() > MAX_LINE_LEN {
337                self.read_buffer.clear();
338                too_long = true;
339            }
340
341            // Didn't find a complete line, fill up the rest of the buffer
342            let mut data = [0u8; MAX_LINE_LEN];
343            let size = match self.socket.as_mut() {
344                Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
345                    Ok(Ok(size)) => size,
346                    Ok(Err(err)) => {
347                        self.socket.take();
348                        if let Some(tracer) = &self.tracer {
349                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
350                                level: Level::ERROR,
351                                message: format!("Error during read: {err:#}"),
352                            });
353                            tracer.trace_event(SmtpClientTraceEvent::Closed);
354                        }
355                        return Err(ClientError::ReadError {
356                            command: cmd.cloned(),
357                            error: format!("{err:#}"),
358                            partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
359                        });
360                    }
361                    Err(_) => {
362                        // If we're using a zero timeout we're probing
363                        // for liveness, so we don't want to treat this
364                        // timeout as a fatal error that closes the
365                        // connection
366                        if timeout_duration != Duration::ZERO {
367                            self.socket.take();
368                        }
369                        if let Some(tracer) = &self.tracer {
370                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
371                                level: Level::ERROR,
372                                message: format!("Read Timeout after {timeout_duration:?}"),
373                            });
374                            if self.socket.is_none() {
375                                tracer.trace_event(SmtpClientTraceEvent::Closed);
376                            }
377                        }
378                        return Err(ClientError::TimeOutResponse {
379                            command: cmd.cloned(),
380                            duration: timeout_duration,
381                        });
382                    }
383                },
384                None => {
385                    return Err(ClientError::ReadError {
386                        command: cmd.cloned(),
387                        error: "the socket was closed in response to an earlier issue".to_string(),
388                        partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
389                    });
390                }
391            };
392            if size == 0 {
393                self.socket.take();
394                if let Some(tracer) = &self.tracer {
395                    tracer.trace_event(SmtpClientTraceEvent::Closed);
396                }
397                return Err(ClientError::ReadError {
398                    command: cmd.cloned(),
399                    error: "Connection closed by peer".to_string(),
400                    partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
401                });
402            }
403            if let Some(tracer) = &self.tracer {
404                tracer.lazy_trace(&ReadTracer {
405                    data: &data[0..size],
406                });
407            }
408            self.read_buffer.extend_from_slice(&data[0..size]);
409        }
410    }
411
412    /// Check to see if there is either a unilateral response
413    /// or an error condition immediately available on the socket.
414    ///
415    /// Intended to be used after the connection has idled to see if
416    /// the peer has opted to close the connection, either gracefully
417    /// by sending what should be a 421, or abruptly by simply snipping
418    /// the connection.
419    pub async fn check_unilateral_response(&mut self) -> Result<Option<Response>, ClientError> {
420        match self.read_response(None, Duration::ZERO).await {
421            // Unilateral response was received
422            Ok(response) => Ok(Some(response)),
423            // Nothing immediately available: this should hopefully be the common case
424            Err(ClientError::TimeOutResponse { .. }) => Ok(None),
425            // An actual error
426            Err(err) => Err(err),
427        }
428    }
429
430    pub async fn read_response(
431        &mut self,
432        command: Option<&Command>,
433        timeout_duration: Duration,
434    ) -> Result<Response, ClientError> {
435        if let Some(sock) = self.socket.as_mut() {
436            if let Err(err) = sock.flush().await {
437                self.socket.take();
438                if let Some(tracer) = &self.tracer {
439                    tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
440                        level: Level::ERROR,
441                        message: format!("Error during flush: {err:#}"),
442                    });
443                    tracer.trace_event(SmtpClientTraceEvent::Closed);
444                }
445                return Err(ClientError::FlushError {
446                    error: format!("{err:#}"),
447                });
448            }
449        }
450
451        let mut line = self.read_line(timeout_duration, command).await?;
452        tracing::trace!("recv<-{}: {line}", self.hostname);
453        let mut parsed = parse_response_line(&line)?;
454        let mut response_builder = ResponseBuilder::new(&parsed);
455
456        let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
457        while !parsed.is_final {
458            line = self
459                .read_line(subsequent_line_timeout_duration, command)
460                .await?;
461            parsed = parse_response_line(&line)?;
462            response_builder
463                .add_line(&parsed)
464                .map_err(ClientError::MalformedResponseLine)?;
465        }
466
467        let response = response_builder.build(command.map(|cmd| cmd.encode().to_string()));
468
469        tracing::trace!("{}: {response:?}", self.hostname);
470
471        Ok(response)
472    }
473
474    pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
475        self.write_command_request(command).await?;
476        self.read_response(Some(command), command.client_timeout(&self.timeouts))
477            .await
478    }
479
480    /// Wrapper around socket.write_all() that will emit trace diagnostics and synthesize
481    /// a Close event to the tracer if a timeout or IO error occurs.
482    /// If an error or timeout, occurs ensures that the socket will not be reused.
483    async fn write_all_with_timeout<F, G>(
484        &mut self,
485        timeout_duration: Duration,
486        bytes: &[u8],
487        make_timeout_err: F,
488        make_write_err: G,
489    ) -> Result<(), ClientError>
490    where
491        F: FnOnce() -> ClientError,
492        G: FnOnce(String) -> ClientError,
493    {
494        match self.socket.as_mut() {
495            Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
496                Ok(Ok(response)) => Ok(response),
497                Ok(Err(err)) => {
498                    self.socket.take();
499                    if let Some(tracer) = &self.tracer {
500                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
501                            level: Level::ERROR,
502                            message: format!("Error during write: {err:#}"),
503                        });
504                        tracer.trace_event(SmtpClientTraceEvent::Closed);
505                    }
506                    Err(make_write_err(format!("{err:#}")))
507                }
508                Err(_) => {
509                    self.socket.take();
510                    if let Some(tracer) = &self.tracer {
511                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
512                            level: Level::ERROR,
513                            message: format!("Write Timeout after {timeout_duration:?}"),
514                        });
515                        tracer.trace_event(SmtpClientTraceEvent::Closed);
516                    }
517                    Err(make_timeout_err())
518                }
519            },
520            None => Err(make_write_err(
521                "the socket was closed in response to an earlier issue".to_string(),
522            )),
523        }
524    }
525
526    async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
527        let total_timeout: Duration = commands
528            .iter()
529            .map(|cmd| cmd.client_timeout_request(&self.timeouts))
530            .sum();
531
532        let mut lines: Vec<String> = vec![];
533        let mut all: Vec<u8> = vec![];
534        for cmd in commands {
535            let line = cmd.encode();
536            all.extend_from_slice(&line);
537            lines.push(line.to_string());
538        }
539        tracing::trace!(
540            "send->{}: (PIPELINE) {}",
541            self.hostname,
542            all.as_bstr().escape_bytes()
543        );
544        if self.socket.is_some() {
545            if let Some(tracer) = &self.tracer {
546                // Send the lines individually to the tracer, so that we
547                // don't break --terse mode
548                for line in &lines {
549                    WriteTracer::trace(tracer, line);
550                }
551            }
552        }
553        self.write_all_with_timeout(
554            total_timeout,
555            &all,
556            || ClientError::TimeOutRequest {
557                duration: total_timeout,
558                commands: commands.to_vec(),
559            },
560            |error| ClientError::WriteError {
561                error,
562                commands: commands.to_vec(),
563            },
564        )
565        .await
566    }
567
568    async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
569        let line = command.encode();
570        tracing::trace!("send->{}: {line}", self.hostname);
571        if self.socket.is_some() {
572            if let Some(tracer) = &self.tracer {
573                WriteTracer::trace(tracer, &line.to_string());
574            }
575        }
576
577        let timeout_duration = command.client_timeout_request(&self.timeouts);
578        self.write_all_with_timeout(
579            timeout_duration,
580            &line,
581            || ClientError::TimeOutRequest {
582                duration: timeout_duration,
583                commands: vec![command.clone()],
584            },
585            |error| ClientError::WriteError {
586                error,
587                commands: vec![command.clone()],
588            },
589        )
590        .await
591    }
592
593    async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
594        if self.socket.is_some() {
595            if let Some(tracer) = &self.tracer {
596                BinWriteTracer::trace(tracer, data);
597            }
598        }
599        let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
600        self.write_all_with_timeout(
601            timeout_duration,
602            data,
603            || ClientError::TimeOutData,
604            |error| ClientError::WriteError {
605                error,
606                commands: vec![],
607            },
608        )
609        .await
610    }
611
612    /// Issue a series of commands, and return the responses to
613    /// those commands.
614    ///
615    /// If the server advertised the RFC 2920 PIPELINING extension,
616    /// the commands are written one after the other before waiting
617    /// to read any data, resulting in lower overall latency due
618    /// to round-trip-times.
619    ///
620    /// If PIPELINING is not available, each command is written
621    /// and the response read before attempting to write the next
622    /// command.
623    ///
624    /// The number of returned responses may be smaller than the
625    /// number of requested commands if there is an issue with
626    /// the network connection.
627    pub async fn pipeline_commands(
628        &mut self,
629        commands: Vec<Command>,
630    ) -> Vec<Result<Response, ClientError>> {
631        let mut results: Vec<Result<Response, ClientError>> = vec![];
632
633        let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
634        if pipeline {
635            if let Err(err) = self.write_pipeline_request(&commands).await {
636                let err: ClientError = err;
637                results.push(Err(err.clone()));
638                while results.len() < commands.len() {
639                    // Synthesize failures for the remaining commands
640                    results.push(Err(err.clone()));
641                }
642                return results;
643            }
644
645            // Now read the responses effectively in a batch
646            for cmd in &commands {
647                results.push(
648                    self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
649                        .await,
650                );
651            }
652            return results;
653        }
654
655        for cmd in &commands {
656            if let Err(err) = self.write_command_request(cmd).await {
657                let err: ClientError = err;
658                results.push(Err(err.clone()));
659                while results.len() < commands.len() {
660                    // Synthesize failures for the remaining commands
661                    results.push(Err(err.clone()));
662                }
663                return results;
664            }
665            // Immediately request the response if the server
666            // doesn't support pipelining
667            results.push(
668                self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
669                    .await,
670            );
671        }
672        results
673    }
674
675    pub async fn ehlo_lhlo(
676        &mut self,
677        ehlo_name: &str,
678        use_lmtp: bool,
679    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
680        if use_lmtp {
681            self.lhlo(ehlo_name).await
682        } else {
683            self.ehlo(ehlo_name).await
684        }
685    }
686
687    pub async fn lhlo(
688        &mut self,
689        ehlo_name: &str,
690    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
691        let response = self
692            .send_command(&Command::Lhlo(Domain::DomainName(
693                ehlo_name
694                    .parse::<DomainString>()
695                    .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
696            )))
697            .await?;
698        self.ehlo_common(response)
699    }
700
701    pub async fn ehlo(
702        &mut self,
703        ehlo_name: &str,
704    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
705        let response = self
706            .send_command(&Command::Ehlo(Domain::DomainName(
707                ehlo_name
708                    .parse::<DomainString>()
709                    .map_err(|_| ClientError::InvalidDnsName(ehlo_name.to_string()))?,
710            )))
711            .await?;
712        self.ehlo_common(response)
713    }
714
715    fn ehlo_common(
716        &mut self,
717        response: Response,
718    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
719        if response.code != 250 {
720            return Err(ClientError::Rejected(response));
721        }
722
723        let mut capabilities = HashMap::new();
724
725        for line in response.content.lines().skip(1) {
726            let mut fields = line.splitn(2, ' ');
727            if let Some(name) = fields.next() {
728                let param = fields.next().map(|s| s.to_string());
729                let cap = EsmtpCapability {
730                    name: name.to_string(),
731                    param,
732                };
733                capabilities.insert(name.to_ascii_uppercase(), cap);
734            }
735        }
736
737        self.capabilities = capabilities;
738        Ok(&self.capabilities)
739    }
740
741    pub async fn auth_plain(
742        &mut self,
743        username: &str,
744        password: Option<&str>,
745    ) -> Result<(), ClientError> {
746        // RFC 4616 says that the format is:
747        // [authzid] NUL authcid NUL passwd
748        let password = password.unwrap_or("");
749        let payload = format!("\x00{username}\x00{password}");
750        let payload = data_encoding::BASE64.encode(payload.as_bytes());
751
752        let response = self
753            .send_command(&Command::Auth {
754                sasl_mech: "PLAIN".to_string(),
755                initial_response: Some(payload),
756            })
757            .await?;
758
759        if response.code != 235 {
760            return Err(ClientError::Rejected(response));
761        }
762
763        Ok(())
764    }
765
766    /// Attempt TLS handshake.
767    /// Returns Err for IO errors.
768    /// On completion, return an option that will be:
769    /// * Some(handshake_error) - if the handshake failed
770    /// * None - if the handshake succeeded
771    pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
772        let resp = self.send_command(&Command::StartTls).await?;
773        if resp.code != 220 {
774            return Err(ClientError::Rejected(resp));
775        }
776
777        let mut handshake_error = None;
778        let mut tls_info = TlsInformation::default();
779
780        let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
781            || !options.dane_tlsa.is_empty()
782        {
783            let connector = options
784                .build_openssl_connector(&self.hostname)
785                .map_err(|error| ClientError::FailedToBuildConnector {
786                    error: error.to_string(),
787                })?;
788            let ssl = connector.into_ssl(self.hostname.as_str())?;
789
790            let (stream, dup_stream) = match self.socket.take() {
791                Some(s) => {
792                    let d = s.try_dup();
793                    (s, d)
794                }
795                None => return Err(ClientError::NotConnected),
796            };
797
798            let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
799
800            if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
801                handshake_error.replace(format!("{err:#}"));
802            }
803
804            tls_info.provider_name = "openssl".to_string();
805            tls_info.cipher = match ssl_stream.ssl().current_cipher() {
806                Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
807                None => String::new(),
808            };
809            tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
810
811            if let Some(cert) = ssl_stream.ssl().peer_certificate() {
812                tls_info.subject_name = subject_name(&cert);
813            }
814            if let Ok(authority) = ssl_stream.ssl().dane_authority() {
815                if let Some(cert) = &authority.cert {
816                    tls_info.subject_name = subject_name(cert);
817                }
818            }
819
820            match (&handshake_error, dup_stream) {
821                (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
822                    // Try falling back to clear text on the duplicate stream.
823                    // This is imperfect: in a failed validation scenario we will
824                    // end up trying to read binary data as a string and get a UTF-8
825                    // error if the peer thinks the session is encrypted.
826                    drop(ssl_stream);
827                    Box::new(dup_stream)
828                }
829                _ => Box::new(ssl_stream),
830            }
831        } else {
832            tls_info.provider_name = "rustls".to_string();
833            let connector = options.build_tls_connector().await.map_err(|error| {
834                ClientError::FailedToBuildConnector {
835                    error: error.to_string(),
836                }
837            })?;
838            let server_name = parse_server_name(self.hostname.as_str())?;
839
840            match connector
841                .connect(
842                    server_name,
843                    match self.socket.take() {
844                        Some(s) => s,
845                        None => return Err(ClientError::NotConnected),
846                    },
847                )
848                .into_fallible()
849                .await
850            {
851                Ok(stream) => {
852                    let (_, conn) = stream.get_ref();
853                    tls_info.cipher = match conn.negotiated_cipher_suite() {
854                        Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
855                        None => String::new(),
856                    };
857                    tls_info.protocol_version = match conn.protocol_version() {
858                        Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
859                        None => String::new(),
860                    };
861
862                    if let Some(certs) = conn.peer_certificates() {
863                        let peer_cert = &certs[0];
864                        if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
865                            tls_info.subject_name = subject_name(&cert);
866                        }
867                    }
868
869                    Box::new(stream)
870                }
871                Err((err, stream)) => {
872                    handshake_error.replace(format!("{err:#}"));
873                    stream
874                }
875            }
876        };
877
878        if let Some(tracer) = &self.tracer {
879            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
880                level: Level::INFO,
881                message: match &handshake_error {
882                    Some(error) => format!("STARTTLS handshake failed: {error:?}"),
883                    None => format!("STARTTLS handshake -> {tls_info:?}"),
884                },
885            });
886        }
887
888        self.socket.replace(stream);
889        Ok(match handshake_error {
890            Some(error) => TlsStatus::FailedHandshake(error),
891            None => TlsStatus::Info(tls_info),
892        })
893    }
894
895    pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
896        &mut self,
897        sender: SENDER,
898        recipient: RECIP,
899        data: B,
900    ) -> Result<Response, ClientError> {
901        let recipient: ForwardPath = recipient.into();
902        let status = self
903            .send_mail_multi_recip(sender, vec![recipient], data)
904            .await?;
905        Ok(status.response)
906    }
907
908    pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
909        &mut self,
910        sender: SENDER,
911        recipient_list: Vec<ForwardPath>,
912        data: B,
913    ) -> Result<BatchSendSuccess, ClientError> {
914        let sender = sender.into();
915
916        let data: &[u8] = data.as_ref();
917        let stuffed;
918
919        let data = match apply_dot_stuffing(data) {
920            Some(d) => {
921                stuffed = d;
922                &stuffed
923            }
924            None => data,
925        };
926
927        let data_is_8bit = data.iter().any(|&b| b >= 0x80);
928        let envelope_is_8bit =
929            !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
930
931        let mut mail_from_params = vec![];
932        if data_is_8bit {
933            if self.capabilities.contains_key("8BITMIME") {
934                mail_from_params.push(EsmtpParameter {
935                    name: "BODY".to_string(),
936                    value: Some("8BITMIME".to_string()),
937                });
938            } else if !self.ignore_8bit_checks {
939                return Err(ClientError::Rejected(Response {
940                    code: 554,
941                    command: None,
942                    enhanced_code: Some(EnhancedStatusCode {
943                        class: 5,
944                        subject: 6,
945                        detail: 3,
946                    }),
947                    content: "KumoMTA internal: DATA is 8bit, destination does \
948                        not support 8BITMIME. Conversion via msg:check_fix_conformance \
949                        during reception is required"
950                        .to_string(),
951                }));
952            }
953        }
954
955        if envelope_is_8bit {
956            if self.capabilities.contains_key("SMTPUTF8") {
957                mail_from_params.push(EsmtpParameter {
958                    name: "SMTPUTF8".to_string(),
959                    value: None,
960                });
961            } else if !self.ignore_8bit_checks {
962                return Err(ClientError::Rejected(Response {
963                    code: 554,
964                    command: None,
965                    enhanced_code: Some(EnhancedStatusCode {
966                        class: 5,
967                        subject: 6,
968                        detail: 7,
969                    }),
970                    content: "KumoMTA internal: envelope is 8bit, destination does \
971                        not support SMTPUTF8."
972                        .to_string(),
973                }));
974            }
975        }
976
977        let mut commands = vec![];
978
979        // We want to avoid using RSET for the first message we send on
980        // a given connection, because postfix can run in a mode where
981        // it will not tolerate RSET because it considers it to be a "junk"
982        // command, and rejecting junk commands will cut down on its load
983        // when it is under stress; it is used as a load shedding approach.
984        // If we always RSET then we will never deliver to a site that is
985        // configured that way. If we take care to RSET only for subsequent
986        // sends, then we should get one message per connection through
987        // without being unfairly penalized for defensively RSETing.
988        let used_rset = self.use_rset;
989        if self.use_rset {
990            commands.push(Command::Rset);
991        }
992        commands.push(Command::MailFrom {
993            address: sender,
994            parameters: mail_from_params,
995        });
996
997        for recipient in &recipient_list {
998            commands.push(Command::RcptTo {
999                address: recipient.clone(),
1000                parameters: vec![],
1001            });
1002        }
1003        commands.push(Command::Data);
1004
1005        // Assume that something might break below: if it does, we want
1006        // to ensure that we RSET the connection on the next go around.
1007        self.use_rset = true;
1008
1009        let mut responses = self.pipeline_commands(commands).await;
1010
1011        // This is a little awkward. We want to handle the RFC 2090 3.1 case
1012        // below, which requires deferring checking the actual response codes
1013        // until later, but we must also handle the case where we had a hard
1014        // transport error partway through pipelining.
1015        // So we set a flag for that case and will then "eagerly", wrt. the
1016        // RFC 2090 3.1 logic, evaluate the SMTP response codes, so that we
1017        // can propagate the correct error disposition up to the caller.
1018        let is_err = responses.iter().any(|r| r.is_err());
1019
1020        if used_rset {
1021            let rset_resp = responses.remove(0)?;
1022            if rset_resp.code != 250 {
1023                return Err(ClientError::Rejected(rset_resp));
1024            }
1025        }
1026
1027        let mail_resp = responses.remove(0)?;
1028        if is_err && mail_resp.code != 250 {
1029            return Err(ClientError::Rejected(mail_resp));
1030        }
1031
1032        let mut rcpt_responses = vec![];
1033        for _ in &recipient_list {
1034            rcpt_responses.push(responses.remove(0)?);
1035        }
1036
1037        if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1038            return Err(ClientError::RejectedBatch(rcpt_responses));
1039        }
1040
1041        let data_resp = responses.remove(0)?;
1042        if is_err && data_resp.code != 354 {
1043            return Err(ClientError::Rejected(data_resp));
1044        }
1045
1046        if data_resp.code == 354
1047            && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1048        {
1049            // RFC 2920 3.1:
1050            // the client cannot assume that the DATA command will be rejected
1051            // just because none of the RCPT TO commands worked.  If the DATA
1052            // command was properly rejected the client SMTP can just issue
1053            // RSET, but if the DATA command was accepted the client SMTP
1054            // should send a single dot.
1055
1056            // Send dummy data
1057            self.write_data_with_timeout(b".\r\n").await?;
1058            let data_dot = Command::DataDot;
1059            // wait for its response
1060            let _ = self
1061                .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1062                .await?;
1063
1064            // Continue below: we will match one of the failure cases and
1065            // return a ClientError::Rejected from one of the earlier
1066            // commands
1067        }
1068
1069        if mail_resp.code != 250 {
1070            return Err(ClientError::Rejected(mail_resp));
1071        }
1072        if rcpt_responses.iter().all(|resp| resp.code != 250) {
1073            if rcpt_responses.len() == 1 {
1074                return Err(ClientError::Rejected(
1075                    rcpt_responses.pop().expect("have at least one"),
1076                ));
1077            }
1078            return Err(ClientError::RejectedBatch(rcpt_responses));
1079        }
1080        if data_resp.code != 354 {
1081            return Err(ClientError::Rejected(data_resp));
1082        }
1083
1084        let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1085
1086        tracing::trace!("message data is {} bytes", data.len());
1087
1088        self.write_data_with_timeout(data).await?;
1089
1090        let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1091
1092        tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1093
1094        self.write_data_with_timeout(marker.as_bytes()).await?;
1095
1096        let data_dot = Command::DataDot;
1097        let resp = self
1098            .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1099            .await?;
1100        if resp.code != 250 {
1101            return Err(ClientError::Rejected(resp));
1102        }
1103
1104        // If everything went well, respect the user preference for speculatively
1105        // issuing an RSET next time around
1106        self.use_rset = self.enable_rset;
1107
1108        Ok(BatchSendSuccess {
1109            response: resp,
1110            rcpt_responses,
1111        })
1112    }
1113}
1114
1115#[derive(Debug)]
1116pub struct BatchSendSuccess {
1117    pub response: Response,
1118    pub rcpt_responses: Vec<Response>,
1119}
1120
1121#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1122pub enum TlsStatus {
1123    FailedHandshake(String),
1124    Info(TlsInformation),
1125}
1126
1127#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1128pub struct TlsInformation {
1129    pub cipher: String,
1130    pub protocol_version: String,
1131    pub subject_name: Vec<String>,
1132    pub provider_name: String,
1133}
1134
1135impl Drop for SmtpClient {
1136    fn drop(&mut self) {
1137        if let Some(tracer) = &self.tracer {
1138            if self.socket.is_some() {
1139                tracer.trace_event(SmtpClientTraceEvent::Closed);
1140            }
1141        }
1142    }
1143}
1144fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1145    if line.len() < 4 {
1146        return Err(ClientError::MalformedResponseLine(line.to_string()));
1147    }
1148
1149    match line.as_bytes()[3] {
1150        b' ' | b'-' => match line[0..3].parse::<u16>() {
1151            Ok(code) => Ok(ResponseLine {
1152                code,
1153                is_final: line.as_bytes()[3] == b' ',
1154                content: &line[4..],
1155            }),
1156            Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1157        },
1158        _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1159    }
1160}
1161
1162fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1163    static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1164
1165    if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1166        return None;
1167    }
1168
1169    let mut stuffed = vec![];
1170    if data.starts_with(b".") {
1171        stuffed.push(b'.');
1172    }
1173    let mut last_idx = 0;
1174    for i in LFDOT.find_iter(data) {
1175        stuffed.extend_from_slice(&data[last_idx..=i]);
1176        stuffed.push(b'.');
1177        last_idx = i + 1;
1178    }
1179    stuffed.extend_from_slice(&data[last_idx..]);
1180    Some(stuffed)
1181}
1182
1183/// Extracts the object=name pairs of the subject name from a cert.
1184/// eg:
1185/// ```text
1186/// ["C=US", "ST=CA", "L=SanFrancisco", "O=Fort-Funston", "OU=MyOrganizationalUnit",
1187/// "CN=do.havedane.net", "name=EasyRSA", "emailAddress=me@myhost.mydomain"]
1188/// ```
1189pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1190    let mut subject_name = vec![];
1191    for entry in cert.subject_name().entries() {
1192        if let Ok(obj) = entry.object().nid().short_name() {
1193            if let Ok(name) = entry.data().as_utf8() {
1194                subject_name.push(format!("{obj}={name}"));
1195            }
1196        }
1197    }
1198    subject_name
1199}
1200
1201fn parse_server_name(input: &str) -> Result<ServerName<'static>, ClientError> {
1202    match IpAddr::from_str(input) {
1203        Ok(ip) => Ok(ServerName::IpAddress(ip.into())),
1204        Err(_) => {
1205            let name = Name::from_str_relaxed(input)
1206                .map_err(|_| ClientError::InvalidDnsName(input.to_string()))?;
1207            ServerName::try_from(name.to_ascii())
1208                .map_err(|_| ClientError::InvalidDnsName(name.to_ascii()))
1209        }
1210    }
1211}
1212
1213#[cfg(test)]
1214mod test {
1215    use super::*;
1216    use crate::parser::{EnvelopeAddress, ReversePath};
1217
1218    #[test]
1219    fn test_stuffing() {
1220        assert_eq!(apply_dot_stuffing(b"foo"), None);
1221        assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1222        assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1223        assert_eq!(
1224            apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1225            b"foo\n..bar\n...baz\n"
1226        );
1227    }
1228
1229    /*
1230    #[tokio::test]
1231    async fn test_against_sink() {
1232        use tokio::net::TcpStream;
1233        let stream = TcpStream::connect("127.0.0.1:2025").await.unwrap();
1234        let mut client =
1235            SmtpClient::with_stream(stream, "localhost", SmtpClientTimeouts::default());
1236        dbg!(client.read_response(None).await).unwrap();
1237        dbg!(client.ehlo("localhost").await).unwrap();
1238        let insecure = true;
1239        dbg!(client.starttls(insecure).await).unwrap();
1240        let resp = client
1241            .send_mail(
1242                ReversePath::try_from("wez@mail.example.com").unwrap(),
1243                ForwardPath::try_from("wez@mail.example.com").unwrap(),
1244                "Subject: hello\r\n\r\nwoot\r\n",
1245            )
1246            .await
1247            .unwrap();
1248        panic!("{resp:#?}");
1249    }
1250    */
1251
1252    #[test]
1253    fn response_line_parsing() {
1254        assert_eq!(
1255            parse_response_line("220 woot").unwrap(),
1256            ResponseLine {
1257                code: 220,
1258                is_final: true,
1259                content: "woot"
1260            }
1261        );
1262        assert_eq!(
1263            parse_response_line("220-woot").unwrap(),
1264            ResponseLine {
1265                code: 220,
1266                is_final: false,
1267                content: "woot"
1268            }
1269        );
1270
1271        assert!(matches!(
1272            parse_response_line("220_woot"),
1273            Err(ClientError::MalformedResponseLine(_))
1274        ));
1275        assert!(matches!(
1276            parse_response_line("not really"),
1277            Err(ClientError::MalformedResponseLine(_))
1278        ));
1279    }
1280
1281    fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1282        let mut parsed = parse_response_line(lines[0])?;
1283        let mut b = ResponseBuilder::new(&parsed);
1284        for line in &lines[1..] {
1285            parsed = parse_response_line(line)?;
1286            b.add_line(&parsed)
1287                .map_err(ClientError::MalformedResponseLine)?;
1288        }
1289        assert!(parsed.is_final);
1290        Ok(b.build(None))
1291    }
1292
1293    #[test]
1294    fn multi_line_response() {
1295        assert_eq!(
1296            parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1297            Response {
1298                code: 220,
1299                enhanced_code: None,
1300                content: "woot\nmore\ndone".to_string(),
1301                command: None
1302            }
1303        );
1304
1305        let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1306        assert!(
1307            matches!(
1308                    res,
1309                ClientError::MalformedResponseLine(ref err) if err == "221-more"
1310            ),
1311            "got error {res:?}"
1312        );
1313
1314        let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1315        assert!(
1316            matches!(
1317                    res,
1318                ClientError::MalformedResponseLine(ref err) if err == "221 done"
1319            ),
1320            "got error {res:?}"
1321        );
1322
1323        assert_eq!(
1324            parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1325            Response {
1326                code: 220,
1327                enhanced_code: Some(EnhancedStatusCode {
1328                    class: 4,
1329                    subject: 1,
1330                    detail: 0
1331                }),
1332                content: "woot\nmore\ndone".to_string(),
1333                command: None
1334            }
1335        );
1336
1337        // Confirm that we strip the enhanced status code from each line
1338        assert_eq!(
1339            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1340            Response {
1341                code: 220,
1342                enhanced_code: Some(EnhancedStatusCode {
1343                    class: 4,
1344                    subject: 1,
1345                    detail: 0
1346                }),
1347                content: "woot\nmore\ndone".to_string(),
1348                command: None
1349            }
1350        );
1351
1352        // ... but only if the code matches that of the first line
1353        assert_eq!(
1354            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1355            Response {
1356                code: 220,
1357                enhanced_code: Some(EnhancedStatusCode {
1358                    class: 4,
1359                    subject: 1,
1360                    detail: 0
1361                }),
1362                content: "woot\nmore\n5.5.5 done".to_string(),
1363                command: None
1364            }
1365        );
1366    }
1367
1368    #[test]
1369    fn test_extract_hostname() {
1370        assert_eq!(extract_hostname("foo"), "foo");
1371        assert_eq!(extract_hostname("foo."), "foo");
1372        assert_eq!(extract_hostname("foo:25"), "foo");
1373        assert_eq!(extract_hostname("foo.:25"), "foo");
1374        assert_eq!(extract_hostname("[foo]:25"), "foo");
1375        assert_eq!(extract_hostname("[foo.]:25"), "foo");
1376        assert_eq!(extract_hostname("[::1]:25"), "::1");
1377        assert_eq!(extract_hostname("::1:25"), "::1");
1378    }
1379
1380    #[test]
1381    fn test_format_error_command() {
1382        assert_eq!(
1383            format!(
1384                "{:#}",
1385                ClientError::TimeOutRequest {
1386                    commands: vec![Command::DataDot],
1387                    duration: Duration::from_secs(10),
1388                }
1389            ),
1390            "Timed Out after 10s writing cmd=."
1391        );
1392        assert_eq!(
1393            format!(
1394                "{:#}",
1395                ClientError::TimeOutResponse {
1396                    command: Some(Command::MailFrom {
1397                        address: {
1398                            let EnvelopeAddress::Path(p) =
1399                                EnvelopeAddress::parse("user@host").unwrap()
1400                            else {
1401                                panic!("expected Path")
1402                            };
1403                            ReversePath::Path(p)
1404                        },
1405                        parameters: vec![],
1406                    }),
1407                    duration: Duration::from_secs(10),
1408                }
1409            ),
1410            r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1411        );
1412    }
1413
1414    #[test]
1415    fn test_issue_533() {
1416        let _name = parse_server_name("münchen.de").unwrap();
1417    }
1418}