1
1
import torch
2
2
import torch .nn as nn
3
- from torch import Tensor
3
+ from torch import Tensor , is_grad_enabled
4
4
from functorch import make_fx
5
5
from torch .fx import immutable_collections
6
6
import torch .utils ._pytree as pytree
7
7
import torch .utils .dlpack
8
8
from torch .nn .utils import _stateless
9
9
from functorch ._C import CompileCache
10
10
from .decompositions import register_decomposition
11
- from .partitioners import default_partition
11
+ from .partitioners import default_partition , _get_saved_values , _extract_fwd_bwd_modules
12
12
from .named_members_polyfill import _named_parameters , _named_buffers
13
13
from typing import Callable , List , Dict , Any , Tuple , Optional
14
14
from functools import wraps
@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
54
54
55
55
def create_joint_forward_backward (fn ):
56
56
def joint_forward_backward (
57
- primals : List [Any ], tangents : List [Any ]
57
+ primals : List [Any ], cotangents : List [Any ]
58
58
) -> Tuple [List [Any ], List [Any ]]:
59
59
# Call the forward pass
60
60
outs = fn (* primals )
@@ -68,21 +68,21 @@ def joint_forward_backward(
68
68
grad_primals .append (p )
69
69
70
70
# Get the outputs that need gradients
71
- assert len (tangents ) == len (outs )
71
+ assert len (cotangents ) == len (outs )
72
72
needed_outs = []
73
- needed_tangents = []
74
- for out , tangent in zip (outs , tangents ):
73
+ needed_cotangents = []
74
+ for out , cotangent in zip (outs , cotangents ):
75
75
if isinstance (out , Tensor ) and out .requires_grad :
76
76
needed_outs .append (out )
77
- needed_tangents .append (tangent )
77
+ needed_cotangents .append (cotangent )
78
78
backward_out = []
79
79
# Call the backwards pass
80
80
if grad_primals :
81
81
backward_out = torch .autograd .grad (
82
82
needed_outs ,
83
83
grad_primals ,
84
- grad_outputs = needed_tangents ,
85
- allow_unused = True ,
84
+ grad_outputs = needed_cotangents ,
85
+ allow_unused = True
86
86
)
87
87
backward_out_iter = iter (backward_out )
88
88
return outs , [
@@ -138,14 +138,18 @@ def create_aot_autograd_function(
138
138
joint_forward_backward = create_joint_forward_backward (flat_fn )
139
139
140
140
compiled_fw = None
141
- compiled_bw = None
141
+ fw_module = None
142
+ bw_modules = []
142
143
num_outs = None
144
+ saved_value_names = None
145
+ aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
143
146
144
147
class CompiledFunction (torch .autograd .Function ):
145
148
@staticmethod
146
149
@disable_torchdynamo
147
150
def forward (ctx , * flat_tensor_args ):
148
- nonlocal compiled_fw , compiled_bw , num_outs
151
+ ctx .set_materialize_grads (False )
152
+ nonlocal compiled_fw , num_outs , fw_module , saved_value_names
149
153
if compiled_fw is None :
150
154
with torch .set_grad_enabled (grad_state ):
151
155
out = flat_fn (* flat_tensor_args )
@@ -159,34 +163,78 @@ def forward(ctx, *flat_tensor_args):
159
163
num_outs = 1
160
164
161
165
joint_inputs = (flat_tensor_args , out )
162
- aot_decompositions = { ** aot_autograd_decompositions , ** decompositions }
166
+ # Need it because autograd.Function disables grad in forward
163
167
with torch .set_grad_enabled (grad_state ):
164
168
fx_g = make_fx (joint_forward_backward , aot_decompositions )(
165
169
* joint_inputs
166
170
)
167
- fw_module , bw_module = partition_fn (fx_g , joint_inputs )
168
- # print(fw_module.code, bw_module.code)
169
-
171
+ # This means the forward and backward graphs are created based on the input fn
172
+ # However we need to take in grad_out for the saved intermediates as well.
173
+ fw_module , bw_module , saved_value_nodes = partition_fn (fx_g , joint_inputs )
174
+ saved_value_names = [node .name for node in saved_value_nodes ]
170
175
compiled_fw = fw_compiler (fw_module , flat_tensor_args )
171
176
fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
172
-
173
- bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
174
- compiled_bw = bw_compiler (bw_module , bw_args )
175
177
else :
176
178
fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
177
- ctx .save_for_backward (* fw_outs [num_outs :])
178
- return tuple (fw_outs [0 :num_outs ])
179
+
180
+ ctx .num_intermediate = len (fw_outs [num_outs :])
181
+ ctx .num_inputs = len (flat_tensor_args )
182
+ to_be_saved = fw_outs [num_outs :] + list (flat_tensor_args ) + fw_outs [0 :num_outs ]
183
+ ctx .save_for_backward (* to_be_saved )
184
+ return tuple (fw_outs )
179
185
180
186
@staticmethod
181
187
@disable_torchdynamo
182
- def backward (ctx , * flat_args ):
183
- contiguous_args = [t .contiguous () for t in flat_args ]
184
- # contiguous_args = [t for t in flat_args]
185
- out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
186
- return tuple (out )
187
-
188
- return CompiledFunction
189
-
188
+ def backward (ctx , * flat_grad_outs ):
189
+ nonlocal fw_module , bw_modules , saved_value_names
190
+ intermediates = ctx .saved_tensors [:ctx .num_intermediate ]
191
+ inputs = ctx .saved_tensors [ctx .num_intermediate :ctx .num_intermediate + ctx .num_inputs ]
192
+ is_grad_enabled = torch .is_grad_enabled ()
193
+
194
+ if not is_grad_enabled :
195
+ input_flat_grad_outs = []
196
+ for grad in flat_grad_outs :
197
+ if grad is not None :
198
+ input_flat_grad_outs .append (grad )
199
+ with torch .set_grad_enabled (grad_state ):
200
+ fx_g_b = make_fx (joint_forward_backward , aot_decompositions )(inputs , input_flat_grad_outs )
201
+ else :
202
+ input_flat_grad_outs = flat_grad_outs
203
+ j_b = create_joint_forward_backward (fw_module )
204
+ with torch .set_grad_enabled (grad_state ):
205
+ fx_g_b = make_fx (j_b , aot_decompositions )(inputs , input_flat_grad_outs )
206
+
207
+ saved_value_nodes = _get_saved_values (fx_g_b , saved_value_names )
208
+ assert len (saved_value_nodes ) <= len (saved_value_names )
209
+ fw_module_b , bw_module_b , saved_values_new = _extract_fwd_bwd_modules (fx_g_b , saved_value_nodes )
210
+ bw_module_fn = None
211
+ for elem in bw_modules :
212
+ if elem .code == bw_module_b .code :
213
+ bw_module_fn = elem
214
+ if bw_module_fn is None :
215
+ bw_modules .append (bw_module_b )
216
+ bw_module_fn = bw_module_b
217
+
218
+ f = aot_function (bw_module_fn , bw_compiler , bw_compiler , partition_fn , aot_decompositions )
219
+
220
+ if len (saved_values_new ) != len (saved_value_names ):
221
+ new_intermediates = []
222
+ # Forward saves more intermediates than needed
223
+ assert len (saved_values_new ) < len (saved_value_names )
224
+ j = 0
225
+ for node in saved_values_new :
226
+ while node .name != saved_value_names [j ]:
227
+ j += 1
228
+ new_intermediates .append (intermediates [j ])
229
+ j += 1
230
+ intermediates = new_intermediates
231
+ out = f (* intermediates , * input_flat_grad_outs )
232
+ return tuple (normalize_as_list (out ))
233
+
234
+ def return_fn (* args , ** kwargs ):
235
+ out = CompiledFunction .apply (* args , ** kwargs )
236
+ return out [0 :num_outs ]
237
+ return return_fn
190
238
191
239
class _CompileCache (CompileCache ):
192
240
pass
@@ -275,7 +323,7 @@ def rearrange(tensor_args, static_args, static_argnums):
275
323
return args
276
324
277
325
278
- KNOWN_TYPES = [torch .Tensor , int , str , float , bool ]
326
+ KNOWN_TYPES = [torch .Tensor , int , str , float , bool , None ]
279
327
280
328
281
329
def aot_function (
@@ -411,7 +459,9 @@ def returned_function(*args, **kwargs):
411
459
hasher_type ,
412
460
* flat_args_for_cache ,
413
461
)
414
-
462
+ # print("fn_id: ", fn_id)
463
+ # print("size: ", compile_cache.size())
464
+ # print("num_tensor_args: ", num_tensor_args)
415
465
# Compile the function and save it in the cache
416
466
if cached_res is None :
417
467
# Save the args_spec for flat_tensor_args to unflatten while tracing
@@ -436,7 +486,7 @@ def flat_fn(*flat_tensor_args):
436
486
for i in flat_out :
437
487
is_known_type = False
438
488
for j in KNOWN_TYPES :
439
- if isinstance (i , j ):
489
+ if j is None or isinstance (i , j ):
440
490
is_known_type = True
441
491
break
442
492
if not is_known_type :
@@ -458,7 +508,7 @@ def flat_fn(*flat_tensor_args):
458
508
partition_fn ,
459
509
decompositions ,
460
510
grad_state = torch .is_grad_enabled (),
461
- ). apply
511
+ )
462
512
cached_res = (compiled_fn , out_spec )
463
513
464
514
# Save the compiled_fn in the cache
@@ -598,7 +648,7 @@ def aot_function_simplified(
598
648
partition_fn ,
599
649
decompositions ,
600
650
grad_state = torch .is_grad_enabled (),
601
- ). apply
651
+ )
602
652
603
653
return compiled_fn
604
654
@@ -620,4 +670,4 @@ def forward(self, *args, **kwargs):
620
670
621
671
622
672
compiled_function = aot_function
623
- compiled_module = aot_module
673
+ compiled_module = aot_module
0 commit comments