@@ -474,6 +474,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
474474 dormant_ratio_encoder = self .intermediate_losses ['dormant_ratio_encoder' ]
475475 dormant_ratio_world_model = self .intermediate_losses ['dormant_ratio_world_model' ]
476476 latent_state_l2_norms = self .intermediate_losses ['latent_state_l2_norms' ]
477+ latent_action_l2_norms = self .intermediate_losses ['latent_action_l2_norms' ]
477478
478479 assert not torch .isnan (losses .loss_total ).any (), "Loss contains NaN values"
479480 assert not torch .isinf (losses .loss_total ).any (), "Loss contains Inf values"
@@ -586,6 +587,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
586587 'analysis/dormant_ratio_encoder' : dormant_ratio_encoder .item (),
587588 'analysis/dormant_ratio_world_model' : dormant_ratio_world_model .item (),
588589 'analysis/latent_state_l2_norms' : latent_state_l2_norms .item (),
590+ 'analysis/latent_action_l2_norms' : latent_action_l2_norms .item (),
589591 'analysis/l2_norm_before' : self .l2_norm_before ,
590592 'analysis/l2_norm_after' : self .l2_norm_after ,
591593 'analysis/grad_norm_before' : self .grad_norm_before ,
@@ -620,6 +622,7 @@ def _init_collect(self) -> None:
620622 self ._mcts_collect = MCTSCtree (mcts_collect_cfg )
621623 else :
622624 self ._mcts_collect = MCTSPtree (mcts_collect_cfg )
625+
623626 self ._collect_mcts_temperature = 1.
624627 self ._collect_epsilon = 0.0
625628 self .collector_env_num = self ._cfg .collector_env_num
@@ -908,29 +911,59 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in
908911 )
909912 self .last_batch_action = [- 1 for _ in range (self ._cfg .collector_env_num )]
910913
911- # Return immediately if env_id is None or a list
912- if env_id is None or isinstance (env_id , list ):
913- return
914+ # --- BEGIN ROBUST FIX ---
915+ # This logic handles the crucial end-of-episode cache clearing.
916+ # The collector calls `_policy.reset([env_id])` when an episode is done,
917+ # which results in `current_steps` being None and `env_id` being a list.
918+
919+ # We must handle both single int and list of ints for env_id.
920+ if env_id is not None :
921+ if isinstance (env_id , int ):
922+ env_ids_to_reset = [env_id ]
923+ else : # Assumes it's a list
924+ env_ids_to_reset = env_id
925+
926+ # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector.
927+ if current_steps is None :
928+ world_model = self ._collect_model .world_model
929+ for eid in env_ids_to_reset :
930+ # Clear the specific environment's initial inference cache.
931+ if eid < len (world_model .past_kv_cache_init_infer_envs ):
932+ world_model .past_kv_cache_init_infer_envs [eid ].clear ()
933+
934+ print (f'>>> [Collector] Cleared KV cache for env_id: { eid } at episode end.' )
935+
936+ # The recurrent cache is global, which is problematic.
937+ # A full clear is heavy-handed but safer than leaving stale entries.
938+ world_model .past_kv_cache_recurrent_infer .clear ()
939+
940+ if hasattr (world_model , 'keys_values_wm_list' ):
941+ world_model .keys_values_wm_list .clear ()
942+
943+ torch .cuda .empty_cache ()
944+
945+ # --- END ROBUST FIX ---
946+
947+ # # Determine the clear interval based on the environment's sample type
948+ # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200
914949
915- # Determine the clear interval based on the environment's sample type
916- clear_interval = 2000 if getattr (self ._cfg , 'sample_type' , '' ) == 'episode' else 200
950+ # # Clear caches if the current steps are a multiple of the clear interval
951+ # if current_steps % clear_interval == 0:
952+ # print(f'clear_interval: {clear_interval}')
917953
918- # Clear caches if the current steps are a multiple of the clear interval
919- if current_steps % clear_interval == 0 :
920- print (f'clear_interval: { clear_interval } ' )
954+ # # Clear various caches in the collect model's world model
955+ # world_model = self._collect_model.world_model
956+ # for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs:
957+ # kv_cache_dict_env.clear()
958+ # world_model.past_kv_cache_recurrent_infer.clear()
959+ # world_model.keys_values_wm_list.clear()
921960
922- # Clear various caches in the collect model's world model
923- world_model = self ._collect_model .world_model
924- for kv_cache_dict_env in world_model .past_kv_cache_init_infer_envs :
925- kv_cache_dict_env .clear ()
926- world_model .past_kv_cache_recurrent_infer .clear ()
927- world_model .keys_values_wm_list .clear ()
961+ # # Free up GPU memory
962+ # torch.cuda.empty_cache()
928963
929- # Free up GPU memory
930- torch . cuda . empty_cache ( )
964+ # print('collector: collect_model clear()')
965+ # print(f'eps_steps_lst[{env_id}]: {current_steps}' )
931966
932- print ('collector: collect_model clear()' )
933- print (f'eps_steps_lst[{ env_id } ]: { current_steps } ' )
934967
935968 def _reset_eval (self , env_id : int = None , current_steps : int = None , reset_init_data : bool = True ) -> None :
936969 """
@@ -952,29 +985,54 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
952985 )
953986 self .last_batch_action = [- 1 for _ in range (self ._cfg .evaluator_env_num )]
954987
955- # Return immediately if env_id is None or a list
956- if env_id is None or isinstance (env_id , list ):
957- return
988+ # --- BEGIN ROBUST FIX ---
989+ # This logic handles the crucial end-of-episode cache clearing for evaluation.
990+ # The evaluator calls `_policy.reset([env_id])` when an episode is done.
991+ if env_id is not None :
992+ if isinstance (env_id , int ):
993+ env_ids_to_reset = [env_id ]
994+ else : # Assumes it's a list
995+ env_ids_to_reset = env_id
996+
997+ # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator.
998+ if current_steps is None :
999+ world_model = self ._eval_model .world_model
1000+ for eid in env_ids_to_reset :
1001+ # Clear the specific environment's initial inference cache.
1002+ if eid < len (world_model .past_kv_cache_init_infer_envs ):
1003+ world_model .past_kv_cache_init_infer_envs [eid ].clear ()
1004+
1005+ print (f'>>> [Evaluator] Cleared KV cache for env_id: { eid } at episode end.' )
1006+
1007+ # The recurrent cache is global.
1008+ world_model .past_kv_cache_recurrent_infer .clear ()
1009+
1010+ if hasattr (world_model , 'keys_values_wm_list' ):
1011+ world_model .keys_values_wm_list .clear ()
1012+
1013+ torch .cuda .empty_cache ()
1014+ return
1015+ # --- END ROBUST FIX ---
9581016
959- # Determine the clear interval based on the environment's sample type
960- clear_interval = 2000 if getattr (self ._cfg , 'sample_type' , '' ) == 'episode' else 200
1017+ # # # Determine the clear interval based on the environment's sample type
1018+ # clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200
9611019
962- # Clear caches if the current steps are a multiple of the clear interval
963- if current_steps % clear_interval == 0 :
964- print (f'clear_interval: { clear_interval } ' )
1020+ # # # Clear caches if the current steps are a multiple of the clear interval
1021+ # if current_steps % clear_interval == 0:
1022+ # print(f'clear_interval: {clear_interval}')
9651023
966- # Clear various caches in the eval model's world model
967- world_model = self ._eval_model .world_model
968- for kv_cache_dict_env in world_model .past_kv_cache_init_infer_envs :
969- kv_cache_dict_env .clear ()
970- world_model .past_kv_cache_recurrent_infer .clear ()
971- world_model .keys_values_wm_list .clear ()
1024+ # # Clear various caches in the eval model's world model
1025+ # world_model = self._eval_model.world_model
1026+ # for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs:
1027+ # kv_cache_dict_env.clear()
1028+ # world_model.past_kv_cache_recurrent_infer.clear()
1029+ # world_model.keys_values_wm_list.clear()
9721030
973- # Free up GPU memory
974- torch .cuda .empty_cache ()
1031+ # # Free up GPU memory
1032+ # torch.cuda.empty_cache()
9751033
976- print ('evaluator: eval_model clear()' )
977- print (f'eps_steps_lst[{ env_id } ]: { current_steps } ' )
1034+ # print('evaluator: eval_model clear()')
1035+ # print(f'eps_steps_lst[{env_id}]: {current_steps}')
9781036
9791037 def _monitor_vars_learn (self ) -> List [str ]:
9801038 """
@@ -986,6 +1044,8 @@ def _monitor_vars_learn(self) -> List[str]:
9861044 'analysis/dormant_ratio_encoder' ,
9871045 'analysis/dormant_ratio_world_model' ,
9881046 'analysis/latent_state_l2_norms' ,
1047+ 'analysis/latent_action_l2_norms' ,
1048+
9891049 'analysis/l2_norm_before' ,
9901050 'analysis/l2_norm_after' ,
9911051 'analysis/grad_norm_before' ,
0 commit comments