Open
Description
The idiomatic way to perform gradient scaling is something like this:
preds = model(inputs)
loss = loss_fn(preds, targets)
scaler.scale(loss).backward()
Given that the current PyTorch PP API handles the backward pass internally, I find it difficult to do gradient scaling under a PP regime.
if is_first_stage:
pp_schedule.step(inputs) # bwd performed internally
elif is_last_stage:
losses = []
pp_schedule.step(target=targets, losses=losses) # bwd performed internally
else:
pp_schedule.step() # bwd performed internally
loss = (
torch.mean(torch.stack(losses)).to(device)
if is_last_stage
else torch.tensor([-1.0], device=device)
)
# scaler.scale(loss).backward() <-- !? backward pass has already been performed
Is there currently a good way to do gradient scaling with Pipeline Parallelism? And if not, will the Pipeline Parallelism API support gradient scaling in the near-term future?