Skip to content

[BUG] ProbabilisticTensorDictModule._dist_sample hasattr error #1152

Open
@olliepro

Description

@olliepro

Describe the bug

The following code section is logically flawed. Bug was introduced here

        elif interaction_type is InteractionType.MEAN:
            if hasattr(dist, "mean"):
                try:
                    return dist.mean
                except NotImplementedError:
                    pass
            if dist.has_rsample:
                return dist.rsample((self.n_empirical_estimate,)).mean(0)
            else:
                return dist.sample((self.n_empirical_estimate,)).mean(0)

The hasattr attempts to access the dist.mean attribute to assess its existence. If the dist.mean raises a NotImplementedError, then hasattr(dist,'mean') will raise a NotImplemented error. The subsequent try block is placed in the wrong spot to catch such an error.

To Reproduce

from tensordict.nn import ProbabilisticTensorDictModule, InteractionType
from torchrl.modules import TanhNormal

prob_module = ProbabilisticTensorDictModule(
     in_keys=["loc", "scale"],
     out_keys=["action"],
     distribution_class=TanhNormal,
     distribution_kwargs={
         "low": -1,
         "high": 1,
     },
     return_log_prob=True,
)

prob_module._dist_sample(
    dist=TanhNormal(low=-1, high=1, loc=0, scale=1),
    interaction_type=InteractionType.MEAN
)

Expected behavior

The code should correctly catch the NotImplementedError exception and then estimate the mean

Suggestion

        elif interaction_type is InteractionType.MEAN:
            try:
                return dist.mean
            except AttributeError, NotImplementedError:
                if dist.has_rsample:
                    return dist.rsample((self.n_empirical_estimate,)).mean(0)
                else:
                    return dist.sample((self.n_empirical_estimate,)).mean(0)

System info

Python 3.10
torch==2.5.1+cu121
torchrl==0.6.0
tensordict==0.6.2

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions