Skip to content

Commit 4d7096b

Browse files
committed
Separate forward and backwad compilation for default partition
ghstack-source-id: c24ee1b8c252d9aebe99b0beb9139dd3eb223dd4 Pull Request resolved: #856
1 parent 130582c commit 4d7096b

File tree

2 files changed

+107
-25
lines changed

2 files changed

+107
-25
lines changed

functorch/_src/aot_autograd.py

+72-19
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5454

5555
def create_joint_forward_backward(fn):
5656
def joint_forward_backward(
57-
primals: List[Any], tangents: List[Any]
57+
primals: List[Any], cotangents: List[Any]
5858
) -> Tuple[List[Any], List[Any]]:
5959
# Call the forward pass
6060
outs = fn(*primals)
@@ -68,21 +68,21 @@ def joint_forward_backward(
6868
grad_primals.append(p)
6969

7070
# Get the outputs that need gradients
71-
assert len(tangents) == len(outs)
71+
assert len(cotangents) == len(outs)
7272
needed_outs = []
73-
needed_tangents = []
74-
for out, tangent in zip(outs, tangents):
73+
needed_cotangents = []
74+
for out, cotangent in zip(outs, cotangents):
7575
if isinstance(out, Tensor) and out.requires_grad:
7676
needed_outs.append(out)
77-
needed_tangents.append(tangent)
77+
needed_cotangents.append(cotangent)
7878
backward_out = []
7979
# Call the backwards pass
8080
if grad_primals:
8181
backward_out = torch.autograd.grad(
8282
needed_outs,
8383
grad_primals,
84-
grad_outputs=needed_tangents,
85-
allow_unused=True,
84+
grad_outputs=needed_cotangents,
85+
allow_unused=True
8686
)
8787
backward_out_iter = iter(backward_out)
8888
return outs, [
@@ -140,12 +140,13 @@ def create_aot_autograd_function(
140140
compiled_fw = None
141141
compiled_bw = None
142142
num_outs = None
143+
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
143144

144145
class CompiledFunction(torch.autograd.Function):
145146
@staticmethod
146147
@disable_torchdynamo
147148
def forward(ctx, *flat_tensor_args):
148-
nonlocal compiled_fw, compiled_bw, num_outs
149+
nonlocal compiled_fw, num_outs
149150
if compiled_fw is None:
150151
with torch.set_grad_enabled(grad_state):
151152
out = flat_fn(*flat_tensor_args)
@@ -159,31 +160,83 @@ def forward(ctx, *flat_tensor_args):
159160
num_outs = 1
160161

161162
joint_inputs = (flat_tensor_args, out)
162-
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
163+
# Need it because autograd.Function disables grad in forward
163164
with torch.set_grad_enabled(grad_state):
164165
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
165166
*joint_inputs
166167
)
167168
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
168-
# print(fw_module.code, bw_module.code)
169169

170170
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
171171
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
172-
173-
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
174-
compiled_bw = bw_compiler(bw_module, bw_args)
172+
if partition_fn is default_partition:
173+
print("ENTERING default_partition")
174+
ctx.num_intermediate = len(fw_outs[num_outs:])
175+
ctx.num_inputs = len(flat_tensor_args)
176+
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
177+
print("fw outs: ", fw_outs, "-------")
178+
ctx.save_for_backward(*to_be_saved)
179+
ctx.fwd_graph = fw_module.code
180+
else:
181+
nonlocal compiled_bw
182+
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
183+
compiled_bw = bw_compiler(bw_module, bw_args)
184+
ctx.save_for_backward(*fw_outs[num_outs:])
175185
else:
176186
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
177-
ctx.save_for_backward(*fw_outs[num_outs:])
187+
if partition_fn is default_partition:
188+
with torch.set_grad_enabled(grad_state):
189+
out = flat_fn(*flat_tensor_args)
190+
out = pytree.tree_map(
191+
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
192+
)
193+
ctx.num_intermediate = len(fw_outs[num_outs:])
194+
ctx.num_inputs = len(flat_tensor_args)
195+
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
196+
ctx.save_for_backward(*to_be_saved)
197+
else:
198+
ctx.save_for_backward(*fw_outs[num_outs:])
178199
return tuple(fw_outs[0:num_outs])
179200

180201
@staticmethod
181202
@disable_torchdynamo
182-
def backward(ctx, *flat_args):
183-
contiguous_args = [t.contiguous() for t in flat_args]
184-
# contiguous_args = [t for t in flat_args]
185-
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
186-
return tuple(out)
203+
def backward(ctx, *flat_grad_outs):
204+
print(flat_grad_outs)
205+
contiguous_args = [t.contiguous() for t in flat_grad_outs]
206+
if compiled_bw is None:
207+
assert partition_fn is default_partition
208+
with torch.set_grad_enabled(grad_state):
209+
inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs]
210+
fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args)
211+
# assert that the forward graph generated here is the same
212+
# if it's specified that the user might want to calculate double backwards
213+
fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:])
214+
print(fw_module.code)
215+
print(ctx.fwd_graph)
216+
assert fw_module.code == ctx.fwd_graph
217+
func_code = bw_module.code.split('self, ')
218+
# print(func_code[0] + func_code[1])
219+
exec(func_code[0] + func_code[1], globals())
220+
f = create_aot_autograd_function(forward, bw_compiler, bw_compiler, partition_fn, aot_decompositions, grad_state)
221+
# print(bw_module.code, *ctx.saved_tensors, contiguous_args)
222+
# print(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
223+
# print(*ctx.saved_tensors[ctx.num_intermediate:], *contiguous_args)
224+
return f.apply(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
225+
else:
226+
assert not torch.is_grad_enabled()
227+
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
228+
return tuple(out)
229+
# nonlocal compiled_bw
230+
# contiguous_args = [t.contiguous() for t in flat_grad_outs]
231+
# if compiled_bw is None:
232+
# with torch.set_grad_enabled(grad_state):
233+
# fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args)
234+
# # assert that the forward graph generated here is the same
235+
# # if it's specified that the user might want to calculate double backwards
236+
# fw_module, bw_module = partition_fn(fx_g, joint_inputs)
237+
# compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args)
238+
# out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
239+
# return tuple(out)
187240

188241
return CompiledFunction
189242

test/test_pythonkey.py

+35-6
Original file line numberDiff line numberDiff line change
@@ -246,25 +246,54 @@ def f(args, kwargs):
246246

247247
def _outs_and_grads(fn, inps):
248248
outs = fn(*inps)
249+
diff_outs = []
249250
for out in pytree.tree_flatten(outs)[0]:
250251
if isinstance(out, torch.Tensor) and out.requires_grad:
251-
out.sum().backward(retain_graph=True)
252-
grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]]
253-
for inp in pytree.tree_flatten(inps)[0]:
254-
inp.grad = None
252+
diff_outs.append(out)
253+
def full_reduce(outs):
254+
res = 0
255+
for out in outs:
256+
res=res+out.sum()
257+
return res
258+
print(inps)
259+
grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True)
255260
return outs, grads
256261

262+
def _outs_and_grads_and_grad_grads(fn, inps):
263+
outs = fn(*inps)
264+
diff_outs = []
265+
diff_inps = []
266+
for out in pytree.tree_flatten(outs)[0]:
267+
if isinstance(out, torch.Tensor) and out.requires_grad:
268+
diff_outs.append(out)
269+
for inp in pytree.tree_flatten(inps)[0]:
270+
if isinstance(inp, torch.Tensor) and inp.requires_grad:
271+
diff_inps.append(inp)
272+
def full_reduce(outs):
273+
res = 0
274+
for out in outs:
275+
res=res+out.sum()
276+
return res
277+
grads = torch.autograd.grad(full_reduce(diff_outs), diff_inps, create_graph=True)
278+
print("grads: ", grads)
279+
diff_grads = []
280+
for grad_ in grads:
281+
if isinstance(grad_, torch.Tensor) and grad_.requires_grad:
282+
diff_grads.append(grad_)
283+
grad_grads = torch.autograd.grad(full_reduce(diff_grads), diff_inps)
284+
return outs, grads, grad_grads
257285

258286
class TestAOTAutograd(TestCase):
259287
def verify_aot_autograd(self, f, inp):
260288
if isinstance(f, nn.Module):
261289
compiled_f = aot_module(f, nop)
262290
else:
263291
compiled_f = aot_function(f, nop)
264-
ref_out, ref_grad = _outs_and_grads(f, inp)
265-
test_out, test_grad = _outs_and_grads(compiled_f, inp)
292+
ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp)
293+
test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp)
266294
self.assertEqual(ref_out, test_out)
267295
self.assertEqual(ref_grad, test_grad)
296+
# self.assertEqual(ref_grad_grad, test_grad_grad)
268297

269298
def test_single_output(self):
270299
def f(a, b):

0 commit comments

Comments
 (0)