Skip to content

Commit fd93ab3

Browse files
committed
v1: replace all isinstance checks from BatchProtocol to Batch
Seriously improves performance of Batch constructor
1 parent 0db2e74 commit fd93ab3

File tree

4 files changed

+14
-13
lines changed

4 files changed

+14
-13
lines changed

test/base/test_collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def forward(
7272
if self.dict_state:
7373
if self.action_shape:
7474
action_shape = self.action_shape
75-
elif isinstance(batch.obs, BatchProtocol):
75+
elif isinstance(batch.obs, Batch):
7676
action_shape = len(batch.obs["index"])
7777
else:
7878
action_shape = len(batch.obs)

tianshou/data/batch.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
983983
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
984984

985985
def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
986-
if isinstance(batches, BatchProtocol | dict):
986+
if isinstance(batches, Batch | dict):
987987
batches = [batches]
988988
# check input format
989989
batch_list = []
@@ -1069,7 +1069,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
10691069
{
10701070
batch_key
10711071
for batch_key, obj in batch.items()
1072-
if not (isinstance(obj, BatchProtocol) and len(obj.get_keys()) == 0)
1072+
if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0)
10731073
}
10741074
for batch in batches
10751075
]
@@ -1080,7 +1080,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
10801080
if all(isinstance(element, torch.Tensor) for element in value):
10811081
self.__dict__[shared_key] = torch.stack(value, axis)
10821082
# third often
1083-
elif all(isinstance(element, BatchProtocol | dict) for element in value):
1083+
elif all(isinstance(element, Batch | dict) for element in value):
10841084
self.__dict__[shared_key] = Batch.stack(value, axis)
10851085
else: # most often case is np.ndarray
10861086
try:
@@ -1114,7 +1114,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
11141114
value = batch.get(key)
11151115
# TODO: fix code/annotations s.t. the ignores can be removed
11161116
if (
1117-
isinstance(value, BatchProtocol) # type: ignore
1117+
isinstance(value, Batch) # type: ignore
11181118
and len(value.get_keys()) == 0 # type: ignore
11191119
):
11201120
continue # type: ignore
@@ -1288,7 +1288,7 @@ def set_array_at_key(
12881288
) from exception
12891289
else:
12901290
existing_entry = self[key]
1291-
if isinstance(existing_entry, BatchProtocol):
1291+
if isinstance(existing_entry, Batch):
12921292
raise ValueError(
12931293
f"Cannot set sequence at key {key} because it is a nested batch, "
12941294
f"can only set a subsequence of an array.",
@@ -1312,7 +1312,7 @@ def hasnull(self) -> bool:
13121312

13131313
def is_any_true(boolean_batch: BatchProtocol) -> bool:
13141314
for val in boolean_batch.values():
1315-
if isinstance(val, BatchProtocol):
1315+
if isinstance(val, Batch):
13161316
if is_any_true(val):
13171317
return True
13181318
else:
@@ -1375,7 +1375,7 @@ def _apply_batch_values_func_recursively(
13751375
"""
13761376
result = batch if inplace else deepcopy(batch)
13771377
for key, val in batch.__dict__.items():
1378-
if isinstance(val, BatchProtocol):
1378+
if isinstance(val, Batch):
13791379
result[key] = _apply_batch_values_func_recursively(val, values_transform, inplace=False)
13801380
else:
13811381
result[key] = values_transform(val)

tianshou/data/buffer/her.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,12 @@ def rewrite_transitions(self, indices: np.ndarray) -> None:
150150
ep_obs = self[unique_ep_indices].obs
151151
# to satisfy mypy
152152
# TODO: add protocol covering these batches
153-
assert isinstance(ep_obs, BatchProtocol)
153+
assert isinstance(ep_obs, Batch)
154154
ep_rew = self[unique_ep_indices].rew
155155
if self._save_obs_next:
156156
ep_obs_next = self[unique_ep_indices].obs_next
157157
# to satisfy mypy
158-
assert isinstance(ep_obs_next, BatchProtocol)
158+
assert isinstance(ep_obs_next, Batch)
159159
future_obs = self[future_t[unique_ep_close_indices]].obs_next
160160
else:
161161
future_obs = self[self.next(future_t[unique_ep_close_indices])].obs
@@ -172,7 +172,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None:
172172
ep_rew[:, her_ep_indices] = self._compute_reward(ep_obs_next)[:, her_ep_indices]
173173
else:
174174
tmp_ep_obs_next = self[self.next(unique_ep_indices)].obs
175-
assert isinstance(tmp_ep_obs_next, BatchProtocol)
175+
assert isinstance(tmp_ep_obs_next, Batch)
176176
ep_rew[:, her_ep_indices] = self._compute_reward(tmp_ep_obs_next)[:, her_ep_indices]
177177

178178
# Sanity check
@@ -181,7 +181,7 @@ def rewrite_transitions(self, indices: np.ndarray) -> None:
181181
assert ep_rew.shape == unique_ep_indices.shape
182182

183183
# Re-write meta
184-
assert isinstance(self._meta.obs, BatchProtocol)
184+
assert isinstance(self._meta.obs, Batch)
185185
self._meta.obs[unique_ep_indices] = ep_obs
186186
if self._save_obs_next:
187187
self._meta.obs_next[unique_ep_indices] = ep_obs_next # type: ignore

tianshou/policy/modelbased/psrl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TPSRL
236236
for minibatch in batch.split(size=1):
237237
obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next
238238
obs_next = cast(np.ndarray, obs_next)
239-
assert not isinstance(obs, BatchProtocol), "Observations cannot be Batches here"
239+
assert not isinstance(obs, Batch), "Observations cannot be Batches here"
240+
obs = cast(np.ndarray, obs)
240241
trans_count[obs, act, obs_next] += 1
241242
rew_sum[obs, act] += minibatch.rew
242243
rew_square_sum[obs, act] += minibatch.rew**2

0 commit comments

Comments
 (0)