diff --git a/pfrl/agents/acer.py b/pfrl/agents/acer.py index 47ef07d5d..1f80997e3 100644 --- a/pfrl/agents/acer.py +++ b/pfrl/agents/acer.py @@ -170,12 +170,19 @@ def evaluator(action): def get_params_of_distribution(distrib): + """Returns learnable parameters of a given distribution.""" if isinstance(distrib, torch.distributions.Independent): return get_params_of_distribution(distrib.base_dist) elif isinstance(distrib, torch.distributions.Categorical): + assert distrib._param.requires_grad return (distrib._param,) elif isinstance(distrib, torch.distributions.Normal): - return distrib.loc, distrib.scale + # Either loc or scale must be learnable + params = tuple( + param for param in [distrib.loc, distrib.scale] if param.requires_grad + ) + assert len(params) > 0 + return params else: raise NotImplementedError("{} is not supported by ACER".format(type(distrib))) diff --git a/tests/agents_tests/test_acer.py b/tests/agents_tests/test_acer.py index 25fbebc2d..721290042 100644 --- a/tests/agents_tests/test_acer.py +++ b/tests/agents_tests/test_acer.py @@ -15,7 +15,11 @@ from pfrl.experiments.evaluator import run_evaluation_episodes from pfrl.experiments.train_agent_async import train_agent_async from pfrl.nn import ConcatObsAndAction -from pfrl.policies import GaussianHeadWithDiagonalCovariance, SoftmaxCategoricalHead +from pfrl.policies import ( + GaussianHeadWithDiagonalCovariance, + GaussianHeadWithFixedCovariance, + SoftmaxCategoricalHead, +) from pfrl.q_functions import DiscreteActionValueHead from pfrl.replay_buffers import EpisodicReplayBuffer @@ -263,6 +267,15 @@ def test_compute_loss_with_kl_constraint_gaussian(): _test_compute_loss_with_kl_constraint(policy) +def test_compute_loss_with_kl_constraint_gaussian_with_fixed_covariance(): + action_size = 3 + policy = nn.Sequential( + nn.Linear(1, action_size), + GaussianHeadWithFixedCovariance(), + ) + _test_compute_loss_with_kl_constraint(policy) + + def test_compute_loss_with_kl_constraint_softmax(): n_actions = 3 policy = nn.Sequential( @@ -282,11 +295,13 @@ def _test_compute_loss_with_kl_constraint(base_policy): with torch.no_grad(): # Compute KL divergence against the original distribution base_distrib = base_policy(x) + some_action = base_distrib.sample() def base_loss_func(distrib): - # Any loss that tends to increase KL divergence should be ok - kl = torch.distributions.kl_divergence(base_distrib, distrib) - return -(kl + distrib.entropy()) + # Any loss that tends to increase KL divergence should be ok. + # Here I choose to minimize the log probability of some fixed action. + # The loss is clipped to avoid NaN. + return torch.max(distrib.log_prob(some_action), torch.as_tensor(-20.0)) def compute_kl_after_update(loss_func, n=100): policy = copy.deepcopy(base_policy)