|
4 | 4 | import json |
5 | 5 | from datetime import datetime |
6 | 6 | from typing import Any, Dict, List, Optional, Union |
| 7 | +from collections import OrderedDict |
7 | 8 |
|
8 | 9 | import gym |
9 | 10 | import numpy as np |
@@ -53,7 +54,9 @@ class JerichoEnv(BaseEnv): |
53 | 54 | 'save_replay': False, |
54 | 55 | 'save_replay_path': None, |
55 | 56 | 'env_type': "zork1", |
56 | | - 'collect_policy_mode': "agent" |
| 57 | + 'collect_policy_mode': "agent", |
| 58 | + 'use_cache': True, |
| 59 | + 'cache_size': 100000, |
57 | 60 | } |
58 | 61 |
|
59 | 62 | def __init__(self, cfg: Dict[str, Any]) -> None: |
@@ -92,6 +95,13 @@ def __init__(self, cfg: Dict[str, Any]) -> None: |
92 | 95 | self.add_location_and_inventory: bool = self.cfg['add_location_and_inventory'] |
93 | 96 | self.for_unizero: bool = self.cfg['for_unizero'] |
94 | 97 |
|
| 98 | + self.use_cache = self.cfg['use_cache'] |
| 99 | + if self.use_cache: |
| 100 | + self.cache_size = self.cfg['cache_size'] |
| 101 | + self.cache_buffer = OrderedDict() |
| 102 | + print(f'[jericho]: use_cache: {self.use_cache}, cache_size={self.cache_size}') |
| 103 | + |
| 104 | + |
95 | 105 | # Initialize the tokenizer once (only in rank 0 process if distributed) |
96 | 106 | if JerichoEnv.tokenizer is None: |
97 | 107 | if self.rank == 0: |
@@ -134,7 +144,18 @@ def prepare_obs(self, obs: str, return_str: bool = False) -> Dict[str, Any]: |
134 | 144 | and action mask. For unizero, an additional "to_play" key is provided. |
135 | 145 | """ |
136 | 146 | if self._action_list is None: |
137 | | - self._action_list = self._env.get_valid_actions() |
| 147 | + if self.use_cache: |
| 148 | + cache_key = self._env.get_world_state_hash() |
| 149 | + if cache_key in self.cache_buffer: |
| 150 | + self.cache_buffer.move_to_end(cache_key) |
| 151 | + self._action_list = self.cache_buffer[cache_key] |
| 152 | + else: |
| 153 | + self._action_list = self._env.get_valid_actions() |
| 154 | + self.cache_buffer[cache_key] = self._action_list |
| 155 | + if len(self.cache_buffer) > self.cache_size: |
| 156 | + self.cache_buffer.popitem(last=False) |
| 157 | + else: |
| 158 | + self._action_list = self._env.get_valid_actions() |
138 | 159 |
|
139 | 160 | # Filter available actions based on whether stuck actions are removed. |
140 | 161 | if self.remove_stuck_actions: |
|
0 commit comments