From 102612dd0fcef526746691f2d5f592b113b98c00 Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Sun, 18 Jan 2026 16:08:29 +0800 Subject: [PATCH 1/9] fix existing bugs in episode mode for collecting and adapt jericho to multitask env. --- lzero/entry/__init__.py | 2 + lzero/entry/train_unizero_multitask.py | 356 ++++++++++++++ lzero/entry/train_unizero_multitask_ddp.py | 450 ++++++++++++++++++ lzero/entry/utils.py | 4 +- lzero/mcts/buffer/game_segment.py | 5 +- lzero/model/common.py | 13 +- lzero/model/unizero_model.py | 6 +- lzero/model/unizero_model_multitask.py | 19 +- .../world_model_multitask.py | 6 + lzero/policy/unizero.py | 4 +- lzero/worker/muzero_collector.py | 35 +- zoo/jericho/configs/jericho_unizero_config.py | 4 +- .../configs/jericho_unizero_ddp_config.py | 4 +- .../jericho_unizero_multitask_config.py | 275 +++++++++++ .../jericho_unizero_multitask_ddp_config.py | 288 +++++++++++ 15 files changed, 1440 insertions(+), 31 deletions(-) create mode 100644 lzero/entry/train_unizero_multitask.py create mode 100644 lzero/entry/train_unizero_multitask_ddp.py create mode 100644 zoo/jericho/configs/jericho_unizero_multitask_config.py create mode 100644 zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index ba846e26a..be0a145e0 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -14,6 +14,8 @@ from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp from .train_unizero_with_loss_landscape import train_unizero_with_loss_landscape +from .train_unizero_multitask import train_unizero_multitask +from .train_unizero_multitask_ddp import train_unizero_multitask_ddp # from .utils import ( # symlog, diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py new file mode 100644 index 000000000..0529a5450 --- /dev/null +++ b/lzero/entry/train_unizero_multitask.py @@ -0,0 +1,356 @@ + +import logging +import os +from functools import partial +from typing import Tuple, Optional, List, Dict, Any +import concurrent.futures +import torch +import numpy as np +import torch.nn.functional as F +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy, Policy +from ding.utils import set_pkg_seed, EasyTimer +from ding.worker import BaseLearner +from lzero.entry.utils import ( + EVALUATION_TIMEOUT, + TemperatureScheduler, + allocate_batch_size, + compute_task_weights, + compute_unizero_mt_normalized_stats, + log_buffer_memory_usage, + safe_eval, + symlog, + inv_symlog, +) + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroCollector as Collector + +# Set timeout (seconds) +timer = EasyTimer() + +def train_unizero_multitask( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + Entry point for UniZero multi-task training (non-DDP version). + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configuration list for different tasks. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): Path to the pre-trained model. + - max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations. + - max_env_step (:obj:`Optional[int]`): Maximum number of collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): The converged policy. + """ + # Initialize temperature scheduler (unchanged) + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' + ) + + # Handle all tasks in a single process + tasks = input_cfg_list + total_tasks = len(tasks) + print(f"Handling all {total_tasks} tasks in a single process.") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + # Ensure at least one task is provided + if not tasks: + logging.error("No task configurations provided. Training cannot proceed.") + return None + + # Use the first task's configuration to create the shared policy and learner + task_id_first, [cfg_first, create_cfg_first] = tasks[0] + + assert create_cfg_first.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], "train_unizero_multitask entry currently only supports 'unizero_multitask' or 'sampled_unizero_multitask'" + + + GameBuffer = None + if create_cfg_first.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GB + GameBuffer = GB + elif create_cfg_first.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as SGB + GameBuffer = SGB + else: + raise NotImplementedError(f"Policy type {create_cfg_first.policy.type} not fully supported for GameBuffer import.") + + cfg_first.policy.device = 'cuda' if cfg_first.policy.cuda and torch.cuda.is_available() else 'cpu' + logging.info(f'Using device: {cfg_first.policy.device}') + + # Compile the main config (only for creating policy and learner) + # Note: we compile once here, but later re-compile per-task configs + compiled_cfg_first = compile_config(cfg_first, seed=seed, env=None, auto=True, create_cfg=create_cfg_first, save_cfg=True) + + # Create shared policy + policy = create_policy(compiled_cfg_first.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if model_path is not None: + logging.info(f'Loading pretrained model: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg_first.policy.device)) + logging.info(f'Finished loading model: {model_path}') + + log_dir = os.path.join('./{}/log/'.format(compiled_cfg_first.exp_name), 'serial') + tb_logger = SummaryWriter(log_dir) + + # Create shared learner + learner = BaseLearner(compiled_cfg_first.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=compiled_cfg_first.exp_name) + + # Process each task + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks): + # Set random seed per task + current_seed = seed + task_id + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + # Compile per-task config + cfg = compile_config(cfg, seed=current_seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Get policy config + policy_config = cfg.policy + policy_config.task_id = task_id # explicitly set task_id + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # Create environments + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(current_seed) + evaluator_env.seed(current_seed, dynamic_seed=False) + set_pkg_seed(current_seed, use_cuda=cfg.policy.cuda) + + # Create buffer, collector, and evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + + reanalyze_batch_size = compiled_cfg_first.policy.reanalyze_batch_size + update_per_collect = compiled_cfg_first.policy.update_per_collect + + task_exploitation_weight = None + task_rewards = {} + + while True: + # Iterate over tasks for data collection and evaluation + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + current_task_id = cfg.policy.task_id + + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, current_task_id) + + policy_config = cfg.policy + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 + } + update_per_collect = policy_config.update_per_collect + if update_per_collect is None: + update_per_collect = 40 + + if learner.train_iter > 0 or evaluator.should_eval(learner.train_iter): # only for debug + print(f'Evaluating task_id: {current_task_id}...') + # Reset evaluator policy state + evaluator._policy.reset(reset_init_data=True, task_id=current_task_id) + + # Perform safe evaluation (non-DDP version) + stop, reward = safe_eval(evaluator, learner, collector) + if stop is None or reward is None: + print(f"Evaluation failed or timed out, task_id: {current_task_id}, continuing training...") + task_rewards[current_task_id] = float('inf') + else: + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"Evaluation reward for task {current_task_id}: {eval_mean_reward}") + task_rewards[current_task_id] = eval_mean_reward + except Exception as e: + print(f"Error extracting reward for task {current_task_id}: {e}") + task_rewards[current_task_id] = float('inf') + + print('=' * 20) + print(f'Starting data collection for task_id: {current_task_id}...') + print(f'cfg.policy.task_id={current_task_id}') + + # Reset collector policy state + collector._policy.reset(reset_init_data=True, task_id=current_task_id) + + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + if policy_config.buffer_reanalyze_freq >= 1: + if update_per_collect is None or update_per_collect == 0: + logging.warning("update_per_collect undefined or zero, cannot compute reanalyze_interval") + reanalyze_interval = float('inf') + + else: + reanalyze_interval = update_per_collect // policy_config.buffer_reanalyze_freq + else: + reanalyze_interval = float('inf') + if train_epoch > 0 and policy_config.buffer_reanalyze_freq > 0 and \ + train_epoch % int(1 / policy_config.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // policy_config.num_unroll_steps > int(reanalyze_batch_size / policy_config.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, time cost: {timer.value}') + + logging.info(f'Finished data collection for task {current_task_id}') + + not_enough_data = any( + game_buffers[idx].get_num_of_transitions() < policy._cfg.batch_size[cfg.policy.task_id] + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)) + ) + task_weights = None + + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_this_epoch = 0 + + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + current_task_id = cfg.policy.task_id + current_batch_size = policy._cfg.batch_size[current_task_id] + + if current_batch_size == 0: + logging.warning(f"Task {current_task_id} batch_size is 0, skipping sampling.") + continue + + if replay_buffer.get_num_of_transitions() >= current_batch_size: + policy_config = cfg.policy + if policy_config.buffer_reanalyze_freq >= 1: + if update_per_collect is not None and update_per_collect > 0: + reanalyze_interval = update_per_collect // policy_config.buffer_reanalyze_freq + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // policy_config.num_unroll_steps > int(reanalyze_batch_size / policy_config.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, time cost: {timer.value}') + + train_data = replay_buffer.sample(current_batch_size, policy) + train_data.append(current_task_id) + train_data_multi_task.append(train_data) + envstep_this_epoch += collector.envstep + else: + logging.warning( + f'Not enough data for task {current_task_id}: ' + f'batch_size: {current_batch_size}, buffer: {replay_buffer.get_num_of_transitions()}' + ) + + if train_data_multi_task: + learn_kwargs = {'task_weights': task_weights, "train_iter": learner.train_iter} + log_vars = learner.train(train_data_multi_task, envstep_this_epoch, policy_kwargs=learn_kwargs) + + + if compiled_cfg_first.policy.use_priority: + if log_vars: + for batch_idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + task_id = cfg.policy.task_id + priority_key = f'value_priority_task{task_id}' + if priority_key in log_vars[0]: + if batch_idx < len(train_data_multi_task): + try: + replay_buffer.update_priority( + train_data_multi_task[batch_idx], + log_vars[0][priority_key] + ) + current_priorities = log_vars[0][priority_key] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + alpha = 0.1 + running_mean_key = f'running_mean_priority_task{task_id}' + if running_mean_key not in value_priority_tasks: + value_priority_tasks[running_mean_key] = mean_priority + else: + value_priority_tasks[running_mean_key] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[running_mean_key] + ) + running_mean_priority = value_priority_tasks[running_mean_key] + if policy_config.print_task_priority_logs: + print(f"Task {task_id} - Mean priority: {mean_priority:.8f}, " + f"Running mean priority: {running_mean_priority:.8f}, " + f"Std: {std_priority:.8f}") + except Exception as e: + logging.error(f"Error updating priority for task {task_id}: {e}") + else: + logging.warning(f"Cannot update priority for task {task_id}, missing data in train_data_multi_task.") + else: + logging.warning(f"Priority key '{priority_key}' not found for task {task_id} in log_vars[0]") + else: + logging.warning("log_vars is empty, cannot update priorities.") + train_epoch += 1 + # Check termination conditions + local_max_envstep = max(collector.envstep for collector in collectors) if collectors else 0 + max_envstep_reached = local_max_envstep >= max_env_step + max_train_iter_reached = learner.train_iter >= max_train_iter + + if max_envstep_reached or max_train_iter_reached: + logging.info(f'Termination condition reached: env_step ({local_max_envstep}/{max_env_step}) or train_iter ({learner.train_iter}/{max_train_iter})') + break + + if hasattr(policy, 'recompute_pos_emb_diff_and_clear_cache'): + policy.recompute_pos_emb_diff_and_clear_cache() + + learner.call_hook('after_run') + return policy diff --git a/lzero/entry/train_unizero_multitask_ddp.py b/lzero/entry/train_unizero_multitask_ddp.py new file mode 100644 index 000000000..fb28d935f --- /dev/null +++ b/lzero/entry/train_unizero_multitask_ddp.py @@ -0,0 +1,450 @@ +import logging +import os +from collections import defaultdict +from functools import partial +from typing import Tuple, Optional, List, Dict +import concurrent.futures +import torch +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer +from ding.worker import BaseLearner +from lzero.entry.utils import ( + EVALUATION_TIMEOUT, + TemperatureScheduler, + allocate_batch_size, + compute_task_weights, + compute_unizero_mt_normalized_stats, + log_buffer_memory_usage, + safe_eval, + symlog, + inv_symlog, +) + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler, symlog, inv_symlog +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroCollector as Collector + +timer = EasyTimer() + +def train_unizero_multitask_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + Entry point for UniZero training. The goal is to improve the planning ability + of reinforcement learning agents by addressing the limitations of MuZero-like + algorithms in environments that require capturing long-term dependencies. + For more details, refer to https://arxiv.org/abs/2406.10667. + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configuration list for different tasks. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): An instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): Path to the pretrained model checkpoint file. + - max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations during training. + - max_env_step (:obj:`Optional[int]`): Maximum number of collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): The converged policy. + """ + + # Initialize the temperature scheduler for task weighting. + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' + ) + + rank = get_rank() + world_size = get_world_size() + + # Task partitioning + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + num_tasks_for_this_rank = tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + num_tasks_for_this_rank = tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # Ensure at least one task is assigned + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: no tasks assigned, continuing execution.") + # Initialize empty lists to avoid errors in later code + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, handling tasks {start_idx} to {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # Use the first task’s config to create a shared policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + for config in tasks_for_this_rank: + config[1][0].policy.task_num = num_tasks_for_this_rank + + assert create_cfg.policy.type in ['unizero_multitask', + 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'Configured device: {cfg.policy.device}') + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create shared policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + print(f"rank {rank} created the policy!") + if model_path is not None: + logging.info(f'Loading pretrained model: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Finished loading pretrained model: {model_path}') + + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # Create shared learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + policy_config = cfg.policy + + # Handle each task assigned to this rank + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # Create environments + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # Create game buffer, collector, and evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + task_exploitation_weight = None + + # Create task reward dictionary + task_rewards = {} # {task_id: reward} + + while True: + # Dynamically adjust batch_size + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("Allocated batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + task_id = cfg.policy.task_id + if isinstance(allocated_batch_sizes, dict): + cfg.policy.batch_size = allocated_batch_sizes.get(task_id, cfg.policy.batch_size) + elif isinstance(allocated_batch_sizes, list): + # Use the index in the list or task_id as fallback + cfg.policy.batch_size = allocated_batch_sizes[idx] if idx < len(allocated_batch_sizes) else cfg.policy.batch_size + else: + logging.warning(f"Unexpected type for allocated_batch_sizes: {type(allocated_batch_sizes)}") + # Also update the policy config (use the full list for compatibility) + policy._cfg.batch_size = allocated_batch_sizes + + # Perform data collection and evaluation for each task on this rank + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if learner.train_iter > 10 or evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} evaluating task_id: {cfg.policy.task_id}...') + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + + # Perform safe evaluation + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + if stop is None or reward is None: + print(f"Rank {rank} encountered an issue during evaluation, continuing training...") + task_rewards[cfg.policy.task_id] = float('inf') # Assign max difficulty if evaluation fails + else: + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"Evaluation reward for task {cfg.policy.task_id}: {eval_mean_reward}") + task_rewards[cfg.policy.task_id] = eval_mean_reward + except Exception as e: + print(f"Error extracting evaluation reward: {e}") + task_rewards[cfg.policy.task_id] = float('inf') # Assign max reward if error occurs + + + print('=' * 20) + print(f'Starting data collection for Rank {rank}, task_id: {cfg.policy.task_id}...') + print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') + + # Reset policy state before each collection (important for multi-task setups) + collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}, collected {len(new_data[0]) if new_data else 0} segments') + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time cost: {timer.value}') + + logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}') + + try: + logging.info(f'Rank {rank}: Waiting at post-collection barrier...') + dist.barrier() + logging.info(f'Rank {rank}: All ranks completed data collection, proceeding...') + except Exception as e: + logging.error(f'Rank {rank}: Post-collection barrier failed, error: {e}') + raise e + + # Check if there is enough data for training + local_not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + logging.info(f"Rank {rank} local_not_enough_data:{local_not_enough_data}") + flag_tensor = torch.tensor(1.0 if local_not_enough_data else 0.0, device=cfg.policy.device) + dist.all_reduce(flag_tensor, op=dist.ReduceOp.MAX) + not_enough_data = (flag_tensor.item() > 0.5) + if rank == 0: + logging.info(f"Global not_enough_data status: {not_enough_data}") + + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + if isinstance(cfg.policy.batch_size, (list, tuple)): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + elif isinstance(cfg.policy.batch_size, dict): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + else: + batch_size = cfg.policy.batch_size + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time cost: {timer.value}') + + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'Not enough data in replay buffer to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + learn_kwargs = {'task_weights': None, "train_iter": learner.train_iter} + log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) + + # Compute task_exploitation_weight if needed + if i == 0: + try: + dist.barrier() + if cfg.policy.use_task_exploitation_weight: + all_obs_loss = [None for _ in range(world_size)] + merged_obs_loss_task = {} + for cfg, replay_buffer in zip(cfgs, game_buffers): + task_id = cfg.policy.task_id + if f'noreduce_obs_loss_task{task_id}' in log_vars[0]: + merged_obs_loss_task[task_id] = log_vars[0][f'noreduce_obs_loss_task{task_id}'] + dist.all_gather_object(all_obs_loss, merged_obs_loss_task) + global_obs_loss_task = {} + for obs_loss_task in all_obs_loss: + if obs_loss_task: + global_obs_loss_task.update(obs_loss_task) + if global_obs_loss_task: + task_exploitation_weight = compute_task_weights( + global_obs_loss_task, + option="rank", + temperature=1, + ) + dist.broadcast_object_list([task_exploitation_weight], src=0) + print(f"Rank {rank}, task_exploitation_weight (by task_id): {task_exploitation_weight}") + else: + logging.warning(f"Rank {rank}: Failed to compute global obs_loss task weights, obs_loss data is empty.") + task_exploitation_weight = None + else: + task_exploitation_weight = None + learn_kwargs['task_weight'] = task_exploitation_weight + except Exception as e: + logging.error(f'Rank {rank}: Failed to synchronize task weights, error: {e}') + raise e + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # smoothing factor + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + if cfg.policy.print_task_priority_logs: + print(f"Task {task_id} - Mean priority: {mean_priority:.8f}, " + f"Running mean priority: {running_mean_priority:.8f}, " + f"Std: {std_priority:.8f}") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # Synchronize all ranks after training + try: + dist.barrier() + logging.info(f'Rank {rank}: passed training synchronization barrier') + except Exception as e: + logging.error(f'Rank {rank}: synchronization barrier failed, error: {e}') + break + + # Check termination conditions + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: termination condition reached') + dist.barrier() + break + except Exception as e: + logging.error(f'Rank {rank}: termination check failed, error: {e}') + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index dc8dacf0f..a6ed03571 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -608,8 +608,8 @@ def safe_eval( evaluator: Evaluator, learner: BaseLearner, collector: Collector, - rank: int, - world_size: int, + rank: int = 0, + world_size: int = 1, timeout: int = EVALUATION_TIMEOUT ) -> Tuple[Optional[bool], Optional[Any]]: """ diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index 6c2cd1999..a4e432828 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -72,7 +72,10 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea # image obs input, e.g. atari environments self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) else: - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: + self.zero_obs_shape = config.model.observation_shape + elif len(config.model.observation_shape) == 3: + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) self.obs_segment = [] self.action_segment = [] diff --git a/lzero/model/common.py b/lzero/model/common.py index 31f254963..d8fd154a7 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -505,8 +505,19 @@ def __init__(self, if tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(model_path) - if tokenizer is not None: + if tokenizer is None: + # Only rank 0 downloads the tokenizer, and then other processes load it from cache. + if get_rank() == 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + if get_world_size() > 1: + torch.distributed.barrier() + if get_rank() != 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + else: self.tokenizer = tokenizer + + for p in self.pretrained_model.parameters(): + p.requires_grad_(False) self.embedding_size = embedding_size self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size) diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index b680a6e2d..2270bc05e 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -129,12 +129,12 @@ def __init__( else: raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}") - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, - with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, + with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) # --- Log parameter counts for analysis --- - self._log_model_parameters(obs_type) + self._log_model_parameters(world_model_cfg.obs_type) logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') logging.info('==' * 20) diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py index a7356d942..8832303d2 100644 --- a/lzero/model/unizero_model_multitask.py +++ b/lzero/model/unizero_model_multitask.py @@ -6,7 +6,8 @@ from easydict import EasyDict from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ - VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ + HFLanguageRepresentationNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model_multitask import WorldModelMT from .vit import ViT, ViTConfig @@ -84,6 +85,8 @@ def __init__( self._init_image_components(world_model_cfg, observation_shape, num_res_blocks, num_channels, obs_act_embed_dim) elif obs_type == 'image_memory': self._init_image_memory_components(world_model_cfg) + elif obs_type == 'text': + self._init_text_components(world_model_cfg, encoder_url=kwargs['encoder_url']) else: raise ValueError(f"Unsupported observation type: {obs_type}") @@ -93,6 +96,20 @@ def __init__( # --- Log parameter counts for analysis --- self._log_model_parameters(obs_type) + def _init_text_components(self, world_model_cfg: EasyDict, encoder_url: str) -> None: + """Initializes components for 'text' observation type.""" + self.representation_network = HFLanguageRepresentationNetwork( + model_path=encoder_url, + embedding_size=world_model_cfg.embed_dim, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder + ) + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=None, + with_lpips=False, + obs_type=world_model_cfg.obs_type + ) + def _init_vector_components(self, world_model_cfg: EasyDict, obs_act_embed_dim: int) -> None: """Initializes components for 'vector' observation type.""" self.representation_network = RepresentationNetworkMLP( diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index 836a463c7..8a6541537 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -1708,7 +1708,13 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dtype=batch['observations'].dtype) perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) + + elif self.obs_type == 'text': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=torch.float32) + latent_recon_loss = self.latent_recon_loss + # Action tokens if self.continuous_action_space: act_tokens = batch['actions'] diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 766012870..e44e321ca 100755 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -1629,7 +1629,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in self._cfg.device, pad_token_id=self.pad_token_id ) - self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + self.last_batch_action_collect = [-1 for _ in range(self._cfg.collector_env_num)] # We must handle both single int and list of ints for env_id. @@ -1707,7 +1707,7 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ ) logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) - self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + self.last_batch_action_eval = [-1 for _ in range(self._cfg.evaluator_env_num)] # This logic handles the crucial end-of-episode cache clearing for evaluation. # The evaluator calls `_policy.reset([env_id])` when an episode is done. diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 1e0a65845..006eaa5b1 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -126,7 +126,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: self._logger.debug( f"Set default n_episode mode(n_episode({self._default_n_episode}), env_num({self._env_num}))" ) - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ @@ -359,7 +359,7 @@ def collect( chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} # Initialize game segments and observation stacks for each environment. - game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) for _ in range(env_nums)] + game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id) for _ in range(env_nums)] observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)] for env_id in range(env_nums): for _ in range(self.policy_config.model.frame_stack_num): @@ -525,7 +525,7 @@ def collect( last_game_priorities[env_id] = priorities # Start a new game segment. - game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id) game_segments[env_id].reset(observation_window_stack[env_id]) self._env_info[env_id]['step'] += 1 @@ -570,7 +570,7 @@ def collect( chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) # Reset game segment and observation stack. - game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id) observation_window_stack[env_id].clear() for _ in range(self.policy_config.model.frame_stack_num): observation_window_stack[env_id].append(init_obs[env_id]['observation']) @@ -585,7 +585,7 @@ def collect( completed_value_lst[env_id] = 0 # Reset policy and collector stats for the finished environment. - self._policy.reset([env_id]) + self._policy.reset([env_id], task_id=self.task_id) self._reset_stat(env_id) ready_env_id.remove(env_id) @@ -663,21 +663,18 @@ def _output_log(self, train_iter: int) -> None: self._episode_info.clear() - # Log to console - self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()]))) - + self._logger.info(f"Collector log (rank {self._rank}, task_id {self.task_id}):\n" + '\n'.join([f'{k}: {v}' for k, v in info.items()])) # Log to TensorBoard and WandB for k, v in info.items(): + if k in ['each_reward']: + continue if self.task_id is None: - tb_prefix_iter = f'{self._instance_name}_iter/' - tb_prefix_step = f'{self._instance_name}_step/' + # Log for single-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, self._total_envstep_count) else: - tb_prefix_iter = f'{self._instance_name}_iter_task{self.task_id}/' - tb_prefix_step = f'{self._instance_name}_step_task{self.task_id}/' - - self._tb_logger.add_scalar(tb_prefix_iter + k, v, train_iter) - self._tb_logger.add_scalar(tb_prefix_step + k, v, self._total_envstep_count) - - if self.policy_config.use_wandb: - wandb_log_data = {tb_prefix_step + k: v for k, v in info.items()} - wandb.log(wandb_log_data, step=self._total_envstep_count) + # Log for multi-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, self._total_envstep_count) \ No newline at end of file diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index c1b34da37..8f02db973 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -135,7 +135,9 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e obs_type="text", env_num=max(collector_env_num, evaluator_env_num), decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. - latent_recon_loss_weight=0.1 + latent_recon_loss_weight=0.1, + game_segment_length=50, + use_priority=False, ), ), update_per_collect=int(collector_env_num*max_steps*replay_ratio ), # Important for DDP diff --git a/zoo/jericho/configs/jericho_unizero_ddp_config.py b/zoo/jericho/configs/jericho_unizero_ddp_config.py index c204148b2..cf9fc33b1 100644 --- a/zoo/jericho/configs/jericho_unizero_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_ddp_config.py @@ -141,7 +141,9 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e obs_type="text", # TODO: Modify as needed. env_num=max(collector_env_num, evaluator_env_num), decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. - latent_recon_loss_weight=0.1 # TODO: decoder loss weight + latent_recon_loss_weight=0, # TODO: decoder loss weight + game_segment_length=50, + use_priority=False, ), ), # TODO diff --git a/zoo/jericho/configs/jericho_unizero_multitask_config.py b/zoo/jericho/configs/jericho_unizero_multitask_config.py new file mode 100644 index 000000000..ea3285445 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_multitask_config.py @@ -0,0 +1,275 @@ +from easydict import EasyDict + +def create_config(env_id, max_steps, max_action_num, action_space_size, collector_env_num, evaluator_env_num, n_episode, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, total_batch_size, + num_layers, model_name, replay_ratio, norm_type, update_per_collect, + collect_num_simulations, eval_num_simulations): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=512, + max_steps=max_steps, + max_action_num=max_action_num, + tokenizer_path=model_name, + max_seq_len=512, + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + ), + policy=dict( + multi_gpu=False, # Very important for ddp + only_use_moco_stats=False, + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + use_moco=False, # Whether to use MoCo for multi-task gradient adjustments + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + use_wandb=False, + learn=dict( + learner=dict( + hook=dict( + save_ckpt_after_iter=200000 + ), + ), + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + model_type="mlp", + norm_type=norm_type, + continuous_action_space=False, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + obs_type="text", + env_num=max(collector_env_num, evaluator_env_num), + task_embed_option=None, + use_task_embed=False, + embed_dim=768, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + multiplication_moe_in_transformer=False, # Whether to use moe in transformers + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + moe_use_lora=False, # Does moe use lora + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + + analysis_dormant_ratio_weight_rank=False, + analysis_dormant_ratio_interval=5000, + game_segment_length=50, + use_priority=False, + ), + ), + optim_type='AdamW', + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=False, + + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + target_entropy_start_ratio =0.98, + # target_entropy_end_ratio =0.9, # TODO===== + # target_entropy_end_ratio =0.7, + # target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 + + target_entropy_end_ratio =0.5, # for action_space=18 + target_entropy_decay_steps = 100000, # 例如,在150k次迭代 300k envsteps后达到最终值 + # target_entropy_decay_steps = 150000, # 例如,在150k次迭代 300k envsteps后达到最终值 + + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=False, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=30000, # 例如,在30k次迭代后达到最终值 + + + # ==================== START: label smooth ==================== + policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + label_smoothing_eps=0, #TODO============= for value + + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=10000, + + + use_task_exploitation_weight=False, + task_complexity_weight=False, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=update_per_collect, + action_type="varied_action_space", + replay_ratio=replay_ratio, + batch_size=batch_size, + reanalyze_ratio=reanalyze_ratio, + learning_rate=0.0001, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + n_episode=n_episode, + train_start_after_envsteps=int(0), + replay_buffer_size=int(5e5), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list, env_configurations, collector_env_num, n_episode, evaluator_env_num, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + total_batch_size, num_layers, model_name, replay_ratio, norm_type, collect_num_simulations, eval_num_simulations): + """ + Overview: + Generates a list of configurations for all specified tasks. + + Arguments: + (See arguments for `create_config` function) + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[List[Union[int, List[EasyDict]]]]`): A list where each element contains a task_id + and its corresponding configuration objects. + """ + configs = [] + + exp_name_prefix = f'data_scalezero2/jericho_mt_moe8_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' + + action_space_size_list = [v[0] for _, v in env_configurations.items()] + max_action_space_size = max(action_space_size_list) + + for task_id, env_id in enumerate(env_id_list): + _, max_steps = env_configurations.get(env_id, (10, 50)) + update_per_collect = 40 # Ensure at least one update per collect + + config = create_config( + env_id=env_id, max_steps=max_steps, max_action_num=max_action_space_size, action_space_size=max_action_space_size, + collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_episode=n_episode, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, + num_layers=num_layers, model_name=model_name, replay_ratio=replay_ratio, + norm_type=norm_type, update_per_collect=update_per_collect, collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('.z5')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs for distributed training. + + Example launch commands: + + cd /path/to/your/project/ + python zoo/jericho/configs/jericho_unizero_multitask_config.py + """ + + from lzero.entry import train_unizero_multitask + import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + env_configurations = { + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), + } + env_id_list = list(env_configurations.keys()) + + # Model name or path - configurable according to the predefined model paths or names + model_name: str = 'BAAI/bge-base-en-v1.5' + replay_ratio = 0.1 + norm_type = 'LN' + + collector_env_num = 4 + n_episode = 4 + evaluator_env_num = 2 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + total_batch_size =int(64 * len(env_id_list)) + batch_size = [int(total_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + + num_layers=2 + num_unroll_steps = 10 + infer_context_length = 4 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + collect_num_simulations = 50 + eval_num_simulations = 50 + + for seed in [0]: + configs = generate_configs( env_id_list=env_id_list, env_configurations=env_configurations, + collector_env_num=collector_env_num, n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, + seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, + total_batch_size=total_batch_size, num_layers=num_layers, + model_name=model_name, replay_ratio=replay_ratio, norm_type=norm_type, + collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations) + train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) diff --git a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py new file mode 100644 index 000000000..5d8660eda --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py @@ -0,0 +1,288 @@ +from easydict import EasyDict + +def create_config(env_id, max_steps, max_action_num, action_space_size, collector_env_num, evaluator_env_num, n_episode, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, total_batch_size, + num_layers, model_name, replay_ratio, norm_type, update_per_collect, + collect_num_simulations, eval_num_simulations): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=512, + max_steps=max_steps, + max_action_num=max_action_num, + tokenizer_path=model_name, + max_seq_len=512, + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + only_use_moco_stats=False, + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + use_moco=False, # Whether to use MoCo for multi-task gradient adjustments + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + use_wandb=False, + learn=dict( + learner=dict( + hook=dict( + save_ckpt_after_iter=200000 + ), + ), + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + model_type="mlp", + norm_type=norm_type, + continuous_action_space=False, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + obs_type="text", + env_num=max(collector_env_num, evaluator_env_num), + task_embed_option=None, + use_task_embed=False, + embed_dim=768, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + multiplication_moe_in_transformer=False, # Whether to use moe in transformers + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + moe_use_lora=False, # Does moe use lora + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + + analysis_dormant_ratio_weight_rank=False, + analysis_dormant_ratio_interval=5000, + game_segment_length=50, + use_priority=False, + ), + ), + optim_type='AdamW', + # (bool) 是否启用自适应策略熵权重 (alpha) + use_adaptive_entropy_weight=False, + + # (float) 自适应alpha优化器的学习率 + adaptive_entropy_alpha_lr=1e-4, + target_entropy_start_ratio =0.98, + # target_entropy_end_ratio =0.9, # TODO===== + # target_entropy_end_ratio =0.7, + # target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 + + target_entropy_end_ratio =0.5, # for action_space=18 + target_entropy_decay_steps = 100000, # 例如,在150k次迭代 300k envsteps后达到最终值 + # target_entropy_decay_steps = 150000, # 例如,在150k次迭代 300k envsteps后达到最终值 + + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) 是否启用 encoder-clip 值的退火。 + use_encoder_clip_annealing=False, + # (str) 退火类型。可选 'linear' 或 'cosine'。 + encoder_clip_anneal_type='cosine', + # (float) 退火的起始 clip 值 (训练初期,较宽松)。 + encoder_clip_start_value=30.0, + # (float) 退火的结束 clip 值 (训练后期,较严格)。 + encoder_clip_end_value=10.0, + # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 + encoder_clip_anneal_steps=30000, # 例如,在30k次迭代后达到最终值 + + + # ==================== START: label smooth ==================== + policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, # 50k + label_smoothing_eps=0, #TODO============= for value + + # ==================== [新增] 范数监控频率 ==================== + # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 + monitor_norm_freq=10000, + + + use_task_exploitation_weight=False, + task_complexity_weight=False, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=update_per_collect, + action_type="varied_action_space", + replay_ratio=replay_ratio, + batch_size=batch_size, + reanalyze_ratio=reanalyze_ratio, + learning_rate=0.0001, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + n_episode=n_episode, + train_start_after_envsteps=int(0), + replay_buffer_size=int(5e5), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list, env_configurations, collector_env_num, n_episode, evaluator_env_num, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + total_batch_size, num_layers, model_name, replay_ratio, norm_type, collect_num_simulations, eval_num_simulations): + """ + Overview: + Generates a list of configurations for all specified tasks. + + Arguments: + (See arguments for `create_config` function) + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[List[Union[int, List[EasyDict]]]]`): A list where each element contains a task_id + and its corresponding configuration objects. + """ + configs = [] + exp_name_prefix = f'data_scalezero/jericho_ddp_mt_moe8_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' + + action_space_size_list = [v[0] for _, v in env_configurations.items()] + max_action_space_size = max(action_space_size_list) + + for task_id, env_id in enumerate(env_id_list): + _, max_steps = env_configurations.get(env_id, (10, 50)) + update_per_collect = 40 # Ensure at least one update per collect + + config = create_config( + env_id=env_id, max_steps=max_steps, max_action_num=max_action_space_size, action_space_size=max_action_space_size, + collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_episode=n_episode, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, + num_layers=num_layers, model_name=model_name, replay_ratio=replay_ratio, + norm_type=norm_type, update_per_collect=update_per_collect, collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('.z5')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs for distributed training. + + Example launch commands: + + export CUDA_VISIBLE_DEVICES=0,1,2,3 + cd /path/to/your/project/ + + torchrun --nproc_per_node=4 zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_ddp + from ding.utils import DDPContext + import torch.distributed as dist + import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + env_configurations = { + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + # 'acorncourt.z5': (45, 50), + # 'zork1.z5': (55, 500), + } + env_id_list = list(env_configurations.keys()) + + # Model name or path - configurable according to the predefined model paths or names + model_name: str = 'BAAI/bge-base-en-v1.5' + replay_ratio = 0.1 + norm_type = 'LN' + + collector_env_num = 4 + n_episode = 4 + evaluator_env_num = 2 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + total_batch_size =int(64 * len(env_id_list)) + batch_size = [int(total_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + + num_layers=2 + num_unroll_steps = 10 + infer_context_length = 4 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + collect_num_simulations = 50 + eval_num_simulations = 50 + + + # Set NCCL timeout to prevent watchdog hang due to unbalanced data collection speeds + os.environ.setdefault('NCCL_TIMEOUT', '480') # 60 minutes in seconds + os.environ.setdefault('NCCL_BLOCKING_WAIT', '1') + + for seed in [0]: + configs = generate_configs( env_id_list=env_id_list, env_configurations=env_configurations, + collector_env_num=collector_env_num, n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, + seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, + total_batch_size=total_batch_size, num_layers=num_layers, + model_name=model_name, replay_ratio=replay_ratio, norm_type=norm_type, + collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations) + + with DDPContext(): + train_unizero_multitask_ddp(configs, seed=seed, max_env_step=max_env_step) + print(f"Seed: {seed} training finished!") + if dist.is_initialized(): + dist.destroy_process_group() \ No newline at end of file From 228e5ce4c7d70bb8e653482e34a7e2c05de3fdcc Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Thu, 29 Jan 2026 22:48:30 +0800 Subject: [PATCH 2/9] fix a bug in ddp setting for jericho-mt env --- lzero/model/common.py | 6 ------ zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/lzero/model/common.py b/lzero/model/common.py index d8fd154a7..4569090ba 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -498,12 +498,6 @@ def __init__(self, torch.distributed.barrier() if get_rank() != 0: self.pretrained_model = AutoModel.from_pretrained(model_path) - - if get_rank() != 0: - logging.info(f"Worker process is loading model from cache: {model_path}") - self.model = AutoModel.from_pretrained(model_path) - if tokenizer is None: - self.tokenizer = AutoTokenizer.from_pretrained(model_path) if tokenizer is None: # Only rank 0 downloads the tokenizer, and then other processes load it from cache. diff --git a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py index 5d8660eda..deaa060a9 100644 --- a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py @@ -266,7 +266,7 @@ def create_env_manager(): # Set NCCL timeout to prevent watchdog hang due to unbalanced data collection speeds - os.environ.setdefault('NCCL_TIMEOUT', '480') # 60 minutes in seconds + os.environ.setdefault('NCCL_TIMEOUT', '480') os.environ.setdefault('NCCL_BLOCKING_WAIT', '1') for seed in [0]: From 154e002399dd0a0a37f351fc8ca1b3544fb076dd Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Fri, 30 Jan 2026 14:09:47 +0800 Subject: [PATCH 3/9] fix a bug resulting in the frequent evaluation --- lzero/entry/train_unizero_multitask.py | 2 +- lzero/entry/train_unizero_multitask_ddp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py index 0529a5450..cafa2965b 100644 --- a/lzero/entry/train_unizero_multitask.py +++ b/lzero/entry/train_unizero_multitask.py @@ -206,7 +206,7 @@ def train_unizero_multitask( if update_per_collect is None: update_per_collect = 40 - if learner.train_iter > 0 or evaluator.should_eval(learner.train_iter): # only for debug + if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): # only for debug print(f'Evaluating task_id: {current_task_id}...') # Reset evaluator policy state evaluator._policy.reset(reset_init_data=True, task_id=current_task_id) diff --git a/lzero/entry/train_unizero_multitask_ddp.py b/lzero/entry/train_unizero_multitask_ddp.py index fb28d935f..0eb1f6d9a 100644 --- a/lzero/entry/train_unizero_multitask_ddp.py +++ b/lzero/entry/train_unizero_multitask_ddp.py @@ -245,7 +245,7 @@ def train_unizero_multitask_ddp( ) collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) - if learner.train_iter > 10 or evaluator.should_eval(learner.train_iter): + if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): print('=' * 20) print(f'Rank {rank} evaluating task_id: {cfg.policy.task_id}...') evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) From e5b2096724e5fa2b7fce2d9e8bb5f2ebc22be3fb Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Sat, 31 Jan 2026 12:34:53 +0800 Subject: [PATCH 4/9] fix a bug when outputing the collect log --- lzero/worker/muzero_collector.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 006eaa5b1..a1d49bc24 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -628,9 +628,6 @@ def _output_log(self, train_iter: int) -> None: Arguments: - train_iter (:obj:`int`): The current training iteration number, used as the logging step. """ - if self._rank != 0: - return - if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) From e5306de047473e41722aefa7b43bc392da24d7a7 Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Mon, 2 Feb 2026 12:02:22 +0800 Subject: [PATCH 5/9] fix a small bug --- lzero/worker/muzero_collector.py | 2 +- .../configs/jericho_unizero_multitask_config.py | 2 +- .../configs/jericho_unizero_multitask_ddp_config.py | 11 ++++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index a1d49bc24..f7956e45b 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -87,7 +87,7 @@ def __init__( self._logger, _ = build_logger( path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) - self._tb_logger = None + self._tb_logger = tb_logger self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy diff --git a/zoo/jericho/configs/jericho_unizero_multitask_config.py b/zoo/jericho/configs/jericho_unizero_multitask_config.py index ea3285445..c2ad4029f 100644 --- a/zoo/jericho/configs/jericho_unizero_multitask_config.py +++ b/zoo/jericho/configs/jericho_unizero_multitask_config.py @@ -241,7 +241,7 @@ def create_env_manager(): # Model name or path - configurable according to the predefined model paths or names model_name: str = 'BAAI/bge-base-en-v1.5' replay_ratio = 0.1 - norm_type = 'LN' + norm_type = 'BN' collector_env_num = 4 n_episode = 4 diff --git a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py index deaa060a9..d0b8272d5 100644 --- a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py @@ -236,15 +236,15 @@ def create_env_manager(): env_configurations = { 'detective.z5': (12, 100), 'omniquest.z5': (25, 100), - # 'acorncourt.z5': (45, 50), - # 'zork1.z5': (55, 500), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), } env_id_list = list(env_configurations.keys()) # Model name or path - configurable according to the predefined model paths or names model_name: str = 'BAAI/bge-base-en-v1.5' replay_ratio = 0.1 - norm_type = 'LN' + norm_type = 'BN' collector_env_num = 4 n_episode = 4 @@ -266,8 +266,9 @@ def create_env_manager(): # Set NCCL timeout to prevent watchdog hang due to unbalanced data collection speeds - os.environ.setdefault('NCCL_TIMEOUT', '480') - os.environ.setdefault('NCCL_BLOCKING_WAIT', '1') + os.environ['NCCL_TIMEOUT'] = '600' + os.environ['NCCL_BLOCKING_WAIT'] = '1' + os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '1' for seed in [0]: configs = generate_configs( env_id_list=env_id_list, env_configurations=env_configurations, From e11dcfeb61e533af6df4a036095b5bd5109cca4f Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Mon, 2 Feb 2026 22:00:55 +0800 Subject: [PATCH 6/9] polish saved cfg_name --- zoo/jericho/configs/jericho_unizero_multitask_config.py | 2 +- zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/zoo/jericho/configs/jericho_unizero_multitask_config.py b/zoo/jericho/configs/jericho_unizero_multitask_config.py index c2ad4029f..4460b9a18 100644 --- a/zoo/jericho/configs/jericho_unizero_multitask_config.py +++ b/zoo/jericho/configs/jericho_unizero_multitask_config.py @@ -179,7 +179,7 @@ def generate_configs(env_id_list, env_configurations, collector_env_num, n_episo """ configs = [] - exp_name_prefix = f'data_scalezero2/jericho_mt_moe8_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' + exp_name_prefix = f'data_scalezero/jericho_mt_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' action_space_size_list = [v[0] for _, v in env_configurations.items()] max_action_space_size = max(action_space_size_list) diff --git a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py index d0b8272d5..01d0faacd 100644 --- a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py @@ -178,7 +178,7 @@ def generate_configs(env_id_list, env_configurations, collector_env_num, n_episo and its corresponding configuration objects. """ configs = [] - exp_name_prefix = f'data_scalezero/jericho_ddp_mt_moe8_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' + exp_name_prefix = f'data_scalezero/jericho_ddp_mt_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' action_space_size_list = [v[0] for _, v in env_configurations.items()] max_action_space_size = max(action_space_size_list) From ee602b5cfd51720717aa959f01cbefdf85024595 Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Thu, 5 Feb 2026 00:12:34 +0800 Subject: [PATCH 7/9] tmp --- lzero/entry/train_unizero_multitask.py | 1 + lzero/entry/train_unizero_multitask_ddp.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py index cafa2965b..336ef20f1 100644 --- a/lzero/entry/train_unizero_multitask.py +++ b/lzero/entry/train_unizero_multitask.py @@ -233,6 +233,7 @@ def train_unizero_multitask( collector._policy.reset(reset_init_data=True, task_id=current_task_id) new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + logging.info(f'Finished data collection for task {cfg.policy.task_id}, collected {len(new_data[0]) if new_data else 0} segments') replay_buffer.push_game_segments(new_data) replay_buffer.remove_oldest_data_to_fit() diff --git a/lzero/entry/train_unizero_multitask_ddp.py b/lzero/entry/train_unizero_multitask_ddp.py index 0eb1f6d9a..896a30904 100644 --- a/lzero/entry/train_unizero_multitask_ddp.py +++ b/lzero/entry/train_unizero_multitask_ddp.py @@ -222,7 +222,9 @@ def train_unizero_multitask_ddp( # Perform data collection and evaluation for each task on this rank for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( zip(cfgs, collectors, evaluators, game_buffers)): - + + policy_config = cfg.policy + # Log buffer memory usage log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) From 924eaed40403eeb88b6b39326fbb9beb19d961bc Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Thu, 5 Feb 2026 01:09:07 +0800 Subject: [PATCH 8/9] delete unused config --- .../configs/jericho_unizero_multitask_config.py | 12 +++++------- .../configs/jericho_unizero_multitask_ddp_config.py | 12 +++++------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/zoo/jericho/configs/jericho_unizero_multitask_config.py b/zoo/jericho/configs/jericho_unizero_multitask_config.py index 4460b9a18..65b59ac4e 100644 --- a/zoo/jericho/configs/jericho_unizero_multitask_config.py +++ b/zoo/jericho/configs/jericho_unizero_multitask_config.py @@ -3,7 +3,7 @@ def create_config(env_id, max_steps, max_action_num, action_space_size, collector_env_num, evaluator_env_num, n_episode, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, total_batch_size, - num_layers, model_name, replay_ratio, norm_type, update_per_collect, + num_layers, model_name, replay_ratio, update_per_collect, collect_num_simulations, eval_num_simulations): return EasyDict(dict( env=dict( @@ -47,7 +47,6 @@ def create_config(env_id, max_steps, max_action_num, action_space_size, collecto action_space_size=action_space_size, encoder_url=model_name, model_type="mlp", - norm_type=norm_type, continuous_action_space=False, world_model_cfg=dict( final_norm_option_in_obs_head='LayerNorm', @@ -164,7 +163,7 @@ def create_config(env_id, max_steps, max_action_num, action_space_size, collecto def generate_configs(env_id_list, env_configurations, collector_env_num, n_episode, evaluator_env_num, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, - total_batch_size, num_layers, model_name, replay_ratio, norm_type, collect_num_simulations, eval_num_simulations): + total_batch_size, num_layers, model_name, replay_ratio, collect_num_simulations, eval_num_simulations): """ Overview: Generates a list of configurations for all specified tasks. @@ -195,7 +194,7 @@ def generate_configs(env_id_list, env_configurations, collector_env_num, n_episo num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, buffer_reanalyze_freq=buffer_reanalyze_freq, reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, num_layers=num_layers, model_name=model_name, replay_ratio=replay_ratio, - norm_type=norm_type, update_per_collect=update_per_collect, collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations, + update_per_collect=update_per_collect, collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations, ) config.policy.task_id = task_id config.exp_name = exp_name_prefix + f"{env_id.split('.z5')[0]}_seed{seed}" @@ -241,7 +240,6 @@ def create_env_manager(): # Model name or path - configurable according to the predefined model paths or names model_name: str = 'BAAI/bge-base-en-v1.5' replay_ratio = 0.1 - norm_type = 'BN' collector_env_num = 4 n_episode = 4 @@ -270,6 +268,6 @@ def create_env_manager(): seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, num_layers=num_layers, - model_name=model_name, replay_ratio=replay_ratio, norm_type=norm_type, - collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations) + model_name=model_name, replay_ratio=replay_ratio,collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations) train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) diff --git a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py index 01d0faacd..fdc9b8412 100644 --- a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py @@ -3,7 +3,7 @@ def create_config(env_id, max_steps, max_action_num, action_space_size, collector_env_num, evaluator_env_num, n_episode, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, total_batch_size, - num_layers, model_name, replay_ratio, norm_type, update_per_collect, + num_layers, model_name, replay_ratio, update_per_collect, collect_num_simulations, eval_num_simulations): return EasyDict(dict( env=dict( @@ -47,7 +47,6 @@ def create_config(env_id, max_steps, max_action_num, action_space_size, collecto action_space_size=action_space_size, encoder_url=model_name, model_type="mlp", - norm_type=norm_type, continuous_action_space=False, world_model_cfg=dict( final_norm_option_in_obs_head='LayerNorm', @@ -164,7 +163,7 @@ def create_config(env_id, max_steps, max_action_num, action_space_size, collecto def generate_configs(env_id_list, env_configurations, collector_env_num, n_episode, evaluator_env_num, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, - total_batch_size, num_layers, model_name, replay_ratio, norm_type, collect_num_simulations, eval_num_simulations): + total_batch_size, num_layers, model_name, replay_ratio, collect_num_simulations, eval_num_simulations): """ Overview: Generates a list of configurations for all specified tasks. @@ -194,7 +193,7 @@ def generate_configs(env_id_list, env_configurations, collector_env_num, n_episo num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, buffer_reanalyze_freq=buffer_reanalyze_freq, reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, num_layers=num_layers, model_name=model_name, replay_ratio=replay_ratio, - norm_type=norm_type, update_per_collect=update_per_collect, collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations, + update_per_collect=update_per_collect, collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations, ) config.policy.task_id = task_id config.exp_name = exp_name_prefix + f"{env_id.split('.z5')[0]}_seed{seed}" @@ -244,7 +243,6 @@ def create_env_manager(): # Model name or path - configurable according to the predefined model paths or names model_name: str = 'BAAI/bge-base-en-v1.5' replay_ratio = 0.1 - norm_type = 'BN' collector_env_num = 4 n_episode = 4 @@ -279,8 +277,8 @@ def create_env_manager(): seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, num_layers=num_layers, - model_name=model_name, replay_ratio=replay_ratio, norm_type=norm_type, - collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations) + model_name=model_name, replay_ratio=replay_ratio, collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations) with DDPContext(): train_unizero_multitask_ddp(configs, seed=seed, max_env_step=max_env_step) From 1b4f4875994fc155065b2775be7ec97e91a0fd5a Mon Sep 17 00:00:00 2001 From: xiongjyu Date: Mon, 9 Feb 2026 20:39:58 +0800 Subject: [PATCH 9/9] Standard format and delete unused configs --- lzero/entry/__init__.py | 27 +------------- lzero/entry/train_unizero_multitask_ddp.py | 5 +-- zoo/jericho/configs/jericho_unizero_config.py | 2 +- .../configs/jericho_unizero_ddp_config.py | 10 ++---- .../jericho_unizero_multitask_config.py | 36 +++++-------------- .../jericho_unizero_multitask_ddp_config.py | 36 +++++-------------- 6 files changed, 24 insertions(+), 92 deletions(-) diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index be0a145e0..aa0c1909c 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -15,29 +15,4 @@ from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp from .train_unizero_with_loss_landscape import train_unizero_with_loss_landscape from .train_unizero_multitask import train_unizero_multitask -from .train_unizero_multitask_ddp import train_unizero_multitask_ddp - -# from .utils import ( -# symlog, -# inv_symlog, -# initialize_zeros_batch, -# freeze_non_lora_parameters, -# compute_task_weights, -# TemperatureScheduler, -# tasks_per_stage, -# compute_unizero_mt_normalized_stats, -# allocate_batch_size, -# is_ddp_enabled, -# ddp_synchronize, -# ddp_all_reduce_sum, -# calculate_update_per_collect, -# initialize_pad_batch, -# random_collect, -# convert_to_batch_for_unizero, -# create_unizero_loss_metrics, -# UniZeroDataLoader, -# log_module_trainable_status, -# log_param_statistics, -# log_buffer_memory_usage, -# log_buffer_run_time, -# ) +from .train_unizero_multitask_ddp import train_unizero_multitask_ddp \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask_ddp.py b/lzero/entry/train_unizero_multitask_ddp.py index 896a30904..faf7c6ce0 100644 --- a/lzero/entry/train_unizero_multitask_ddp.py +++ b/lzero/entry/train_unizero_multitask_ddp.py @@ -45,10 +45,7 @@ def train_unizero_multitask_ddp( ) -> 'Policy': """ Overview: - Entry point for UniZero training. The goal is to improve the planning ability - of reinforcement learning agents by addressing the limitations of MuZero-like - algorithms in environments that require capturing long-term dependencies. - For more details, refer to https://arxiv.org/abs/2406.10667. + Entry point for UniZero multi-task training (DDP version). Args: - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configuration list for different tasks. - seed (:obj:`int`): Random seed. diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index 8f02db973..82d2f01d3 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -109,7 +109,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), ), ), - accumulation_steps=1, # TODO: Accumulated gradient steps (currently default) + accumulation_steps=1, model=dict( observation_shape=512, action_space_size=action_space_size, diff --git a/zoo/jericho/configs/jericho_unizero_ddp_config.py b/zoo/jericho/configs/jericho_unizero_ddp_config.py index cf9fc33b1..64f4a63ca 100644 --- a/zoo/jericho/configs/jericho_unizero_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_ddp_config.py @@ -34,9 +34,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e else: raise ValueError(f"Unsupported encoder option: {encoder_option}") - # TODO - # batch_size = batch_size * 2 - # ------------------------------------------------------------------ # Base environment parameters (Note: these values might be adjusted for different env_id) # ------------------------------------------------------------------ @@ -115,7 +112,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), ), ), - accumulation_steps=accumulation_steps, # TODO: Accumulated gradient steps (currently default) + accumulation_steps=accumulation_steps, model=dict( observation_shape=512, action_space_size=action_space_size, @@ -138,7 +135,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e num_layers=num_layers, num_heads=24, embed_dim=embed_dim, - obs_type="text", # TODO: Modify as needed. + obs_type="text", env_num=max(collector_env_num, evaluator_env_num), decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. latent_recon_loss_weight=0, # TODO: decoder loss weight @@ -146,7 +143,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e use_priority=False, ), ), - # TODO update_per_collect=int(collector_env_num*max_steps*replay_ratio*accumulation_steps), # Important for DDP action_type="varied_action_space", model_path=None, @@ -161,7 +157,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e num_simulations=num_simulations, n_episode=n_episode, - train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed. + train_start_after_envsteps=0, replay_buffer_size=int(5e5), eval_freq=int(3e4), collector_env_num=collector_env_num, diff --git a/zoo/jericho/configs/jericho_unizero_multitask_config.py b/zoo/jericho/configs/jericho_unizero_multitask_config.py index 65b59ac4e..8ab654915 100644 --- a/zoo/jericho/configs/jericho_unizero_multitask_config.py +++ b/zoo/jericho/configs/jericho_unizero_multitask_config.py @@ -91,43 +91,25 @@ def create_config(env_id, max_steps, max_action_num, action_space_size, collecto ), ), optim_type='AdamW', - # (bool) 是否启用自适应策略熵权重 (alpha) + + # (bool) whether enable adaptive policy entropy weights use_adaptive_entropy_weight=False, - - # (float) 自适应alpha优化器的学习率 + # (float) learning rate of adaptive alpha optimizer adaptive_entropy_alpha_lr=1e-4, target_entropy_start_ratio =0.98, - # target_entropy_end_ratio =0.9, # TODO===== - # target_entropy_end_ratio =0.7, - # target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 - - target_entropy_end_ratio =0.5, # for action_space=18 - target_entropy_decay_steps = 100000, # 例如,在150k次迭代 300k envsteps后达到最终值 - # target_entropy_decay_steps = 150000, # 例如,在150k次迭代 300k envsteps后达到最终值 + target_entropy_end_ratio =0.5, + target_entropy_decay_steps = 100000, - # ==================== START: Encoder-Clip Annealing Config ==================== - # (bool) 是否启用 encoder-clip 值的退火。 use_encoder_clip_annealing=False, - # (str) 退火类型。可选 'linear' 或 'cosine'。 encoder_clip_anneal_type='cosine', - # (float) 退火的起始 clip 值 (训练初期,较宽松)。 encoder_clip_start_value=30.0, - # (float) 退火的结束 clip 值 (训练后期,较严格)。 encoder_clip_end_value=10.0, - # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 - encoder_clip_anneal_steps=30000, # 例如,在30k次迭代后达到最终值 + encoder_clip_anneal_steps=30000, - - # ==================== START: label smooth ==================== - policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman + policy_ls_eps_start=0.05, policy_ls_eps_end=0.01, - policy_ls_eps_decay_steps=50000, # 50k - label_smoothing_eps=0, #TODO============= for value - - # ==================== [新增] 范数监控频率 ==================== - # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 - monitor_norm_freq=10000, - + policy_ls_eps_decay_steps=50000, + label_smoothing_eps=0, use_task_exploitation_weight=False, task_complexity_weight=False, diff --git a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py index fdc9b8412..5aa480cb9 100644 --- a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py @@ -91,42 +91,24 @@ def create_config(env_id, max_steps, max_action_num, action_space_size, collecto ), ), optim_type='AdamW', - # (bool) 是否启用自适应策略熵权重 (alpha) + # (bool) whether enable adaptive policy entropy weights use_adaptive_entropy_weight=False, - - # (float) 自适应alpha优化器的学习率 + # (float) learning rate of adaptive alpha optimizer adaptive_entropy_alpha_lr=1e-4, target_entropy_start_ratio =0.98, - # target_entropy_end_ratio =0.9, # TODO===== - # target_entropy_end_ratio =0.7, - # target_entropy_decay_steps = 100000, # 例如,在100k次迭代后达到最终值 - - target_entropy_end_ratio =0.5, # for action_space=18 - target_entropy_decay_steps = 100000, # 例如,在150k次迭代 300k envsteps后达到最终值 - # target_entropy_decay_steps = 150000, # 例如,在150k次迭代 300k envsteps后达到最终值 + target_entropy_end_ratio =0.5, + target_entropy_decay_steps = 100000, - # ==================== START: Encoder-Clip Annealing Config ==================== - # (bool) 是否启用 encoder-clip 值的退火。 use_encoder_clip_annealing=False, - # (str) 退火类型。可选 'linear' 或 'cosine'。 encoder_clip_anneal_type='cosine', - # (float) 退火的起始 clip 值 (训练初期,较宽松)。 encoder_clip_start_value=30.0, - # (float) 退火的结束 clip 值 (训练后期,较严格)。 encoder_clip_end_value=10.0, - # (int) 完成从起始值到结束值的退火所需的训练迭代步数。 - encoder_clip_anneal_steps=30000, # 例如,在30k次迭代后达到最终值 - - - # ==================== START: label smooth ==================== - policy_ls_eps_start=0.05, #TODO============= good start in Pong and MsPacman + encoder_clip_anneal_steps=30000, + + policy_ls_eps_start=0.05, policy_ls_eps_end=0.01, - policy_ls_eps_decay_steps=50000, # 50k - label_smoothing_eps=0, #TODO============= for value - - # ==================== [新增] 范数监控频率 ==================== - # 每隔多少个训练迭代步数,监控一次模型参数的范数。设置为0则禁用。 - monitor_norm_freq=10000, + policy_ls_eps_decay_steps=50000, + label_smoothing_eps=0, use_task_exploitation_weight=False,