Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Including code for training from checkpoint #61

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion data_aug/contrastive_learning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def get_dataset(self, name, n_views):
transform=ContrastiveLearningViewGenerator(
self.get_simclr_pipeline_transform(96),
n_views),
download=True)}
download=True)
}

try:
dataset_fn = valid_datasets[name]
Expand Down
3 changes: 3 additions & 0 deletions exceptions/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ class InvalidBackboneError(BaseSimCLRException):

class InvalidDatasetSelection(BaseSimCLRException):
"""Raised when the choice of dataset is invalid."""

class InvalidCheckpointPath(BaseSimCLRException):
"""Raised when the path of the checkpoint is invalid"""
5 changes: 4 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('-ckpt', default=None, type=str, metavar='CKPT',
help='the checkpoint to resume training')
parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
Expand Down Expand Up @@ -73,6 +75,7 @@ def main():
num_workers=args.workers, pin_memory=True, drop_last=True)

model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim)
ckpt = args.ckpt

optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)

Expand All @@ -81,7 +84,7 @@ def main():

# It’s a no-op if the 'gpu_index' argument is a negative integer or None.
with torch.cuda.device(args.gpu_index):
simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args)
simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args, ckpt=ckpt)
simclr.train(train_loader)


Expand Down
18 changes: 13 additions & 5 deletions simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils import save_config_file, accuracy, save_checkpoint
from exceptions.exceptions import InvalidCheckpointPath
from utils import save_config_file, accuracy, save_checkpoint, load_checkpoint

torch.manual_seed(0)

Expand Down Expand Up @@ -60,12 +61,19 @@ def train(self, train_loader):

# save config file
save_config_file(self.writer.log_dir, self.args)

n_iter = 0
n_iter, start_epochs, end_epochs = 0, 0, self.args.epochs

if(self.args.ckpt):
try:
self.model, start_epochs = load_checkpoint(self.model, self.args.ckpt)
end_epochs = end_epochs + start_epochs
except:
InvalidCheckpointPath()

logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
logging.info(f"Training with gpu: {self.args.disable_cuda}.")
logging.info(f"Training with gpu: {self.args.gpu_index}.")

for epoch_counter in range(self.args.epochs):
for epoch_counter in range(start_epochs, end_epochs):
for images, _ in tqdm(train_loader):
images = torch.cat(images, dim=0)

Expand Down
8 changes: 8 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
import yaml


def load_checkpoint(model, filepath):
if(os.path.exists(filepath)):
ckpt = torch.load(filepath)
model.load_state_dict(ckpt['state_dict'])
epoch = ckpt['epoch']
return model, epoch


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
Expand Down