Skip to content

Commit 961f3be

Browse files
committed
fix(pu): fix recur kv pool index compatibility
1 parent b7c7016 commit 961f3be

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

lzero/model/unizero_world_models/world_model.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def custom_init(module):
167167

168168
self.num_simulations = getattr(self.config, 'num_simulations', 50)
169169

170+
# TODO: recur kv pool是否应该分成不同的环境有不同的pool呢
170171
self.shared_pool_size_recur = int(self.num_simulations*self.env_num)
171172

172173
# self.shared_pool_size_init = int(50) # NOTE: Will having too many cause incorrect retrieval of the kv cache?
@@ -1497,9 +1498,22 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int,
14971498
matched_value = None
14981499

14991500
# If not found, try to retrieve from past_kv_cache_recurrent_infer
1501+
# if matched_value is None:
1502+
# matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)]
1503+
1504+
# ==================== 核心修复 ====================
1505+
# 步骤 2: 仅当在 init_infer 中未找到时,才尝试从 recurrent_infer 缓存中查找
15001506
if matched_value is None:
1501-
matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)]
1507+
# 2.1 安全地从字典中获取索引,它可能返回 None
1508+
recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key)
1509+
# 2.2 只有在索引有效(不是 None)的情况下,才使用它来从物理池中检索值
1510+
if recur_cache_index is not None:
1511+
matched_value = self.shared_pool_recur_infer[recur_cache_index]
1512+
1513+
if recur_cache_index is None:
1514+
print(f"[CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.")
15021515

1516+
# =================================================
15031517
# # TODO
15041518
# retrieved_cache = matched_value._keys_values[0]._k_cache._cache
15051519
# retrieved_sum = torch.sum(retrieved_cache).item()

zoo/atari/config/atari_unizero_segment_config.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,25 @@ def main(env_id, seed):
99
# ==============================================================
1010
# begin of the most frequently changed config specified by the user
1111
# ==============================================================
12-
# collector_env_num = 8
13-
# num_segments = 8
14-
# evaluator_env_num = 3
12+
collector_env_num = 8
13+
num_segments = 8
14+
evaluator_env_num = 3
1515

16-
collector_env_num = 1
17-
num_segments = 1
18-
evaluator_env_num = 1
16+
# collector_env_num = 1
17+
# num_segments = 1
18+
# evaluator_env_num = 1
1919

2020
num_simulations = 50
2121
collect_num_simulations = 25
2222
# collect_num_simulations = 50
2323
eval_num_simulations = 50
2424
# max_env_step = int(5e5)
2525
max_env_step = int(50e6)
26-
# batch_size = 256
27-
batch_size = 64 # debug
26+
batch_size = 256
27+
# batch_size = 64 # debug
2828
num_layers = 2
29-
replay_ratio = 0.25
30-
# replay_ratio = 0.1
29+
# replay_ratio = 0.25
30+
replay_ratio = 0.1
3131

3232
game_segment_length = 20
3333
num_unroll_steps = 10
@@ -114,14 +114,14 @@ def main(env_id, seed):
114114
# final_norm_option_in_obs_head="LayerNorm",
115115
# predict_latent_loss_type='mse',
116116

117-
final_norm_option_in_encoder='L2Norm',
118-
final_norm_option_in_obs_head="L2Norm",
119-
predict_latent_loss_type='mse',
120-
121-
# final_norm_option_in_encoder="LayerNorm",
122-
# final_norm_option_in_obs_head="LayerNorm",
117+
# final_norm_option_in_encoder='L2Norm',
118+
# final_norm_option_in_obs_head="L2Norm",
123119
# predict_latent_loss_type='mse',
124120

121+
final_norm_option_in_encoder="LayerNorm",
122+
final_norm_option_in_obs_head="LayerNorm",
123+
predict_latent_loss_type='mse',
124+
125125
# final_norm_option_in_encoder="SimNorm",
126126
# final_norm_option_in_obs_head="SimNorm",
127127
# predict_latent_loss_type='group_kl',
@@ -192,7 +192,17 @@ def main(env_id, seed):
192192

193193
# ============ use muzero_segment_collector instead of muzero_collector =============
194194
from lzero.entry import train_unizero_segment
195-
main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear20_muzerolossweight_spsi20_envnum{collector_env_num}_encoder-head-l2norm_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
195+
196+
main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear20_mulossweight_spsi20_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
197+
198+
199+
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear20_origlossweight_spsi20_envnum{collector_env_num}_encoder-head-l2norm_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
200+
201+
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear20_origlossweight_spsi20_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
202+
203+
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear20_muzerolossweight_spsi20_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
204+
205+
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear20_origlossweight_spsi20_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
196206

197207
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_lrucache-init-recur_clear20_muzerolossweight_spsi20_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
198208

@@ -246,7 +256,7 @@ def main(env_id, seed):
246256
main(args.env, args.seed)
247257

248258
"""
249-
export CUDA_VISIBLE_DEVICES=4
259+
export CUDA_VISIBLE_DEVICES=0
250260
cd /fs-computility/niuyazhe/puyuan/code/LightZero
251261
python /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/atari/config/atari_unizero_segment_config.py
252262
"""

0 commit comments

Comments
 (0)