Skip to content

Commit e846968

Browse files
committed
Separate forward and backwad compilation for default partition
ghstack-source-id: 4bc7082b2b8cfc1f980b303689ab193e17ccb6c7 Pull Request resolved: #856
1 parent 130582c commit e846968

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

functorch/_src/aot_autograd.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5454

5555
def create_joint_forward_backward(fn):
5656
def joint_forward_backward(
57-
primals: List[Any], tangents: List[Any]
57+
primals: List[Any], cotangents: List[Any]
5858
) -> Tuple[List[Any], List[Any]]:
5959
# Call the forward pass
6060
outs = fn(*primals)
@@ -68,20 +68,20 @@ def joint_forward_backward(
6868
grad_primals.append(p)
6969

7070
# Get the outputs that need gradients
71-
assert len(tangents) == len(outs)
71+
assert len(cotangents) == len(outs)
7272
needed_outs = []
73-
needed_tangents = []
74-
for out, tangent in zip(outs, tangents):
73+
needed_cotangents = []
74+
for out, cotangent in zip(outs, cotangents):
7575
if isinstance(out, Tensor) and out.requires_grad:
7676
needed_outs.append(out)
77-
needed_tangents.append(tangent)
77+
needed_cotangents.append(cotangent)
7878
backward_out = []
7979
# Call the backwards pass
8080
if grad_primals:
8181
backward_out = torch.autograd.grad(
8282
needed_outs,
8383
grad_primals,
84-
grad_outputs=needed_tangents,
84+
grad_outputs=needed_cotangents,
8585
allow_unused=True,
8686
)
8787
backward_out_iter = iter(backward_out)
@@ -140,12 +140,15 @@ def create_aot_autograd_function(
140140
compiled_fw = None
141141
compiled_bw = None
142142
num_outs = None
143+
joint_inputs = None
144+
fw_outs = None
145+
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
143146

144147
class CompiledFunction(torch.autograd.Function):
145148
@staticmethod
146149
@disable_torchdynamo
147150
def forward(ctx, *flat_tensor_args):
148-
nonlocal compiled_fw, compiled_bw, num_outs
151+
nonlocal compiled_fw, num_outs, joint_inputs, fw_outs
149152
if compiled_fw is None:
150153
with torch.set_grad_enabled(grad_state):
151154
out = flat_fn(*flat_tensor_args)
@@ -159,29 +162,34 @@ def forward(ctx, *flat_tensor_args):
159162
num_outs = 1
160163

161164
joint_inputs = (flat_tensor_args, out)
162-
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
165+
# Need it because autograd.Function disables grad in forward
163166
with torch.set_grad_enabled(grad_state):
164167
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
165168
*joint_inputs
166169
)
167170
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
168-
# print(fw_module.code, bw_module.code)
169171

170172
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
171173
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
172-
173-
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
174-
compiled_bw = bw_compiler(bw_module, bw_args)
174+
if partition_fn is default_partition:
175+
nonlocal compiled_bw
176+
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
177+
compiled_bw = bw_compiler(bw_module, bw_args)
175178
else:
176179
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
177180
ctx.save_for_backward(*fw_outs[num_outs:])
178181
return tuple(fw_outs[0:num_outs])
179182

180183
@staticmethod
181184
@disable_torchdynamo
182-
def backward(ctx, *flat_args):
183-
contiguous_args = [t.contiguous() for t in flat_args]
184-
# contiguous_args = [t for t in flat_args]
185+
def backward(ctx, *flat_grad_outs):
186+
nonlocal compiled_bw
187+
contiguous_args = [t.contiguous() for t in flat_grad_outs]
188+
if compiled_bw is None:
189+
with torch.set_grad_enabled(grad_state):
190+
fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args)
191+
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
192+
compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args)
185193
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
186194
return tuple(out)
187195

0 commit comments

Comments
 (0)