Skip to content

Commit 50bd3c0

Browse files
committed
fix(pu): fix unizero reset_collect/eval kv_cache bug!!!
1 parent bb0845d commit 50bd3c0

File tree

9 files changed

+1912
-64
lines changed

9 files changed

+1912
-64
lines changed

lzero/mcts/buffer/game_buffer_bkp20250818.py

Lines changed: 668 additions & 0 deletions
Large diffs are not rendered by default.

lzero/model/unizero_world_models/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,28 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu
268268
self.reward_loss_weight = 1.
269269
self.policy_loss_weight = 1.
270270
self.ends_loss_weight = 0.
271+
272+
# muzero loss weight
273+
# self.obs_loss_weight = 2
274+
# self.value_loss_weight = 0.25
275+
# self.reward_loss_weight = 1
276+
# self.policy_loss_weight = 1
277+
# self.ends_loss_weight = 0.
278+
279+
# like TD-MPC2 for DMC
280+
# self.obs_loss_weight = 10
281+
# self.value_loss_weight = 0.1
282+
# self.reward_loss_weight = 0.1
283+
# self.policy_loss_weight = 0.1
284+
# self.ends_loss_weight = 0.
271285
else:
272286
# like TD-MPC2 for DMC
273287
self.obs_loss_weight = 10
274288
self.value_loss_weight = 0.1
275289
self.reward_loss_weight = 0.1
276290
self.policy_loss_weight = 0.1
277291
self.ends_loss_weight = 0.
278-
292+
279293
self.latent_recon_loss_weight = latent_recon_loss_weight
280294
self.perceptual_loss_weight = perceptual_loss_weight
281295

lzero/model/unizero_world_models/world_model.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
6262
# Position embedding
6363
if not self.config.rotary_emb:
6464
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device)
65+
# TODO(pu)
66+
# self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device, max_norm=1.0)
6567
self.precompute_pos_emb_diff_kv()
6668
print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}")
6769

@@ -76,6 +78,9 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
7678
else:
7779
# for discrete action space
7880
self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device)
81+
# TODO(pu)
82+
# self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device, max_norm=1.0)
83+
7984
logging.info(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}")
8085

8186
self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'SimNorm')
@@ -1324,18 +1329,21 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
13241329
else:
13251330
dormant_ratio_encoder = torch.tensor(0.)
13261331

1327-
# Calculate the L2 norm of the latent state roots
1328-
latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean()
1329-
13301332
# Action tokens
13311333
if self.continuous_action_space:
13321334
act_tokens = batch['actions']
13331335
else:
13341336
act_tokens = rearrange(batch['actions'], 'b l -> b l 1')
13351337

1338+
with torch.no_grad():
1339+
# Calculate the L2 norm of the latent state roots
1340+
latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean()
1341+
# Calculate the L2 norm of the latent action
1342+
latent_action_l2_norms = torch.norm(self.act_embedding_table(act_tokens), p=2, dim=2).mean()
1343+
13361344
# Forward pass to obtain predictions for observations, rewards, and policies
13371345
outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos)
1338-
1346+
13391347
if self.obs_type == 'image':
13401348
# Reconstruct observations from latent state representations
13411349
# reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)
@@ -1481,7 +1489,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
14811489
loss_obs = (loss_obs * mask_padding_expanded)
14821490

14831491
# Compute labels for policy and value
1484-
labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'],
1492+
labels_value, labels_policy = self.compute_labels_world_model_value_policy(batch['target_value'],
14851493
batch['target_policy'],
14861494
batch['mask_padding'])
14871495

@@ -1582,6 +1590,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
15821590
dormant_ratio_encoder=dormant_ratio_encoder,
15831591
dormant_ratio_world_model=dormant_ratio_world_model,
15841592
latent_state_l2_norms=latent_state_l2_norms,
1593+
latent_action_l2_norms=latent_action_l2_norms,
15851594
policy_mu=mu,
15861595
policy_sigma=sigma,
15871596
target_sampled_actions=target_sampled_actions,
@@ -1605,6 +1614,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
16051614
dormant_ratio_encoder=dormant_ratio_encoder,
16061615
dormant_ratio_world_model=dormant_ratio_world_model,
16071616
latent_state_l2_norms=latent_state_l2_norms,
1617+
latent_action_l2_norms=latent_action_l2_norms,
1618+
16081619
)
16091620

16101621

@@ -1821,9 +1832,9 @@ def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, ta
18211832
labels_value = target_value.masked_fill(mask_fill_value, -100)
18221833

18231834
if self.continuous_action_space:
1824-
return None, labels_value.reshape(-1, self.support_size)
1835+
return labels_value.reshape(-1, self.support_size), None
18251836
else:
1826-
return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size)
1837+
return labels_value.reshape(-1, self.support_size), labels_policy.reshape(-1, self.action_space_size)
18271838

18281839
def clear_caches(self):
18291840
"""

lzero/policy/sampled_unizero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
410410
# Prepare action batch and convert to torch tensor
411411
if self._cfg.model.continuous_action_space:
412412
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(
413-
-1) # For discrete action space
413+
-1) # For continuous action space
414414
else:
415415
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(
416416
-1).long() # For discrete action space

lzero/policy/unizero.py

Lines changed: 96 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
474474
dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder']
475475
dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model']
476476
latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms']
477+
latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms']
477478

478479
assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
479480
assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"
@@ -586,6 +587,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
586587
'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(),
587588
'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(),
588589
'analysis/latent_state_l2_norms': latent_state_l2_norms.item(),
590+
'analysis/latent_action_l2_norms': latent_action_l2_norms.item(),
589591
'analysis/l2_norm_before': self.l2_norm_before,
590592
'analysis/l2_norm_after': self.l2_norm_after,
591593
'analysis/grad_norm_before': self.grad_norm_before,
@@ -620,6 +622,7 @@ def _init_collect(self) -> None:
620622
self._mcts_collect = MCTSCtree(mcts_collect_cfg)
621623
else:
622624
self._mcts_collect = MCTSPtree(mcts_collect_cfg)
625+
623626
self._collect_mcts_temperature = 1.
624627
self._collect_epsilon = 0.0
625628
self.collector_env_num = self._cfg.collector_env_num
@@ -908,29 +911,59 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in
908911
)
909912
self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)]
910913

911-
# Return immediately if env_id is None or a list
912-
if env_id is None or isinstance(env_id, list):
913-
return
914+
# --- BEGIN ROBUST FIX ---
915+
# This logic handles the crucial end-of-episode cache clearing.
916+
# The collector calls `_policy.reset([env_id])` when an episode is done,
917+
# which results in `current_steps` being None and `env_id` being a list.
918+
919+
# We must handle both single int and list of ints for env_id.
920+
if env_id is not None:
921+
if isinstance(env_id, int):
922+
env_ids_to_reset = [env_id]
923+
else: # Assumes it's a list
924+
env_ids_to_reset = env_id
925+
926+
# The key condition: `current_steps` is None only on the end-of-episode reset call from the collector.
927+
if current_steps is None:
928+
world_model = self._collect_model.world_model
929+
for eid in env_ids_to_reset:
930+
# Clear the specific environment's initial inference cache.
931+
if eid < len(world_model.past_kv_cache_init_infer_envs):
932+
world_model.past_kv_cache_init_infer_envs[eid].clear()
933+
934+
print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.')
935+
936+
# The recurrent cache is global, which is problematic.
937+
# A full clear is heavy-handed but safer than leaving stale entries.
938+
world_model.past_kv_cache_recurrent_infer.clear()
939+
940+
if hasattr(world_model, 'keys_values_wm_list'):
941+
world_model.keys_values_wm_list.clear()
942+
943+
torch.cuda.empty_cache()
944+
945+
# --- END ROBUST FIX ---
946+
947+
# # Determine the clear interval based on the environment's sample type
948+
# clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200
914949

915-
# Determine the clear interval based on the environment's sample type
916-
clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200
950+
# # Clear caches if the current steps are a multiple of the clear interval
951+
# if current_steps % clear_interval == 0:
952+
# print(f'clear_interval: {clear_interval}')
917953

918-
# Clear caches if the current steps are a multiple of the clear interval
919-
if current_steps % clear_interval == 0:
920-
print(f'clear_interval: {clear_interval}')
954+
# # Clear various caches in the collect model's world model
955+
# world_model = self._collect_model.world_model
956+
# for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs:
957+
# kv_cache_dict_env.clear()
958+
# world_model.past_kv_cache_recurrent_infer.clear()
959+
# world_model.keys_values_wm_list.clear()
921960

922-
# Clear various caches in the collect model's world model
923-
world_model = self._collect_model.world_model
924-
for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs:
925-
kv_cache_dict_env.clear()
926-
world_model.past_kv_cache_recurrent_infer.clear()
927-
world_model.keys_values_wm_list.clear()
961+
# # Free up GPU memory
962+
# torch.cuda.empty_cache()
928963

929-
# Free up GPU memory
930-
torch.cuda.empty_cache()
964+
# print('collector: collect_model clear()')
965+
# print(f'eps_steps_lst[{env_id}]: {current_steps}')
931966

932-
print('collector: collect_model clear()')
933-
print(f'eps_steps_lst[{env_id}]: {current_steps}')
934967

935968
def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None:
936969
"""
@@ -952,29 +985,54 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_
952985
)
953986
self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)]
954987

955-
# Return immediately if env_id is None or a list
956-
if env_id is None or isinstance(env_id, list):
957-
return
988+
# --- BEGIN ROBUST FIX ---
989+
# This logic handles the crucial end-of-episode cache clearing for evaluation.
990+
# The evaluator calls `_policy.reset([env_id])` when an episode is done.
991+
if env_id is not None:
992+
if isinstance(env_id, int):
993+
env_ids_to_reset = [env_id]
994+
else: # Assumes it's a list
995+
env_ids_to_reset = env_id
996+
997+
# The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator.
998+
if current_steps is None:
999+
world_model = self._eval_model.world_model
1000+
for eid in env_ids_to_reset:
1001+
# Clear the specific environment's initial inference cache.
1002+
if eid < len(world_model.past_kv_cache_init_infer_envs):
1003+
world_model.past_kv_cache_init_infer_envs[eid].clear()
1004+
1005+
print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.')
1006+
1007+
# The recurrent cache is global.
1008+
world_model.past_kv_cache_recurrent_infer.clear()
1009+
1010+
if hasattr(world_model, 'keys_values_wm_list'):
1011+
world_model.keys_values_wm_list.clear()
1012+
1013+
torch.cuda.empty_cache()
1014+
return
1015+
# --- END ROBUST FIX ---
9581016

959-
# Determine the clear interval based on the environment's sample type
960-
clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200
1017+
# # # Determine the clear interval based on the environment's sample type
1018+
# clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200
9611019

962-
# Clear caches if the current steps are a multiple of the clear interval
963-
if current_steps % clear_interval == 0:
964-
print(f'clear_interval: {clear_interval}')
1020+
# # # Clear caches if the current steps are a multiple of the clear interval
1021+
# if current_steps % clear_interval == 0:
1022+
# print(f'clear_interval: {clear_interval}')
9651023

966-
# Clear various caches in the eval model's world model
967-
world_model = self._eval_model.world_model
968-
for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs:
969-
kv_cache_dict_env.clear()
970-
world_model.past_kv_cache_recurrent_infer.clear()
971-
world_model.keys_values_wm_list.clear()
1024+
# # Clear various caches in the eval model's world model
1025+
# world_model = self._eval_model.world_model
1026+
# for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs:
1027+
# kv_cache_dict_env.clear()
1028+
# world_model.past_kv_cache_recurrent_infer.clear()
1029+
# world_model.keys_values_wm_list.clear()
9721030

973-
# Free up GPU memory
974-
torch.cuda.empty_cache()
1031+
# # Free up GPU memory
1032+
# torch.cuda.empty_cache()
9751033

976-
print('evaluator: eval_model clear()')
977-
print(f'eps_steps_lst[{env_id}]: {current_steps}')
1034+
# print('evaluator: eval_model clear()')
1035+
# print(f'eps_steps_lst[{env_id}]: {current_steps}')
9781036

9791037
def _monitor_vars_learn(self) -> List[str]:
9801038
"""
@@ -986,6 +1044,8 @@ def _monitor_vars_learn(self) -> List[str]:
9861044
'analysis/dormant_ratio_encoder',
9871045
'analysis/dormant_ratio_world_model',
9881046
'analysis/latent_state_l2_norms',
1047+
'analysis/latent_action_l2_norms',
1048+
9891049
'analysis/l2_norm_before',
9901050
'analysis/l2_norm_after',
9911051
'analysis/grad_norm_before',

0 commit comments

Comments
 (0)