mod_amqp/
amqprs_client.rs

1use amqprs::callbacks::{DefaultChannelCallback, DefaultConnectionCallback};
2use amqprs::channel::{BasicPublishArguments, Channel, ConfirmSelectArguments};
3use amqprs::connection::{Connection, OpenConnectionArguments};
4use amqprs::tls::TlsAdaptor;
5use amqprs::{BasicProperties, FieldTable, TimeStamp};
6use deadpool::managed::{Manager, Metrics, Pool, RecycleError, RecycleResult};
7use kumo_server_memory::subscribe_to_memory_status_changes_async;
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::Path;
12use std::sync::{Arc, LazyLock};
13use std::time::Duration;
14
15static POOLS: LazyLock<Mutex<Pools>> = LazyLock::new(Pools::new);
16
17struct Pools {
18    map: HashMap<ConnectionInfo, Arc<Pool<ConnectionManager>>>,
19}
20
21impl Pools {
22    pub fn new() -> Mutex<Self> {
23        tokio::spawn(Self::memory_monitor());
24        Mutex::new(Self {
25            map: HashMap::new(),
26        })
27    }
28
29    pub fn get(&self, info: &ConnectionInfo) -> Option<&Arc<Pool<ConnectionManager>>> {
30        self.map.get(info)
31    }
32
33    pub fn insert(&mut self, info: ConnectionInfo, pool: Arc<Pool<ConnectionManager>>) {
34        self.map.insert(info, pool);
35    }
36
37    async fn memory_monitor() {
38        tracing::debug!("starting memory monitor");
39        tokio::time::sleep(Duration::from_secs(10)).await;
40
41        let mut memory_status = subscribe_to_memory_status_changes_async().await;
42        while let Ok(()) = memory_status.changed().await {
43            if kumo_server_memory::get_headroom() == 0 {
44                Self::purge().await;
45            }
46        }
47    }
48
49    fn all_pools() -> Vec<Arc<Pool<ConnectionManager>>> {
50        POOLS.lock().map.values().cloned().collect()
51    }
52
53    async fn purge() {
54        let pools = Self::all_pools();
55        for pool in pools {
56            let info = &pool.manager().0;
57
58            let result = pool.retain(|_, _| false);
59            tracing::error!(
60                "purging {} amqprs connections for {}:{:?}",
61                result.removed.len(),
62                info.host,
63                info.port
64            );
65
66            for connection in result.removed {
67                if let Err(err) = connection.channel.close().await {
68                    tracing::error!(
69                        "Error closing channel to {}:{:?}: {err:#}",
70                        info.host,
71                        info.port
72                    );
73                }
74                if let Err(err) = connection.connection.connection.close().await {
75                    tracing::error!(
76                        "Error closing connection to {}:{:?}: {err:#}",
77                        info.host,
78                        info.port
79                    );
80                }
81            }
82        }
83    }
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize, Hash, Eq, PartialEq)]
87#[serde(deny_unknown_fields)]
88pub struct ConnectionInfo {
89    pub host: String,
90    pub port: Option<u16>,
91    pub username: Option<String>,
92    pub password: Option<String>,
93    pub vhost: Option<String>,
94    pub connection_name: Option<String>,
95    pub heartbeat: Option<u16>,
96    #[serde(default)]
97    pub enable_tls: bool,
98    pub root_ca_cert: Option<String>,
99    pub client_cert: Option<String>,
100    pub client_private_key: Option<String>,
101
102    // TODO: not fully implemented
103    #[serde(default)]
104    pub confirm_select: bool,
105
106    #[serde(default)]
107    pub pool_size: Option<usize>,
108    #[serde(default, with = "duration_serde")]
109    pub connect_timeout: Option<Duration>,
110    #[serde(default, with = "duration_serde")]
111    pub recycle_timeout: Option<Duration>,
112    #[serde(default, with = "duration_serde")]
113    pub wait_timeout: Option<Duration>,
114    #[serde(default, with = "duration_serde")]
115    pub publish_timeout: Option<Duration>,
116}
117
118pub struct ConnectionManager(ConnectionInfo);
119pub struct ConnectionWithInfo {
120    connection: Connection,
121    // TODO: when wiring up confirm_select, we need to look
122    // at this to see whether we should await for the publish
123    #[allow(unused)]
124    info: ConnectionInfo,
125}
126
127pub struct ConnectionAndChannel {
128    connection: ConnectionWithInfo,
129    channel: Channel,
130}
131
132impl Manager for ConnectionManager {
133    type Type = ConnectionAndChannel;
134    type Error = anyhow::Error;
135
136    async fn create(&self) -> Result<Self::Type, Self::Error> {
137        let connection = self.0.connect().await?;
138
139        connection
140            .register_callback(DefaultConnectionCallback)
141            .await?;
142
143        let channel = connection.open_channel(None).await?;
144        channel.register_callback(DefaultChannelCallback).await?;
145        if self.0.confirm_select {
146            channel
147                .confirm_select(ConfirmSelectArguments::default())
148                .await?;
149        }
150
151        Ok(ConnectionAndChannel {
152            connection: ConnectionWithInfo {
153                connection,
154                info: self.0.clone(),
155            },
156            channel,
157        })
158    }
159
160    async fn recycle(
161        &self,
162        conn: &mut Self::Type,
163        _metrics: &Metrics,
164    ) -> RecycleResult<anyhow::Error> {
165        if conn.connection.connection.is_open() && conn.channel.is_open() {
166            Ok(())
167        } else {
168            Err(RecycleError::message("channel/connection is closed"))
169        }
170    }
171}
172
173impl ConnectionInfo {
174    pub async fn connect(&self) -> anyhow::Result<Connection> {
175        let mut args = OpenConnectionArguments::new(
176            &self.host,
177            self.port.unwrap_or(5672),
178            self.username.as_deref().unwrap_or("guest"),
179            self.password.as_deref().unwrap_or("guest"),
180        );
181        if let Some(vhost) = &self.vhost {
182            args.virtual_host(vhost);
183        }
184        if let Some(name) = &self.connection_name {
185            args.connection_name(name);
186        }
187        if let Some(hb) = self.heartbeat {
188            args.heartbeat(hb);
189        }
190        if self.enable_tls {
191            let adaptor = match (&self.client_cert, &self.client_private_key) {
192                (Some(cert), Some(key)) => TlsAdaptor::with_client_auth(
193                    self.root_ca_cert.as_deref().map(Path::new),
194                    Path::new(cert),
195                    Path::new(key),
196                    self.host.to_string(),
197                )?,
198                (None, None) => TlsAdaptor::without_client_auth(
199                    self.root_ca_cert.as_deref().map(Path::new),
200                    self.host.to_string(),
201                )?,
202                _ => anyhow::bail!(
203                    "Either both client_cert and client_private_key must be specified, or neither"
204                ),
205            };
206            args.tls_adaptor(adaptor);
207        }
208
209        let connection = Connection::open(&args).await?;
210
211        Ok(connection)
212    }
213
214    pub fn get_pool(&self) -> anyhow::Result<Arc<Pool<ConnectionManager>>> {
215        let mut pools = POOLS.lock();
216        if let Some(pool) = pools.get(self) {
217            return Ok(pool.clone());
218        }
219
220        let mut builder = Pool::builder(ConnectionManager(self.clone()))
221            .runtime(deadpool::Runtime::Tokio1)
222            .create_timeout(self.connect_timeout)
223            .recycle_timeout(self.recycle_timeout)
224            .wait_timeout(self.wait_timeout);
225
226        if let Some(limit) = self.pool_size {
227            builder = builder.max_size(limit);
228        }
229
230        let pool = Arc::new(builder.build()?);
231
232        pools.insert(self.clone(), pool.clone());
233
234        Ok(pool)
235    }
236}
237
238#[derive(Deserialize, Debug)]
239#[serde(deny_unknown_fields)]
240pub struct PublishParams {
241    pub routing_key: String,
242    pub payload: String,
243    pub connection: ConnectionInfo,
244
245    pub app_id: Option<String>,
246    pub cluster_id: Option<String>,
247    pub content_encoding: Option<String>,
248    pub content_type: Option<String>,
249    pub correlation_id: Option<String>,
250    pub delivery_mode: Option<u8>,
251    pub expiration: Option<String>,
252    pub headers: Option<FieldTable>,
253    pub message_id: Option<String>,
254    pub message_type: Option<String>,
255    pub priority: Option<u8>,
256    pub reply_to: Option<String>,
257    pub timestamp: Option<TimeStamp>,
258    pub user_id: Option<String>,
259
260    #[serde(default)]
261    pub exchange: String,
262    #[serde(default)]
263    pub mandatory: bool,
264    #[serde(default)]
265    pub immediate: bool,
266}
267
268pub async fn publish(params: PublishParams) -> anyhow::Result<()> {
269    kumo_server_runtime::get_main_runtime()
270        .spawn(async move {
271            let connection = params
272                .connection
273                .get_pool()?
274                .get()
275                .await
276                .map_err(|err| anyhow::anyhow!("{err:#}"))?;
277
278            let mut props = BasicProperties::default();
279            if let Some(v) = &params.app_id {
280                props.with_app_id(v);
281            }
282            if let Some(v) = &params.cluster_id {
283                props.with_cluster_id(v);
284            }
285            if let Some(v) = &params.content_encoding {
286                props.with_content_encoding(v);
287            }
288            if let Some(v) = &params.content_type {
289                props.with_content_type(v);
290            }
291            if let Some(v) = &params.correlation_id {
292                props.with_correlation_id(v);
293            }
294            if let Some(v) = params.delivery_mode {
295                props.with_delivery_mode(v);
296            }
297            if let Some(v) = &params.expiration {
298                props.with_expiration(v);
299            }
300            if let Some(v) = params.headers {
301                props.with_headers(v);
302            }
303            if let Some(v) = &params.message_id {
304                props.with_message_id(v);
305            }
306            if let Some(v) = &params.message_type {
307                props.with_message_type(v);
308            }
309            if let Some(v) = params.priority {
310                props.with_priority(v);
311            }
312            if let Some(v) = &params.reply_to {
313                props.with_reply_to(v);
314            }
315            if let Some(v) = params.timestamp {
316                props.with_timestamp(v);
317            }
318            if let Some(v) = &params.user_id {
319                props.with_user_id(v);
320            }
321
322            let args = BasicPublishArguments {
323                exchange: params.exchange,
324                routing_key: params.routing_key,
325                mandatory: params.mandatory,
326                immediate: params.immediate,
327            };
328
329            let timeout_duration = params
330                .connection
331                .publish_timeout
332                .unwrap_or_else(|| Duration::from_secs(60));
333
334            tokio::time::timeout(
335                timeout_duration,
336                connection
337                    .channel
338                    .basic_publish(props, params.payload.into_bytes(), args),
339            )
340            .await??;
341
342            Ok(())
343        })
344        .await?
345}