Skip to content

Error about using a grad transform with in-place operation is inconsistent with and without DDP #1112

Open
@XuchanBao

Description

Hi,

I was using torch.func in pytorch 2.0 to compute the Hessian-vector product of a neural network.

I first used torch.func.functional_call to define a functional version of the neural network model, and then proceeded to use torch.func.jvp and torch.func.grad to compute the hvp.

The above works when I was using one gpu without parallel processing. However, when I wrapped the model with Distributed Data Parallel (DDP), it gave the following error:

*** RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::copy_) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.

I am confused about this error, because if there were indeed such in-place operations (which I couldn't find in my model.forward() code), I'd expect this error to occur regardless of DDP. Given the inconsistent behaviour, can I still trust the hvp result when I wasn't using DDP?

My torch version: is 2.0.0.dev20230119+cu117

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions