rfc5321/
client.rs

1#![allow(clippy::result_large_err)]
2use crate::client_types::*;
3use crate::{
4    AsyncReadAndWrite, BoxedAsyncReadAndWrite, Command, Domain, EsmtpParameter, ForwardPath,
5    ReversePath,
6};
7use hickory_proto::rr::rdata::TLSA;
8use memchr::memmem::Finder;
9use openssl::x509::{X509Ref, X509};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::IpAddr;
13use std::str::FromStr;
14use std::sync::{Arc, LazyLock};
15use std::time::Duration;
16use thiserror::Error;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpStream, ToSocketAddrs};
19use tokio::time::timeout;
20use tokio_rustls::rustls::pki_types::ServerName;
21use tracing::Level;
22
23pub use kumo_tls_helper::TlsOptions;
24pub use {openssl, tokio_rustls};
25
26const MAX_LINE_LEN: usize = 4096;
27
28#[derive(Error, Debug, Clone)]
29pub enum ClientError {
30    #[error("response is not UTF8")]
31    Utf8(#[from] std::string::FromUtf8Error),
32    #[error("Malformed Response: {0}")]
33    MalformedResponseLine(String),
34    #[error("Response line is too long")]
35    ResponseTooLong,
36    #[error("Not connected")]
37    NotConnected,
38    #[error("Command rejected {0:?}")]
39    Rejected(Response),
40    #[error("Commands rejected {0:?}")]
41    RejectedBatch(Vec<Response>),
42    #[error("STARTTLS: {0} is not a valid DNS name")]
43    InvalidDnsName(String),
44    #[error("Invalid client certificate configured: {error:?}")]
45    FailedToBuildConnector { error: String },
46    #[error("Timed Out waiting {duration:?} for response to cmd={}", self.command_for_err())]
47    TimeOutResponse {
48        command: Option<Command>,
49        duration: Duration,
50    },
51    #[error("Timed Out after {duration:?} writing cmd={}", self.command_for_err())]
52    TimeOutRequest {
53        commands: Vec<Command>,
54        duration: Duration,
55    },
56    #[error("Error {error} reading response to cmd={}", self.command_for_err())]
57    ReadError {
58        command: Option<Command>,
59        error: String,
60        partial: String,
61    },
62    #[error("Error {error} flushing send buffer")]
63    FlushError { error: String },
64    #[error("Error {error} writing {}", self.command_for_err())]
65    WriteError {
66        commands: Vec<Command>,
67        error: String,
68    },
69    #[error("Timed Out sending message payload data")]
70    TimeOutData,
71    #[error("SSL Error: {0}")]
72    SslErrorStack(#[from] openssl::error::ErrorStack),
73    #[error("No usable DANE TLSA records for {hostname}: {tlsa:?}")]
74    NoUsableDaneTlsa { hostname: String, tlsa: Vec<TLSA> },
75}
76
77impl ClientError {
78    /// Returns the command(s) string suitable for passing into a Response
79    pub fn command(&self) -> Option<String> {
80        match self {
81            Self::TimeOutResponse {
82                command: Some(command),
83                ..
84            }
85            | Self::ReadError {
86                command: Some(command),
87                ..
88            } => Some(command.encode()),
89            Self::TimeOutRequest { commands, .. } | Self::WriteError { commands, .. }
90                if !commands.is_empty() =>
91            {
92                let commands: Vec<String> = commands.into_iter().map(|cmd| cmd.encode()).collect();
93                Some(commands.join(""))
94            }
95            _ => None,
96        }
97    }
98
99    /// Returns the command(s) string suitable for embedding in the
100    /// overall error message
101    fn command_for_err(&self) -> String {
102        self.command()
103            .map(|cmd| cmd.replace("\r\n", ""))
104            .unwrap_or_else(|| "NONE".to_string())
105    }
106
107    /// If the error contents were likely caused by something
108    /// about the mostly recently attempted message, rather than
109    /// a transport issue, or a carry-over from a prior message
110    /// (eg: previous message was rejected and destination chose
111    /// to drop the connection, which we detect later in RSET
112    /// on the next message), then we return true.
113    /// The expectation is that the caller will transiently
114    /// fail the message for later retry.
115    /// If we return false then the caller might decide to
116    /// try that same message again more immediately on
117    /// a separate connection
118    pub fn was_due_to_message(&self) -> bool {
119        match self {
120            Self::Utf8(_)
121            | Self::MalformedResponseLine(_)
122            | Self::ResponseTooLong
123            | Self::NotConnected
124            | Self::InvalidDnsName(_)
125            | Self::FailedToBuildConnector { .. }
126            | Self::TimeOutResponse { .. }
127            | Self::TimeOutRequest { .. }
128            | Self::ReadError { .. }
129            | Self::FlushError { .. }
130            | Self::WriteError { .. }
131            | Self::TimeOutData
132            | Self::SslErrorStack(_)
133            | Self::NoUsableDaneTlsa { .. } => false,
134            Self::Rejected(response) => response.was_due_to_message(),
135            Self::RejectedBatch(responses) => match responses.len() {
136                1 => responses[0].was_due_to_message(),
137                _ => false,
138            },
139        }
140    }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
144pub struct EsmtpCapability {
145    pub name: String,
146    pub param: Option<String>,
147}
148
149#[derive(Clone, Debug)]
150pub enum SmtpClientTraceEvent {
151    Closed,
152    Read(Vec<u8>),
153    Write(String),
154    Diagnostic {
155        level: tracing::Level,
156        message: String,
157    },
158}
159
160pub trait DeferredTracer {
161    fn trace(&self) -> SmtpClientTraceEvent;
162}
163
164pub trait SmtpClientTracer: std::fmt::Debug {
165    fn trace_event(&self, event: SmtpClientTraceEvent);
166    fn lazy_trace(&self, deferred: &dyn DeferredTracer);
167}
168
169// helper to avoid making a second copy of every write buffer
170struct WriteTracer<'a> {
171    data: &'a str,
172}
173impl DeferredTracer for WriteTracer<'_> {
174    fn trace(&self) -> SmtpClientTraceEvent {
175        SmtpClientTraceEvent::Write(self.data.to_string())
176    }
177}
178impl<'a> WriteTracer<'a> {
179    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a str) {
180        tracer.lazy_trace(&Self { data });
181    }
182}
183
184struct BinWriteTracer<'a> {
185    data: &'a [u8],
186}
187impl DeferredTracer for BinWriteTracer<'_> {
188    fn trace(&self) -> SmtpClientTraceEvent {
189        let data = String::from_utf8_lossy(self.data).to_string();
190        SmtpClientTraceEvent::Write(data)
191    }
192}
193impl<'a> BinWriteTracer<'a> {
194    fn trace(tracer: &Arc<dyn SmtpClientTracer + Send + Sync>, data: &'a [u8]) {
195        tracer.lazy_trace(&Self { data });
196    }
197}
198
199// A little bit of gymnastics to avoid making a second
200// copy of every read buffer
201struct ReadTracer<'a> {
202    data: &'a [u8],
203}
204impl DeferredTracer for ReadTracer<'_> {
205    fn trace(&self) -> SmtpClientTraceEvent {
206        SmtpClientTraceEvent::Read(self.data.to_vec())
207    }
208}
209
210#[derive(Debug)]
211pub struct SmtpClient {
212    socket: Option<BoxedAsyncReadAndWrite>,
213    hostname: String,
214    capabilities: HashMap<String, EsmtpCapability>,
215    read_buffer: Vec<u8>,
216    timeouts: SmtpClientTimeouts,
217    tracer: Option<Arc<dyn SmtpClientTracer + Send + Sync>>,
218    use_rset: bool,
219    enable_rset: bool,
220    enable_pipelining: bool,
221    ignore_8bit_checks: bool,
222}
223
224fn extract_hostname(hostname: &str) -> &str {
225    // Just the hostname, without any :port
226    let fields: Vec<&str> = hostname.rsplitn(2, ':').collect();
227    let hostname = if fields.len() == 2 {
228        fields[1]
229    } else {
230        hostname
231    };
232
233    let hostname = if hostname.starts_with('[') && hostname.ends_with(']') {
234        &hostname[1..hostname.len() - 1]
235    } else {
236        hostname
237    };
238
239    // Remove any trailing FQDN dot
240    hostname.strip_suffix(".").unwrap_or(hostname)
241}
242
243impl SmtpClient {
244    pub async fn new<A: ToSocketAddrs + ToString + Clone>(
245        addr: A,
246        timeouts: SmtpClientTimeouts,
247    ) -> std::io::Result<Self> {
248        let stream = TcpStream::connect(addr.clone()).await?;
249        // No need for Nagle with SMTP request/response
250        stream.set_nodelay(true)?;
251        Ok(Self::with_stream(stream, addr.to_string(), timeouts))
252    }
253
254    pub fn with_stream<S: AsyncReadAndWrite + 'static, H: AsRef<str>>(
255        stream: S,
256        peer_hostname: H,
257        timeouts: SmtpClientTimeouts,
258    ) -> Self {
259        let hostname = extract_hostname(peer_hostname.as_ref()).to_string();
260
261        Self {
262            socket: Some(Box::new(stream)),
263            hostname,
264            capabilities: HashMap::new(),
265            read_buffer: Vec::with_capacity(1024),
266            timeouts,
267            tracer: None,
268            use_rset: false,
269            enable_rset: false,
270            enable_pipelining: false,
271            ignore_8bit_checks: false,
272        }
273    }
274
275    /// Setting this to true causes 8BITMIME and SMTPUTF8 checks to be ignored
276    pub fn set_ignore_8bit_checks(&mut self, enable: bool) {
277        self.ignore_8bit_checks = enable;
278    }
279
280    pub fn is_connected(&self) -> bool {
281        self.socket.is_some()
282    }
283
284    pub fn set_enable_rset(&mut self, enable: bool) {
285        self.enable_rset = enable;
286    }
287
288    pub fn set_enable_pipelining(&mut self, enable: bool) {
289        self.enable_pipelining = enable;
290    }
291
292    pub fn set_tracer(&mut self, tracer: Arc<dyn SmtpClientTracer + Send + Sync>) {
293        self.tracer.replace(tracer);
294    }
295
296    pub fn timeouts(&self) -> &SmtpClientTimeouts {
297        &self.timeouts
298    }
299
300    async fn read_line(
301        &mut self,
302        timeout_duration: Duration,
303        cmd: Option<&Command>,
304    ) -> Result<String, ClientError> {
305        let mut too_long = false;
306        loop {
307            let mut iter = self.read_buffer.iter().enumerate();
308            while let Some((i, &b)) = iter.next() {
309                if b != b'\r' {
310                    continue;
311                }
312                if let Some((_, b'\n')) = iter.next() {
313                    if too_long {
314                        self.read_buffer.drain(0..i + 2);
315
316                        if let Some(tracer) = &self.tracer {
317                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
318                                level: Level::ERROR,
319                                message: "Response too long".to_string(),
320                            });
321                        }
322
323                        return Err(ClientError::ResponseTooLong);
324                    }
325
326                    let line = String::from_utf8(self.read_buffer[0..i].to_vec());
327                    self.read_buffer.drain(0..i + 2);
328                    return Ok(line?);
329                }
330            }
331            if self.read_buffer.len() > MAX_LINE_LEN {
332                self.read_buffer.clear();
333                too_long = true;
334            }
335
336            // Didn't find a complete line, fill up the rest of the buffer
337            let mut data = [0u8; MAX_LINE_LEN];
338            let size = match self.socket.as_mut() {
339                Some(s) => match timeout(timeout_duration, s.read(&mut data)).await {
340                    Ok(Ok(size)) => size,
341                    Ok(Err(err)) => {
342                        self.socket.take();
343                        if let Some(tracer) = &self.tracer {
344                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
345                                level: Level::ERROR,
346                                message: format!("Error during read: {err:#}"),
347                            });
348                            tracer.trace_event(SmtpClientTraceEvent::Closed);
349                        }
350                        return Err(ClientError::ReadError {
351                            command: cmd.cloned(),
352                            error: format!("{err:#}"),
353                            partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
354                        });
355                    }
356                    Err(_) => {
357                        // If we're using a zero timeout we're probing
358                        // for liveness, so we don't want to treat this
359                        // timeout as a fatal error that closes the
360                        // connection
361                        if timeout_duration != Duration::ZERO {
362                            self.socket.take();
363                        }
364                        if let Some(tracer) = &self.tracer {
365                            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
366                                level: Level::ERROR,
367                                message: format!("Read Timeout after {timeout_duration:?}"),
368                            });
369                            if self.socket.is_none() {
370                                tracer.trace_event(SmtpClientTraceEvent::Closed);
371                            }
372                        }
373                        return Err(ClientError::TimeOutResponse {
374                            command: cmd.cloned(),
375                            duration: timeout_duration,
376                        });
377                    }
378                },
379                None => {
380                    return Err(ClientError::ReadError {
381                        command: cmd.cloned(),
382                        error: "the socket was closed in response to an earlier issue".to_string(),
383                        partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
384                    });
385                }
386            };
387            if size == 0 {
388                self.socket.take();
389                if let Some(tracer) = &self.tracer {
390                    tracer.trace_event(SmtpClientTraceEvent::Closed);
391                }
392                return Err(ClientError::ReadError {
393                    command: cmd.cloned(),
394                    error: "Connection closed by peer".to_string(),
395                    partial: String::from_utf8_lossy(&self.read_buffer).to_string(),
396                });
397            }
398            if let Some(tracer) = &self.tracer {
399                tracer.lazy_trace(&ReadTracer {
400                    data: &data[0..size],
401                });
402            }
403            self.read_buffer.extend_from_slice(&data[0..size]);
404        }
405    }
406
407    /// Check to see if there is either a unilateral response
408    /// or an error condition immediately available on the socket.
409    ///
410    /// Intended to be used after the connection has idled to see if
411    /// the peer has opted to close the connection, either gracefully
412    /// by sending what should be a 421, or abruptly by simply snipping
413    /// the connection.
414    pub async fn check_unilateral_response(&mut self) -> Result<Option<Response>, ClientError> {
415        match self.read_response(None, Duration::ZERO).await {
416            // Unilateral response was received
417            Ok(response) => Ok(Some(response)),
418            // Nothing immediately available: this should hopefully be the common case
419            Err(ClientError::TimeOutResponse { .. }) => Ok(None),
420            // An actual error
421            Err(err) => Err(err),
422        }
423    }
424
425    pub async fn read_response(
426        &mut self,
427        command: Option<&Command>,
428        timeout_duration: Duration,
429    ) -> Result<Response, ClientError> {
430        if let Some(sock) = self.socket.as_mut() {
431            if let Err(err) = sock.flush().await {
432                self.socket.take();
433                if let Some(tracer) = &self.tracer {
434                    tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
435                        level: Level::ERROR,
436                        message: format!("Error during flush: {err:#}"),
437                    });
438                    tracer.trace_event(SmtpClientTraceEvent::Closed);
439                }
440                return Err(ClientError::FlushError {
441                    error: format!("{err:#}"),
442                });
443            }
444        }
445
446        let mut line = self.read_line(timeout_duration, command).await?;
447        tracing::trace!("recv<-{}: {line}", self.hostname);
448        let mut parsed = parse_response_line(&line)?;
449        let mut response_builder = ResponseBuilder::new(&parsed);
450
451        let subsequent_line_timeout_duration = Duration::from_secs(60).min(timeout_duration);
452        while !parsed.is_final {
453            line = self
454                .read_line(subsequent_line_timeout_duration, command)
455                .await?;
456            parsed = parse_response_line(&line)?;
457            response_builder
458                .add_line(&parsed)
459                .map_err(ClientError::MalformedResponseLine)?;
460        }
461
462        let response = response_builder.build(command.map(|cmd| cmd.encode()));
463
464        tracing::trace!("{}: {response:?}", self.hostname);
465
466        Ok(response)
467    }
468
469    pub async fn send_command(&mut self, command: &Command) -> Result<Response, ClientError> {
470        self.write_command_request(command).await?;
471        self.read_response(Some(command), command.client_timeout(&self.timeouts))
472            .await
473    }
474
475    /// Wrapper around socket.write_all() that will emit trace diagnostics and synthesize
476    /// a Close event to the tracer if a timeout or IO error occurs.
477    /// If an error or timeout, occurs ensures that the socket will not be reused.
478    async fn write_all_with_timeout<F, G>(
479        &mut self,
480        timeout_duration: Duration,
481        bytes: &[u8],
482        make_timeout_err: F,
483        make_write_err: G,
484    ) -> Result<(), ClientError>
485    where
486        F: FnOnce() -> ClientError,
487        G: FnOnce(String) -> ClientError,
488    {
489        match self.socket.as_mut() {
490            Some(socket) => match timeout(timeout_duration, socket.write_all(bytes)).await {
491                Ok(Ok(response)) => Ok(response),
492                Ok(Err(err)) => {
493                    self.socket.take();
494                    if let Some(tracer) = &self.tracer {
495                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
496                            level: Level::ERROR,
497                            message: format!("Error during write: {err:#}"),
498                        });
499                        tracer.trace_event(SmtpClientTraceEvent::Closed);
500                    }
501                    Err(make_write_err(format!("{err:#}")))
502                }
503                Err(_) => {
504                    self.socket.take();
505                    if let Some(tracer) = &self.tracer {
506                        tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
507                            level: Level::ERROR,
508                            message: format!("Write Timeout after {timeout_duration:?}"),
509                        });
510                        tracer.trace_event(SmtpClientTraceEvent::Closed);
511                    }
512                    Err(make_timeout_err())
513                }
514            },
515            None => Err(make_write_err(
516                "the socket was closed in response to an earlier issue".to_string(),
517            )),
518        }
519    }
520
521    async fn write_pipeline_request(&mut self, commands: &[Command]) -> Result<(), ClientError> {
522        let total_timeout: Duration = commands
523            .iter()
524            .map(|cmd| cmd.client_timeout_request(&self.timeouts))
525            .sum();
526
527        let mut lines: Vec<String> = vec![];
528        let mut all = String::new();
529        for cmd in commands {
530            let line = cmd.encode();
531            all.push_str(&line);
532            lines.push(line);
533        }
534        tracing::trace!("send->{}: (PIPELINE) {all}", self.hostname);
535        if self.socket.is_some() {
536            if let Some(tracer) = &self.tracer {
537                // Send the lines individually to the tracer, so that we
538                // don't break --terse mode
539                for line in lines {
540                    WriteTracer::trace(tracer, &line);
541                }
542            }
543        }
544        self.write_all_with_timeout(
545            total_timeout,
546            all.as_bytes(),
547            || ClientError::TimeOutRequest {
548                duration: total_timeout,
549                commands: commands.to_vec(),
550            },
551            |error| ClientError::WriteError {
552                error,
553                commands: commands.to_vec(),
554            },
555        )
556        .await
557    }
558
559    async fn write_command_request(&mut self, command: &Command) -> Result<(), ClientError> {
560        let line = command.encode();
561        tracing::trace!("send->{}: {line}", self.hostname);
562        if self.socket.is_some() {
563            if let Some(tracer) = &self.tracer {
564                WriteTracer::trace(tracer, &line);
565            }
566        }
567
568        let timeout_duration = command.client_timeout_request(&self.timeouts);
569        self.write_all_with_timeout(
570            timeout_duration,
571            line.as_bytes(),
572            || ClientError::TimeOutRequest {
573                duration: timeout_duration,
574                commands: vec![command.clone()],
575            },
576            |error| ClientError::WriteError {
577                error,
578                commands: vec![command.clone()],
579            },
580        )
581        .await
582    }
583
584    async fn write_data_with_timeout(&mut self, data: &[u8]) -> Result<(), ClientError> {
585        if self.socket.is_some() {
586            if let Some(tracer) = &self.tracer {
587                BinWriteTracer::trace(tracer, data);
588            }
589        }
590        let timeout_duration = Command::Data.client_timeout_request(&self.timeouts);
591        self.write_all_with_timeout(
592            timeout_duration,
593            data,
594            || ClientError::TimeOutData,
595            |error| ClientError::WriteError {
596                error,
597                commands: vec![],
598            },
599        )
600        .await
601    }
602
603    /// Issue a series of commands, and return the responses to
604    /// those commands.
605    ///
606    /// If the server advertised the RFC 2920 PIPELINING extension,
607    /// the commands are written one after the other before waiting
608    /// to read any data, resulting in lower overall latency due
609    /// to round-trip-times.
610    ///
611    /// If PIPELINING is not available, each command is written
612    /// and the response read before attempting to write the next
613    /// command.
614    ///
615    /// The number of returned responses may be smaller than the
616    /// number of requested commands if there is an issue with
617    /// the network connection.
618    pub async fn pipeline_commands(
619        &mut self,
620        commands: Vec<Command>,
621    ) -> Vec<Result<Response, ClientError>> {
622        let mut results: Vec<Result<Response, ClientError>> = vec![];
623
624        let pipeline = self.enable_pipelining && self.capabilities.contains_key("PIPELINING");
625        if pipeline {
626            if let Err(err) = self.write_pipeline_request(&commands).await {
627                let err: ClientError = err;
628                results.push(Err(err.clone()));
629                while results.len() < commands.len() {
630                    // Synthesize failures for the remaining commands
631                    results.push(Err(err.clone()));
632                }
633                return results;
634            }
635
636            // Now read the responses effectively in a batch
637            for cmd in &commands {
638                results.push(
639                    self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
640                        .await,
641                );
642            }
643            return results;
644        }
645
646        for cmd in &commands {
647            if let Err(err) = self.write_command_request(cmd).await {
648                let err: ClientError = err;
649                results.push(Err(err.clone()));
650                while results.len() < commands.len() {
651                    // Synthesize failures for the remaining commands
652                    results.push(Err(err.clone()));
653                }
654                return results;
655            }
656            // Immediately request the response if the server
657            // doesn't support pipelining
658            results.push(
659                self.read_response(Some(cmd), cmd.client_timeout(&self.timeouts))
660                    .await,
661            );
662        }
663        results
664    }
665
666    pub async fn ehlo_lhlo(
667        &mut self,
668        ehlo_name: &str,
669        use_lmtp: bool,
670    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
671        if use_lmtp {
672            self.lhlo(ehlo_name).await
673        } else {
674            self.ehlo(ehlo_name).await
675        }
676    }
677
678    pub async fn lhlo(
679        &mut self,
680        ehlo_name: &str,
681    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
682        let response = self
683            .send_command(&Command::Lhlo(Domain::Name(ehlo_name.to_string())))
684            .await?;
685        self.ehlo_common(response)
686    }
687
688    pub async fn ehlo(
689        &mut self,
690        ehlo_name: &str,
691    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
692        let response = self
693            .send_command(&Command::Ehlo(Domain::Name(ehlo_name.to_string())))
694            .await?;
695        self.ehlo_common(response)
696    }
697
698    fn ehlo_common(
699        &mut self,
700        response: Response,
701    ) -> Result<&HashMap<String, EsmtpCapability>, ClientError> {
702        if response.code != 250 {
703            return Err(ClientError::Rejected(response));
704        }
705
706        let mut capabilities = HashMap::new();
707
708        for line in response.content.lines().skip(1) {
709            let mut fields = line.splitn(2, ' ');
710            if let Some(name) = fields.next() {
711                let param = fields.next().map(|s| s.to_string());
712                let cap = EsmtpCapability {
713                    name: name.to_string(),
714                    param,
715                };
716                capabilities.insert(name.to_ascii_uppercase(), cap);
717            }
718        }
719
720        self.capabilities = capabilities;
721        Ok(&self.capabilities)
722    }
723
724    pub async fn auth_plain(
725        &mut self,
726        username: &str,
727        password: Option<&str>,
728    ) -> Result<(), ClientError> {
729        // RFC 4616 says that the format is:
730        // [authzid] NUL authcid NUL passwd
731        let password = password.unwrap_or("");
732        let payload = format!("\x00{username}\x00{password}");
733        let payload = data_encoding::BASE64.encode(payload.as_bytes());
734
735        let response = self
736            .send_command(&Command::Auth {
737                sasl_mech: "PLAIN".to_string(),
738                initial_response: Some(payload),
739            })
740            .await?;
741
742        if response.code != 235 {
743            return Err(ClientError::Rejected(response));
744        }
745
746        Ok(())
747    }
748
749    /// Attempt TLS handshake.
750    /// Returns Err for IO errors.
751    /// On completion, return an option that will be:
752    /// * Some(handshake_error) - if the handshake failed
753    /// * None - if the handshake succeeded
754    pub async fn starttls(&mut self, options: TlsOptions) -> Result<TlsStatus, ClientError> {
755        let resp = self.send_command(&Command::StartTls).await?;
756        if resp.code != 220 {
757            return Err(ClientError::Rejected(resp));
758        }
759
760        let mut handshake_error = None;
761        let mut tls_info = TlsInformation::default();
762
763        let stream: BoxedAsyncReadAndWrite = if options.prefer_openssl
764            || !options.dane_tlsa.is_empty()
765        {
766            let connector = options
767                .build_openssl_connector(&self.hostname)
768                .map_err(|error| ClientError::FailedToBuildConnector {
769                    error: error.to_string(),
770                })?;
771            let ssl = connector.into_ssl(self.hostname.as_str())?;
772
773            let (stream, dup_stream) = match self.socket.take() {
774                Some(s) => {
775                    let d = s.try_dup();
776                    (s, d)
777                }
778                None => return Err(ClientError::NotConnected),
779            };
780
781            let mut ssl_stream = tokio_openssl::SslStream::new(ssl, stream)?;
782
783            if let Err(err) = std::pin::Pin::new(&mut ssl_stream).connect().await {
784                handshake_error.replace(format!("{err:#}"));
785            }
786
787            tls_info.provider_name = "openssl".to_string();
788            tls_info.cipher = match ssl_stream.ssl().current_cipher() {
789                Some(cipher) => cipher.standard_name().unwrap_or(cipher.name()).to_string(),
790                None => String::new(),
791            };
792            tls_info.protocol_version = ssl_stream.ssl().version_str().to_string();
793
794            if let Some(cert) = ssl_stream.ssl().peer_certificate() {
795                tls_info.subject_name = subject_name(&cert);
796            }
797            if let Ok(authority) = ssl_stream.ssl().dane_authority() {
798                if let Some(cert) = &authority.cert {
799                    tls_info.subject_name = subject_name(cert);
800                }
801            }
802
803            match (&handshake_error, dup_stream) {
804                (Some(_), Some(dup_stream)) if !ssl_stream.ssl().is_init_finished() => {
805                    // Try falling back to clear text on the duplicate stream.
806                    // This is imperfect: in a failed validation scenario we will
807                    // end up trying to read binary data as a string and get a UTF-8
808                    // error if the peer thinks the session is encrypted.
809                    drop(ssl_stream);
810                    Box::new(dup_stream)
811                }
812                _ => Box::new(ssl_stream),
813            }
814        } else {
815            tls_info.provider_name = "rustls".to_string();
816            let connector = options.build_tls_connector().await.map_err(|error| {
817                ClientError::FailedToBuildConnector {
818                    error: error.to_string(),
819                }
820            })?;
821            let server_name = match IpAddr::from_str(self.hostname.as_str()) {
822                Ok(ip) => ServerName::IpAddress(ip.into()),
823                Err(_) => ServerName::try_from(self.hostname.clone())
824                    .map_err(|_| ClientError::InvalidDnsName(self.hostname.clone()))?,
825            };
826
827            match connector
828                .connect(
829                    server_name,
830                    match self.socket.take() {
831                        Some(s) => s,
832                        None => return Err(ClientError::NotConnected),
833                    },
834                )
835                .into_fallible()
836                .await
837            {
838                Ok(stream) => {
839                    let (_, conn) = stream.get_ref();
840                    tls_info.cipher = match conn.negotiated_cipher_suite() {
841                        Some(suite) => suite.suite().as_str().unwrap_or("UNKNOWN").to_string(),
842                        None => String::new(),
843                    };
844                    tls_info.protocol_version = match conn.protocol_version() {
845                        Some(version) => version.as_str().unwrap_or("UNKNOWN").to_string(),
846                        None => String::new(),
847                    };
848
849                    if let Some(certs) = conn.peer_certificates() {
850                        let peer_cert = &certs[0];
851                        if let Ok(cert) = X509::from_der(peer_cert.as_ref()) {
852                            tls_info.subject_name = subject_name(&cert);
853                        }
854                    }
855
856                    Box::new(stream)
857                }
858                Err((err, stream)) => {
859                    handshake_error.replace(format!("{err:#}"));
860                    stream
861                }
862            }
863        };
864
865        if let Some(tracer) = &self.tracer {
866            tracer.trace_event(SmtpClientTraceEvent::Diagnostic {
867                level: Level::INFO,
868                message: match &handshake_error {
869                    Some(error) => format!("STARTTLS handshake failed: {error:?}"),
870                    None => format!("STARTTLS handshake -> {tls_info:?}"),
871                },
872            });
873        }
874
875        self.socket.replace(stream);
876        Ok(match handshake_error {
877            Some(error) => TlsStatus::FailedHandshake(error),
878            None => TlsStatus::Info(tls_info),
879        })
880    }
881
882    pub async fn send_mail<B: AsRef<[u8]>, SENDER: Into<ReversePath>, RECIP: Into<ForwardPath>>(
883        &mut self,
884        sender: SENDER,
885        recipient: RECIP,
886        data: B,
887    ) -> Result<Response, ClientError> {
888        let recipient: ForwardPath = recipient.into();
889        let status = self
890            .send_mail_multi_recip(sender, vec![recipient], data)
891            .await?;
892        Ok(status.response)
893    }
894
895    pub async fn send_mail_multi_recip<B: AsRef<[u8]>, SENDER: Into<ReversePath>>(
896        &mut self,
897        sender: SENDER,
898        recipient_list: Vec<ForwardPath>,
899        data: B,
900    ) -> Result<BatchSendSuccess, ClientError> {
901        let sender = sender.into();
902
903        let data: &[u8] = data.as_ref();
904        let stuffed;
905
906        let data = match apply_dot_stuffing(data) {
907            Some(d) => {
908                stuffed = d;
909                &stuffed
910            }
911            None => data,
912        };
913
914        let data_is_8bit = data.iter().any(|&b| b >= 0x80);
915        let envelope_is_8bit =
916            !sender.is_ascii() || recipient_list.iter().any(|recipient| !recipient.is_ascii());
917
918        let mut mail_from_params = vec![];
919        if data_is_8bit {
920            if self.capabilities.contains_key("8BITMIME") {
921                mail_from_params.push(EsmtpParameter {
922                    name: "BODY".to_string(),
923                    value: Some("8BITMIME".to_string()),
924                });
925            } else if !self.ignore_8bit_checks {
926                return Err(ClientError::Rejected(Response {
927                    code: 554,
928                    command: None,
929                    enhanced_code: Some(EnhancedStatusCode {
930                        class: 5,
931                        subject: 6,
932                        detail: 3,
933                    }),
934                    content: "KumoMTA internal: DATA is 8bit, destination does \
935                        not support 8BITMIME. Conversion via msg:check_fix_conformance \
936                        during reception is required"
937                        .to_string(),
938                }));
939            }
940        }
941
942        if envelope_is_8bit {
943            if self.capabilities.contains_key("SMTPUTF8") {
944                mail_from_params.push(EsmtpParameter {
945                    name: "SMTPUTF8".to_string(),
946                    value: None,
947                });
948            } else if !self.ignore_8bit_checks {
949                return Err(ClientError::Rejected(Response {
950                    code: 554,
951                    command: None,
952                    enhanced_code: Some(EnhancedStatusCode {
953                        class: 5,
954                        subject: 6,
955                        detail: 7,
956                    }),
957                    content: "KumoMTA internal: envelope is 8bit, destination does \
958                        not support SMTPUTF8."
959                        .to_string(),
960                }));
961            }
962        }
963
964        let mut commands = vec![];
965
966        // We want to avoid using RSET for the first message we send on
967        // a given connection, because postfix can run in a mode where
968        // it will not tolerate RSET because it considers it to be a "junk"
969        // command, and rejecting junk commands will cut down on its load
970        // when it is under stress; it is used as a load shedding approach.
971        // If we always RSET then we will never deliver to a site that is
972        // configured that way. If we take care to RSET only for subsequent
973        // sends, then we should get one message per connection through
974        // without being unfairly penalized for defensively RSETing.
975        let used_rset = self.use_rset;
976        if self.use_rset {
977            commands.push(Command::Rset);
978        }
979        commands.push(Command::MailFrom {
980            address: sender,
981            parameters: mail_from_params,
982        });
983
984        for recipient in &recipient_list {
985            commands.push(Command::RcptTo {
986                address: recipient.clone(),
987                parameters: vec![],
988            });
989        }
990        commands.push(Command::Data);
991
992        // Assume that something might break below: if it does, we want
993        // to ensure that we RSET the connection on the next go around.
994        self.use_rset = true;
995
996        let mut responses = self.pipeline_commands(commands).await;
997
998        // This is a little awkward. We want to handle the RFC 2090 3.1 case
999        // below, which requires deferring checking the actual response codes
1000        // until later, but we must also handle the case where we had a hard
1001        // transport error partway through pipelining.
1002        // So we set a flag for that case and will then "eagerly", wrt. the
1003        // RFC 2090 3.1 logic, evaluate the SMTP response codes, so that we
1004        // can propagate the correct error disposition up to the caller.
1005        let is_err = responses.iter().any(|r| r.is_err());
1006
1007        if used_rset {
1008            let rset_resp = responses.remove(0)?;
1009            if rset_resp.code != 250 {
1010                return Err(ClientError::Rejected(rset_resp));
1011            }
1012        }
1013
1014        let mail_resp = responses.remove(0)?;
1015        if is_err && mail_resp.code != 250 {
1016            return Err(ClientError::Rejected(mail_resp));
1017        }
1018
1019        let mut rcpt_responses = vec![];
1020        for _ in &recipient_list {
1021            rcpt_responses.push(responses.remove(0)?);
1022        }
1023
1024        if is_err && rcpt_responses.iter().all(|resp| resp.code != 250) {
1025            return Err(ClientError::RejectedBatch(rcpt_responses));
1026        }
1027
1028        let data_resp = responses.remove(0)?;
1029        if is_err && data_resp.code != 354 {
1030            return Err(ClientError::Rejected(data_resp));
1031        }
1032
1033        if data_resp.code == 354
1034            && (mail_resp.code != 250 || rcpt_responses.iter().all(|resp| resp.code != 250))
1035        {
1036            // RFC 2920 3.1:
1037            // the client cannot assume that the DATA command will be rejected
1038            // just because none of the RCPT TO commands worked.  If the DATA
1039            // command was properly rejected the client SMTP can just issue
1040            // RSET, but if the DATA command was accepted the client SMTP
1041            // should send a single dot.
1042
1043            // Send dummy data
1044            self.write_data_with_timeout(b".\r\n").await?;
1045            let data_dot = Command::DataDot;
1046            // wait for its response
1047            let _ = self
1048                .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1049                .await?;
1050
1051            // Continue below: we will match one of the failure cases and
1052            // return a ClientError::Rejected from one of the earlier
1053            // commands
1054        }
1055
1056        if mail_resp.code != 250 {
1057            return Err(ClientError::Rejected(mail_resp));
1058        }
1059        if rcpt_responses.iter().all(|resp| resp.code != 250) {
1060            if rcpt_responses.len() == 1 {
1061                return Err(ClientError::Rejected(
1062                    rcpt_responses.pop().expect("have at least one"),
1063                ));
1064            }
1065            return Err(ClientError::RejectedBatch(rcpt_responses));
1066        }
1067        if data_resp.code != 354 {
1068            return Err(ClientError::Rejected(data_resp));
1069        }
1070
1071        let needs_newline = data.last().map(|&b| b != b'\n').unwrap_or(true);
1072
1073        tracing::trace!("message data is {} bytes", data.len());
1074
1075        self.write_data_with_timeout(data).await?;
1076
1077        let marker = if needs_newline { "\r\n.\r\n" } else { ".\r\n" };
1078
1079        tracing::trace!("send->{}: {}", self.hostname, marker.escape_debug());
1080
1081        self.write_data_with_timeout(marker.as_bytes()).await?;
1082
1083        let data_dot = Command::DataDot;
1084        let resp = self
1085            .read_response(Some(&data_dot), data_dot.client_timeout(&self.timeouts))
1086            .await?;
1087        if resp.code != 250 {
1088            return Err(ClientError::Rejected(resp));
1089        }
1090
1091        // If everything went well, respect the user preference for speculatively
1092        // issuing an RSET next time around
1093        self.use_rset = self.enable_rset;
1094
1095        Ok(BatchSendSuccess {
1096            response: resp,
1097            rcpt_responses,
1098        })
1099    }
1100}
1101
1102#[derive(Debug)]
1103pub struct BatchSendSuccess {
1104    pub response: Response,
1105    pub rcpt_responses: Vec<Response>,
1106}
1107
1108#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)]
1109pub enum TlsStatus {
1110    FailedHandshake(String),
1111    Info(TlsInformation),
1112}
1113
1114#[derive(Debug, Default, PartialEq, Eq, Serialize, Deserialize, Clone)]
1115pub struct TlsInformation {
1116    pub cipher: String,
1117    pub protocol_version: String,
1118    pub subject_name: Vec<String>,
1119    pub provider_name: String,
1120}
1121
1122impl Drop for SmtpClient {
1123    fn drop(&mut self) {
1124        if let Some(tracer) = &self.tracer {
1125            if self.socket.is_some() {
1126                tracer.trace_event(SmtpClientTraceEvent::Closed);
1127            }
1128        }
1129    }
1130}
1131fn parse_response_line(line: &'_ str) -> Result<ResponseLine<'_>, ClientError> {
1132    if line.len() < 4 {
1133        return Err(ClientError::MalformedResponseLine(line.to_string()));
1134    }
1135
1136    match line.as_bytes()[3] {
1137        b' ' | b'-' => match line[0..3].parse::<u16>() {
1138            Ok(code) => Ok(ResponseLine {
1139                code,
1140                is_final: line.as_bytes()[3] == b' ',
1141                content: &line[4..],
1142            }),
1143            Err(_) => Err(ClientError::MalformedResponseLine(line.to_string())),
1144        },
1145        _ => Err(ClientError::MalformedResponseLine(line.to_string())),
1146    }
1147}
1148
1149fn apply_dot_stuffing(data: &[u8]) -> Option<Vec<u8>> {
1150    static LFDOT: LazyLock<Finder> = LazyLock::new(|| memchr::memmem::Finder::new("\n."));
1151
1152    if !data.starts_with(b".") && LFDOT.find(data).is_none() {
1153        return None;
1154    }
1155
1156    let mut stuffed = vec![];
1157    if data.starts_with(b".") {
1158        stuffed.push(b'.');
1159    }
1160    let mut last_idx = 0;
1161    for i in LFDOT.find_iter(data) {
1162        stuffed.extend_from_slice(&data[last_idx..=i]);
1163        stuffed.push(b'.');
1164        last_idx = i + 1;
1165    }
1166    stuffed.extend_from_slice(&data[last_idx..]);
1167    Some(stuffed)
1168}
1169
1170/// Extracts the object=name pairs of the subject name from a cert.
1171/// eg:
1172/// ```no_run
1173/// ["C=US", "ST=CA", "L=SanFrancisco", "O=Fort-Funston", "OU=MyOrganizationalUnit",
1174/// "CN=do.havedane.net", "name=EasyRSA", "emailAddress=me@myhost.mydomain"]
1175/// ```
1176pub fn subject_name(cert: &X509Ref) -> Vec<String> {
1177    let mut subject_name = vec![];
1178    for entry in cert.subject_name().entries() {
1179        if let Ok(obj) = entry.object().nid().short_name() {
1180            if let Ok(name) = entry.data().as_utf8() {
1181                subject_name.push(format!("{obj}={name}"));
1182            }
1183        }
1184    }
1185    subject_name
1186}
1187
1188#[cfg(test)]
1189mod test {
1190    use super::*;
1191    use crate::{MailPath, Mailbox};
1192
1193    #[test]
1194    fn test_stuffing() {
1195        assert_eq!(apply_dot_stuffing(b"foo"), None);
1196        assert_eq!(apply_dot_stuffing(b".foo").unwrap(), b"..foo");
1197        assert_eq!(apply_dot_stuffing(b"foo\n.bar").unwrap(), b"foo\n..bar");
1198        assert_eq!(
1199            apply_dot_stuffing(b"foo\n.bar\n..baz\n").unwrap(),
1200            b"foo\n..bar\n...baz\n"
1201        );
1202    }
1203
1204    /*
1205    #[tokio::test]
1206    async fn test_against_sink() {
1207        use tokio::net::TcpStream;
1208        let stream = TcpStream::connect("127.0.0.1:2025").await.unwrap();
1209        let mut client =
1210            SmtpClient::with_stream(stream, "localhost", SmtpClientTimeouts::default());
1211        dbg!(client.read_response(None).await).unwrap();
1212        dbg!(client.ehlo("localhost").await).unwrap();
1213        let insecure = true;
1214        dbg!(client.starttls(insecure).await).unwrap();
1215        let resp = client
1216            .send_mail(
1217                ReversePath::try_from("wez@mail.example.com").unwrap(),
1218                ForwardPath::try_from("wez@mail.example.com").unwrap(),
1219                "Subject: hello\r\n\r\nwoot\r\n",
1220            )
1221            .await
1222            .unwrap();
1223        panic!("{resp:#?}");
1224    }
1225    */
1226
1227    #[test]
1228    fn response_line_parsing() {
1229        assert_eq!(
1230            parse_response_line("220 woot").unwrap(),
1231            ResponseLine {
1232                code: 220,
1233                is_final: true,
1234                content: "woot"
1235            }
1236        );
1237        assert_eq!(
1238            parse_response_line("220-woot").unwrap(),
1239            ResponseLine {
1240                code: 220,
1241                is_final: false,
1242                content: "woot"
1243            }
1244        );
1245
1246        assert!(matches!(
1247            parse_response_line("220_woot"),
1248            Err(ClientError::MalformedResponseLine(_))
1249        ));
1250        assert!(matches!(
1251            parse_response_line("not really"),
1252            Err(ClientError::MalformedResponseLine(_))
1253        ));
1254    }
1255
1256    fn parse_multi_line(lines: &[&str]) -> Result<Response, ClientError> {
1257        let mut parsed = parse_response_line(lines[0])?;
1258        let mut b = ResponseBuilder::new(&parsed);
1259        for line in &lines[1..] {
1260            parsed = parse_response_line(line)?;
1261            b.add_line(&parsed)
1262                .map_err(ClientError::MalformedResponseLine)?;
1263        }
1264        assert!(parsed.is_final);
1265        Ok(b.build(None))
1266    }
1267
1268    #[test]
1269    fn multi_line_response() {
1270        assert_eq!(
1271            parse_multi_line(&["220-woot", "220-more", "220 done",]).unwrap(),
1272            Response {
1273                code: 220,
1274                enhanced_code: None,
1275                content: "woot\nmore\ndone".to_string(),
1276                command: None
1277            }
1278        );
1279
1280        let res = parse_multi_line(&["220-woot", "221-more", "220 done"]).unwrap_err();
1281        assert!(
1282            matches!(
1283                    res,
1284                ClientError::MalformedResponseLine(ref err) if err == "221-more"
1285            ),
1286            "got error {res:?}"
1287        );
1288
1289        let res = parse_multi_line(&["220-woot", "220-more", "221 done"]).unwrap_err();
1290        assert!(
1291            matches!(
1292                    res,
1293                ClientError::MalformedResponseLine(ref err) if err == "221 done"
1294            ),
1295            "got error {res:?}"
1296        );
1297
1298        assert_eq!(
1299            parse_multi_line(&["220-4.1.0 woot", "220-more", "220 done",]).unwrap(),
1300            Response {
1301                code: 220,
1302                enhanced_code: Some(EnhancedStatusCode {
1303                    class: 4,
1304                    subject: 1,
1305                    detail: 0
1306                }),
1307                content: "woot\nmore\ndone".to_string(),
1308                command: None
1309            }
1310        );
1311
1312        // Confirm that we strip the enhanced status code from each line
1313        assert_eq!(
1314            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 done",]).unwrap(),
1315            Response {
1316                code: 220,
1317                enhanced_code: Some(EnhancedStatusCode {
1318                    class: 4,
1319                    subject: 1,
1320                    detail: 0
1321                }),
1322                content: "woot\nmore\ndone".to_string(),
1323                command: None
1324            }
1325        );
1326
1327        // ... but only if the code matches that of the first line
1328        assert_eq!(
1329            parse_multi_line(&["220-4.1.0 woot", "220-4.1.0 more", "220 5.5.5 done",]).unwrap(),
1330            Response {
1331                code: 220,
1332                enhanced_code: Some(EnhancedStatusCode {
1333                    class: 4,
1334                    subject: 1,
1335                    detail: 0
1336                }),
1337                content: "woot\nmore\n5.5.5 done".to_string(),
1338                command: None
1339            }
1340        );
1341    }
1342
1343    #[test]
1344    fn test_extract_hostname() {
1345        assert_eq!(extract_hostname("foo"), "foo");
1346        assert_eq!(extract_hostname("foo."), "foo");
1347        assert_eq!(extract_hostname("foo:25"), "foo");
1348        assert_eq!(extract_hostname("foo.:25"), "foo");
1349        assert_eq!(extract_hostname("[foo]:25"), "foo");
1350        assert_eq!(extract_hostname("[foo.]:25"), "foo");
1351        assert_eq!(extract_hostname("[::1]:25"), "::1");
1352        assert_eq!(extract_hostname("::1:25"), "::1");
1353    }
1354
1355    #[test]
1356    fn test_format_error_command() {
1357        assert_eq!(
1358            format!(
1359                "{:#}",
1360                ClientError::TimeOutRequest {
1361                    commands: vec![Command::DataDot],
1362                    duration: Duration::from_secs(10),
1363                }
1364            ),
1365            "Timed Out after 10s writing cmd=."
1366        );
1367        assert_eq!(
1368            format!(
1369                "{:#}",
1370                ClientError::TimeOutResponse {
1371                    command: Some(Command::MailFrom {
1372                        address: ReversePath::Path(MailPath {
1373                            at_domain_list: vec![],
1374                            mailbox: Mailbox {
1375                                local_part: "user".to_string(),
1376                                domain: Domain::Name("host".to_string())
1377                            }
1378                        }),
1379                        parameters: vec![],
1380                    }),
1381                    duration: Duration::from_secs(10),
1382                }
1383            ),
1384            r#"Timed Out waiting 10s for response to cmd=MAIL FROM:<user@host>"#
1385        );
1386    }
1387}