Skip to content

Commit 2eb6d05

Browse files
committed
feature(pu): add toy env to test unizero world_model
1 parent 50bd3c0 commit 2eb6d05

File tree

6 files changed

+2275
-30
lines changed

6 files changed

+2275
-30
lines changed
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# test_world_model_cache.py
2+
import torch
3+
import torch.nn as nn
4+
import numpy as np
5+
from easydict import EasyDict
6+
import csv
7+
import os
8+
9+
# 确保lzero和toy_env在Python路径中
10+
from lzero.model.unizero_world_models.world_model import WorldModel
11+
from toy_env import ToyEnv
12+
from lzero.model.unizero_world_models.utils import hash_state
13+
14+
# ==============================================================================
15+
# Helper classes and functions for the test
16+
# ==============================================================================
17+
18+
class DummyTokenizer:
19+
"""一个用于向量观测的简化分词器。"""
20+
def __init__(self, obs_shape, embed_dim, device):
21+
self.encoder = nn.Linear(obs_shape[0], embed_dim).to(device)
22+
self.device = device
23+
24+
def encode_to_obs_embeddings(self, obs):
25+
obs_tensor = torch.from_numpy(obs).float().to(self.device)
26+
if len(obs_tensor.shape) == 1:
27+
obs_tensor = obs_tensor.unsqueeze(0)
28+
if len(obs_tensor.shape) == 2:
29+
return self.encoder(obs_tensor).unsqueeze(1)
30+
elif len(obs_tensor.shape) == 3:
31+
return self.encoder(obs_tensor).unsqueeze(2)
32+
else:
33+
raise ValueError(f"Unsupported observation tensor shape: {obs_tensor.shape}")
34+
35+
def print_cache_summary(name: str, kv_cache, context_length: int):
36+
"""打印 KeysValues 缓存对象的摘要,并高亮显示截断行为。"""
37+
if kv_cache is None:
38+
print(f" {name}: None")
39+
return 0, "None"
40+
41+
size = kv_cache.size
42+
shape = kv_cache._keys_values[0]._k_cache._cache.shape
43+
status_msg = ""
44+
# 模型在截断时会为未来的(act, obs)等留出空间,所以我们检查是否接近限制
45+
if size >= context_length - 3:
46+
status_msg = f" (!! Approaching/Exceeded Context Limit of {context_length}. Truncation will occur.)"
47+
48+
print(f" {name}: Size = {size}, Shape = {shape}{status_msg}")
49+
return size, f"Size={size}"
50+
51+
# ==============================================================================
52+
# Main Test Function
53+
# ==============================================================================
54+
def test_cache_logic():
55+
# 1. 设置环境和模型配置
56+
env_cfg = ToyEnv.default_config()
57+
env = ToyEnv(env_cfg)
58+
59+
world_model_cfg = EasyDict(
60+
dict(
61+
continuous_action_space=False, num_layers=2, num_heads=4, embed_dim=64,
62+
context_length=8, max_tokens=100, tokens_per_block=2,
63+
action_space_size=env.cfg.action_space_size, env_num=1, obs_type='vector',
64+
device='cuda' if torch.cuda.is_available() else 'cpu', rotary_emb=False,
65+
policy_entropy_weight=0, predict_latent_loss_type='mse', group_size=8,
66+
gamma=0.99, dormant_threshold=0.0, analysis_dormant_ratio=False,
67+
latent_recon_loss_weight=0, perceptual_loss_weight=0, support_size=11,
68+
max_cache_size=1000, final_norm_option_in_obs_head='SimNorm', norm_type='LN',
69+
embed_pdrop=0.1, resid_pdrop=0.1, attn_pdrop=0.1, max_blocks=10, gru_gating=False,
70+
)
71+
)
72+
73+
# 2. 实例化世界模型
74+
tokenizer = DummyTokenizer(env.cfg.observation_shape, world_model_cfg.embed_dim, world_model_cfg.device)
75+
world_model = WorldModel(world_model_cfg, tokenizer).to(world_model_cfg.device)
76+
world_model.eval()
77+
78+
# 3. 设置日志文件
79+
log_filename = "cache_log.csv"
80+
log_filepath = os.path.join(os.getcwd(), log_filename)
81+
print(f"\nLogging statistics to: {log_filepath}")
82+
83+
with open(log_filepath, 'w', newline='') as log_file:
84+
csv_writer = csv.writer(log_file)
85+
header = [
86+
'Timestep', 'Action_Taken', 'Current_State',
87+
'Root_Cache_Hit', 'Root_Cache_Size',
88+
'Recurrent_Cache_Hit', 'Recurrent_Cache_Size',
89+
'Comment'
90+
]
91+
csv_writer.writerow(header)
92+
93+
# 4. 运行一个 episode 并检查缓存
94+
obs_dict = env.reset()
95+
last_action = -1
96+
last_obs_for_infer = np.zeros_like(obs_dict['observation'])
97+
98+
for t in range(env.cfg.collect_max_episode_steps):
99+
print(f"\n{'='*25} Timestep {t} {'='*25}")
100+
print(f"Environment State: Obs = {obs_dict['observation']}, Timestep from Env = {obs_dict['timestep']}")
101+
102+
log_row = {'Timestep': t, 'Current_State': str(obs_dict['observation'])}
103+
104+
# --- 模拟 MCTS 搜索开始 ---
105+
obs_act_dict = {
106+
'obs': last_obs_for_infer, 'action': np.array([last_action]),
107+
'current_obs': obs_dict['observation']
108+
}
109+
print("\n[1. Initial Inference] -> Simulating root node creation for MCTS.")
110+
print(f" Inputs: last_obs={obs_act_dict['obs']}, last_action={obs_act_dict['action']}, current_obs={obs_act_dict['current_obs']}")
111+
112+
with torch.no_grad():
113+
# 注意:start_pos 应该是一个列表或数组,以适应模型的批处理逻辑
114+
_, latent_state, _, _, _ = world_model.forward_initial_inference(
115+
obs_act_dict, start_pos=[obs_dict['timestep']]
116+
)
117+
118+
# --- 检查根节点缓存 ---
119+
print("\n[2. Inspecting Root Node Cache]")
120+
cache_key = hash_state(latent_state.cpu().numpy().flatten())
121+
cache_index = world_model.past_kv_cache_init_infer_envs[0].get(cache_key)
122+
123+
if cache_index is not None:
124+
root_kv_cache = world_model.shared_pool_init_infer[0][cache_index]
125+
log_row['Root_Cache_Hit'] = 'Stored'
126+
size, _ = print_cache_summary("Stored Root KV Cache", root_kv_cache, world_model_cfg.context_length)
127+
log_row['Root_Cache_Size'] = size
128+
else:
129+
log_row['Root_Cache_Hit'] = 'Not_Found'
130+
log_row['Root_Cache_Size'] = 0
131+
print(" Status: Cache Not Found! (This is unexpected after the first step).")
132+
133+
# --- 模拟一步 MCTS 循环推断 ---
134+
action_to_take = env.action_space.sample()
135+
log_row['Action_Taken'] = action_to_take
136+
print(f"\n[3. Recurrent Inference] -> Simulating one search step from the root.")
137+
print(f" Action to explore: {action_to_take}")
138+
139+
state_action_history = [(latent_state.cpu().numpy(), np.array([action_to_take]))]
140+
141+
print(" Checking if root cache is available for recurrent step...")
142+
root_cache_key_for_recur = hash_state(state_action_history[0][0].flatten())
143+
root_cache_index = world_model.past_kv_cache_init_infer_envs[0].get(root_cache_key_for_recur)
144+
if root_cache_index is not None:
145+
log_row['Comment'] = 'Recurrent step found root cache.'
146+
print(" -> Cache Hit! The recurrent step will build upon the existing root cache.")
147+
else:
148+
log_row['Comment'] = 'Recurrent step MISSES root cache!'
149+
print(" -> Cache Miss! The recurrent step will have to regenerate context. (This indicates a problem)")
150+
151+
with torch.no_grad():
152+
# 注意:start_pos 应该是一个列表或数组
153+
_, next_latent_state, _, _, _ = world_model.forward_recurrent_inference(
154+
state_action_history,
155+
start_pos=[obs_dict['timestep']]
156+
)
157+
158+
# --- 检查循环推断节点的缓存 ---
159+
print("\n[4. Inspecting Recurrent Node Cache]")
160+
cache_key_recur = hash_state(next_latent_state.cpu().numpy().flatten())
161+
cache_index_recur = world_model.past_kv_cache_recurrent_infer.get(cache_key_recur)
162+
if cache_index_recur is not None:
163+
recurrent_kv_cache = world_model.shared_pool_recur_infer[cache_index_recur]
164+
log_row['Recurrent_Cache_Hit'] = 'Stored'
165+
size, _ = print_cache_summary("Stored Recurrent KV Cache", recurrent_kv_cache, world_model_cfg.context_length)
166+
log_row['Recurrent_Cache_Size'] = size
167+
else:
168+
log_row['Recurrent_Cache_Hit'] = 'Not_Found'
169+
log_row['Recurrent_Cache_Size'] = 0
170+
print(" Status: Recurrent Cache Not Found! (This is unexpected).")
171+
172+
# --- 环境步进 ---
173+
print("\n[5. Stepping Environment]")
174+
timestep_obj = env.step(action_to_take)
175+
176+
last_action = action_to_take
177+
last_obs_for_infer = obs_dict['observation']
178+
obs_dict = timestep_obj.obs
179+
180+
# 写入日志行
181+
csv_writer.writerow([log_row.get(h, '') for h in header])
182+
183+
if timestep_obj.done:
184+
print("\n" + "="*20 + " Episode Finished " + "="*20)
185+
break
186+
187+
world_model.clear_caches()
188+
print(f"\nTest finished. Log saved to {log_filepath}")
189+
190+
if __name__ == "__main__":
191+
test_cache_logic()
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# toy_env.py
2+
import copy
3+
from typing import List
4+
import gym
5+
import numpy as np
6+
from ding.envs import BaseEnv, BaseEnvTimestep
7+
from ding.utils import ENV_REGISTRY
8+
from easydict import EasyDict
9+
10+
@ENV_REGISTRY.register('toy_lightzero')
11+
class ToyEnv(BaseEnv):
12+
"""
13+
Overview:
14+
A simple, deterministic toy environment for debugging KV cache and long-sequence processing in UniZero.
15+
- State: 4-dim vector.
16+
- Actions: 3 discrete actions (stay, increment, decrement).
17+
- Episode Length: Fixed at 15 steps.
18+
- Returns 'timestep' in observation.
19+
"""
20+
config = dict(
21+
env_id='toy-v0',
22+
env_type='Toy',
23+
observation_shape=(4,),
24+
action_space_size=3,
25+
collect_max_episode_steps=15,
26+
eval_max_episode_steps=15,
27+
manager=dict(shared_memory=False),
28+
stop_value=100,
29+
)
30+
31+
@classmethod
32+
def default_config(cls: type) -> EasyDict:
33+
cfg = EasyDict(copy.deepcopy(cls.config))
34+
cfg.cfg_type = cls.__name__ + 'Dict'
35+
return cfg
36+
37+
def __init__(self, cfg: EasyDict) -> None:
38+
self.cfg = cfg
39+
self._init_flag = False
40+
self._observation_space = gym.spaces.Dict({
41+
'observation': gym.spaces.Box(low=-np.inf, high=np.inf, shape=self.cfg.observation_shape, dtype=np.float32),
42+
'action_mask': gym.spaces.Box(low=0, high=1, shape=(self.cfg.action_space_size,), dtype=np.int8),
43+
'to_play': gym.spaces.Box(low=-1, high=2, shape=(), dtype=np.int8),
44+
'timestep': gym.spaces.Box(low=0, high=self.cfg.collect_max_episode_steps, shape=(), dtype=np.int32),
45+
})
46+
self._action_space = gym.spaces.Discrete(self.cfg.action_space_size)
47+
self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)
48+
49+
def reset(self) -> dict:
50+
if not self._init_flag:
51+
self._init_flag = True
52+
self._state = np.zeros(self.cfg.observation_shape, dtype=np.float32)
53+
self._episode_steps = 0
54+
self._eval_episode_return = 0.0
55+
return self.observe()
56+
57+
def step(self, action: int) -> BaseEnvTimestep:
58+
if action == 1:
59+
self._state += 1
60+
elif action == 2:
61+
self._state -= 1
62+
63+
self._episode_steps += 1
64+
reward = np.array([1.0], dtype=np.float32)
65+
self._eval_episode_return += reward
66+
67+
done = self._episode_steps >= self.cfg.collect_max_episode_steps
68+
info = {}
69+
if done:
70+
info['eval_episode_return'] = self._eval_episode_return
71+
72+
return BaseEnvTimestep(self.observe(), reward, done, info)
73+
74+
def observe(self) -> dict:
75+
return {
76+
'observation': self._state.copy(),
77+
'action_mask': np.ones(self.cfg.action_space_size, dtype=np.int8),
78+
'to_play': np.array(-1, dtype=np.int8),
79+
'timestep': np.array(self._episode_steps, dtype=np.int32)
80+
}
81+
82+
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
83+
self._seed = seed
84+
self._dynamic_seed = dynamic_seed
85+
np.random.seed(self._seed)
86+
87+
def close(self) -> None:
88+
self._init_flag = False
89+
90+
@property
91+
def observation_space(self) -> gym.spaces.Space:
92+
return self._observation_space
93+
94+
@property
95+
def action_space(self) -> gym.spaces.Space:
96+
return self._action_space
97+
98+
@property
99+
def reward_space(self) -> gym.spaces.Space:
100+
return self._reward_space
101+
102+
def __repr__(self) -> str:
103+
return "LightZero Toy Env"
104+
105+
@staticmethod
106+
def create_collector_env_cfg(cfg: dict) -> List[dict]:
107+
collector_env_num = cfg.pop('collector_env_num')
108+
return [cfg for _ in range(collector_env_num)]
109+
110+
@staticmethod
111+
def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
112+
evaluator_env_num = cfg.pop('evaluator_env_num')
113+
return [cfg for _ in range(evaluator_env_num)]

0 commit comments

Comments
 (0)