Skip to content

Commit 55baed5

Browse files
committed
Separate forward and backwad compilation
ghstack-source-id: 0b78895219a89ec3841cf8b0804e1be69bfeed8a Pull Request resolved: #856
1 parent 44dd1bf commit 55baed5

File tree

4 files changed

+218
-82
lines changed

4 files changed

+218
-82
lines changed

functorch/_src/aot_autograd.py

+83-31
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from functorch.experimental import functionalize
1212
from . import config
1313
from .decompositions import register_decomposition
14-
from .partitioners import default_partition
14+
from .partitioners import default_partition, _get_saved_values, _extract_fwd_bwd_modules
1515
from .named_members_polyfill import _named_parameters, _named_buffers
1616
from typing import Callable, List, Dict, Any, Tuple, Optional
1717
from functools import wraps
@@ -70,7 +70,7 @@ def preserve_rng_state():
7070

7171
def create_joint_forward_backward(fn):
7272
def joint_forward_backward(
73-
primals: List[Any], tangents: List[Any]
73+
primals: List[Any], cotangents: List[Any]
7474
) -> Tuple[List[Any], List[Any]]:
7575
# Call the forward pass
7676
outs = fn(*primals)
@@ -84,21 +84,21 @@ def joint_forward_backward(
8484
grad_primals.append(p)
8585

8686
# Get the outputs that need gradients
87-
assert len(tangents) == len(outs)
87+
assert len(cotangents) == len(outs)
8888
needed_outs = []
89-
needed_tangents = []
90-
for out, tangent in zip(outs, tangents):
89+
needed_cotangents = []
90+
for out, cotangent in zip(outs, cotangents):
9191
if isinstance(out, Tensor) and out.requires_grad:
9292
needed_outs.append(out)
93-
needed_tangents.append(tangent)
93+
needed_cotangents.append(cotangent)
9494
backward_out = []
9595
# Call the backwards pass
9696
if grad_primals:
9797
backward_out = torch.autograd.grad(
9898
needed_outs,
9999
grad_primals,
100-
grad_outputs=needed_tangents,
101-
allow_unused=True,
100+
grad_outputs=needed_cotangents,
101+
allow_unused=True
102102
)
103103
backward_out_iter = iter(backward_out)
104104
return outs, [
@@ -152,22 +152,31 @@ def create_aot_autograd_function(
152152
if decompositions is None:
153153
decompositions = {}
154154
joint_forward_backward = create_joint_forward_backward(flat_fn)
155-
155+
# create_joint_forward_backward takes inputs and cotangents as inps
156+
# inps: inputs, cotangents: flat_grad_outs
157+
j_b = None
156158
compiled_fw = None
157-
compiled_bw = None
159+
bw_modules = []
158160
num_outs = None
161+
saved_value_names = None
162+
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
159163

160164
class CompiledFunction(torch.autograd.Function):
161165
@staticmethod
162166
@disable_torchdynamo
163167
def forward(ctx, *flat_tensor_args):
164-
nonlocal compiled_fw, compiled_bw, num_outs
168+
# ctx.set_materialize_grads(False)
169+
nonlocal compiled_fw, num_outs, saved_value_names, aot_decompositions, j_b
165170
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
166171
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
167172
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
173+
# creating this to save the original inputs since the inputs might be returned as outs
174+
# and would then have grad_fn set on them which is incorrect.
175+
flat_tensor_args_0 = flat_tensor_args
168176
if compiled_fw is None:
169177
with preserve_rng_state():
170178
# Set input tensors that require grad to leaves
179+
# Detach to not accidentally extend the graph
171180
flat_tensor_args = pytree.tree_map(
172181
lambda x: x.detach().requires_grad_(x.requires_grad)
173182
if isinstance(x, Tensor) else x, flat_tensor_args
@@ -184,8 +193,9 @@ def forward(ctx, *flat_tensor_args):
184193
num_outs = 1
185194

186195
joint_inputs = (flat_tensor_args, out)
187-
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
188196
with torch.set_grad_enabled(grad_state):
197+
# This means the forward and backward graphs are created based on the input fn
198+
# However we need to take in grad_out for the saved intermediates as well.
189199
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
190200
*joint_inputs
191201
)
@@ -196,33 +206,76 @@ def forward(ctx, *flat_tensor_args):
196206
def fake_fn(primals, tangents):
197207
return fx_g(primals, tangents)
198208
fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)
199-
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
200-
# print(fw_module.code, bw_module.code)
201-
209+
fw_module, bw_module, saved_value_nodes = partition_fn(fx_g, joint_inputs)
210+
saved_value_names = [node.name for node in saved_value_nodes]
202211
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
203212
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
204-
205-
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
206-
compiled_bw = bw_compiler(bw_module, bw_args)
213+
j_b = create_joint_forward_backward(fw_module)
207214
else:
208215
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
216+
ctx.num_intermediate = len(fw_outs[num_outs:])
217+
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args_0)
218+
ctx.save_for_backward(*to_be_saved)
209219
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
210-
ctx.save_for_backward(*fw_outs[num_outs:])
211-
return tuple(fw_outs[0:num_outs])
220+
return tuple(fw_outs)
212221

213222
@staticmethod
214223
@disable_torchdynamo
215-
def backward(ctx, *flat_args):
224+
def backward(ctx, *flat_grad_outs):
216225
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
217226
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
218227
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
219-
contiguous_args = [t.contiguous() for t in flat_args]
220-
# contiguous_args = [t for t in flat_args]
221-
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
228+
nonlocal bw_modules, saved_value_names, num_outs, aot_decompositions, j_b
229+
with preserve_rng_state():
230+
intermediates = ctx.saved_tensors[:ctx.num_intermediate]
231+
flat_tensor_args = ctx.saved_tensors[ctx.num_intermediate:]
232+
flat_tensor_args = pytree.tree_map(
233+
lambda x: x.detach().requires_grad_(x.requires_grad)
234+
if isinstance(x, Tensor) else x, flat_tensor_args
235+
)
236+
inp_grad_outs = flat_grad_outs
237+
with torch.set_grad_enabled(grad_state):
238+
fx_g_b = make_fx(j_b, aot_decompositions)(flat_tensor_args, inp_grad_outs)
239+
if config.use_functionalize:
240+
# Functionalize the foward backward graph. First create a
241+
# fake fn to make functionalize happy
242+
def fake_fn(primals, tangents):
243+
return fx_g_b(primals, tangents)
244+
fx_g_b = make_fx(functionalize(fake_fn))(flat_tensor_args, inp_grad_outs)
245+
saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names)
246+
assert len(saved_value_nodes) <= len(saved_value_names)
247+
fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules(fx_g_b, saved_value_nodes)
248+
if len(saved_values_new) != len(saved_value_names):
249+
new_intermediates = []
250+
# Forward saves more intermediates than needed
251+
assert len(saved_values_new) < len(saved_value_names)
252+
j = 0
253+
for node in saved_values_new:
254+
while node.name != saved_value_names[j]:
255+
j += 1
256+
new_intermediates.append(intermediates[j])
257+
j += 1
258+
intermediates = new_intermediates
259+
260+
# This is needed because aot function caching uses function id right now
261+
bw_module_fn = None
262+
for elem in bw_modules:
263+
if elem.code == bw_module_b.code:
264+
bw_module_fn = elem
265+
break
266+
if bw_module_fn is None:
267+
bw_modules.append(bw_module_b)
268+
bw_module_fn = bw_module_b
269+
270+
f = aot_function(bw_module_fn, bw_compiler, bw_compiler, partition_fn, aot_decompositions)
271+
out = f(*intermediates, *inp_grad_outs)
222272
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
223-
return tuple(out)
273+
return tuple(normalize_as_list(out))
224274

225-
return CompiledFunction
275+
def return_fn(*args, **kwargs):
276+
out = CompiledFunction.apply(*args, **kwargs)
277+
return out[0:num_outs]
278+
return return_fn
226279

227280

228281
class _CompileCache(CompileCache):
@@ -312,7 +365,7 @@ def rearrange(tensor_args, static_args, static_argnums):
312365
return args
313366

314367

315-
KNOWN_TYPES = [torch.Tensor, int, str, float, bool]
368+
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, None]
316369

317370

318371
def aot_function(
@@ -448,7 +501,6 @@ def returned_function(*args, **kwargs):
448501
hasher_type,
449502
*flat_args_for_cache,
450503
)
451-
452504
# Compile the function and save it in the cache
453505
if cached_res is None:
454506
# Save the args_spec for flat_tensor_args to unflatten while tracing
@@ -473,7 +525,7 @@ def flat_fn(*flat_tensor_args):
473525
for i in flat_out:
474526
is_known_type = False
475527
for j in KNOWN_TYPES:
476-
if isinstance(i, j):
528+
if j is None or isinstance(i, j):
477529
is_known_type = True
478530
break
479531
if not is_known_type:
@@ -495,7 +547,7 @@ def flat_fn(*flat_tensor_args):
495547
partition_fn,
496548
decompositions,
497549
grad_state=torch.is_grad_enabled(),
498-
).apply
550+
)
499551
cached_res = (compiled_fn, out_spec)
500552

501553
# Save the compiled_fn in the cache
@@ -635,7 +687,7 @@ def aot_function_simplified(
635687
partition_fn,
636688
decompositions,
637689
grad_state=torch.is_grad_enabled(),
638-
).apply
690+
)
639691

640692
return compiled_fn
641693

functorch/_src/partitioners.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,24 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values):
109109

110110
fwd_module = fx.GraphModule(joint_module, fwd_graph)
111111
bwd_module = fx.GraphModule(joint_module, bwd_graph)
112-
return fwd_module, bwd_module
112+
return fwd_module, bwd_module, saved_values
113+
114+
115+
def _get_saved_values(new_module: fx.GraphModule, saved_value_names):
116+
saved_values = []
117+
for node in new_module.graph.nodes:
118+
if node.name in saved_value_names:
119+
if 'tensor_meta' not in node.meta and node.op == 'call_function':
120+
users = node.users
121+
assert all(user.target == operator.getitem for user in users)
122+
for user in users:
123+
saved_values.append(user)
124+
else:
125+
saved_values.append(node)
126+
127+
saved_values = list(saved_values)
128+
129+
return saved_values
113130

114131

115132
def default_partition(
@@ -154,8 +171,8 @@ def default_partition(
154171
saved_values.append(user)
155172
else:
156173
saved_values.append(node)
157-
saved_values = list(set(saved_values))
158174

175+
saved_values = list(saved_values)
159176
return _extract_fwd_bwd_modules(joint_module, saved_values)
160177

161178

0 commit comments

Comments
 (0)