Skip to main content

rust_data_processing/processing/
reduce.rs

1//! Reduction operations for [`crate::types::DataSet`].
2
3use std::collections::HashSet;
4
5use crate::types::{DataSet, DataType, Value};
6
7/// Population vs sample variance / standard deviation (`ddof` 0 vs 1).
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum VarianceKind {
10    /// Divide by `n` (when `n > 0`).
11    Population,
12    /// Divide by `n - 1` (when `n >= 2`); otherwise [`None`] / null.
13    Sample,
14}
15
16/// Built-in reduction operations over a single column.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ReduceOp {
19    /// Count all rows (including nulls).
20    Count,
21    /// Sum numeric values, ignoring nulls.
22    Sum,
23    /// Minimum numeric value, ignoring nulls.
24    Min,
25    /// Maximum numeric value, ignoring nulls.
26    Max,
27    /// Arithmetic mean of numeric values as [`Value::Float64`], ignoring nulls.
28    Mean,
29    /// Variance (Welford); null if no values, or sample with fewer than two values.
30    Variance(VarianceKind),
31    /// Standard deviation from variance; same null rules as [`ReduceOp::Variance`].
32    StdDev(VarianceKind),
33    /// \(\sum x^2\) over non-null numeric values as [`Value::Float64`].
34    SumSquares,
35    /// \(\sqrt{\sum x^2}\) over non-null numeric values as [`Value::Float64`].
36    L2Norm,
37    /// Count of distinct non-null values (returns [`Value::Int64`]).
38    CountDistinctNonNull,
39    /// Median of numeric values as [`Value::Float64`], ignoring nulls. Even count: average of the two middle values.
40    /// Non-numeric columns yield [`Value::Null`].
41    Median,
42}
43
44/// Reduce a column using a built-in [`ReduceOp`].
45///
46/// - Returns `None` if `column` does not exist in the schema.
47/// - For `Count`, always returns `Some(Value::Int64(row_count))`.
48/// - For numeric aggregates other than `Count` / `CountDistinctNonNull`, returns
49///   `Some(Value::Null)` if there are no non-null numeric values, or if the column type is not
50///   numeric (for those ops). `CountDistinctNonNull` supports [`DataType::Bool`] and
51///   [`DataType::Utf8`] as well as numeric types.
52/// - For [`ReduceOp::Median`], non-numeric columns return `Some(Value::Null)`; numeric columns
53///   return [`Value::Float64`] (even-length inputs use the mean of the two middle values).
54pub fn reduce(dataset: &DataSet, column: &str, op: ReduceOp) -> Option<Value> {
55    let idx = dataset.schema.index_of(column)?;
56
57    match op {
58        ReduceOp::Count => Some(Value::Int64(dataset.row_count() as i64)),
59        ReduceOp::CountDistinctNonNull => {
60            let field = dataset.schema.fields.get(idx)?;
61            reduce_count_distinct_non_null(dataset, idx, &field.data_type)
62        }
63        ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => match dataset.schema.fields.get(idx) {
64            Some(field) => reduce_numeric_typed(dataset, idx, field.data_type.clone(), op),
65            None => None,
66        },
67        ReduceOp::Mean
68        | ReduceOp::Variance(_)
69        | ReduceOp::StdDev(_)
70        | ReduceOp::SumSquares
71        | ReduceOp::L2Norm => match dataset.schema.fields.get(idx) {
72            Some(field) => reduce_numeric_float_stats(dataset, idx, field.data_type.clone(), op),
73            None => None,
74        },
75        ReduceOp::Median => reduce_median(dataset, idx),
76    }
77}
78
79fn reduce_median(dataset: &DataSet, idx: usize) -> Option<Value> {
80    let field = dataset.schema.fields.get(idx)?;
81    if !matches!(field.data_type, DataType::Int64 | DataType::Float64) {
82        return Some(Value::Null);
83    }
84    let is_int = matches!(field.data_type, DataType::Int64);
85    let mut xs: Vec<f64> = Vec::new();
86    for row in &dataset.rows {
87        let x = match row.get(idx) {
88            Some(Value::Null) | None => None,
89            Some(Value::Int64(v)) if is_int => Some(*v as f64),
90            Some(Value::Float64(v)) if !is_int => Some(*v),
91            _ => None,
92        };
93        if let Some(x) = x {
94            xs.push(x);
95        }
96    }
97    if xs.is_empty() {
98        return Some(Value::Null);
99    }
100    xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
101    let n = xs.len();
102    let med = if n % 2 == 1 {
103        xs[n / 2]
104    } else {
105        (xs[n / 2 - 1] + xs[n / 2]) / 2.0
106    };
107    Some(Value::Float64(med))
108}
109
110#[derive(Default)]
111pub(crate) struct Welford {
112    n: u64,
113    mean: f64,
114    m2: f64,
115}
116
117impl Welford {
118    pub(crate) fn observe(&mut self, x: f64) {
119        self.n += 1;
120        let delta = x - self.mean;
121        self.mean += delta / self.n as f64;
122        let delta2 = x - self.mean;
123        self.m2 += delta * delta2;
124    }
125
126    pub(crate) fn mean(&self) -> Option<f64> {
127        (self.n > 0).then_some(self.mean)
128    }
129
130    pub(crate) fn variance(&self, kind: VarianceKind) -> Option<f64> {
131        if self.n == 0 {
132            return None;
133        }
134        match kind {
135            VarianceKind::Population => Some(self.m2 / self.n as f64),
136            VarianceKind::Sample => {
137                if self.n < 2 {
138                    None
139                } else {
140                    Some(self.m2 / (self.n - 1) as f64)
141                }
142            }
143        }
144    }
145
146    pub(crate) fn observation_count(&self) -> u64 {
147        self.n
148    }
149}
150
151fn reduce_numeric_float_stats(
152    dataset: &DataSet,
153    idx: usize,
154    data_type: DataType,
155    op: ReduceOp,
156) -> Option<Value> {
157    match data_type {
158        dt @ (DataType::Int64 | DataType::Float64) => {
159            let is_int = matches!(dt, DataType::Int64);
160            let mut w = Welford::default();
161            let mut sum_squares = 0.0_f64;
162            let mut any = false;
163
164            for row in &dataset.rows {
165                let x = match row.get(idx) {
166                    Some(Value::Null) | None => None,
167                    Some(Value::Int64(v)) if is_int => Some(*v as f64),
168                    Some(Value::Float64(v)) if !is_int => Some(*v),
169                    Some(_) => None,
170                };
171                if let Some(x) = x {
172                    any = true;
173                    w.observe(x);
174                    sum_squares += x * x;
175                }
176            }
177
178            if !any {
179                return Some(Value::Null);
180            }
181
182            let out = match op {
183                ReduceOp::Mean => Value::Float64(w.mean().expect("n > 0")),
184                ReduceOp::Variance(kind) => match w.variance(kind) {
185                    Some(v) => Value::Float64(v),
186                    None => Value::Null,
187                },
188                ReduceOp::StdDev(kind) => match w.variance(kind) {
189                    Some(v) => Value::Float64(v.sqrt()),
190                    None => Value::Null,
191                },
192                ReduceOp::SumSquares => Value::Float64(sum_squares),
193                ReduceOp::L2Norm => Value::Float64(sum_squares.sqrt()),
194                _ => unreachable!("caller only dispatches float stats ops"),
195            };
196            Some(out)
197        }
198        _ => Some(Value::Null),
199    }
200}
201
202fn reduce_count_distinct_non_null(
203    dataset: &DataSet,
204    idx: usize,
205    data_type: &DataType,
206) -> Option<Value> {
207    let n = match data_type {
208        DataType::Int64 => {
209            let mut set = HashSet::new();
210            for row in &dataset.rows {
211                if let Some(Value::Int64(v)) = row.get(idx) {
212                    set.insert(*v);
213                }
214            }
215            set.len() as i64
216        }
217        DataType::Float64 => {
218            let mut set = HashSet::new();
219            for row in &dataset.rows {
220                if let Some(Value::Float64(v)) = row.get(idx) {
221                    set.insert(v.to_bits());
222                }
223            }
224            set.len() as i64
225        }
226        DataType::Bool => {
227            let mut set = HashSet::new();
228            for row in &dataset.rows {
229                if let Some(Value::Bool(v)) = row.get(idx) {
230                    set.insert(*v);
231                }
232            }
233            set.len() as i64
234        }
235        DataType::Utf8 => {
236            let mut set = HashSet::new();
237            for row in &dataset.rows {
238                if let Some(Value::Utf8(s)) = row.get(idx) {
239                    set.insert(s.clone());
240                }
241            }
242            set.len() as i64
243        }
244    };
245    Some(Value::Int64(n))
246}
247
248fn reduce_numeric_typed(
249    dataset: &DataSet,
250    idx: usize,
251    data_type: DataType,
252    op: ReduceOp,
253) -> Option<Value> {
254    match data_type {
255        DataType::Int64 => {
256            let mut acc: Option<i64> = None;
257            for row in &dataset.rows {
258                match row.get(idx) {
259                    Some(Value::Null) | None => {}
260                    Some(Value::Int64(v)) => {
261                        acc = Some(match (op, acc) {
262                            (ReduceOp::Sum, Some(a)) => a + v,
263                            (ReduceOp::Sum, None) => *v,
264                            (ReduceOp::Min, Some(a)) => a.min(*v),
265                            (ReduceOp::Min, None) => *v,
266                            (ReduceOp::Max, Some(a)) => a.max(*v),
267                            (ReduceOp::Max, None) => *v,
268                            _ => unreachable!("non-numeric op handled earlier"),
269                        });
270                    }
271                    Some(_) => {}
272                }
273            }
274            Some(acc.map(Value::Int64).unwrap_or(Value::Null))
275        }
276        DataType::Float64 => {
277            let mut acc: Option<f64> = None;
278            for row in &dataset.rows {
279                match row.get(idx) {
280                    Some(Value::Null) | None => {}
281                    Some(Value::Float64(v)) => {
282                        acc = Some(match (op, acc) {
283                            (ReduceOp::Sum, Some(a)) => a + v,
284                            (ReduceOp::Sum, None) => *v,
285                            (ReduceOp::Min, Some(a)) => a.min(*v),
286                            (ReduceOp::Min, None) => *v,
287                            (ReduceOp::Max, Some(a)) => a.max(*v),
288                            (ReduceOp::Max, None) => *v,
289                            _ => unreachable!("non-numeric op handled earlier"),
290                        });
291                    }
292                    Some(_) => {}
293                }
294            }
295            Some(acc.map(Value::Float64).unwrap_or(Value::Null))
296        }
297        _ => Some(Value::Null),
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::{ReduceOp, VarianceKind, reduce};
304    use crate::types::{DataSet, DataType, Field, Schema, Value};
305
306    fn numeric_dataset_with_nulls() -> DataSet {
307        let schema = Schema::new(vec![
308            Field::new("id", DataType::Int64),
309            Field::new("score", DataType::Float64),
310        ]);
311
312        let rows = vec![
313            vec![Value::Int64(1), Value::Float64(10.0)],
314            vec![Value::Int64(2), Value::Null],
315            vec![Value::Int64(3), Value::Float64(5.5)],
316        ];
317
318        DataSet::new(schema, rows)
319    }
320
321    #[test]
322    fn reduce_count_counts_rows() {
323        let ds = numeric_dataset_with_nulls();
324        assert_eq!(reduce(&ds, "score", ReduceOp::Count), Some(Value::Int64(3)));
325        assert_eq!(reduce(&ds, "id", ReduceOp::Count), Some(Value::Int64(3)));
326    }
327
328    #[test]
329    fn reduce_sum_ignores_nulls_and_preserves_type() {
330        let ds = numeric_dataset_with_nulls();
331        assert_eq!(
332            reduce(&ds, "score", ReduceOp::Sum),
333            Some(Value::Float64(15.5))
334        );
335        assert_eq!(reduce(&ds, "id", ReduceOp::Sum), Some(Value::Int64(6)));
336    }
337
338    #[test]
339    fn reduce_min_max_ignore_nulls() {
340        let ds = numeric_dataset_with_nulls();
341        assert_eq!(
342            reduce(&ds, "score", ReduceOp::Min),
343            Some(Value::Float64(5.5))
344        );
345        assert_eq!(
346            reduce(&ds, "score", ReduceOp::Max),
347            Some(Value::Float64(10.0))
348        );
349        assert_eq!(reduce(&ds, "id", ReduceOp::Min), Some(Value::Int64(1)));
350        assert_eq!(reduce(&ds, "id", ReduceOp::Max), Some(Value::Int64(3)));
351    }
352
353    #[test]
354    fn reduce_returns_none_for_missing_column() {
355        let ds = numeric_dataset_with_nulls();
356        assert_eq!(reduce(&ds, "missing", ReduceOp::Count), None);
357        assert_eq!(reduce(&ds, "missing", ReduceOp::Sum), None);
358    }
359
360    #[test]
361    fn reduce_numeric_returns_null_if_all_values_null() {
362        let schema = Schema::new(vec![Field::new("score", DataType::Float64)]);
363        let ds = DataSet::new(schema, vec![vec![Value::Null], vec![Value::Null]]);
364        assert_eq!(reduce(&ds, "score", ReduceOp::Sum), Some(Value::Null));
365        assert_eq!(reduce(&ds, "score", ReduceOp::Min), Some(Value::Null));
366        assert_eq!(reduce(&ds, "score", ReduceOp::Max), Some(Value::Null));
367        assert_eq!(reduce(&ds, "score", ReduceOp::Mean), Some(Value::Null));
368        assert_eq!(
369            reduce(&ds, "score", ReduceOp::Variance(VarianceKind::Population)),
370            Some(Value::Null)
371        );
372        assert_eq!(
373            reduce(&ds, "score", ReduceOp::StdDev(VarianceKind::Sample)),
374            Some(Value::Null)
375        );
376    }
377
378    #[test]
379    fn reduce_mean_float_and_int() {
380        let ds = numeric_dataset_with_nulls();
381        assert_eq!(
382            reduce(&ds, "score", ReduceOp::Mean),
383            Some(Value::Float64(7.75))
384        );
385        assert_eq!(reduce(&ds, "id", ReduceOp::Mean), Some(Value::Float64(2.0)));
386    }
387
388    #[test]
389    fn reduce_variance_std_known_values() {
390        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
391        let ds = DataSet::new(
392            schema,
393            vec![
394                vec![Value::Float64(1.0)],
395                vec![Value::Float64(2.0)],
396                vec![Value::Float64(3.0)],
397            ],
398        );
399        let pop = 2.0 / 3.0;
400        assert_eq!(
401            reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)),
402            Some(Value::Float64(pop))
403        );
404        assert_eq!(
405            reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)),
406            Some(Value::Float64(1.0))
407        );
408        let std_pop = reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Population)).unwrap();
409        match std_pop {
410            Value::Float64(v) => assert!((v - pop.sqrt()).abs() < 1e-12),
411            other => panic!("expected Float64, got {other:?}"),
412        }
413    }
414
415    #[test]
416    fn reduce_sample_variance_single_value_is_null() {
417        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
418        let ds = DataSet::new(schema, vec![vec![Value::Float64(42.0)]]);
419        assert_eq!(
420            reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)),
421            Some(Value::Null)
422        );
423    }
424
425    #[test]
426    fn reduce_population_variance_single_value_is_zero() {
427        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
428        let ds = DataSet::new(schema, vec![vec![Value::Float64(42.0)]]);
429        assert_eq!(
430            reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)),
431            Some(Value::Float64(0.0))
432        );
433        let std0 = reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Population)).unwrap();
434        match std0 {
435            Value::Float64(v) => assert_eq!(v, 0.0),
436            other => panic!("expected Float64, got {other:?}"),
437        }
438    }
439
440    #[test]
441    fn reduce_int64_mean_sum_squares_and_distinct() {
442        let schema = Schema::new(vec![Field::new("k", DataType::Int64)]);
443        let ds = DataSet::new(
444            schema,
445            vec![
446                vec![Value::Int64(2)],
447                vec![Value::Int64(3)],
448                vec![Value::Null],
449            ],
450        );
451        assert_eq!(reduce(&ds, "k", ReduceOp::Mean), Some(Value::Float64(2.5)));
452        assert_eq!(
453            reduce(&ds, "k", ReduceOp::SumSquares),
454            Some(Value::Float64(13.0))
455        );
456        assert_eq!(
457            reduce(&ds, "k", ReduceOp::L2Norm),
458            Some(Value::Float64(13.0_f64.sqrt()))
459        );
460        assert_eq!(
461            reduce(&ds, "k", ReduceOp::CountDistinctNonNull),
462            Some(Value::Int64(2))
463        );
464    }
465
466    #[test]
467    fn reduce_sum_squares_and_l2() {
468        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
469        let ds = DataSet::new(
470            schema,
471            vec![
472                vec![Value::Float64(3.0)],
473                vec![Value::Float64(4.0)],
474                vec![Value::Null],
475            ],
476        );
477        assert_eq!(
478            reduce(&ds, "x", ReduceOp::SumSquares),
479            Some(Value::Float64(25.0))
480        );
481        assert_eq!(
482            reduce(&ds, "x", ReduceOp::L2Norm),
483            Some(Value::Float64(5.0))
484        );
485    }
486
487    #[test]
488    fn reduce_count_distinct_non_null() {
489        let schema = Schema::new(vec![
490            Field::new("f", DataType::Float64),
491            Field::new("s", DataType::Utf8),
492        ]);
493        let ds = DataSet::new(
494            schema,
495            vec![
496                vec![Value::Float64(1.0), Value::Utf8("a".to_string())],
497                vec![Value::Float64(1.0), Value::Utf8("b".to_string())],
498                vec![Value::Null, Value::Null],
499            ],
500        );
501        assert_eq!(
502            reduce(&ds, "f", ReduceOp::CountDistinctNonNull),
503            Some(Value::Int64(1))
504        );
505        assert_eq!(
506            reduce(&ds, "s", ReduceOp::CountDistinctNonNull),
507            Some(Value::Int64(2))
508        );
509    }
510
511    #[test]
512    fn reduce_new_ops_return_none_for_missing_column() {
513        let ds = numeric_dataset_with_nulls();
514        assert_eq!(reduce(&ds, "nope", ReduceOp::Mean), None);
515        assert_eq!(
516            reduce(&ds, "nope", ReduceOp::Variance(VarianceKind::Sample)),
517            None
518        );
519        assert_eq!(reduce(&ds, "nope", ReduceOp::CountDistinctNonNull), None);
520    }
521
522    #[test]
523    fn reduce_sum_squares_and_l2_all_null() {
524        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
525        let ds = DataSet::new(schema, vec![vec![Value::Null]]);
526        assert_eq!(reduce(&ds, "x", ReduceOp::SumSquares), Some(Value::Null));
527        assert_eq!(reduce(&ds, "x", ReduceOp::L2Norm), Some(Value::Null));
528    }
529
530    #[test]
531    fn reduce_count_distinct_bool_and_empty_rows() {
532        let schema = Schema::new(vec![Field::new("b", DataType::Bool)]);
533        let ds = DataSet::new(schema.clone(), vec![]);
534        assert_eq!(
535            reduce(&ds, "b", ReduceOp::CountDistinctNonNull),
536            Some(Value::Int64(0))
537        );
538
539        let ds = DataSet::new(
540            schema,
541            vec![
542                vec![Value::Bool(true)],
543                vec![Value::Bool(false)],
544                vec![Value::Bool(true)],
545                vec![Value::Null],
546            ],
547        );
548        assert_eq!(
549            reduce(&ds, "b", ReduceOp::CountDistinctNonNull),
550            Some(Value::Int64(2))
551        );
552    }
553
554    #[test]
555    fn reduce_mean_variance_null_for_non_numeric_column() {
556        let schema = Schema::new(vec![Field::new("label", DataType::Utf8)]);
557        let ds = DataSet::new(
558            schema,
559            vec![
560                vec![Value::Utf8("a".to_string())],
561                vec![Value::Utf8("b".to_string())],
562            ],
563        );
564        assert_eq!(reduce(&ds, "label", ReduceOp::Mean), Some(Value::Null));
565        assert_eq!(
566            reduce(&ds, "label", ReduceOp::Variance(VarianceKind::Population)),
567            Some(Value::Null)
568        );
569        assert_eq!(
570            reduce(&ds, "label", ReduceOp::SumSquares),
571            Some(Value::Null)
572        );
573    }
574
575    #[test]
576    fn reduce_std_dev_sample_matches_sqrt_of_sample_variance() {
577        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
578        let ds = DataSet::new(
579            schema,
580            vec![
581                vec![Value::Float64(0.0)],
582                vec![Value::Float64(4.0)],
583                vec![Value::Float64(8.0)],
584            ],
585        );
586        let var_s = match reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)).unwrap() {
587            Value::Float64(v) => v,
588            other => panic!("expected Float64, got {other:?}"),
589        };
590        let std_s = match reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Sample)).unwrap() {
591            Value::Float64(v) => v,
592            other => panic!("expected Float64, got {other:?}"),
593        };
594        assert!((std_s - var_s.sqrt()).abs() < 1e-12);
595    }
596
597    #[test]
598    fn reduce_l2_squared_matches_sum_squares_for_non_nulls() {
599        let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
600        let ds = DataSet::new(
601            schema,
602            vec![vec![Value::Float64(2.0)], vec![Value::Float64(3.0)]],
603        );
604        let ss = match reduce(&ds, "x", ReduceOp::SumSquares).unwrap() {
605            Value::Float64(v) => v,
606            other => panic!("expected Float64, got {other:?}"),
607        };
608        let l2 = match reduce(&ds, "x", ReduceOp::L2Norm).unwrap() {
609            Value::Float64(v) => v,
610            other => panic!("expected Float64, got {other:?}"),
611        };
612        assert!((l2 * l2 - ss).abs() < 1e-12);
613    }
614
615    #[test]
616    fn reduce_median_odd_and_even() {
617        let schema = Schema::new(vec![Field::new("x", DataType::Int64)]);
618        let odd = DataSet::new(
619            schema.clone(),
620            vec![
621                vec![Value::Int64(3)],
622                vec![Value::Int64(1)],
623                vec![Value::Int64(2)],
624            ],
625        );
626        assert_eq!(
627            reduce(&odd, "x", ReduceOp::Median),
628            Some(Value::Float64(2.0))
629        );
630        let even = DataSet::new(
631            schema,
632            vec![
633                vec![Value::Int64(10)],
634                vec![Value::Int64(20)],
635                vec![Value::Int64(30)],
636                vec![Value::Int64(40)],
637            ],
638        );
639        assert_eq!(
640            reduce(&even, "x", ReduceOp::Median),
641            Some(Value::Float64(25.0))
642        );
643    }
644}