Skip to main content

rust_data_processing/pipeline/
mod.rs

1//! DataFrame-centric pipeline/transforms backed by a Polars lazy plan.
2//!
3//! This module provides a small, engine-delegated pipeline API that compiles to a Polars
4//! [`polars::prelude::LazyFrame`] and then collects results back into our in-memory [`crate::types::DataSet`].
5//!
6//! Design goals for Phase 1:
7//! - Keep the public API in our own types (no Polars types in signatures)
8//! - Support a minimal set of transformation primitives needed for parity/benchmarks
9//! - Provide deterministic, testable behavior (null handling, missing column errors)
10//!
11//! # Examples
12//!
13//! ```no_run
14//! use rust_data_processing::pipeline::{Agg, DataFrame, JoinKind, Predicate};
15//! use rust_data_processing::types::{DataSet, DataType, Field, Schema, Value};
16//!
17//! # fn main() -> Result<(), rust_data_processing::IngestionError> {
18//! let ds = DataSet::new(
19//!     Schema::new(vec![
20//!         Field::new("id", DataType::Int64),
21//!         Field::new("active", DataType::Bool),
22//!         Field::new("score", DataType::Int64),
23//!         Field::new("grp", DataType::Utf8),
24//!     ]),
25//!     vec![
26//!         vec![Value::Int64(1), Value::Bool(true), Value::Int64(10), Value::Utf8("A".to_string())],
27//!         vec![Value::Int64(2), Value::Bool(true), Value::Null, Value::Utf8("A".to_string())],
28//!     ],
29//! );
30//!
31//! // Rename + cast + fill nulls.
32//! let cleaned = DataFrame::from_dataset(&ds)?
33//!     .rename(&[("score", "score_i")])?
34//!     .cast("score_i", DataType::Float64)?
35//!     .fill_null("score_i", Value::Float64(0.0))?;
36//!
37//! // Filter + group_by.
38//! let _out = cleaned
39//!     .filter(Predicate::Eq {
40//!         column: "active".to_string(),
41//!         value: Value::Bool(true),
42//!     })?
43//!     .group_by(
44//!         &["grp"],
45//!         &[Agg::Sum {
46//!             column: "score_i".to_string(),
47//!             alias: "sum_score".to_string(),
48//!         }],
49//!     )?
50//!     .collect()?;
51//!
52//! // Join two DataFrames.
53//! let left = DataFrame::from_dataset(&ds)?;
54//! let right = DataFrame::from_dataset(&ds)?;
55//! let _joined = left.join(right, &["id"], &["id"], JoinKind::Inner)?;
56//! # Ok(())
57//! # }
58//! ```
59
60use crate::error::{IngestionError, IngestionResult};
61use crate::ingestion::polars_bridge::{
62    dataframe_to_dataset, dataset_to_dataframe, infer_schema_from_dataframe,
63    polars_error_to_ingestion,
64};
65use crate::processing::{FeatureMeanStd, ReduceOp, VarianceKind};
66use crate::types::{DataSet, DataType, Schema, Value};
67
68use polars::chunked_array::cast::CastOptions;
69use polars::prelude::*;
70use serde::{Deserialize, Serialize};
71
72const REDUCE_SCALAR_COL: &str = "__rust_dp_reduce_scalar";
73
74/// A predicate used by [`DataFrame::filter`].
75#[derive(Debug, Clone, PartialEq)]
76pub enum Predicate {
77    /// Keep rows where `column == value`.
78    Eq { column: String, value: Value },
79    /// Keep rows where `column` is not null.
80    NotNull { column: String },
81    /// Keep rows where `column % modulus == equals` (Int64 only).
82    ModEqInt64 {
83        column: String,
84        modulus: i64,
85        equals: i64,
86    },
87}
88
89/// Join behavior for [`DataFrame::join`].
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum JoinKind {
92    Inner,
93    Left,
94    Right,
95    Full,
96}
97
98/// Aggregations for [`DataFrame::group_by`].
99#[derive(Debug, Clone, PartialEq)]
100pub enum Agg {
101    /// Count rows in each group (includes nulls).
102    CountRows {
103        alias: String,
104    },
105    /// Count non-null values of a column in each group.
106    CountNotNull {
107        column: String,
108        alias: String,
109    },
110    Sum {
111        column: String,
112        alias: String,
113    },
114    Min {
115        column: String,
116        alias: String,
117    },
118    Max {
119        column: String,
120        alias: String,
121    },
122    /// Mean of numeric values (cast to `Float64` first), nulls ignored.
123    Mean {
124        column: String,
125        alias: String,
126    },
127    Variance {
128        column: String,
129        alias: String,
130        kind: VarianceKind,
131    },
132    StdDev {
133        column: String,
134        alias: String,
135        kind: VarianceKind,
136    },
137    SumSquares {
138        column: String,
139        alias: String,
140    },
141    L2Norm {
142        column: String,
143        alias: String,
144    },
145    /// Distinct count of non-null values in each group.
146    CountDistinctNonNull {
147        column: String,
148        alias: String,
149    },
150    /// Median of numeric values (cast to `Float64` first), nulls ignored.
151    Median {
152        column: String,
153        alias: String,
154    },
155}
156
157/// Casting behavior for [`DataFrame::cast_with_mode`].
158#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
159#[serde(rename_all = "snake_case")]
160pub enum CastMode {
161    /// Casting errors fail the pipeline at `collect()` time.
162    #[default]
163    Strict,
164    /// Casting errors yield nulls instead of failing.
165    Lossy,
166}
167
168/// A DataFrame-centric pipeline compiled into a lazy plan.
169///
170/// The public API stays in this crate's own types. The current engine implementation is Polars,
171/// but callers do not need to depend on Polars types.
172#[derive(Clone)]
173pub struct DataFrame {
174    lf: LazyFrame,
175}
176
177impl DataFrame {
178    /// Build a pipeline starting from an in-memory [`DataSet`].
179    ///
180    /// Note: this converts the dataset into a Polars `DataFrame` first. The transformations after
181    /// that are planned lazily.
182    pub fn from_dataset(ds: &DataSet) -> IngestionResult<Self> {
183        let df = dataset_to_dataframe(ds)?;
184        Ok(Self { lf: df.lazy() })
185    }
186
187    /// Add a filter predicate.
188    pub fn filter(mut self, predicate: Predicate) -> IngestionResult<Self> {
189        let expr = match predicate {
190            Predicate::Eq { column, value } => match value {
191                Value::Null => col(&column).is_null(),
192                Value::Int64(x) => col(&column).eq(lit(x)),
193                Value::Float64(x) => col(&column).eq(lit(x)),
194                Value::Bool(x) => col(&column).eq(lit(x)),
195                Value::Utf8(s) => col(&column).eq(lit(s)),
196            },
197            Predicate::NotNull { column } => col(&column).is_not_null(),
198            Predicate::ModEqInt64 {
199                column,
200                modulus,
201                equals,
202            } => (col(&column) % lit(modulus)).eq(lit(equals)),
203        };
204        // Planning ops are infallible; errors surface at `collect` time.
205        self.lf = self.lf.filter(expr);
206        Ok(self)
207    }
208
209    /// Multiply a Float64 column by a constant factor (nulls remain null).
210    pub fn multiply_f64(mut self, column: &str, factor: f64) -> IngestionResult<Self> {
211        // Planning ops are infallible; errors surface at `collect` time.
212        self.lf = self
213            .lf
214            .with_columns([(col(column) * lit(factor)).alias(column)]);
215        Ok(self)
216    }
217
218    /// Add a constant Float64 value to a column (nulls remain null).
219    pub fn add_f64(mut self, column: &str, delta: f64) -> IngestionResult<Self> {
220        self.lf = self
221            .lf
222            .with_columns([(col(column) + lit(delta)).alias(column)]);
223        Ok(self)
224    }
225
226    /// Add a derived Float64 column: `name = source * factor` (nulls remain null).
227    pub fn with_mul_f64(mut self, name: &str, source: &str, factor: f64) -> IngestionResult<Self> {
228        self.lf = self
229            .lf
230            .with_columns([(col(source) * lit(factor)).alias(name)]);
231        Ok(self)
232    }
233
234    /// Add a derived Float64 column: `name = source + delta` (nulls remain null).
235    pub fn with_add_f64(mut self, name: &str, source: &str, delta: f64) -> IngestionResult<Self> {
236        self.lf = self
237            .lf
238            .with_columns([(col(source) + lit(delta)).alias(name)]);
239        Ok(self)
240    }
241
242    /// Select a subset of columns (in the provided order).
243    pub fn select(mut self, columns: &[&str]) -> IngestionResult<Self> {
244        let exprs: Vec<Expr> = columns.iter().map(|c| col(*c)).collect();
245        // Planning ops are infallible; errors surface at `collect` time.
246        self.lf = self.lf.select(exprs);
247        Ok(self)
248    }
249
250    /// Rename columns.
251    ///
252    /// This uses Polars' `rename(..., strict=true)` behavior: all `from` columns must exist.
253    pub fn rename(mut self, pairs: &[(&str, &str)]) -> IngestionResult<Self> {
254        let (existing, new): (Vec<&str>, Vec<&str>) = pairs.iter().copied().unzip();
255        self.lf = self.lf.rename(existing, new, true);
256        Ok(self)
257    }
258
259    /// Cast a column to a target type.
260    ///
261    /// Note: cast errors (e.g. invalid parses) surface at `collect()` time.
262    pub fn cast(self, column: &str, to: DataType) -> IngestionResult<Self> {
263        self.cast_with_mode(column, to, CastMode::Strict)
264    }
265
266    /// Cast a column with an explicit mode (strict vs lossy).
267    pub fn cast_with_mode(
268        mut self,
269        column: &str,
270        to: DataType,
271        mode: CastMode,
272    ) -> IngestionResult<Self> {
273        let dt = to_polars_dtype(&to);
274        let expr = match mode {
275            CastMode::Strict => col(column).strict_cast(dt),
276            CastMode::Lossy => col(column).cast_with_options(dt, CastOptions::NonStrict),
277        }
278        .alias(column);
279        self.lf = self.lf.with_columns([expr]);
280        Ok(self)
281    }
282
283    /// Drop columns by name.
284    pub fn drop(mut self, columns: &[&str]) -> IngestionResult<Self> {
285        let names: Vec<PlSmallStr> = columns.iter().map(|c| (*c).into()).collect();
286        let sel = Selector::ByName {
287            names: names.into(),
288            strict: true,
289        };
290        self.lf = self.lf.drop(sel);
291        Ok(self)
292    }
293
294    /// Fill nulls in a column with a literal.
295    pub fn fill_null(mut self, column: &str, value: Value) -> IngestionResult<Self> {
296        let lit_expr = value_to_lit_expr(value)?;
297        self.lf = self
298            .lf
299            .with_columns([col(column).fill_null(lit_expr).alias(column)]);
300        Ok(self)
301    }
302
303    /// Add a derived column with a literal value.
304    pub fn with_literal(mut self, name: &str, value: Value) -> IngestionResult<Self> {
305        let lit_expr = value_to_lit_expr(value)?;
306        self.lf = self.lf.with_columns([lit_expr.alias(name)]);
307        Ok(self)
308    }
309
310    /// Group rows by `keys` and compute aggregations.
311    pub fn group_by(mut self, keys: &[&str], aggs: &[Agg]) -> IngestionResult<Self> {
312        if keys.is_empty() {
313            return Err(IngestionError::SchemaMismatch {
314                message: "group_by requires at least one key column".to_string(),
315            });
316        }
317        if aggs.is_empty() {
318            return Err(IngestionError::SchemaMismatch {
319                message: "group_by requires at least one aggregation".to_string(),
320            });
321        }
322
323        let key_exprs: Vec<Expr> = keys.iter().map(|k| col(*k)).collect();
324        let agg_exprs: Vec<Expr> = aggs.iter().map(agg_to_expr).collect();
325        self.lf = self.lf.group_by(key_exprs).agg(agg_exprs);
326        Ok(self)
327    }
328
329    /// Join this pipeline with another [`DataFrame`] on key columns.
330    ///
331    /// Note: join planning is infallible; missing-column errors surface at `collect()` time.
332    pub fn join(
333        mut self,
334        other: DataFrame,
335        left_on: &[&str],
336        right_on: &[&str],
337        how: JoinKind,
338    ) -> IngestionResult<Self> {
339        if left_on.is_empty() || right_on.is_empty() {
340            return Err(IngestionError::SchemaMismatch {
341                message: "join requires at least one join key on each side".to_string(),
342            });
343        }
344        if left_on.len() != right_on.len() {
345            return Err(IngestionError::SchemaMismatch {
346                message: format!(
347                    "join requires left_on and right_on to have same length (left_on={}, right_on={})",
348                    left_on.len(),
349                    right_on.len()
350                ),
351            });
352        }
353
354        let left_exprs: Vec<Expr> = left_on.iter().map(|c| col(*c)).collect();
355        let right_exprs: Vec<Expr> = right_on.iter().map(|c| col(*c)).collect();
356
357        let how = match how {
358            JoinKind::Inner => JoinType::Inner,
359            JoinKind::Left => JoinType::Left,
360            JoinKind::Right => JoinType::Right,
361            JoinKind::Full => JoinType::Full,
362        };
363
364        self.lf = self
365            .lf
366            .join(other.lf, left_exprs, right_exprs, JoinArgs::new(how));
367        Ok(self)
368    }
369
370    /// Collect the pipeline into an in-memory [`DataSet`].
371    pub fn collect(self) -> IngestionResult<DataSet> {
372        let df = self
373            .lf
374            .collect()
375            .map_err(|e| polars_error_to_ingestion("failed to collect polars lazy plan", e))?;
376        let out_schema = infer_schema_from_dataframe(&df)?;
377        dataframe_to_dataset(&df, &out_schema, "column", 1)
378    }
379
380    /// Collect the pipeline into an in-memory [`DataSet`], enforcing an explicit output schema.
381    pub fn collect_with_schema(self, schema: &Schema) -> IngestionResult<DataSet> {
382        let df = self
383            .lf
384            .collect()
385            .map_err(|e| polars_error_to_ingestion("failed to collect polars lazy plan", e))?;
386        dataframe_to_dataset(&df, schema, "column", 1)
387    }
388
389    /// Reduce a column using a built-in [`ReduceOp`] (Polars-backed).
390    ///
391    /// Returns `None` if `column` does not exist (aligned with [`crate::processing::reduce()`]).
392    pub fn reduce(mut self, column: &str, op: ReduceOp) -> IngestionResult<Option<Value>> {
393        let df_schema = self
394            .lf
395            .collect_schema()
396            .map_err(|e| polars_error_to_ingestion("failed to collect polars schema", e))?;
397        if df_schema.get(column).is_none() {
398            return Ok(None);
399        }
400
401        let expr = polars_reduce_expr(column, op);
402        let df = self
403            .lf
404            .select([expr.alias(REDUCE_SCALAR_COL)])
405            .collect()
406            .map_err(|e| polars_error_to_ingestion("failed to collect polars reduce", e))?;
407
408        let s = df
409            .column(REDUCE_SCALAR_COL)
410            .map_err(|_| IngestionError::SchemaMismatch {
411                message: format!("missing reduce output column '{REDUCE_SCALAR_COL}'"),
412            })?
413            .as_materialized_series();
414        if s.is_empty() {
415            return Ok(Some(Value::Null));
416        }
417        let av = s.get(0).map_err(|e| IngestionError::SchemaMismatch {
418            message: format!("polars reduce output error: {e}"),
419        })?;
420        Ok(Some(anyvalue_to_value(av)))
421    }
422
423    /// Reduce a numeric column by summing values (nulls ignored; all-null -> null).
424    ///
425    /// Returns `None` if `column` does not exist (aligned with `processing::reduce`).
426    pub fn sum(self, column: &str) -> IngestionResult<Option<Value>> {
427        self.reduce(column, ReduceOp::Sum)
428    }
429
430    /// Single Polars collect: for each column, mean and standard deviation (`std_kind` maps to
431    /// Polars `ddof`). Columns are cast to `Float64` first (aligned with scalar reduces).
432    ///
433    /// Returns an error if any column name is missing from the lazy schema.
434    pub fn feature_wise_mean_std(
435        mut self,
436        columns: &[&str],
437        std_kind: VarianceKind,
438    ) -> IngestionResult<Vec<(String, FeatureMeanStd)>> {
439        let df_schema = self
440            .lf
441            .collect_schema()
442            .map_err(|e| polars_error_to_ingestion("failed to collect polars schema", e))?;
443        for c in columns {
444            if df_schema.get(c).is_none() {
445                return Err(IngestionError::SchemaMismatch {
446                    message: format!("feature_wise_mean_std: unknown column '{c}'"),
447                });
448            }
449        }
450        let ddof = match std_kind {
451            VarianceKind::Population => 0u8,
452            VarianceKind::Sample => 1u8,
453        };
454        use polars::datatypes::DataType as P;
455        let mut exprs: Vec<Expr> = Vec::with_capacity(columns.len() * 2);
456        for (i, c) in columns.iter().enumerate() {
457            let cf = col(*c).strict_cast(P::Float64);
458            exprs.push(cf.clone().mean().alias(format!("__fwm_{i}_mean").as_str()));
459            exprs.push(cf.std(ddof).alias(format!("__fwm_{i}_std").as_str()));
460        }
461        let df =
462            self.lf.select(exprs).collect().map_err(|e| {
463                polars_error_to_ingestion("failed to collect feature_wise_mean_std", e)
464            })?;
465
466        if df.height() == 0 {
467            return Ok(columns
468                .iter()
469                .map(|c| {
470                    (
471                        (*c).to_string(),
472                        FeatureMeanStd {
473                            mean: Value::Null,
474                            std_dev: Value::Null,
475                        },
476                    )
477                })
478                .collect());
479        }
480
481        let mut out = Vec::with_capacity(columns.len());
482        for (i, _) in columns.iter().enumerate() {
483            let mean_s = df
484                .column(&format!("__fwm_{i}_mean"))
485                .map_err(|_| IngestionError::SchemaMismatch {
486                    message: format!("missing __fwm_{i}_mean"),
487                })?
488                .as_materialized_series();
489            let std_s = df
490                .column(&format!("__fwm_{i}_std"))
491                .map_err(|_| IngestionError::SchemaMismatch {
492                    message: format!("missing __fwm_{i}_std"),
493                })?
494                .as_materialized_series();
495            let mean_av = mean_s.get(0).map_err(|e| IngestionError::SchemaMismatch {
496                message: format!("feature_wise mean get: {e}"),
497            })?;
498            let std_av = std_s.get(0).map_err(|e| IngestionError::SchemaMismatch {
499                message: format!("feature_wise std get: {e}"),
500            })?;
501            out.push((
502                columns[i].to_string(),
503                FeatureMeanStd {
504                    mean: anyvalue_to_value(mean_av),
505                    std_dev: anyvalue_to_value(std_av),
506                },
507            ));
508        }
509        Ok(out)
510    }
511
512    pub(crate) fn lazy_clone(&self) -> LazyFrame {
513        self.lf.clone()
514    }
515
516    pub(crate) fn from_lazyframe(lf: LazyFrame) -> Self {
517        Self { lf }
518    }
519}
520
521fn polars_reduce_expr(column: &str, op: ReduceOp) -> Expr {
522    use polars::datatypes::DataType as P;
523    let c = col(column);
524    match op {
525        ReduceOp::Count => len(),
526        ReduceOp::Sum => c.sum(),
527        ReduceOp::Min => c.min(),
528        ReduceOp::Max => c.max(),
529        ReduceOp::Mean => c.clone().strict_cast(P::Float64).mean(),
530        ReduceOp::Variance(kind) => {
531            let ddof = match kind {
532                VarianceKind::Population => 0u8,
533                VarianceKind::Sample => 1u8,
534            };
535            c.clone().strict_cast(P::Float64).var(ddof)
536        }
537        ReduceOp::StdDev(kind) => {
538            let ddof = match kind {
539                VarianceKind::Population => 0u8,
540                VarianceKind::Sample => 1u8,
541            };
542            c.clone().strict_cast(P::Float64).std(ddof)
543        }
544        ReduceOp::SumSquares => c.clone().strict_cast(P::Float64).pow(lit(2.0)).sum(),
545        ReduceOp::L2Norm => c.clone().strict_cast(P::Float64).pow(lit(2.0)).sum().sqrt(),
546        ReduceOp::CountDistinctNonNull => c.drop_nulls().n_unique(),
547        ReduceOp::Median => c.clone().strict_cast(P::Float64).median(),
548    }
549}
550
551fn agg_to_expr(agg: &Agg) -> Expr {
552    use polars::datatypes::DataType as P;
553    match agg {
554        Agg::CountRows { alias } => len().alias(alias.as_str()),
555        Agg::CountNotNull { column, alias } => col(column.as_str()).count().alias(alias.as_str()),
556        Agg::Sum { column, alias } => col(column.as_str()).sum().alias(alias.as_str()),
557        Agg::Min { column, alias } => col(column.as_str()).min().alias(alias.as_str()),
558        Agg::Max { column, alias } => col(column.as_str()).max().alias(alias.as_str()),
559        Agg::Mean { column, alias } => col(column.as_str())
560            .strict_cast(P::Float64)
561            .mean()
562            .alias(alias.as_str()),
563        Agg::Variance {
564            column,
565            alias,
566            kind,
567        } => {
568            let ddof = match kind {
569                VarianceKind::Population => 0u8,
570                VarianceKind::Sample => 1u8,
571            };
572            col(column.as_str())
573                .strict_cast(P::Float64)
574                .var(ddof)
575                .alias(alias.as_str())
576        }
577        Agg::StdDev {
578            column,
579            alias,
580            kind,
581        } => {
582            let ddof = match kind {
583                VarianceKind::Population => 0u8,
584                VarianceKind::Sample => 1u8,
585            };
586            col(column.as_str())
587                .strict_cast(P::Float64)
588                .std(ddof)
589                .alias(alias.as_str())
590        }
591        Agg::SumSquares { column, alias } => col(column.as_str())
592            .strict_cast(P::Float64)
593            .pow(lit(2.0))
594            .sum()
595            .alias(alias.as_str()),
596        Agg::L2Norm { column, alias } => col(column.as_str())
597            .strict_cast(P::Float64)
598            .pow(lit(2.0))
599            .sum()
600            .sqrt()
601            .alias(alias.as_str()),
602        Agg::CountDistinctNonNull { column, alias } => col(column.as_str())
603            .drop_nulls()
604            .n_unique()
605            .alias(alias.as_str()),
606        Agg::Median { column, alias } => col(column.as_str())
607            .strict_cast(P::Float64)
608            .median()
609            .alias(alias.as_str()),
610    }
611}
612
613fn to_polars_dtype(dt: &DataType) -> polars::datatypes::DataType {
614    match dt {
615        DataType::Int64 => polars::datatypes::DataType::Int64,
616        DataType::Float64 => polars::datatypes::DataType::Float64,
617        DataType::Bool => polars::datatypes::DataType::Boolean,
618        DataType::Utf8 => polars::datatypes::DataType::String,
619    }
620}
621
622fn value_to_lit_expr(value: Value) -> IngestionResult<Expr> {
623    match value {
624        Value::Null => Err(IngestionError::SchemaMismatch {
625            message: "Value::Null is not supported as a literal expression; use fill_null or cast/collect to materialize".to_string(),
626        }),
627        Value::Int64(v) => Ok(lit(v)),
628        Value::Float64(v) => Ok(lit(v)),
629        Value::Bool(v) => Ok(lit(v)),
630        Value::Utf8(v) => Ok(lit(v)),
631    }
632}
633
634fn anyvalue_to_value(av: AnyValue) -> Value {
635    match av {
636        AnyValue::Null => Value::Null,
637        AnyValue::Int8(v) => Value::Int64(v as i64),
638        AnyValue::Int16(v) => Value::Int64(v as i64),
639        AnyValue::Int32(v) => Value::Int64(v as i64),
640        AnyValue::Int64(v) => Value::Int64(v),
641        AnyValue::UInt8(v) => Value::Int64(v as i64),
642        AnyValue::UInt16(v) => Value::Int64(v as i64),
643        AnyValue::UInt32(v) => Value::Int64(v as i64),
644        AnyValue::UInt64(v) => Value::Int64(v as i64),
645        AnyValue::Float64(v) => Value::Float64(v),
646        AnyValue::Boolean(v) => Value::Bool(v),
647        AnyValue::String(v) => Value::Utf8(v.to_string()),
648        AnyValue::StringOwned(v) => Value::Utf8(v.to_string()),
649        other => Value::Utf8(other.to_string()),
650    }
651}
652
653/// Backwards-compatible alias for earlier naming.
654pub type PolarsPipeline = DataFrame;
655
656#[cfg(test)]
657mod tests {
658    use super::{Agg, DataFrame, JoinKind, PolarsPipeline, Predicate};
659    use crate::processing::{ReduceOp, VarianceKind, feature_wise_mean_std, filter, map, reduce};
660    use crate::types::{DataSet, DataType, Field, Schema, Value};
661
662    fn sample_dataset() -> DataSet {
663        let schema = Schema::new(vec![
664            Field::new("id", DataType::Int64),
665            Field::new("active", DataType::Bool),
666            Field::new("score", DataType::Float64),
667        ]);
668        let rows = vec![
669            vec![Value::Int64(1), Value::Bool(true), Value::Float64(10.0)],
670            vec![Value::Int64(2), Value::Bool(true), Value::Float64(20.0)],
671            vec![Value::Int64(3), Value::Bool(false), Value::Float64(30.0)],
672            vec![Value::Int64(4), Value::Bool(true), Value::Null],
673        ];
674        DataSet::new(schema, rows)
675    }
676
677    #[test]
678    fn polars_pipeline_filter_map_reduce_parity_with_in_memory() {
679        let ds = sample_dataset();
680
681        // In-memory baseline: active && even id, score *= 2.0, then sum(score)
682        let active_idx = ds.schema.index_of("active").unwrap();
683        let id_idx = ds.schema.index_of("id").unwrap();
684        let filtered = filter(&ds, |row| {
685            let is_active = matches!(row.get(active_idx), Some(Value::Bool(true)));
686            let even_id = matches!(row.get(id_idx), Some(Value::Int64(v)) if *v % 2 == 0);
687            is_active && even_id
688        });
689        let mapped = map(&filtered, |row| {
690            let mut out = row.to_vec();
691            if let Some(Value::Float64(v)) = out.get(2) {
692                out[2] = Value::Float64(v * 2.0);
693            }
694            out
695        });
696        let expected = reduce(&mapped, "score", ReduceOp::Sum).unwrap();
697
698        // Polars-delegated pipeline.
699        let got = DataFrame::from_dataset(&ds)
700            .unwrap()
701            .filter(Predicate::Eq {
702                column: "active".to_string(),
703                value: Value::Bool(true),
704            })
705            .unwrap()
706            .filter(Predicate::ModEqInt64 {
707                column: "id".to_string(),
708                modulus: 2,
709                equals: 0,
710            })
711            .unwrap()
712            .multiply_f64("score", 2.0)
713            .unwrap()
714            .sum("score")
715            .unwrap()
716            .unwrap();
717
718        assert_eq!(got, expected);
719    }
720
721    #[test]
722    fn polars_pipeline_reduce_parity_mean_variance_l2_distinct() {
723        let schema = Schema::new(vec![
724            Field::new("x", DataType::Float64),
725            Field::new("tag", DataType::Utf8),
726        ]);
727        let ds = DataSet::new(
728            schema,
729            vec![
730                vec![Value::Float64(1.0), Value::Utf8("a".to_string())],
731                vec![Value::Float64(2.0), Value::Utf8("b".to_string())],
732                vec![Value::Null, Value::Utf8("a".to_string())],
733            ],
734        );
735
736        let mean = reduce(&ds, "x", ReduceOp::Mean).unwrap();
737        let var_pop = reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)).unwrap();
738        let l2 = reduce(&ds, "x", ReduceOp::L2Norm).unwrap();
739        let dcnt = reduce(&ds, "tag", ReduceOp::CountDistinctNonNull).unwrap();
740
741        assert_eq!(
742            DataFrame::from_dataset(&ds)
743                .unwrap()
744                .reduce("x", ReduceOp::Mean)
745                .unwrap()
746                .unwrap(),
747            mean
748        );
749        assert_eq!(
750            DataFrame::from_dataset(&ds)
751                .unwrap()
752                .reduce("x", ReduceOp::Variance(VarianceKind::Population))
753                .unwrap()
754                .unwrap(),
755            var_pop
756        );
757        assert_eq!(
758            DataFrame::from_dataset(&ds)
759                .unwrap()
760                .reduce("x", ReduceOp::L2Norm)
761                .unwrap()
762                .unwrap(),
763            l2
764        );
765        assert_eq!(
766            DataFrame::from_dataset(&ds)
767                .unwrap()
768                .reduce("tag", ReduceOp::CountDistinctNonNull)
769                .unwrap()
770                .unwrap(),
771            dcnt
772        );
773    }
774
775    #[test]
776    fn polars_pipeline_collect_select_works() {
777        let ds = sample_dataset();
778        let out = DataFrame::from_dataset(&ds)
779            .unwrap()
780            .select(&["score", "id"])
781            .unwrap()
782            .collect()
783            .unwrap();
784
785        assert_eq!(
786            out.schema.field_names().collect::<Vec<_>>(),
787            vec!["score", "id"]
788        );
789        assert_eq!(out.row_count(), ds.row_count());
790        assert_eq!(out.rows[0][0], Value::Float64(10.0));
791        assert_eq!(out.rows[0][1], Value::Int64(1));
792    }
793
794    #[test]
795    fn polars_pipeline_sum_returns_none_for_missing_column() {
796        let ds = sample_dataset();
797        let out = DataFrame::from_dataset(&ds)
798            .unwrap()
799            .sum("missing")
800            .unwrap();
801        assert_eq!(out, None);
802    }
803
804    #[test]
805    fn polars_errors_are_preserved_as_engine_error_sources() {
806        // Trigger a Polars execution error by applying a numeric multiply to a Utf8 column.
807        let schema = Schema::new(vec![Field::new("name", DataType::Utf8)]);
808        let ds = DataSet::new(schema, vec![vec![Value::Utf8("x".to_string())]]);
809
810        let err = DataFrame::from_dataset(&ds)
811            .unwrap()
812            .multiply_f64("name", 2.0)
813            .unwrap()
814            .collect()
815            .unwrap_err();
816
817        // This should not be stringified into SchemaMismatch; it should preserve a source() chain.
818        match err {
819            crate::error::IngestionError::Engine { source, .. } => {
820                assert!(!source.to_string().is_empty());
821            }
822            other => panic!("expected Engine error, got: {other:?}"),
823        }
824    }
825
826    #[test]
827    fn backwards_compatible_polars_pipeline_alias_exists() {
828        let ds = sample_dataset();
829        let _ = PolarsPipeline::from_dataset(&ds)
830            .unwrap()
831            .select(&["id"])
832            .unwrap();
833    }
834
835    #[test]
836    fn rename_cast_fill_null_group_by_and_join_work() {
837        // rename + cast + fill_null
838        let schema = Schema::new(vec![
839            Field::new("id", DataType::Int64),
840            Field::new("score", DataType::Int64),
841        ]);
842        let ds = DataSet::new(
843            schema,
844            vec![
845                vec![Value::Int64(1), Value::Int64(10)],
846                vec![Value::Int64(2), Value::Null],
847            ],
848        );
849
850        let out = DataFrame::from_dataset(&ds)
851            .unwrap()
852            .rename(&[("score", "score_i")])
853            .unwrap()
854            .cast("score_i", DataType::Float64)
855            .unwrap()
856            .fill_null("score_i", Value::Float64(0.0))
857            .unwrap()
858            .collect()
859            .unwrap();
860
861        assert_eq!(
862            out.schema.field_names().collect::<Vec<_>>(),
863            vec!["id", "score_i"]
864        );
865        assert_eq!(out.rows[0][1], Value::Float64(10.0));
866        assert_eq!(out.rows[1][1], Value::Float64(0.0));
867
868        // group_by
869        let schema = Schema::new(vec![
870            Field::new("grp", DataType::Utf8),
871            Field::new("score", DataType::Float64),
872        ]);
873        let ds = DataSet::new(
874            schema,
875            vec![
876                vec![Value::Utf8("A".to_string()), Value::Float64(1.0)],
877                vec![Value::Utf8("A".to_string()), Value::Float64(2.0)],
878                vec![Value::Utf8("B".to_string()), Value::Null],
879            ],
880        );
881
882        let out = DataFrame::from_dataset(&ds)
883            .unwrap()
884            .group_by(
885                &["grp"],
886                &[
887                    Agg::Sum {
888                        column: "score".to_string(),
889                        alias: "sum_score".to_string(),
890                    },
891                    Agg::CountRows {
892                        alias: "cnt".to_string(),
893                    },
894                ],
895            )
896            .unwrap()
897            .collect()
898            .unwrap();
899
900        // Order is not guaranteed; validate via a lookup.
901        let mut sums: std::collections::HashMap<String, (Value, Value)> =
902            std::collections::HashMap::new();
903        for row in &out.rows {
904            if let Value::Utf8(g) = &row[0] {
905                sums.insert(g.clone(), (row[1].clone(), row[2].clone()));
906            }
907        }
908        assert_eq!(sums.get("A"), Some(&(Value::Float64(3.0), Value::Int64(2))));
909        assert_eq!(
910            sums.get("B"),
911            // Polars `sum` ignores nulls and returns 0.0 for all-null groups.
912            Some(&(Value::Float64(0.0), Value::Int64(1)))
913        );
914
915        // join
916        let left = DataSet::new(
917            Schema::new(vec![
918                Field::new("id", DataType::Int64),
919                Field::new("name", DataType::Utf8),
920            ]),
921            vec![
922                vec![Value::Int64(1), Value::Utf8("Ada".to_string())],
923                vec![Value::Int64(2), Value::Utf8("Grace".to_string())],
924            ],
925        );
926        let right = DataSet::new(
927            Schema::new(vec![
928                Field::new("id", DataType::Int64),
929                Field::new("score", DataType::Float64),
930            ]),
931            vec![
932                vec![Value::Int64(1), Value::Float64(9.0)],
933                vec![Value::Int64(3), Value::Float64(7.0)],
934            ],
935        );
936
937        let out = DataFrame::from_dataset(&left)
938            .unwrap()
939            .join(
940                DataFrame::from_dataset(&right).unwrap(),
941                &["id"],
942                &["id"],
943                JoinKind::Inner,
944            )
945            .unwrap()
946            .collect()
947            .unwrap();
948        assert_eq!(out.row_count(), 1);
949        // One matched row with id=1.
950        assert_eq!(out.rows[0][0], Value::Int64(1));
951    }
952
953    #[test]
954    fn polars_feature_wise_mean_std_matches_in_memory() {
955        let schema = Schema::new(vec![
956            Field::new("a", DataType::Int64),
957            Field::new("b", DataType::Float64),
958        ]);
959        let ds = DataSet::new(
960            schema,
961            vec![
962                vec![Value::Int64(1), Value::Float64(10.0)],
963                vec![Value::Int64(3), Value::Float64(20.0)],
964            ],
965        );
966        let mem = feature_wise_mean_std(&ds, &["a", "b"], VarianceKind::Sample).unwrap();
967        let pol = DataFrame::from_dataset(&ds)
968            .unwrap()
969            .feature_wise_mean_std(&["a", "b"], VarianceKind::Sample)
970            .unwrap();
971        assert_eq!(mem.len(), pol.len());
972        for i in 0..mem.len() {
973            assert_eq!(mem[i].0, pol[i].0);
974            assert_eq!(mem[i].1.mean, pol[i].1.mean);
975            match (&mem[i].1.std_dev, &pol[i].1.std_dev) {
976                (Value::Float64(m), Value::Float64(p)) => assert!((m - p).abs() < 1e-9),
977                (a, b) => assert_eq!(a, b),
978            }
979        }
980    }
981
982    #[test]
983    fn group_by_mean_std_count_distinct_all_null_numeric_is_null() {
984        let schema = Schema::new(vec![
985            Field::new("g", DataType::Utf8),
986            Field::new("x", DataType::Float64),
987            Field::new("tag", DataType::Utf8),
988        ]);
989        let ds = DataSet::new(
990            schema,
991            vec![
992                vec![
993                    Value::Utf8("A".to_string()),
994                    Value::Null,
995                    Value::Utf8("p".to_string()),
996                ],
997                vec![
998                    Value::Utf8("A".to_string()),
999                    Value::Null,
1000                    Value::Utf8("q".to_string()),
1001                ],
1002            ],
1003        );
1004        let out = DataFrame::from_dataset(&ds)
1005            .unwrap()
1006            .group_by(
1007                &["g"],
1008                &[
1009                    Agg::Mean {
1010                        column: "x".to_string(),
1011                        alias: "mx".to_string(),
1012                    },
1013                    Agg::StdDev {
1014                        column: "x".to_string(),
1015                        alias: "sx".to_string(),
1016                        kind: VarianceKind::Sample,
1017                    },
1018                    Agg::CountDistinctNonNull {
1019                        column: "tag".to_string(),
1020                        alias: "dt".to_string(),
1021                    },
1022                ],
1023            )
1024            .unwrap()
1025            .collect()
1026            .unwrap();
1027        assert_eq!(out.row_count(), 1);
1028        assert_eq!(out.rows[0][0], Value::Utf8("A".to_string()));
1029        assert_eq!(out.rows[0][1], Value::Null);
1030        assert_eq!(out.rows[0][2], Value::Null);
1031        assert_eq!(out.rows[0][3], Value::Int64(2));
1032    }
1033}