1use std::collections::HashSet;
4
5use crate::types::{DataSet, DataType, Value};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum VarianceKind {
10 Population,
12 Sample,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ReduceOp {
19 Count,
21 Sum,
23 Min,
25 Max,
27 Mean,
29 Variance(VarianceKind),
31 StdDev(VarianceKind),
33 SumSquares,
35 L2Norm,
37 CountDistinctNonNull,
39 Median,
42}
43
44pub fn reduce(dataset: &DataSet, column: &str, op: ReduceOp) -> Option<Value> {
55 let idx = dataset.schema.index_of(column)?;
56
57 match op {
58 ReduceOp::Count => Some(Value::Int64(dataset.row_count() as i64)),
59 ReduceOp::CountDistinctNonNull => {
60 let field = dataset.schema.fields.get(idx)?;
61 reduce_count_distinct_non_null(dataset, idx, &field.data_type)
62 }
63 ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => match dataset.schema.fields.get(idx) {
64 Some(field) => reduce_numeric_typed(dataset, idx, field.data_type.clone(), op),
65 None => None,
66 },
67 ReduceOp::Mean
68 | ReduceOp::Variance(_)
69 | ReduceOp::StdDev(_)
70 | ReduceOp::SumSquares
71 | ReduceOp::L2Norm => match dataset.schema.fields.get(idx) {
72 Some(field) => reduce_numeric_float_stats(dataset, idx, field.data_type.clone(), op),
73 None => None,
74 },
75 ReduceOp::Median => reduce_median(dataset, idx),
76 }
77}
78
79fn reduce_median(dataset: &DataSet, idx: usize) -> Option<Value> {
80 let field = dataset.schema.fields.get(idx)?;
81 if !matches!(field.data_type, DataType::Int64 | DataType::Float64) {
82 return Some(Value::Null);
83 }
84 let is_int = matches!(field.data_type, DataType::Int64);
85 let mut xs: Vec<f64> = Vec::new();
86 for row in &dataset.rows {
87 let x = match row.get(idx) {
88 Some(Value::Null) | None => None,
89 Some(Value::Int64(v)) if is_int => Some(*v as f64),
90 Some(Value::Float64(v)) if !is_int => Some(*v),
91 _ => None,
92 };
93 if let Some(x) = x {
94 xs.push(x);
95 }
96 }
97 if xs.is_empty() {
98 return Some(Value::Null);
99 }
100 xs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
101 let n = xs.len();
102 let med = if n % 2 == 1 {
103 xs[n / 2]
104 } else {
105 (xs[n / 2 - 1] + xs[n / 2]) / 2.0
106 };
107 Some(Value::Float64(med))
108}
109
110#[derive(Default)]
111pub(crate) struct Welford {
112 n: u64,
113 mean: f64,
114 m2: f64,
115}
116
117impl Welford {
118 pub(crate) fn observe(&mut self, x: f64) {
119 self.n += 1;
120 let delta = x - self.mean;
121 self.mean += delta / self.n as f64;
122 let delta2 = x - self.mean;
123 self.m2 += delta * delta2;
124 }
125
126 pub(crate) fn mean(&self) -> Option<f64> {
127 (self.n > 0).then_some(self.mean)
128 }
129
130 pub(crate) fn variance(&self, kind: VarianceKind) -> Option<f64> {
131 if self.n == 0 {
132 return None;
133 }
134 match kind {
135 VarianceKind::Population => Some(self.m2 / self.n as f64),
136 VarianceKind::Sample => {
137 if self.n < 2 {
138 None
139 } else {
140 Some(self.m2 / (self.n - 1) as f64)
141 }
142 }
143 }
144 }
145
146 pub(crate) fn observation_count(&self) -> u64 {
147 self.n
148 }
149}
150
151fn reduce_numeric_float_stats(
152 dataset: &DataSet,
153 idx: usize,
154 data_type: DataType,
155 op: ReduceOp,
156) -> Option<Value> {
157 match data_type {
158 dt @ (DataType::Int64 | DataType::Float64) => {
159 let is_int = matches!(dt, DataType::Int64);
160 let mut w = Welford::default();
161 let mut sum_squares = 0.0_f64;
162 let mut any = false;
163
164 for row in &dataset.rows {
165 let x = match row.get(idx) {
166 Some(Value::Null) | None => None,
167 Some(Value::Int64(v)) if is_int => Some(*v as f64),
168 Some(Value::Float64(v)) if !is_int => Some(*v),
169 Some(_) => None,
170 };
171 if let Some(x) = x {
172 any = true;
173 w.observe(x);
174 sum_squares += x * x;
175 }
176 }
177
178 if !any {
179 return Some(Value::Null);
180 }
181
182 let out = match op {
183 ReduceOp::Mean => Value::Float64(w.mean().expect("n > 0")),
184 ReduceOp::Variance(kind) => match w.variance(kind) {
185 Some(v) => Value::Float64(v),
186 None => Value::Null,
187 },
188 ReduceOp::StdDev(kind) => match w.variance(kind) {
189 Some(v) => Value::Float64(v.sqrt()),
190 None => Value::Null,
191 },
192 ReduceOp::SumSquares => Value::Float64(sum_squares),
193 ReduceOp::L2Norm => Value::Float64(sum_squares.sqrt()),
194 _ => unreachable!("caller only dispatches float stats ops"),
195 };
196 Some(out)
197 }
198 _ => Some(Value::Null),
199 }
200}
201
202fn reduce_count_distinct_non_null(
203 dataset: &DataSet,
204 idx: usize,
205 data_type: &DataType,
206) -> Option<Value> {
207 let n = match data_type {
208 DataType::Int64 => {
209 let mut set = HashSet::new();
210 for row in &dataset.rows {
211 if let Some(Value::Int64(v)) = row.get(idx) {
212 set.insert(*v);
213 }
214 }
215 set.len() as i64
216 }
217 DataType::Float64 => {
218 let mut set = HashSet::new();
219 for row in &dataset.rows {
220 if let Some(Value::Float64(v)) = row.get(idx) {
221 set.insert(v.to_bits());
222 }
223 }
224 set.len() as i64
225 }
226 DataType::Bool => {
227 let mut set = HashSet::new();
228 for row in &dataset.rows {
229 if let Some(Value::Bool(v)) = row.get(idx) {
230 set.insert(*v);
231 }
232 }
233 set.len() as i64
234 }
235 DataType::Utf8 => {
236 let mut set = HashSet::new();
237 for row in &dataset.rows {
238 if let Some(Value::Utf8(s)) = row.get(idx) {
239 set.insert(s.clone());
240 }
241 }
242 set.len() as i64
243 }
244 };
245 Some(Value::Int64(n))
246}
247
248fn reduce_numeric_typed(
249 dataset: &DataSet,
250 idx: usize,
251 data_type: DataType,
252 op: ReduceOp,
253) -> Option<Value> {
254 match data_type {
255 DataType::Int64 => {
256 let mut acc: Option<i64> = None;
257 for row in &dataset.rows {
258 match row.get(idx) {
259 Some(Value::Null) | None => {}
260 Some(Value::Int64(v)) => {
261 acc = Some(match (op, acc) {
262 (ReduceOp::Sum, Some(a)) => a + v,
263 (ReduceOp::Sum, None) => *v,
264 (ReduceOp::Min, Some(a)) => a.min(*v),
265 (ReduceOp::Min, None) => *v,
266 (ReduceOp::Max, Some(a)) => a.max(*v),
267 (ReduceOp::Max, None) => *v,
268 _ => unreachable!("non-numeric op handled earlier"),
269 });
270 }
271 Some(_) => {}
272 }
273 }
274 Some(acc.map(Value::Int64).unwrap_or(Value::Null))
275 }
276 DataType::Float64 => {
277 let mut acc: Option<f64> = None;
278 for row in &dataset.rows {
279 match row.get(idx) {
280 Some(Value::Null) | None => {}
281 Some(Value::Float64(v)) => {
282 acc = Some(match (op, acc) {
283 (ReduceOp::Sum, Some(a)) => a + v,
284 (ReduceOp::Sum, None) => *v,
285 (ReduceOp::Min, Some(a)) => a.min(*v),
286 (ReduceOp::Min, None) => *v,
287 (ReduceOp::Max, Some(a)) => a.max(*v),
288 (ReduceOp::Max, None) => *v,
289 _ => unreachable!("non-numeric op handled earlier"),
290 });
291 }
292 Some(_) => {}
293 }
294 }
295 Some(acc.map(Value::Float64).unwrap_or(Value::Null))
296 }
297 _ => Some(Value::Null),
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::{ReduceOp, VarianceKind, reduce};
304 use crate::types::{DataSet, DataType, Field, Schema, Value};
305
306 fn numeric_dataset_with_nulls() -> DataSet {
307 let schema = Schema::new(vec![
308 Field::new("id", DataType::Int64),
309 Field::new("score", DataType::Float64),
310 ]);
311
312 let rows = vec![
313 vec![Value::Int64(1), Value::Float64(10.0)],
314 vec![Value::Int64(2), Value::Null],
315 vec![Value::Int64(3), Value::Float64(5.5)],
316 ];
317
318 DataSet::new(schema, rows)
319 }
320
321 #[test]
322 fn reduce_count_counts_rows() {
323 let ds = numeric_dataset_with_nulls();
324 assert_eq!(reduce(&ds, "score", ReduceOp::Count), Some(Value::Int64(3)));
325 assert_eq!(reduce(&ds, "id", ReduceOp::Count), Some(Value::Int64(3)));
326 }
327
328 #[test]
329 fn reduce_sum_ignores_nulls_and_preserves_type() {
330 let ds = numeric_dataset_with_nulls();
331 assert_eq!(
332 reduce(&ds, "score", ReduceOp::Sum),
333 Some(Value::Float64(15.5))
334 );
335 assert_eq!(reduce(&ds, "id", ReduceOp::Sum), Some(Value::Int64(6)));
336 }
337
338 #[test]
339 fn reduce_min_max_ignore_nulls() {
340 let ds = numeric_dataset_with_nulls();
341 assert_eq!(
342 reduce(&ds, "score", ReduceOp::Min),
343 Some(Value::Float64(5.5))
344 );
345 assert_eq!(
346 reduce(&ds, "score", ReduceOp::Max),
347 Some(Value::Float64(10.0))
348 );
349 assert_eq!(reduce(&ds, "id", ReduceOp::Min), Some(Value::Int64(1)));
350 assert_eq!(reduce(&ds, "id", ReduceOp::Max), Some(Value::Int64(3)));
351 }
352
353 #[test]
354 fn reduce_returns_none_for_missing_column() {
355 let ds = numeric_dataset_with_nulls();
356 assert_eq!(reduce(&ds, "missing", ReduceOp::Count), None);
357 assert_eq!(reduce(&ds, "missing", ReduceOp::Sum), None);
358 }
359
360 #[test]
361 fn reduce_numeric_returns_null_if_all_values_null() {
362 let schema = Schema::new(vec![Field::new("score", DataType::Float64)]);
363 let ds = DataSet::new(schema, vec![vec![Value::Null], vec![Value::Null]]);
364 assert_eq!(reduce(&ds, "score", ReduceOp::Sum), Some(Value::Null));
365 assert_eq!(reduce(&ds, "score", ReduceOp::Min), Some(Value::Null));
366 assert_eq!(reduce(&ds, "score", ReduceOp::Max), Some(Value::Null));
367 assert_eq!(reduce(&ds, "score", ReduceOp::Mean), Some(Value::Null));
368 assert_eq!(
369 reduce(&ds, "score", ReduceOp::Variance(VarianceKind::Population)),
370 Some(Value::Null)
371 );
372 assert_eq!(
373 reduce(&ds, "score", ReduceOp::StdDev(VarianceKind::Sample)),
374 Some(Value::Null)
375 );
376 }
377
378 #[test]
379 fn reduce_mean_float_and_int() {
380 let ds = numeric_dataset_with_nulls();
381 assert_eq!(
382 reduce(&ds, "score", ReduceOp::Mean),
383 Some(Value::Float64(7.75))
384 );
385 assert_eq!(reduce(&ds, "id", ReduceOp::Mean), Some(Value::Float64(2.0)));
386 }
387
388 #[test]
389 fn reduce_variance_std_known_values() {
390 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
391 let ds = DataSet::new(
392 schema,
393 vec![
394 vec![Value::Float64(1.0)],
395 vec![Value::Float64(2.0)],
396 vec![Value::Float64(3.0)],
397 ],
398 );
399 let pop = 2.0 / 3.0;
400 assert_eq!(
401 reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)),
402 Some(Value::Float64(pop))
403 );
404 assert_eq!(
405 reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)),
406 Some(Value::Float64(1.0))
407 );
408 let std_pop = reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Population)).unwrap();
409 match std_pop {
410 Value::Float64(v) => assert!((v - pop.sqrt()).abs() < 1e-12),
411 other => panic!("expected Float64, got {other:?}"),
412 }
413 }
414
415 #[test]
416 fn reduce_sample_variance_single_value_is_null() {
417 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
418 let ds = DataSet::new(schema, vec![vec![Value::Float64(42.0)]]);
419 assert_eq!(
420 reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)),
421 Some(Value::Null)
422 );
423 }
424
425 #[test]
426 fn reduce_population_variance_single_value_is_zero() {
427 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
428 let ds = DataSet::new(schema, vec![vec![Value::Float64(42.0)]]);
429 assert_eq!(
430 reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)),
431 Some(Value::Float64(0.0))
432 );
433 let std0 = reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Population)).unwrap();
434 match std0 {
435 Value::Float64(v) => assert_eq!(v, 0.0),
436 other => panic!("expected Float64, got {other:?}"),
437 }
438 }
439
440 #[test]
441 fn reduce_int64_mean_sum_squares_and_distinct() {
442 let schema = Schema::new(vec![Field::new("k", DataType::Int64)]);
443 let ds = DataSet::new(
444 schema,
445 vec![
446 vec![Value::Int64(2)],
447 vec![Value::Int64(3)],
448 vec![Value::Null],
449 ],
450 );
451 assert_eq!(reduce(&ds, "k", ReduceOp::Mean), Some(Value::Float64(2.5)));
452 assert_eq!(
453 reduce(&ds, "k", ReduceOp::SumSquares),
454 Some(Value::Float64(13.0))
455 );
456 assert_eq!(
457 reduce(&ds, "k", ReduceOp::L2Norm),
458 Some(Value::Float64(13.0_f64.sqrt()))
459 );
460 assert_eq!(
461 reduce(&ds, "k", ReduceOp::CountDistinctNonNull),
462 Some(Value::Int64(2))
463 );
464 }
465
466 #[test]
467 fn reduce_sum_squares_and_l2() {
468 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
469 let ds = DataSet::new(
470 schema,
471 vec![
472 vec![Value::Float64(3.0)],
473 vec![Value::Float64(4.0)],
474 vec![Value::Null],
475 ],
476 );
477 assert_eq!(
478 reduce(&ds, "x", ReduceOp::SumSquares),
479 Some(Value::Float64(25.0))
480 );
481 assert_eq!(
482 reduce(&ds, "x", ReduceOp::L2Norm),
483 Some(Value::Float64(5.0))
484 );
485 }
486
487 #[test]
488 fn reduce_count_distinct_non_null() {
489 let schema = Schema::new(vec![
490 Field::new("f", DataType::Float64),
491 Field::new("s", DataType::Utf8),
492 ]);
493 let ds = DataSet::new(
494 schema,
495 vec![
496 vec![Value::Float64(1.0), Value::Utf8("a".to_string())],
497 vec![Value::Float64(1.0), Value::Utf8("b".to_string())],
498 vec![Value::Null, Value::Null],
499 ],
500 );
501 assert_eq!(
502 reduce(&ds, "f", ReduceOp::CountDistinctNonNull),
503 Some(Value::Int64(1))
504 );
505 assert_eq!(
506 reduce(&ds, "s", ReduceOp::CountDistinctNonNull),
507 Some(Value::Int64(2))
508 );
509 }
510
511 #[test]
512 fn reduce_new_ops_return_none_for_missing_column() {
513 let ds = numeric_dataset_with_nulls();
514 assert_eq!(reduce(&ds, "nope", ReduceOp::Mean), None);
515 assert_eq!(
516 reduce(&ds, "nope", ReduceOp::Variance(VarianceKind::Sample)),
517 None
518 );
519 assert_eq!(reduce(&ds, "nope", ReduceOp::CountDistinctNonNull), None);
520 }
521
522 #[test]
523 fn reduce_sum_squares_and_l2_all_null() {
524 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
525 let ds = DataSet::new(schema, vec![vec![Value::Null]]);
526 assert_eq!(reduce(&ds, "x", ReduceOp::SumSquares), Some(Value::Null));
527 assert_eq!(reduce(&ds, "x", ReduceOp::L2Norm), Some(Value::Null));
528 }
529
530 #[test]
531 fn reduce_count_distinct_bool_and_empty_rows() {
532 let schema = Schema::new(vec![Field::new("b", DataType::Bool)]);
533 let ds = DataSet::new(schema.clone(), vec![]);
534 assert_eq!(
535 reduce(&ds, "b", ReduceOp::CountDistinctNonNull),
536 Some(Value::Int64(0))
537 );
538
539 let ds = DataSet::new(
540 schema,
541 vec![
542 vec![Value::Bool(true)],
543 vec![Value::Bool(false)],
544 vec![Value::Bool(true)],
545 vec![Value::Null],
546 ],
547 );
548 assert_eq!(
549 reduce(&ds, "b", ReduceOp::CountDistinctNonNull),
550 Some(Value::Int64(2))
551 );
552 }
553
554 #[test]
555 fn reduce_mean_variance_null_for_non_numeric_column() {
556 let schema = Schema::new(vec![Field::new("label", DataType::Utf8)]);
557 let ds = DataSet::new(
558 schema,
559 vec![
560 vec![Value::Utf8("a".to_string())],
561 vec![Value::Utf8("b".to_string())],
562 ],
563 );
564 assert_eq!(reduce(&ds, "label", ReduceOp::Mean), Some(Value::Null));
565 assert_eq!(
566 reduce(&ds, "label", ReduceOp::Variance(VarianceKind::Population)),
567 Some(Value::Null)
568 );
569 assert_eq!(
570 reduce(&ds, "label", ReduceOp::SumSquares),
571 Some(Value::Null)
572 );
573 }
574
575 #[test]
576 fn reduce_std_dev_sample_matches_sqrt_of_sample_variance() {
577 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
578 let ds = DataSet::new(
579 schema,
580 vec![
581 vec![Value::Float64(0.0)],
582 vec![Value::Float64(4.0)],
583 vec![Value::Float64(8.0)],
584 ],
585 );
586 let var_s = match reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Sample)).unwrap() {
587 Value::Float64(v) => v,
588 other => panic!("expected Float64, got {other:?}"),
589 };
590 let std_s = match reduce(&ds, "x", ReduceOp::StdDev(VarianceKind::Sample)).unwrap() {
591 Value::Float64(v) => v,
592 other => panic!("expected Float64, got {other:?}"),
593 };
594 assert!((std_s - var_s.sqrt()).abs() < 1e-12);
595 }
596
597 #[test]
598 fn reduce_l2_squared_matches_sum_squares_for_non_nulls() {
599 let schema = Schema::new(vec![Field::new("x", DataType::Float64)]);
600 let ds = DataSet::new(
601 schema,
602 vec![vec![Value::Float64(2.0)], vec![Value::Float64(3.0)]],
603 );
604 let ss = match reduce(&ds, "x", ReduceOp::SumSquares).unwrap() {
605 Value::Float64(v) => v,
606 other => panic!("expected Float64, got {other:?}"),
607 };
608 let l2 = match reduce(&ds, "x", ReduceOp::L2Norm).unwrap() {
609 Value::Float64(v) => v,
610 other => panic!("expected Float64, got {other:?}"),
611 };
612 assert!((l2 * l2 - ss).abs() < 1e-12);
613 }
614
615 #[test]
616 fn reduce_median_odd_and_even() {
617 let schema = Schema::new(vec![Field::new("x", DataType::Int64)]);
618 let odd = DataSet::new(
619 schema.clone(),
620 vec![
621 vec![Value::Int64(3)],
622 vec![Value::Int64(1)],
623 vec![Value::Int64(2)],
624 ],
625 );
626 assert_eq!(
627 reduce(&odd, "x", ReduceOp::Median),
628 Some(Value::Float64(2.0))
629 );
630 let even = DataSet::new(
631 schema,
632 vec![
633 vec![Value::Int64(10)],
634 vec![Value::Int64(20)],
635 vec![Value::Int64(30)],
636 vec![Value::Int64(40)],
637 ],
638 );
639 assert_eq!(
640 reduce(&even, "x", ReduceOp::Median),
641 Some(Value::Float64(25.0))
642 );
643 }
644}