Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"pytimeparse == 1.1.*",
"networkx == 3.3.*",
"pyarrow == 17.*",
"meds == 0.3.3",
"meds ~= 0.4.0",
]

[tool.setuptools]
Expand Down
45 changes: 18 additions & 27 deletions src/aces/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import polars as pl
import pyarrow as pa
import pyarrow.parquet as pq
from meds import label_schema, prediction_time_field, subject_id_field
from meds import LabelSchema
from omegaconf import DictConfig, OmegaConf

from . import config, predicates, query
Expand All @@ -20,15 +20,15 @@
config_yaml = files("aces").joinpath("configs/_aces.yaml")

MEDS_LABEL_MANDATORY_TYPES = {
subject_id_field: pl.Int64,
LabelSchema.subject_id_name: pl.Int64,
}

MEDS_LABEL_OPTIONAL_TYPES = {
"boolean_value": pl.Boolean,
"integer_value": pl.Int64,
"float_value": pl.Float64,
"categorical_value": pl.String,
prediction_time_field: pl.Datetime("us"),
LabelSchema.prediction_time_name: pl.Datetime("us"),
LabelSchema.boolean_value_name: pl.Boolean,
LabelSchema.integer_value_name: pl.Int64,
LabelSchema.float_value_name: pl.Float64,
LabelSchema.categorical_value_name: pl.String,
}


Expand Down Expand Up @@ -56,9 +56,9 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
>>> get_and_validate_label_schema(df)
Traceback (most recent call last):
...
ValueError: MEDS Data DataFrame must have a 'subject_id' column of type Int64.
ValueError: MEDS Label DataFrame must have a 'subject_id' column of type Int64.
>>> df = pl.DataFrame({
... subject_id_field: pl.Series([1, 3, 2], dtype=pl.UInt32),
... "subject_id": pl.Series([1, 3, 2], dtype=pl.UInt32),
... "time": [datetime(2021, 1, 1), datetime(2021, 1, 2), datetime(2021, 1, 3)],
... "boolean_value": [1, 0, 100],
... })
Expand All @@ -68,7 +68,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
prediction_time: timestamp[us]
boolean_value: bool
integer_value: int64
float_value: double
float_value: float
categorical_value: string
----
subject_id: [[1,3,2]]
Expand All @@ -80,7 +80,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
"""

schema = df.schema
if "prediction_time" not in schema:
if LabelSchema.prediction_time_name not in schema:
logger.warning(
"Output DataFrame is missing a 'prediction_time' column. If this is not intentional, add a "
"'index_timestamp' (yes, it should be different) key to the task configuration identifying "
Expand All @@ -92,7 +92,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
if col in schema and schema[col] != dtype:
df = df.with_columns(pl.col(col).cast(dtype, strict=False))
elif col not in schema:
errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.")
errors.append(f"MEDS Label DataFrame must have a '{col}' column of type {dtype}.")

if errors:
raise ValueError("\n".join(errors))
Expand All @@ -115,16 +115,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
)
df = df.drop(extra_cols)

df = df.select(
subject_id_field,
"prediction_time",
"boolean_value",
"integer_value",
"float_value",
"categorical_value",
)

return df.to_arrow().cast(label_schema)
return LabelSchema.align(df.to_arrow())


@hydra.main(version_base=None, config_path=str(config_yaml.parent.resolve()), config_name=config_yaml.stem)
Expand Down Expand Up @@ -154,18 +145,18 @@ def main(cfg: DictConfig) -> None: # pragma: no cover

if cfg.data.standard.lower() == "meds":
for in_col, out_col in [
("subject_id", subject_id_field),
("index_timestamp", "prediction_time"),
("label", "boolean_value"),
("subject_id", LabelSchema.subject_id_name),
("index_timestamp", LabelSchema.prediction_time_name),
("label", LabelSchema.boolean_value_name),
]:
if in_col in result.columns:
result = result.rename({in_col: out_col})
if subject_id_field not in result.columns:
if LabelSchema.subject_id_name not in result.columns:
if not result_is_empty:
raise ValueError("Output dataframe is missing a 'subject_id' column.")
else:
logger.warning("Output dataframe is empty; adding an empty patient ID column.")
result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias(subject_id_field))
result = result.with_columns(pl.lit(None, dtype=pl.Int64).alias(LabelSchema.subject_id_name))
result = result.head(0)
if cfg.window_stats_dir:
Path(cfg.window_stats_filepath).parent.mkdir(exist_ok=True, parents=True)
Expand Down
60 changes: 25 additions & 35 deletions tests/test_meds.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import polars as pl
import pyarrow as pa
from meds import label_schema, subject_id_field
from meds import DataSchema, LabelSchema
from yaml import load as load_yaml

from .utils import (
Expand Down Expand Up @@ -36,24 +36,23 @@

# TODO: Make use meds library
MEDS_PL_SCHEMA = {
subject_id_field: pl.Int64,
"time": pl.Datetime("us"),
"code": pl.Utf8,
"numeric_value": pl.Float32,
"numeric_value/is_inlier": pl.Boolean,
DataSchema.subject_id_name: pl.Int64,
DataSchema.time_name: pl.Datetime("us"),
DataSchema.code_name: pl.Utf8,
DataSchema.numeric_value_name: pl.Float32,
}


MEDS_LABEL_MANDATORY_TYPES = {
subject_id_field: pl.Int64,
LabelSchema.subject_id_name: pl.Int64,
}

MEDS_LABEL_OPTIONAL_TYPES = {
"boolean_value": pl.Boolean,
"integer_value": pl.Int64,
"float_value": pl.Float64,
"categorical_value": pl.String,
"prediction_time": pl.Datetime("us"),
LabelSchema.boolean_value_name: pl.Boolean,
LabelSchema.integer_value_name: pl.Int64,
LabelSchema.float_value_name: pl.Float64,
LabelSchema.categorical_value_name: pl.String,
LabelSchema.prediction_time_name: pl.Datetime("us"),
}


Expand Down Expand Up @@ -113,16 +112,7 @@ def get_and_validate_label_schema(df: pl.DataFrame) -> pa.Table:
)
df = df.drop(extra_cols)

df = df.select(
subject_id_field,
"prediction_time",
"boolean_value",
"integer_value",
"float_value",
"categorical_value",
)

return df.to_arrow().cast(label_schema)
return LabelSchema.align(df.to_arrow())


def parse_meds_csvs(
Expand All @@ -140,7 +130,7 @@ def reader(csv_str: str) -> pl.DataFrame:
cols = csv_str.strip().split("\n")[0].split(",")
read_schema = {k: v for k, v in default_read_schema.items() if k in cols}
return pl.read_csv(StringIO(csv_str), schema=read_schema).with_columns(
pl.col("time").str.strptime(MEDS_PL_SCHEMA["time"], DEFAULT_CSV_TS_FORMAT)
pl.col("time").str.strptime(MEDS_PL_SCHEMA[DataSchema.time_name], DEFAULT_CSV_TS_FORMAT)
)

if isinstance(csvs, str):
Expand Down Expand Up @@ -169,9 +159,9 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:

# Data (input)
MEDS_SHARDS = parse_shards_yaml(
f"""
"""
"train/0": |-
{subject_id_field},time,code,numeric_value
subject_id,time,code,numeric_value
2,,SNP//rs234567,
2,,SNP//rs345678,
2,,GENDER//FEMALE,
Expand All @@ -196,7 +186,7 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
2,6/8/1996 3:00,DEATH,

"train/1": |-2
{subject_id_field},time,code,numeric_value
subject_id,time,code,numeric_value
4,,GENDER//MALE,
4,,SNP//rs123456,
4,12/1/1989 12:03,ADMISSION//CARDIAC,
Expand Down Expand Up @@ -246,7 +236,7 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
6,3/12/1996 0:00,DEATH,

"held_out/0/0": |-2
{subject_id_field},time,code,numeric_value
subject_id,time,code,numeric_value
3,,GENDER//FEMALE,
3,,SNP//rs234567,
3,,SNP//rs345678,
Expand All @@ -261,10 +251,10 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
3,3/12/1996 0:00,DEATH,

"empty_shard": |-2
{subject_id_field},time,code,numeric_value
subject_id,time,code,numeric_value

"held_out": |-2
{subject_id_field},time,code,numeric_value
subject_id,time,code,numeric_value
1,,GENDER//MALE,
1,,SNP//rs123456,
1,12/1/1989 12:03,ADMISSION//CARDIAC,
Expand Down Expand Up @@ -349,22 +339,22 @@ def parse_labels_yaml(yaml_str: str) -> dict[str, pl.DataFrame]:
"""

WANT_SHARDS = parse_labels_yaml(
f"""
"""
"train/0": |-2
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value

"train/1": |-2
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
4,1/28/1991 23:32,False,,,,

"held_out/0/0": |-2
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value

"empty_shard": |-2
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value

"held_out": |-2
{subject_id_field},prediction_time,boolean_value,integer_value,float_value,categorical_value
subject_id,prediction_time,boolean_value,integer_value,float_value,categorical_value
1,1/28/1991 23:32,False,,,,
"""
)
Expand Down
Loading