Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions fuse_examples/classification/multimodality/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import sys
from typing import Callable, Optional
import logging
import pandas as pd
import pydicom
import os, glob
from pathlib import Path
from typing import Tuple

from fuse.data.visualizer.visualizer_default import FuseVisualizerDefault
from fuse.data.augmentor.augmentor_default import FuseAugmentorDefault
from fuse.data.augmentor.augmentor_toolbox import aug_op_color, aug_op_gaussian, aug_op_affine
from fuse.data.dataset.dataset_default import FuseDatasetDefault
from fuse.data.dataset.dataset_generator import FuseDatasetGenerator
from fuse.data.data_source.data_source_default import FuseDataSourceDefault

from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerUniform as Uniform
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandInt as RandInt
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandBool as RandBool


from fuse_examples.classification.multimodality.input_processor import ImagingTabularProcessor





def IMAGING_dataset():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one returns augmentor and not a dataset, right?
If yes, lets rename (name should be lower case with underscores) and also add type annotation for the returned value.

"""
Creates Fuse Dataset object for training, validation and test

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment not up to date I guess.

:param data_dir: dataset root path
:param data_misc_dir path to save misc files to be used later
:param cache_dir: Optional, name of the cache folder
:param reset_cache: Optional,specifies if we want to clear the cache first
:param post_cache_processing_func: Optional, function run post cache processing
:return: training, validation and test FuseDatasetDefault objects
"""
augmentation_pipeline = [
[
('data.image',),
aug_op_affine,
{'rotate': Uniform(-30.0, 30.0), 'translate': (RandInt(-10, 10), RandInt(-10, 10)),
'flip': (RandBool(0.3), RandBool(0.3)), 'scale': Uniform(0.9, 1.1)},
{'apply': RandBool(0.5)}
],
[
('data.image',),
aug_op_color,
{'add': Uniform(-0.06, 0.06), 'mul': Uniform(0.95, 1.05), 'gamma': Uniform(0.9, 1.1),
'contrast': Uniform(0.85, 1.15)},
{'apply': RandBool(0.5)}
],
[
('data.image',),
aug_op_gaussian,
{'std': 0.03},
{'apply': RandBool(0.5)}
],
]



# Create data augmentation (optional)
augmentor = FuseAugmentorDefault(
augmentation_pipeline=augmentation_pipeline)




return augmentor


def TABULAR_dataset(tabular_processor,df,tabular_features,sample_key):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotations, name with lower case letters (tabular processor). and comments.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this func

tabular_features.remove(sample_key)
tabular_processor = tabular_processor(data=df,
sample_desc_column=sample_key,
columns_to_extract=tabular_features + [sample_key],
columns_to_tensor=tabular_features)
return tabular_processor


def IMAGING_TABULAR_dataset(df, imaging_processor, tabular_processor,label_key:str,img_key:str,tabular_features_lst: list,sample_key: str,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add missing type annotations.
rename df to something like data_split: List[pd.Dataframe]

cache_dir: str = 'cache', reset_cache: bool = False,
post_cache_processing_func: Optional[Callable] = None) -> Tuple[FuseDatasetDefault, FuseDatasetDefault]:


lgr = logging.getLogger('Fuse')

if isinstance(df,list):
df_train = df[0]
if len(df)>1:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that you only support len(df)==3 - so add assert with an error message instead.

df_val = df[1]
if len(df)>2:
df_test = df[2]

Comment thread
taltlusty marked this conversation as resolved.
#----------------------------------------------
# -----Datasource
train_data_source = FuseDataSourceDefault(input_source=df_train)
validation_data_source = FuseDataSourceDefault(input_source=df_val)
test_data_source = FuseDataSourceDefault(input_source=df_test)

# ----------------------------------------------
# -----Data-processors
img_clinical_processor_train = ImagingTabularProcessor(data=df_train,
label=label_key,
img_key = img_key,
image_processor=imaging_processor(''),
tabular_processor= \
TABULAR_dataset(tabular_processor,df_train,tabular_features_lst.copy(),sample_key))

img_clinical_processor_val = ImagingTabularProcessor(data=df_val,
label=label_key,
img_key=img_key,
image_processor=imaging_processor(''),
tabular_processor=\
TABULAR_dataset(tabular_processor,df_val,tabular_features_lst.copy(),sample_key))

img_clinical_processor_test = ImagingTabularProcessor(data=df_test,
label=label_key,
img_key=img_key,
image_processor=imaging_processor(''),
tabular_processor= \
TABULAR_dataset(tabular_processor,df_test,tabular_features_lst.copy(),sample_key))



visualiser = FuseVisualizerDefault(image_name='data.image', label_name='data.gt')


# ----------------------------------------------
# ------ Dataset
train_dataset = FuseDatasetGenerator(cache_dest=cache_dir,
data_source=train_data_source,
processor=img_clinical_processor_train,
augmentor=IMAGING_dataset(),
visualizer=visualiser,
post_processing_func=post_cache_processing_func,)


validation_dataset = FuseDatasetGenerator(cache_dest=cache_dir,
data_source=validation_data_source,
processor=img_clinical_processor_val,
augmentor=None,
visualizer=visualiser,
post_processing_func=post_cache_processing_func,)

test_dataset = FuseDatasetGenerator(cache_dest=cache_dir,
data_source=test_data_source,
processor=img_clinical_processor_test,
augmentor=None,
visualizer=visualiser,
post_processing_func=post_cache_processing_func,)


# ----------------------------------------------
# ------ Cache

# create cache
train_dataset.create(reset_cache=reset_cache) # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading
validation_dataset.create() # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading
test_dataset.create() # use ThreadPool to create this dataset, to avoid cv2 problems in multithreading

lgr.info(f'- Load and cache data:')

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this line to line 158


lgr.info(f'- Load and cache data: Done')

return train_dataset, validation_dataset, test_dataset


25 changes: 25 additions & 0 deletions fuse_examples/classification/multimodality/input_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
def sample_desc_to_xml_path(df, sample_desc,img_key):
xml_path = df[img_key][df.sample_desc == sample_desc].values
return xml_path
def get_gt_from_tabular_sample(tabular_sample_dict,gt_key):
gt = tabular_sample_dict[gt_key]
tabular_sample_dict.pop(gt_key)
return tabular_sample_dict,gt

class ImagingTabularProcessor:
def __init__(self, data, label,img_key,image_processor, tabular_processor):
self.image_processor = image_processor
self.tabular_processor = tabular_processor
self.data = data
self.label = label
self.img_key = img_key
def __call__(self, sample_desc):
img_path = sample_desc_to_xml_path(self.data, sample_desc,self.img_key)
tabular_sample_dict = self.tabular_processor(sample_desc)
image_dict_list = self.image_processor(img_path[0][0])
tabular_sample_dict,gt = get_gt_from_tabular_sample(tabular_sample_dict.copy(), self.label)
img_sample_dict = image_dict_list
sample = tabular_sample_dict
sample['image'] = img_sample_dict
sample['gt'] = gt
return sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Dict
import torch
import torch.nn.functional as F
from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict


def softcrossentropyloss(target, logits):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add underscores?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that targets is soft label with shape [N, NUM_CLASSES], can you please mention it in a comment.

"""
From the pytorch discussion Forum:
https://discuss.pytorch.org/t/soft-cross-entropy-loss-tf-has-it-does-pytorch-have-it/69501
"""
logprobs = torch.nn.functional.log_softmax(logits, dim=1)
loss = -(target * logprobs).sum() / logits.shape[0]
return loss


class FuseLossMultimodalContrastiveLearning:
def __init__(self,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add an explanation/link to an explanation?

imaging_representations: str = None,
tabular_representations: str = None,
label: str = None,
temperature: float = 1.0,
alpha: float = 0.5
) -> None:
self.imaging_representations = imaging_representations
self.tabular_representations = tabular_representations
self.temperature = temperature
self.label = label
self.alpha = alpha

def __call__(self, batch_dict: Dict) -> torch.Tensor:
# filter batch_dict if required
imaging_representations = FuseUtilsHierarchicalDict.get(batch_dict, self.imaging_representations)
tabular_representations = FuseUtilsHierarchicalDict.get(batch_dict, self.tabular_representations)
label = FuseUtilsHierarchicalDict.get(batch_dict, self.label)
if len(imaging_representations.shape)<2:
imaging_representations = imaging_representations.unsqueeze(dim=0)
if len(imaging_representations.shape) < 2:
tabular_representations = tabular_representations.unsqueeze(dim=0)
imaging_representations = F.normalize(imaging_representations, p=2, dim=1)
tabular_representations = F.normalize(tabular_representations, p=2, dim=1)
label_vec = torch.unsqueeze(label, 0)
mask = torch.eq(torch.transpose(label_vec, 0, 1), label_vec).float()
logits_imaging_tabular = torch.matmul(imaging_representations, torch.transpose(tabular_representations, 0, 1))/self.temperature
logits_tabular_imaging = torch.matmul(tabular_representations, torch.transpose(imaging_representations, 0, 1))/self.temperature
loss_imaging_tabular = softcrossentropyloss(mask, logits_imaging_tabular)/torch.sum(mask, 0)
loss_tabular_imaging = softcrossentropyloss(mask, logits_tabular_imaging)/torch.sum(mask, 0)
return self.alpha*loss_tabular_imaging.sum() + (1-self.alpha)*loss_imaging_tabular.sum()


if __name__ == '__main__':

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove main

import torch

batch_dict = {'model.imaging_representations': torch.randn(3, 2),
'model.tabular_representations': torch.randn(3, 2),
'data.label': torch.empty(3, dtype=torch.long).random_(2)}

loss = FuseLossMultimodalContrastiveLearning(temperature=0.1,
imaging_representations='model.imaging_representations',
tabular_representations='model.tabular_representations',
label='data.label')
res = loss(batch_dict)
print('Loss output = ' + str(res))
Loading