@@ -231,7 +231,7 @@ def _debug_check_for_stale_pointers(self, env_id: int, current_key: Any, index_t
231231
232232 # 打印详细的调试信息
233233 print ("=" * 60 )
234- print (f"!!! BUG CONDITION DETECTED (Detection #{ self .stale_pointer_detections } ) !!!" )
234+ print (f"!!! INIT BUG CONDITION DETECTED (Detection #{ self .stale_pointer_detections } ) !!!" )
235235 print (f" Environment ID: { env_id } " )
236236 print (f" Pool Index to be overwritten: { index_to_be_written } " )
237237 print (f" New state hash being written: '{ current_key } '" )
@@ -466,7 +466,18 @@ def _initialize_cache_structures(self) -> None:
466466 from collections import defaultdict
467467 # self.past_kv_cache_recurrent_infer = defaultdict(dict)
468468 # 使用 LRUCache 替换 defaultdict,并同步容量
469- self .past_kv_cache_recurrent_infer = LRUCache (self .shared_pool_size_recur )
469+
470+ # ========================= 核心修复与注释 (Recurrent Infer) =========================
471+ # 问题: recurrent_infer 缓存同样存在 LRUCache 与环形缓冲区逻辑不匹配的问题。
472+ #
473+ # 修复方案:
474+ # 1. 将 past_kv_cache_recurrent_infer 从 LRUCache 改为标准字典。
475+ # 2. 引入辅助列表 pool_idx_to_key_map_recur_infer 来维护反向映射。
476+ # 这确保了在覆写 recurrent 数据池中的条目时,可以同步删除旧的指针。
477+
478+ self .past_kv_cache_recurrent_infer = {}
479+ self .pool_idx_to_key_map_recur_infer = [None ] * self .shared_pool_size_recur
480+ # ========================== 修复结束 ==========================
470481
471482 # self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)]
472483
@@ -490,7 +501,23 @@ def _initialize_cache_structures(self) -> None:
490501 # 完全同步。当数据池的索引0被新数据覆盖时,指向旧索引0的指针也已被自动清除。
491502 # 3. 杜绝污染: 从根本上解决了Episode内部的状态哈希碰撞问题。
492503
493- self .past_kv_cache_init_infer_envs = [LRUCache (self .shared_pool_size_init ) for _ in range (self .env_num )]
504+ # self.past_kv_cache_init_infer_envs = [LRUCache(self.shared_pool_size_init-1) for _ in range(self.env_num)]
505+ # ========================== 修复结束 ==========================
506+
507+ # ========================= 核心修复与注释 =========================
508+ # 问题: LRUCache 的淘汰逻辑(基于访问顺序)与环形缓冲区的覆写逻辑(基于写入顺序)不匹配,导致指针过时。
509+ #
510+ # 修复方案:
511+ # 1. 使用一个标准的字典 `past_kv_cache_init_infer_envs` 来存储 {state_hash -> pool_index}。
512+ # 2. 引入一个辅助列表 `pool_idx_to_key_map_init_envs` 来维护反向映射 {pool_index -> state_hash}。
513+ #
514+ # 效果:
515+ # 在向环形缓冲区的某个索引写入新数据之前,我们可以通过辅助列表立即找到即将被覆盖的旧 state_hash,
516+ # 并从主字典中精确地删除这个过时的条目。这确保了字典和数据池的完全同步。
517+
518+ self .past_kv_cache_init_infer_envs = [{} for _ in range (self .env_num )]
519+ # 辅助数据结构,用于反向查找:pool_index -> key
520+ self .pool_idx_to_key_map_init_envs = [[None ] * self .shared_pool_size_init for _ in range (self .env_num )]
494521 # ========================== 修复结束 ==========================
495522
496523 self .keys_values_wm_list = []
@@ -1365,34 +1392,68 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde
13651392
13661393 if is_init_infer :
13671394 # TODO
1368- # ==================== DEBUG CODE INSERTION ====================
1369- # 在写入之前,先获取将要写入的索引
1395+ # ==================== 主动淘汰修复逻辑 ====================
1396+ # 1. 获取即将被覆写的物理索引
13701397 index_to_write = self .shared_pool_index_init_envs [i ]
1398+
1399+ # 2. 使用辅助列表查找该索引上存储的旧的 key
1400+ old_key_to_evict = self .pool_idx_to_key_map_init_envs [i ][index_to_write ]
1401+
1402+ # 3. 如果存在旧 key,就从主 cache map 中删除它
1403+ if old_key_to_evict is not None :
1404+ # 确保要删除的键确实存在,避免意外错误
1405+ if old_key_to_evict in self .past_kv_cache_init_infer_envs [i ]:
1406+ del self .past_kv_cache_init_infer_envs [i ][old_key_to_evict ]
1407+
1408+ # 现在可以安全地写入新数据了
1409+ cache_index = self .custom_copy_kv_cache_to_shared_init_envs (self .keys_values_wm_single_env , i )
1410+
1411+ # 4. 在主 cache map 和辅助列表中同时更新新的映射关系
1412+ self .past_kv_cache_init_infer_envs [i ][cache_key ] = cache_index
1413+ self .pool_idx_to_key_map_init_envs [i ][index_to_write ] = cache_key
1414+
13711415 # 调用调试函数进行检查
13721416 self ._debug_check_for_stale_pointers (env_id = i , current_key = cache_key , index_to_be_written = index_to_write )
13731417 # ============================================================
13741418
13751419 # Store the latest key-value cache for initial inference
1376- cache_index = self .custom_copy_kv_cache_to_shared_init_envs (self .keys_values_wm_single_env , i )
1377- self .past_kv_cache_init_infer_envs [i ][cache_key ] = cache_index
1420+ # cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i)
1421+ # self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index
13781422 else :
13791423 # TODO 获取要存入的cache的某个唯一标识,例如tensor的和
13801424 # cache_to_store = self.keys_values_wm_single_env._keys_values[0]._k_cache._cache
13811425 # cache_sum = torch.sum(cache_to_store).item()
13821426 # cache_shape = cache_to_store.shape
13831427 # print(f"[CACHE WRITE] Storing for key={cache_key}, cache_shape={cache_shape}, cache_sum={cache_sum:.4f}")
13841428
1385- # ==================== DEBUG CODE INSERTION ====================
1386- # 在写入之前,获取将要写入的索引
1429+ # ==================== RECURRENT INFER FIX ====================
1430+ # 1. 获取即将被覆写的物理索引
13871431 index_to_write = self .shared_pool_index
1432+
1433+ # 2. 使用辅助列表查找该索引上存储的旧的 key
1434+ old_key_to_evict = self .pool_idx_to_key_map_recur_infer [index_to_write ]
1435+
1436+ # 3. 如果存在旧 key,就从主 cache map 中删除它
1437+ if old_key_to_evict is not None :
1438+ if old_key_to_evict in self .past_kv_cache_recurrent_infer :
1439+ del self .past_kv_cache_recurrent_infer [old_key_to_evict ]
1440+
1441+ # 4. 现在可以安全地写入新数据了
1442+ cache_index = self .custom_copy_kv_cache_to_shared_recur (self .keys_values_wm_single_env )
1443+
1444+ # 5. 在主 cache map 和辅助列表中同时更新新的映射关系
1445+ self .past_kv_cache_recurrent_infer [cache_key ] = cache_index
1446+ self .pool_idx_to_key_map_recur_infer [index_to_write ] = cache_key
1447+ # ============================================================
1448+
1449+ # ==================== DEBUG CODE INSERTION ====================
13881450 # 调用调试函数进行检查
13891451 self ._debug_check_for_stale_pointers_recur (current_key = cache_key , index_to_be_written = index_to_write )
13901452 # ============================================================
13911453
1392-
13931454 # Store the latest key-value cache for recurrent inference
1394- cache_index = self .custom_copy_kv_cache_to_shared_recur (self .keys_values_wm_single_env )
1395- self .past_kv_cache_recurrent_infer [cache_key ] = cache_index
1455+ # cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env)
1456+ # self.past_kv_cache_recurrent_infer[cache_key] = cache_index
13961457
13971458
13981459 def retrieve_or_generate_kvcache (self , latent_state : list , ready_env_num : int ,
0 commit comments