Skip to content

Commit 5be19f3

Browse files
author
Clemens Schwarke
committed
Adds Recurrent Student-Teacher Distillation
Approved-by: Mayank Mittal
1 parent 73de8a3 commit 5be19f3

10 files changed

+227
-62
lines changed

Diff for: config/dummy_config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ policy:
7171

7272
# only needed for `ActorCriticRecurrent`
7373
# rnn_type: 'lstm'
74-
# rnn_hidden_size: 512
74+
# rnn_hidden_dim: 512
7575
# rnn_num_layers: 1
7676

7777
runner:

Diff for: rsl_rl/algorithms/distillation.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import torch.optim as optim
99

1010
# rsl-rl
11-
from rsl_rl.modules import StudentTeacher
11+
from rsl_rl.modules import StudentTeacher, StudentTeacherRecurrent
1212
from rsl_rl.storage import RolloutStorage
1313

1414

1515
class Distillation:
1616
"""Distillation algorithm for training a student model to mimic a teacher model."""
1717

18-
policy: StudentTeacher
18+
policy: StudentTeacher | StudentTeacherRecurrent
1919
"""The student teacher model."""
2020

2121
def __init__(
@@ -24,6 +24,7 @@ def __init__(
2424
num_learning_epochs=1,
2525
gradient_length=15,
2626
learning_rate=1e-3,
27+
loss_type="mse",
2728
device="cpu",
2829
):
2930
self.device = device
@@ -37,11 +38,20 @@ def __init__(
3738
self.storage = None # initialized later
3839
self.optimizer = optim.Adam(self.policy.student.parameters(), lr=self.learning_rate)
3940
self.transition = RolloutStorage.Transition()
41+
self.last_hidden_states = None
4042

4143
# distillation parameters
4244
self.num_learning_epochs = num_learning_epochs
4345
self.gradient_length = gradient_length
4446

47+
# initialize the loss function
48+
if loss_type == "mse":
49+
self.loss_fn = nn.functional.mse_loss
50+
elif loss_type == "huber":
51+
self.loss_fn = nn.functional.huber_loss
52+
else:
53+
raise ValueError(f"Unknown loss type: {loss_type}. Supported types are: mse, huber")
54+
4555
self.num_updates = 0
4656

4757
def init_storage(
@@ -79,25 +89,24 @@ def process_env_step(self, rewards, dones, infos):
7989

8090
def update(self):
8191
self.num_updates += 1
82-
mean_behaviour_loss = 0
92+
mean_behavior_loss = 0
8393
loss = 0
8494
cnt = 0
8595

86-
for epoch in range(self.num_learning_epochs): # TODO unify num_steps_per_env and gradient_length
87-
self.policy.reset()
96+
for epoch in range(self.num_learning_epochs):
97+
self.policy.reset(hidden_states=self.last_hidden_states)
8898
self.policy.detach_hidden_states()
89-
for obs, _, _, privileged_actions in self.storage.generator():
99+
for obs, _, _, privileged_actions, dones in self.storage.generator():
90100

91101
# inference the student for gradient computation
92102
actions = self.policy.act_inference(obs)
93103

94-
# behaviour cloning loss
95-
behaviour_loss = nn.functional.mse_loss(actions, privileged_actions)
104+
# behavior cloning loss
105+
behavior_loss = self.loss_fn(actions, privileged_actions)
96106

97107
# total loss
98-
loss = loss + behaviour_loss
99-
100-
mean_behaviour_loss += behaviour_loss.item()
108+
loss = loss + behavior_loss
109+
mean_behavior_loss += behavior_loss.item()
101110
cnt += 1
102111

103112
# gradient step
@@ -108,11 +117,16 @@ def update(self):
108117
self.policy.detach_hidden_states()
109118
loss = 0
110119

111-
mean_behaviour_loss /= cnt
120+
# reset dones
121+
self.policy.reset(dones.view(-1))
122+
self.policy.detach_hidden_states(dones.view(-1))
123+
124+
mean_behavior_loss /= cnt
112125
self.storage.clear()
113-
self.policy.reset() # TODO needed?
126+
self.last_hidden_states = self.policy.get_hidden_states()
127+
self.policy.detach_hidden_states()
114128

115129
# construct the loss dictionary
116-
loss_dict = {"behaviour": mean_behaviour_loss}
130+
loss_dict = {"behavior": mean_behavior_loss}
117131

118132
return loss_dict

Diff for: rsl_rl/modules/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from .normalizer import EmpiricalNormalization
1111
from .rnd import RandomNetworkDistillation
1212
from .student_teacher import StudentTeacher
13+
from .student_teacher_recurrent import StudentTeacherRecurrent
1314

1415
__all__ = [
1516
"ActorCritic",
1617
"ActorCriticRecurrent",
1718
"EmpiricalNormalization",
1819
"RandomNetworkDistillation",
1920
"StudentTeacher",
21+
"StudentTeacherRecurrent",
2022
]

Diff for: rsl_rl/modules/actor_critic_recurrent.py

+8-39
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55

66
from __future__ import annotations
77

8-
import torch
9-
import torch.nn as nn
10-
11-
from rsl_rl.modules.actor_critic import ActorCritic
12-
from rsl_rl.utils import resolve_nn_activation, unpad_trajectories
8+
from rsl_rl.modules import ActorCritic
9+
from rsl_rl.networks import Memory
10+
from rsl_rl.utils import resolve_nn_activation
1311

1412

1513
class ActorCriticRecurrent(ActorCritic):
@@ -24,7 +22,7 @@ def __init__(
2422
critic_hidden_dims=[256, 256, 256],
2523
activation="elu",
2624
rnn_type="lstm",
27-
rnn_hidden_size=256,
25+
rnn_hidden_dim=256,
2826
rnn_num_layers=1,
2927
init_noise_std=1.0,
3028
**kwargs,
@@ -35,8 +33,8 @@ def __init__(
3533
)
3634

3735
super().__init__(
38-
num_actor_obs=rnn_hidden_size,
39-
num_critic_obs=rnn_hidden_size,
36+
num_actor_obs=rnn_hidden_dim,
37+
num_critic_obs=rnn_hidden_dim,
4038
num_actions=num_actions,
4139
actor_hidden_dims=actor_hidden_dims,
4240
critic_hidden_dims=critic_hidden_dims,
@@ -46,8 +44,8 @@ def __init__(
4644

4745
activation = resolve_nn_activation(activation)
4846

49-
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
50-
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
47+
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
48+
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
5149

5250
print(f"Actor RNN: {self.memory_a}")
5351
print(f"Critic RNN: {self.memory_c}")
@@ -70,32 +68,3 @@ def evaluate(self, critic_observations, masks=None, hidden_states=None):
7068

7169
def get_hidden_states(self):
7270
return self.memory_a.hidden_states, self.memory_c.hidden_states
73-
74-
75-
class Memory(torch.nn.Module):
76-
def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256):
77-
super().__init__()
78-
# RNN
79-
rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
80-
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
81-
self.hidden_states = None
82-
83-
def forward(self, input, masks=None, hidden_states=None):
84-
batch_mode = masks is not None
85-
if batch_mode:
86-
# batch mode (policy update): need saved hidden states
87-
if hidden_states is None:
88-
raise ValueError("Hidden states not passed to memory module during policy update")
89-
out, _ = self.rnn(input, hidden_states)
90-
out = unpad_trajectories(out, masks)
91-
else:
92-
# inference mode (collection): use hidden states of last step
93-
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
94-
return out
95-
96-
def reset(self, dones=None):
97-
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
98-
if self.hidden_states is None:
99-
return
100-
for hidden_state in self.hidden_states:
101-
hidden_state[..., dones == 1, :] = 0.0

Diff for: rsl_rl/modules/student_teacher.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
# disable args validation for speedup
7373
Normal.set_default_validate_args = False
7474

75-
def reset(self, dones=None):
75+
def reset(self, dones=None, hidden_states=None):
7676
pass
7777

7878
def forward(self):
@@ -128,6 +128,9 @@ def load_state_dict(self, state_dict, strict=True):
128128
if "actor." in key:
129129
teacher_state_dict[key.replace("actor.", "")] = value
130130
self.teacher.load_state_dict(teacher_state_dict, strict=strict)
131+
# also load recurrent memory if teacher is recurrent
132+
if self.is_recurrent and self.teacher_recurrent:
133+
raise NotImplementedError("Loading recurrent memory for the teacher is not implemented yet") # TODO
131134
# set flag for successfully loading the parameters
132135
self.loaded_teacher = True
133136
self.teacher.eval()
@@ -141,5 +144,8 @@ def load_state_dict(self, state_dict, strict=True):
141144
else:
142145
raise ValueError("state_dict does not contain student or teacher parameters")
143146

144-
def detach_hidden_states(self):
147+
def get_hidden_states(self):
148+
return None
149+
150+
def detach_hidden_states(self, dones=None):
145151
pass

Diff for: rsl_rl/modules/student_teacher_recurrent.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
from __future__ import annotations
7+
8+
from rsl_rl.modules import StudentTeacher
9+
from rsl_rl.networks import Memory
10+
from rsl_rl.utils import resolve_nn_activation
11+
12+
13+
class StudentTeacherRecurrent(StudentTeacher):
14+
is_recurrent = True
15+
16+
def __init__(
17+
self,
18+
num_student_obs,
19+
num_teacher_obs,
20+
num_actions,
21+
student_hidden_dims=[256, 256, 256],
22+
teacher_hidden_dims=[256, 256, 256],
23+
activation="elu",
24+
rnn_type="lstm",
25+
rnn_hidden_dim=256,
26+
rnn_num_layers=1,
27+
init_noise_std=0.1,
28+
teacher_recurrent=False,
29+
**kwargs,
30+
):
31+
if kwargs:
32+
print(
33+
"StudentTeacherRecurrent.__init__ got unexpected arguments, which will be ignored: "
34+
+ str(kwargs.keys()),
35+
)
36+
37+
self.teacher_recurrent = teacher_recurrent
38+
39+
super().__init__(
40+
num_student_obs=rnn_hidden_dim,
41+
num_teacher_obs=rnn_hidden_dim if teacher_recurrent else num_teacher_obs,
42+
num_actions=num_actions,
43+
student_hidden_dims=student_hidden_dims,
44+
teacher_hidden_dims=teacher_hidden_dims,
45+
activation=activation,
46+
init_noise_std=init_noise_std,
47+
)
48+
49+
activation = resolve_nn_activation(activation)
50+
51+
self.memory_s = Memory(num_student_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim)
52+
if self.teacher_recurrent:
53+
self.memory_t = Memory(
54+
num_teacher_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim
55+
)
56+
57+
print(f"Student RNN: {self.memory_s}")
58+
if self.teacher_recurrent:
59+
print(f"Teacher RNN: {self.memory_t}")
60+
61+
def reset(self, dones=None, hidden_states=None):
62+
if hidden_states is None:
63+
hidden_states = (None, None)
64+
self.memory_s.reset(dones, hidden_states[0])
65+
if self.teacher_recurrent:
66+
self.memory_t.reset(dones, hidden_states[1])
67+
68+
def act(self, observations):
69+
input_s = self.memory_s(observations)
70+
return super().act(input_s.squeeze(0))
71+
72+
def act_inference(self, observations):
73+
input_s = self.memory_s(observations)
74+
return super().act_inference(input_s.squeeze(0))
75+
76+
def evaluate(self, teacher_observations):
77+
if self.teacher_recurrent:
78+
teacher_observations = self.memory_t(teacher_observations)
79+
return super().evaluate(teacher_observations.squeeze(0))
80+
81+
def get_hidden_states(self):
82+
if self.teacher_recurrent:
83+
return self.memory_s.hidden_states, self.memory_t.hidden_states
84+
else:
85+
return self.memory_s.hidden_states, None
86+
87+
def detach_hidden_states(self, dones=None):
88+
self.memory_s.detach_hidden_states(dones)
89+
if self.teacher_recurrent:
90+
self.memory_t.detach_hidden_states(dones)

Diff for: rsl_rl/networks/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
"""Definitions for neural networks."""
7+
8+
from .memory import Memory
9+
10+
__all__ = ["Memory"]

Diff for: rsl_rl/networks/memory.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2+
# All rights reserved.
3+
#
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
6+
from __future__ import annotations
7+
8+
import torch
9+
import torch.nn as nn
10+
11+
from rsl_rl.utils import unpad_trajectories
12+
13+
14+
class Memory(torch.nn.Module):
15+
def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256):
16+
super().__init__()
17+
# RNN
18+
rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
19+
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
20+
self.hidden_states = None
21+
22+
def forward(self, input, masks=None, hidden_states=None):
23+
batch_mode = masks is not None
24+
if batch_mode:
25+
# batch mode: needs saved hidden states
26+
if hidden_states is None:
27+
raise ValueError("Hidden states not passed to memory module during policy update")
28+
out, _ = self.rnn(input, hidden_states)
29+
out = unpad_trajectories(out, masks)
30+
else:
31+
# inference/distillation mode: uses hidden states of last step
32+
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
33+
return out
34+
35+
def reset(self, dones=None, hidden_states=None):
36+
if dones is None: # reset all hidden states
37+
if hidden_states is None:
38+
self.hidden_states = None
39+
else:
40+
self.hidden_states = hidden_states
41+
elif self.hidden_states is not None: # reset hidden states of done environments
42+
if hidden_states is None:
43+
if isinstance(self.hidden_states, tuple): # tuple in case of LSTM
44+
for hidden_state in self.hidden_states:
45+
hidden_state[..., dones == 1, :] = 0.0
46+
else:
47+
self.hidden_states[..., dones == 1, :] = 0.0
48+
else:
49+
NotImplementedError(
50+
"Resetting hidden states of done environments with custom hidden states is not implemented"
51+
)
52+
53+
def detach_hidden_states(self, dones=None):
54+
if self.hidden_states is not None:
55+
if dones is None: # detach all hidden states
56+
if isinstance(self.hidden_states, tuple): # tuple in case of LSTM
57+
self.hidden_states = tuple(hidden_state.detach() for hidden_state in self.hidden_states)
58+
else:
59+
self.hidden_states = self.hidden_states.detach()
60+
else: # detach hidden states of done environments
61+
if isinstance(self.hidden_states, tuple): # tuple in case of LSTM
62+
for hidden_state in self.hidden_states:
63+
hidden_state[..., dones == 1, :] = hidden_state[..., dones == 1, :].detach()
64+
else:
65+
self.hidden_states[..., dones == 1, :] = self.hidden_states[..., dones == 1, :].detach()

0 commit comments

Comments
 (0)