Skip to content

Commit b9364bb

Browse files
authored
[Feature] RND Implementation (#3889)
1 parent 120ffde commit b9364bb

7 files changed

Lines changed: 751 additions & 0 deletions

File tree

docs/source/reference/envs_transforms.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,12 @@ Available Transforms
305305
RemoveEmptySpecs
306306
RenameTransform
307307
Resize
308+
RNDTransform
308309
Reward2GoTransform
309310
RewardClipping
310311
RewardScaling
311312
RewardSum
313+
RunningMeanStd
312314
SelectTransform
313315
SignTransform
314316
SqueezeTransform

docs/source/reference/objectives_other.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Additional loss modules for specialized algorithms.
2020
DreamerValueLoss
2121
WorldModelLoss
2222
ExponentialQuadraticCost
23+
RNDLoss
2324

2425
DreamerV3
2526
---------

test/test_rnd.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import argparse
8+
9+
import pytest
10+
import torch
11+
import torch.nn as nn
12+
from tensordict import TensorDict
13+
14+
from torchrl.data.tensor_specs import Composite, Unbounded
15+
from torchrl.envs.transforms import RNDTransform, RunningMeanStd
16+
from torchrl.objectives import RNDLoss
17+
from torchrl.testing import get_default_devices
18+
19+
20+
# ---------------------------------------------------------------------------
21+
# Helpers
22+
# ---------------------------------------------------------------------------
23+
24+
25+
def _make_networks(obs_dim: int = 4, embed_dim: int = 16):
26+
target = nn.Sequential(
27+
nn.Linear(obs_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, embed_dim)
28+
)
29+
predictor = nn.Sequential(
30+
nn.Linear(obs_dim, embed_dim), nn.ReLU(), nn.Linear(embed_dim, embed_dim)
31+
)
32+
return target, predictor
33+
34+
35+
# ---------------------------------------------------------------------------
36+
# RunningMeanStd
37+
# ---------------------------------------------------------------------------
38+
39+
40+
class TestRunningMeanStd:
41+
def test_scalar_update(self):
42+
rms = RunningMeanStd(shape=())
43+
x = torch.arange(100, dtype=torch.float32)
44+
rms.update(x)
45+
assert abs(rms.mean.item() - x.mean().item()) < 1e-3
46+
assert abs(rms.var.item() - x.var(unbiased=False).item()) < 1e-1
47+
48+
def test_vector_update(self):
49+
rms = RunningMeanStd(shape=(4,))
50+
x = torch.randn(1000, 4)
51+
rms.update(x)
52+
assert torch.allclose(rms.mean, x.mean(0), atol=0.1)
53+
assert torch.allclose(rms.var, x.var(0, unbiased=False), atol=0.2)
54+
55+
def test_incremental_updates(self):
56+
rms = RunningMeanStd(shape=(4,))
57+
full = torch.randn(200, 4)
58+
rms.update(full[:100])
59+
rms.update(full[100:])
60+
rms_full = RunningMeanStd(shape=(4,))
61+
rms_full.update(full)
62+
assert torch.allclose(rms.mean, rms_full.mean, atol=1e-4)
63+
assert torch.allclose(rms.var, rms_full.var, atol=1e-4)
64+
65+
def test_normalize_shape_preserved(self):
66+
rms = RunningMeanStd(shape=(4,))
67+
x = torch.randn(8, 4)
68+
rms.update(x)
69+
out = rms.normalize(x)
70+
assert out.shape == x.shape
71+
72+
def test_normalize_nested_key(self):
73+
"""Running stats should work with a 2-D nested NestedKey input."""
74+
rms = RunningMeanStd(shape=(4,))
75+
x = torch.randn(3, 5, 4)
76+
rms.update(x)
77+
out = rms.normalize(x)
78+
assert out.shape == x.shape
79+
80+
def test_state_dict_roundtrip(self):
81+
rms = RunningMeanStd(shape=(4,))
82+
rms.update(torch.randn(32, 4))
83+
sd = rms.state_dict()
84+
rms2 = RunningMeanStd(shape=(4,))
85+
rms2.load_state_dict(sd)
86+
assert torch.allclose(rms.mean, rms2.mean)
87+
assert torch.allclose(rms.var, rms2.var)
88+
89+
@pytest.mark.parametrize("device", get_default_devices())
90+
def test_device_move(self, device):
91+
rms = RunningMeanStd(shape=(4,)).to(device)
92+
x = torch.randn(16, 4, device=device)
93+
rms.update(x)
94+
out = rms.normalize(x)
95+
assert out.device.type == torch.device(device).type
96+
97+
98+
# ---------------------------------------------------------------------------
99+
# RNDTransform
100+
# ---------------------------------------------------------------------------
101+
102+
103+
class TestRNDTransform:
104+
@pytest.mark.parametrize("device", get_default_devices())
105+
def test_intrinsic_reward_written(self, device):
106+
target, predictor = _make_networks()
107+
transform = RNDTransform(target, predictor).to(device)
108+
obs = torch.randn(4, device=device)
109+
next_td = TensorDict({"observation": obs}, batch_size=[])
110+
transform._step(TensorDict({}, batch_size=[]), next_td)
111+
assert "intrinsic_reward" in next_td.keys()
112+
assert next_td["intrinsic_reward"].shape == torch.Size([1])
113+
114+
@pytest.mark.parametrize("device", get_default_devices())
115+
def test_batched_intrinsic_reward(self, device):
116+
target, predictor = _make_networks()
117+
transform = RNDTransform(target, predictor).to(device)
118+
obs = torch.randn(8, 4, device=device)
119+
next_td = TensorDict({"observation": obs}, batch_size=[8])
120+
transform._step(TensorDict({}, batch_size=[8]), next_td)
121+
assert next_td["intrinsic_reward"].shape == torch.Size([8, 1])
122+
123+
def test_target_frozen(self):
124+
target, predictor = _make_networks()
125+
transform = RNDTransform(target, predictor)
126+
for p in transform.target_network.parameters():
127+
assert not p.requires_grad
128+
129+
def test_obs_rms_updated_in_train_mode(self):
130+
target, predictor = _make_networks()
131+
transform = RNDTransform(target, predictor, normalize_obs=True)
132+
transform.train()
133+
obs = torch.randn(32, 4)
134+
next_td = TensorDict({"observation": obs}, batch_size=[32])
135+
transform._step(TensorDict({}, batch_size=[32]), next_td)
136+
assert transform.obs_rms is not None
137+
assert transform.obs_rms.count.item() > 1e-4
138+
139+
def test_obs_rms_not_updated_in_eval_mode(self):
140+
target, predictor = _make_networks()
141+
transform = RNDTransform(target, predictor, normalize_obs=True)
142+
transform.train()
143+
obs = torch.randn(32, 4)
144+
next_td = TensorDict({"observation": obs}, batch_size=[32])
145+
transform._step(TensorDict({}, batch_size=[32]), next_td)
146+
count_after_train = transform.obs_rms.count.item()
147+
148+
transform.eval()
149+
next_td2 = TensorDict({"observation": obs}, batch_size=[32])
150+
transform._step(TensorDict({}, batch_size=[32]), next_td2)
151+
assert transform.obs_rms.count.item() == count_after_train
152+
153+
def test_no_normalization(self):
154+
target, predictor = _make_networks()
155+
transform = RNDTransform(
156+
target, predictor, normalize_obs=False, normalize_reward=False
157+
)
158+
transform.train()
159+
obs = torch.randn(8, 4)
160+
next_td = TensorDict({"observation": obs}, batch_size=[8])
161+
transform._step(TensorDict({}, batch_size=[8]), next_td)
162+
assert transform.obs_rms is None
163+
assert transform.reward_rms is None
164+
assert "intrinsic_reward" in next_td.keys()
165+
166+
def test_reward_rms_updated(self):
167+
target, predictor = _make_networks()
168+
transform = RNDTransform(target, predictor, normalize_reward=True)
169+
transform.train()
170+
for _ in range(5):
171+
obs = torch.randn(16, 4)
172+
next_td = TensorDict({"observation": obs}, batch_size=[16])
173+
transform._step(TensorDict({}, batch_size=[16]), next_td)
174+
assert transform.reward_rms is not None
175+
assert transform.reward_rms.count.item() > 1
176+
177+
def test_custom_keys(self):
178+
target, predictor = _make_networks()
179+
transform = RNDTransform(
180+
target,
181+
predictor,
182+
in_keys=["obs_feat"],
183+
out_keys=["curiosity"],
184+
)
185+
obs = torch.randn(4)
186+
next_td = TensorDict({"obs_feat": obs}, batch_size=[])
187+
transform._step(TensorDict({}, batch_size=[]), next_td)
188+
assert "curiosity" in next_td.keys()
189+
190+
def test_state_dict_includes_rms(self):
191+
target, predictor = _make_networks()
192+
transform = RNDTransform(target, predictor)
193+
transform.train()
194+
obs = torch.randn(8, 4)
195+
next_td = TensorDict({"observation": obs}, batch_size=[8])
196+
transform._step(TensorDict({}, batch_size=[8]), next_td)
197+
sd = transform.state_dict()
198+
assert any("obs_rms" in k for k in sd)
199+
200+
def test_state_dict_roundtrip_with_lazy_rms(self):
201+
target, predictor = _make_networks()
202+
transform = RNDTransform(target, predictor)
203+
transform.train()
204+
obs = torch.randn(8, 4)
205+
next_td = TensorDict({"observation": obs}, batch_size=[8])
206+
transform._step(TensorDict({}, batch_size=[8]), next_td)
207+
208+
target_copy, predictor_copy = _make_networks()
209+
transform_copy = RNDTransform(target_copy, predictor_copy)
210+
transform_copy.load_state_dict(transform.state_dict())
211+
212+
assert transform_copy.obs_rms is not None
213+
assert transform_copy.reward_rms is not None
214+
assert torch.allclose(transform.obs_rms.mean, transform_copy.obs_rms.mean)
215+
assert torch.allclose(transform.reward_rms.var, transform_copy.reward_rms.var)
216+
217+
def test_transform_reward_spec_has_reward_shape(self):
218+
target, predictor = _make_networks()
219+
transform = RNDTransform(target, predictor)
220+
reward_spec = Composite(
221+
reward=Unbounded(shape=(3, 1), dtype=torch.float32),
222+
shape=(3,),
223+
)
224+
225+
reward_spec = transform.transform_reward_spec(reward_spec)
226+
227+
assert reward_spec["intrinsic_reward"].shape == torch.Size([3, 1])
228+
229+
230+
# ---------------------------------------------------------------------------
231+
# RNDLoss
232+
# ---------------------------------------------------------------------------
233+
234+
235+
class TestRNDLoss:
236+
@pytest.mark.parametrize("device", get_default_devices())
237+
def test_forward_returns_loss(self, device):
238+
target, predictor = _make_networks()
239+
loss_fn = RNDLoss(predictor, target).to(device)
240+
batch = TensorDict(
241+
{"next": {"observation": torch.randn(32, 4, device=device)}}, [32]
242+
)
243+
out = loss_fn(batch)
244+
assert "loss_predictor" in out.keys()
245+
assert out["loss_predictor"].shape == torch.Size([])
246+
247+
def test_backward(self):
248+
target, predictor = _make_networks()
249+
loss_fn = RNDLoss(predictor, target)
250+
batch = TensorDict({"next": {"observation": torch.randn(32, 4)}}, [32])
251+
out = loss_fn(batch)
252+
out["loss_predictor"].backward()
253+
for p in predictor.parameters():
254+
assert p.grad is not None
255+
for p in target.parameters():
256+
assert p.grad is None
257+
258+
def test_target_frozen(self):
259+
target, predictor = _make_networks()
260+
loss_fn = RNDLoss(predictor, target)
261+
for p in loss_fn.target_network.parameters():
262+
assert not p.requires_grad
263+
264+
def test_update_fraction_reduces_effective_batch(self):
265+
torch.manual_seed(0)
266+
target, predictor = _make_networks()
267+
loss_full = RNDLoss(predictor, target, update_fraction=1.0)
268+
loss_partial = RNDLoss(predictor, target, update_fraction=0.25)
269+
batch = TensorDict({"next": {"observation": torch.randn(1000, 4)}}, [1000])
270+
# Both should return a scalar without error
271+
out_full = loss_full(batch)
272+
out_partial = loss_partial(batch)
273+
assert out_full["loss_predictor"].shape == torch.Size([])
274+
assert out_partial["loss_predictor"].shape == torch.Size([])
275+
276+
def test_set_keys(self):
277+
target, predictor = _make_networks()
278+
loss_fn = RNDLoss(predictor, target)
279+
loss_fn.set_keys(observation=("next", "obs_encoded"))
280+
assert loss_fn.tensor_keys.observation == ("next", "obs_encoded")
281+
batch = TensorDict({"next": {"obs_encoded": torch.randn(16, 4)}}, [16])
282+
out = loss_fn(batch)
283+
assert "loss_predictor" in out.keys()
284+
285+
def test_obs_rms_shared_with_transform(self):
286+
target, predictor = _make_networks()
287+
transform = RNDTransform(target, predictor, normalize_obs=True)
288+
transform.train()
289+
obs = torch.randn(64, 4)
290+
next_td = TensorDict({"observation": obs}, batch_size=[64])
291+
transform._step(TensorDict({}, batch_size=[64]), next_td)
292+
293+
loss_fn = RNDLoss(predictor, target, obs_rms=transform.obs_rms)
294+
batch = TensorDict({"next": {"observation": obs}}, [64])
295+
out = loss_fn(batch)
296+
out["loss_predictor"].backward()
297+
298+
@pytest.mark.parametrize("reduction", ["mean", "sum", "none"])
299+
def test_reduction_modes(self, reduction):
300+
target, predictor = _make_networks()
301+
loss_fn = RNDLoss(predictor, target, reduction=reduction, update_fraction=1.0)
302+
batch = TensorDict({"next": {"observation": torch.randn(16, 4)}}, [16])
303+
out = loss_fn(batch)
304+
if reduction == "none":
305+
assert out["loss_predictor"].shape == torch.Size([16])
306+
else:
307+
assert out["loss_predictor"].shape == torch.Size([])
308+
309+
def test_nested_observation_key(self):
310+
"""NestedKey tuple should work as the observation key."""
311+
target, predictor = _make_networks()
312+
loss_fn = RNDLoss(predictor, target)
313+
loss_fn.set_keys(observation=("next", "obs"))
314+
batch = TensorDict({"next": {"obs": torch.randn(8, 4)}}, [8])
315+
out = loss_fn(batch)
316+
assert "loss_predictor" in out.keys()
317+
318+
319+
if __name__ == "__main__":
320+
args, unknown = argparse.ArgumentParser().parse_known_args()
321+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/transforms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .r3m import R3MTransform
1818
from .ray_service import RayTransform
1919
from .rb_transforms import MultiStepTransform, NextStateReconstructor
20+
from .rnd import RNDTransform, RunningMeanStd
2021
from .transforms import (
2122
ActionChunkTransform,
2223
ActionDiscretizer,
@@ -188,10 +189,12 @@ def __getattr__(name: str):
188189
"RemoveEmptySpecs",
189190
"RenameTransform",
190191
"Resize",
192+
"RNDTransform",
191193
"Reward2GoTransform",
192194
"RewardClipping",
193195
"RewardScaling",
194196
"RewardSum",
197+
"RunningMeanStd",
195198
"SelectTransform",
196199
"SignTransform",
197200
"SqueezeTransform",

0 commit comments

Comments
 (0)