Error about using a grad transform with in-place operation is inconsistent with and without DDP #1112
Description
Hi,
I was using torch.func
in pytorch 2.0 to compute the Hessian-vector product of a neural network.
I first used torch.func.functional_call
to define a functional version of the neural network model, and then proceeded to use torch.func.jvp
and torch.func.grad
to compute the hvp.
The above works when I was using one gpu without parallel processing. However, when I wrapped the model with Distributed Data Parallel (DDP), it gave the following error:
*** RuntimeError: During a grad (vjp, jvp, grad, etc) transform, the function provided attempted to call in-place operation (aten::copy_) that would mutate a captured Tensor. This is not supported; please rewrite the function being transformed to explicitly accept the mutated Tensor(s) as inputs.
I am confused about this error, because if there were indeed such in-place operations (which I couldn't find in my model.forward() code), I'd expect this error to occur regardless of DDP. Given the inconsistent behaviour, can I still trust the hvp result when I wasn't using DDP?
My torch version: is 2.0.0.dev20230119+cu117