Open
Description
Describe the bug
The following code section is logically flawed. Bug was introduced here
elif interaction_type is InteractionType.MEAN:
if hasattr(dist, "mean"):
try:
return dist.mean
except NotImplementedError:
pass
if dist.has_rsample:
return dist.rsample((self.n_empirical_estimate,)).mean(0)
else:
return dist.sample((self.n_empirical_estimate,)).mean(0)
The hasattr attempts to access the dist.mean attribute to assess its existence. If the dist.mean raises a NotImplementedError, then hasattr(dist,'mean') will raise a NotImplemented error. The subsequent try block is placed in the wrong spot to catch such an error.
To Reproduce
from tensordict.nn import ProbabilisticTensorDictModule, InteractionType
from torchrl.modules import TanhNormal
prob_module = ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=TanhNormal,
distribution_kwargs={
"low": -1,
"high": 1,
},
return_log_prob=True,
)
prob_module._dist_sample(
dist=TanhNormal(low=-1, high=1, loc=0, scale=1),
interaction_type=InteractionType.MEAN
)
Expected behavior
The code should correctly catch the NotImplementedError exception and then estimate the mean
Suggestion
elif interaction_type is InteractionType.MEAN:
try:
return dist.mean
except AttributeError, NotImplementedError:
if dist.has_rsample:
return dist.rsample((self.n_empirical_estimate,)).mean(0)
else:
return dist.sample((self.n_empirical_estimate,)).mean(0)
System info
Python 3.10
torch==2.5.1+cu121
torchrl==0.6.0
tensordict==0.6.2
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)