Skip to content

Commit b446e81

Browse files
feat(breadbox): Support list of strings values matrix (#217)
1 parent ac08320 commit b446e81

File tree

7 files changed

+273
-32
lines changed

7 files changed

+273
-32
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Add list_strings as matrix value type
2+
3+
Revision ID: 020788c82611
4+
Revises: e593fefbe9fc
5+
Create Date: 2025-03-13 14:42:21.461226
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "020788c82611"
14+
down_revision = "e593fefbe9fc"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
with op.batch_alter_table("matrix_dataset", schema=None) as batch_op:
22+
batch_op.alter_column(
23+
"value_type",
24+
existing_type=sa.VARCHAR(length=11),
25+
type_=sa.Enum(
26+
"continuous", "categorical", "list_strings", name="valuetype"
27+
),
28+
existing_nullable=False,
29+
)
30+
31+
# ### end Alembic commands ###
32+
33+
34+
def downgrade():
35+
# ### commands auto generated by Alembic - please adjust! ###
36+
with op.batch_alter_table("matrix_dataset", schema=None) as batch_op:
37+
batch_op.alter_column(
38+
"value_type",
39+
existing_type=sa.Enum(
40+
"continuous", "categorical", "list_strings", name="valuetype"
41+
),
42+
type_=sa.VARCHAR(length=11),
43+
existing_nullable=False,
44+
)
45+
46+
# ### end Alembic commands ###

breadbox/breadbox/compute/dataset_uploads_tasks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ def dataset_upload(
155155
dataset_params.version,
156156
dataset_params.description,
157157
)
158-
save_dataset_file(dataset_id, data_df, settings.filestore_location)
158+
save_dataset_file(
159+
dataset_id, data_df, dataset_params.value_type, settings.filestore_location
160+
)
159161

160162
else:
161163
index_type = _get_dimension_type(db, dataset_params.index_type)

breadbox/breadbox/io/data_validation.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,26 @@ def _validate_dimension_type_metadata_file(
134134
return df
135135

136136

137+
def _parse_list_strings(val):
138+
example_list_string = '["x", "y"]'
139+
try:
140+
deserialized_str_list = json.loads(val)
141+
except Exception as e:
142+
raise FileValidationError(
143+
f"Value: {val} must be able to be deserialized into a list. Please make sure values for columns of type list_strings are a stringified list (ex: {example_list_string})"
144+
) from e
145+
146+
if not isinstance(deserialized_str_list, list):
147+
raise FileValidationError(
148+
f"Value: {val} must be able to be deserialized into a list. Please make sure values for columns of type list_strings are a stringified list (ex: {example_list_string})"
149+
)
150+
151+
if not all(isinstance(x, str) for x in deserialized_str_list):
152+
raise FileValidationError(
153+
f"All values in {deserialized_str_list} must be a string (ex: {example_list_string})"
154+
)
155+
156+
137157
def _validate_data_value_type(
138158
df: pd.DataFrame, value_type: ValueType, allowed_values: Optional[List]
139159
):
@@ -170,6 +190,18 @@ def _validate_data_value_type(
170190

171191
int_df = int_df.astype(int)
172192
return int_df
193+
elif value_type == ValueType.list_strings:
194+
195+
def validate_list_strings(val):
196+
if not pd.isnull(val):
197+
_parse_list_strings(val)
198+
return val
199+
else:
200+
# hdf5 will stringify 'None' or '<NA>'. Use empty string to represent NAs instead
201+
return ""
202+
203+
df = df.applymap(validate_list_strings)
204+
return df.astype(str)
173205
else:
174206
if not all([is_numeric_dtype(df[col].dtypes) for col in df.columns]):
175207
raise FileValidationError(
@@ -190,7 +222,7 @@ def _read_parquet(file, value_type: ValueType) -> pd.DataFrame:
190222
# parquet files have the types encoded in the file, so we'll convert after the fact
191223
if value_type == ValueType.continuous:
192224
dtype = "Float64"
193-
elif value_type == ValueType.categorical:
225+
elif value_type == ValueType.categorical or value_type == ValueType.list_strings:
194226
dtype = "string"
195227
else:
196228
raise ValueError(f"Invalid value type: {value_type}")
@@ -217,7 +249,7 @@ def _read_csv(file: BinaryIO, value_type: ValueType) -> pd.DataFrame:
217249

218250
if value_type == ValueType.continuous:
219251
dtypes_ = dict(zip(cols, ["string"] + (["Float64"] * (len(cols) - 1))))
220-
elif value_type == ValueType.categorical:
252+
elif value_type == ValueType.categorical or value_type == ValueType.list_strings:
221253
dtypes_ = dict(zip(cols, ["string"] * len(cols)))
222254
else:
223255
raise ValueError(f"Invalid value type: {value_type}")
@@ -422,7 +454,7 @@ def validate_and_upload_dataset_files(
422454
)
423455

424456
# TODO: Move save function to api layer. Need to make sure the db save is successful first
425-
save_dataset_file(dataset_id, data_df, filestore_location)
457+
save_dataset_file(dataset_id, data_df, value_type, filestore_location)
426458

427459
return dataframe_validated_dimensions
428460

@@ -575,27 +607,9 @@ def _validate_tabular_df_schema(
575607
dimension_type_identifier: str,
576608
):
577609
def can_parse_list_strings(val):
578-
example_list_string = '["x", "y"]'
579-
if val is not None and not pd.isnull(val):
580-
try:
581-
deserialized_str_list = json.loads(val)
582-
except Exception as e:
583-
raise FileValidationError(
584-
f"Value: {val} must be able to be deserialized into a list. Please make sure values for columns of type list_strings are a stringified list (ex: {example_list_string})"
585-
) from e
586-
587-
if not isinstance(deserialized_str_list, list):
588-
raise FileValidationError(
589-
f"Value: {val} must be able to be deserialized into a list. Please make sure values for columns of type list_strings are a stringified list (ex: {example_list_string})"
590-
)
591-
592-
if not all(isinstance(x, str) for x in deserialized_str_list):
593-
raise FileValidationError(
594-
f"All values in {deserialized_str_list} must be a string (ex: {example_list_string})"
595-
)
596-
return True
597-
else:
598-
return True
610+
if not pd.isnull(val):
611+
_parse_list_strings(val)
612+
return True
599613

600614
def get_checks_for_col(annotation_type: AnnotationType):
601615
checks = []

breadbox/breadbox/io/filestore_crud.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import shutil
3+
import json
34
from typing import Any, List, Optional, Union
45

56
import pandas as pd
@@ -11,13 +12,21 @@
1112

1213

1314
def save_dataset_file(
14-
dataset_id: str, data_df: pd.DataFrame, filestore_location: str,
15+
dataset_id: str,
16+
data_df: pd.DataFrame,
17+
value_type: ValueType,
18+
filestore_location: str,
1519
):
1620
base_path = os.path.join(filestore_location, dataset_id)
1721
os.makedirs(base_path)
1822

23+
if value_type == ValueType.list_strings:
24+
dtype = "str"
25+
else:
26+
dtype = "float"
27+
1928
write_hdf5_file(
20-
get_file_location(dataset_id, filestore_location, DATA_FILE), data_df
29+
get_file_location(dataset_id, filestore_location, DATA_FILE), data_df, dtype
2130
)
2231

2332

@@ -94,6 +103,10 @@ def get_df_by_value_type(
94103
# Convert numerical values back to origincal categorical value
95104
df = df.astype(int)
96105
df = df.applymap(lambda x: dataset_allowed_values[x])
106+
elif value_type == ValueType.list_strings:
107+
# NOTE: String data in HDF5 datasets is read as bytes by default
108+
# len of byte encoded empty string should be 0
109+
df = df.applymap(lambda x: json.loads(x) if len(x) != 0 else None)
97110
return df
98111

99112

breadbox/breadbox/io/hdf5_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List, Optional, Literal
22

33
import h5py
44
import numpy as np
@@ -21,10 +21,15 @@ def create_index_dataset(f: h5py.File, key: str, idx: pd.Index):
2121
)
2222

2323

24-
def write_hdf5_file(path: str, df: pd.DataFrame):
24+
def write_hdf5_file(path: str, df: pd.DataFrame, dtype: Literal["float", "str"]):
2525
f = h5py.File(path, mode="w")
2626
try:
27-
f.create_dataset("data", shape=df.shape, dtype=np.float64, data=df.values)
27+
f.create_dataset(
28+
"data",
29+
shape=df.shape,
30+
dtype=h5py.string_dtype() if dtype == "str" else np.float64,
31+
data=df.values,
32+
)
2833

2934
create_index_dataset(f, "features", df.columns)
3035
create_index_dataset(f, "samples", df.index)

breadbox/breadbox/schemas/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class FeatureSampleIdentifier(enum.Enum):
2424
class ValueType(enum.Enum):
2525
continuous = "continuous"
2626
categorical = "categorical"
27+
list_strings = "list_strings"
2728

2829

2930
class AnnotationType(enum.Enum):

0 commit comments

Comments
 (0)