Skip to content

Commit 91a7744

Browse files
committed
v2: WIP, improve offline algo performance
1 parent 4e60cf9 commit 91a7744

File tree

3 files changed

+58
-13
lines changed

3 files changed

+58
-13
lines changed

tianshou/data/buffer/base.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,8 @@ def reset(self, keep_statistics: bool = False) -> None:
298298
if not keep_statistics:
299299
self._ep_return, self._ep_len = 0.0, 0
300300

301-
# TODO: is this method really necessary? It's kinda dangerous, can accidentally
302-
# remove all references to collected data
303301
def set_batch(self, batch: RolloutBatchProtocol) -> None:
304-
"""Manually choose the batch you want the ReplayBuffer to manage."""
302+
"""Manually choose the batch you want the ReplayBuffer to manage. Use with caution!."""
305303
assert len(batch) == self.maxsize and set(batch.get_keys()).issubset(
306304
self._reserved_keys,
307305
), "Input batch doesn't meet ReplayBuffer's data form requirement."
@@ -496,12 +494,10 @@ def add(
496494
def sample_indices(self, batch_size: int | None) -> np.ndarray:
497495
"""Get a random sample of index with size = batch_size.
498496
499-
Return all available indices in the buffer if batch_size is 0; return an empty
500-
numpy array if batch_size < 0 or no available index can be sampled.
501-
502-
:param batch_size: the number of indices to be sampled. If None, it will be set
503-
to the length of the buffer (i.e. return all available indices in a
504-
random order).
497+
:param batch_size: the number of indices to be sampled. Three cases are possible:
498+
1. positive int - sample random indices of that size
499+
2. zero - all indices in current order
500+
3. None - all indices but in random order
505501
"""
506502
if batch_size is None:
507503
batch_size = len(self)
@@ -534,8 +530,10 @@ def sample_indices(self, batch_size: int | None) -> np.ndarray:
534530
def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]:
535531
"""Get a random sample from buffer with size = batch_size.
536532
537-
Return all the data in the buffer if batch_size is 0.
538-
533+
:param batch_size: the number of indices to be sampled. Three cases are possible:
534+
1. positive int - sample random indices of that size
535+
2. zero - all indices in current order
536+
3. None - all indices but in random order
539537
:return: Sample data and its corresponding index inside the buffer.
540538
"""
541539
indices = self.sample_indices(batch_size)

tianshou/policy/base.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from abc import ABC, abstractmethod
44
from collections.abc import Callable, Mapping
5+
from copy import copy
56
from dataclasses import dataclass, field
67
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
78

@@ -985,6 +986,47 @@ def update(
985986
)
986987

987988

989+
class OfflineAlgorithmFromOffPolicyAlgorithm(
990+
OfflineAlgorithm[TPolicy],
991+
Generic[TPolicy],
992+
ABC,
993+
):
994+
"""Base class for offline algorithms that use the same data preprocessing as an off-policy algorithm.
995+
996+
Typically used within a diamond inheritance pattern for transforming the respective off-policy algorithm
997+
into a derived offline variant. See usages.
998+
"""
999+
1000+
def __init__(
1001+
self, *, policy: TPolicy, off_policy_algorithm_class: type[OfflineAlgorithm[TPolicy]]
1002+
):
1003+
self._off_policy_algorithm_class = off_policy_algorithm_class
1004+
OfflineAlgorithm.__init__(self, policy=policy)
1005+
1006+
@override
1007+
def process_buffer(self, buffer: TBuffer) -> TBuffer:
1008+
"""Use the off-policy algorithm's batch pre-processing for processing the buffer once before training.
1009+
1010+
This implementation avoids unnecessary re-computation of preprocessing.
1011+
"""
1012+
buffer = copy(buffer)
1013+
batch, indices = buffer.sample(0)
1014+
processed_batch = self._off_policy_algorithm_class._preprocess_batch(
1015+
self, batch, buffer, indices # type: ignore[arg-type]
1016+
)
1017+
buffer.set_batch(processed_batch)
1018+
return buffer
1019+
1020+
@override
1021+
def _preprocess_batch(
1022+
self,
1023+
batch: RolloutBatchProtocol,
1024+
buffer: ReplayBuffer,
1025+
indices: np.ndarray,
1026+
) -> RolloutBatchProtocol | BatchWithReturnsProtocol:
1027+
return batch
1028+
1029+
9881030
class OnPolicyWrapperAlgorithm(
9891031
OnPolicyAlgorithm[TPolicy],
9901032
Generic[TPolicy],

tianshou/policy/imitation/td3_bc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
from tianshou.data import to_torch_as
55
from tianshou.data.types import RolloutBatchProtocol
66
from tianshou.policy import TD3
7-
from tianshou.policy.base import OfflineAlgorithm
7+
from tianshou.policy.base import (
8+
OfflineAlgorithmFromOffPolicyAlgorithm,
9+
)
810
from tianshou.policy.modelfree.ddpg import ContinuousDeterministicPolicy
911
from tianshou.policy.modelfree.td3 import TD3TrainingStats
1012
from tianshou.policy.optim import OptimizerFactory
1113

1214

1315
# NOTE: This uses diamond inheritance to convert from off-policy to offline
14-
class TD3BC(OfflineAlgorithm[ContinuousDeterministicPolicy], TD3): # type: ignore
16+
class TD3BC(OfflineAlgorithmFromOffPolicyAlgorithm[ContinuousDeterministicPolicy], TD3): # type: ignore
1517
"""Implementation of TD3+BC. arXiv:2106.06860."""
1618

1719
def __init__(
@@ -97,6 +99,9 @@ def __init__(
9799
update_actor_freq=update_actor_freq,
98100
estimation_step=estimation_step,
99101
)
102+
OfflineAlgorithmFromOffPolicyAlgorithm.__init__(
103+
self, policy=policy, off_policy_algorithm_class=TD3 # type: ignore[arg-type]
104+
)
100105
self.alpha = alpha
101106

102107
def _update_with_batch(self, batch: RolloutBatchProtocol) -> TD3TrainingStats:

0 commit comments

Comments
 (0)