Skip to content

Commit df00d61

Browse files
authored
[BugFix] Fix SACLoss target_entropy="auto" ignoring action space dimensionality (#3292)
1 parent ab35c36 commit df00d61

2 files changed

Lines changed: 24 additions & 1 deletion

File tree

test/test_objectives.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5245,6 +5245,27 @@ def test_state_dict(self, version):
52455245
)
52465246
loss.load_state_dict(state)
52475247

5248+
@pytest.mark.parametrize("action_dim", [1, 2, 4, 8])
5249+
def test_sac_target_entropy_auto(self, version, action_dim):
5250+
"""Regression test for issue #3291: target_entropy='auto' should be -dim(A)."""
5251+
torch.manual_seed(self.seed)
5252+
actor = self._create_mock_actor(action_dim=action_dim)
5253+
qvalue = self._create_mock_qvalue(action_dim=action_dim)
5254+
if version == 1:
5255+
value = self._create_mock_value(action_dim=action_dim)
5256+
else:
5257+
value = None
5258+
5259+
loss_fn = SACLoss(
5260+
actor_network=actor,
5261+
qvalue_network=qvalue,
5262+
value_network=value,
5263+
)
5264+
# target_entropy="auto" should compute -action_dim
5265+
assert (
5266+
loss_fn.target_entropy.item() == -action_dim
5267+
), f"target_entropy should be -{action_dim}, got {loss_fn.target_entropy.item()}"
5268+
52485269
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
52495270
@pytest.mark.parametrize("composite_action_dist", [True, False])
52505271
def test_sac_reduction(self, reduction, version, composite_action_dist):

torchrl/objectives/sac.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,9 @@ def target_entropy(self):
499499
else:
500500
action_container_shape = action_spec.shape
501501
target_entropy = -float(
502-
action_spec.shape[len(action_container_shape) :].numel()
502+
action_spec[self.tensor_keys.action]
503+
.shape[len(action_container_shape) :]
504+
.numel()
503505
)
504506
delattr(self, "_target_entropy")
505507
self.register_buffer(

0 commit comments

Comments
 (0)