mod_mpsc/
lib.rs

1use config::{any_err, get_or_create_sub_module};
2use dashmap::DashMap;
3use mlua::prelude::*;
4use mlua::{Lua, UserDataMethods};
5use mod_memoize::CacheValue;
6use std::sync::{Arc, LazyLock};
7use tokio::sync::Mutex;
8use tokio::sync::mpsc::{
9    Receiver, Sender, UnboundedReceiver, UnboundedSender, channel, unbounded_channel,
10};
11use tokio::time::Duration;
12
13enum SenderWrapper {
14    Bounded(Sender<CacheValue>),
15    Unbounded(UnboundedSender<CacheValue>),
16}
17
18impl SenderWrapper {
19    pub async fn send(&self, value: CacheValue) -> anyhow::Result<()> {
20        match self {
21            Self::Bounded(s) => {
22                s.send(value).await?;
23            }
24            Self::Unbounded(s) => {
25                s.send(value)?;
26            }
27        };
28        Ok(())
29    }
30
31    pub fn try_send(&self, value: CacheValue) -> anyhow::Result<()> {
32        match self {
33            Self::Bounded(s) => {
34                s.try_send(value)?;
35            }
36            Self::Unbounded(s) => {
37                s.send(value)?;
38            }
39        };
40        Ok(())
41    }
42
43    pub async fn send_timeout(&self, value: CacheValue, duration: Duration) -> anyhow::Result<()> {
44        match self {
45            Self::Bounded(s) => {
46                s.send_timeout(value, duration).await?;
47            }
48            Self::Unbounded(s) => {
49                s.send(value)?;
50            }
51        };
52        Ok(())
53    }
54    pub fn is_closed(&self) -> bool {
55        match self {
56            Self::Bounded(r) => r.is_closed(),
57            Self::Unbounded(r) => r.is_closed(),
58        }
59    }
60}
61
62enum ReceiverWrapper {
63    Bounded(Receiver<CacheValue>),
64    Unbounded(UnboundedReceiver<CacheValue>),
65}
66
67impl ReceiverWrapper {
68    pub async fn recv(&mut self) -> Option<CacheValue> {
69        match self {
70            Self::Bounded(r) => r.recv().await,
71            Self::Unbounded(r) => r.recv().await,
72        }
73    }
74
75    pub fn try_recv(&mut self) -> Option<CacheValue> {
76        match self {
77            Self::Bounded(r) => r.try_recv().ok(),
78            Self::Unbounded(r) => r.try_recv().ok(),
79        }
80    }
81
82    pub async fn recv_many(&mut self, limit: usize) -> Vec<CacheValue> {
83        let mut buffer = vec![];
84        match self {
85            Self::Bounded(r) => r.recv_many(&mut buffer, limit).await,
86            Self::Unbounded(r) => r.recv_many(&mut buffer, limit).await,
87        };
88
89        buffer
90    }
91
92    pub fn close(&mut self) {
93        match self {
94            Self::Bounded(r) => r.close(),
95            Self::Unbounded(r) => r.close(),
96        }
97    }
98
99    pub fn is_empty(&self) -> bool {
100        match self {
101            Self::Bounded(r) => r.is_empty(),
102            Self::Unbounded(r) => r.is_empty(),
103        }
104    }
105
106    pub fn len(&self) -> usize {
107        match self {
108            Self::Bounded(r) => r.len(),
109            Self::Unbounded(r) => r.len(),
110        }
111    }
112}
113
114struct Queue {
115    sender: Mutex<Option<Arc<SenderWrapper>>>,
116    receiver: Arc<Mutex<ReceiverWrapper>>,
117}
118
119struct QueueHandle(Arc<Queue>);
120
121impl QueueHandle {
122    async fn get_sender(&self) -> Option<Arc<SenderWrapper>> {
123        let sender = self.0.sender.lock().await;
124        sender.clone()
125    }
126}
127
128impl LuaUserData for QueueHandle {
129    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
130        methods.add_async_method(
131            "send",
132            move |_lua, this: LuaUserDataRef<QueueHandle>, value: CacheValue| async move {
133                this.get_sender()
134                    .await
135                    .ok_or_else(|| LuaError::external("cannot send to closed channel"))?
136                    .send(value)
137                    .await
138                    .map_err(any_err)?;
139                Ok(())
140            },
141        );
142
143        methods.add_async_method(
144            "send_timeout",
145            move |_lua,
146                  this: LuaUserDataRef<QueueHandle>,
147                  (value, timeout_seconds): (CacheValue, f32)| async move {
148                this.get_sender()
149                    .await
150                    .ok_or_else(|| LuaError::external("cannot send to closed channel"))?
151                    .send_timeout(value, Duration::from_secs_f32(timeout_seconds))
152                    .await
153                    .map_err(any_err)?;
154                Ok(())
155            },
156        );
157
158        methods.add_async_method(
159            "try_send",
160            move |_lua, this: LuaUserDataRef<QueueHandle>, value: CacheValue| async move {
161                match this
162                    .get_sender()
163                    .await
164                    .ok_or_else(|| LuaError::external("cannot send to closed channel"))?
165                    .try_send(value)
166                {
167                    Ok(()) => Ok(true),
168                    Err(_) => Ok(false),
169                }
170            },
171        );
172
173        methods.add_async_method(
174            "close",
175            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
176                let mut sender = this.get_sender().await;
177                sender.take();
178
179                if let Ok(mut rx) = this.0.receiver.try_lock() {
180                    rx.close();
181                }
182                Ok(())
183            },
184        );
185
186        methods.add_async_method(
187            "is_closed",
188            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
189                Ok(match this.get_sender().await {
190                    None => true,
191                    Some(sender) => sender.is_closed(),
192                })
193            },
194        );
195
196        methods.add_async_method(
197            "is_empty",
198            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
199                let rx = this.0.receiver.try_lock().map_err(any_err)?;
200                Ok(rx.is_empty())
201            },
202        );
203
204        methods.add_async_method(
205            "len",
206            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
207                let rx = this.0.receiver.try_lock().map_err(any_err)?;
208                Ok(rx.len())
209            },
210        );
211
212        methods.add_async_method(
213            "recv",
214            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
215                let mut rx = this.0.receiver.try_lock().map_err(any_err)?;
216                Ok(rx.recv().await)
217            },
218        );
219
220        methods.add_async_method(
221            "try_recv",
222            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
223                let mut rx = this.0.receiver.try_lock().map_err(any_err)?;
224                Ok(rx.try_recv())
225            },
226        );
227
228        methods.add_async_method(
229            "recv_many",
230            move |_lua, this: LuaUserDataRef<QueueHandle>, limit: usize| async move {
231                let mut rx = this.0.receiver.try_lock().map_err(any_err)?;
232                Ok(rx.recv_many(limit).await)
233            },
234        );
235    }
236}
237
238static QUEUES: LazyLock<DashMap<String, Arc<Queue>>> = LazyLock::new(DashMap::new);
239
240impl Queue {
241    pub fn define_unbounded(name: &str) -> anyhow::Result<QueueHandle> {
242        let queue = QUEUES.entry(name.to_string()).or_insert_with(|| {
243            let (sender, receiver) = unbounded_channel();
244            Arc::new(Queue {
245                sender: Mutex::new(Some(Arc::new(SenderWrapper::Unbounded(sender)))),
246                receiver: Arc::new(Mutex::new(ReceiverWrapper::Unbounded(receiver))),
247            })
248        });
249
250        Ok(QueueHandle(Arc::clone(queue.value())))
251    }
252
253    pub fn define_bounded(name: &str, buffer: usize) -> anyhow::Result<QueueHandle> {
254        let queue = QUEUES.entry(name.to_string()).or_insert_with(|| {
255            let (sender, receiver) = channel(buffer);
256            Arc::new(Queue {
257                sender: Mutex::new(Some(Arc::new(SenderWrapper::Bounded(sender)))),
258                receiver: Arc::new(Mutex::new(ReceiverWrapper::Bounded(receiver))),
259            })
260        });
261
262        Ok(QueueHandle(Arc::clone(queue.value())))
263    }
264}
265
266pub fn register(lua: &Lua) -> anyhow::Result<()> {
267    let kumo_mpsc = get_or_create_sub_module(lua, "mpsc")?;
268
269    kumo_mpsc.set(
270        "define",
271        lua.create_function(|_lua, (name, buffer): (String, Option<usize>)| {
272            match buffer {
273                Some(buffer) => Queue::define_bounded(&name, buffer),
274                None => Queue::define_unbounded(&name),
275            }
276            .map_err(any_err)
277        })?,
278    )?;
279
280    Ok(())
281}