Skip to content

Commit df92f42

Browse files
committed
Add BwdModelView instead of shallow copy for bwd model
1 parent aa0d049 commit df92f42

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

rewarped/autograd.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66

77

88
# for checkpointing method
9-
def assign_tensors(x, x_out, names, tensors):
9+
def assign_tensors(x, x_out, names, tensors, view=False):
1010
# need to assign b/c state_0, state_1 cannot be swapped
11+
# if view=True, then x == x_out except for tensors given by names, so we can skip assigning some
1112
# TODO: Add fn to get wp.array attributes instead of vars(..)
12-
for name in vars(x):
13-
if name in names:
14-
continue
15-
attr = getattr(x, name)
16-
if isinstance(attr, wp.array):
17-
wp_array = getattr(x_out, name)
18-
wp_array.assign(attr)
13+
if not view:
14+
for name in vars(x):
15+
if name in names:
16+
continue
17+
attr = getattr(x, name)
18+
if isinstance(attr, wp.array):
19+
wp_array = getattr(x_out, name)
20+
wp_array.assign(attr)
1921
for name, tensor in zip(names, tensors, strict=True):
2022
# assert not torch.isnan(tensor).any(), print("NaN tensor", name)
2123
wp_array = getattr(x_out, name)
@@ -115,7 +117,7 @@ def forward(
115117
finally:
116118
tape.bwd_update_graph = wp.capture_end()
117119

118-
assign_tensors(model, model_bwd, model_tensors_names, model_tensors)
120+
assign_tensors(model, model_bwd, model_tensors_names, model_tensors, view=True)
119121
assign_tensors(state_in, state_in_bwd, state_tensors_names, state_tensors)
120122
assign_tensors(control, control_bwd, control_tensors_names, control_tensors)
121123
wp.capture_launch(tape.update_graph)
@@ -197,7 +199,7 @@ def backward(ctx, *adj_tensors):
197199

198200
if use_graph_capture:
199201
# checkpointing method
200-
assign_tensors(model, model_bwd, model_tensors_names, model_tensors)
202+
assign_tensors(model, model_bwd, model_tensors_names, model_tensors, view=True)
201203
assign_tensors(state_in, state_in_bwd, state_tensors_names, state_tensors)
202204
assign_tensors(control, control_bwd, control_tensors_names, control_tensors)
203205
wp.capture_launch(tape.update_graph)

rewarped/warp_env.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import copy
2-
31
import numpy as np
42
import torch
53
from gym import spaces
@@ -52,6 +50,23 @@ def __getattr__(self, name):
5250
return wp.to_torch(getattr(self.data, name))
5351

5452

53+
class BwdModelView:
54+
def __init__(self, model, model_tensors_names):
55+
self.model = model
56+
self.model_tensors_names = model_tensors_names
57+
58+
self.bwd_tensors = {}
59+
for k in model_tensors_names:
60+
v = getattr(model, k)
61+
v = wp.zeros_like(v, requires_grad=v.requires_grad)
62+
self.bwd_tensors[k] = v
63+
64+
def __getattr__(self, name):
65+
if name in self.model_tensors_names:
66+
return self.bwd_tensors[name]
67+
return getattr(self.model, name)
68+
69+
5570
class WarpEnv(Environment):
5671
r"""Base class for gym-like Warp environments that builds on `Environment`.
5772
@@ -290,13 +305,8 @@ def init_sim(self):
290305
if self.use_graph_capture:
291306
self.tape = wp.Tape() # persistent tape for graph capture
292307

293-
# shallow copy
294-
# TODO: need a better day to have separate copies when not using graph capture (for randomization)
295-
self.model_bwd = copy.copy(self.model)
296-
for k in self.model_tensors_names:
297-
v = getattr(self.model_bwd, k)
298-
v = wp.zeros_like(v, requires_grad=self.requires_grad)
299-
setattr(self.model_bwd, k, v)
308+
# shallow copy of model with new arrays for `model_tensors`
309+
self.model_bwd = BwdModelView(self.model, self.model_tensors_names)
300310

301311
self.state_0_bwd = self.model.state(copy="zeros")
302312
self.state_1_bwd = self.model.state(copy="zeros")

0 commit comments

Comments
 (0)