Skip to main content

rust_data_processing/
export.rs

1//! Deterministic **JSON Lines** export and simple **train / test** row splits (Phase 2).
2//!
3//! This module does **not** implement tokenizers or model-specific chat templates. Callers align
4//! exported text with their trainer’s expected fields.
5
6use serde_json::{Map, Value as JsonValue};
7
8use crate::error::{IngestionError, IngestionResult};
9use crate::types::{DataSet, Value};
10
11fn cell_to_json(v: &Value) -> JsonValue {
12    match v {
13        Value::Null => JsonValue::Null,
14        Value::Int64(i) => JsonValue::from(*i),
15        Value::Float64(x) => JsonValue::from(*x),
16        Value::Bool(b) => JsonValue::from(*b),
17        Value::Utf8(s) => JsonValue::from(s.clone()),
18    }
19}
20
21/// Serialize each row as one JSON object per line (UTF-8), columns in `column_order`.
22///
23/// Column names must exist on `ds`. Row order is preserved.
24pub fn dataset_to_jsonl(ds: &DataSet, column_order: &[String]) -> IngestionResult<String> {
25    let idx: Vec<usize> = column_order
26        .iter()
27        .map(|name| {
28            ds.schema
29                .index_of(name)
30                .ok_or_else(|| IngestionError::SchemaMismatch {
31                    message: format!("dataset_to_jsonl: unknown column '{name}'"),
32                })
33        })
34        .collect::<Result<_, _>>()?;
35
36    let mut out = String::new();
37    for row in &ds.rows {
38        let mut m = Map::new();
39        for (name, &i) in column_order.iter().zip(&idx) {
40            m.insert(name.clone(), cell_to_json(&row[i]));
41        }
42        let line = serde_json::to_string(&JsonValue::Object(m)).map_err(|e| {
43            IngestionError::SchemaMismatch {
44                message: format!("dataset_to_jsonl: json encode failed: {e}"),
45            }
46        })?;
47        out.push_str(&line);
48        out.push('\n');
49    }
50    Ok(out)
51}
52
53/// Deterministic split: first `train_count` rows are train, remaining rows are test, where
54/// `train_count = row_count - round(row_count * test_fraction.clamp(0..=1))`.
55pub fn train_test_row_indices(row_count: usize, test_fraction: f64) -> (Vec<usize>, Vec<usize>) {
56    let tf = test_fraction.clamp(0.0, 1.0);
57    let test_n = ((row_count as f64) * tf).round() as usize;
58    let test_n = test_n.min(row_count);
59    let train_n = row_count.saturating_sub(test_n);
60    let train: Vec<usize> = (0..train_n).collect();
61    let test: Vec<usize> = (train_n..row_count).collect();
62    (train, test)
63}
64
65/// Keep only rows whose UTF-8 value in `column` has at most `max_chars` Unicode scalars; other rows dropped.
66pub fn filter_rows_max_utf8_chars(
67    ds: &DataSet,
68    column: &str,
69    max_chars: usize,
70) -> IngestionResult<DataSet> {
71    let idx = ds
72        .schema
73        .index_of(column)
74        .ok_or_else(|| IngestionError::SchemaMismatch {
75            message: format!("filter_rows_max_utf8_chars: unknown column '{column}'"),
76        })?;
77    if ds.schema.fields[idx].data_type != crate::types::DataType::Utf8 {
78        return Err(IngestionError::SchemaMismatch {
79            message: format!("column '{column}' must be Utf8"),
80        });
81    }
82    let mut rows = Vec::new();
83    for row in &ds.rows {
84        match row.get(idx) {
85            Some(Value::Utf8(s)) if s.chars().count() <= max_chars => rows.push(row.clone()),
86            Some(Value::Null) | None => rows.push(row.clone()),
87            _ => {}
88        }
89    }
90    Ok(DataSet::new(ds.schema.clone(), rows))
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::types::{DataType, Field, Schema};
97
98    #[test]
99    fn jsonl_roundtrip_ordering_and_split() {
100        let schema = Schema::new(vec![
101            Field::new("a", DataType::Int64),
102            Field::new("b", DataType::Utf8),
103        ]);
104        let ds = DataSet::new(
105            schema.clone(),
106            vec![
107                vec![Value::Int64(1), Value::Utf8("x".into())],
108                vec![Value::Int64(2), Value::Utf8("yy".into())],
109                vec![Value::Int64(3), Value::Utf8("zzz".into())],
110            ],
111        );
112        let jl = dataset_to_jsonl(&ds, &["a".into(), "b".into()]).unwrap();
113        assert_eq!(
114            jl,
115            "{\"a\":1,\"b\":\"x\"}\n{\"a\":2,\"b\":\"yy\"}\n{\"a\":3,\"b\":\"zzz\"}\n"
116        );
117        let (tr, te) = train_test_row_indices(3, 1.0 / 3.0);
118        assert_eq!(tr, vec![0, 1]);
119        assert_eq!(te, vec![2]);
120    }
121
122    #[test]
123    fn filter_max_chars_drops_long_rows() {
124        let schema = Schema::new(vec![Field::new("s", DataType::Utf8)]);
125        let ds = DataSet::new(
126            schema,
127            vec![
128                vec![Value::Utf8("ab".into())],
129                vec![Value::Utf8("abc".into())],
130            ],
131        );
132        let out = filter_rows_max_utf8_chars(&ds, "s", 2).unwrap();
133        assert_eq!(out.row_count(), 1);
134    }
135
136    #[test]
137    fn jsonl_empty_dataset() {
138        let schema = Schema::new(vec![Field::new("id", DataType::Int64)]);
139        let ds = DataSet::new(schema, vec![]);
140        assert_eq!(dataset_to_jsonl(&ds, &["id".into()]).unwrap(), "");
141    }
142}