Skip to content

Commit bc222e8

Browse files
authored
Fix #811 (#817)
1 parent c8be85b commit bc222e8

File tree

4 files changed

+39
-7
lines changed

4 files changed

+39
-7
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ exclude =
88
dist
99
*.egg-info
1010
max-line-length = 87
11-
ignore = B305,W504,B006,B008,B024,W503
11+
ignore = B305,W504,B006,B008,B024,W503,B028
1212

1313
[yapf]
1414
based_on_style = pep8

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,16 @@ def get_install_requires() -> str:
2222
"torch>=1.4.0",
2323
"numba>=0.51.0",
2424
"h5py>=2.10.0", # to match tensorflow's minimal requirements
25-
"protobuf~=3.19.0", # breaking change, sphinx fail
2625
"packaging",
2726
]
2827

2928

3029
def get_extras_require() -> str:
3130
req = {
3231
"dev": [
33-
"sphinx<4",
32+
"sphinx",
3433
"sphinx_rtd_theme",
35-
"jinja2<3.1", # temporary fix
34+
"jinja2",
3635
"sphinxcontrib-bibtex",
3736
"flake8",
3837
"flake8-bugbear",

test/base/test_buffer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,38 @@ def compute_reward_fn(ag, g):
437437
assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep)
438438
assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep)
439439

440+
# Another test case for cycled indices
441+
env_size = 99
442+
bufsize = 15
443+
env = MyGoalEnv(env_size, array_state=False)
444+
buf = HERReplayBuffer(
445+
bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8
446+
)
447+
buf.future_p = 1
448+
for x, ep_len in enumerate([10, 20]):
449+
obs, _ = env.reset()
450+
for i in range(ep_len):
451+
act = 1
452+
obs_next, rew, terminated, truncated, info = env.step(act)
453+
batch = Batch(
454+
obs=obs,
455+
act=[act],
456+
rew=rew,
457+
terminated=(i == ep_len - 1),
458+
truncated=(i == ep_len - 1),
459+
obs_next=obs_next,
460+
info=info
461+
)
462+
if x == 1 and obs["observation"] < 10:
463+
obs = obs_next
464+
continue
465+
buf.add(batch)
466+
obs = obs_next
467+
buf._restore_cache()
468+
sample_indices = np.array([10]) # Suppose the sampled indices is [10]
469+
buf.rewrite_transitions(sample_indices)
470+
assert int(buf.obs.desired_goal[10][0]) in [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
471+
440472

441473
def test_update():
442474
buf1 = ReplayBuffer(4, stack_num=2)

tianshou/data/buffer/her.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,10 @@ def rewrite_transitions(self, indices: np.ndarray) -> None:
120120
# Calculate future timestep to use
121121
current = indices[0]
122122
terminal = indices[-1]
123-
future_offset = np.random.uniform(size=len(indices[0])) * (terminal - current)
124-
future_offset = future_offset.astype(int)
125-
future_t = (current + future_offset)
123+
episodes_len = (terminal - current + self.maxsize) % self.maxsize
124+
future_offset = np.random.uniform(size=len(indices[0])) * episodes_len
125+
future_offset = np.round(future_offset).astype(int)
126+
future_t = (current + future_offset) % self.maxsize
126127

127128
# Compute indices
128129
# open indices are used to find longest, unique trajectories among

0 commit comments

Comments
 (0)