mod_counter_series/
lib.rs

1use config::{SerdeWrappedValue, get_or_create_sub_module};
2use dashmap::DashMap;
3use dashmap::mapref::entry::Entry;
4use kumo_counter_series::{CounterSeries, CounterSeriesConfig};
5use mlua::{Lua, UserData, UserDataMethods};
6use parking_lot::Mutex;
7use serde::Deserialize;
8use std::sync::{Arc, LazyLock};
9use std::time::Duration;
10
11type LuaDuration = SerdeWrappedValue<duration_serde::Wrap<Duration>>;
12
13#[derive(Deserialize)]
14#[serde(deny_unknown_fields)]
15struct DefineParams {
16    name: String,
17    num_buckets: u8,
18    #[serde(with = "duration_serde")]
19    bucket_size: Duration,
20    #[serde(default)]
21    initial_value: Option<u64>,
22}
23
24struct CachedSeries {
25    num_buckets: u8,
26    bucket_size_seconds: u64,
27    series: Arc<Mutex<CounterSeries>>,
28}
29
30static CACHE: LazyLock<DashMap<String, CachedSeries>> = LazyLock::new(DashMap::new);
31
32struct LuaCounterSeries {
33    series: Arc<Mutex<CounterSeries>>,
34}
35
36impl UserData for LuaCounterSeries {
37    fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
38        methods.add_method("increment", |_lua, this, inc: u64| {
39            this.series.lock().increment(inc);
40            Ok(())
41        });
42        methods.add_method("delta", |_lua, this, delta: i64| {
43            this.series.lock().delta(delta);
44            Ok(())
45        });
46        methods.add_method("observe", |_lua, this, value: u64| {
47            this.series.lock().observe(value);
48            Ok(())
49        });
50        methods.add_method("sum", |_lua, this, ()| Ok(this.series.lock().sum()));
51        methods.add_method("sum_over", |_lua, this, duration: LuaDuration| {
52            Ok(this.series.lock().sum_over((*duration).into_inner()))
53        });
54    }
55}
56
57/// Rounds a `Duration` up to the nearest whole second. A duration with any
58/// non-zero sub-second component is rounded up to the next full second.
59fn round_up_to_seconds(d: Duration) -> u64 {
60    let secs = d.as_secs();
61    if d.subsec_nanos() > 0 {
62        secs + 1
63    } else {
64        secs
65    }
66}
67
68fn make_config(num_buckets: u8, bucket_size_seconds: u64) -> mlua::Result<CounterSeriesConfig> {
69    if num_buckets == 0 {
70        return Err(mlua::Error::external("num_buckets must be >= 1"));
71    }
72    if bucket_size_seconds == 0 {
73        return Err(mlua::Error::external("bucket_size must be >= 1 second"));
74    }
75
76    Ok(CounterSeriesConfig {
77        num_buckets,
78        bucket_size: bucket_size_seconds,
79    })
80}
81
82fn build_cached(
83    num_buckets: u8,
84    bucket_size_seconds: u64,
85    initial_value: Option<u64>,
86) -> mlua::Result<CachedSeries> {
87    let config = make_config(num_buckets, bucket_size_seconds)?;
88    let series = match initial_value {
89        Some(value) => CounterSeries::with_initial_value(config, value),
90        None => CounterSeries::with_config(config),
91    };
92    Ok(CachedSeries {
93        num_buckets,
94        bucket_size_seconds,
95        series: Arc::new(Mutex::new(series)),
96    })
97}
98
99pub fn register(lua: &Lua) -> anyhow::Result<()> {
100    let module = get_or_create_sub_module(lua, "counter_series")?;
101
102    module.set(
103        "define",
104        lua.create_function(|_lua, params: SerdeWrappedValue<DefineParams>| {
105            let DefineParams {
106                name,
107                num_buckets,
108                bucket_size,
109                initial_value,
110            } = params.0;
111            let bucket_size_seconds = round_up_to_seconds(bucket_size);
112
113            let series = match CACHE.entry(name) {
114                Entry::Occupied(mut entry) => {
115                    let cached = entry.get();
116                    if cached.num_buckets == num_buckets
117                        && cached.bucket_size_seconds == bucket_size_seconds
118                    {
119                        // Same shape: preserve existing values, ignore initial_value.
120                        Arc::clone(&cached.series)
121                    } else {
122                        // Shape changed: replace the cached series with a fresh one.
123                        let fresh = build_cached(num_buckets, bucket_size_seconds, initial_value)?;
124                        let series = Arc::clone(&fresh.series);
125                        entry.insert(fresh);
126                        series
127                    }
128                }
129                Entry::Vacant(entry) => {
130                    let fresh = build_cached(num_buckets, bucket_size_seconds, initial_value)?;
131                    let series = Arc::clone(&fresh.series);
132                    entry.insert(fresh);
133                    series
134                }
135            };
136
137            Ok(LuaCounterSeries { series })
138        })?,
139    )?;
140
141    Ok(())
142}