-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] AbsorbingStateTransform #2290
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8557,3 +8557,157 @@ def _inv_call(self, tensordict): | |
if self.sampling == self.SamplingStrategy.RANDOM: | ||
action = action + self.jitters * torch.rand_like(self.jitters) | ||
return tensordict.set(self.in_keys_inv[0], action) | ||
|
||
|
||
class AbsorbingStateTransform(ObservationTransform): | ||
"""Adds an absorbing state to the observation space. | ||
|
||
A transform that introduces an absorbing state to the environment. This absorbing state is typically used | ||
in reinforcement learning to handle terminal states by creating an additional state that signifies | ||
the end of an episode but allows for additional steps in the transition to better handle | ||
learning algorithms. | ||
|
||
Args: | ||
max_episode_length (int): Maximum length of an episode. | ||
in_keys (Sequence[NestedKey], optional): Keys to use for input observation. Defaults to ``"observation"``. | ||
out_keys (Sequence[NestedKey], optional): Keys to use for output observation. Defaults to ``in_keys``. | ||
done_key (Optional[NestedKey]): Key indicating if the episode is done. Defaults to ``"done"``. | ||
terminate_key (Optional[NestedKey]): Key indicating if the episode is terminated. Defaults to ``"terminated"``. | ||
|
||
Examples: | ||
>>> from torchrl.envs import GymEnv | ||
>>> t = AbsorbingStateTransform(max_episode_length=1000) | ||
>>> base_env = GymEnv("HalfCheetah-v4") | ||
>>> env = TransformedEnv(base_env, t) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not very informative about the functionality ;) |
||
""" | ||
|
||
def __init__( | ||
self, | ||
max_episode_length: int, | ||
in_keys: Sequence[NestedKey] | None = None, | ||
out_keys: Sequence[NestedKey] | None = None, | ||
done_key: Optional[NestedKey] = "done", | ||
terminate_key: Optional[NestedKey] = "terminated", | ||
): | ||
if in_keys is None: | ||
in_keys = "observation" # default | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ["observation"] no? |
||
if out_keys is None: | ||
out_keys = copy(in_keys) | ||
super().__init__(in_keys=in_keys, out_keys=out_keys) | ||
self.max_episode_length = max_episode_length | ||
self.done_key = done_key | ||
self.terminate_key = terminate_key | ||
self._done = None | ||
self._curr_timestep = 0 | ||
|
||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
raise RuntimeError(FORWARD_NOT_IMPLEMENTED.format(type(self))) | ||
|
||
def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is a version of this that works for all batch sizes. This one will only work with uni of bidimensional batch sizes. |
||
# Check if the observation is batched or not | ||
if observation.dim() == 1: | ||
# Single observation | ||
if self._done: | ||
# Return absorbing state which is [0, ..., 0, 1] | ||
return torch.eye(observation.size(0) + 1)[-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if the observation is more than 1d? |
||
return torch.cat((observation, torch.zeros(1)), dim=-1) | ||
|
||
elif observation.dim() == 2: | ||
# Batched observations | ||
batch_size = observation.size(0) | ||
if self._done: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need an in-place value? How does that work if one sub-env is done and the other not? |
||
# Create absorbing states for the batched observations | ||
absorbing_state = torch.eye(observation.size(1) + 1)[-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is rather wasteful, we creating a big tensor and indexing it, plus this is a view on a storage hence the original storage isn't cleared when you index. Besides it lacks dtype and device. You can create an incomplete |
||
return absorbing_state.expand(batch_size, -1) | ||
zeros = torch.zeros(batch_size, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing device and dtype You could use observation.new_zeros |
||
return torch.cat((observation, zeros), dim=-1) | ||
|
||
else: | ||
raise ValueError( | ||
"Unsupported observation dimension: {}".format(observation.dim()) | ||
) | ||
|
||
def _reset( | ||
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase | ||
) -> TensorDictBase: | ||
self._curr_timestep = 0 | ||
self._done = False | ||
with _set_missing_tolerance(self, True): | ||
return self._call(tensordict_reset) | ||
|
||
def _call(self, tensordict: TensorDictBase) -> TensorDictBase: | ||
parent = self.parent | ||
if parent is None: | ||
raise RuntimeError( | ||
f"{type(self)}.parent cannot be None: make sure this transform is executed within an environment." | ||
) | ||
if self._done: | ||
for in_key, out_key in zip(self.in_keys, self.out_keys): | ||
value = tensordict.get(in_key, default=None) | ||
if value is not None: | ||
observation = self._apply_transform(value) | ||
tensordict.set( | ||
out_key, | ||
observation, | ||
) | ||
elif not self.missing_tolerance: | ||
raise KeyError( | ||
f"{self}: '{in_key}' not found in tensordict {tensordict}" | ||
) | ||
tensordict.set( | ||
self.done_key, torch.ones_like(tensordict.get(self.done_key)).bool() | ||
) | ||
tensordict.set( | ||
self.terminate_key, | ||
torch.ones_like(tensordict.get(self.terminate_key)).bool(), | ||
) | ||
return tensordict | ||
done = tensordict.get(self.done_key) | ||
self._done = done.any() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this means that if any sub-env is done all are done? |
||
# set dones to be true | ||
for in_key, out_key in zip(self.in_keys, self.out_keys): | ||
value = tensordict.get(in_key, default=None) | ||
if value is not None: | ||
observation = self._apply_transform(value) | ||
tensordict.set( | ||
out_key, | ||
observation, | ||
) | ||
elif not self.missing_tolerance: | ||
raise KeyError( | ||
f"{self}: '{in_key}' not found in tensordict {tensordict}" | ||
) | ||
|
||
tensordict.set( | ||
self.done_key, torch.zeros_like(tensordict.get(self.done_key)).bool() | ||
) | ||
tensordict.set( | ||
self.terminate_key, | ||
torch.zeros_like(tensordict.get(self.terminate_key)).bool(), | ||
) | ||
return tensordict | ||
|
||
@property | ||
def is_done(self) -> bool: | ||
return self._done | ||
|
||
@_apply_to_composite | ||
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: | ||
space = observation_spec.space | ||
|
||
if isinstance(space, ContinuousBox): | ||
space.low = self._apply_transform(space.low) | ||
space.high = self._apply_transform(space.high) | ||
observation_spec.shape = space.low.shape | ||
else: | ||
observation_spec.shape = self._apply_transform( | ||
torch.zeros(observation_spec.shape) | ||
).shape | ||
return observation_spec | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"{self.__class__.__name__}(" | ||
f"max_episode_length={self.max_episode_length}, " | ||
f"keys={self.in_keys})" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need tests for this class
It should be registered in the
__init__.py
and put in the doc.