11
11
from functorch .experimental import functionalize
12
12
from . import config
13
13
from .decompositions import register_decomposition
14
- from .partitioners import default_partition
14
+ from .partitioners import default_partition , _get_saved_values , _extract_fwd_bwd_modules
15
15
from .named_members_polyfill import _named_parameters , _named_buffers
16
16
from typing import Callable , List , Dict , Any , Tuple , Optional
17
17
from functools import wraps
@@ -70,7 +70,7 @@ def preserve_rng_state():
70
70
71
71
def create_joint_forward_backward (fn ):
72
72
def joint_forward_backward (
73
- primals : List [Any ], tangents : List [Any ]
73
+ primals : List [Any ], cotangents : List [Any ]
74
74
) -> Tuple [List [Any ], List [Any ]]:
75
75
# Call the forward pass
76
76
outs = fn (* primals )
@@ -84,21 +84,21 @@ def joint_forward_backward(
84
84
grad_primals .append (p )
85
85
86
86
# Get the outputs that need gradients
87
- assert len (tangents ) == len (outs )
87
+ assert len (cotangents ) == len (outs )
88
88
needed_outs = []
89
- needed_tangents = []
90
- for out , tangent in zip (outs , tangents ):
89
+ needed_cotangents = []
90
+ for out , cotangent in zip (outs , cotangents ):
91
91
if isinstance (out , Tensor ) and out .requires_grad :
92
92
needed_outs .append (out )
93
- needed_tangents .append (tangent )
93
+ needed_cotangents .append (cotangent )
94
94
backward_out = []
95
95
# Call the backwards pass
96
96
if grad_primals :
97
97
backward_out = torch .autograd .grad (
98
98
needed_outs ,
99
99
grad_primals ,
100
- grad_outputs = needed_tangents ,
101
- allow_unused = True ,
100
+ grad_outputs = needed_cotangents ,
101
+ allow_unused = True
102
102
)
103
103
backward_out_iter = iter (backward_out )
104
104
return outs , [
@@ -152,22 +152,31 @@ def create_aot_autograd_function(
152
152
if decompositions is None :
153
153
decompositions = {}
154
154
joint_forward_backward = create_joint_forward_backward (flat_fn )
155
-
155
+ # create_joint_forward_backward takes inputs and cotangents as inps
156
+ # inps: inputs, cotangents: flat_grad_outs
157
+ j_b = None
156
158
compiled_fw = None
157
- compiled_bw = None
159
+ bw_modules = []
158
160
num_outs = None
161
+ saved_value_names = None
162
+ aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
159
163
160
164
class CompiledFunction (torch .autograd .Function ):
161
165
@staticmethod
162
166
@disable_torchdynamo
163
167
def forward (ctx , * flat_tensor_args ):
164
- nonlocal compiled_fw , compiled_bw , num_outs
168
+ # ctx.set_materialize_grads(False)
169
+ nonlocal compiled_fw , num_outs , saved_value_names , aot_decompositions , j_b
165
170
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
166
171
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
167
172
old_jit_autocast_flag = torch ._C ._jit_set_autocast_mode (False )
173
+ # creating this to save the original inputs since the inputs might be returned as outs
174
+ # and would then have grad_fn set on them which is incorrect.
175
+ flat_tensor_args_0 = flat_tensor_args
168
176
if compiled_fw is None :
169
177
with preserve_rng_state ():
170
178
# Set input tensors that require grad to leaves
179
+ # Detach to not accidentally extend the graph
171
180
flat_tensor_args = pytree .tree_map (
172
181
lambda x : x .detach ().requires_grad_ (x .requires_grad )
173
182
if isinstance (x , Tensor ) else x , flat_tensor_args
@@ -184,8 +193,9 @@ def forward(ctx, *flat_tensor_args):
184
193
num_outs = 1
185
194
186
195
joint_inputs = (flat_tensor_args , out )
187
- aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
188
196
with torch .set_grad_enabled (grad_state ):
197
+ # This means the forward and backward graphs are created based on the input fn
198
+ # However we need to take in grad_out for the saved intermediates as well.
189
199
fx_g = make_fx (joint_forward_backward , aot_decompositions )(
190
200
* joint_inputs
191
201
)
@@ -196,33 +206,76 @@ def forward(ctx, *flat_tensor_args):
196
206
def fake_fn (primals , tangents ):
197
207
return fx_g (primals , tangents )
198
208
fx_g = make_fx (functionalize (fake_fn ))(* joint_inputs )
199
- fw_module , bw_module = partition_fn (fx_g , joint_inputs )
200
- # print(fw_module.code, bw_module.code)
201
-
209
+ fw_module , bw_module , saved_value_nodes = partition_fn (fx_g , joint_inputs )
210
+ saved_value_names = [node .name for node in saved_value_nodes ]
202
211
compiled_fw = fw_compiler (fw_module , flat_tensor_args )
203
212
fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
204
-
205
- bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
206
- compiled_bw = bw_compiler (bw_module , bw_args )
213
+ j_b = create_joint_forward_backward (fw_module )
207
214
else :
208
215
fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
216
+ ctx .num_intermediate = len (fw_outs [num_outs :])
217
+ to_be_saved = fw_outs [num_outs :] + list (flat_tensor_args_0 )
218
+ ctx .save_for_backward (* to_be_saved )
209
219
torch ._C ._jit_set_autocast_mode (old_jit_autocast_flag )
210
- ctx .save_for_backward (* fw_outs [num_outs :])
211
- return tuple (fw_outs [0 :num_outs ])
220
+ return tuple (fw_outs )
212
221
213
222
@staticmethod
214
223
@disable_torchdynamo
215
- def backward (ctx , * flat_args ):
224
+ def backward (ctx , * flat_grad_outs ):
216
225
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
217
226
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
218
227
old_jit_autocast_flag = torch ._C ._jit_set_autocast_mode (False )
219
- contiguous_args = [t .contiguous () for t in flat_args ]
220
- # contiguous_args = [t for t in flat_args]
221
- out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
228
+ nonlocal bw_modules , saved_value_names , num_outs , aot_decompositions , j_b
229
+ with preserve_rng_state ():
230
+ intermediates = ctx .saved_tensors [:ctx .num_intermediate ]
231
+ flat_tensor_args = ctx .saved_tensors [ctx .num_intermediate :]
232
+ flat_tensor_args = pytree .tree_map (
233
+ lambda x : x .detach ().requires_grad_ (x .requires_grad )
234
+ if isinstance (x , Tensor ) else x , flat_tensor_args
235
+ )
236
+ inp_grad_outs = flat_grad_outs
237
+ with torch .set_grad_enabled (grad_state ):
238
+ fx_g_b = make_fx (j_b , aot_decompositions )(flat_tensor_args , inp_grad_outs )
239
+ if config .use_functionalize :
240
+ # Functionalize the foward backward graph. First create a
241
+ # fake fn to make functionalize happy
242
+ def fake_fn (primals , tangents ):
243
+ return fx_g_b (primals , tangents )
244
+ fx_g_b = make_fx (functionalize (fake_fn ))(flat_tensor_args , inp_grad_outs )
245
+ saved_value_nodes = _get_saved_values (fx_g_b , saved_value_names )
246
+ assert len (saved_value_nodes ) <= len (saved_value_names )
247
+ fw_module_b , bw_module_b , saved_values_new = _extract_fwd_bwd_modules (fx_g_b , saved_value_nodes )
248
+ if len (saved_values_new ) != len (saved_value_names ):
249
+ new_intermediates = []
250
+ # Forward saves more intermediates than needed
251
+ assert len (saved_values_new ) < len (saved_value_names )
252
+ j = 0
253
+ for node in saved_values_new :
254
+ while node .name != saved_value_names [j ]:
255
+ j += 1
256
+ new_intermediates .append (intermediates [j ])
257
+ j += 1
258
+ intermediates = new_intermediates
259
+
260
+ # This is needed because aot function caching uses function id right now
261
+ bw_module_fn = None
262
+ for elem in bw_modules :
263
+ if elem .code == bw_module_b .code :
264
+ bw_module_fn = elem
265
+ break
266
+ if bw_module_fn is None :
267
+ bw_modules .append (bw_module_b )
268
+ bw_module_fn = bw_module_b
269
+
270
+ f = aot_function (bw_module_fn , bw_compiler , bw_compiler , partition_fn , aot_decompositions )
271
+ out = f (* intermediates , * inp_grad_outs )
222
272
torch ._C ._jit_set_autocast_mode (old_jit_autocast_flag )
223
- return tuple (out )
273
+ return tuple (normalize_as_list ( out ) )
224
274
225
- return CompiledFunction
275
+ def return_fn (* args , ** kwargs ):
276
+ out = CompiledFunction .apply (* args , ** kwargs )
277
+ return out [0 :num_outs ]
278
+ return return_fn
226
279
227
280
228
281
class _CompileCache (CompileCache ):
@@ -312,7 +365,7 @@ def rearrange(tensor_args, static_args, static_argnums):
312
365
return args
313
366
314
367
315
- KNOWN_TYPES = [torch .Tensor , int , str , float , bool ]
368
+ KNOWN_TYPES = [torch .Tensor , int , str , float , bool , None ]
316
369
317
370
318
371
def aot_function (
@@ -448,7 +501,6 @@ def returned_function(*args, **kwargs):
448
501
hasher_type ,
449
502
* flat_args_for_cache ,
450
503
)
451
-
452
504
# Compile the function and save it in the cache
453
505
if cached_res is None :
454
506
# Save the args_spec for flat_tensor_args to unflatten while tracing
@@ -473,7 +525,7 @@ def flat_fn(*flat_tensor_args):
473
525
for i in flat_out :
474
526
is_known_type = False
475
527
for j in KNOWN_TYPES :
476
- if isinstance (i , j ):
528
+ if j is None or isinstance (i , j ):
477
529
is_known_type = True
478
530
break
479
531
if not is_known_type :
@@ -495,7 +547,7 @@ def flat_fn(*flat_tensor_args):
495
547
partition_fn ,
496
548
decompositions ,
497
549
grad_state = torch .is_grad_enabled (),
498
- ). apply
550
+ )
499
551
cached_res = (compiled_fn , out_spec )
500
552
501
553
# Save the compiled_fn in the cache
@@ -635,7 +687,7 @@ def aot_function_simplified(
635
687
partition_fn ,
636
688
decompositions ,
637
689
grad_state = torch .is_grad_enabled (),
638
- ). apply
690
+ )
639
691
640
692
return compiled_fn
641
693
0 commit comments