mod_mpsc/
lib.rs

1use config::{any_err, get_or_create_sub_module};
2use dashmap::DashMap;
3use mlua::prelude::*;
4use mlua::{Lua, UserDataMethods};
5use mod_memoize::CacheValue;
6use std::sync::{Arc, LazyLock};
7use tokio::sync::Mutex;
8use tokio::sync::mpsc::{
9    Receiver, Sender, UnboundedReceiver, UnboundedSender, channel, unbounded_channel,
10};
11use tokio::time::Duration;
12
13enum SenderWrapper {
14    Bounded(Sender<CacheValue>),
15    Unbounded(UnboundedSender<CacheValue>),
16}
17
18impl SenderWrapper {
19    pub async fn send(&self, value: CacheValue) -> anyhow::Result<()> {
20        match self {
21            Self::Bounded(s) => {
22                s.send(value).await?;
23            }
24            Self::Unbounded(s) => {
25                s.send(value)?;
26            }
27        };
28        Ok(())
29    }
30
31    pub fn try_send(&self, value: CacheValue) -> anyhow::Result<()> {
32        match self {
33            Self::Bounded(s) => {
34                s.try_send(value)?;
35            }
36            Self::Unbounded(s) => {
37                s.send(value)?;
38            }
39        };
40        Ok(())
41    }
42
43    pub async fn send_timeout(&self, value: CacheValue, duration: Duration) -> anyhow::Result<()> {
44        match self {
45            Self::Bounded(s) => {
46                s.send_timeout(value, duration).await?;
47            }
48            Self::Unbounded(s) => {
49                s.send(value)?;
50            }
51        };
52        Ok(())
53    }
54    pub fn is_closed(&self) -> bool {
55        match self {
56            Self::Bounded(r) => r.is_closed(),
57            Self::Unbounded(r) => r.is_closed(),
58        }
59    }
60}
61
62enum ReceiverWrapper {
63    Bounded(Receiver<CacheValue>),
64    Unbounded(UnboundedReceiver<CacheValue>),
65}
66
67impl ReceiverWrapper {
68    pub async fn recv(&mut self) -> Option<CacheValue> {
69        match self {
70            Self::Bounded(r) => r.recv().await,
71            Self::Unbounded(r) => r.recv().await,
72        }
73    }
74
75    pub fn try_recv(&mut self) -> Option<CacheValue> {
76        match self {
77            Self::Bounded(r) => r.try_recv().ok(),
78            Self::Unbounded(r) => r.try_recv().ok(),
79        }
80    }
81
82    pub async fn recv_many(&mut self, limit: usize) -> Vec<CacheValue> {
83        let mut buffer = vec![];
84        match self {
85            Self::Bounded(r) => r.recv_many(&mut buffer, limit).await,
86            Self::Unbounded(r) => r.recv_many(&mut buffer, limit).await,
87        };
88
89        buffer
90    }
91
92    pub fn close(&mut self) {
93        match self {
94            Self::Bounded(r) => r.close(),
95            Self::Unbounded(r) => r.close(),
96        }
97    }
98
99    pub fn is_empty(&self) -> bool {
100        match self {
101            Self::Bounded(r) => r.is_empty(),
102            Self::Unbounded(r) => r.is_empty(),
103        }
104    }
105
106    pub fn len(&self) -> usize {
107        match self {
108            Self::Bounded(r) => r.len(),
109            Self::Unbounded(r) => r.len(),
110        }
111    }
112}
113
114struct Queue {
115    sender: Arc<SenderWrapper>,
116    receiver: Arc<Mutex<ReceiverWrapper>>,
117}
118
119struct QueueHandle(Arc<Queue>);
120
121impl LuaUserData for QueueHandle {
122    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
123        methods.add_async_method(
124            "send",
125            move |_lua, this: LuaUserDataRef<QueueHandle>, value: CacheValue| async move {
126                this.0.sender.send(value).await.map_err(any_err)?;
127                Ok(())
128            },
129        );
130
131        methods.add_async_method(
132            "send_timeout",
133            move |_lua,
134                  this: LuaUserDataRef<QueueHandle>,
135                  (value, timeout_seconds): (CacheValue, f32)| async move {
136                this.0
137                    .sender
138                    .send_timeout(value, Duration::from_secs_f32(timeout_seconds))
139                    .await
140                    .map_err(any_err)?;
141                Ok(())
142            },
143        );
144
145        methods.add_method(
146            "try_send",
147            move |_lua, this: &QueueHandle, value: CacheValue| match this.0.sender.try_send(value) {
148                Ok(()) => Ok(true),
149                Err(_) => Ok(false),
150            },
151        );
152
153        methods.add_async_method(
154            "close",
155            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
156                let mut rx = this.0.receiver.lock().await;
157                Ok(rx.close())
158            },
159        );
160
161        methods.add_method("is_closed", move |_lua, this: &QueueHandle, ()| {
162            Ok(this.0.sender.is_closed())
163        });
164
165        methods.add_async_method(
166            "is_empty",
167            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
168                let rx = this.0.receiver.lock().await;
169                Ok(rx.is_empty())
170            },
171        );
172
173        methods.add_async_method(
174            "len",
175            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
176                let rx = this.0.receiver.lock().await;
177                Ok(rx.len())
178            },
179        );
180
181        methods.add_async_method(
182            "recv",
183            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
184                let mut rx = this.0.receiver.lock().await;
185                Ok(rx.recv().await)
186            },
187        );
188
189        methods.add_async_method(
190            "try_recv",
191            move |_lua, this: LuaUserDataRef<QueueHandle>, ()| async move {
192                let mut rx = this.0.receiver.lock().await;
193                Ok(rx.try_recv())
194            },
195        );
196
197        methods.add_async_method(
198            "recv_many",
199            move |_lua, this: LuaUserDataRef<QueueHandle>, limit: usize| async move {
200                let mut rx = this.0.receiver.lock().await;
201                Ok(rx.recv_many(limit).await)
202            },
203        );
204    }
205}
206
207static QUEUES: LazyLock<DashMap<String, Arc<Queue>>> = LazyLock::new(DashMap::new);
208
209impl Queue {
210    pub fn define_unbounded(name: &str) -> anyhow::Result<QueueHandle> {
211        let queue = QUEUES.entry(name.to_string()).or_insert_with(|| {
212            let (sender, receiver) = unbounded_channel();
213            Arc::new(Queue {
214                sender: Arc::new(SenderWrapper::Unbounded(sender)),
215                receiver: Arc::new(Mutex::new(ReceiverWrapper::Unbounded(receiver))),
216            })
217        });
218
219        Ok(QueueHandle(Arc::clone(queue.value())))
220    }
221
222    pub fn define_bounded(name: &str, buffer: usize) -> anyhow::Result<QueueHandle> {
223        let queue = QUEUES.entry(name.to_string()).or_insert_with(|| {
224            let (sender, receiver) = channel(buffer);
225            Arc::new(Queue {
226                sender: Arc::new(SenderWrapper::Bounded(sender)),
227                receiver: Arc::new(Mutex::new(ReceiverWrapper::Bounded(receiver))),
228            })
229        });
230
231        Ok(QueueHandle(Arc::clone(queue.value())))
232    }
233}
234
235pub fn register(lua: &Lua) -> anyhow::Result<()> {
236    let kumo_mpsc = get_or_create_sub_module(lua, "mpsc")?;
237
238    kumo_mpsc.set(
239        "define",
240        lua.create_function(|_lua, (name, buffer): (String, Option<usize>)| {
241            match buffer {
242                Some(buffer) => Queue::define_bounded(&name, buffer),
243                None => Queue::define_unbounded(&name),
244            }
245            .map_err(any_err)
246        })?,
247    )?;
248
249    Ok(())
250}