Skip to content

Training script for echos #539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions model_zoo/DROID/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ python echo_supervised_inference_recipe.py \
--wide_file {WIDE_FILE_PATH} \
--splits_file {SPLITS_JSON} \
--lmdb_folder {LMDB_DIRECTORY_PATH} \
--pretrained_ckpt_dir {SPECIALIZED_ENCODER_PATH} \
--movinet_ckpt_dir {MoViNet-A2-Base_PATH} \
--pretrained_chkp_dir {SPECIALIZED_ENCODER_PATH} \
--movinet_chkp_dir {MoViNet-A2-Base_PATH} \
--output_dir {WHERE_TO_STORE_PREDICTIONS}
```
24 changes: 17 additions & 7 deletions model_zoo/DROID/data_descriptions/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,19 @@ def __init__(
transforms=None,
nframes: int = None,
skip_modulo: int = 1,
start_beat=0,
start_frame=0,
randomize_start_frame = False
):

self.local_lmdb_dir = local_lmdb_dir
self._name = name
self.start_frame = start_frame
self.nframes = nframes
self.nframes = (nframes + start_beat) * skip_modulo
self.start_beat = start_beat
# transformations
self.transforms = transforms or []
self.skip_modulo = skip_modulo

self.randomize_start_frame = randomize_start_frame

def get_loading_options(self, sample_id):
_, study, view = sample_id.split('_')
lmdb_folder = os.path.join(self.local_lmdb_dir, f"{study}.lmdb")
Expand Down Expand Up @@ -81,13 +82,22 @@ def get_raw_data(self, sample_id, loading_option=None):
in_mem_bytes_io = io.BytesIO(txn.get(view.encode('utf-8')))
video_container = av.open(in_mem_bytes_io, metadata_errors="ignore")
video_frames = itertools.cycle(video_container.decode(video=0))

total_frames = len(list(video_container.decode(video=0)))
video_container.seek(0)

if self.randomize_start_frame:
frame_range = total_frames - (self.nframes * self.skip_modulo)
if frame_range > 0:
self.start_frame = np.random.randint(frame_range)

for i, frame in enumerate(video_frames):
if i == nframes:
if len(frames) == self.nframes:
break
if i < (self.start_beat * self.skip_modulo):
if i < (self.start_frame):
continue
if self.skip_modulo > 1:
if (i % self.skip_modulo) != 0:
if ((i - self.start_frame) % self.skip_modulo) != 0:
continue
frame = np.array(frame.to_image())
for transform in self.transforms:
Expand Down
96 changes: 96 additions & 0 deletions model_zoo/DROID/data_descriptions/wide_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Dict

import numpy as np
import pandas as pd
import tensorflow as tf

from ml4ht.data.data_description import DataDescription

from data_descriptions.echo import VIEW_OPTION_KEY


class EcholabDataDescription(DataDescription):
# DataDescription for a wide file

def __init__(
self,
wide_df: pd.DataFrame,
sample_id_column: str,
column_names: str,
name: str,
categories: Dict = None,
cls_categories_map: Dict = None,
transforms=None,
):
"""
"""
self.wide_df = wide_df
self._name = name
self.sample_id_column = sample_id_column
self.column_names = column_names
self.categories = categories
self.prep_df()
self.transforms = transforms or []
self.cls_categories_map = cls_categories_map

def prep_df(self):
self.wide_df.index = self.wide_df[self.sample_id_column]
self.wide_df = self.wide_df.drop_duplicates()

def get_loading_options(self, sample_id):
row = self.wide_df.loc[sample_id]

# a loading option is a dictionary of options to use at loading time
# we use DATE_OPTION_KEY to make the date selection utilities work
loading_options = [{VIEW_OPTION_KEY: row}]

# it's get_loading_options, not get loading_option, so we return a list
return loading_options

def get_raw_data(self, sample_id, loading_option=None):
try:
if sample_id.shape[0] > 1:
sample_id = sample_id[0]
except AttributeError:
pass
try:
sample_id = sample_id.decode('UTF-8')
except (UnicodeDecodeError, AttributeError):
pass
row = self.wide_df.loc[sample_id]
data = row[self.column_names].values
label_noise = np.zeros(len(self.column_names))
for transform in self.transforms:
label_noise += transform()
if self.categories:
output_data = np.zeros(len(self.categories), dtype=np.float32)
output_data[self.categories[data[0]]['index']] = 1.0
return output_data
# ---------- Adaptation for regression + classification ---------- #
if self.cls_categories_map:
# If training include classification tasks:
data = []
reg_data = row[self.column_names].drop(self.cls_categories_map['cls_output_order']).values
if len(reg_data) > 0:
data.append(np.squeeze(np.array(reg_data, dtype=np.float32)))

for k in self.cls_categories_map['cls_output_order']:
# Changing values to class labels:
row_cls_lbl = self.cls_categories_map[k][row[k]]
# Changing class indices to one hot vectors
cls_one_hot = tf.keras.utils.to_categorical(row_cls_lbl,
num_classes=len(self.cls_categories_map[k]))
data.append(cls_one_hot)

if len(data) == 1:
data = data[0]

return data
# ---------------------------------------------------------------- #
return np.squeeze(np.array(data, dtype=np.float32))

@property
def name(self):
# if we have multiple wide file DataDescriptions at the same time,
# this will allow us to differentiate between them
return self._name
Loading