1use crate::error::{IngestionError, IngestionResult};
64use crate::pipeline::{CastMode, DataFrame};
65use crate::types::{DataSet, DataType, Schema, Value};
66use serde::{Deserialize, Serialize};
67use sha2::{Digest, Sha256};
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
71pub enum TransformStep {
72 Select { columns: Vec<String> },
74 Drop { columns: Vec<String> },
76 Rename { pairs: Vec<(String, String)> },
78 Cast {
80 column: String,
81 to: DataType,
82 #[serde(default)]
83 mode: CastMode,
84 },
85 FillNull { column: String, value: Value },
87 WithLiteral { name: String, value: Value },
89 DeriveMulF64 {
91 name: String,
92 source: String,
93 factor: f64,
94 },
95 DeriveAddF64 {
97 name: String,
98 source: String,
99 delta: f64,
100 },
101 Utf8Truncate { column: String, max_chars: usize },
103 Utf8Sha256Hex { column: String },
105 Utf8RedactMiddle {
107 column: String,
108 keep_left: usize,
109 keep_right: usize,
110 redaction: String,
112 },
113}
114
115#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
121pub struct TransformSpec {
122 pub output_schema: Schema,
123 pub steps: Vec<TransformStep>,
124}
125
126impl TransformSpec {
127 pub fn new(output_schema: Schema) -> Self {
128 Self {
129 output_schema,
130 steps: Vec::new(),
131 }
132 }
133
134 pub fn with_step(mut self, step: TransformStep) -> Self {
135 self.steps.push(step);
136 self
137 }
138
139 pub fn apply(&self, input: &DataSet) -> IngestionResult<DataSet> {
141 let mut df = DataFrame::from_dataset(input)?;
142
143 for step in &self.steps {
144 df = match step {
145 TransformStep::Select { columns } => {
146 let cols: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
147 df.select(&cols)?
148 }
149 TransformStep::Drop { columns } => {
150 let cols: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
151 df.drop(&cols)?
152 }
153 TransformStep::Rename { pairs } => {
154 let pairs_ref: Vec<(&str, &str)> = pairs
155 .iter()
156 .map(|(a, b)| (a.as_str(), b.as_str()))
157 .collect();
158 df.rename(&pairs_ref)?
159 }
160 TransformStep::Cast { column, to, mode } => {
161 df.cast_with_mode(column, to.clone(), *mode)?
162 }
163 TransformStep::FillNull { column, value } => df.fill_null(column, value.clone())?,
164 TransformStep::WithLiteral { name, value } => {
165 df.with_literal(name, value.clone())?
166 }
167 TransformStep::DeriveMulF64 {
168 name,
169 source,
170 factor,
171 } => df.with_mul_f64(name, source, *factor)?,
172 TransformStep::DeriveAddF64 {
173 name,
174 source,
175 delta,
176 } => df.with_add_f64(name, source, *delta)?,
177 TransformStep::Utf8Truncate { column, max_chars } => {
178 Self::apply_utf8_dataset_step(df, |ds| {
179 utf8_truncate_dataset(ds, column, *max_chars)
180 })?
181 }
182 TransformStep::Utf8Sha256Hex { column } => {
183 Self::apply_utf8_dataset_step(df, |ds| utf8_sha256_dataset(ds, column))?
184 }
185 TransformStep::Utf8RedactMiddle {
186 column,
187 keep_left,
188 keep_right,
189 redaction,
190 } => Self::apply_utf8_dataset_step(df, |ds| {
191 utf8_redact_middle_dataset(ds, column, *keep_left, *keep_right, redaction)
192 })?,
193 };
194 }
195
196 df.collect_with_schema(&self.output_schema)
197 }
198
199 fn apply_utf8_dataset_step<F>(df: DataFrame, mut f: F) -> IngestionResult<DataFrame>
200 where
201 F: FnMut(&mut DataSet) -> IngestionResult<()>,
202 {
203 let mut ds = df.collect()?;
204 f(&mut ds)?;
205 DataFrame::from_dataset(&ds)
206 }
207}
208
209fn utf8_field_index(ds: &DataSet, column: &str) -> IngestionResult<usize> {
210 let idx = ds
211 .schema
212 .index_of(column)
213 .ok_or_else(|| IngestionError::SchemaMismatch {
214 message: format!("unknown column '{column}' for UTF-8 transform"),
215 })?;
216 if ds.schema.fields[idx].data_type != DataType::Utf8 {
217 return Err(IngestionError::SchemaMismatch {
218 message: format!("column '{column}' must be Utf8 for this transform"),
219 });
220 }
221 Ok(idx)
222}
223
224fn utf8_truncate_dataset(ds: &mut DataSet, column: &str, max_chars: usize) -> IngestionResult<()> {
225 let idx = utf8_field_index(ds, column)?;
226 for row in &mut ds.rows {
227 if let Value::Utf8(s) = &mut row[idx] {
228 let t: String = s.chars().take(max_chars).collect();
229 *s = t;
230 }
231 }
232 Ok(())
233}
234
235fn utf8_sha256_dataset(ds: &mut DataSet, column: &str) -> IngestionResult<()> {
236 use std::fmt::Write as _;
237 let idx = utf8_field_index(ds, column)?;
238 for row in &mut ds.rows {
239 if let Value::Utf8(s) = &mut row[idx] {
240 let mut h = Sha256::new();
241 h.update(s.as_bytes());
242 let out = h.finalize();
243 let mut hex = String::with_capacity(64);
244 for b in out.iter() {
245 let _ = write!(&mut hex, "{b:02x}");
246 }
247 *s = hex;
248 }
249 }
250 Ok(())
251}
252
253fn utf8_redact_middle_dataset(
254 ds: &mut DataSet,
255 column: &str,
256 keep_left: usize,
257 keep_right: usize,
258 redaction: &str,
259) -> IngestionResult<()> {
260 let idx = utf8_field_index(ds, column)?;
261 for row in &mut ds.rows {
262 if let Value::Utf8(s) = &mut row[idx] {
263 let chs: Vec<char> = s.chars().collect();
264 let n = chs.len();
265 if n > keep_left + keep_right {
266 let left: String = chs.iter().take(keep_left).collect();
267 let right: String = chs.iter().skip(n.saturating_sub(keep_right)).collect();
268 *s = format!("{left}{redaction}{right}");
269 }
270 }
271 }
272 Ok(())
273}
274
275#[cfg(feature = "arrow")]
277pub mod arrow {
278 use std::sync::Arc;
279
280 use arrow::array::{
281 Array, ArrayRef, BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray,
282 };
283 use arrow::compute::concat_batches;
284 use arrow::datatypes::{DataType as ArrowDataType, Field, Schema as ArrowSchema};
285 use arrow::record_batch::RecordBatch;
286
287 use crate::error::{IngestionError, IngestionResult};
288 use crate::types::{DataSet, DataType, Field as DsField, Schema, Value};
289
290 pub fn schema_from_record_batch(batch: &RecordBatch) -> IngestionResult<Schema> {
291 let mut fields = Vec::with_capacity(batch.schema().fields().len());
292 for f in batch.schema().fields() {
293 let dt = match f.data_type() {
294 ArrowDataType::Int64 => DataType::Int64,
295 ArrowDataType::Float64 => DataType::Float64,
296 ArrowDataType::Boolean => DataType::Bool,
297 ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => DataType::Utf8,
298 other => {
299 return Err(IngestionError::SchemaMismatch {
300 message: format!("unsupported Arrow dtype in schema: {other:?}"),
301 });
302 }
303 };
304 fields.push(DsField::new(f.name().to_string(), dt));
305 }
306 Ok(Schema::new(fields))
307 }
308
309 pub fn dataset_to_record_batch(ds: &DataSet) -> IngestionResult<RecordBatch> {
310 let mut arrow_fields = Vec::with_capacity(ds.schema.fields.len());
311 let mut cols: Vec<ArrayRef> = Vec::with_capacity(ds.schema.fields.len());
312
313 for (col_idx, field) in ds.schema.fields.iter().enumerate() {
314 match field.data_type {
315 DataType::Int64 => {
316 let mut v = Vec::with_capacity(ds.row_count());
317 for row in &ds.rows {
318 match row.get(col_idx) {
319 Some(Value::Null) | None => v.push(None),
320 Some(Value::Int64(x)) => v.push(Some(*x)),
321 Some(other) => {
322 return Err(IngestionError::ParseError {
323 row: 1,
324 column: field.name.clone(),
325 raw: format!("{other:?}"),
326 message: "value does not match schema type Int64".to_string(),
327 });
328 }
329 }
330 }
331 cols.push(Arc::new(Int64Array::from(v)) as ArrayRef);
332 arrow_fields.push(Field::new(&field.name, ArrowDataType::Int64, true));
333 }
334 DataType::Float64 => {
335 let mut v = Vec::with_capacity(ds.row_count());
336 for row in &ds.rows {
337 match row.get(col_idx) {
338 Some(Value::Null) | None => v.push(None),
339 Some(Value::Float64(x)) => v.push(Some(*x)),
340 Some(other) => {
341 return Err(IngestionError::ParseError {
342 row: 1,
343 column: field.name.clone(),
344 raw: format!("{other:?}"),
345 message: "value does not match schema type Float64".to_string(),
346 });
347 }
348 }
349 }
350 cols.push(Arc::new(Float64Array::from(v)) as ArrayRef);
351 arrow_fields.push(Field::new(&field.name, ArrowDataType::Float64, true));
352 }
353 DataType::Bool => {
354 let mut v = Vec::with_capacity(ds.row_count());
355 for row in &ds.rows {
356 match row.get(col_idx) {
357 Some(Value::Null) | None => v.push(None),
358 Some(Value::Bool(x)) => v.push(Some(*x)),
359 Some(other) => {
360 return Err(IngestionError::ParseError {
361 row: 1,
362 column: field.name.clone(),
363 raw: format!("{other:?}"),
364 message: "value does not match schema type Bool".to_string(),
365 });
366 }
367 }
368 }
369 cols.push(Arc::new(BooleanArray::from(v)) as ArrayRef);
370 arrow_fields.push(Field::new(&field.name, ArrowDataType::Boolean, true));
371 }
372 DataType::Utf8 => {
373 let mut v = Vec::with_capacity(ds.row_count());
374 for row in &ds.rows {
375 match row.get(col_idx) {
376 Some(Value::Null) | None => v.push(None),
377 Some(Value::Utf8(x)) => v.push(Some(x.as_str())),
378 Some(other) => {
379 return Err(IngestionError::ParseError {
380 row: 1,
381 column: field.name.clone(),
382 raw: format!("{other:?}"),
383 message: "value does not match schema type Utf8".to_string(),
384 });
385 }
386 }
387 }
388 cols.push(Arc::new(StringArray::from(v)) as ArrayRef);
389 arrow_fields.push(Field::new(&field.name, ArrowDataType::Utf8, true));
390 }
391 }
392 }
393
394 let schema = Arc::new(ArrowSchema::new(arrow_fields));
395 RecordBatch::try_new(schema, cols).map_err(|e| IngestionError::Engine {
396 message: "failed to build Arrow RecordBatch".to_string(),
397 source: Box::new(e),
398 })
399 }
400
401 pub fn record_batch_to_dataset(
402 batch: &RecordBatch,
403 schema: &Schema,
404 ) -> IngestionResult<DataSet> {
405 let mut col_idx = Vec::with_capacity(schema.fields.len());
407 for f in &schema.fields {
408 let idx =
409 batch
410 .schema()
411 .index_of(&f.name)
412 .map_err(|_| IngestionError::SchemaMismatch {
413 message: format!("missing required column '{}'", f.name),
414 })?;
415 col_idx.push(idx);
416 }
417
418 let nrows = batch.num_rows();
419 let mut out_rows = Vec::with_capacity(nrows);
420 for row_i in 0..nrows {
421 let mut row = Vec::with_capacity(schema.fields.len());
422 for (field, idx) in schema.fields.iter().zip(col_idx.iter().copied()) {
423 let arr = batch.column(idx);
424 let v = match field.data_type {
425 DataType::Int64 => {
426 let a = arr.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
427 IngestionError::SchemaMismatch {
428 message: format!("arrow column '{}' is not Int64", field.name),
429 }
430 })?;
431 if a.is_null(row_i) {
432 Value::Null
433 } else {
434 Value::Int64(a.value(row_i))
435 }
436 }
437 DataType::Float64 => {
438 let a = arr.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
439 IngestionError::SchemaMismatch {
440 message: format!("arrow column '{}' is not Float64", field.name),
441 }
442 })?;
443 if a.is_null(row_i) {
444 Value::Null
445 } else {
446 Value::Float64(a.value(row_i))
447 }
448 }
449 DataType::Bool => {
450 let a = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
451 IngestionError::SchemaMismatch {
452 message: format!("arrow column '{}' is not Boolean", field.name),
453 }
454 })?;
455 if a.is_null(row_i) {
456 Value::Null
457 } else {
458 Value::Bool(a.value(row_i))
459 }
460 }
461 DataType::Utf8 => {
462 if let Some(a) = arr.as_any().downcast_ref::<StringArray>() {
464 if a.is_null(row_i) {
465 Value::Null
466 } else {
467 Value::Utf8(a.value(row_i).to_string())
468 }
469 } else if let Some(a) = arr.as_any().downcast_ref::<LargeStringArray>() {
470 if a.is_null(row_i) {
471 Value::Null
472 } else {
473 Value::Utf8(a.value(row_i).to_string())
474 }
475 } else {
476 return Err(IngestionError::SchemaMismatch {
477 message: format!(
478 "arrow column '{}' is not Utf8/LargeUtf8",
479 field.name
480 ),
481 });
482 }
483 }
484 };
485 row.push(v);
486 }
487 out_rows.push(row);
488 }
489 Ok(DataSet::new(schema.clone(), out_rows))
490 }
491
492 pub fn record_batches_to_dataset(
495 batches: &[RecordBatch],
496 schema: &Schema,
497 ) -> IngestionResult<DataSet> {
498 if batches.is_empty() {
499 return Ok(DataSet::new(schema.clone(), Vec::new()));
500 }
501 let sch_ref = batches[0].schema();
502 for b in batches.iter().skip(1) {
503 if b.schema().as_ref() != sch_ref.as_ref() {
504 return Err(IngestionError::SchemaMismatch {
505 message:
506 "record_batches_to_dataset: all batches must share the same Arrow schema"
507 .to_string(),
508 });
509 }
510 }
511 let merged = if batches.len() == 1 {
512 batches[0].clone()
513 } else {
514 concat_batches(&sch_ref, batches).map_err(|e| IngestionError::Engine {
515 message: "arrow concat_batches failed".to_string(),
516 source: Box::new(e),
517 })?
518 };
519 record_batch_to_dataset(&merged, schema)
520 }
521}
522
523#[cfg(feature = "serde_arrow")]
527pub mod serde_interop {
528 use arrow::datatypes::FieldRef;
529 use arrow::record_batch::RecordBatch;
530 use serde_arrow::schema::{SchemaLike, TracingOptions};
531
532 use crate::error::{IngestionError, IngestionResult};
533
534 pub fn to_record_batch<T>(records: &Vec<T>) -> IngestionResult<RecordBatch>
536 where
537 T: serde::Serialize + for<'de> serde::Deserialize<'de>,
538 {
539 let fields = Vec::<FieldRef>::from_type::<T>(TracingOptions::default()).map_err(|e| {
540 IngestionError::Engine {
541 message: "failed to trace Arrow schema from type".to_string(),
542 source: Box::new(e),
543 }
544 })?;
545
546 serde_arrow::to_record_batch(&fields, records).map_err(|e| IngestionError::Engine {
547 message: "failed to convert records to Arrow RecordBatch".to_string(),
548 source: Box::new(e),
549 })
550 }
551
552 pub fn from_record_batch<T>(batch: &RecordBatch) -> IngestionResult<Vec<T>>
554 where
555 T: serde::de::DeserializeOwned,
556 {
557 serde_arrow::from_record_batch(batch).map_err(|e| IngestionError::Engine {
558 message: "failed to deserialize records from Arrow RecordBatch".to_string(),
559 source: Box::new(e),
560 })
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::{TransformSpec, TransformStep};
567 use crate::pipeline::CastMode;
568 use crate::types::{DataSet, DataType, Field, Schema, Value};
569
570 fn sample_dataset() -> DataSet {
571 let schema = Schema::new(vec![
572 Field::new("id", DataType::Int64),
573 Field::new("score", DataType::Int64),
574 ]);
575 let rows = vec![
576 vec![Value::Int64(1), Value::Int64(10)],
577 vec![Value::Int64(2), Value::Null],
578 ];
579 DataSet::new(schema, rows)
580 }
581
582 #[test]
583 fn transform_spec_can_rename_cast_fill_and_derive() {
584 let ds = sample_dataset();
585
586 let out_schema = Schema::new(vec![
587 Field::new("id", DataType::Int64),
588 Field::new("score_x2", DataType::Float64),
589 Field::new("score_f", DataType::Float64),
590 Field::new("tag", DataType::Utf8),
591 ]);
592
593 let spec = TransformSpec::new(out_schema.clone())
594 .with_step(TransformStep::Rename {
595 pairs: vec![("score".to_string(), "score_f".to_string())],
596 })
597 .with_step(TransformStep::Cast {
598 column: "score_f".to_string(),
599 to: DataType::Float64,
600 mode: CastMode::Strict,
601 })
602 .with_step(TransformStep::FillNull {
603 column: "score_f".to_string(),
604 value: Value::Float64(0.0),
605 })
606 .with_step(TransformStep::DeriveMulF64 {
607 name: "score_x2".to_string(),
608 source: "score_f".to_string(),
609 factor: 2.0,
610 })
611 .with_step(TransformStep::WithLiteral {
612 name: "tag".to_string(),
613 value: Value::Utf8("A".to_string()),
614 })
615 .with_step(TransformStep::Select {
616 columns: vec![
617 "id".to_string(),
618 "score_x2".to_string(),
619 "score_f".to_string(),
620 "tag".to_string(),
621 ],
622 });
623
624 let out = spec.apply(&ds).unwrap();
625 assert_eq!(out.schema, out_schema);
626 assert_eq!(out.row_count(), 2);
627 assert_eq!(out.rows[0][0], Value::Int64(1));
628 assert_eq!(out.rows[0][1], Value::Float64(20.0));
629 assert_eq!(out.rows[0][2], Value::Float64(10.0));
630 assert_eq!(out.rows[0][3], Value::Utf8("A".to_string()));
631
632 assert_eq!(out.rows[1][0], Value::Int64(2));
633 assert_eq!(out.rows[1][1], Value::Float64(0.0));
634 assert_eq!(out.rows[1][2], Value::Float64(0.0));
635 assert_eq!(out.rows[1][3], Value::Utf8("A".to_string()));
636 }
637
638 #[test]
639 fn utf8_privacy_transforms_apply() {
640 let schema = Schema::new(vec![Field::new("s", DataType::Utf8)]);
641 let ds = DataSet::new(
642 schema.clone(),
643 vec![
644 vec![Value::Utf8("abcdef".into())],
645 vec![Value::Utf8("hi".into())],
646 ],
647 );
648 let out_schema = schema.clone();
649 let spec = TransformSpec::new(out_schema)
650 .with_step(TransformStep::Utf8Truncate {
651 column: "s".into(),
652 max_chars: 3,
653 })
654 .with_step(TransformStep::Utf8RedactMiddle {
655 column: "s".into(),
656 keep_left: 1,
657 keep_right: 1,
658 redaction: "***".into(),
659 });
660 let out = spec.apply(&ds).unwrap();
661 assert_eq!(out.rows[0][0], Value::Utf8("a***c".into()));
662 assert_eq!(out.rows[1][0], Value::Utf8("hi".into()));
663
664 let ds2 = DataSet::new(
665 Schema::new(vec![Field::new("s", DataType::Utf8)]),
666 vec![vec![Value::Utf8("abc".into())]],
667 );
668 let spec2 = TransformSpec::new(ds2.schema.clone())
669 .with_step(TransformStep::Utf8Sha256Hex { column: "s".into() });
670 let h = spec2.apply(&ds2).unwrap().rows[0][0].clone();
671 let Value::Utf8(hex) = h else {
672 panic!("expected utf8");
673 };
674 assert_eq!(hex.len(), 64);
675 }
676}