Skip to content

Commit b7c7016

Browse files
committed
fix(pu): fix init_recur kv share pool index
1 parent 22cba86 commit b7c7016

File tree

4 files changed

+86
-23
lines changed

4 files changed

+86
-23
lines changed

lzero/model/unizero_world_models/world_model.py

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _debug_check_for_stale_pointers(self, env_id: int, current_key: Any, index_t
231231

232232
# 打印详细的调试信息
233233
print("="*60)
234-
print(f"!!! BUG CONDITION DETECTED (Detection #{self.stale_pointer_detections}) !!!")
234+
print(f"!!! INIT BUG CONDITION DETECTED (Detection #{self.stale_pointer_detections}) !!!")
235235
print(f" Environment ID: {env_id}")
236236
print(f" Pool Index to be overwritten: {index_to_be_written}")
237237
print(f" New state hash being written: '{current_key}'")
@@ -466,7 +466,18 @@ def _initialize_cache_structures(self) -> None:
466466
from collections import defaultdict
467467
# self.past_kv_cache_recurrent_infer = defaultdict(dict)
468468
# 使用 LRUCache 替换 defaultdict,并同步容量
469-
self.past_kv_cache_recurrent_infer = LRUCache(self.shared_pool_size_recur)
469+
470+
# ========================= 核心修复与注释 (Recurrent Infer) =========================
471+
# 问题: recurrent_infer 缓存同样存在 LRUCache 与环形缓冲区逻辑不匹配的问题。
472+
#
473+
# 修复方案:
474+
# 1. 将 past_kv_cache_recurrent_infer 从 LRUCache 改为标准字典。
475+
# 2. 引入辅助列表 pool_idx_to_key_map_recur_infer 来维护反向映射。
476+
# 这确保了在覆写 recurrent 数据池中的条目时,可以同步删除旧的指针。
477+
478+
self.past_kv_cache_recurrent_infer = {}
479+
self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur
480+
# ========================== 修复结束 ==========================
470481

471482
# self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)]
472483

@@ -490,7 +501,23 @@ def _initialize_cache_structures(self) -> None:
490501
# 完全同步。当数据池的索引0被新数据覆盖时,指向旧索引0的指针也已被自动清除。
491502
# 3. 杜绝污染: 从根本上解决了Episode内部的状态哈希碰撞问题。
492503

493-
self.past_kv_cache_init_infer_envs = [LRUCache(self.shared_pool_size_init) for _ in range(self.env_num)]
504+
# self.past_kv_cache_init_infer_envs = [LRUCache(self.shared_pool_size_init-1) for _ in range(self.env_num)]
505+
# ========================== 修复结束 ==========================
506+
507+
# ========================= 核心修复与注释 =========================
508+
# 问题: LRUCache 的淘汰逻辑(基于访问顺序)与环形缓冲区的覆写逻辑(基于写入顺序)不匹配,导致指针过时。
509+
#
510+
# 修复方案:
511+
# 1. 使用一个标准的字典 `past_kv_cache_init_infer_envs` 来存储 {state_hash -> pool_index}。
512+
# 2. 引入一个辅助列表 `pool_idx_to_key_map_init_envs` 来维护反向映射 {pool_index -> state_hash}。
513+
#
514+
# 效果:
515+
# 在向环形缓冲区的某个索引写入新数据之前,我们可以通过辅助列表立即找到即将被覆盖的旧 state_hash,
516+
# 并从主字典中精确地删除这个过时的条目。这确保了字典和数据池的完全同步。
517+
518+
self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)]
519+
# 辅助数据结构,用于反向查找:pool_index -> key
520+
self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)]
494521
# ========================== 修复结束 ==========================
495522

496523
self.keys_values_wm_list = []
@@ -1365,34 +1392,68 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
13651392

13661393
if is_init_infer:
13671394
# TODO
1368-
# ==================== DEBUG CODE INSERTION ====================
1369-
# 在写入之前,先获取将要写入的索引
1395+
# ==================== 主动淘汰修复逻辑 ====================
1396+
# 1. 获取即将被覆写的物理索引
13701397
index_to_write = self.shared_pool_index_init_envs[i]
1398+
1399+
# 2. 使用辅助列表查找该索引上存储的旧的 key
1400+
old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write]
1401+
1402+
# 3. 如果存在旧 key,就从主 cache map 中删除它
1403+
if old_key_to_evict is not None:
1404+
# 确保要删除的键确实存在,避免意外错误
1405+
if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]:
1406+
del self.past_kv_cache_init_infer_envs[i][old_key_to_evict]
1407+
1408+
# 现在可以安全地写入新数据了
1409+
cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i)
1410+
1411+
# 4. 在主 cache map 和辅助列表中同时更新新的映射关系
1412+
self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index
1413+
self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key
1414+
13711415
# 调用调试函数进行检查
13721416
self._debug_check_for_stale_pointers(env_id=i, current_key=cache_key, index_to_be_written=index_to_write)
13731417
# ============================================================
13741418

13751419
# Store the latest key-value cache for initial inference
1376-
cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i)
1377-
self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index
1420+
# cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i)
1421+
# self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index
13781422
else:
13791423
# TODO 获取要存入的cache的某个唯一标识,例如tensor的和
13801424
# cache_to_store = self.keys_values_wm_single_env._keys_values[0]._k_cache._cache
13811425
# cache_sum = torch.sum(cache_to_store).item()
13821426
# cache_shape = cache_to_store.shape
13831427
# print(f"[CACHE WRITE] Storing for key={cache_key}, cache_shape={cache_shape}, cache_sum={cache_sum:.4f}")
13841428

1385-
# ==================== DEBUG CODE INSERTION ====================
1386-
# 在写入之前,获取将要写入的索引
1429+
# ==================== RECURRENT INFER FIX ====================
1430+
# 1. 获取即将被覆写的物理索引
13871431
index_to_write = self.shared_pool_index
1432+
1433+
# 2. 使用辅助列表查找该索引上存储的旧的 key
1434+
old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write]
1435+
1436+
# 3. 如果存在旧 key,就从主 cache map 中删除它
1437+
if old_key_to_evict is not None:
1438+
if old_key_to_evict in self.past_kv_cache_recurrent_infer:
1439+
del self.past_kv_cache_recurrent_infer[old_key_to_evict]
1440+
1441+
# 4. 现在可以安全地写入新数据了
1442+
cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env)
1443+
1444+
# 5. 在主 cache map 和辅助列表中同时更新新的映射关系
1445+
self.past_kv_cache_recurrent_infer[cache_key] = cache_index
1446+
self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key
1447+
# ============================================================
1448+
1449+
# ==================== DEBUG CODE INSERTION ====================
13881450
# 调用调试函数进行检查
13891451
self._debug_check_for_stale_pointers_recur(current_key=cache_key, index_to_be_written=index_to_write)
13901452
# ============================================================
13911453

1392-
13931454
# Store the latest key-value cache for recurrent inference
1394-
cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env)
1395-
self.past_kv_cache_recurrent_infer[cache_key] = cache_index
1455+
# cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env)
1456+
# self.past_kv_cache_recurrent_infer[cache_key] = cache_index
13961457

13971458

13981459
def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int,

lzero/policy/unizero.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,6 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in
955955

956956
clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length
957957

958-
959958

960959
# Clear caches if the current steps are a multiple of the clear interval
961960
if current_steps is not None and current_steps % clear_interval == 0:
@@ -971,8 +970,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in
971970
# Free up GPU memory
972971
torch.cuda.empty_cache()
973972

974-
print('collector: collect_model clear()')
975-
print(f'eps_steps_lst[{env_id}]: {current_steps}')
973+
print(f'eps_steps_lst[{env_id}]: {current_steps}, collector: collect_model clear()')
976974

977975

978976
def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None:

lzero/worker/muzero_segment_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,9 @@ def collect(self,
603603

604604
# ============ TODO(pu): only for UniZero now ============
605605
if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']:
606-
if eps_steps_lst[env_id]>=self.policy_config.game_segment_length:
606+
if eps_steps_lst[env_id]>self.policy_config.game_segment_length:
607607
self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False)
608-
print("eps_steps_lst[env_id]>=self.policy_config.game_segment_length")
608+
print(f"eps_steps_lst[env_id]>self.policy_config.game_segment_length:{eps_steps_lst[env_id]}>{self.policy_config.game_segment_length}")
609609

610610
# if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']:
611611
# self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False)

zoo/atari/config/atari_unizero_segment_config.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def main(env_id, seed):
2626
# batch_size = 256
2727
batch_size = 64 # debug
2828
num_layers = 2
29-
# replay_ratio = 0.25
30-
replay_ratio = 0.1
29+
replay_ratio = 0.25
30+
# replay_ratio = 0.1
3131

3232
game_segment_length = 20
3333
num_unroll_steps = 10
@@ -192,7 +192,11 @@ def main(env_id, seed):
192192

193193
# ============ use muzero_segment_collector instead of muzero_collector =============
194194
from lzero.entry import train_unizero_segment
195-
main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_lrucache-init-recur_clear20_muzerolossweight_spsi20_envnum{collector_env_num}_encoder-head-l2norm_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
195+
main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_fix-init-recur_clear20_muzerolossweight_spsi20_envnum{collector_env_num}_encoder-head-l2norm_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
196+
197+
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_lrucache-init-recur_clear20_muzerolossweight_spsi20_envnum{collector_env_num}_encoder-head-ln_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
198+
199+
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_lrucache-init-recur_clear20_muzerolossweight_spsi20_envnum{collector_env_num}_encoder-head-l2norm_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
196200
# main_config.exp_name = f'data_unizero_longrun_20250819_debug/{env_id[:-14]}/{env_id[:-14]}_uz_lrucache_clear20_muzerolossweight_spsi20_envnum{collector_env_num}_encoder-head-l2norm_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
197201

198202
# main_config.exp_name = f'data_unizero_longrun_20250819/{env_id[:-14]}/{env_id[:-14]}_uz_clear40_muzerolossweight_spsi20_envnum8_encoder-head-l2norm_soft-target-005_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_c25_seed{seed}'
@@ -226,13 +230,13 @@ def main(env_id, seed):
226230

227231
args.env = 'MsPacmanNoFrameskip-v4'
228232
# args.env = 'QbertNoFrameskip-v4'
233+
# args.env = 'SeaquestNoFrameskip-v4'
229234

230235
# args.env = 'SpaceInvadersNoFrameskip-v4'
236+
231237
# args.env = 'BeamRiderNoFrameskip-v4'
232238
# args.env = 'GravitarNoFrameskip-v4'
233239

234-
235-
# args.env = 'SeaquestNoFrameskip-v4'
236240
# args.env = 'BreakoutNoFrameskip-v4'
237241

238242

@@ -242,7 +246,7 @@ def main(env_id, seed):
242246
main(args.env, args.seed)
243247

244248
"""
245-
export CUDA_VISIBLE_DEVICES=6
249+
export CUDA_VISIBLE_DEVICES=4
246250
cd /fs-computility/niuyazhe/puyuan/code/LightZero
247251
python /fs-computility/niuyazhe/puyuan/code/LightZero/zoo/atari/config/atari_unizero_segment_config.py
248252
"""

0 commit comments

Comments
 (0)