@@ -54,7 +54,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
54
54
55
55
def create_joint_forward_backward (fn ):
56
56
def joint_forward_backward (
57
- primals : List [Any ], tangents : List [Any ]
57
+ primals : List [Any ], cotangents : List [Any ]
58
58
) -> Tuple [List [Any ], List [Any ]]:
59
59
# Call the forward pass
60
60
outs = fn (* primals )
@@ -68,21 +68,21 @@ def joint_forward_backward(
68
68
grad_primals .append (p )
69
69
70
70
# Get the outputs that need gradients
71
- assert len (tangents ) == len (outs )
71
+ assert len (cotangents ) == len (outs )
72
72
needed_outs = []
73
- needed_tangents = []
74
- for out , tangent in zip (outs , tangents ):
73
+ needed_cotangents = []
74
+ for out , cotangent in zip (outs , cotangents ):
75
75
if isinstance (out , Tensor ) and out .requires_grad :
76
76
needed_outs .append (out )
77
- needed_tangents .append (tangent )
77
+ needed_cotangents .append (cotangent )
78
78
backward_out = []
79
79
# Call the backwards pass
80
80
if grad_primals :
81
81
backward_out = torch .autograd .grad (
82
82
needed_outs ,
83
83
grad_primals ,
84
- grad_outputs = needed_tangents ,
85
- allow_unused = True ,
84
+ grad_outputs = needed_cotangents ,
85
+ allow_unused = True
86
86
)
87
87
backward_out_iter = iter (backward_out )
88
88
return outs , [
@@ -140,12 +140,13 @@ def create_aot_autograd_function(
140
140
compiled_fw = None
141
141
compiled_bw = None
142
142
num_outs = None
143
+ aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
143
144
144
145
class CompiledFunction (torch .autograd .Function ):
145
146
@staticmethod
146
147
@disable_torchdynamo
147
148
def forward (ctx , * flat_tensor_args ):
148
- nonlocal compiled_fw , compiled_bw , num_outs
149
+ nonlocal compiled_fw , num_outs
149
150
if compiled_fw is None :
150
151
with torch .set_grad_enabled (grad_state ):
151
152
out = flat_fn (* flat_tensor_args )
@@ -159,31 +160,83 @@ def forward(ctx, *flat_tensor_args):
159
160
num_outs = 1
160
161
161
162
joint_inputs = (flat_tensor_args , out )
162
- aot_decompositions = { ** aot_autograd_decompositions , ** decompositions }
163
+ # Need it because autograd.Function disables grad in forward
163
164
with torch .set_grad_enabled (grad_state ):
164
165
fx_g = make_fx (joint_forward_backward , aot_decompositions )(
165
166
* joint_inputs
166
167
)
167
168
fw_module , bw_module = partition_fn (fx_g , joint_inputs )
168
- # print(fw_module.code, bw_module.code)
169
169
170
170
compiled_fw = fw_compiler (fw_module , flat_tensor_args )
171
171
fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
172
-
173
- bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
174
- compiled_bw = bw_compiler (bw_module , bw_args )
172
+ if partition_fn is default_partition :
173
+ print ("ENTERING default_partition" )
174
+ ctx .num_intermediate = len (fw_outs [num_outs :])
175
+ ctx .num_inputs = len (flat_tensor_args )
176
+ to_be_saved = fw_outs [num_outs :] + list (flat_tensor_args ) + out
177
+ print ("fw outs: " , fw_outs , "-------" )
178
+ ctx .save_for_backward (* to_be_saved )
179
+ ctx .fwd_graph = fw_module .code
180
+ else :
181
+ nonlocal compiled_bw
182
+ bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
183
+ compiled_bw = bw_compiler (bw_module , bw_args )
184
+ ctx .save_for_backward (* fw_outs [num_outs :])
175
185
else :
176
186
fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
177
- ctx .save_for_backward (* fw_outs [num_outs :])
187
+ if partition_fn is default_partition :
188
+ with torch .set_grad_enabled (grad_state ):
189
+ out = flat_fn (* flat_tensor_args )
190
+ out = pytree .tree_map (
191
+ lambda x : x .detach ().contiguous () if isinstance (x , Tensor ) else x , out
192
+ )
193
+ ctx .num_intermediate = len (fw_outs [num_outs :])
194
+ ctx .num_inputs = len (flat_tensor_args )
195
+ to_be_saved = fw_outs [num_outs :] + list (flat_tensor_args ) + out
196
+ ctx .save_for_backward (* to_be_saved )
197
+ else :
198
+ ctx .save_for_backward (* fw_outs [num_outs :])
178
199
return tuple (fw_outs [0 :num_outs ])
179
200
180
201
@staticmethod
181
202
@disable_torchdynamo
182
- def backward (ctx , * flat_args ):
183
- contiguous_args = [t .contiguous () for t in flat_args ]
184
- # contiguous_args = [t for t in flat_args]
185
- out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
186
- return tuple (out )
203
+ def backward (ctx , * flat_grad_outs ):
204
+ print (flat_grad_outs )
205
+ contiguous_args = [t .contiguous () for t in flat_grad_outs ]
206
+ if compiled_bw is None :
207
+ assert partition_fn is default_partition
208
+ with torch .set_grad_enabled (grad_state ):
209
+ inputs = ctx .saved_tensors [ctx .num_intermediate :ctx .num_intermediate + ctx .num_inputs ]
210
+ fx_g = make_fx (joint_forward_backward , aot_decompositions )(inputs , contiguous_args )
211
+ # assert that the forward graph generated here is the same
212
+ # if it's specified that the user might want to calculate double backwards
213
+ fw_module , bw_module = partition_fn (fx_g , ctx .saved_tensors [ctx .num_intermediate :])
214
+ print (fw_module .code )
215
+ print (ctx .fwd_graph )
216
+ assert fw_module .code == ctx .fwd_graph
217
+ func_code = bw_module .code .split ('self, ' )
218
+ # print(func_code[0] + func_code[1])
219
+ exec (func_code [0 ] + func_code [1 ], globals ())
220
+ f = create_aot_autograd_function (forward , bw_compiler , bw_compiler , partition_fn , aot_decompositions , grad_state )
221
+ # print(bw_module.code, *ctx.saved_tensors, contiguous_args)
222
+ # print(*ctx.saved_tensors[:ctx.num_intermediate], *contiguous_args)
223
+ # print(*ctx.saved_tensors[ctx.num_intermediate:], *contiguous_args)
224
+ return f .apply (* ctx .saved_tensors [:ctx .num_intermediate ], * contiguous_args )
225
+ else :
226
+ assert not torch .is_grad_enabled ()
227
+ out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
228
+ return tuple (out )
229
+ # nonlocal compiled_bw
230
+ # contiguous_args = [t.contiguous() for t in flat_grad_outs]
231
+ # if compiled_bw is None:
232
+ # with torch.set_grad_enabled(grad_state):
233
+ # fx_g = make_fx(joint_forward_backward, aot_decompositions)(joint_inputs[0], contiguous_args)
234
+ # # assert that the forward graph generated here is the same
235
+ # # if it's specified that the user might want to calculate double backwards
236
+ # fw_module, bw_module = partition_fn(fx_g, joint_inputs)
237
+ # compiled_bw = bw_compiler(bw_module, fw_outs[num_outs:] + contiguous_args)
238
+ # out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
239
+ # return tuple(out)
187
240
188
241
return CompiledFunction
189
242
0 commit comments