Skip to content

Commit 2010aca

Browse files
authored
fix gaussian policy double squashing (#155)
* fix gaussian policy double squashing * fix gaussian_test
1 parent 6d1111a commit 2010aca

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

all/policies/gaussian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, model, space):
4848
def forward(self, state):
4949
outputs = super().forward(state)
5050
action_dim = outputs.shape[1] // 2
51-
means = self._squash(torch.tanh(outputs[:, 0:action_dim]))
51+
means = self._squash(outputs[:, 0:action_dim])
5252

5353
if not self.training:
5454
return means

all/policies/gaussian_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ def test_converge(self):
5959
def test_eval(self):
6060
state = State(torch.randn(1, STATE_DIM))
6161
dist = self.policy.no_grad(state)
62-
tt.assert_almost_equal(dist.mean, torch.tensor([[-0.229, 0.43, -0.058]]), decimal=3)
62+
tt.assert_almost_equal(dist.mean, torch.tensor([[-0.233, 0.459, -0.058]]), decimal=3)
6363
tt.assert_almost_equal(dist.entropy(), torch.tensor([4.251]), decimal=3)
6464
best = self.policy.eval(state)
65-
tt.assert_almost_equal(best, torch.tensor([[-0.229, 0.43, -0.058]]), decimal=3)
65+
tt.assert_almost_equal(best, torch.tensor([[-0.233, 0.459, -0.058]]), decimal=3)
6666

6767

6868
if __name__ == '__main__':

0 commit comments

Comments
 (0)