Skip to content

Commit b400719

Browse files
authored
feature(xjy): add valid_actions cache in the jericho env (#457)
1 parent d4c687a commit b400719

File tree

5 files changed

+32
-3
lines changed

5 files changed

+32
-3
lines changed

zoo/jericho/configs/jericho_ppo_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
collector_env_num=collector_env_num,
3838
evaluator_env_num=evaluator_env_num,
3939
n_evaluator_episode=evaluator_env_num,
40-
manager=dict(shared_memory=False, )
40+
manager=dict(shared_memory=False, ),
41+
use_cache=True,
42+
cache_size=100000,
4143
),
4244
policy=dict(
4345
cuda=True,

zoo/jericho/configs/jericho_unizero_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
9696
evaluator_env_num=evaluator_env_num,
9797
n_evaluator_episode=evaluator_env_num,
9898
manager=dict(shared_memory=False),
99+
use_cache=True,
100+
cache_size=100000,
99101
),
100102
policy=dict(
101103
multi_gpu=False,

zoo/jericho/configs/jericho_unizero_ddp_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e
102102
evaluator_env_num=evaluator_env_num,
103103
n_evaluator_episode=evaluator_env_num,
104104
manager=dict(shared_memory=False),
105+
use_cache=True,
106+
cache_size=100000,
105107
),
106108
policy=dict(
107109
multi_gpu=True, # Important for distributed data parallel (DDP)

zoo/jericho/configs/jericho_unizero_segment_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def main(env_id: str = 'detective.z5', seed: int = 0) -> None:
8383
evaluator_env_num=evaluator_env_num,
8484
n_evaluator_episode=evaluator_env_num,
8585
manager=dict(shared_memory=False),
86+
use_cache=True,
87+
cache_size=100000,
8688
),
8789
policy=dict(
8890
learn=dict(

zoo/jericho/envs/jericho_env.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
from datetime import datetime
66
from typing import Any, Dict, List, Optional, Union
7+
from collections import OrderedDict
78

89
import gym
910
import numpy as np
@@ -53,7 +54,9 @@ class JerichoEnv(BaseEnv):
5354
'save_replay': False,
5455
'save_replay_path': None,
5556
'env_type': "zork1",
56-
'collect_policy_mode': "agent"
57+
'collect_policy_mode': "agent",
58+
'use_cache': True,
59+
'cache_size': 100000,
5760
}
5861

5962
def __init__(self, cfg: Dict[str, Any]) -> None:
@@ -92,6 +95,13 @@ def __init__(self, cfg: Dict[str, Any]) -> None:
9295
self.add_location_and_inventory: bool = self.cfg['add_location_and_inventory']
9396
self.for_unizero: bool = self.cfg['for_unizero']
9497

98+
self.use_cache = self.cfg['use_cache']
99+
if self.use_cache:
100+
self.cache_size = self.cfg['cache_size']
101+
self.cache_buffer = OrderedDict()
102+
print(f'[jericho]: use_cache: {self.use_cache}, cache_size={self.cache_size}')
103+
104+
95105
# Initialize the tokenizer once (only in rank 0 process if distributed)
96106
if JerichoEnv.tokenizer is None:
97107
if self.rank == 0:
@@ -134,7 +144,18 @@ def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]:
134144
and action mask. For unizero, an additional "to_play" key is provided.
135145
"""
136146
if self._action_list is None:
137-
self._action_list = self._env.get_valid_actions()
147+
if self.use_cache:
148+
cache_key = self._env.get_world_state_hash()
149+
if cache_key in self.cache_buffer:
150+
self.cache_buffer.move_to_end(cache_key)
151+
self._action_list = self.cache_buffer[cache_key]
152+
else:
153+
self._action_list = self._env.get_valid_actions()
154+
self.cache_buffer[cache_key] = self._action_list
155+
if len(self.cache_buffer) > self.cache_size:
156+
self.cache_buffer.popitem(last=False)
157+
else:
158+
self._action_list = self._env.get_valid_actions()
138159

139160
# Filter available actions based on whether stuck actions are removed.
140161
if self.remove_stuck_actions:

0 commit comments

Comments
 (0)