Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
841,354 changes: 841,354 additions & 0 deletions Notes/GarvSachdev_Notes_and_Files/ACDC_Dataset_test/ACDC visualiser.ipynb

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions Notes/GarvSachdev_Notes_and_Files/ACDC_Dataset_test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Implementation of TransUNet Segmentation model on ACDC Dataset

- to run directly on your system, please change the data path from "./data/ACDC" in test.py or train.py to the path to your Dataset.

- this folder does NOT include the dataset due to size reasons.

- test.py in this folder does NOT save the visualised images for every predicted file. Please modify and run the ACDCvisualiser.ipynb notebook to visualise your data slice by slice.

- Download Google pre-trained ViT models
* [Get models in this link](https://console.cloud.google.com/storage/vit_models/): R50-ViT-B_16, ViT-B_16, ViT-L_16...
```bash
wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz &&
mkdir ../model/vit_checkpoint/imagenet21k &&
mv {MODEL_NAME}.npz ../model/vit_checkpoint/imagenet21k/{MODEL_NAME}.npz
```
I used the R50_ViT-B_16 model

142 changes: 142 additions & 0 deletions Notes/GarvSachdev_Notes_and_Files/ACDC_Dataset_test/dataset_acdc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import itertools
import os
import random
import re
from glob import glob

import cv2
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
from skimage import io
import cv2
class BaseDataSets(Dataset):
def __init__(self, base_dir=None, split='train', list_dir=None, transform=None):
self._base_dir = base_dir
self.sample_list = []
self.split = split
self.transform = transform
train_ids, val_ids, test_ids = self._get_ids()
if self.split.find('train') != -1:
print(f"Looking for slices in: {os.path.abspath('./data/ACDC_training_slices')}")
self.all_slices = os.listdir(
self._base_dir + "/ACDC_training_slices")
self.sample_list = []
for ids in train_ids:
new_data_list = list(filter(lambda x: re.match('{}.*'.format(ids), x) != None, self.all_slices))
self.sample_list.extend(new_data_list)

elif self.split.find('val') != -1:
self.all_volumes = os.listdir(
self._base_dir + "/ACDC_training_volumes")
self.sample_list = []
for ids in val_ids:
new_data_list = list(filter(lambda x: re.match('{}.*'.format(ids), x) != None, self.all_volumes))
self.sample_list.extend(new_data_list)

elif self.split.find('test') != -1:
self.all_volumes = os.listdir(
self._base_dir + "/ACDC_training_volumes")
self.sample_list = []
for ids in test_ids:
new_data_list = list(filter(lambda x: re.match('{}.*'.format(ids), x) != None, self.all_volumes))
self.sample_list.extend(new_data_list)

# if num is not None and self.split == "train":
# self.sample_list = self.sample_list[:num]
print("total {} samples".format(len(self.sample_list)))

def _get_ids(self):
all_cases_set = ["patient{:0>3}".format(i) for i in range(1, 101)]
testing_set = ["patient{:0>3}".format(i) for i in range(1, 21)]
validation_set = ["patient{:0>3}".format(i) for i in range(21, 31)]
training_set = [i for i in all_cases_set if i not in testing_set+validation_set]

return [training_set, validation_set, testing_set]

def __len__(self):
return len(self.sample_list)

def __getitem__(self, idx):
case = self.sample_list[idx]

# image = h5f['image'][:]
# label = h5f['label'][:]
# sample = {'image': image, 'label': label}
if self.split == "train":
h5f = h5py.File(self._base_dir + "/ACDC_training_slices/{}".format(case), 'r')
image = h5f['image'][:]
label = h5f['label'][:] # fix sup_type to label
sample = {'image': image, 'label': label}
sample = self.transform(sample)
else:
h5f = h5py.File(self._base_dir + "/ACDC_training_volumes/{}".format(case), 'r')
image = h5f['image'][:]
label = h5f['label'][:]
sample = {'image': image, 'label': label}
sample["idx"] = idx
sample['case_name'] = case.replace('.h5', '')
return sample


def random_rot_flip(image, label):
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label


def random_rotate(image, label):
angle = np.random.randint(-20, 20)
image = ndimage.rotate(image, angle, order=0, reshape=False)
label = ndimage.rotate(label, angle, order=0, reshape=False)
return image, label


class RandomGenerator(object):
def __init__(self, output_size):
self.output_size = output_size

def __call__(self, sample):
image, label = sample['image'], sample['label']
# ind = random.randrange(0, img.shape[0])
# image = img[ind, ...]
# label = lab[ind, ...]
if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
x, y = image.shape
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=0) # the default is 0
label = zoom( label, (self.output_size[0] / x, self.output_size[1] / y), order=0)

assert (image.shape[0] == self.output_size[0]) and (image.shape[1] == self.output_size[1])
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
label = torch.from_numpy(label.astype(np.uint8))
sample = {'image': image, 'label': label}
return sample


def iterate_once(iterable):
return np.random.permutation(iterable)


def iterate_eternally(indices):
def infinite_shuffles():
while True:
yield np.random.permutation(indices)
return itertools.chain.from_iterable(infinite_shuffles())


def grouper(iterable, n):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3) --> ABC DEF"
args = [iter(iterable)] * n
return zip(*args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import ml_collections

def get_b16_config():
"""Returns the ViT-B/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 768
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 3072
config.transformer.num_heads = 12
config.transformer.num_layers = 12
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1

config.classifier = 'seg'
config.representation_size = None
config.resnet_pretrained_path = None
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz'
config.patch_size = 16

config.decoder_channels = (256, 128, 64, 16)
config.n_classes = 2
config.activation = 'softmax'
return config


def get_testing():
"""Returns a minimal configuration for testing."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 1
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 1
config.transformer.num_heads = 1
config.transformer.num_layers = 1
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None
return config

def get_r50_b16_config():
"""Returns the Resnet50 + ViT-B/16 configuration."""
config = get_b16_config()
config.patches.grid = (16, 16)
config.resnet = ml_collections.ConfigDict()
config.resnet.num_layers = (3, 4, 9)
config.resnet.width_factor = 1

config.classifier = 'seg'
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.skip_channels = [512, 256, 64, 16]
config.n_classes = 2
config.n_skip = 3
config.activation = 'softmax'

return config


def get_b32_config():
"""Returns the ViT-B/32 configuration."""
config = get_b16_config()
config.patches.size = (32, 32)
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz'
return config


def get_l16_config():
"""Returns the ViT-L/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
config.hidden_size = 1024
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 4096
config.transformer.num_heads = 16
config.transformer.num_layers = 24
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.representation_size = None

# custom
config.classifier = 'seg'
config.resnet_pretrained_path = None
config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.n_classes = 2
config.activation = 'softmax'
return config


def get_r50_l16_config():
"""Returns the Resnet50 + ViT-L/16 configuration. customized """
config = get_l16_config()
config.patches.grid = (16, 16)
config.resnet = ml_collections.ConfigDict()
config.resnet.num_layers = (3, 4, 9)
config.resnet.width_factor = 1

config.classifier = 'seg'
config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz'
config.decoder_channels = (256, 128, 64, 16)
config.skip_channels = [512, 256, 64, 16]
config.n_classes = 2
config.activation = 'softmax'
return config


def get_l32_config():
"""Returns the ViT-L/32 configuration."""
config = get_l16_config()
config.patches.size = (32, 32)
return config


def get_h14_config():
"""Returns the ViT-L/16 configuration."""
config = ml_collections.ConfigDict()
config.patches = ml_collections.ConfigDict({'size': (14, 14)})
config.hidden_size = 1280
config.transformer = ml_collections.ConfigDict()
config.transformer.mlp_dim = 5120
config.transformer.num_heads = 16
config.transformer.num_layers = 32
config.transformer.attention_dropout_rate = 0.0
config.transformer.dropout_rate = 0.1
config.classifier = 'token'
config.representation_size = None

return config
Loading