Skip to content

Commit dd0a862

Browse files
committed
Separate forward and backwad compilation for default partition
ghstack-source-id: 248de2f577fe3d61d4f2c40dc04570978dcc1543 Pull Request resolved: #856
1 parent 130582c commit dd0a862

File tree

3 files changed

+101
-26
lines changed

3 files changed

+101
-26
lines changed

functorch/_src/aot_autograd.py

+56-19
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5454

5555
def create_joint_forward_backward(fn):
5656
def joint_forward_backward(
57-
primals: List[Any], tangents: List[Any]
57+
primals: List[Any], cotangents: List[Any]
5858
) -> Tuple[List[Any], List[Any]]:
5959
# Call the forward pass
6060
outs = fn(*primals)
@@ -68,21 +68,21 @@ def joint_forward_backward(
6868
grad_primals.append(p)
6969

7070
# Get the outputs that need gradients
71-
assert len(tangents) == len(outs)
71+
assert len(cotangents) == len(outs)
7272
needed_outs = []
73-
needed_tangents = []
74-
for out, tangent in zip(outs, tangents):
73+
needed_cotangents = []
74+
for out, cotangent in zip(outs, cotangents):
7575
if isinstance(out, Tensor) and out.requires_grad:
7676
needed_outs.append(out)
77-
needed_tangents.append(tangent)
77+
needed_cotangents.append(cotangent)
7878
backward_out = []
7979
# Call the backwards pass
8080
if grad_primals:
8181
backward_out = torch.autograd.grad(
8282
needed_outs,
8383
grad_primals,
84-
grad_outputs=needed_tangents,
85-
allow_unused=True,
84+
grad_outputs=needed_cotangents,
85+
allow_unused=True
8686
)
8787
backward_out_iter = iter(backward_out)
8888
return outs, [
@@ -140,12 +140,13 @@ def create_aot_autograd_function(
140140
compiled_fw = None
141141
compiled_bw = None
142142
num_outs = None
143+
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
143144

144145
class CompiledFunction(torch.autograd.Function):
145146
@staticmethod
146147
@disable_torchdynamo
147148
def forward(ctx, *flat_tensor_args):
148-
nonlocal compiled_fw, compiled_bw, num_outs
149+
nonlocal compiled_fw, num_outs
149150
if compiled_fw is None:
150151
with torch.set_grad_enabled(grad_state):
151152
out = flat_fn(*flat_tensor_args)
@@ -159,31 +160,67 @@ def forward(ctx, *flat_tensor_args):
159160
num_outs = 1
160161

161162
joint_inputs = (flat_tensor_args, out)
162-
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
163+
# Need it because autograd.Function disables grad in forward
163164
with torch.set_grad_enabled(grad_state):
164165
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
165166
*joint_inputs
166167
)
167168
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
168-
# print(fw_module.code, bw_module.code)
169169

170170
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
171171
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
172-
173-
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
174-
compiled_bw = bw_compiler(bw_module, bw_args)
172+
if partition_fn is default_partition:
173+
ctx.num_intermediate = len(fw_outs[num_outs:])
174+
ctx.num_inputs = len(flat_tensor_args)
175+
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
176+
ctx.fx_g = fx_g
177+
ctx.save_for_backward(*to_be_saved)
178+
ctx.fwd_graph = fw_module.code
179+
ctx.bw_graph = bw_module.code
180+
else:
181+
nonlocal compiled_bw
182+
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
183+
compiled_bw = bw_compiler(bw_module, bw_args)
184+
ctx.save_for_backward(*fw_outs[num_outs:])
175185
else:
176186
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
177-
ctx.save_for_backward(*fw_outs[num_outs:])
187+
if partition_fn is default_partition:
188+
with torch.set_grad_enabled(grad_state):
189+
out = flat_fn(*flat_tensor_args)
190+
out = pytree.tree_map(
191+
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x, out
192+
)
193+
ctx.num_intermediate = len(fw_outs[num_outs:])
194+
ctx.num_inputs = len(flat_tensor_args)
195+
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + out
196+
ctx.save_for_backward(*to_be_saved)
197+
else:
198+
ctx.save_for_backward(*fw_outs[num_outs:])
178199
return tuple(fw_outs[0:num_outs])
179200

180201
@staticmethod
181202
@disable_torchdynamo
182-
def backward(ctx, *flat_args):
183-
contiguous_args = [t.contiguous() for t in flat_args]
184-
# contiguous_args = [t for t in flat_args]
185-
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
186-
return tuple(out)
203+
def backward(ctx, *flat_grad_outs):
204+
contiguous_args = [t.contiguous() for t in flat_grad_outs]
205+
if compiled_bw is None:
206+
assert partition_fn is default_partition
207+
with torch.set_grad_enabled(grad_state):
208+
inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs]
209+
fx_g = make_fx(joint_forward_backward, aot_decompositions)(inputs, contiguous_args)
210+
fw_module, bw_module = partition_fn(fx_g, ctx.saved_tensors[ctx.num_intermediate:])
211+
assert fx_g.code == ctx.fx_g.code
212+
f = aot_function(bw_module, bw_compiler, bw_compiler, partition_fn, aot_decompositions)
213+
print("INPUTS----->", *ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
214+
print(bw_module.code)
215+
out = f(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
216+
return out
217+
else:
218+
if partition_fn is default_partition:
219+
out = normalize_as_list(compiled_bw(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args))
220+
else:
221+
assert not torch.is_grad_enabled()
222+
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
223+
return tuple(out)
187224

188225
return CompiledFunction
189226

functorch/_src/partitioners.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def default_partition(
153153
saved_values.append(user)
154154
else:
155155
saved_values.append(node)
156-
saved_values = list(set(saved_values))
156+
saved_values = list(saved_values)
157157

158158
return _extract_fwd_bwd_modules(joint_module, saved_values)
159159

test/test_pythonkey.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -246,25 +246,57 @@ def f(args, kwargs):
246246

247247
def _outs_and_grads(fn, inps):
248248
outs = fn(*inps)
249+
diff_outs = []
249250
for out in pytree.tree_flatten(outs)[0]:
250251
if isinstance(out, torch.Tensor) and out.requires_grad:
251-
out.sum().backward(retain_graph=True)
252-
grads = [inp.grad for inp in pytree.tree_flatten(inps)[0]]
253-
for inp in pytree.tree_flatten(inps)[0]:
254-
inp.grad = None
252+
diff_outs.append(out)
253+
def full_reduce(outs):
254+
res = 0
255+
for out in outs:
256+
res=res+out.sum()
257+
return res
258+
# print(inps)
259+
grads = torch.autograd.grad(full_reduce(diff_outs), pytree.tree_flatten(inps)[0], create_graph=True)
255260
return outs, grads
256261

262+
def _outs_and_grads_and_grad_grads(fn, inps):
263+
outs = fn(*inps)
264+
diff_outs = []
265+
diff_inps = []
266+
for out in pytree.tree_flatten(outs)[0]:
267+
if isinstance(out, torch.Tensor) and out.requires_grad:
268+
diff_outs.append(out)
269+
for inp in pytree.tree_flatten(inps)[0]:
270+
if isinstance(inp, torch.Tensor) and inp.requires_grad:
271+
diff_inps.append(inp)
272+
def full_reduce(outs):
273+
res = 0
274+
# print("entering full_reduce: ", type(outs))
275+
for out in outs:
276+
res=res+out.sum()
277+
return res
278+
print("diff_outs, diff_inps: ", diff_outs, diff_inps)
279+
grads = torch.autograd.grad(diff_outs, diff_inps, create_graph=True)
280+
# print("grad call with: ", full_reduce(diff_outs), diff_inps)
281+
diff_grads = []
282+
for grad_ in grads:
283+
if isinstance(grad_, torch.Tensor) and grad_.requires_grad:
284+
diff_grads.append(grad_)
285+
# print("grad grad call with: ", grads, full_reduce(diff_grads), diff_inps)
286+
grad_grads = torch.autograd.grad(diff_grads, diff_inps)
287+
return outs, grads, grad_grads
257288

258289
class TestAOTAutograd(TestCase):
259290
def verify_aot_autograd(self, f, inp):
260291
if isinstance(f, nn.Module):
261292
compiled_f = aot_module(f, nop)
262293
else:
263294
compiled_f = aot_function(f, nop)
264-
ref_out, ref_grad = _outs_and_grads(f, inp)
265-
test_out, test_grad = _outs_and_grads(compiled_f, inp)
295+
ref_out, ref_grad, ref_grad_grad = _outs_and_grads_and_grad_grads(f, inp)
296+
test_out, test_grad, test_grad_grad = _outs_and_grads_and_grad_grads(compiled_f, inp)
266297
self.assertEqual(ref_out, test_out)
267298
self.assertEqual(ref_grad, test_grad)
299+
self.assertEqual(ref_grad_grad, test_grad_grad)
268300

269301
def test_single_output(self):
270302
def f(a, b):
@@ -284,6 +316,12 @@ def f(a, b):
284316
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
285317
self.verify_aot_autograd(f, inp)
286318

319+
def test_cube(self):
320+
def f(a):
321+
return a ** 3
322+
inp = [torch.tensor(2.3, requires_grad=True)]
323+
self.verify_aot_autograd(f, inp)
324+
287325
def test_no_grad_input_output(self):
288326
def f(a, b):
289327
return a.cos(), b.cos(), a * b

0 commit comments

Comments
 (0)