Skip to content

Gradient Scaling With Pipeline Parallelism #803

Open
@windsornguyen

Description

The idiomatic way to perform gradient scaling is something like this:

preds = model(inputs)
loss = loss_fn(preds, targets)
scaler.scale(loss).backward()

Given that the current PyTorch PP API handles the backward pass internally, I find it difficult to do gradient scaling under a PP regime.

if is_first_stage:
    pp_schedule.step(inputs)                        # bwd performed internally
elif is_last_stage:
    losses = []
    pp_schedule.step(target=targets, losses=losses) # bwd performed internally
else:
    pp_schedule.step()                              # bwd performed internally

loss = (
    torch.mean(torch.stack(losses)).to(device)
    if is_last_stage
    else torch.tensor([-1.0], device=device)
)

# scaler.scale(loss).backward() <-- !? backward pass has already been performed

Is there currently a good way to do gradient scaling with Pipeline Parallelism? And if not, will the Pipeline Parallelism API support gradient scaling in the near-term future?

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions