mod_counter_series/
lib.rs

1use config::{SerdeWrappedValue, get_or_create_sub_module};
2use dashmap::DashMap;
3use dashmap::mapref::entry::Entry;
4use kumo_counter_series::{CounterSeries, CounterSeriesConfig};
5use mlua::{Lua, UserData, UserDataMethods};
6use parking_lot::Mutex;
7use serde::Deserialize;
8use std::sync::{Arc, LazyLock};
9use std::time::Duration;
10
11type LuaDuration = SerdeWrappedValue<duration_serde::Wrap<Duration>>;
12
13#[derive(Deserialize)]
14#[serde(deny_unknown_fields)]
15struct DefineParams {
16    name: String,
17    num_buckets: u8,
18    #[serde(with = "duration_serde")]
19    bucket_size: Duration,
20    #[serde(default)]
21    initial_value: Option<u64>,
22}
23
24struct CachedSeries {
25    num_buckets: u8,
26    bucket_size_seconds: u64,
27    series: Arc<Mutex<CounterSeries>>,
28}
29
30static CACHE: LazyLock<DashMap<String, CachedSeries>> = LazyLock::new(DashMap::new);
31
32struct LuaCounterSeries {
33    series: Arc<Mutex<CounterSeries>>,
34}
35
36impl UserData for LuaCounterSeries {
37    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
38        methods.add_method("increment", |_lua, this, inc: u64| {
39            this.series.lock().increment(inc);
40            Ok(())
41        });
42        methods.add_method("delta", |_lua, this, delta: i64| {
43            this.series.lock().delta(delta);
44            Ok(())
45        });
46        methods.add_method("observe", |_lua, this, value: u64| {
47            this.series.lock().observe(value);
48            Ok(())
49        });
50        methods.add_method("sum", |_lua, this, ()| Ok(this.series.lock().sum()));
51        methods.add_method("sum_over", |_lua, this, duration: LuaDuration| {
52            Ok(this.series.lock().sum_over((*duration).into_inner()))
53        });
54    }
55}
56
57/// Rounds a `Duration` up to the nearest whole second. A duration with any
58/// non-zero sub-second component is rounded up to the next full second.
59fn round_up_to_seconds(d: Duration) -> u64 {
60    let secs = d.as_secs();
61    if d.subsec_nanos() > 0 { secs + 1 } else { secs }
62}
63
64fn make_config(num_buckets: u8, bucket_size_seconds: u64) -> mlua::Result<CounterSeriesConfig> {
65    if num_buckets == 0 {
66        return Err(mlua::Error::external("num_buckets must be >= 1"));
67    }
68    if bucket_size_seconds == 0 {
69        return Err(mlua::Error::external("bucket_size must be >= 1 second"));
70    }
71
72    Ok(CounterSeriesConfig {
73        num_buckets,
74        bucket_size: bucket_size_seconds,
75    })
76}
77
78fn build_cached(
79    num_buckets: u8,
80    bucket_size_seconds: u64,
81    initial_value: Option<u64>,
82) -> mlua::Result<CachedSeries> {
83    let config = make_config(num_buckets, bucket_size_seconds)?;
84    let series = match initial_value {
85        Some(value) => CounterSeries::with_initial_value(config, value),
86        None => CounterSeries::with_config(config),
87    };
88    Ok(CachedSeries {
89        num_buckets,
90        bucket_size_seconds,
91        series: Arc::new(Mutex::new(series)),
92    })
93}
94
95pub fn register(lua: &Lua) -> anyhow::Result<()> {
96    let module = get_or_create_sub_module(lua, "counter_series")?;
97
98    module.set(
99        "define",
100        lua.create_function(|_lua, params: SerdeWrappedValue<DefineParams>| {
101            let DefineParams {
102                name,
103                num_buckets,
104                bucket_size,
105                initial_value,
106            } = params.0;
107            let bucket_size_seconds = round_up_to_seconds(bucket_size);
108
109            let series = match CACHE.entry(name) {
110                Entry::Occupied(mut entry) => {
111                    let cached = entry.get();
112                    if cached.num_buckets == num_buckets
113                        && cached.bucket_size_seconds == bucket_size_seconds
114                    {
115                        // Same shape: preserve existing values, ignore initial_value.
116                        Arc::clone(&cached.series)
117                    } else {
118                        // Shape changed: replace the cached series with a fresh one.
119                        let fresh = build_cached(num_buckets, bucket_size_seconds, initial_value)?;
120                        let series = Arc::clone(&fresh.series);
121                        entry.insert(fresh);
122                        series
123                    }
124                }
125                Entry::Vacant(entry) => {
126                    let fresh = build_cached(num_buckets, bucket_size_seconds, initial_value)?;
127                    let series = Arc::clone(&fresh.series);
128                    entry.insert(fresh);
129                    series
130                }
131            };
132
133            Ok(LuaCounterSeries { series })
134        })?,
135    )?;
136
137    Ok(())
138}