Skip to content

Commit d8e3b16

Browse files
committed
Separate forward and backwad compilation
ghstack-source-id: f154dd1cbab518acd5890090ca081db1ec7fa20a Pull Request resolved: #856
1 parent ca3ac11 commit d8e3b16

File tree

4 files changed

+257
-93
lines changed

4 files changed

+257
-93
lines changed

functorch/_src/aot_autograd.py

+84-34
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from contextlib import contextmanager
22
import torch
33
import torch.nn as nn
4-
from torch import Tensor
4+
from torch import Tensor, is_grad_enabled
55
from functorch import make_fx
66
from torch.fx import immutable_collections
77
import torch.utils._pytree as pytree
@@ -11,7 +11,7 @@
1111
from functorch.experimental import functionalize
1212
from . import config
1313
from .decompositions import register_decomposition
14-
from .partitioners import default_partition
14+
from .partitioners import default_partition, _get_saved_values, _extract_fwd_bwd_modules, _extract_fwd_bwd_modules_db
1515
from .named_members_polyfill import _named_parameters, _named_buffers
1616
from typing import Callable, List, Dict, Any, Tuple, Optional
1717
from functools import wraps
@@ -70,7 +70,7 @@ def preserve_rng_state():
7070

7171
def create_joint_forward_backward(fn):
7272
def joint_forward_backward(
73-
primals: List[Any], tangents: List[Any]
73+
primals: List[Any], cotangents: List[Any]
7474
) -> Tuple[List[Any], List[Any]]:
7575
# Call the forward pass
7676
outs = fn(*primals)
@@ -84,21 +84,21 @@ def joint_forward_backward(
8484
grad_primals.append(p)
8585

8686
# Get the outputs that need gradients
87-
assert len(tangents) == len(outs)
87+
assert len(cotangents) == len(outs)
8888
needed_outs = []
89-
needed_tangents = []
90-
for out, tangent in zip(outs, tangents):
89+
needed_cotangents = []
90+
for out, cotangent in zip(outs, cotangents):
9191
if isinstance(out, Tensor) and out.requires_grad:
9292
needed_outs.append(out)
93-
needed_tangents.append(tangent)
93+
needed_cotangents.append(cotangent)
9494
backward_out = []
9595
# Call the backwards pass
9696
if grad_primals:
9797
backward_out = torch.autograd.grad(
9898
needed_outs,
9999
grad_primals,
100-
grad_outputs=needed_tangents,
101-
allow_unused=True,
100+
grad_outputs=needed_cotangents,
101+
allow_unused=True
102102
)
103103
backward_out_iter = iter(backward_out)
104104
return outs, [
@@ -152,16 +152,21 @@ def create_aot_autograd_function(
152152
if decompositions is None:
153153
decompositions = {}
154154
joint_forward_backward = create_joint_forward_backward(flat_fn)
155-
155+
# create_joint_forward_backward takes inputs and cotangents as inps
156+
# inps: inputs, cotangents: flat_grad_outs
157+
j_b = None
156158
compiled_fw = None
157-
compiled_bw = None
159+
bw_modules = []
158160
num_outs = None
161+
saved_value_names = None
162+
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
159163

160164
class CompiledFunction(torch.autograd.Function):
161165
@staticmethod
162166
@disable_torchdynamo
163167
def forward(ctx, *flat_tensor_args):
164-
nonlocal compiled_fw, compiled_bw, num_outs
168+
# ctx.set_materialize_grads(False)
169+
nonlocal compiled_fw, num_outs, saved_value_names, aot_decompositions, j_b
165170
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
166171
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
167172
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
@@ -184,8 +189,9 @@ def forward(ctx, *flat_tensor_args):
184189
num_outs = 1
185190

186191
joint_inputs = (flat_tensor_args, out)
187-
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
188192
with torch.set_grad_enabled(grad_state):
193+
# This means the forward and backward graphs are created based on the input fn
194+
# However we need to take in grad_out for the saved intermediates as well.
189195
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
190196
*joint_inputs
191197
)
@@ -196,34 +202,79 @@ def forward(ctx, *flat_tensor_args):
196202
def fake_fn(primals, tangents):
197203
return fx_g(primals, tangents)
198204
fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)
199-
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
200-
# print(fw_module.code, bw_module.code)
201-
205+
fw_module, bw_module, saved_value_nodes = partition_fn(fx_g, joint_inputs)
206+
saved_value_names = [node.name for node in saved_value_nodes]
202207
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
203208
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
204-
205-
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
206-
compiled_bw = bw_compiler(bw_module, bw_args)
209+
j_b = create_joint_forward_backward(fw_module)
207210
else:
208211
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
212+
ctx.num_intermediate = len(fw_outs[num_outs:])
213+
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args)
214+
ctx.save_for_backward(*to_be_saved)
209215
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
210-
ctx.save_for_backward(*fw_outs[num_outs:])
211-
return tuple(fw_outs[0:num_outs])
216+
return tuple(fw_outs)
212217

213218
@staticmethod
214219
@disable_torchdynamo
215-
def backward(ctx, *flat_args):
220+
def backward(ctx, *flat_grad_outs):
216221
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
217222
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
218223
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
219-
contiguous_args = [t.contiguous() for t in flat_args]
220-
# contiguous_args = [t for t in flat_args]
221-
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
224+
nonlocal bw_modules, saved_value_names, num_outs, aot_decompositions, j_b
225+
with preserve_rng_state():
226+
intermediates = ctx.saved_tensors[:ctx.num_intermediate]
227+
flat_tensor_args = ctx.saved_tensors[ctx.num_intermediate:]
228+
flat_tensor_args = pytree.tree_map(
229+
lambda x: x.detach().requires_grad_(x.requires_grad)
230+
if isinstance(x, Tensor) else x, flat_tensor_args
231+
)
232+
inp_grad_outs = pytree.tree_map(
233+
lambda x: x.detach() if isinstance(x, Tensor) else x, flat_grad_outs
234+
)
235+
# inp_grad_outs = flat_grad_outs
236+
with torch.set_grad_enabled(grad_state):
237+
fx_g_b = make_fx(j_b, aot_decompositions)(flat_tensor_args, inp_grad_outs)
238+
if config.use_functionalize:
239+
# Functionalize the foward backward graph. First create a
240+
# fake fn to make functionalize happy
241+
def fake_fn(primals, tangents):
242+
return fx_g(primals, tangents)
243+
fx_g = make_fx(functionalize(fake_fn))(flat_tensor_args, inp_grad_outs)
244+
saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names)
245+
assert len(saved_value_nodes) <= len(saved_value_names)
246+
fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules(fx_g_b, saved_value_nodes)
247+
if len(saved_values_new) != len(saved_value_names):
248+
new_intermediates = []
249+
# Forward saves more intermediates than needed
250+
assert len(saved_values_new) < len(saved_value_names)
251+
j = 0
252+
for node in saved_values_new:
253+
while node.name != saved_value_names[j]:
254+
j+=1
255+
new_intermediates.append(intermediates[j])
256+
j+=1
257+
intermediates = new_intermediates
258+
259+
# This is needed because aot function caching uses function id right now
260+
bw_module_fn = None
261+
for elem in bw_modules:
262+
if elem.code == bw_module_b.code:
263+
bw_module_fn = elem
264+
break
265+
if bw_module_fn is None:
266+
bw_modules.append(bw_module_b)
267+
bw_module_fn = bw_module_b
268+
269+
f = aot_function(bw_module_fn, bw_compiler, bw_compiler, partition_fn, aot_decompositions)
270+
out = f(*intermediates, *flat_grad_outs)
222271
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
223-
return tuple(out)
224-
225-
return CompiledFunction
272+
return tuple(normalize_as_list(out))
226273

274+
def return_fn(*args, **kwargs):
275+
out = CompiledFunction.apply(*args, **kwargs)
276+
return out[0:num_outs]
277+
return return_fn
227278

228279
class _CompileCache(CompileCache):
229280
pass
@@ -312,7 +363,7 @@ def rearrange(tensor_args, static_args, static_argnums):
312363
return args
313364

314365

315-
KNOWN_TYPES = [torch.Tensor, int, str, float, bool]
366+
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, None]
316367

317368

318369
def aot_function(
@@ -448,7 +499,6 @@ def returned_function(*args, **kwargs):
448499
hasher_type,
449500
*flat_args_for_cache,
450501
)
451-
452502
# Compile the function and save it in the cache
453503
if cached_res is None:
454504
# Save the args_spec for flat_tensor_args to unflatten while tracing
@@ -473,7 +523,7 @@ def flat_fn(*flat_tensor_args):
473523
for i in flat_out:
474524
is_known_type = False
475525
for j in KNOWN_TYPES:
476-
if isinstance(i, j):
526+
if j is None or isinstance(i, j):
477527
is_known_type = True
478528
break
479529
if not is_known_type:
@@ -495,7 +545,7 @@ def flat_fn(*flat_tensor_args):
495545
partition_fn,
496546
decompositions,
497547
grad_state=torch.is_grad_enabled(),
498-
).apply
548+
)
499549
cached_res = (compiled_fn, out_spec)
500550

501551
# Save the compiled_fn in the cache
@@ -635,7 +685,7 @@ def aot_function_simplified(
635685
partition_fn,
636686
decompositions,
637687
grad_state=torch.is_grad_enabled(),
638-
).apply
688+
)
639689

640690
return compiled_fn
641691

@@ -657,4 +707,4 @@ def forward(self, *args, **kwargs):
657707

658708

659709
compiled_function = aot_function
660-
compiled_module = aot_module
710+
compiled_module = aot_module

functorch/_src/partitioners.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,52 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values):
109109

110110
fwd_module = fx.GraphModule(joint_module, fwd_graph)
111111
bwd_module = fx.GraphModule(joint_module, bwd_graph)
112-
return fwd_module, bwd_module
112+
return fwd_module, bwd_module, saved_values
113113

114+
def _extract_fwd_bwd_modules_db(joint_module: fx.GraphModule, saved_values):
115+
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module)
116+
print("FWD OUTS: ", fwd_outputs)
117+
print("BWD OUTS: ", bwd_outputs)
118+
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
119+
tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes))
120+
print("primal_inputs: ", primal_inputs)
121+
print("tangent_inputs: ", tangent_inputs)
122+
# Construct the forward module
123+
fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs)
124+
bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_values + tangent_inputs, bwd_outputs)
125+
126+
# This is to filter out saved values that don't actually end up being used by the backwards pass
127+
for node in bwd_graph.nodes:
128+
if node.op == 'placeholder' and not node.users:
129+
for saved_value in saved_values:
130+
if saved_value.name == node.name:
131+
saved_values.remove(saved_value)
132+
break
133+
134+
# Now, we re-generate the fwd/bwd graphs.
135+
# NB: This might increase compilation time, but I doubt it matters
136+
fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs)
137+
bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_values + tangent_inputs, bwd_outputs)
138+
139+
fwd_module = fx.GraphModule(joint_module, fwd_graph)
140+
bwd_module = fx.GraphModule(joint_module, bwd_graph)
141+
return fwd_module, bwd_module, saved_values
142+
143+
def _get_saved_values(new_module: fx.GraphModule, saved_value_names):
144+
saved_values = []
145+
for node in new_module.graph.nodes:
146+
if node.name in saved_value_names:
147+
if 'tensor_meta' not in node.meta and node.op == 'call_function':
148+
users = node.users
149+
assert all(user.target == operator.getitem for user in users)
150+
for user in users:
151+
saved_values.append(user)
152+
else:
153+
saved_values.append(node)
154+
155+
saved_values = list(saved_values)
156+
157+
return saved_values
114158

115159
def default_partition(
116160
joint_module: fx.GraphModule, _joint_inputs
@@ -154,8 +198,8 @@ def default_partition(
154198
saved_values.append(user)
155199
else:
156200
saved_values.append(node)
157-
saved_values = list(set(saved_values))
158201

202+
saved_values = list(saved_values)
159203
return _extract_fwd_bwd_modules(joint_module, saved_values)
160204

161205

0 commit comments

Comments
 (0)