diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index ba846e26a..aa0c1909c 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -14,28 +14,5 @@ from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp from .train_unizero_with_loss_landscape import train_unizero_with_loss_landscape - -# from .utils import ( -# symlog, -# inv_symlog, -# initialize_zeros_batch, -# freeze_non_lora_parameters, -# compute_task_weights, -# TemperatureScheduler, -# tasks_per_stage, -# compute_unizero_mt_normalized_stats, -# allocate_batch_size, -# is_ddp_enabled, -# ddp_synchronize, -# ddp_all_reduce_sum, -# calculate_update_per_collect, -# initialize_pad_batch, -# random_collect, -# convert_to_batch_for_unizero, -# create_unizero_loss_metrics, -# UniZeroDataLoader, -# log_module_trainable_status, -# log_param_statistics, -# log_buffer_memory_usage, -# log_buffer_run_time, -# ) +from .train_unizero_multitask import train_unizero_multitask +from .train_unizero_multitask_ddp import train_unizero_multitask_ddp \ No newline at end of file diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py new file mode 100644 index 000000000..336ef20f1 --- /dev/null +++ b/lzero/entry/train_unizero_multitask.py @@ -0,0 +1,357 @@ + +import logging +import os +from functools import partial +from typing import Tuple, Optional, List, Dict, Any +import concurrent.futures +import torch +import numpy as np +import torch.nn.functional as F +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy, Policy +from ding.utils import set_pkg_seed, EasyTimer +from ding.worker import BaseLearner +from lzero.entry.utils import ( + EVALUATION_TIMEOUT, + TemperatureScheduler, + allocate_batch_size, + compute_task_weights, + compute_unizero_mt_normalized_stats, + log_buffer_memory_usage, + safe_eval, + symlog, + inv_symlog, +) + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroCollector as Collector + +# Set timeout (seconds) +timer = EasyTimer() + +def train_unizero_multitask( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + Entry point for UniZero multi-task training (non-DDP version). + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configuration list for different tasks. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): Path to the pre-trained model. + - max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations. + - max_env_step (:obj:`Optional[int]`): Maximum number of collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): The converged policy. + """ + # Initialize temperature scheduler (unchanged) + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' + ) + + # Handle all tasks in a single process + tasks = input_cfg_list + total_tasks = len(tasks) + print(f"Handling all {total_tasks} tasks in a single process.") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + # Ensure at least one task is provided + if not tasks: + logging.error("No task configurations provided. Training cannot proceed.") + return None + + # Use the first task's configuration to create the shared policy and learner + task_id_first, [cfg_first, create_cfg_first] = tasks[0] + + assert create_cfg_first.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], "train_unizero_multitask entry currently only supports 'unizero_multitask' or 'sampled_unizero_multitask'" + + + GameBuffer = None + if create_cfg_first.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GB + GameBuffer = GB + elif create_cfg_first.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as SGB + GameBuffer = SGB + else: + raise NotImplementedError(f"Policy type {create_cfg_first.policy.type} not fully supported for GameBuffer import.") + + cfg_first.policy.device = 'cuda' if cfg_first.policy.cuda and torch.cuda.is_available() else 'cpu' + logging.info(f'Using device: {cfg_first.policy.device}') + + # Compile the main config (only for creating policy and learner) + # Note: we compile once here, but later re-compile per-task configs + compiled_cfg_first = compile_config(cfg_first, seed=seed, env=None, auto=True, create_cfg=create_cfg_first, save_cfg=True) + + # Create shared policy + policy = create_policy(compiled_cfg_first.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if model_path is not None: + logging.info(f'Loading pretrained model: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg_first.policy.device)) + logging.info(f'Finished loading model: {model_path}') + + log_dir = os.path.join('./{}/log/'.format(compiled_cfg_first.exp_name), 'serial') + tb_logger = SummaryWriter(log_dir) + + # Create shared learner + learner = BaseLearner(compiled_cfg_first.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=compiled_cfg_first.exp_name) + + # Process each task + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks): + # Set random seed per task + current_seed = seed + task_id + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + # Compile per-task config + cfg = compile_config(cfg, seed=current_seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Get policy config + policy_config = cfg.policy + policy_config.task_id = task_id # explicitly set task_id + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # Create environments + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(current_seed) + evaluator_env.seed(current_seed, dynamic_seed=False) + set_pkg_seed(current_seed, use_cuda=cfg.policy.cuda) + + # Create buffer, collector, and evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + + reanalyze_batch_size = compiled_cfg_first.policy.reanalyze_batch_size + update_per_collect = compiled_cfg_first.policy.update_per_collect + + task_exploitation_weight = None + task_rewards = {} + + while True: + # Iterate over tasks for data collection and evaluation + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + current_task_id = cfg.policy.task_id + + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, current_task_id) + + policy_config = cfg.policy + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 + } + update_per_collect = policy_config.update_per_collect + if update_per_collect is None: + update_per_collect = 40 + + if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): # only for debug + print(f'Evaluating task_id: {current_task_id}...') + # Reset evaluator policy state + evaluator._policy.reset(reset_init_data=True, task_id=current_task_id) + + # Perform safe evaluation (non-DDP version) + stop, reward = safe_eval(evaluator, learner, collector) + if stop is None or reward is None: + print(f"Evaluation failed or timed out, task_id: {current_task_id}, continuing training...") + task_rewards[current_task_id] = float('inf') + else: + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"Evaluation reward for task {current_task_id}: {eval_mean_reward}") + task_rewards[current_task_id] = eval_mean_reward + except Exception as e: + print(f"Error extracting reward for task {current_task_id}: {e}") + task_rewards[current_task_id] = float('inf') + + print('=' * 20) + print(f'Starting data collection for task_id: {current_task_id}...') + print(f'cfg.policy.task_id={current_task_id}') + + # Reset collector policy state + collector._policy.reset(reset_init_data=True, task_id=current_task_id) + + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + logging.info(f'Finished data collection for task {cfg.policy.task_id}, collected {len(new_data[0]) if new_data else 0} segments') + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + if policy_config.buffer_reanalyze_freq >= 1: + if update_per_collect is None or update_per_collect == 0: + logging.warning("update_per_collect undefined or zero, cannot compute reanalyze_interval") + reanalyze_interval = float('inf') + + else: + reanalyze_interval = update_per_collect // policy_config.buffer_reanalyze_freq + else: + reanalyze_interval = float('inf') + if train_epoch > 0 and policy_config.buffer_reanalyze_freq > 0 and \ + train_epoch % int(1 / policy_config.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // policy_config.num_unroll_steps > int(reanalyze_batch_size / policy_config.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, time cost: {timer.value}') + + logging.info(f'Finished data collection for task {current_task_id}') + + not_enough_data = any( + game_buffers[idx].get_num_of_transitions() < policy._cfg.batch_size[cfg.policy.task_id] + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)) + ) + task_weights = None + + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_this_epoch = 0 + + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + current_task_id = cfg.policy.task_id + current_batch_size = policy._cfg.batch_size[current_task_id] + + if current_batch_size == 0: + logging.warning(f"Task {current_task_id} batch_size is 0, skipping sampling.") + continue + + if replay_buffer.get_num_of_transitions() >= current_batch_size: + policy_config = cfg.policy + if policy_config.buffer_reanalyze_freq >= 1: + if update_per_collect is not None and update_per_collect > 0: + reanalyze_interval = update_per_collect // policy_config.buffer_reanalyze_freq + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // policy_config.num_unroll_steps > int(reanalyze_batch_size / policy_config.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, time cost: {timer.value}') + + train_data = replay_buffer.sample(current_batch_size, policy) + train_data.append(current_task_id) + train_data_multi_task.append(train_data) + envstep_this_epoch += collector.envstep + else: + logging.warning( + f'Not enough data for task {current_task_id}: ' + f'batch_size: {current_batch_size}, buffer: {replay_buffer.get_num_of_transitions()}' + ) + + if train_data_multi_task: + learn_kwargs = {'task_weights': task_weights, "train_iter": learner.train_iter} + log_vars = learner.train(train_data_multi_task, envstep_this_epoch, policy_kwargs=learn_kwargs) + + + if compiled_cfg_first.policy.use_priority: + if log_vars: + for batch_idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + task_id = cfg.policy.task_id + priority_key = f'value_priority_task{task_id}' + if priority_key in log_vars[0]: + if batch_idx < len(train_data_multi_task): + try: + replay_buffer.update_priority( + train_data_multi_task[batch_idx], + log_vars[0][priority_key] + ) + current_priorities = log_vars[0][priority_key] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + alpha = 0.1 + running_mean_key = f'running_mean_priority_task{task_id}' + if running_mean_key not in value_priority_tasks: + value_priority_tasks[running_mean_key] = mean_priority + else: + value_priority_tasks[running_mean_key] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[running_mean_key] + ) + running_mean_priority = value_priority_tasks[running_mean_key] + if policy_config.print_task_priority_logs: + print(f"Task {task_id} - Mean priority: {mean_priority:.8f}, " + f"Running mean priority: {running_mean_priority:.8f}, " + f"Std: {std_priority:.8f}") + except Exception as e: + logging.error(f"Error updating priority for task {task_id}: {e}") + else: + logging.warning(f"Cannot update priority for task {task_id}, missing data in train_data_multi_task.") + else: + logging.warning(f"Priority key '{priority_key}' not found for task {task_id} in log_vars[0]") + else: + logging.warning("log_vars is empty, cannot update priorities.") + train_epoch += 1 + # Check termination conditions + local_max_envstep = max(collector.envstep for collector in collectors) if collectors else 0 + max_envstep_reached = local_max_envstep >= max_env_step + max_train_iter_reached = learner.train_iter >= max_train_iter + + if max_envstep_reached or max_train_iter_reached: + logging.info(f'Termination condition reached: env_step ({local_max_envstep}/{max_env_step}) or train_iter ({learner.train_iter}/{max_train_iter})') + break + + if hasattr(policy, 'recompute_pos_emb_diff_and_clear_cache'): + policy.recompute_pos_emb_diff_and_clear_cache() + + learner.call_hook('after_run') + return policy diff --git a/lzero/entry/train_unizero_multitask_ddp.py b/lzero/entry/train_unizero_multitask_ddp.py new file mode 100644 index 000000000..faf7c6ce0 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_ddp.py @@ -0,0 +1,449 @@ +import logging +import os +from collections import defaultdict +from functools import partial +from typing import Tuple, Optional, List, Dict +import concurrent.futures +import torch +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer +from ding.worker import BaseLearner +from lzero.entry.utils import ( + EVALUATION_TIMEOUT, + TemperatureScheduler, + allocate_batch_size, + compute_task_weights, + compute_unizero_mt_normalized_stats, + log_buffer_memory_usage, + safe_eval, + symlog, + inv_symlog, +) + +from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler, symlog, inv_symlog +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroCollector as Collector + +timer = EasyTimer() + +def train_unizero_multitask_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + Entry point for UniZero multi-task training (DDP version). + Args: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configuration list for different tasks. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): An instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): Path to the pretrained model checkpoint file. + - max_train_iter (:obj:`Optional[int]`): Maximum number of policy update iterations during training. + - max_env_step (:obj:`Optional[int]`): Maximum number of collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): The converged policy. + """ + + # Initialize the temperature scheduler for task weighting. + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' + ) + + rank = get_rank() + world_size = get_world_size() + + # Task partitioning + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + num_tasks_for_this_rank = tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + num_tasks_for_this_rank = tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # Ensure at least one task is assigned + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: no tasks assigned, continuing execution.") + # Initialize empty lists to avoid errors in later code + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + print(f"Rank {rank}/{world_size}, handling tasks {start_idx} to {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # Use the first task’s config to create a shared policy + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + for config in tasks_for_this_rank: + config[1][0].policy.task_num = num_tasks_for_this_rank + + assert create_cfg.policy.type in ['unizero_multitask', + 'sampled_unizero_multitask'], "train_unizero entry 目前仅支持 'unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'Configured device: {cfg.policy.device}') + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create shared policy + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + print(f"rank {rank} created the policy!") + if model_path is not None: + logging.info(f'Loading pretrained model: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Finished loading pretrained model: {model_path}') + + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # Create shared learner + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + policy_config = cfg.policy + + # Handle each task assigned to this rank + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # Create environments + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # Create game buffer, collector, and evaluator + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + task_exploitation_weight = None + + # Create task reward dictionary + task_rewards = {} # {task_id: reward} + + while True: + # Dynamically adjust batch_size + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + print("Allocated batch_sizes: ", allocated_batch_sizes) + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + task_id = cfg.policy.task_id + if isinstance(allocated_batch_sizes, dict): + cfg.policy.batch_size = allocated_batch_sizes.get(task_id, cfg.policy.batch_size) + elif isinstance(allocated_batch_sizes, list): + # Use the index in the list or task_id as fallback + cfg.policy.batch_size = allocated_batch_sizes[idx] if idx < len(allocated_batch_sizes) else cfg.policy.batch_size + else: + logging.warning(f"Unexpected type for allocated_batch_sizes: {type(allocated_batch_sizes)}") + # Also update the policy config (use the full list for compatibility) + policy._cfg.batch_size = allocated_batch_sizes + + # Perform data collection and evaluation for each task on this rank + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + policy_config = cfg.policy + + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'Rank {rank} evaluating task_id: {cfg.policy.task_id}...') + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + + # Perform safe evaluation + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + if stop is None or reward is None: + print(f"Rank {rank} encountered an issue during evaluation, continuing training...") + task_rewards[cfg.policy.task_id] = float('inf') # Assign max difficulty if evaluation fails + else: + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + print(f"Evaluation reward for task {cfg.policy.task_id}: {eval_mean_reward}") + task_rewards[cfg.policy.task_id] = eval_mean_reward + except Exception as e: + print(f"Error extracting evaluation reward: {e}") + task_rewards[cfg.policy.task_id] = float('inf') # Assign max reward if error occurs + + + print('=' * 20) + print(f'Starting data collection for Rank {rank}, task_id: {cfg.policy.task_id}...') + print(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') + + # Reset policy state before each collection (important for multi-task setups) + collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}, collected {len(new_data[0]) if new_data else 0} segments') + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time cost: {timer.value}') + + logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}') + + try: + logging.info(f'Rank {rank}: Waiting at post-collection barrier...') + dist.barrier() + logging.info(f'Rank {rank}: All ranks completed data collection, proceeding...') + except Exception as e: + logging.error(f'Rank {rank}: Post-collection barrier failed, error: {e}') + raise e + + # Check if there is enough data for training + local_not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + logging.info(f"Rank {rank} local_not_enough_data:{local_not_enough_data}") + flag_tensor = torch.tensor(1.0 if local_not_enough_data else 0.0, device=cfg.policy.device) + dist.all_reduce(flag_tensor, op=dist.ReduceOp.MAX) + not_enough_data = (flag_tensor.item() > 0.5) + if rank == 0: + logging.info(f"Global not_enough_data status: {not_enough_data}") + + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + if isinstance(cfg.policy.batch_size, (list, tuple)): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + elif isinstance(cfg.policy.batch_size, dict): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + else: + batch_size = cfg.policy.batch_size + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalyze time cost: {timer.value}') + + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'Not enough data in replay buffer to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + learn_kwargs = {'task_weights': None, "train_iter": learner.train_iter} + log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) + + # Compute task_exploitation_weight if needed + if i == 0: + try: + dist.barrier() + if cfg.policy.use_task_exploitation_weight: + all_obs_loss = [None for _ in range(world_size)] + merged_obs_loss_task = {} + for cfg, replay_buffer in zip(cfgs, game_buffers): + task_id = cfg.policy.task_id + if f'noreduce_obs_loss_task{task_id}' in log_vars[0]: + merged_obs_loss_task[task_id] = log_vars[0][f'noreduce_obs_loss_task{task_id}'] + dist.all_gather_object(all_obs_loss, merged_obs_loss_task) + global_obs_loss_task = {} + for obs_loss_task in all_obs_loss: + if obs_loss_task: + global_obs_loss_task.update(obs_loss_task) + if global_obs_loss_task: + task_exploitation_weight = compute_task_weights( + global_obs_loss_task, + option="rank", + temperature=1, + ) + dist.broadcast_object_list([task_exploitation_weight], src=0) + print(f"Rank {rank}, task_exploitation_weight (by task_id): {task_exploitation_weight}") + else: + logging.warning(f"Rank {rank}: Failed to compute global obs_loss task weights, obs_loss data is empty.") + task_exploitation_weight = None + else: + task_exploitation_weight = None + learn_kwargs['task_weight'] = task_exploitation_weight + except Exception as e: + logging.error(f'Rank {rank}: Failed to synchronize task weights, error: {e}') + raise e + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'value_priority_task{task_id}'] + ) + + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + alpha = 0.1 # smoothing factor + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + if cfg.policy.print_task_priority_logs: + print(f"Task {task_id} - Mean priority: {mean_priority:.8f}, " + f"Running mean priority: {running_mean_priority:.8f}, " + f"Std: {std_priority:.8f}") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # Synchronize all ranks after training + try: + dist.barrier() + logging.info(f'Rank {rank}: passed training synchronization barrier') + except Exception as e: + logging.error(f'Rank {rank}: synchronization barrier failed, error: {e}') + break + + # Check termination conditions + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: termination condition reached') + dist.barrier() + break + except Exception as e: + logging.error(f'Rank {rank}: termination check failed, error: {e}') + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index dc8dacf0f..a6ed03571 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -608,8 +608,8 @@ def safe_eval( evaluator: Evaluator, learner: BaseLearner, collector: Collector, - rank: int, - world_size: int, + rank: int = 0, + world_size: int = 1, timeout: int = EVALUATION_TIMEOUT ) -> Tuple[Optional[bool], Optional[Any]]: """ diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index 6c2cd1999..a4e432828 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -72,7 +72,10 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea # image obs input, e.g. atari environments self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) else: - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: + self.zero_obs_shape = config.model.observation_shape + elif len(config.model.observation_shape) == 3: + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) self.obs_segment = [] self.action_segment = [] diff --git a/lzero/model/common.py b/lzero/model/common.py index 31f254963..4569090ba 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -498,15 +498,20 @@ def __init__(self, torch.distributed.barrier() if get_rank() != 0: self.pretrained_model = AutoModel.from_pretrained(model_path) - - if get_rank() != 0: - logging.info(f"Worker process is loading model from cache: {model_path}") - self.model = AutoModel.from_pretrained(model_path) - if tokenizer is None: - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - if tokenizer is not None: + if tokenizer is None: + # Only rank 0 downloads the tokenizer, and then other processes load it from cache. + if get_rank() == 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + if get_world_size() > 1: + torch.distributed.barrier() + if get_rank() != 0: + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + else: self.tokenizer = tokenizer + + for p in self.pretrained_model.parameters(): + p.requires_grad_(False) self.embedding_size = embedding_size self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size) diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index b680a6e2d..2270bc05e 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -129,12 +129,12 @@ def __init__( else: raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}") - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, - with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, + with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) # --- Log parameter counts for analysis --- - self._log_model_parameters(obs_type) + self._log_model_parameters(world_model_cfg.obs_type) logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') logging.info('==' * 20) diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py index a7356d942..8832303d2 100644 --- a/lzero/model/unizero_model_multitask.py +++ b/lzero/model/unizero_model_multitask.py @@ -6,7 +6,8 @@ from easydict import EasyDict from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ - VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ + HFLanguageRepresentationNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model_multitask import WorldModelMT from .vit import ViT, ViTConfig @@ -84,6 +85,8 @@ def __init__( self._init_image_components(world_model_cfg, observation_shape, num_res_blocks, num_channels, obs_act_embed_dim) elif obs_type == 'image_memory': self._init_image_memory_components(world_model_cfg) + elif obs_type == 'text': + self._init_text_components(world_model_cfg, encoder_url=kwargs['encoder_url']) else: raise ValueError(f"Unsupported observation type: {obs_type}") @@ -93,6 +96,20 @@ def __init__( # --- Log parameter counts for analysis --- self._log_model_parameters(obs_type) + def _init_text_components(self, world_model_cfg: EasyDict, encoder_url: str) -> None: + """Initializes components for 'text' observation type.""" + self.representation_network = HFLanguageRepresentationNetwork( + model_path=encoder_url, + embedding_size=world_model_cfg.embed_dim, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder + ) + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=None, + with_lpips=False, + obs_type=world_model_cfg.obs_type + ) + def _init_vector_components(self, world_model_cfg: EasyDict, obs_act_embed_dim: int) -> None: """Initializes components for 'vector' observation type.""" self.representation_network = RepresentationNetworkMLP( diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index 836a463c7..8a6541537 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -1708,7 +1708,13 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dtype=batch['observations'].dtype) perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) + + elif self.obs_type == 'text': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=torch.float32) + latent_recon_loss = self.latent_recon_loss + # Action tokens if self.continuous_action_space: act_tokens = batch['actions'] diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 766012870..e44e321ca 100755 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -1629,7 +1629,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in self._cfg.device, pad_token_id=self.pad_token_id ) - self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + self.last_batch_action_collect = [-1 for _ in range(self._cfg.collector_env_num)] # We must handle both single int and list of ints for env_id. @@ -1707,7 +1707,7 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ ) logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) - self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + self.last_batch_action_eval = [-1 for _ in range(self._cfg.evaluator_env_num)] # This logic handles the crucial end-of-episode cache clearing for evaluation. # The evaluator calls `_policy.reset([env_id])` when an episode is done. diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 1e0a65845..f7956e45b 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -87,7 +87,7 @@ def __init__( self._logger, _ = build_logger( path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) - self._tb_logger = None + self._tb_logger = tb_logger self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy @@ -126,7 +126,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: self._logger.debug( f"Set default n_episode mode(n_episode({self._default_n_episode}), env_num({self._env_num}))" ) - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ @@ -359,7 +359,7 @@ def collect( chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} # Initialize game segments and observation stacks for each environment. - game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) for _ in range(env_nums)] + game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id) for _ in range(env_nums)] observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)] for env_id in range(env_nums): for _ in range(self.policy_config.model.frame_stack_num): @@ -525,7 +525,7 @@ def collect( last_game_priorities[env_id] = priorities # Start a new game segment. - game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id) game_segments[env_id].reset(observation_window_stack[env_id]) self._env_info[env_id]['step'] += 1 @@ -570,7 +570,7 @@ def collect( chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) # Reset game segment and observation stack. - game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config, task_id=self.task_id) observation_window_stack[env_id].clear() for _ in range(self.policy_config.model.frame_stack_num): observation_window_stack[env_id].append(init_obs[env_id]['observation']) @@ -585,7 +585,7 @@ def collect( completed_value_lst[env_id] = 0 # Reset policy and collector stats for the finished environment. - self._policy.reset([env_id]) + self._policy.reset([env_id], task_id=self.task_id) self._reset_stat(env_id) ready_env_id.remove(env_id) @@ -628,9 +628,6 @@ def _output_log(self, train_iter: int) -> None: Arguments: - train_iter (:obj:`int`): The current training iteration number, used as the logging step. """ - if self._rank != 0: - return - if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) @@ -663,21 +660,18 @@ def _output_log(self, train_iter: int) -> None: self._episode_info.clear() - # Log to console - self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()]))) - + self._logger.info(f"Collector log (rank {self._rank}, task_id {self.task_id}):\n" + '\n'.join([f'{k}: {v}' for k, v in info.items()])) # Log to TensorBoard and WandB for k, v in info.items(): + if k in ['each_reward']: + continue if self.task_id is None: - tb_prefix_iter = f'{self._instance_name}_iter/' - tb_prefix_step = f'{self._instance_name}_step/' + # Log for single-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, self._total_envstep_count) else: - tb_prefix_iter = f'{self._instance_name}_iter_task{self.task_id}/' - tb_prefix_step = f'{self._instance_name}_step_task{self.task_id}/' - - self._tb_logger.add_scalar(tb_prefix_iter + k, v, train_iter) - self._tb_logger.add_scalar(tb_prefix_step + k, v, self._total_envstep_count) - - if self.policy_config.use_wandb: - wandb_log_data = {tb_prefix_step + k: v for k, v in info.items()} - wandb.log(wandb_log_data, step=self._total_envstep_count) + # Log for multi-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, self._total_envstep_count) \ No newline at end of file diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index c1b34da37..82d2f01d3 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -109,7 +109,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), ), ), - accumulation_steps=1, # TODO: Accumulated gradient steps (currently default) + accumulation_steps=1, model=dict( observation_shape=512, action_space_size=action_space_size, @@ -135,7 +135,9 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e obs_type="text", env_num=max(collector_env_num, evaluator_env_num), decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. - latent_recon_loss_weight=0.1 + latent_recon_loss_weight=0.1, + game_segment_length=50, + use_priority=False, ), ), update_per_collect=int(collector_env_num*max_steps*replay_ratio ), # Important for DDP diff --git a/zoo/jericho/configs/jericho_unizero_ddp_config.py b/zoo/jericho/configs/jericho_unizero_ddp_config.py index c204148b2..64f4a63ca 100644 --- a/zoo/jericho/configs/jericho_unizero_ddp_config.py +++ b/zoo/jericho/configs/jericho_unizero_ddp_config.py @@ -34,9 +34,6 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e else: raise ValueError(f"Unsupported encoder option: {encoder_option}") - # TODO - # batch_size = batch_size * 2 - # ------------------------------------------------------------------ # Base environment parameters (Note: these values might be adjusted for different env_id) # ------------------------------------------------------------------ @@ -115,7 +112,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e ), ), ), - accumulation_steps=accumulation_steps, # TODO: Accumulated gradient steps (currently default) + accumulation_steps=accumulation_steps, model=dict( observation_shape=512, action_space_size=action_space_size, @@ -138,13 +135,14 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e num_layers=num_layers, num_heads=24, embed_dim=embed_dim, - obs_type="text", # TODO: Modify as needed. + obs_type="text", env_num=max(collector_env_num, evaluator_env_num), decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. - latent_recon_loss_weight=0.1 # TODO: decoder loss weight + latent_recon_loss_weight=0, # TODO: decoder loss weight + game_segment_length=50, + use_priority=False, ), ), - # TODO update_per_collect=int(collector_env_num*max_steps*replay_ratio*accumulation_steps), # Important for DDP action_type="varied_action_space", model_path=None, @@ -159,7 +157,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e num_simulations=num_simulations, n_episode=n_episode, - train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed. + train_start_after_envsteps=0, replay_buffer_size=int(5e5), eval_freq=int(3e4), collector_env_num=collector_env_num, diff --git a/zoo/jericho/configs/jericho_unizero_multitask_config.py b/zoo/jericho/configs/jericho_unizero_multitask_config.py new file mode 100644 index 000000000..8ab654915 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_multitask_config.py @@ -0,0 +1,255 @@ +from easydict import EasyDict + +def create_config(env_id, max_steps, max_action_num, action_space_size, collector_env_num, evaluator_env_num, n_episode, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, total_batch_size, + num_layers, model_name, replay_ratio, update_per_collect, + collect_num_simulations, eval_num_simulations): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=512, + max_steps=max_steps, + max_action_num=max_action_num, + tokenizer_path=model_name, + max_seq_len=512, + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + ), + policy=dict( + multi_gpu=False, # Very important for ddp + only_use_moco_stats=False, + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + use_moco=False, # Whether to use MoCo for multi-task gradient adjustments + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + use_wandb=False, + learn=dict( + learner=dict( + hook=dict( + save_ckpt_after_iter=200000 + ), + ), + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + model_type="mlp", + continuous_action_space=False, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + obs_type="text", + env_num=max(collector_env_num, evaluator_env_num), + task_embed_option=None, + use_task_embed=False, + embed_dim=768, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + multiplication_moe_in_transformer=False, # Whether to use moe in transformers + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + moe_use_lora=False, # Does moe use lora + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + + analysis_dormant_ratio_weight_rank=False, + analysis_dormant_ratio_interval=5000, + game_segment_length=50, + use_priority=False, + ), + ), + optim_type='AdamW', + + # (bool) whether enable adaptive policy entropy weights + use_adaptive_entropy_weight=False, + # (float) learning rate of adaptive alpha optimizer + adaptive_entropy_alpha_lr=1e-4, + target_entropy_start_ratio =0.98, + target_entropy_end_ratio =0.5, + target_entropy_decay_steps = 100000, + + use_encoder_clip_annealing=False, + encoder_clip_anneal_type='cosine', + encoder_clip_start_value=30.0, + encoder_clip_end_value=10.0, + encoder_clip_anneal_steps=30000, + + policy_ls_eps_start=0.05, + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, + label_smoothing_eps=0, + + use_task_exploitation_weight=False, + task_complexity_weight=False, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=update_per_collect, + action_type="varied_action_space", + replay_ratio=replay_ratio, + batch_size=batch_size, + reanalyze_ratio=reanalyze_ratio, + learning_rate=0.0001, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + n_episode=n_episode, + train_start_after_envsteps=int(0), + replay_buffer_size=int(5e5), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list, env_configurations, collector_env_num, n_episode, evaluator_env_num, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + total_batch_size, num_layers, model_name, replay_ratio, collect_num_simulations, eval_num_simulations): + """ + Overview: + Generates a list of configurations for all specified tasks. + + Arguments: + (See arguments for `create_config` function) + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[List[Union[int, List[EasyDict]]]]`): A list where each element contains a task_id + and its corresponding configuration objects. + """ + configs = [] + + exp_name_prefix = f'data_scalezero/jericho_mt_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' + + action_space_size_list = [v[0] for _, v in env_configurations.items()] + max_action_space_size = max(action_space_size_list) + + for task_id, env_id in enumerate(env_id_list): + _, max_steps = env_configurations.get(env_id, (10, 50)) + update_per_collect = 40 # Ensure at least one update per collect + + config = create_config( + env_id=env_id, max_steps=max_steps, max_action_num=max_action_space_size, action_space_size=max_action_space_size, + collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_episode=n_episode, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, + num_layers=num_layers, model_name=model_name, replay_ratio=replay_ratio, + update_per_collect=update_per_collect, collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('.z5')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs for distributed training. + + Example launch commands: + + cd /path/to/your/project/ + python zoo/jericho/configs/jericho_unizero_multitask_config.py + """ + + from lzero.entry import train_unizero_multitask + import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + env_configurations = { + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), + } + env_id_list = list(env_configurations.keys()) + + # Model name or path - configurable according to the predefined model paths or names + model_name: str = 'BAAI/bge-base-en-v1.5' + replay_ratio = 0.1 + + collector_env_num = 4 + n_episode = 4 + evaluator_env_num = 2 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + total_batch_size =int(64 * len(env_id_list)) + batch_size = [int(total_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + + num_layers=2 + num_unroll_steps = 10 + infer_context_length = 4 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + collect_num_simulations = 50 + eval_num_simulations = 50 + + for seed in [0]: + configs = generate_configs( env_id_list=env_id_list, env_configurations=env_configurations, + collector_env_num=collector_env_num, n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, + seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, + total_batch_size=total_batch_size, num_layers=num_layers, + model_name=model_name, replay_ratio=replay_ratio,collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations) + train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) diff --git a/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py new file mode 100644 index 000000000..5aa480cb9 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py @@ -0,0 +1,269 @@ +from easydict import EasyDict + +def create_config(env_id, max_steps, max_action_num, action_space_size, collector_env_num, evaluator_env_num, n_episode, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, total_batch_size, + num_layers, model_name, replay_ratio, update_per_collect, + collect_num_simulations, eval_num_simulations): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=512, + max_steps=max_steps, + max_action_num=max_action_num, + tokenizer_path=model_name, + max_seq_len=512, + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + ), + policy=dict( + multi_gpu=True, # Very important for ddp + only_use_moco_stats=False, + collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations, + use_moco=False, # Whether to use MoCo for multi-task gradient adjustments + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + use_wandb=False, + learn=dict( + learner=dict( + hook=dict( + save_ckpt_after_iter=200000 + ), + ), + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + model_type="mlp", + continuous_action_space=False, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + obs_type="text", + env_num=max(collector_env_num, evaluator_env_num), + task_embed_option=None, + use_task_embed=False, + embed_dim=768, + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + moe_in_transformer=False, + multiplication_moe_in_transformer=False, # Whether to use moe in transformers + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + moe_use_lora=False, # Does moe use lora + lora_r= 0, + lora_alpha =1, + lora_dropout= 0.0, + + analysis_dormant_ratio_weight_rank=False, + analysis_dormant_ratio_interval=5000, + game_segment_length=50, + use_priority=False, + ), + ), + optim_type='AdamW', + # (bool) whether enable adaptive policy entropy weights + use_adaptive_entropy_weight=False, + # (float) learning rate of adaptive alpha optimizer + adaptive_entropy_alpha_lr=1e-4, + target_entropy_start_ratio =0.98, + target_entropy_end_ratio =0.5, + target_entropy_decay_steps = 100000, + + use_encoder_clip_annealing=False, + encoder_clip_anneal_type='cosine', + encoder_clip_start_value=30.0, + encoder_clip_end_value=10.0, + encoder_clip_anneal_steps=30000, + + policy_ls_eps_start=0.05, + policy_ls_eps_end=0.01, + policy_ls_eps_decay_steps=50000, + label_smoothing_eps=0, + + + use_task_exploitation_weight=False, + task_complexity_weight=False, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=update_per_collect, + action_type="varied_action_space", + replay_ratio=replay_ratio, + batch_size=batch_size, + reanalyze_ratio=reanalyze_ratio, + learning_rate=0.0001, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + n_episode=n_episode, + train_start_after_envsteps=int(0), + replay_buffer_size=int(5e5), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def generate_configs(env_id_list, env_configurations, collector_env_num, n_episode, evaluator_env_num, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, + seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + total_batch_size, num_layers, model_name, replay_ratio, collect_num_simulations, eval_num_simulations): + """ + Overview: + Generates a list of configurations for all specified tasks. + + Arguments: + (See arguments for `create_config` function) + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[List[Union[int, List[EasyDict]]]]`): A list where each element contains a task_id + and its corresponding configuration objects. + """ + configs = [] + exp_name_prefix = f'data_scalezero/jericho_ddp_mt_{len(env_id_list)}games_tbs{total_batch_size}-nlayer{num_layers}-rr{replay_ratio}_not-share-head_encoder-final-ln_seed{seed}/' + + action_space_size_list = [v[0] for _, v in env_configurations.items()] + max_action_space_size = max(action_space_size_list) + + for task_id, env_id in enumerate(env_id_list): + _, max_steps = env_configurations.get(env_id, (10, 50)) + update_per_collect = 40 # Ensure at least one update per collect + + config = create_config( + env_id=env_id, max_steps=max_steps, max_action_num=max_action_space_size, action_space_size=max_action_space_size, + collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_episode=n_episode, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, total_batch_size=total_batch_size, + num_layers=num_layers, model_name=model_name, replay_ratio=replay_ratio, + update_per_collect=update_per_collect, collect_num_simulations=collect_num_simulations, eval_num_simulations=eval_num_simulations, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('.z5')[0]}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs for distributed training. + + Example launch commands: + + export CUDA_VISIBLE_DEVICES=0,1,2,3 + cd /path/to/your/project/ + + torchrun --nproc_per_node=4 zoo/jericho/configs/jericho_unizero_multitask_ddp_config.py + """ + + from lzero.entry import train_unizero_multitask_ddp + from ding.utils import DDPContext + import torch.distributed as dist + import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + env_configurations = { + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), + } + env_id_list = list(env_configurations.keys()) + + # Model name or path - configurable according to the predefined model paths or names + model_name: str = 'BAAI/bge-base-en-v1.5' + replay_ratio = 0.1 + + collector_env_num = 4 + n_episode = 4 + evaluator_env_num = 2 + max_env_step = int(5e5) + reanalyze_ratio = 0.0 + + total_batch_size =int(64 * len(env_id_list)) + batch_size = [int(total_batch_size / len(env_id_list)) for _ in range(len(env_id_list))] + + num_layers=2 + num_unroll_steps = 10 + infer_context_length = 4 + buffer_reanalyze_freq = 1 / 1000000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + collect_num_simulations = 50 + eval_num_simulations = 50 + + + # Set NCCL timeout to prevent watchdog hang due to unbalanced data collection speeds + os.environ['NCCL_TIMEOUT'] = '600' + os.environ['NCCL_BLOCKING_WAIT'] = '1' + os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '1' + + for seed in [0]: + configs = generate_configs( env_id_list=env_id_list, env_configurations=env_configurations, + collector_env_num=collector_env_num, n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + reanalyze_ratio=reanalyze_ratio, batch_size=batch_size, + num_unroll_steps=num_unroll_steps, infer_context_length=infer_context_length, + seed=seed, buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, reanalyze_partition=reanalyze_partition, + total_batch_size=total_batch_size, num_layers=num_layers, + model_name=model_name, replay_ratio=replay_ratio, collect_num_simulations=collect_num_simulations, + eval_num_simulations=eval_num_simulations) + + with DDPContext(): + train_unizero_multitask_ddp(configs, seed=seed, max_env_step=max_env_step) + print(f"Seed: {seed} training finished!") + if dist.is_initialized(): + dist.destroy_process_group() \ No newline at end of file