Skip to content

Commit eeaedd7

Browse files
committed
Add finalize/prepare and requirements text
1 parent b618820 commit eeaedd7

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import argparse
2+
import json
3+
import numpy as np
4+
import pandas as pd
5+
6+
7+
def convert_survival_curve_to_risk_score(curve):
8+
curve = np.array(curve)
9+
return 1 - np.cumprod(curve[:25])[-1]
10+
11+
12+
def finalize(input_csv, predictions_json, output_csv):
13+
with open(predictions_json, "r") as f:
14+
prediction_data = json.load(f)
15+
16+
df = pd.read_csv(input_csv, dtype={"file_id": str})
17+
18+
age = prediction_data["output_age_from_wide_csv_continuous"]
19+
af = prediction_data["output_af_in_read_categorical"]
20+
sex = prediction_data["output_sex_from_wide_categorical"]
21+
curves = prediction_data["output_survival_curve_af_survival_curve"]
22+
23+
if len(age) != len(df):
24+
raise ValueError(f"Mismatch: {len(age)} predictions but {len(df)} rows in input CSV!")
25+
26+
df["output_age"] = [row[0] for row in age]
27+
df["output_af_0"] = [row[0] for row in af]
28+
df["output_af_1"] = [row[1] for row in af]
29+
df["output_sex_male"] = [row[0] for row in sex]
30+
df["output_sex_female"] = [row[1] for row in sex]
31+
df["af_risk_score"] = [convert_survival_curve_to_risk_score(row) for row in curves]
32+
33+
df.to_csv(output_csv, index=False)
34+
print(f"✅ Predictions written to {output_csv} ({len(df)} rows).")
35+
36+
37+
if __name__ == "__main__":
38+
parser = argparse.ArgumentParser()
39+
parser.add_argument("--input", required=True, help="Path to input CSV")
40+
parser.add_argument("--output", required=True, help="Path to final CSV with predictions")
41+
parser.add_argument("--predictions", required=True, help="Path to predictions JSON")
42+
args = parser.parse_args()
43+
44+
finalize(args.input, args.predictions, args.output)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import argparse
2+
3+
import h5py
4+
import numpy as np
5+
import pandas as pd
6+
import smart_open
7+
8+
ECG_REST_LEADS = {
9+
'strip_I': 0, 'strip_II': 1, 'strip_III': 2, 'strip_V1': 6, 'strip_V2': 7, 'strip_V3': 8,
10+
'strip_V4': 9, 'strip_V5': 10, 'strip_V6': 11, 'strip_aVF': 5, 'strip_aVL': 4, 'strip_aVR': 3,
11+
}
12+
ECG_SHAPE = (5000, 12)
13+
ECG_HD5_PATH = 'ukb_ecg_rest'
14+
15+
16+
def ecg_as_tensor(ecg_file):
17+
with smart_open.open(ecg_file, 'rb') as f:
18+
with h5py.File(f, 'r') as hd5:
19+
tensor = np.zeros(ECG_SHAPE, dtype=np.float32)
20+
for lead in ECG_REST_LEADS:
21+
data = np.array(hd5[f'{ECG_HD5_PATH}/{lead}/instance_0'])
22+
tensor[:, ECG_REST_LEADS[lead]] = data
23+
24+
mean = np.mean(tensor)
25+
std = np.std(tensor) + 1e-7
26+
tensor = (tensor - mean) / std
27+
return tensor
28+
29+
30+
def prepare(input_csv, output_h5):
31+
"""Processes ECG files into HDF5 tensor format from GCS/Azure/Local."""
32+
df = pd.read_csv(input_csv, dtype={"file": str})
33+
h5_file = h5py.File(output_h5, "w")
34+
tensors_group = h5_file.create_group("tensors")
35+
df = df.dropna(subset=["file"])
36+
df["file"] = df["file"].astype(str)
37+
for _, row in df.iterrows():
38+
sample_id, file_path = row["file_id"], row["file"]
39+
print(f"Processing: sample_id={sample_id}, file_path={file_path}, type={type(file_path)}")
40+
tensor = ecg_as_tensor(file_path)
41+
tensors_group.create_dataset(str(sample_id), data=tensor)
42+
43+
h5_file.close()
44+
print(f"Processed ECG tensors saved to {output_h5}")
45+
46+
47+
if __name__ == "__main__":
48+
parser = argparse.ArgumentParser()
49+
parser.add_argument("--input", required=True, help="Path to input CSV")
50+
parser.add_argument("--output", required=True, help="Path to output HDF5 file")
51+
args = parser.parse_args()
52+
53+
prepare(args.input, args.output)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
pandas
2+
numpy
3+
h5py
4+
smart-open[gcs]

0 commit comments

Comments
 (0)