|
6 | 6 |
|
7 | 7 |
|
8 | 8 | # for checkpointing method |
9 | | -def assign_tensors(x, x_out, names, tensors): |
| 9 | +def assign_tensors(x, x_out, names, tensors, view=False): |
10 | 10 | # need to assign b/c state_0, state_1 cannot be swapped |
| 11 | + # if view=True, then x == x_out except for tensors given by names, so we can skip assigning some |
11 | 12 | # TODO: Add fn to get wp.array attributes instead of vars(..) |
12 | | - for name in vars(x): |
13 | | - if name in names: |
14 | | - continue |
15 | | - attr = getattr(x, name) |
16 | | - if isinstance(attr, wp.array): |
17 | | - wp_array = getattr(x_out, name) |
18 | | - wp_array.assign(attr) |
| 13 | + if not view: |
| 14 | + for name in vars(x): |
| 15 | + if name in names: |
| 16 | + continue |
| 17 | + attr = getattr(x, name) |
| 18 | + if isinstance(attr, wp.array): |
| 19 | + wp_array = getattr(x_out, name) |
| 20 | + wp_array.assign(attr) |
19 | 21 | for name, tensor in zip(names, tensors, strict=True): |
20 | 22 | # assert not torch.isnan(tensor).any(), print("NaN tensor", name) |
21 | 23 | wp_array = getattr(x_out, name) |
@@ -115,7 +117,7 @@ def forward( |
115 | 117 | finally: |
116 | 118 | tape.bwd_update_graph = wp.capture_end() |
117 | 119 |
|
118 | | - assign_tensors(model, model_bwd, model_tensors_names, model_tensors) |
| 120 | + assign_tensors(model, model_bwd, model_tensors_names, model_tensors, view=True) |
119 | 121 | assign_tensors(state_in, state_in_bwd, state_tensors_names, state_tensors) |
120 | 122 | assign_tensors(control, control_bwd, control_tensors_names, control_tensors) |
121 | 123 | wp.capture_launch(tape.update_graph) |
@@ -197,7 +199,7 @@ def backward(ctx, *adj_tensors): |
197 | 199 |
|
198 | 200 | if use_graph_capture: |
199 | 201 | # checkpointing method |
200 | | - assign_tensors(model, model_bwd, model_tensors_names, model_tensors) |
| 202 | + assign_tensors(model, model_bwd, model_tensors_names, model_tensors, view=True) |
201 | 203 | assign_tensors(state_in, state_in_bwd, state_tensors_names, state_tensors) |
202 | 204 | assign_tensors(control, control_bwd, control_tensors_names, control_tensors) |
203 | 205 | wp.capture_launch(tape.update_graph) |
|
0 commit comments