Skip to content

Commit f87e98a

Browse files
Fix wrong observation dimension in recurrent teacher (#136)
Fixes issue #133
1 parent 2ac99c0 commit f87e98a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

rsl_rl/modules/student_teacher_recurrent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def __init__(
8181
# Teacher
8282
if self.teacher_recurrent:
8383
self.memory_t = Memory(num_teacher_obs, rnn_hidden_dim, rnn_num_layers, rnn_type)
84-
self.teacher = MLP(rnn_hidden_dim, num_actions, teacher_hidden_dims, activation)
84+
teacher_input_dim = rnn_hidden_dim if self.teacher_recurrent else num_teacher_obs
85+
self.teacher = MLP(teacher_input_dim, num_actions, teacher_hidden_dims, activation)
8586
if self.teacher_recurrent:
8687
print(f"Teacher RNN: {self.memory_t}")
8788
print(f"Teacher MLP: {self.teacher}")

0 commit comments

Comments
 (0)