Skip to content

Commit ae74c10

Browse files
feat(connect): read/write → csv, write → json (Eventual-Inc#3361)
1 parent ea8f8bd commit ae74c10

File tree

3 files changed

+109
-14
lines changed

3 files changed

+109
-14
lines changed

src/daft-connect/src/op/execute/write.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ impl Session {
5555
bail!("Source is required");
5656
};
5757

58-
if source != "parquet" {
59-
bail!("Unsupported source: {source}; only parquet is supported");
60-
}
58+
let file_format: FileFormat = source.parse()?;
6159

6260
let Ok(mode) = SaveMode::try_from(mode) else {
6361
bail!("Invalid save mode: {mode}");
@@ -115,7 +113,7 @@ impl Session {
115113
let plan = translator.to_logical_plan(input).await?;
116114

117115
let plan = plan
118-
.table_write(&path, FileFormat::Parquet, None, None, None)
116+
.table_write(&path, file_format, None, None, None)
119117
.wrap_err("Failed to create table write plan")?;
120118

121119
let optimized_plan = plan.optimize()?;

src/daft-connect/src/translation/logical_plan/read/data_source.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use daft_logical_plan::LogicalPlanBuilder;
2-
use daft_scan::builder::ParquetScanBuilder;
2+
use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder};
33
use eyre::{bail, ensure, WrapErr};
44
use tracing::warn;
55

@@ -18,10 +18,6 @@ pub async fn data_source(
1818
bail!("Format is required");
1919
};
2020

21-
if format != "parquet" {
22-
bail!("Unsupported format: {format}; only parquet is supported");
23-
}
24-
2521
ensure!(!paths.is_empty(), "Paths are required");
2622

2723
if let Some(schema) = schema {
@@ -36,10 +32,23 @@ pub async fn data_source(
3632
warn!("Ignoring predicates: {predicates:?}; not yet implemented");
3733
}
3834

39-
let builder = ParquetScanBuilder::new(paths)
40-
.finish()
41-
.await
42-
.wrap_err("Failed to create parquet scan builder")?;
35+
let plan = match &*format {
36+
"parquet" => ParquetScanBuilder::new(paths)
37+
.finish()
38+
.await
39+
.wrap_err("Failed to create parquet scan builder")?,
40+
"csv" => CsvScanBuilder::new(paths)
41+
.finish()
42+
.await
43+
.wrap_err("Failed to create csv scan builder")?,
44+
"json" => {
45+
// todo(completeness): implement json reading
46+
bail!("json reading is not yet implemented");
47+
}
48+
other => {
49+
bail!("Unsupported format: {other}; only parquet and csv are supported");
50+
}
51+
};
4352

44-
Ok(builder)
53+
Ok(plan)
4554
}

tests/connect/test_csv.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
import pytest
6+
7+
8+
def test_write_csv_basic(spark_session, tmp_path):
9+
df = spark_session.range(10)
10+
csv_dir = os.path.join(tmp_path, "csv")
11+
df.write.csv(csv_dir)
12+
13+
csv_files = [f for f in os.listdir(csv_dir) if f.endswith(".csv")]
14+
assert len(csv_files) > 0, "Expected at least one CSV file to be written"
15+
16+
df_read = spark_session.read.csv(str(csv_dir))
17+
df_pandas = df.toPandas()
18+
df_read_pandas = df_read.toPandas()
19+
assert df_pandas["id"].equals(df_read_pandas["id"]), "Data should be unchanged after write/read"
20+
21+
22+
def test_write_csv_with_header(spark_session, tmp_path):
23+
df = spark_session.range(10)
24+
csv_dir = os.path.join(tmp_path, "csv")
25+
df.write.option("header", True).csv(csv_dir)
26+
27+
df_read = spark_session.read.option("header", True).csv(str(csv_dir))
28+
df_pandas = df.toPandas()
29+
df_read_pandas = df_read.toPandas()
30+
assert df_pandas["id"].equals(df_read_pandas["id"])
31+
32+
33+
def test_write_csv_with_delimiter(spark_session, tmp_path):
34+
df = spark_session.range(10)
35+
csv_dir = os.path.join(tmp_path, "csv")
36+
df.write.option("sep", "|").csv(csv_dir)
37+
38+
df_read = spark_session.read.option("sep", "|").csv(str(csv_dir))
39+
df_pandas = df.toPandas()
40+
df_read_pandas = df_read.toPandas()
41+
assert df_pandas["id"].equals(df_read_pandas["id"])
42+
43+
44+
def test_write_csv_with_quote(spark_session, tmp_path):
45+
df = spark_session.createDataFrame([("a,b",), ("c'd",)], ["text"])
46+
csv_dir = os.path.join(tmp_path, "csv")
47+
df.write.option("quote", "'").csv(csv_dir)
48+
49+
df_read = spark_session.read.option("quote", "'").csv(str(csv_dir))
50+
df_pandas = df.toPandas()
51+
df_read_pandas = df_read.toPandas()
52+
assert df_pandas["text"].equals(df_read_pandas["text"])
53+
54+
55+
def test_write_csv_with_escape(spark_session, tmp_path):
56+
df = spark_session.createDataFrame([("a'b",), ("c'd",)], ["text"])
57+
csv_dir = os.path.join(tmp_path, "csv")
58+
df.write.option("escape", "\\").csv(csv_dir)
59+
60+
df_read = spark_session.read.option("escape", "\\").csv(str(csv_dir))
61+
df_pandas = df.toPandas()
62+
df_read_pandas = df_read.toPandas()
63+
assert df_pandas["text"].equals(df_read_pandas["text"])
64+
65+
66+
@pytest.mark.skip(
67+
reason="https://github.com/Eventual-Inc/Daft/issues/3609: CSV null value handling not yet implemented"
68+
)
69+
def test_write_csv_with_null_value(spark_session, tmp_path):
70+
df = spark_session.createDataFrame([(1, None), (2, "test")], ["id", "value"])
71+
csv_dir = os.path.join(tmp_path, "csv")
72+
df.write.option("nullValue", "NULL").csv(csv_dir)
73+
74+
df_read = spark_session.read.option("nullValue", "NULL").csv(str(csv_dir))
75+
df_pandas = df.toPandas()
76+
df_read_pandas = df_read.toPandas()
77+
assert df_pandas["value"].isna().equals(df_read_pandas["value"].isna())
78+
79+
80+
def test_write_csv_with_compression(spark_session, tmp_path):
81+
df = spark_session.range(10)
82+
csv_dir = os.path.join(tmp_path, "csv")
83+
df.write.option("compression", "gzip").csv(csv_dir)
84+
85+
df_read = spark_session.read.csv(str(csv_dir))
86+
df_pandas = df.toPandas()
87+
df_read_pandas = df_read.toPandas()
88+
assert df_pandas["id"].equals(df_read_pandas["id"])

0 commit comments

Comments
 (0)