mod_kafka/
lib.rs

1use config::{any_err, get_or_create_sub_module};
2use futures::stream::FuturesOrdered;
3use futures::StreamExt;
4use mlua::prelude::LuaUserData;
5use mlua::{Lua, LuaSerdeExt, UserDataMethods, Value};
6use rdkafka::message::{Header, OwnedHeaders};
7use rdkafka::producer::future_producer::Delivery;
8use rdkafka::producer::{FutureProducer, FutureRecord};
9use rdkafka::util::Timeout;
10use rdkafka::ClientConfig;
11use serde::Deserialize;
12use std::collections::HashMap;
13use std::sync::{Arc, Mutex};
14use std::time::Duration;
15
16#[derive(Clone)]
17struct Producer {
18    producer: Arc<Mutex<Option<Arc<FutureProducer>>>>,
19}
20
21impl Producer {
22    fn get_producer(&self) -> mlua::Result<Arc<FutureProducer>> {
23        self.producer
24            .lock()
25            .unwrap()
26            .as_ref()
27            .map(Arc::clone)
28            .ok_or_else(|| mlua::Error::external("client was closed"))
29    }
30}
31
32#[derive(Deserialize, Debug)]
33struct Record {
34    /// Required destination topic
35    topic: String,
36    /// Optional destination partition
37    #[serde(default)]
38    partition: Option<i32>,
39    /// Optional payload
40    #[serde(with = "serde_bytes")]
41    payload: Option<Vec<u8>>,
42    /// Optional key
43    #[serde(default)]
44    key: Option<String>,
45
46    /// Optional headers
47    #[serde(default)]
48    headers: HashMap<String, String>,
49
50    /// Optional timeout. If no timeout is provided, assume 1 minute.
51    /// The timeout is how long to keep retrying to submit to kafka
52    /// before giving up on this attempt.
53    /// Note that the underlying library supports retrying forever,
54    /// but in kumomta we don't allow that; we can retry later without
55    /// keeping the system occupied for an indefinite time.
56    #[serde(default)]
57    #[serde(with = "duration_serde")]
58    timeout: Option<Duration>,
59}
60
61impl LuaUserData for Producer {
62    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
63        methods.add_async_method("send", |lua, this, value: Value| async move {
64            let record: Record = lua.from_value(value)?;
65
66            let headers = if record.headers.is_empty() {
67                None
68            } else {
69                let mut headers = OwnedHeaders::new();
70                for (key, v) in &record.headers {
71                    headers = headers.insert(Header {
72                        key,
73                        value: Some(v),
74                    });
75                }
76                Some(headers)
77            };
78
79            let future_record = FutureRecord {
80                topic: &record.topic,
81                partition: record.partition,
82                payload: record.payload.as_ref(),
83                key: record.key.as_ref(),
84                headers,
85                timestamp: None,
86            };
87
88            let Delivery {
89                partition,
90                offset,
91                timestamp: _,
92            } = this
93                .get_producer()?
94                .send(
95                    future_record,
96                    Timeout::After(record.timeout.unwrap_or(Duration::from_secs(60))),
97                )
98                .await
99                .map_err(|(code, _msg)| any_err(code))?;
100
101            Ok((partition, offset))
102        });
103
104        methods.add_async_method("send_batch", |lua, this, values: Vec<Value>| async move {
105            let mut tasks = FuturesOrdered::new();
106            let producer = this.get_producer()?;
107
108            for value in values {
109                let record: Record = lua.from_value(value)?;
110
111                let headers = if record.headers.is_empty() {
112                    None
113                } else {
114                    let mut headers = OwnedHeaders::new();
115                    for (key, v) in &record.headers {
116                        headers = headers.insert(Header {
117                            key,
118                            value: Some(v),
119                        });
120                    }
121                    Some(headers)
122                };
123
124                let producer = producer.clone();
125
126                tasks.push_back(tokio::spawn(async move {
127                    producer
128                        .send(
129                            FutureRecord {
130                                topic: &record.topic,
131                                partition: record.partition,
132                                payload: record.payload.as_ref(),
133                                key: record.key.as_ref(),
134                                headers,
135                                timestamp: None,
136                            },
137                            Timeout::After(record.timeout.unwrap_or(Duration::from_secs(60))),
138                        )
139                        .await
140                }));
141            }
142
143            let failed_indexes = lua.create_table()?;
144            let errors = lua.create_table()?;
145            let mut index = 1;
146
147            while let Some(result) = tasks.next().await {
148                match result {
149                    Ok(Ok(_)) => {}
150                    Ok(Err((error, _msg))) => {
151                        failed_indexes.push(index)?;
152                        errors.push(format!("{error:#}"))?;
153                    }
154                    Err(error) => {
155                        failed_indexes.push(index)?;
156                        errors.push(format!("{error:#}"))?;
157                    }
158                }
159                index += 1;
160            }
161            Ok((failed_indexes, errors))
162        });
163
164        methods.add_method("close", |_lua, this, _: ()| {
165            this.producer.lock().unwrap().take();
166            Ok(())
167        });
168    }
169}
170
171pub fn register(lua: &Lua) -> anyhow::Result<()> {
172    let kafka_mod = get_or_create_sub_module(lua, "kafka")?;
173
174    kafka_mod.set(
175        "build_producer",
176        lua.create_async_function(|_, config: HashMap<String, String>| async move {
177            let mut builder = ClientConfig::new();
178            for (k, v) in config {
179                builder.set(k, v);
180            }
181
182            let producer = builder.create().map_err(any_err)?;
183
184            Ok(Producer {
185                producer: Arc::new(Mutex::new(Some(Arc::new(producer)))),
186            })
187        })?,
188    )?;
189
190    Ok(())
191}