Skip to content

Commit 4cc60ae

Browse files
committed
Separate forward and backwad compilation for default partition
ghstack-source-id: 4de63f2aff78e0575fc342e13688308c542aa62f Pull Request resolved: #856
1 parent 130582c commit 4cc60ae

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

functorch/_src/aot_autograd.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5454

5555
def create_joint_forward_backward(fn):
5656
def joint_forward_backward(
57-
primals: List[Any], tangents: List[Any]
57+
primals: List[Any], cotangents: List[Any]
5858
) -> Tuple[List[Any], List[Any]]:
5959
# Call the forward pass
6060
outs = fn(*primals)
@@ -68,20 +68,20 @@ def joint_forward_backward(
6868
grad_primals.append(p)
6969

7070
# Get the outputs that need gradients
71-
assert len(tangents) == len(outs)
71+
assert len(cotangents) == len(outs)
7272
needed_outs = []
73-
needed_tangents = []
74-
for out, tangent in zip(outs, tangents):
73+
needed_cotangents = []
74+
for out, cotangent in zip(outs, cotangents):
7575
if isinstance(out, Tensor) and out.requires_grad:
7676
needed_outs.append(out)
77-
needed_tangents.append(tangent)
77+
needed_cotangents.append(cotangent)
7878
backward_out = []
7979
# Call the backwards pass
8080
if grad_primals:
8181
backward_out = torch.autograd.grad(
8282
needed_outs,
8383
grad_primals,
84-
grad_outputs=needed_tangents,
84+
grad_outputs=needed_cotangents,
8585
allow_unused=True,
8686
)
8787
backward_out_iter = iter(backward_out)
@@ -140,12 +140,14 @@ def create_aot_autograd_function(
140140
compiled_fw = None
141141
compiled_bw = None
142142
num_outs = None
143-
143+
joint_inputs = None
144+
fw_outs = None
145+
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
144146
class CompiledFunction(torch.autograd.Function):
145147
@staticmethod
146148
@disable_torchdynamo
147149
def forward(ctx, *flat_tensor_args):
148-
nonlocal compiled_fw, compiled_bw, num_outs
150+
nonlocal compiled_fw, num_outs, joint_inputs, fw_outs
149151
if compiled_fw is None:
150152
with torch.set_grad_enabled(grad_state):
151153
out = flat_fn(*flat_tensor_args)
@@ -159,29 +161,34 @@ def forward(ctx, *flat_tensor_args):
159161
num_outs = 1
160162

161163
joint_inputs = (flat_tensor_args, out)
162-
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
164+
# Need it because autograd.Function disables grad in forward
163165
with torch.set_grad_enabled(grad_state):
164166
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
165167
*joint_inputs
166168
)
167169
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
168-
# print(fw_module.code, bw_module.code)
169170

170171
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
171172
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
172-
173-
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
174-
compiled_bw = bw_compiler(bw_module, bw_args)
173+
if partition_fn is default_partition:
174+
nonlocal compiled_bw
175+
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
176+
compiled_bw = bw_compiler(bw_module, bw_args)
175177
else:
176178
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
177179
ctx.save_for_backward(*fw_outs[num_outs:])
178180
return tuple(fw_outs[0:num_outs])
179181

180182
@staticmethod
181183
@disable_torchdynamo
182-
def backward(ctx, *flat_args):
183-
contiguous_args = [t.contiguous() for t in flat_args]
184-
# contiguous_args = [t for t in flat_args]
184+
def backward(ctx, *flat_grad_outs):
185+
nonlocal compiled_bw
186+
contiguous_args = [t.contiguous() for t in flat_grad_outs]
187+
if compiled_bw is None:
188+
with torch.set_grad_enabled(grad_state):
189+
fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args)
190+
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
191+
compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args)
185192
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
186193
return tuple(out)
187194

0 commit comments

Comments
 (0)