Path location for best model storage #1100
-
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 12 replies
-
Hi @Minxiangliu , if you are using monailabel's basic trainer, the saved best metric model should be named by the |
Beta Was this translation helpful? Give feedback.
-
Hi @tangy5 ,
def getPreTrans(**kwargs):
if kwargs['mode'] == 'train':
return [
LoadImaged(keys=['image','label'], reader='ITKReader'),
NormalizeLabelsInDatasetd(keys='label', label_names=kwargs['labels']),
EnsureChannelFirstd(keys=['image','label']),
Orientationd(keys=['image','label'], axcodes='RAS'),
Spacingd(keys=['image','label'], pixdim=kwargs['target_spacing'], mode=('bilinear', 'nearest')),
Lambdad(keys=['image','label'], func=lambda x:np.moveaxis(x,-1,1)),
CropForegroundd(keys=['image','label'], source_key='image', k_divisible=kwargs['padd_size'])
] path: class SegmentConfig(TaskConfig):
def init(self, name: str, model_dir: str, conf:dict, planner: Any, **kwargs):
super().init(name, model_dir, conf, planner, **kwargs)
# Model Files
self.path = [
os.path.join(self.model_dir, f"pretrained_{name}.pt"), # pretrained
os.path.join(self.model_dir, f"{name}.pt"), # published
]
self.trans_params ={
'labels':{"cancer": 1},
'target_spacing':(3.33, 3.33, 2.18),
'padd_size':(96, 96, 96),
'crop_size':(96, 96, 96)
}
network_params = {
'spatial_dims':3,
'in_channels':1,
'out_channels':len(self.trans_params['labels']) + 1, # labels plus background,
'channels':(16, 32, 64, 128, 256),
'strides':(2, 2, 2, 2),
'num_res_units':2,
'norm':'batch',
}
self.slidingWindowInfererParams = {
'roi_size':self.trans_params['crop_size'], 'sw_batch_size':4, 'overlap':0.5, 'padding_mode':'replicate'
}
self.network = UNet(**network_params)
self.description = "A model for volumetric (3D) segmentation of the HNC from PET CT image"
def trainer(self) -> Optional[TrainTask]:
output_dir = os.path.join(self.model_dir, self.name)
load_path = self.path[0] if os.path.exists(self.path[0]) else self.path[1]
_trans_params = self.trans_params.copy()
_trans_params.update({'mode':'train'})
task: TrainTask = lib.trainers.ModelTrain(
description=self.description,
model_dir=output_dir,
publish_path=self.path[1],
network=self.network,
load_path=load_path,
crop_size=self.trans_params['crop_size'],
labels=self.trans_params['labels'],
slidingWindowInfererParams = self.slidingWindowInfererParams,
config=self.conf,
preTrans=getPreTrans(**_trans_params)
)
return task path: class ModelTrain(BasicTrainTask):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._network = kwargs['network']
self.crop_size = kwargs['crop_size']
self.labels = kwargs['labels']
self.pre_trans = kwargs['preTrans']
self.slidingWindowInfererParams = kwargs['slidingWindowInfererParams']
def network(self, context: Context):
return self._network
def optimizer(self, context: Context):
return torch.optim.Adam(context.network.parameters(), lr=1e-4, weight_decay=1e-5)
def loss_function(self, context: Context):
return DiceCELoss(to_onehot_y=True, softmax=True)
def lr_scheduler_handler(self, context: Context):
return None
def train_data_loader(self, context, num_workers=0, shuffle=False):
return super().train_data_loader(context, num_workers, True)
def train_pre_transforms(self, context: Context):
train_transforms = self.pre_trans.copy()
train_transforms.extend([
EnsureTyped(keys=['image','label'], device=context.device),
RandSpatialCropd(keys=['image', 'label'], roi_size=self.crop_size, random_size=False, random_center=True),
RandFlipd(keys=['image','label'], prob=0.5, spatial_axis=0),
RandFlipd(keys=['image','label'], prob=0.5, spatial_axis=1),
RandFlipd(keys=['image','label'], prob=0.5, spatial_axis=2),
NormalizeIntensityd(keys='image', nonzero=True, channel_wise=True),
RandScaleIntensityd(keys='image', factors=0.1, prob=1.0),
RandShiftIntensityd(keys='image', offsets=0.1, prob=1.0),
SelectItemsd(keys=['image','label'])
])
return train_transforms
def train_post_transforms(self, context: Context):
return [
EnsureTyped(keys="pred", device=context.device),
Activationsd(keys="pred", softmax=True),
AsDiscreted(
keys=("pred", "label"),
argmax=(True, False),
to_onehot=(len(self.labels) + 1, len(self.labels) + 1),
),
]
def val_pre_transforms(self, context: Context):
val_transforms = self.pre_trans.copy()
val_transforms.extend([
SelectItemsd(keys=['image','label'])
])
return val_transforms
def val_inferer(self, context: Context):
return SlidingWindowInferer(**self.slidingWindowInfererParams)
def train_key_metric(self, context: Context):
return region_wise_metrics(self.labels, self.TRAIN_KEY_METRIC, "train")
def val_key_metric(self, context: Context):
return region_wise_metrics(self.labels, self.VAL_KEY_METRIC, "val")
def train_handlers(self, context: Context):
handlers = super().train_handlers(context)
if context.local_rank == 0:
handlers.append(
TensorBoardImageHandler(
log_dir=context.events_dir,
batch_transform=from_engine(["image", "label"]),
output_transform=from_engine(["pred"]),
interval=10,
epoch_level=True,
)
)
return handlers Additional verification code: def load_model(path:str, model:UNet):
state_dict = torch.load(path, map_location='cuda')
model.load_state_dict(state_dict, strict=True)
model.eval().cuda()
return model
def getDataList():
with open(r'D:\AI-Project-Code\HNC\Segmentation\DataSets\cvDataSets\testCV1.json', 'r') as file:
val_ds = json.load(file)
return val_ds
def val_dataloader(test_ds):
return ThreadDataLoader(test_ds,
num_workers=0,
batch_size=1,
shuffle=False,
drop_last=False)
trans_params ={
'mode':'train',
'labels':{"cancer": 1},
'target_spacing':(3.33, 3.33, 2.18),
'padd_size':(96, 96, 96),
'crop_size':(96, 96, 96)
}
transform=getPreTrans(**trans_params)
transform.extend([EnsureTyped(keys=['image','label']), ToDeviced(keys=['image','label'],device='cuda')])
transform = Compose(transform)
datasetParameter = {'transform':transform, 'cache_rate':1.0, 'copy_cache':False, 'num_workers':2}
datasetParameter['data'] = getDataList()
test_ds = CacheDataset(**datasetParameter)
loader = val_dataloader(test_ds)
network_params = {
'spatial_dims':3,
'in_channels':1,
'out_channels':len(trans_params['labels']) + 1, # labels plus background,
'channels':(16, 32, 64, 128, 256),
'strides':(2, 2, 2, 2),
'num_res_units':2,
'norm':'batch',
}
model = UNet(**network_params)
model = load_model(
path=r'D:\AI-Project-Code\HNC\Segmentation\monailabel\app-hnc\modelCV1\segmentation\train_01\model.pt', model=model)
post_pred = Compose([Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=len(trans_params['labels']) + 1)])
post_label = Compose([AsDiscrete(argmax=False, to_onehot=len(trans_params['labels']) + 1)])
dice_metric = DiceMetric(include_background=False, reduction="mean")
for idx, batch_data in enumerate(loader):
image, label = batch_data['image'], batch_data['label']
outputs = sliding_window_inference(
inputs=image,
roi_size=trans_params['crop_size'],
sw_batch_size=4,
predictor=model,
overlap=0.5,
padding_mode='replicate')
outputs = [post_pred(i) for i in decollate_batch(outputs)]
label = [post_label(i) for i in decollate_batch(label)]
dice_metric(y_pred=outputs, y=label)
metric = dice_metric.aggregate().item()
dice_metric.reset()
print(metric) The metric result is about 0.65 |
Beta Was this translation helpful? Give feedback.
-
Hi @Minxiangliu, I see two different metrics, one is DiceMetric and the other Dice + Cross Entropy (DiceCELoss). Any reason for this? |
Beta Was this translation helpful? Give feedback.
-
Hi @tangy5 , In the post transform, I wrote the same method according to the writing method of monailabel. In monailabel: def train_post_transforms(self, context: Context):
return [
EnsureTyped(keys="pred", device=context.device),
Activationsd(keys="pred", softmax=True),
AsDiscreted(
keys=("pred", "label"),
argmax=(True, False),
to_onehot=(len(self.labels) + 1, len(self.labels) + 1),
),
]
def val_inferer(self, context: Context):
self.slidingWindowInfererParams= {
'roi_size':self.trans_params['crop_size'], 'sw_batch_size':4, 'overlap':0.5, 'padding_mode':'replicate'
}
return SlidingWindowInferer(**self.slidingWindowInfererParams) In custom method: post_pred = Compose([Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=len(trans_params['labels']) + 1)])
post_label = Compose([AsDiscrete(argmax=False, to_onehot=len(trans_params['labels']) + 1)])
image, label = batch_data['image'], batch_data['label']
outputs = sliding_window_inference(
inputs=image,
roi_size=trans_params['crop_size'],
sw_batch_size=4,
predictor=model,
overlap=0.5,
padding_mode='replicate')
outputs = [post_pred(i) for i in decollate_batch(outputs)]
label = [post_label(i) for i in decollate_batch(label)] In DiceMetric, I refer to the writing of monailabel: MONAILabel/monailabel/tasks/train/utils.py Lines 17 to 28 in 22d1e54 In custom method: dice_metric = DiceMetric(include_background=False, reduction="mean")
dice_metric(y_pred=outputs, y=label) Here I am not sure if it is consistent with the setting of monailabel. |
Beta Was this translation helpful? Give feedback.
-
Hi @tangy5 , I've packaged up my code, model, and test files, can you help me identify what's the problem? Thanks for your time. If you are in the If you can directly execute 3D Slicer Execution Screen: monailabel log:
|
Beta Was this translation helpful? Give feedback.
Hi @Minxiangliu , if you are using monailabel's basic trainer, the saved best metric model should be named by the
key_metric_filename
, if you are using the default setting, it should be the "model.pt"MONAILabel/monailabel/tasks/train/basic_train.py
Line 110 in 22d1e54