Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better device handling #1301

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 24 additions & 45 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def main_worker(gpu, ngpus_per_node, args):
# For multiprocessing distributed training, rank needs to be the
# global rank among all the processes
args.rank = args.rank * ngpus_per_node + gpu
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.rank)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
# create model
Expand All @@ -154,20 +155,14 @@ def main_worker(gpu, ngpus_per_node, args):
# should always set the single device scope, otherwise,
# DistributedDataParallel will use all available devices.
if torch.cuda.is_available():
model.cuda()
if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs of the current node.
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
# available GPUs if device_ids are not set
model = torch.nn.parallel.DistributedDataParallel(model)
model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None and torch.cuda.is_available():
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
Expand All @@ -183,7 +178,7 @@ def main_worker(gpu, ngpus_per_node, args):
model = torch.nn.DataParallel(model).cuda()

if torch.cuda.is_available():
if args.gpu:
if args.gpu and not args.distributed:
device = torch.device('cuda:{}'.format(args.gpu))
else:
device = torch.device("cuda")
Expand All @@ -205,17 +200,11 @@ def main_worker(gpu, ngpus_per_node, args):
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
elif torch.cuda.is_available():
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
checkpoint = torch.load(args.resume, map_location=device)
args.start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
if args.gpu is not None:
# best_acc1 may be from a checkpoint from a different GPU
best_acc1 = best_acc1.to(args.gpu)
# best_acc1 may be from a checkpoint from a different GPU
best_acc1 = best_acc1.to(device=device)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
Expand Down Expand Up @@ -270,7 +259,7 @@ def main_worker(gpu, ngpus_per_node, args):
num_workers=args.workers, pin_memory=True, sampler=val_sampler)

if args.evaluate:
validate(val_loader, model, criterion, args)
validate(val_loader, model, criterion, device, args)
return

for epoch in range(args.start_epoch, args.epochs):
Expand All @@ -281,7 +270,7 @@ def main_worker(gpu, ngpus_per_node, args):
train(train_loader, model, criterion, optimizer, epoch, device, args)

# evaluate on validation set
acc1 = validate(val_loader, model, criterion, args)
acc1 = validate(val_loader, model, criterion, device, args)

scheduler.step()

Expand All @@ -302,11 +291,11 @@ def main_worker(gpu, ngpus_per_node, args):


def train(train_loader, model, criterion, optimizer, epoch, device, args):
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
batch_time = AverageMeter('Time', device, ':6.3f')
data_time = AverageMeter('Data', device, ':6.3f')
losses = AverageMeter('Loss', device, ':.4e')
top1 = AverageMeter('Acc@1', device, ':6.2f')
top5 = AverageMeter('Acc@5', device, ':6.2f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
Expand Down Expand Up @@ -347,20 +336,15 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args):
progress.display(i + 1)


def validate(val_loader, model, criterion, args):
def validate(val_loader, model, criterion, device, args):

def run_validate(loader, base_progress=0):
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(loader):
i = base_progress + i
if args.gpu is not None and torch.cuda.is_available():
images = images.cuda(args.gpu, non_blocking=True)
if torch.backends.mps.is_available():
images = images.to('mps')
target = target.to('mps')
if torch.cuda.is_available():
target = target.cuda(args.gpu, non_blocking=True)
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)

# compute output
output = model(images)
Expand All @@ -379,10 +363,10 @@ def run_validate(loader, base_progress=0):
if i % args.print_freq == 0:
progress.display(i + 1)

batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
batch_time = AverageMeter('Time', device, ':6.3f', Summary.NONE)
losses = AverageMeter('Loss', device, ':.4e', Summary.NONE)
top1 = AverageMeter('Acc@1', device, ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', device, ':6.2f', Summary.AVERAGE)
progress = ProgressMeter(
len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
[batch_time, losses, top1, top5],
Expand Down Expand Up @@ -422,8 +406,9 @@ class Summary(Enum):

class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
def __init__(self, name, device, fmt=':f', summary_type=Summary.AVERAGE):
self.name = name
self.device = device
self.fmt = fmt
self.summary_type = summary_type
self.reset()
Expand All @@ -441,13 +426,7 @@ def update(self, val, n=1):
self.avg = self.sum / self.count

def all_reduce(self):
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=self.device)
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
self.sum, self.count = total.tolist()
self.avg = self.sum / self.count
Expand Down