|
| 1 | +import os |
| 2 | +from functools import cached_property |
| 3 | +from time import time |
| 4 | + |
| 5 | +import click |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +from torch.utils.tensorboard import SummaryWriter |
| 9 | +from yacs.config import CfgNode |
| 10 | + |
| 11 | +from neutorch.data.dataset import BoundaryAugmentationDataset |
| 12 | +from neutorch.model.io import log_tensor, save_chkpt |
| 13 | + |
| 14 | +from .base import SemanticTrainer, TrainerBase |
| 15 | + |
| 16 | +class BoundaryAugTrainer(SemanticTrainer): |
| 17 | + def __init__(self, cfg: CfgNode) -> None: |
| 18 | + assert isinstance(cfg, CfgNode) |
| 19 | + super().__init__(cfg) |
| 20 | + self.cfg = cfg |
| 21 | + breakpoint() |
| 22 | + |
| 23 | + @cached_property |
| 24 | + def training_dataset(self): |
| 25 | + return BoundaryAugmentationDataset.from_config(self.cfg, is_train=True) |
| 26 | + |
| 27 | + @cached_property |
| 28 | + def validation_dataset(self): |
| 29 | + return BoundaryAugmentationDataset.from_config(self.cfg, is_train=False) |
| 30 | + |
| 31 | + """ |
| 32 | + def call(self): |
| 33 | + writer = SummaryWriter(log_dir=self.cfg.train.output_dir) |
| 34 | + accumulated_loss = 0. #floating point |
| 35 | +
|
| 36 | + for image, label in self.training_data_loader: |
| 37 | + iter_idx += 1 |
| 38 | + if iter_idx > self.cfg.train.iter_stop: |
| 39 | + print('exeeds maximum iteration:', self.cfg.train.iter_stop) |
| 40 | + return |
| 41 | + |
| 42 | + pint = time() |
| 43 | + predict = self.model(image) |
| 44 | + loss = self.loss_module(predict, label) |
| 45 | + assert not torch.isnan(loss), 'loss is NaN.' |
| 46 | +
|
| 47 | + self.optimizer # |
| 48 | + loss.backward() |
| 49 | + self.optimizer.step() |
| 50 | + accumulated_loss += loss.tolist() |
| 51 | +
|
| 52 | + if iter_idx % self.cfg.train.training_interval == 0 and iter_idx > 0: |
| 53 | + per_voxel_loss = accumulated_loss / \ |
| 54 | + self.cfg.train.training_interval / \ |
| 55 | + self.voxel_num |
| 56 | +
|
| 57 | + print(f'iteration {iter_idx} takes {round(time()-ping, 3)} seconds with loss: {per_voxel_loss}') |
| 58 | + accumulated_loss = 0. |
| 59 | + predict = self.post_processing(predict) |
| 60 | + writer.add_scalar('loss/train', per_voxel_loss, iter_idx) |
| 61 | + log_tensor(writer, 'train/image', image, 'image', iter_idx) |
| 62 | + log_tensor(writer, 'train/prediction', predict.detach(), 'image', iter_idx) |
| 63 | + log_tensor(writer, 'train/label', label, 'segmentation', iter_idx) |
| 64 | +
|
| 65 | + if iter_idx % self.cfg.train.validation_interval == 0 and iter_idx > 0: |
| 66 | + fname = os.path.join(self.cfg.train.output_dir, f'model_{iter_idx}.chkpt') |
| 67 | + print(f'save model to {fname}') |
| 68 | + save_chkpt(self.model, self.cfg.train.output_dir, iter_idx, self.optimizer) |
| 69 | +
|
| 70 | + print('evaluate prediction: ') |
| 71 | + validation_image, validation_label = next(self.validation_data_iter) |
| 72 | +
|
| 73 | + with torch.no_grad(): |
| 74 | + validation_predict = self.model(validation_image) |
| 75 | + validation_loss = self.loss_module(validation_predict, validation_label) |
| 76 | + validation_predict = self.post_processing(validation_predict) |
| 77 | + per_voxel_loss = validation_loss.tolist() / self.voxel_num |
| 78 | + print(f'iteration {iter_idx} takes {round(time()-ping, 3)} seconds with loss: {per_voxel_loss}') |
| 79 | + writer.add_scalar('loss/validation', per_voxel_loss, iter_idx) |
| 80 | + log_tensor(writer, 'validation/image', validation_image, 'image', iter_idx) |
| 81 | + log_tensor(writer, 'validation/prediction', validation_predict, 'image', iter_idx) |
| 82 | + log_tensor(writer, 'validation/label', validation_label, 'segmentation', iter_idx) |
| 83 | +
|
| 84 | + writer.close() |
| 85 | + """ |
| 86 | + |
| 87 | +@click.command() |
| 88 | +@click.option('--config-file', '-c', |
| 89 | + type=click.Path(exists=True, dir_okay=False, file_okay=True, readable=True, resolve_path=True), |
| 90 | + default='./config.yaml', |
| 91 | + help = 'configuration file containing all the parameters.' |
| 92 | +) |
| 93 | + |
| 94 | +def main(config_file: str): |
| 95 | + from neutorch.data.dataset import load_cfg |
| 96 | + cfg = load_cfg(config_file) |
| 97 | + trainer = BoundaryAugTrainer(cfg) |
| 98 | + trainer() |
| 99 | + |
| 100 | + |
0 commit comments