-
Notifications
You must be signed in to change notification settings - Fork 158
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Reported in #143
ACER assumes that all the parameters of a distribution (defined by get_params_of_distribution) require grad so that the algorithm can compute the gradient wrt the parameters.
Lines 172 to 180 in 44bf2e4
| def get_params_of_distribution(distrib): | |
| if isinstance(distrib, torch.distributions.Independent): | |
| return get_params_of_distribution(distrib.base_dist) | |
| elif isinstance(distrib, torch.distributions.Categorical): | |
| return (distrib._param,) | |
| elif isinstance(distrib, torch.distributions.Normal): | |
| return distrib.loc, distrib.scale | |
| else: | |
| raise NotImplementedError("{} is not supported by ACER".format(type(distrib))) |
Lines 218 to 221 in 44bf2e4
| distrib_params = get_params_of_distribution(distrib) | |
| for param in distrib_params: | |
| assert param.shape[0] == 1 | |
| assert param.requires_grad |
However, GaussianHeadWithFixedCovariance (
pfrl/pfrl/policies/gaussian_policy.py
Line 96 in 44bf2e4
| class GaussianHeadWithFixedCovariance(nn.Module): |
scale parameter of the torch.distributions.Normal distribution does not require grad, resulting in an assertion error.Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working