Skip to content

Commit

Permalink
fix gaussian policy double squashing (#155)
Browse files Browse the repository at this point in the history
* fix gaussian policy double squashing

* fix gaussian_test
  • Loading branch information
cpnota authored Jul 4, 2020
1 parent 6d1111a commit 2010aca
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion all/policies/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, model, space):
def forward(self, state):
outputs = super().forward(state)
action_dim = outputs.shape[1] // 2
means = self._squash(torch.tanh(outputs[:, 0:action_dim]))
means = self._squash(outputs[:, 0:action_dim])

if not self.training:
return means
Expand Down
4 changes: 2 additions & 2 deletions all/policies/gaussian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def test_converge(self):
def test_eval(self):
state = State(torch.randn(1, STATE_DIM))
dist = self.policy.no_grad(state)
tt.assert_almost_equal(dist.mean, torch.tensor([[-0.229, 0.43, -0.058]]), decimal=3)
tt.assert_almost_equal(dist.mean, torch.tensor([[-0.233, 0.459, -0.058]]), decimal=3)
tt.assert_almost_equal(dist.entropy(), torch.tensor([4.251]), decimal=3)
best = self.policy.eval(state)
tt.assert_almost_equal(best, torch.tensor([[-0.229, 0.43, -0.058]]), decimal=3)
tt.assert_almost_equal(best, torch.tensor([[-0.233, 0.459, -0.058]]), decimal=3)


if __name__ == '__main__':
Expand Down

0 comments on commit 2010aca

Please sign in to comment.