@@ -53,6 +53,7 @@ def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]:
53
53
54
54
55
55
def create_joint_forward_backward (fn ):
56
+ # tangents are just grad_outs/cotangents (wrong naming)
56
57
def joint_forward_backward (
57
58
primals : List [Any ], tangents : List [Any ]
58
59
) -> Tuple [List [Any ], List [Any ]]:
@@ -140,12 +141,14 @@ def create_aot_autograd_function(
140
141
compiled_fw = None
141
142
compiled_bw = None
142
143
num_outs = None
143
-
144
+ joint_inputs = None
145
+ fw_outs = None
146
+ aot_decompositions = {** aot_autograd_decompositions , ** decompositions }
144
147
class CompiledFunction (torch .autograd .Function ):
145
148
@staticmethod
146
149
@disable_torchdynamo
147
150
def forward (ctx , * flat_tensor_args ):
148
- nonlocal compiled_fw , compiled_bw , num_outs
151
+ nonlocal compiled_fw , num_outs , joint_inputs , fw_outs
149
152
if compiled_fw is None :
150
153
with torch .set_grad_enabled (grad_state ):
151
154
out = flat_fn (* flat_tensor_args )
@@ -159,29 +162,34 @@ def forward(ctx, *flat_tensor_args):
159
162
num_outs = 1
160
163
161
164
joint_inputs = (flat_tensor_args , out )
162
- aot_decompositions = { ** aot_autograd_decompositions , ** decompositions }
165
+ # Need it because autograd.Function disables grad in forward
163
166
with torch .set_grad_enabled (grad_state ):
164
167
fx_g = make_fx (joint_forward_backward , aot_decompositions )(
165
168
* joint_inputs
166
169
)
167
170
fw_module , bw_module = partition_fn (fx_g , joint_inputs )
168
- # print(fw_module.code, bw_module.code)
169
171
170
172
compiled_fw = fw_compiler (fw_module , flat_tensor_args )
171
173
fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
172
-
173
- bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
174
- compiled_bw = bw_compiler (bw_module , bw_args )
174
+ if partition_fn is default_partition :
175
+ nonlocal compiled_bw
176
+ bw_args = fw_outs [num_outs :] + fw_outs [0 :num_outs ]
177
+ compiled_bw = bw_compiler (bw_module , bw_args )
175
178
else :
176
179
fw_outs = normalize_as_list (compiled_fw (* flat_tensor_args ))
177
180
ctx .save_for_backward (* fw_outs [num_outs :])
178
181
return tuple (fw_outs [0 :num_outs ])
179
182
180
183
@staticmethod
181
184
@disable_torchdynamo
182
- def backward (ctx , * flat_args ):
183
- contiguous_args = [t .contiguous () for t in flat_args ]
184
- # contiguous_args = [t for t in flat_args]
185
+ def backward (ctx , * flat_grad_outs ):
186
+ nonlocal compiled_bw
187
+ contiguous_args = [t .contiguous () for t in flat_grad_outs ]
188
+ if compiled_bw is None :
189
+ with torch .set_grad_enabled (grad_state ):
190
+ fx_g = make_fx (joint_forward_backward , aot_decompositions )(joint_inputs [0 ], contiguous_args )
191
+ fw_module , bw_module = partition_fn (fx_g , joint_inputs )
192
+ compiled_bw = bw_compiler (bw_module , fw_outs [num_outs :] + contiguous_args )
185
193
out = normalize_as_list (compiled_bw (* ctx .saved_tensors , * contiguous_args ))
186
194
return tuple (out )
187
195
0 commit comments