1+ # test_world_model_cache.py
2+ import torch
3+ import torch .nn as nn
4+ import numpy as np
5+ from easydict import EasyDict
6+ import csv
7+ import os
8+
9+ # 确保lzero和toy_env在Python路径中
10+ from lzero .model .unizero_world_models .world_model import WorldModel
11+ from toy_env import ToyEnv
12+ from lzero .model .unizero_world_models .utils import hash_state
13+
14+ # ==============================================================================
15+ # Helper classes and functions for the test
16+ # ==============================================================================
17+
18+ class DummyTokenizer :
19+ """一个用于向量观测的简化分词器。"""
20+ def __init__ (self , obs_shape , embed_dim , device ):
21+ self .encoder = nn .Linear (obs_shape [0 ], embed_dim ).to (device )
22+ self .device = device
23+
24+ def encode_to_obs_embeddings (self , obs ):
25+ obs_tensor = torch .from_numpy (obs ).float ().to (self .device )
26+ if len (obs_tensor .shape ) == 1 :
27+ obs_tensor = obs_tensor .unsqueeze (0 )
28+ if len (obs_tensor .shape ) == 2 :
29+ return self .encoder (obs_tensor ).unsqueeze (1 )
30+ elif len (obs_tensor .shape ) == 3 :
31+ return self .encoder (obs_tensor ).unsqueeze (2 )
32+ else :
33+ raise ValueError (f"Unsupported observation tensor shape: { obs_tensor .shape } " )
34+
35+ def print_cache_summary (name : str , kv_cache , context_length : int ):
36+ """打印 KeysValues 缓存对象的摘要,并高亮显示截断行为。"""
37+ if kv_cache is None :
38+ print (f" { name } : None" )
39+ return 0 , "None"
40+
41+ size = kv_cache .size
42+ shape = kv_cache ._keys_values [0 ]._k_cache ._cache .shape
43+ status_msg = ""
44+ # 模型在截断时会为未来的(act, obs)等留出空间,所以我们检查是否接近限制
45+ if size >= context_length - 3 :
46+ status_msg = f" (!! Approaching/Exceeded Context Limit of { context_length } . Truncation will occur.)"
47+
48+ print (f" { name } : Size = { size } , Shape = { shape } { status_msg } " )
49+ return size , f"Size={ size } "
50+
51+ # ==============================================================================
52+ # Main Test Function
53+ # ==============================================================================
54+ def test_cache_logic ():
55+ # 1. 设置环境和模型配置
56+ env_cfg = ToyEnv .default_config ()
57+ env = ToyEnv (env_cfg )
58+
59+ world_model_cfg = EasyDict (
60+ dict (
61+ continuous_action_space = False , num_layers = 2 , num_heads = 4 , embed_dim = 64 ,
62+ context_length = 8 , max_tokens = 100 , tokens_per_block = 2 ,
63+ action_space_size = env .cfg .action_space_size , env_num = 1 , obs_type = 'vector' ,
64+ device = 'cuda' if torch .cuda .is_available () else 'cpu' , rotary_emb = False ,
65+ policy_entropy_weight = 0 , predict_latent_loss_type = 'mse' , group_size = 8 ,
66+ gamma = 0.99 , dormant_threshold = 0.0 , analysis_dormant_ratio = False ,
67+ latent_recon_loss_weight = 0 , perceptual_loss_weight = 0 , support_size = 11 ,
68+ max_cache_size = 1000 , final_norm_option_in_obs_head = 'SimNorm' , norm_type = 'LN' ,
69+ embed_pdrop = 0.1 , resid_pdrop = 0.1 , attn_pdrop = 0.1 , max_blocks = 10 , gru_gating = False ,
70+ )
71+ )
72+
73+ # 2. 实例化世界模型
74+ tokenizer = DummyTokenizer (env .cfg .observation_shape , world_model_cfg .embed_dim , world_model_cfg .device )
75+ world_model = WorldModel (world_model_cfg , tokenizer ).to (world_model_cfg .device )
76+ world_model .eval ()
77+
78+ # 3. 设置日志文件
79+ log_filename = "cache_log.csv"
80+ log_filepath = os .path .join (os .getcwd (), log_filename )
81+ print (f"\n Logging statistics to: { log_filepath } " )
82+
83+ with open (log_filepath , 'w' , newline = '' ) as log_file :
84+ csv_writer = csv .writer (log_file )
85+ header = [
86+ 'Timestep' , 'Action_Taken' , 'Current_State' ,
87+ 'Root_Cache_Hit' , 'Root_Cache_Size' ,
88+ 'Recurrent_Cache_Hit' , 'Recurrent_Cache_Size' ,
89+ 'Comment'
90+ ]
91+ csv_writer .writerow (header )
92+
93+ # 4. 运行一个 episode 并检查缓存
94+ obs_dict = env .reset ()
95+ last_action = - 1
96+ last_obs_for_infer = np .zeros_like (obs_dict ['observation' ])
97+
98+ for t in range (env .cfg .collect_max_episode_steps ):
99+ print (f"\n { '=' * 25 } Timestep { t } { '=' * 25 } " )
100+ print (f"Environment State: Obs = { obs_dict ['observation' ]} , Timestep from Env = { obs_dict ['timestep' ]} " )
101+
102+ log_row = {'Timestep' : t , 'Current_State' : str (obs_dict ['observation' ])}
103+
104+ # --- 模拟 MCTS 搜索开始 ---
105+ obs_act_dict = {
106+ 'obs' : last_obs_for_infer , 'action' : np .array ([last_action ]),
107+ 'current_obs' : obs_dict ['observation' ]
108+ }
109+ print ("\n [1. Initial Inference] -> Simulating root node creation for MCTS." )
110+ print (f" Inputs: last_obs={ obs_act_dict ['obs' ]} , last_action={ obs_act_dict ['action' ]} , current_obs={ obs_act_dict ['current_obs' ]} " )
111+
112+ with torch .no_grad ():
113+ # 注意:start_pos 应该是一个列表或数组,以适应模型的批处理逻辑
114+ _ , latent_state , _ , _ , _ = world_model .forward_initial_inference (
115+ obs_act_dict , start_pos = [obs_dict ['timestep' ]]
116+ )
117+
118+ # --- 检查根节点缓存 ---
119+ print ("\n [2. Inspecting Root Node Cache]" )
120+ cache_key = hash_state (latent_state .cpu ().numpy ().flatten ())
121+ cache_index = world_model .past_kv_cache_init_infer_envs [0 ].get (cache_key )
122+
123+ if cache_index is not None :
124+ root_kv_cache = world_model .shared_pool_init_infer [0 ][cache_index ]
125+ log_row ['Root_Cache_Hit' ] = 'Stored'
126+ size , _ = print_cache_summary ("Stored Root KV Cache" , root_kv_cache , world_model_cfg .context_length )
127+ log_row ['Root_Cache_Size' ] = size
128+ else :
129+ log_row ['Root_Cache_Hit' ] = 'Not_Found'
130+ log_row ['Root_Cache_Size' ] = 0
131+ print (" Status: Cache Not Found! (This is unexpected after the first step)." )
132+
133+ # --- 模拟一步 MCTS 循环推断 ---
134+ action_to_take = env .action_space .sample ()
135+ log_row ['Action_Taken' ] = action_to_take
136+ print (f"\n [3. Recurrent Inference] -> Simulating one search step from the root." )
137+ print (f" Action to explore: { action_to_take } " )
138+
139+ state_action_history = [(latent_state .cpu ().numpy (), np .array ([action_to_take ]))]
140+
141+ print (" Checking if root cache is available for recurrent step..." )
142+ root_cache_key_for_recur = hash_state (state_action_history [0 ][0 ].flatten ())
143+ root_cache_index = world_model .past_kv_cache_init_infer_envs [0 ].get (root_cache_key_for_recur )
144+ if root_cache_index is not None :
145+ log_row ['Comment' ] = 'Recurrent step found root cache.'
146+ print (" -> Cache Hit! The recurrent step will build upon the existing root cache." )
147+ else :
148+ log_row ['Comment' ] = 'Recurrent step MISSES root cache!'
149+ print (" -> Cache Miss! The recurrent step will have to regenerate context. (This indicates a problem)" )
150+
151+ with torch .no_grad ():
152+ # 注意:start_pos 应该是一个列表或数组
153+ _ , next_latent_state , _ , _ , _ = world_model .forward_recurrent_inference (
154+ state_action_history ,
155+ start_pos = [obs_dict ['timestep' ]]
156+ )
157+
158+ # --- 检查循环推断节点的缓存 ---
159+ print ("\n [4. Inspecting Recurrent Node Cache]" )
160+ cache_key_recur = hash_state (next_latent_state .cpu ().numpy ().flatten ())
161+ cache_index_recur = world_model .past_kv_cache_recurrent_infer .get (cache_key_recur )
162+ if cache_index_recur is not None :
163+ recurrent_kv_cache = world_model .shared_pool_recur_infer [cache_index_recur ]
164+ log_row ['Recurrent_Cache_Hit' ] = 'Stored'
165+ size , _ = print_cache_summary ("Stored Recurrent KV Cache" , recurrent_kv_cache , world_model_cfg .context_length )
166+ log_row ['Recurrent_Cache_Size' ] = size
167+ else :
168+ log_row ['Recurrent_Cache_Hit' ] = 'Not_Found'
169+ log_row ['Recurrent_Cache_Size' ] = 0
170+ print (" Status: Recurrent Cache Not Found! (This is unexpected)." )
171+
172+ # --- 环境步进 ---
173+ print ("\n [5. Stepping Environment]" )
174+ timestep_obj = env .step (action_to_take )
175+
176+ last_action = action_to_take
177+ last_obs_for_infer = obs_dict ['observation' ]
178+ obs_dict = timestep_obj .obs
179+
180+ # 写入日志行
181+ csv_writer .writerow ([log_row .get (h , '' ) for h in header ])
182+
183+ if timestep_obj .done :
184+ print ("\n " + "=" * 20 + " Episode Finished " + "=" * 20 )
185+ break
186+
187+ world_model .clear_caches ()
188+ print (f"\n Test finished. Log saved to { log_filepath } " )
189+
190+ if __name__ == "__main__" :
191+ test_cache_logic ()
0 commit comments