Skip to content

Commit fc83e9a

Browse files
committed
polish(pu): polish unizero config
1 parent 2eb6d05 commit fc83e9a

File tree

6 files changed

+274
-97
lines changed

6 files changed

+274
-97
lines changed

lzero/mcts/buffer/game_buffer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -697,16 +697,16 @@ def _remove(self, excess_game_segment_index: int) -> None:
697697
f" - base_idx: {base_idx_after}\n"
698698
f"------------------------------\n\n"
699699
)
700-
700+
# TODO
701701
# 5. Print to console and write to file
702-
print(log_message)
702+
# print(log_message)
703703

704-
log_filename = f"game_buffer_remove_log_{timestamp.strftime('%Y%m%d_%H%M%S')}.txt"
705-
try:
706-
with open(log_filename, 'a', encoding='utf-8') as f:
707-
f.write(log_message)
708-
except Exception as e:
709-
print(f"[ERROR] Failed to write to log file {log_filename}: {e}")
704+
# log_filename = f"game_buffer_remove_log_{timestamp.strftime('%Y%m%d_%H%M%S')}.txt"
705+
# try:
706+
# with open(log_filename, 'a', encoding='utf-8') as f:
707+
# f.write(log_message)
708+
# except Exception as e:
709+
# print(f"[ERROR] Failed to write to log file {log_filename}: {e}")
710710

711711
# --- End of logging modification ---
712712

lzero/model/unizero_world_models/utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,18 +263,18 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu
263263
# NOTE: Define the weights for each loss type
264264
if not continuous_action_space:
265265
# like EZV2, for atari and memory
266-
self.obs_loss_weight = 10
267-
self.value_loss_weight = 0.5
268-
self.reward_loss_weight = 1.
269-
self.policy_loss_weight = 1.
270-
self.ends_loss_weight = 0.
266+
# self.obs_loss_weight = 10
267+
# self.value_loss_weight = 0.5
268+
# self.reward_loss_weight = 1.
269+
# self.policy_loss_weight = 1.
270+
# self.ends_loss_weight = 0.
271271

272272
# muzero loss weight
273-
# self.obs_loss_weight = 2
274-
# self.value_loss_weight = 0.25
275-
# self.reward_loss_weight = 1
276-
# self.policy_loss_weight = 1
277-
# self.ends_loss_weight = 0.
273+
self.obs_loss_weight = 2
274+
self.value_loss_weight = 0.25
275+
self.reward_loss_weight = 1
276+
self.policy_loss_weight = 1
277+
self.ends_loss_weight = 0.
278278

279279
# like TD-MPC2 for DMC
280280
# self.obs_loss_weight = 10

0 commit comments

Comments
 (0)