1use crate::error::{IngestionError, IngestionResult};
61use crate::ingestion::polars_bridge::{
62 dataframe_to_dataset, dataset_to_dataframe, infer_schema_from_dataframe,
63 polars_error_to_ingestion,
64};
65use crate::processing::{FeatureMeanStd, ReduceOp, VarianceKind};
66use crate::types::{DataSet, DataType, Schema, Value};
67
68use polars::chunked_array::cast::CastOptions;
69use polars::prelude::*;
70use serde::{Deserialize, Serialize};
71
72const REDUCE_SCALAR_COL: &str = "__rust_dp_reduce_scalar";
73
74#[derive(Debug, Clone, PartialEq)]
76pub enum Predicate {
77 Eq { column: String, value: Value },
79 NotNull { column: String },
81 ModEqInt64 {
83 column: String,
84 modulus: i64,
85 equals: i64,
86 },
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum JoinKind {
92 Inner,
93 Left,
94 Right,
95 Full,
96}
97
98#[derive(Debug, Clone, PartialEq)]
100pub enum Agg {
101 CountRows {
103 alias: String,
104 },
105 CountNotNull {
107 column: String,
108 alias: String,
109 },
110 Sum {
111 column: String,
112 alias: String,
113 },
114 Min {
115 column: String,
116 alias: String,
117 },
118 Max {
119 column: String,
120 alias: String,
121 },
122 Mean {
124 column: String,
125 alias: String,
126 },
127 Variance {
128 column: String,
129 alias: String,
130 kind: VarianceKind,
131 },
132 StdDev {
133 column: String,
134 alias: String,
135 kind: VarianceKind,
136 },
137 SumSquares {
138 column: String,
139 alias: String,
140 },
141 L2Norm {
142 column: String,
143 alias: String,
144 },
145 CountDistinctNonNull {
147 column: String,
148 alias: String,
149 },
150 Median {
152 column: String,
153 alias: String,
154 },
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
159#[serde(rename_all = "snake_case")]
160pub enum CastMode {
161 #[default]
163 Strict,
164 Lossy,
166}
167
168#[derive(Clone)]
173pub struct DataFrame {
174 lf: LazyFrame,
175}
176
177impl DataFrame {
178 pub fn from_dataset(ds: &DataSet) -> IngestionResult<Self> {
183 let df = dataset_to_dataframe(ds)?;
184 Ok(Self { lf: df.lazy() })
185 }
186
187 pub fn filter(mut self, predicate: Predicate) -> IngestionResult<Self> {
189 let expr = match predicate {
190 Predicate::Eq { column, value } => match value {
191 Value::Null => col(&column).is_null(),
192 Value::Int64(x) => col(&column).eq(lit(x)),
193 Value::Float64(x) => col(&column).eq(lit(x)),
194 Value::Bool(x) => col(&column).eq(lit(x)),
195 Value::Utf8(s) => col(&column).eq(lit(s)),
196 },
197 Predicate::NotNull { column } => col(&column).is_not_null(),
198 Predicate::ModEqInt64 {
199 column,
200 modulus,
201 equals,
202 } => (col(&column) % lit(modulus)).eq(lit(equals)),
203 };
204 self.lf = self.lf.filter(expr);
206 Ok(self)
207 }
208
209 pub fn multiply_f64(mut self, column: &str, factor: f64) -> IngestionResult<Self> {
211 self.lf = self
213 .lf
214 .with_columns([(col(column) * lit(factor)).alias(column)]);
215 Ok(self)
216 }
217
218 pub fn add_f64(mut self, column: &str, delta: f64) -> IngestionResult<Self> {
220 self.lf = self
221 .lf
222 .with_columns([(col(column) + lit(delta)).alias(column)]);
223 Ok(self)
224 }
225
226 pub fn with_mul_f64(mut self, name: &str, source: &str, factor: f64) -> IngestionResult<Self> {
228 self.lf = self
229 .lf
230 .with_columns([(col(source) * lit(factor)).alias(name)]);
231 Ok(self)
232 }
233
234 pub fn with_add_f64(mut self, name: &str, source: &str, delta: f64) -> IngestionResult<Self> {
236 self.lf = self
237 .lf
238 .with_columns([(col(source) + lit(delta)).alias(name)]);
239 Ok(self)
240 }
241
242 pub fn select(mut self, columns: &[&str]) -> IngestionResult<Self> {
244 let exprs: Vec<Expr> = columns.iter().map(|c| col(*c)).collect();
245 self.lf = self.lf.select(exprs);
247 Ok(self)
248 }
249
250 pub fn rename(mut self, pairs: &[(&str, &str)]) -> IngestionResult<Self> {
254 let (existing, new): (Vec<&str>, Vec<&str>) = pairs.iter().copied().unzip();
255 self.lf = self.lf.rename(existing, new, true);
256 Ok(self)
257 }
258
259 pub fn cast(self, column: &str, to: DataType) -> IngestionResult<Self> {
263 self.cast_with_mode(column, to, CastMode::Strict)
264 }
265
266 pub fn cast_with_mode(
268 mut self,
269 column: &str,
270 to: DataType,
271 mode: CastMode,
272 ) -> IngestionResult<Self> {
273 let dt = to_polars_dtype(&to);
274 let expr = match mode {
275 CastMode::Strict => col(column).strict_cast(dt),
276 CastMode::Lossy => col(column).cast_with_options(dt, CastOptions::NonStrict),
277 }
278 .alias(column);
279 self.lf = self.lf.with_columns([expr]);
280 Ok(self)
281 }
282
283 pub fn drop(mut self, columns: &[&str]) -> IngestionResult<Self> {
285 let names: Vec<PlSmallStr> = columns.iter().map(|c| (*c).into()).collect();
286 let sel = Selector::ByName {
287 names: names.into(),
288 strict: true,
289 };
290 self.lf = self.lf.drop(sel);
291 Ok(self)
292 }
293
294 pub fn fill_null(mut self, column: &str, value: Value) -> IngestionResult<Self> {
296 let lit_expr = value_to_lit_expr(value)?;
297 self.lf = self
298 .lf
299 .with_columns([col(column).fill_null(lit_expr).alias(column)]);
300 Ok(self)
301 }
302
303 pub fn with_literal(mut self, name: &str, value: Value) -> IngestionResult<Self> {
305 let lit_expr = value_to_lit_expr(value)?;
306 self.lf = self.lf.with_columns([lit_expr.alias(name)]);
307 Ok(self)
308 }
309
310 pub fn group_by(mut self, keys: &[&str], aggs: &[Agg]) -> IngestionResult<Self> {
312 if keys.is_empty() {
313 return Err(IngestionError::SchemaMismatch {
314 message: "group_by requires at least one key column".to_string(),
315 });
316 }
317 if aggs.is_empty() {
318 return Err(IngestionError::SchemaMismatch {
319 message: "group_by requires at least one aggregation".to_string(),
320 });
321 }
322
323 let key_exprs: Vec<Expr> = keys.iter().map(|k| col(*k)).collect();
324 let agg_exprs: Vec<Expr> = aggs.iter().map(agg_to_expr).collect();
325 self.lf = self.lf.group_by(key_exprs).agg(agg_exprs);
326 Ok(self)
327 }
328
329 pub fn join(
333 mut self,
334 other: DataFrame,
335 left_on: &[&str],
336 right_on: &[&str],
337 how: JoinKind,
338 ) -> IngestionResult<Self> {
339 if left_on.is_empty() || right_on.is_empty() {
340 return Err(IngestionError::SchemaMismatch {
341 message: "join requires at least one join key on each side".to_string(),
342 });
343 }
344 if left_on.len() != right_on.len() {
345 return Err(IngestionError::SchemaMismatch {
346 message: format!(
347 "join requires left_on and right_on to have same length (left_on={}, right_on={})",
348 left_on.len(),
349 right_on.len()
350 ),
351 });
352 }
353
354 let left_exprs: Vec<Expr> = left_on.iter().map(|c| col(*c)).collect();
355 let right_exprs: Vec<Expr> = right_on.iter().map(|c| col(*c)).collect();
356
357 let how = match how {
358 JoinKind::Inner => JoinType::Inner,
359 JoinKind::Left => JoinType::Left,
360 JoinKind::Right => JoinType::Right,
361 JoinKind::Full => JoinType::Full,
362 };
363
364 self.lf = self
365 .lf
366 .join(other.lf, left_exprs, right_exprs, JoinArgs::new(how));
367 Ok(self)
368 }
369
370 pub fn collect(self) -> IngestionResult<DataSet> {
372 let df = self
373 .lf
374 .collect()
375 .map_err(|e| polars_error_to_ingestion("failed to collect polars lazy plan", e))?;
376 let out_schema = infer_schema_from_dataframe(&df)?;
377 dataframe_to_dataset(&df, &out_schema, "column", 1)
378 }
379
380 pub fn collect_with_schema(self, schema: &Schema) -> IngestionResult<DataSet> {
382 let df = self
383 .lf
384 .collect()
385 .map_err(|e| polars_error_to_ingestion("failed to collect polars lazy plan", e))?;
386 dataframe_to_dataset(&df, schema, "column", 1)
387 }
388
389 pub fn reduce(mut self, column: &str, op: ReduceOp) -> IngestionResult<Option<Value>> {
393 let df_schema = self
394 .lf
395 .collect_schema()
396 .map_err(|e| polars_error_to_ingestion("failed to collect polars schema", e))?;
397 if df_schema.get(column).is_none() {
398 return Ok(None);
399 }
400
401 let expr = polars_reduce_expr(column, op);
402 let df = self
403 .lf
404 .select([expr.alias(REDUCE_SCALAR_COL)])
405 .collect()
406 .map_err(|e| polars_error_to_ingestion("failed to collect polars reduce", e))?;
407
408 let s = df
409 .column(REDUCE_SCALAR_COL)
410 .map_err(|_| IngestionError::SchemaMismatch {
411 message: format!("missing reduce output column '{REDUCE_SCALAR_COL}'"),
412 })?
413 .as_materialized_series();
414 if s.is_empty() {
415 return Ok(Some(Value::Null));
416 }
417 let av = s.get(0).map_err(|e| IngestionError::SchemaMismatch {
418 message: format!("polars reduce output error: {e}"),
419 })?;
420 Ok(Some(anyvalue_to_value(av)))
421 }
422
423 pub fn sum(self, column: &str) -> IngestionResult<Option<Value>> {
427 self.reduce(column, ReduceOp::Sum)
428 }
429
430 pub fn feature_wise_mean_std(
435 mut self,
436 columns: &[&str],
437 std_kind: VarianceKind,
438 ) -> IngestionResult<Vec<(String, FeatureMeanStd)>> {
439 let df_schema = self
440 .lf
441 .collect_schema()
442 .map_err(|e| polars_error_to_ingestion("failed to collect polars schema", e))?;
443 for c in columns {
444 if df_schema.get(c).is_none() {
445 return Err(IngestionError::SchemaMismatch {
446 message: format!("feature_wise_mean_std: unknown column '{c}'"),
447 });
448 }
449 }
450 let ddof = match std_kind {
451 VarianceKind::Population => 0u8,
452 VarianceKind::Sample => 1u8,
453 };
454 use polars::datatypes::DataType as P;
455 let mut exprs: Vec<Expr> = Vec::with_capacity(columns.len() * 2);
456 for (i, c) in columns.iter().enumerate() {
457 let cf = col(*c).strict_cast(P::Float64);
458 exprs.push(cf.clone().mean().alias(format!("__fwm_{i}_mean").as_str()));
459 exprs.push(cf.std(ddof).alias(format!("__fwm_{i}_std").as_str()));
460 }
461 let df =
462 self.lf.select(exprs).collect().map_err(|e| {
463 polars_error_to_ingestion("failed to collect feature_wise_mean_std", e)
464 })?;
465
466 if df.height() == 0 {
467 return Ok(columns
468 .iter()
469 .map(|c| {
470 (
471 (*c).to_string(),
472 FeatureMeanStd {
473 mean: Value::Null,
474 std_dev: Value::Null,
475 },
476 )
477 })
478 .collect());
479 }
480
481 let mut out = Vec::with_capacity(columns.len());
482 for (i, _) in columns.iter().enumerate() {
483 let mean_s = df
484 .column(&format!("__fwm_{i}_mean"))
485 .map_err(|_| IngestionError::SchemaMismatch {
486 message: format!("missing __fwm_{i}_mean"),
487 })?
488 .as_materialized_series();
489 let std_s = df
490 .column(&format!("__fwm_{i}_std"))
491 .map_err(|_| IngestionError::SchemaMismatch {
492 message: format!("missing __fwm_{i}_std"),
493 })?
494 .as_materialized_series();
495 let mean_av = mean_s.get(0).map_err(|e| IngestionError::SchemaMismatch {
496 message: format!("feature_wise mean get: {e}"),
497 })?;
498 let std_av = std_s.get(0).map_err(|e| IngestionError::SchemaMismatch {
499 message: format!("feature_wise std get: {e}"),
500 })?;
501 out.push((
502 columns[i].to_string(),
503 FeatureMeanStd {
504 mean: anyvalue_to_value(mean_av),
505 std_dev: anyvalue_to_value(std_av),
506 },
507 ));
508 }
509 Ok(out)
510 }
511
512 pub(crate) fn lazy_clone(&self) -> LazyFrame {
513 self.lf.clone()
514 }
515
516 pub(crate) fn from_lazyframe(lf: LazyFrame) -> Self {
517 Self { lf }
518 }
519}
520
521fn polars_reduce_expr(column: &str, op: ReduceOp) -> Expr {
522 use polars::datatypes::DataType as P;
523 let c = col(column);
524 match op {
525 ReduceOp::Count => len(),
526 ReduceOp::Sum => c.sum(),
527 ReduceOp::Min => c.min(),
528 ReduceOp::Max => c.max(),
529 ReduceOp::Mean => c.clone().strict_cast(P::Float64).mean(),
530 ReduceOp::Variance(kind) => {
531 let ddof = match kind {
532 VarianceKind::Population => 0u8,
533 VarianceKind::Sample => 1u8,
534 };
535 c.clone().strict_cast(P::Float64).var(ddof)
536 }
537 ReduceOp::StdDev(kind) => {
538 let ddof = match kind {
539 VarianceKind::Population => 0u8,
540 VarianceKind::Sample => 1u8,
541 };
542 c.clone().strict_cast(P::Float64).std(ddof)
543 }
544 ReduceOp::SumSquares => c.clone().strict_cast(P::Float64).pow(lit(2.0)).sum(),
545 ReduceOp::L2Norm => c.clone().strict_cast(P::Float64).pow(lit(2.0)).sum().sqrt(),
546 ReduceOp::CountDistinctNonNull => c.drop_nulls().n_unique(),
547 ReduceOp::Median => c.clone().strict_cast(P::Float64).median(),
548 }
549}
550
551fn agg_to_expr(agg: &Agg) -> Expr {
552 use polars::datatypes::DataType as P;
553 match agg {
554 Agg::CountRows { alias } => len().alias(alias.as_str()),
555 Agg::CountNotNull { column, alias } => col(column.as_str()).count().alias(alias.as_str()),
556 Agg::Sum { column, alias } => col(column.as_str()).sum().alias(alias.as_str()),
557 Agg::Min { column, alias } => col(column.as_str()).min().alias(alias.as_str()),
558 Agg::Max { column, alias } => col(column.as_str()).max().alias(alias.as_str()),
559 Agg::Mean { column, alias } => col(column.as_str())
560 .strict_cast(P::Float64)
561 .mean()
562 .alias(alias.as_str()),
563 Agg::Variance {
564 column,
565 alias,
566 kind,
567 } => {
568 let ddof = match kind {
569 VarianceKind::Population => 0u8,
570 VarianceKind::Sample => 1u8,
571 };
572 col(column.as_str())
573 .strict_cast(P::Float64)
574 .var(ddof)
575 .alias(alias.as_str())
576 }
577 Agg::StdDev {
578 column,
579 alias,
580 kind,
581 } => {
582 let ddof = match kind {
583 VarianceKind::Population => 0u8,
584 VarianceKind::Sample => 1u8,
585 };
586 col(column.as_str())
587 .strict_cast(P::Float64)
588 .std(ddof)
589 .alias(alias.as_str())
590 }
591 Agg::SumSquares { column, alias } => col(column.as_str())
592 .strict_cast(P::Float64)
593 .pow(lit(2.0))
594 .sum()
595 .alias(alias.as_str()),
596 Agg::L2Norm { column, alias } => col(column.as_str())
597 .strict_cast(P::Float64)
598 .pow(lit(2.0))
599 .sum()
600 .sqrt()
601 .alias(alias.as_str()),
602 Agg::CountDistinctNonNull { column, alias } => col(column.as_str())
603 .drop_nulls()
604 .n_unique()
605 .alias(alias.as_str()),
606 Agg::Median { column, alias } => col(column.as_str())
607 .strict_cast(P::Float64)
608 .median()
609 .alias(alias.as_str()),
610 }
611}
612
613fn to_polars_dtype(dt: &DataType) -> polars::datatypes::DataType {
614 match dt {
615 DataType::Int64 => polars::datatypes::DataType::Int64,
616 DataType::Float64 => polars::datatypes::DataType::Float64,
617 DataType::Bool => polars::datatypes::DataType::Boolean,
618 DataType::Utf8 => polars::datatypes::DataType::String,
619 }
620}
621
622fn value_to_lit_expr(value: Value) -> IngestionResult<Expr> {
623 match value {
624 Value::Null => Err(IngestionError::SchemaMismatch {
625 message: "Value::Null is not supported as a literal expression; use fill_null or cast/collect to materialize".to_string(),
626 }),
627 Value::Int64(v) => Ok(lit(v)),
628 Value::Float64(v) => Ok(lit(v)),
629 Value::Bool(v) => Ok(lit(v)),
630 Value::Utf8(v) => Ok(lit(v)),
631 }
632}
633
634fn anyvalue_to_value(av: AnyValue) -> Value {
635 match av {
636 AnyValue::Null => Value::Null,
637 AnyValue::Int8(v) => Value::Int64(v as i64),
638 AnyValue::Int16(v) => Value::Int64(v as i64),
639 AnyValue::Int32(v) => Value::Int64(v as i64),
640 AnyValue::Int64(v) => Value::Int64(v),
641 AnyValue::UInt8(v) => Value::Int64(v as i64),
642 AnyValue::UInt16(v) => Value::Int64(v as i64),
643 AnyValue::UInt32(v) => Value::Int64(v as i64),
644 AnyValue::UInt64(v) => Value::Int64(v as i64),
645 AnyValue::Float64(v) => Value::Float64(v),
646 AnyValue::Boolean(v) => Value::Bool(v),
647 AnyValue::String(v) => Value::Utf8(v.to_string()),
648 AnyValue::StringOwned(v) => Value::Utf8(v.to_string()),
649 other => Value::Utf8(other.to_string()),
650 }
651}
652
653pub type PolarsPipeline = DataFrame;
655
656#[cfg(test)]
657mod tests {
658 use super::{Agg, DataFrame, JoinKind, PolarsPipeline, Predicate};
659 use crate::processing::{ReduceOp, VarianceKind, feature_wise_mean_std, filter, map, reduce};
660 use crate::types::{DataSet, DataType, Field, Schema, Value};
661
662 fn sample_dataset() -> DataSet {
663 let schema = Schema::new(vec![
664 Field::new("id", DataType::Int64),
665 Field::new("active", DataType::Bool),
666 Field::new("score", DataType::Float64),
667 ]);
668 let rows = vec![
669 vec![Value::Int64(1), Value::Bool(true), Value::Float64(10.0)],
670 vec![Value::Int64(2), Value::Bool(true), Value::Float64(20.0)],
671 vec![Value::Int64(3), Value::Bool(false), Value::Float64(30.0)],
672 vec![Value::Int64(4), Value::Bool(true), Value::Null],
673 ];
674 DataSet::new(schema, rows)
675 }
676
677 #[test]
678 fn polars_pipeline_filter_map_reduce_parity_with_in_memory() {
679 let ds = sample_dataset();
680
681 let active_idx = ds.schema.index_of("active").unwrap();
683 let id_idx = ds.schema.index_of("id").unwrap();
684 let filtered = filter(&ds, |row| {
685 let is_active = matches!(row.get(active_idx), Some(Value::Bool(true)));
686 let even_id = matches!(row.get(id_idx), Some(Value::Int64(v)) if *v % 2 == 0);
687 is_active && even_id
688 });
689 let mapped = map(&filtered, |row| {
690 let mut out = row.to_vec();
691 if let Some(Value::Float64(v)) = out.get(2) {
692 out[2] = Value::Float64(v * 2.0);
693 }
694 out
695 });
696 let expected = reduce(&mapped, "score", ReduceOp::Sum).unwrap();
697
698 let got = DataFrame::from_dataset(&ds)
700 .unwrap()
701 .filter(Predicate::Eq {
702 column: "active".to_string(),
703 value: Value::Bool(true),
704 })
705 .unwrap()
706 .filter(Predicate::ModEqInt64 {
707 column: "id".to_string(),
708 modulus: 2,
709 equals: 0,
710 })
711 .unwrap()
712 .multiply_f64("score", 2.0)
713 .unwrap()
714 .sum("score")
715 .unwrap()
716 .unwrap();
717
718 assert_eq!(got, expected);
719 }
720
721 #[test]
722 fn polars_pipeline_reduce_parity_mean_variance_l2_distinct() {
723 let schema = Schema::new(vec![
724 Field::new("x", DataType::Float64),
725 Field::new("tag", DataType::Utf8),
726 ]);
727 let ds = DataSet::new(
728 schema,
729 vec![
730 vec![Value::Float64(1.0), Value::Utf8("a".to_string())],
731 vec![Value::Float64(2.0), Value::Utf8("b".to_string())],
732 vec![Value::Null, Value::Utf8("a".to_string())],
733 ],
734 );
735
736 let mean = reduce(&ds, "x", ReduceOp::Mean).unwrap();
737 let var_pop = reduce(&ds, "x", ReduceOp::Variance(VarianceKind::Population)).unwrap();
738 let l2 = reduce(&ds, "x", ReduceOp::L2Norm).unwrap();
739 let dcnt = reduce(&ds, "tag", ReduceOp::CountDistinctNonNull).unwrap();
740
741 assert_eq!(
742 DataFrame::from_dataset(&ds)
743 .unwrap()
744 .reduce("x", ReduceOp::Mean)
745 .unwrap()
746 .unwrap(),
747 mean
748 );
749 assert_eq!(
750 DataFrame::from_dataset(&ds)
751 .unwrap()
752 .reduce("x", ReduceOp::Variance(VarianceKind::Population))
753 .unwrap()
754 .unwrap(),
755 var_pop
756 );
757 assert_eq!(
758 DataFrame::from_dataset(&ds)
759 .unwrap()
760 .reduce("x", ReduceOp::L2Norm)
761 .unwrap()
762 .unwrap(),
763 l2
764 );
765 assert_eq!(
766 DataFrame::from_dataset(&ds)
767 .unwrap()
768 .reduce("tag", ReduceOp::CountDistinctNonNull)
769 .unwrap()
770 .unwrap(),
771 dcnt
772 );
773 }
774
775 #[test]
776 fn polars_pipeline_collect_select_works() {
777 let ds = sample_dataset();
778 let out = DataFrame::from_dataset(&ds)
779 .unwrap()
780 .select(&["score", "id"])
781 .unwrap()
782 .collect()
783 .unwrap();
784
785 assert_eq!(
786 out.schema.field_names().collect::<Vec<_>>(),
787 vec!["score", "id"]
788 );
789 assert_eq!(out.row_count(), ds.row_count());
790 assert_eq!(out.rows[0][0], Value::Float64(10.0));
791 assert_eq!(out.rows[0][1], Value::Int64(1));
792 }
793
794 #[test]
795 fn polars_pipeline_sum_returns_none_for_missing_column() {
796 let ds = sample_dataset();
797 let out = DataFrame::from_dataset(&ds)
798 .unwrap()
799 .sum("missing")
800 .unwrap();
801 assert_eq!(out, None);
802 }
803
804 #[test]
805 fn polars_errors_are_preserved_as_engine_error_sources() {
806 let schema = Schema::new(vec![Field::new("name", DataType::Utf8)]);
808 let ds = DataSet::new(schema, vec![vec![Value::Utf8("x".to_string())]]);
809
810 let err = DataFrame::from_dataset(&ds)
811 .unwrap()
812 .multiply_f64("name", 2.0)
813 .unwrap()
814 .collect()
815 .unwrap_err();
816
817 match err {
819 crate::error::IngestionError::Engine { source, .. } => {
820 assert!(!source.to_string().is_empty());
821 }
822 other => panic!("expected Engine error, got: {other:?}"),
823 }
824 }
825
826 #[test]
827 fn backwards_compatible_polars_pipeline_alias_exists() {
828 let ds = sample_dataset();
829 let _ = PolarsPipeline::from_dataset(&ds)
830 .unwrap()
831 .select(&["id"])
832 .unwrap();
833 }
834
835 #[test]
836 fn rename_cast_fill_null_group_by_and_join_work() {
837 let schema = Schema::new(vec![
839 Field::new("id", DataType::Int64),
840 Field::new("score", DataType::Int64),
841 ]);
842 let ds = DataSet::new(
843 schema,
844 vec![
845 vec![Value::Int64(1), Value::Int64(10)],
846 vec![Value::Int64(2), Value::Null],
847 ],
848 );
849
850 let out = DataFrame::from_dataset(&ds)
851 .unwrap()
852 .rename(&[("score", "score_i")])
853 .unwrap()
854 .cast("score_i", DataType::Float64)
855 .unwrap()
856 .fill_null("score_i", Value::Float64(0.0))
857 .unwrap()
858 .collect()
859 .unwrap();
860
861 assert_eq!(
862 out.schema.field_names().collect::<Vec<_>>(),
863 vec!["id", "score_i"]
864 );
865 assert_eq!(out.rows[0][1], Value::Float64(10.0));
866 assert_eq!(out.rows[1][1], Value::Float64(0.0));
867
868 let schema = Schema::new(vec![
870 Field::new("grp", DataType::Utf8),
871 Field::new("score", DataType::Float64),
872 ]);
873 let ds = DataSet::new(
874 schema,
875 vec![
876 vec![Value::Utf8("A".to_string()), Value::Float64(1.0)],
877 vec![Value::Utf8("A".to_string()), Value::Float64(2.0)],
878 vec![Value::Utf8("B".to_string()), Value::Null],
879 ],
880 );
881
882 let out = DataFrame::from_dataset(&ds)
883 .unwrap()
884 .group_by(
885 &["grp"],
886 &[
887 Agg::Sum {
888 column: "score".to_string(),
889 alias: "sum_score".to_string(),
890 },
891 Agg::CountRows {
892 alias: "cnt".to_string(),
893 },
894 ],
895 )
896 .unwrap()
897 .collect()
898 .unwrap();
899
900 let mut sums: std::collections::HashMap<String, (Value, Value)> =
902 std::collections::HashMap::new();
903 for row in &out.rows {
904 if let Value::Utf8(g) = &row[0] {
905 sums.insert(g.clone(), (row[1].clone(), row[2].clone()));
906 }
907 }
908 assert_eq!(sums.get("A"), Some(&(Value::Float64(3.0), Value::Int64(2))));
909 assert_eq!(
910 sums.get("B"),
911 Some(&(Value::Float64(0.0), Value::Int64(1)))
913 );
914
915 let left = DataSet::new(
917 Schema::new(vec![
918 Field::new("id", DataType::Int64),
919 Field::new("name", DataType::Utf8),
920 ]),
921 vec![
922 vec![Value::Int64(1), Value::Utf8("Ada".to_string())],
923 vec![Value::Int64(2), Value::Utf8("Grace".to_string())],
924 ],
925 );
926 let right = DataSet::new(
927 Schema::new(vec![
928 Field::new("id", DataType::Int64),
929 Field::new("score", DataType::Float64),
930 ]),
931 vec![
932 vec![Value::Int64(1), Value::Float64(9.0)],
933 vec![Value::Int64(3), Value::Float64(7.0)],
934 ],
935 );
936
937 let out = DataFrame::from_dataset(&left)
938 .unwrap()
939 .join(
940 DataFrame::from_dataset(&right).unwrap(),
941 &["id"],
942 &["id"],
943 JoinKind::Inner,
944 )
945 .unwrap()
946 .collect()
947 .unwrap();
948 assert_eq!(out.row_count(), 1);
949 assert_eq!(out.rows[0][0], Value::Int64(1));
951 }
952
953 #[test]
954 fn polars_feature_wise_mean_std_matches_in_memory() {
955 let schema = Schema::new(vec![
956 Field::new("a", DataType::Int64),
957 Field::new("b", DataType::Float64),
958 ]);
959 let ds = DataSet::new(
960 schema,
961 vec![
962 vec![Value::Int64(1), Value::Float64(10.0)],
963 vec![Value::Int64(3), Value::Float64(20.0)],
964 ],
965 );
966 let mem = feature_wise_mean_std(&ds, &["a", "b"], VarianceKind::Sample).unwrap();
967 let pol = DataFrame::from_dataset(&ds)
968 .unwrap()
969 .feature_wise_mean_std(&["a", "b"], VarianceKind::Sample)
970 .unwrap();
971 assert_eq!(mem.len(), pol.len());
972 for i in 0..mem.len() {
973 assert_eq!(mem[i].0, pol[i].0);
974 assert_eq!(mem[i].1.mean, pol[i].1.mean);
975 match (&mem[i].1.std_dev, &pol[i].1.std_dev) {
976 (Value::Float64(m), Value::Float64(p)) => assert!((m - p).abs() < 1e-9),
977 (a, b) => assert_eq!(a, b),
978 }
979 }
980 }
981
982 #[test]
983 fn group_by_mean_std_count_distinct_all_null_numeric_is_null() {
984 let schema = Schema::new(vec![
985 Field::new("g", DataType::Utf8),
986 Field::new("x", DataType::Float64),
987 Field::new("tag", DataType::Utf8),
988 ]);
989 let ds = DataSet::new(
990 schema,
991 vec![
992 vec![
993 Value::Utf8("A".to_string()),
994 Value::Null,
995 Value::Utf8("p".to_string()),
996 ],
997 vec![
998 Value::Utf8("A".to_string()),
999 Value::Null,
1000 Value::Utf8("q".to_string()),
1001 ],
1002 ],
1003 );
1004 let out = DataFrame::from_dataset(&ds)
1005 .unwrap()
1006 .group_by(
1007 &["g"],
1008 &[
1009 Agg::Mean {
1010 column: "x".to_string(),
1011 alias: "mx".to_string(),
1012 },
1013 Agg::StdDev {
1014 column: "x".to_string(),
1015 alias: "sx".to_string(),
1016 kind: VarianceKind::Sample,
1017 },
1018 Agg::CountDistinctNonNull {
1019 column: "tag".to_string(),
1020 alias: "dt".to_string(),
1021 },
1022 ],
1023 )
1024 .unwrap()
1025 .collect()
1026 .unwrap();
1027 assert_eq!(out.row_count(), 1);
1028 assert_eq!(out.rows[0][0], Value::Utf8("A".to_string()));
1029 assert_eq!(out.rows[0][1], Value::Null);
1030 assert_eq!(out.rows[0][2], Value::Null);
1031 assert_eq!(out.rows[0][3], Value::Int64(2));
1032 }
1033}