Skip to content

Commit e1301b9

Browse files
committed
Separate forward and backwad compilation for default partition
ghstack-source-id: 3197eae34a9003563e0b56235540c17a0ba0617c Pull Request resolved: #856
1 parent 130582c commit e1301b9

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

functorch/_src/aot_autograd.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5353

5454

5555
def create_joint_forward_backward(fn):
56+
# tangents are just grad_outs/cotangents (wrong naming)
5657
def joint_forward_backward(
5758
primals: List[Any], tangents: List[Any]
5859
) -> Tuple[List[Any], List[Any]]:
@@ -140,12 +141,14 @@ def create_aot_autograd_function(
140141
compiled_fw = None
141142
compiled_bw = None
142143
num_outs = None
143-
144+
joint_inputs = None
145+
fw_outs = None
146+
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
144147
class CompiledFunction(torch.autograd.Function):
145148
@staticmethod
146149
@disable_torchdynamo
147150
def forward(ctx, *flat_tensor_args):
148-
nonlocal compiled_fw, compiled_bw, num_outs
151+
nonlocal compiled_fw, num_outs, joint_inputs, fw_outs
149152
if compiled_fw is None:
150153
with torch.set_grad_enabled(grad_state):
151154
out = flat_fn(*flat_tensor_args)
@@ -159,29 +162,34 @@ def forward(ctx, *flat_tensor_args):
159162
num_outs = 1
160163

161164
joint_inputs = (flat_tensor_args, out)
162-
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
165+
# Need it because autograd.Function disables grad in forward
163166
with torch.set_grad_enabled(grad_state):
164167
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
165168
*joint_inputs
166169
)
167170
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
168-
# print(fw_module.code, bw_module.code)
169171

170172
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
171173
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
172-
173-
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
174-
compiled_bw = bw_compiler(bw_module, bw_args)
174+
if partition_fn is default_partition:
175+
nonlocal compiled_bw
176+
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
177+
compiled_bw = bw_compiler(bw_module, bw_args)
175178
else:
176179
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
177180
ctx.save_for_backward(*fw_outs[num_outs:])
178181
return tuple(fw_outs[0:num_outs])
179182

180183
@staticmethod
181184
@disable_torchdynamo
182-
def backward(ctx, *flat_args):
183-
contiguous_args = [t.contiguous() for t in flat_args]
184-
# contiguous_args = [t for t in flat_args]
185+
def backward(ctx, *flat_grad_outs):
186+
nonlocal compiled_bw
187+
contiguous_args = [t.contiguous() for t in flat_grad_outs]
188+
if compiled_bw is None:
189+
with torch.set_grad_enabled(grad_state):
190+
fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args)
191+
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
192+
compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args)
185193
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
186194
return tuple(out)
187195

0 commit comments

Comments
 (0)