mod_kafka/
lib.rs

1use config::{any_err, get_or_create_sub_module};
2use futures::stream::FuturesOrdered;
3use futures::StreamExt;
4use mlua::prelude::LuaUserData;
5use mlua::{Lua, LuaSerdeExt, UserDataMethods, Value};
6use rdkafka::message::{Header, OwnedHeaders};
7use rdkafka::producer::{FutureProducer, FutureRecord};
8use rdkafka::util::Timeout;
9use rdkafka::ClientConfig;
10use serde::Deserialize;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14
15#[derive(Clone)]
16struct Producer {
17    producer: Arc<Mutex<Option<Arc<FutureProducer>>>>,
18}
19
20impl Producer {
21    fn get_producer(&self) -> mlua::Result<Arc<FutureProducer>> {
22        self.producer
23            .lock()
24            .unwrap()
25            .as_ref()
26            .map(Arc::clone)
27            .ok_or_else(|| mlua::Error::external("client was closed"))
28    }
29}
30
31#[derive(Deserialize, Debug)]
32struct Record {
33    /// Required destination topic
34    topic: String,
35    /// Optional destination partition
36    #[serde(default)]
37    partition: Option<i32>,
38    /// Optional payload
39    #[serde(with = "serde_bytes")]
40    payload: Option<Vec<u8>>,
41    /// Optional key
42    #[serde(default)]
43    key: Option<String>,
44
45    /// Optional headers
46    #[serde(default)]
47    headers: HashMap<String, String>,
48
49    /// Optional timeout. If no timeout is provided, assume 1 minute.
50    /// The timeout is how long to keep retrying to submit to kafka
51    /// before giving up on this attempt.
52    /// Note that the underlying library supports retrying forever,
53    /// but in kumomta we don't allow that; we can retry later without
54    /// keeping the system occupied for an indefinite time.
55    #[serde(default)]
56    #[serde(with = "duration_serde")]
57    timeout: Option<Duration>,
58}
59
60impl LuaUserData for Producer {
61    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
62        methods.add_async_method("send", |lua, this, value: Value| async move {
63            let record: Record = lua.from_value(value)?;
64
65            let headers = if record.headers.is_empty() {
66                None
67            } else {
68                let mut headers = OwnedHeaders::new();
69                for (key, v) in &record.headers {
70                    headers = headers.insert(Header {
71                        key,
72                        value: Some(v),
73                    });
74                }
75                Some(headers)
76            };
77
78            let future_record = FutureRecord {
79                topic: &record.topic,
80                partition: record.partition,
81                payload: record.payload.as_ref(),
82                key: record.key.as_ref(),
83                headers,
84                timestamp: None,
85            };
86
87            let (partition, offset) = this
88                .get_producer()?
89                .send(
90                    future_record,
91                    Timeout::After(record.timeout.unwrap_or(Duration::from_secs(60))),
92                )
93                .await
94                .map_err(|(code, _msg)| any_err(code))?;
95
96            Ok((partition, offset))
97        });
98
99        methods.add_async_method("send_batch", |lua, this, values: Vec<Value>| async move {
100            let mut tasks = FuturesOrdered::new();
101            let producer = this.get_producer()?;
102
103            for value in values {
104                let record: Record = lua.from_value(value)?;
105
106                let headers = if record.headers.is_empty() {
107                    None
108                } else {
109                    let mut headers = OwnedHeaders::new();
110                    for (key, v) in &record.headers {
111                        headers = headers.insert(Header {
112                            key,
113                            value: Some(v),
114                        });
115                    }
116                    Some(headers)
117                };
118
119                let producer = producer.clone();
120
121                tasks.push_back(tokio::spawn(async move {
122                    producer
123                        .send(
124                            FutureRecord {
125                                topic: &record.topic,
126                                partition: record.partition,
127                                payload: record.payload.as_ref(),
128                                key: record.key.as_ref(),
129                                headers,
130                                timestamp: None,
131                            },
132                            Timeout::After(record.timeout.unwrap_or(Duration::from_secs(60))),
133                        )
134                        .await
135                }));
136            }
137
138            let failed_indexes = lua.create_table()?;
139            let errors = lua.create_table()?;
140            let mut index = 1;
141
142            while let Some(result) = tasks.next().await {
143                match result {
144                    Ok(Ok(_)) => {}
145                    Ok(Err((error, _msg))) => {
146                        failed_indexes.push(index)?;
147                        errors.push(format!("{error:#}"))?;
148                    }
149                    Err(error) => {
150                        failed_indexes.push(index)?;
151                        errors.push(format!("{error:#}"))?;
152                    }
153                }
154                index += 1;
155            }
156            Ok((failed_indexes, errors))
157        });
158
159        methods.add_method("close", |_lua, this, _: ()| {
160            this.producer.lock().unwrap().take();
161            Ok(())
162        });
163    }
164}
165
166pub fn register(lua: &Lua) -> anyhow::Result<()> {
167    let kafka_mod = get_or_create_sub_module(lua, "kafka")?;
168
169    kafka_mod.set(
170        "build_producer",
171        lua.create_async_function(|_, config: HashMap<String, String>| async move {
172            let mut builder = ClientConfig::new();
173            for (k, v) in config {
174                builder.set(k, v);
175            }
176
177            let producer = builder.create().map_err(any_err)?;
178
179            Ok(Producer {
180                producer: Arc::new(Mutex::new(Some(Arc::new(producer)))),
181            })
182        })?,
183    )?;
184
185    Ok(())
186}