Skip to content

Commit 1e8d502

Browse files
authored
add parquet support (#192)
1 parent 2233c39 commit 1e8d502

File tree

5 files changed

+69
-30
lines changed

5 files changed

+69
-30
lines changed

Diff for: pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies = [
3232
"pillow >=10.0.0",
3333
"psycopg2-binary >=2.9.6",
3434
"psygnal >=0.9.0",
35+
"pyarrow >=16.1.0,<20",
3536
"pydantic >= 2",
3637
"pydot >=2.0.0",
3738
"qtawesome >=1.3.1",
@@ -118,6 +119,7 @@ uvicorn = ">=0.27.0.post1"
118119
websocket = ">=0.2.1"
119120
websockets = ">=12.0"
120121
zarr = ">=2.15.0,<3.0.0"
122+
pyarrow = ">=16.1.0,<20"
121123

122124
[tool.pixi.feature.cuda]
123125
channels = ["conda-forge", "rapidsai"]

Diff for: ultrack/core/export/exporter.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def export_tracks_by_extension(
2222
Supported file extensions are .xml, .csv, .zarr, .dot, and .json.
2323
- `.xml` exports to a TrackMate compatible XML file.
2424
- `.csv` exports to a CSV file.
25+
- `.parquet` exports to a Parquet file.
2526
- `.zarr` exports the tracks to dense segments in a `zarr` array format.
2627
- `.dot` exports to a Graphviz DOT file.
2728
- `.json` exports to a networkx JSON file.
@@ -60,6 +61,9 @@ def export_tracks_by_extension(
6061
elif file_ext.lower() == ".zarr":
6162
df, _ = to_tracks_layer(config)
6263
tracks_to_zarr(config, df, filename, overwrite=True)
64+
elif file_ext.lower() == ".parquet":
65+
df, _ = to_tracks_layer(config)
66+
df.to_parquet(filename)
6367
elif file_ext.lower() == ".dot":
6468
G = to_networkx(config)
6569
nx.drawing.nx_pydot.write_dot(G, filename)
@@ -70,5 +74,6 @@ def export_tracks_by_extension(
7074
json.dump(json_data, f)
7175
else:
7276
raise ValueError(
73-
f"Unknown file extension: {file_ext}. Supported extensions are .xml, .csv, .zarr, .dot, and .json."
77+
f"Unknown file extension: {file_ext}. "
78+
"Supported extensions are .xml, .csv, .zarr, .parquet, .dot, and .json."
7479
)

Diff for: ultrack/napari.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ contributions:
4141
readers:
4242
- command: ultrack.get_reader
4343
accepts_directories: false
44-
filename_patterns: ['*.csv']
44+
filename_patterns: ['*.csv', '*.parquet']
4545

4646
# writers:
4747
# - command: ultrack.write_multiple

Diff for: ultrack/reader/_test/test_napari_reader.py

+40-17
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,18 @@ def tracks_df(n_nodes: int = 10) -> pd.DataFrame:
2525
return pd.DataFrame(tracks_data, columns=["track_id", "t", "z", "y", "x"])
2626

2727

28-
def test_reader(tracks_df: pd.DataFrame, tmp_path: Path):
29-
reader = napari_get_reader("tracks.csv")
28+
@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
29+
def test_reader(tracks_df: pd.DataFrame, tmp_path: Path, file_ext: str):
30+
reader = napari_get_reader(f"tracks.{file_ext}")
3031
assert reader is None
3132

32-
path = tmp_path / "good_tracks.csv"
33+
path = tmp_path / f"good_tracks.{file_ext}"
3334
tracks_df["node_id"] = np.arange(len(tracks_df)) + 1
3435
tracks_df["labels"] = np.random.randint(2, size=len(tracks_df))
35-
tracks_df.to_csv(path, index=False)
36+
if file_ext == "csv":
37+
tracks_df.to_csv(path, index=False)
38+
else:
39+
tracks_df.to_parquet(path)
3640

3741
reader = napari_get_reader(path)
3842
assert callable(reader)
@@ -47,13 +51,17 @@ def test_reader(tracks_df: pd.DataFrame, tmp_path: Path):
4751
assert np.allclose(data, tracks_df[["track_id", "t", "z", "y", "x"]])
4852

4953

50-
def test_reader_2d(tracks_df: pd.DataFrame, tmp_path: Path):
51-
reader = napari_get_reader("tracks.csv")
54+
@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
55+
def test_reader_2d(tracks_df: pd.DataFrame, tmp_path: Path, file_ext: str):
56+
reader = napari_get_reader(f"tracks.{file_ext}")
5257
assert reader is None
5358

54-
path = tmp_path / "good_tracks.csv"
59+
path = tmp_path / f"good_tracks.{file_ext}"
5560
tracks_df = tracks_df.drop(columns=["z"])
56-
tracks_df.to_csv(path, index=False)
61+
if file_ext == "csv":
62+
tracks_df.to_csv(path, index=False)
63+
else:
64+
tracks_df.to_parquet(path)
5765

5866
reader = napari_get_reader(path)
5967
assert callable(reader)
@@ -64,7 +72,8 @@ def test_reader_2d(tracks_df: pd.DataFrame, tmp_path: Path):
6472
assert np.allclose(data, tracks_df[["track_id", "t", "y", "x"]])
6573

6674

67-
def test_reader_with_lineage(tmp_path: Path):
75+
@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
76+
def test_reader_with_lineage(tmp_path: Path, file_ext: str):
6877
tracks_df = pd.DataFrame(
6978
{
7079
"track_id": [1, 1, 2, 3],
@@ -76,8 +85,11 @@ def test_reader_with_lineage(tmp_path: Path):
7685
}
7786
)
7887

79-
path = tmp_path / "tracks.csv"
80-
tracks_df.to_csv(path, index=False)
88+
path = tmp_path / f"tracks.{file_ext}"
89+
if file_ext == "csv":
90+
tracks_df.to_csv(path, index=False)
91+
else:
92+
tracks_df.to_parquet(path)
8193

8294
reader = napari_get_reader(path)
8395
assert callable(reader)
@@ -95,26 +107,37 @@ def test_non_existing_track():
95107
assert reader is None
96108

97109

98-
def test_wrong_columns_track(tracks_df: pd.DataFrame, tmp_path: Path):
99-
reader = napari_get_reader("tracks.csv")
110+
@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
111+
def test_wrong_columns_track(tracks_df: pd.DataFrame, tmp_path: Path, file_ext: str):
112+
reader = napari_get_reader(f"tracks.{file_ext}")
100113
assert reader is None
101114

102-
path = tmp_path / "bad_tracks.csv"
115+
path = tmp_path / f"bad_tracks.{file_ext}"
103116
tracks_df = tracks_df.rename(columns={"track_id": "id"})
104-
tracks_df.to_csv(path, index=False)
117+
if file_ext == "csv":
118+
tracks_df.to_csv(path, index=False)
119+
else:
120+
tracks_df.to_parquet(path)
121+
105122
reader = napari_get_reader(path)
106123
assert reader is None
107124

108125

126+
@pytest.mark.parametrize("file_ext", ["csv", "parquet"])
109127
def test_napari_viewer_open_tracks(
110128
make_napari_viewer: Callable[[], ViewerModel],
111129
tracks_df: pd.DataFrame,
112130
tmp_path: Path,
131+
file_ext: str,
113132
) -> None:
114133

115134
_initialize_plugins()
116135

117-
tracks_df.to_csv(tmp_path / "tracks.csv", index=False)
136+
path = tmp_path / f"tracks.{file_ext}"
137+
if file_ext == "csv":
138+
tracks_df.to_csv(path, index=False)
139+
else:
140+
tracks_df.to_parquet(path)
118141

119142
viewer = make_napari_viewer()
120-
viewer.open(tmp_path / "tracks.csv", plugin="ultrack")
143+
viewer.open(path, plugin="ultrack")

Diff for: ultrack/reader/napari_reader.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Callable, List, Union
44

55
import pandas as pd
6+
import pyarrow.parquet as pq
67
from napari.types import LayerDataTuple
78

89
from ultrack.tracks.graph import inv_tracks_df_forest
@@ -46,15 +47,21 @@ def napari_get_reader(
4647

4748
LOG.info(f"Reading tracks from {path}")
4849

49-
if not path.name.endswith(".csv"):
50-
LOG.info(f"{path} must end with `.csv`.")
50+
file_name = path.name.lower()
51+
52+
if not file_name.endswith(".csv") and not file_name.endswith(".parquet"):
53+
LOG.info(f"{path} must end with `.csv` or `.parquet`.")
5154
return None
5255

5356
if not path.exists():
5457
LOG.info(f"{path} does not exist.")
5558
return None
5659

57-
header = pd.read_csv(path, nrows=0).columns.tolist()
60+
if file_name.endswith(".csv"):
61+
header = pd.read_csv(path, nrows=0).columns.tolist()
62+
else:
63+
header = pq.read_table(path).schema.names
64+
5865
LOG.info(f"Tracks file header: {header}")
5966

6067
for colname in TRACKS_HEADER:
@@ -68,14 +75,14 @@ def napari_get_reader(
6875
return reader_function
6976

7077

71-
def read_csv(path: Union[Path, str]) -> LayerDataTuple:
78+
def read_dataframe(path: Union[Path, str]) -> LayerDataTuple:
7279
"""
73-
Read track data from a CSV file.
80+
Read track data from a CSV or Parquet file.
7481
7582
Parameters
7683
----------
7784
path : Union[Path, str]
78-
Path to the CSV file.
85+
Path to the CSV or Parquet file.
7986
8087
Returns
8188
-------
@@ -90,10 +97,12 @@ def read_csv(path: Union[Path, str]) -> LayerDataTuple:
9097
If the CSV file contains a 'parent_track_id' column, a track lineage graph
9198
is constructed.
9299
"""
93-
if isinstance(path, str):
94-
path = Path(path)
95-
96-
df = pd.read_csv(path)
100+
path = Path(path)
101+
file_name = path.name.lower()
102+
if file_name.endswith(".csv"):
103+
df = pd.read_csv(path)
104+
elif file_name.endswith(".parquet"):
105+
df = pd.read_parquet(path)
97106

98107
LOG.info(f"Read {len(df)} tracks from {path}")
99108
LOG.info(df.head())
@@ -132,4 +141,4 @@ def reader_function(path: Union[List[str], str]) -> List:
132141
List of track data tuples.
133142
"""
134143
paths = [path] if isinstance(path, (str, Path)) else path
135-
return [read_csv(p) for p in paths]
144+
return [read_dataframe(p) for p in paths]

0 commit comments

Comments
 (0)