Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
f49b4fb
generalizing echo training script
Sep 1, 2023
4e6b1fe
echo training
Sep 1, 2023
749c596
fix import
alalusim Sep 14, 2023
f69dc20
save wide_df_selected
alalusim Sep 14, 2023
e171e2a
change patience=3
alalusim Sep 14, 2023
a82a1d5
fix output dir bug
alalusim Sep 14, 2023
988585e
Updating DROID to allow simultaneous regression and classification tr…
shnitzer Oct 5, 2023
7cabb97
Updating DROID to allow simultaneous regression and classification tr…
shnitzer Oct 6, 2023
6e182f5
Updating DROID to allow simultaneous regression and classification tr…
shnitzer Oct 6, 2023
0525769
Changed warnings about 'output_labels_types' to raised type errors (f…
shnitzer Oct 11, 2023
d7d429f
Adding new flags (early stopping and loss weights) and removing the n…
shnitzer Oct 19, 2023
ce6d53b
Using saved model information for loading the checkpoint model and re…
shnitzer Oct 19, 2023
16a6a6b
Changing the output_labels, selected_views, selected_doppler, selecte…
shnitzer Oct 19, 2023
361c774
Fixing bug that occurs when fine-tuning with a subset of the classifi…
shnitzer Oct 24, 2023
e19297c
autotune prefetch, fix inference classification bug
alalusim Nov 7, 2023
b427c52
save wide_df_selected during inference; fix cls_pred bug
alalusim Nov 10, 2023
35aa5e4
workaround due to parquet bug
alalusim Nov 14, 2023
ceebd6f
Revert "workaround due to parquet bug"
alalusim Nov 20, 2023
0b6b7de
change start_beat to start_frame,. add randomize_start_frame option
alalusim Nov 20, 2023
79af016
add randomize_start_frame option, more classification metrics
alalusim Nov 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions model_zoo/DROID/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ python echo_supervised_inference_recipe.py \
--wide_file {WIDE_FILE_PATH} \
--splits_file {SPLITS_JSON} \
--lmdb_folder {LMDB_DIRECTORY_PATH} \
--pretrained_ckpt_dir {SPECIALIZED_ENCODER_PATH} \
--movinet_ckpt_dir {MoViNet-A2-Base_PATH} \
--pretrained_chkp_dir {SPECIALIZED_ENCODER_PATH} \
--movinet_chkp_dir {MoViNet-A2-Base_PATH} \
--output_dir {WHERE_TO_STORE_PREDICTIONS}
```
24 changes: 17 additions & 7 deletions model_zoo/DROID/data_descriptions/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,19 @@ def __init__(
transforms=None,
nframes: int = None,
skip_modulo: int = 1,
start_beat=0,
start_frame=0,
randomize_start_frame = False
):

self.local_lmdb_dir = local_lmdb_dir
self._name = name
self.start_frame = start_frame
self.nframes = nframes
self.nframes = (nframes + start_beat) * skip_modulo
self.start_beat = start_beat
# transformations
self.transforms = transforms or []
self.skip_modulo = skip_modulo

self.randomize_start_frame = randomize_start_frame

def get_loading_options(self, sample_id):
_, study, view = sample_id.split('_')
lmdb_folder = os.path.join(self.local_lmdb_dir, f"{study}.lmdb")
Expand Down Expand Up @@ -81,13 +82,22 @@ def get_raw_data(self, sample_id, loading_option=None):
in_mem_bytes_io = io.BytesIO(txn.get(view.encode('utf-8')))
video_container = av.open(in_mem_bytes_io, metadata_errors="ignore")
video_frames = itertools.cycle(video_container.decode(video=0))

total_frames = len(list(video_container.decode(video=0)))
video_container.seek(0)

if self.randomize_start_frame:
frame_range = total_frames - (self.nframes * self.skip_modulo)
if frame_range > 0:
self.start_frame = np.random.randint(frame_range)

for i, frame in enumerate(video_frames):
if i == nframes:
if len(frames) == self.nframes:
break
if i < (self.start_beat * self.skip_modulo):
if i < (self.start_frame):
continue
if self.skip_modulo > 1:
if (i % self.skip_modulo) != 0:
if ((i - self.start_frame) % self.skip_modulo) != 0:
continue
frame = np.array(frame.to_image())
for transform in self.transforms:
Expand Down
96 changes: 96 additions & 0 deletions model_zoo/DROID/data_descriptions/wide_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Dict

import numpy as np
import pandas as pd
import tensorflow as tf

from ml4ht.data.data_description import DataDescription

from data_descriptions.echo import VIEW_OPTION_KEY


class EcholabDataDescription(DataDescription):
# DataDescription for a wide file

def __init__(
self,
wide_df: pd.DataFrame,
sample_id_column: str,
column_names: str,
name: str,
categories: Dict = None,
cls_categories_map: Dict = None,
transforms=None,
):
"""
"""
self.wide_df = wide_df
self._name = name
self.sample_id_column = sample_id_column
self.column_names = column_names
self.categories = categories
self.prep_df()
self.transforms = transforms or []
self.cls_categories_map = cls_categories_map

def prep_df(self):
self.wide_df.index = self.wide_df[self.sample_id_column]
self.wide_df = self.wide_df.drop_duplicates()

def get_loading_options(self, sample_id):
row = self.wide_df.loc[sample_id]

# a loading option is a dictionary of options to use at loading time
# we use DATE_OPTION_KEY to make the date selection utilities work
loading_options = [{VIEW_OPTION_KEY: row}]

# it's get_loading_options, not get loading_option, so we return a list
return loading_options

def get_raw_data(self, sample_id, loading_option=None):
try:
if sample_id.shape[0] > 1:
sample_id = sample_id[0]
except AttributeError:
pass
try:
sample_id = sample_id.decode('UTF-8')
except (UnicodeDecodeError, AttributeError):
pass
row = self.wide_df.loc[sample_id]
data = row[self.column_names].values
label_noise = np.zeros(len(self.column_names))
for transform in self.transforms:
label_noise += transform()
if self.categories:
output_data = np.zeros(len(self.categories), dtype=np.float32)
output_data[self.categories[data[0]]['index']] = 1.0
return output_data
# ---------- Adaptation for regression + classification ---------- #
if self.cls_categories_map:
# If training include classification tasks:
data = []
reg_data = row[self.column_names].drop(self.cls_categories_map['cls_output_order']).values
if len(reg_data) > 0:
data.append(np.squeeze(np.array(reg_data, dtype=np.float32)))

for k in self.cls_categories_map['cls_output_order']:
# Changing values to class labels:
row_cls_lbl = self.cls_categories_map[k][row[k]]
# Changing class indices to one hot vectors
cls_one_hot = tf.keras.utils.to_categorical(row_cls_lbl,
num_classes=len(self.cls_categories_map[k]))
data.append(cls_one_hot)

if len(data) == 1:
data = data[0]

return data
# ---------------------------------------------------------------- #
return np.squeeze(np.array(data, dtype=np.float32))

@property
def name(self):
# if we have multiple wide file DataDescriptions at the same time,
# this will allow us to differentiate between them
return self._name
Loading