1use amqprs::callbacks::{DefaultChannelCallback, DefaultConnectionCallback};
2use amqprs::channel::{BasicPublishArguments, Channel, ConfirmSelectArguments};
3use amqprs::connection::{Connection, OpenConnectionArguments};
4use amqprs::tls::TlsAdaptor;
5use amqprs::{BasicProperties, FieldTable, TimeStamp};
6use deadpool::managed::{Manager, Metrics, Pool, RecycleError, RecycleResult};
7use kumo_server_memory::subscribe_to_memory_status_changes_async;
8use parking_lot::Mutex;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::Path;
12use std::sync::{Arc, LazyLock};
13use std::time::Duration;
14
15static POOLS: LazyLock<Mutex<Pools>> = LazyLock::new(Pools::new);
16
17struct Pools {
18 map: HashMap<ConnectionInfo, Arc<Pool<ConnectionManager>>>,
19}
20
21impl Pools {
22 pub fn new() -> Mutex<Self> {
23 tokio::spawn(Self::memory_monitor());
24 Mutex::new(Self {
25 map: HashMap::new(),
26 })
27 }
28
29 pub fn get(&self, info: &ConnectionInfo) -> Option<&Arc<Pool<ConnectionManager>>> {
30 self.map.get(info)
31 }
32
33 pub fn insert(&mut self, info: ConnectionInfo, pool: Arc<Pool<ConnectionManager>>) {
34 self.map.insert(info, pool);
35 }
36
37 async fn memory_monitor() {
38 tracing::debug!("starting memory monitor");
39 tokio::time::sleep(Duration::from_secs(10)).await;
40
41 let mut memory_status = subscribe_to_memory_status_changes_async().await;
42 while let Ok(()) = memory_status.changed().await {
43 if kumo_server_memory::get_headroom() == 0 {
44 Self::purge().await;
45 }
46 }
47 }
48
49 fn all_pools() -> Vec<Arc<Pool<ConnectionManager>>> {
50 POOLS.lock().map.values().cloned().collect()
51 }
52
53 async fn purge() {
54 let pools = Self::all_pools();
55 for pool in pools {
56 let info = &pool.manager().0;
57
58 let result = pool.retain(|_, _| false);
59 tracing::error!(
60 "purging {} amqprs connections for {}:{:?}",
61 result.removed.len(),
62 info.host,
63 info.port
64 );
65
66 for connection in result.removed {
67 if let Err(err) = connection.channel.close().await {
68 tracing::error!(
69 "Error closing channel to {}:{:?}: {err:#}",
70 info.host,
71 info.port
72 );
73 }
74 if let Err(err) = connection.connection.connection.close().await {
75 tracing::error!(
76 "Error closing connection to {}:{:?}: {err:#}",
77 info.host,
78 info.port
79 );
80 }
81 }
82 }
83 }
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize, Hash, Eq, PartialEq)]
87#[serde(deny_unknown_fields)]
88pub struct ConnectionInfo {
89 pub host: String,
90 pub port: Option<u16>,
91 pub username: Option<String>,
92 pub password: Option<String>,
93 pub vhost: Option<String>,
94 pub connection_name: Option<String>,
95 pub heartbeat: Option<u16>,
96 #[serde(default)]
97 pub enable_tls: bool,
98 pub root_ca_cert: Option<String>,
99 pub client_cert: Option<String>,
100 pub client_private_key: Option<String>,
101
102 #[serde(default)]
104 pub confirm_select: bool,
105
106 #[serde(default)]
107 pub pool_size: Option<usize>,
108 #[serde(default, with = "duration_serde")]
109 pub connect_timeout: Option<Duration>,
110 #[serde(default, with = "duration_serde")]
111 pub recycle_timeout: Option<Duration>,
112 #[serde(default, with = "duration_serde")]
113 pub wait_timeout: Option<Duration>,
114 #[serde(default, with = "duration_serde")]
115 pub publish_timeout: Option<Duration>,
116}
117
118pub struct ConnectionManager(ConnectionInfo);
119pub struct ConnectionWithInfo {
120 connection: Connection,
121 #[allow(unused)]
124 info: ConnectionInfo,
125}
126
127pub struct ConnectionAndChannel {
128 connection: ConnectionWithInfo,
129 channel: Channel,
130}
131
132impl Manager for ConnectionManager {
133 type Type = ConnectionAndChannel;
134 type Error = anyhow::Error;
135
136 async fn create(&self) -> Result<Self::Type, Self::Error> {
137 let connection = self.0.connect().await?;
138
139 connection
140 .register_callback(DefaultConnectionCallback)
141 .await?;
142
143 let channel = connection.open_channel(None).await?;
144 channel.register_callback(DefaultChannelCallback).await?;
145 if self.0.confirm_select {
146 channel
147 .confirm_select(ConfirmSelectArguments::default())
148 .await?;
149 }
150
151 Ok(ConnectionAndChannel {
152 connection: ConnectionWithInfo {
153 connection,
154 info: self.0.clone(),
155 },
156 channel,
157 })
158 }
159
160 async fn recycle(
161 &self,
162 conn: &mut Self::Type,
163 _metrics: &Metrics,
164 ) -> RecycleResult<anyhow::Error> {
165 if conn.connection.connection.is_open() && conn.channel.is_open() {
166 Ok(())
167 } else {
168 Err(RecycleError::message("channel/connection is closed"))
169 }
170 }
171}
172
173impl ConnectionInfo {
174 pub async fn connect(&self) -> anyhow::Result<Connection> {
175 let mut args = OpenConnectionArguments::new(
176 &self.host,
177 self.port.unwrap_or(5672),
178 self.username.as_deref().unwrap_or("guest"),
179 self.password.as_deref().unwrap_or("guest"),
180 );
181 if let Some(vhost) = &self.vhost {
182 args.virtual_host(vhost);
183 }
184 if let Some(name) = &self.connection_name {
185 args.connection_name(name);
186 }
187 if let Some(hb) = self.heartbeat {
188 args.heartbeat(hb);
189 }
190 if self.enable_tls {
191 let adaptor = match (&self.client_cert, &self.client_private_key) {
192 (Some(cert), Some(key)) => TlsAdaptor::with_client_auth(
193 self.root_ca_cert.as_deref().map(Path::new),
194 Path::new(cert),
195 Path::new(key),
196 self.host.to_string(),
197 )?,
198 (None, None) => TlsAdaptor::without_client_auth(
199 self.root_ca_cert.as_deref().map(Path::new),
200 self.host.to_string(),
201 )?,
202 _ => anyhow::bail!(
203 "Either both client_cert and client_private_key must be specified, or neither"
204 ),
205 };
206 args.tls_adaptor(adaptor);
207 }
208
209 let connection = Connection::open(&args).await?;
210
211 Ok(connection)
212 }
213
214 pub fn get_pool(&self) -> anyhow::Result<Arc<Pool<ConnectionManager>>> {
215 let mut pools = POOLS.lock();
216 if let Some(pool) = pools.get(self) {
217 return Ok(pool.clone());
218 }
219
220 let mut builder = Pool::builder(ConnectionManager(self.clone()))
221 .runtime(deadpool::Runtime::Tokio1)
222 .create_timeout(self.connect_timeout)
223 .recycle_timeout(self.recycle_timeout)
224 .wait_timeout(self.wait_timeout);
225
226 if let Some(limit) = self.pool_size {
227 builder = builder.max_size(limit);
228 }
229
230 let pool = Arc::new(builder.build()?);
231
232 pools.insert(self.clone(), pool.clone());
233
234 Ok(pool)
235 }
236}
237
238#[derive(Deserialize, Debug)]
239#[serde(deny_unknown_fields)]
240pub struct PublishParams {
241 pub routing_key: String,
242 pub payload: String,
243 pub connection: ConnectionInfo,
244
245 pub app_id: Option<String>,
246 pub cluster_id: Option<String>,
247 pub content_encoding: Option<String>,
248 pub content_type: Option<String>,
249 pub correlation_id: Option<String>,
250 pub delivery_mode: Option<u8>,
251 pub expiration: Option<String>,
252 pub headers: Option<FieldTable>,
253 pub message_id: Option<String>,
254 pub message_type: Option<String>,
255 pub priority: Option<u8>,
256 pub reply_to: Option<String>,
257 pub timestamp: Option<TimeStamp>,
258 pub user_id: Option<String>,
259
260 #[serde(default)]
261 pub exchange: String,
262 #[serde(default)]
263 pub mandatory: bool,
264 #[serde(default)]
265 pub immediate: bool,
266}
267
268pub async fn publish(params: PublishParams) -> anyhow::Result<()> {
269 kumo_server_runtime::get_main_runtime()
270 .spawn(async move {
271 let connection = params
272 .connection
273 .get_pool()?
274 .get()
275 .await
276 .map_err(|err| anyhow::anyhow!("{err:#}"))?;
277
278 let mut props = BasicProperties::default();
279 if let Some(v) = ¶ms.app_id {
280 props.with_app_id(v);
281 }
282 if let Some(v) = ¶ms.cluster_id {
283 props.with_cluster_id(v);
284 }
285 if let Some(v) = ¶ms.content_encoding {
286 props.with_content_encoding(v);
287 }
288 if let Some(v) = ¶ms.content_type {
289 props.with_content_type(v);
290 }
291 if let Some(v) = ¶ms.correlation_id {
292 props.with_correlation_id(v);
293 }
294 if let Some(v) = params.delivery_mode {
295 props.with_delivery_mode(v);
296 }
297 if let Some(v) = ¶ms.expiration {
298 props.with_expiration(v);
299 }
300 if let Some(v) = params.headers {
301 props.with_headers(v);
302 }
303 if let Some(v) = ¶ms.message_id {
304 props.with_message_id(v);
305 }
306 if let Some(v) = ¶ms.message_type {
307 props.with_message_type(v);
308 }
309 if let Some(v) = params.priority {
310 props.with_priority(v);
311 }
312 if let Some(v) = ¶ms.reply_to {
313 props.with_reply_to(v);
314 }
315 if let Some(v) = params.timestamp {
316 props.with_timestamp(v);
317 }
318 if let Some(v) = ¶ms.user_id {
319 props.with_user_id(v);
320 }
321
322 let args = BasicPublishArguments {
323 exchange: params.exchange,
324 routing_key: params.routing_key,
325 mandatory: params.mandatory,
326 immediate: params.immediate,
327 };
328
329 let timeout_duration = params
330 .connection
331 .publish_timeout
332 .unwrap_or_else(|| Duration::from_secs(60));
333
334 tokio::time::timeout(
335 timeout_duration,
336 connection
337 .channel
338 .basic_publish(props, params.payload.into_bytes(), args),
339 )
340 .await??;
341
342 Ok(())
343 })
344 .await?
345}