Skip to content

Commit 683a17d

Browse files
committed
v1: Allow collection of empty episodes (done on reset)
Slightly enhanced docstrings in collector
1 parent 4089675 commit 683a17d

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

tianshou/data/buffer/buffer_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ def _get_start_stop_tuples_for_edge_crossing_interval(
145145
if stop >= start:
146146
raise ValueError(
147147
f"Expected stop < start, but got {start=}, {stop=}. "
148-
f"For stop larger than start this method should never be called, "
149-
f"and stop=start should never occur. This can occur either due to an implementation error, "
148+
f"For stop larger-equal than start this method should never be called. This can occur either due to an implementation error, "
150149
f"or due a bad configuration of the buffer that resulted in a single episode being so long that "
151150
f"it completely filled a subbuffer (of size len(buffer)/degree_of_vectorization). "
152151
f"Consider either shortening the episode, increasing the size of the buffer, or decreasing the "
@@ -213,7 +212,7 @@ def get_buffer_indices(self, start: int, stop: int) -> np.ndarray:
213212
f"Start and stop indices must be within the same subbuffer. "
214213
f"Got {start=} in subbuffer edge {start_left_edge} and {stop=} in subbuffer edge {stop_left_edge}.",
215214
)
216-
if stop > start:
215+
if stop >= start:
217216
return np.arange(start, stop, dtype=int)
218217
else:
219218
(start, upper_edge), (

tianshou/data/collector.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
_TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None")
4444

45+
TScalarArrayShape = TypeVar("TScalarArrayShape")
46+
4547

4648
class CollectActionBatchProtocol(Protocol):
4749
"""A protocol for results of computing actions from a batch of observations within a single collect step.
@@ -777,10 +779,13 @@ def _collect( # noqa: C901
777779
# TODO: can't do it init since AsyncCollector is currently a subclass of Collector
778780
if self.env.is_async:
779781
raise ValueError(
780-
f"Please use {AsyncCollector.__name__} for asynchronous environments. "
782+
f"Please use AsyncCollector for asynchronous environments. "
781783
f"Env class: {self.env.__class__.__name__}.",
782784
)
783785

786+
ready_env_ids_R: np.ndarray[Any, np.dtype[np.signedinteger]]
787+
"""provides a mapping from local indices (indexing within `1, ..., R` where `R` is the number of ready envs)
788+
to global ones (indexing within `1, ..., num_envs`). So the entry i in this array is the global index of the i-th ready env."""
784789
if n_step is not None:
785790
ready_env_ids_R = np.arange(self.env_num)
786791
elif n_episode is not None:
@@ -914,6 +919,8 @@ def _collect( # noqa: C901
914919
# local_idx - see block comment on class level
915920
# Step 7
916921
env_done_local_idx_D = np.where(done_R)[0]
922+
"""Indexes which episodes are done within the ready envs, so it can be used for selecting from `..._R` arrays.
923+
Stands in contrast to the "global" index, which counts within all envs and is unsuitable for selecting from `..._R` arrays."""
917924
episode_lens_D = ep_len_R[env_done_local_idx_D]
918925
episode_returns_D = ep_return_R[env_done_local_idx_D]
919926
episode_start_indices_D = ep_start_idx_R[env_done_local_idx_D]
@@ -932,6 +939,10 @@ def _collect( # noqa: C901
932939
# 0,...,R and this global index is maintained by the ready_env_ids_R array.
933940
# See the class block comment for more details
934941
env_done_global_idx_D = ready_env_ids_R[env_done_local_idx_D]
942+
"""Indexes which episodes are done within all envs, i.e., within the index `1, ..., num_envs`. It can be
943+
used to communicate with the vector env, where env ids are selected from this "global" index.
944+
Is not suited for selecting from the ready envs (`..._R` arrays), use the local counterpart instead.
945+
"""
935946
obs_reset_DO, info_reset_D = self.env.reset(
936947
env_id=env_done_global_idx_D,
937948
**gym_reset_kwargs,

0 commit comments

Comments
 (0)