Skip to content

Commit 957cd6b

Browse files
committed
Separate forward and backwad compilation
ghstack-source-id: 0ce1d4ab26357b8614c57d539cdb61f1ae90a25e Pull Request resolved: #856
1 parent 130582c commit 957cd6b

File tree

4 files changed

+254
-71
lines changed

4 files changed

+254
-71
lines changed

functorch/_src/aot_autograd.py

+107-34
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import torch
22
import torch.nn as nn
3-
from torch import Tensor
3+
from torch import Tensor, is_grad_enabled
44
from functorch import make_fx
55
from torch.fx import immutable_collections
66
import torch.utils._pytree as pytree
77
import torch.utils.dlpack
88
from torch.nn.utils import _stateless
99
from functorch._C import CompileCache
1010
from .decompositions import register_decomposition
11-
from .partitioners import default_partition
11+
from .partitioners import default_partition, _get_saved_values, _extract_fwd_bwd_modules, _extract_fwd_bwd_modules_db
1212
from .named_members_polyfill import _named_parameters, _named_buffers
1313
from typing import Callable, List, Dict, Any, Tuple, Optional
1414
from functools import wraps
@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
5454

5555
def create_joint_forward_backward(fn):
5656
def joint_forward_backward(
57-
primals: List[Any], tangents: List[Any]
57+
primals: List[Any], cotangents: List[Any]
5858
) -> Tuple[List[Any], List[Any]]:
5959
# Call the forward pass
6060
outs = fn(*primals)
@@ -68,21 +68,21 @@ def joint_forward_backward(
6868
grad_primals.append(p)
6969

7070
# Get the outputs that need gradients
71-
assert len(tangents) == len(outs)
71+
assert len(cotangents) == len(outs)
7272
needed_outs = []
73-
needed_tangents = []
74-
for out, tangent in zip(outs, tangents):
73+
needed_cotangents = []
74+
for out, cotangent in zip(outs, cotangents):
7575
if isinstance(out, Tensor) and out.requires_grad:
7676
needed_outs.append(out)
77-
needed_tangents.append(tangent)
77+
needed_cotangents.append(cotangent)
7878
backward_out = []
7979
# Call the backwards pass
8080
if grad_primals:
8181
backward_out = torch.autograd.grad(
8282
needed_outs,
8383
grad_primals,
84-
grad_outputs=needed_tangents,
85-
allow_unused=True,
84+
grad_outputs=needed_cotangents,
85+
allow_unused=True
8686
)
8787
backward_out_iter = iter(backward_out)
8888
return outs, [
@@ -138,14 +138,18 @@ def create_aot_autograd_function(
138138
joint_forward_backward = create_joint_forward_backward(flat_fn)
139139

140140
compiled_fw = None
141-
compiled_bw = None
141+
bw_modules = []
142+
fw_module = None
142143
num_outs = None
144+
saved_value_names = None
145+
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
143146

144147
class CompiledFunction(torch.autograd.Function):
145148
@staticmethod
146149
@disable_torchdynamo
147150
def forward(ctx, *flat_tensor_args):
148-
nonlocal compiled_fw, compiled_bw, num_outs
151+
ctx.set_materialize_grads(False)
152+
nonlocal compiled_fw, num_outs, saved_value_names, fw_module
149153
if compiled_fw is None:
150154
with torch.set_grad_enabled(grad_state):
151155
out = flat_fn(*flat_tensor_args)
@@ -159,34 +163,101 @@ def forward(ctx, *flat_tensor_args):
159163
num_outs = 1
160164

161165
joint_inputs = (flat_tensor_args, out)
162-
aot_decompositions = {**aot_autograd_decompositions, **decompositions}
166+
# Need it because autograd.Function disables grad in forward
163167
with torch.set_grad_enabled(grad_state):
164168
fx_g = make_fx(joint_forward_backward, aot_decompositions)(
165169
*joint_inputs
166170
)
167-
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
168-
# print(fw_module.code, bw_module.code)
169-
171+
# This means the forward and backward graphs are created based on the input fn
172+
# However we need to take in grad_out for the saved intermediates as well.
173+
fw_module, bw_module, saved_value_nodes = partition_fn(fx_g, joint_inputs)
174+
saved_value_names = [node.name for node in saved_value_nodes]
170175
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
171176
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
172-
173-
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
174-
compiled_bw = bw_compiler(bw_module, bw_args)
175177
else:
176178
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
177-
ctx.save_for_backward(*fw_outs[num_outs:])
178-
return tuple(fw_outs[0:num_outs])
179+
180+
# print(fw_module.code)
181+
ctx.num_intermediate = len(fw_outs[num_outs:])
182+
ctx.num_inputs = len(flat_tensor_args)
183+
to_be_saved = fw_outs[num_outs:] + list(flat_tensor_args) + fw_outs[0:num_outs]
184+
ctx.save_for_backward(*to_be_saved)
185+
return tuple(fw_outs)
179186

180187
@staticmethod
181188
@disable_torchdynamo
182-
def backward(ctx, *flat_args):
183-
contiguous_args = [t.contiguous() for t in flat_args]
184-
# contiguous_args = [t for t in flat_args]
185-
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
186-
return tuple(out)
187-
188-
return CompiledFunction
189-
189+
def backward(ctx, *flat_grad_outs):
190+
nonlocal bw_modules, saved_value_names, fw_module, num_outs
191+
intermediates = ctx.saved_tensors[:ctx.num_intermediate]
192+
inputs = ctx.saved_tensors[ctx.num_intermediate:ctx.num_intermediate+ctx.num_inputs]
193+
is_grad_enabled = torch.is_grad_enabled()
194+
if not is_grad_enabled:
195+
input_flat_grad_outs = []
196+
for grad in flat_grad_outs:
197+
if grad is not None:
198+
input_flat_grad_outs.append(grad)
199+
with torch.set_grad_enabled(grad_state):
200+
fx_g_b = make_fx(joint_forward_backward, aot_decompositions)(inputs, input_flat_grad_outs)
201+
saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names)
202+
assert len(saved_value_nodes) <= len(saved_value_names)
203+
fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules(fx_g_b, saved_value_nodes)
204+
if len(saved_values_new) != len(saved_value_names):
205+
new_intermediates = []
206+
# Forward saves more intermediates than needed
207+
assert len(saved_values_new) < len(saved_value_names)
208+
j = 0
209+
for node in saved_values_new:
210+
while node.name != saved_value_names[j]:
211+
j+=1
212+
new_intermediates.append(intermediates[j])
213+
j+=1
214+
intermediates = new_intermediates
215+
# else:
216+
# input_flat_grad_outs = flat_grad_outs
217+
# # create_joint_forward_backward takes inputs and cotangents as inps
218+
# # inps: inputs, cotangents: flat_grad_outs
219+
# j_b = create_joint_forward_backward(ctx.fw_module)
220+
# # setting grad is not needed
221+
# with torch.set_grad_enabled(grad_state):
222+
# fx_g_b = make_fx(j_b, aot_decompositions)(inputs, input_flat_grad_outs)
223+
# saved_value_nodes = _get_saved_values(fx_g_b, saved_value_names)
224+
# # print(saved_value_nodes)
225+
# # print(saved_value_names)
226+
# # assert len(saved_value_nodes) == len(saved_value_names)
227+
# fw_module_b, bw_module_b, saved_values_new = _extract_fwd_bwd_modules_db(fx_g_b, saved_value_nodes)
228+
# # print(fx_g_b.code, ctx.fw_module.code, fw_module_b.code, bw_module_b.code)
229+
# # assert fw_module_b.code == fw_module.code
230+
# # print(len(sew), len(saved_value_names))
231+
# if len(saved_values_new) != len(saved_value_names):
232+
# new_intermediates = []
233+
# # Forward saves more intermediates than needed
234+
# assert len(saved_values_new) < len(saved_value_names)
235+
# for node in saved_values_new:
236+
# j = 0
237+
# while node.name != saved_value_names[j]:
238+
# j+=1
239+
# new_intermediates.append(intermediates[j])
240+
# j+=1
241+
# intermediates = new_intermediates
242+
243+
# This is needed because aot function caching uses function id right now
244+
bw_module_fn = None
245+
for elem in bw_modules:
246+
if elem.code == bw_module_b.code:
247+
bw_module_fn = elem
248+
break
249+
if bw_module_fn is None:
250+
bw_modules.append(bw_module_b)
251+
bw_module_fn = bw_module_b
252+
253+
f = aot_function(bw_module_fn, bw_compiler, bw_compiler, partition_fn, aot_decompositions)
254+
out = f(*intermediates, *input_flat_grad_outs)
255+
return tuple(normalize_as_list(out))
256+
257+
def return_fn(*args, **kwargs):
258+
out = CompiledFunction.apply(*args, **kwargs)
259+
return out[0:num_outs]
260+
return return_fn
190261

191262
class _CompileCache(CompileCache):
192263
pass
@@ -275,7 +346,7 @@ def rearrange(tensor_args, static_args, static_argnums):
275346
return args
276347

277348

278-
KNOWN_TYPES = [torch.Tensor, int, str, float, bool]
349+
KNOWN_TYPES = [torch.Tensor, int, str, float, bool, None]
279350

280351

281352
def aot_function(
@@ -411,7 +482,9 @@ def returned_function(*args, **kwargs):
411482
hasher_type,
412483
*flat_args_for_cache,
413484
)
414-
485+
# print("fn_id: ", fn_id)
486+
# print("size: ", compile_cache.size())
487+
# print("num_tensor_args: ", num_tensor_args)
415488
# Compile the function and save it in the cache
416489
if cached_res is None:
417490
# Save the args_spec for flat_tensor_args to unflatten while tracing
@@ -436,7 +509,7 @@ def flat_fn(*flat_tensor_args):
436509
for i in flat_out:
437510
is_known_type = False
438511
for j in KNOWN_TYPES:
439-
if isinstance(i, j):
512+
if j is None or isinstance(i, j):
440513
is_known_type = True
441514
break
442515
if not is_known_type:
@@ -458,7 +531,7 @@ def flat_fn(*flat_tensor_args):
458531
partition_fn,
459532
decompositions,
460533
grad_state=torch.is_grad_enabled(),
461-
).apply
534+
)
462535
cached_res = (compiled_fn, out_spec)
463536

464537
# Save the compiled_fn in the cache
@@ -598,7 +671,7 @@ def aot_function_simplified(
598671
partition_fn,
599672
decompositions,
600673
grad_state=torch.is_grad_enabled(),
601-
).apply
674+
)
602675

603676
return compiled_fn
604677

@@ -620,4 +693,4 @@ def forward(self, *args, **kwargs):
620693

621694

622695
compiled_function = aot_function
623-
compiled_module = aot_module
696+
compiled_module = aot_module

functorch/_src/partitioners.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,52 @@ def _extract_fwd_bwd_modules(joint_module: fx.GraphModule, saved_values):
108108

109109
fwd_module = fx.GraphModule(joint_module, fwd_graph)
110110
bwd_module = fx.GraphModule(joint_module, bwd_graph)
111-
return fwd_module, bwd_module
111+
return fwd_module, bwd_module, saved_values
112112

113+
def _extract_fwd_bwd_modules_db(joint_module: fx.GraphModule, saved_values):
114+
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module)
115+
print("FWD OUTS: ", fwd_outputs)
116+
print("BWD OUTS: ", bwd_outputs)
117+
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
118+
tangent_inputs = list(filter(_is_tangent, joint_module.graph.nodes))
119+
print("primal_inputs: ", primal_inputs)
120+
print("tangent_inputs: ", tangent_inputs)
121+
# Construct the forward module
122+
fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs)
123+
bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_values + tangent_inputs, bwd_outputs)
124+
125+
# This is to filter out saved values that don't actually end up being used by the backwards pass
126+
for node in bwd_graph.nodes:
127+
if node.op == 'placeholder' and not node.users:
128+
for saved_value in saved_values:
129+
if saved_value.name == node.name:
130+
saved_values.remove(saved_value)
131+
break
132+
133+
# Now, we re-generate the fwd/bwd graphs.
134+
# NB: This might increase compilation time, but I doubt it matters
135+
fwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, primal_inputs, fwd_outputs)
136+
bwd_graph = _extract_graph_with_inputs_outputs(joint_module.graph, saved_values + tangent_inputs, bwd_outputs)
137+
138+
fwd_module = fx.GraphModule(joint_module, fwd_graph)
139+
bwd_module = fx.GraphModule(joint_module, bwd_graph)
140+
return fwd_module, bwd_module, saved_values
141+
142+
def _get_saved_values(new_module: fx.GraphModule, saved_value_names):
143+
saved_values = []
144+
for node in new_module.graph.nodes:
145+
if node.name in saved_value_names:
146+
if 'tensor_meta' not in node.meta and node.op == 'call_function':
147+
users = node.users
148+
assert all(user.target == operator.getitem for user in users)
149+
for user in users:
150+
saved_values.append(user)
151+
else:
152+
saved_values.append(node)
153+
154+
saved_values = list(saved_values)
155+
156+
return saved_values
113157

114158
def default_partition(
115159
joint_module: fx.GraphModule, _joint_inputs
@@ -153,8 +197,8 @@ def default_partition(
153197
saved_values.append(user)
154198
else:
155199
saved_values.append(node)
156-
saved_values = list(set(saved_values))
157200

201+
saved_values = list(saved_values)
158202
return _extract_fwd_bwd_modules(joint_module, saved_values)
159203

160204

0 commit comments

Comments
 (0)