diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4c2e2bd3d..93e79081b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: matrix: os: - 'ubuntu-20.04' - - 'macos-13' + - 'macos-15' python: - '3.7' - '3.8' @@ -73,11 +73,11 @@ jobs: architecture: x86 - os: ubuntu-20.04 architecture: AMD64 - - os: macos-13 + - os: macos-15 architecture: aarch64 - - os: macos-13 + - os: macos-15 architecture: x86 - - os: macos-13 + - os: macos-15 architecture: AMD64 steps: @@ -167,25 +167,25 @@ jobs: name: build-artifacts-wheels-ubuntu-20.04-3.11-aarch64 path: aggregated_wheels_all - - name: Download wheel macos-13, 3.7, x86_64 + - name: Download wheel macos-15, 3.7, x86_64 uses: actions/download-artifact@v4 with: - name: build-artifacts-wheels-macos-13-3.7-x86_64 + name: build-artifacts-wheels-macos-15-3.7-x86_64 path: aggregated_wheels_all - - name: Download wheel macos-13, 3.8, x86_64 + - name: Download wheel macos-15, 3.8, x86_64 uses: actions/download-artifact@v4 with: - name: build-artifacts-wheels-macos-13-3.8-x86_64 + name: build-artifacts-wheels-macos-15-3.8-x86_64 path: aggregated_wheels_all - - name: Download wheel macos-13, 3.7, arm64 + - name: Download wheel macos-15, 3.7, arm64 uses: actions/download-artifact@v4 with: - name: build-artifacts-wheels-macos-13-3.7-arm64 + name: build-artifacts-wheels-macos-15-3.7-arm64 path: aggregated_wheels_all - - name: Download wheel macos-13, 3.8, arm64 + - name: Download wheel macos-15, 3.8, arm64 uses: actions/download-artifact@v4 with: - name: build-artifacts-wheels-macos-13-3.8-arm64 + name: build-artifacts-wheels-macos-15-3.8-arm64 path: aggregated_wheels_all - name: Upload unified wheels artifact diff --git a/.github/workflows/release_test.yml b/.github/workflows/release_test.yml index 70ea8bd35..c86c21def 100644 --- a/.github/workflows/release_test.yml +++ b/.github/workflows/release_test.yml @@ -56,7 +56,7 @@ jobs: matrix: os: - 'ubuntu-20.04' - - 'macos-13' + - 'macos-15' python: - '3.7.17' - '3.8.17' @@ -76,11 +76,11 @@ jobs: architecture: x86 - os: ubuntu-20.04 architecture: AMD64 - - os: macos-13 + - os: macos-15 architecture: aarch64 - - os: macos-13 + - os: macos-15 architecture: x86 - - os: macos-13 + - os: macos-15 architecture: AMD64 - python: '3.7.17' architecture: arm64 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1db1c0c35..78e2d1536 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,7 +20,7 @@ jobs: matrix: os: - 'self-hosted' - - 'macos-13' + - 'macos-15' python-version: - '3.8' - '3.9' @@ -61,7 +61,7 @@ jobs: if: ${{ env.OS_NAME == 'MacOS' }} shell: bash run: | - brew install tree cloc wget curl make zip graphviz + brew install tree cloc wget curl make zip graphviz swig brew install llvm # Install llvm (which includes clang) brew install opencv # Install OpenCV echo 'export PATH="/usr/local/opt/llvm/bin:$PATH"' >> $GITHUB_ENV # update PATH diff --git a/.gitignore b/.gitignore index 90d30a1b4..96cfabd26 100644 --- a/.gitignore +++ b/.gitignore @@ -1453,4 +1453,4 @@ events.* !/assets/pooltool/** lzero/mcts/ctree/ctree_alphazero/pybind11 -zoo/jericho/envs/z-machine-games-master \ No newline at end of file +zoo/jericho/envs/z-machine-games-master diff --git a/lzero/entry/README.md b/lzero/entry/README.md new file mode 100644 index 000000000..096d7f725 --- /dev/null +++ b/lzero/entry/README.md @@ -0,0 +1,156 @@ +# LightZero Entry Functions + +English | [中文](./README_zh.md) + +This directory contains the training and evaluation entry functions for various algorithms in the LightZero framework. These entry functions serve as the main interfaces for launching different types of reinforcement learning experiments. + +## 📁 Directory Structure + +### 🎯 Training Entries + +#### AlphaZero Family +- **`train_alphazero.py`** - Training entry for AlphaZero algorithm + - Suitable for perfect information board games (e.g., Go, Chess) + - No environment model needed, learns through self-play + - Uses Monte Carlo Tree Search (MCTS) for policy improvement + +#### MuZero Family +- **`train_muzero.py`** - Standard training entry for MuZero algorithm + - Supports MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero variants + - Learns an implicit model of the environment (dynamics model) + - Suitable for single-task reinforcement learning scenarios + +- **`train_muzero_segment.py`** - MuZero training with segment collector and buffer reanalyze + - Uses `MuZeroSegmentCollector` for data collection + - Supports buffer reanalyze trick for improved sample efficiency + - Supported algorithms: MuZero, EfficientZero, Sampled MuZero, Sampled EfficientZero, Gumbel MuZero, StochasticMuZero + +- **`train_muzero_with_gym_env.py`** - MuZero training adapted for Gym environments + - Specifically designed for OpenAI Gym-style environments + - Simplifies environment interface adaptation + +- **`train_muzero_with_reward_model.py`** - MuZero training with reward model + - Integrates external Reward Model + - Suitable for scenarios requiring learning complex reward functions + +- **`train_muzero_multitask_segment_ddp.py`** - MuZero multi-task distributed training + - Supports multi-task learning + - Uses DDP (Distributed Data Parallel) for distributed training + - Uses Segment Collector + +#### UniZero Family +- **`train_unizero.py`** - Training entry for UniZero algorithm + - Based on paper "UniZero: Generalized and Efficient Planning with Scalable Latent World Models" + - Enhanced planning capabilities for better long-term dependency capture + - Uses scalable latent world models + - Paper: https://arxiv.org/abs/2406.10667 + +- **`train_unizero_segment.py`** - UniZero training with segment collector + - Uses `MuZeroSegmentCollector` for efficient data collection + - Supports buffer reanalyze trick + +- **`train_unizero_multitask_segment_ddp.py`** - UniZero/ScaleZero multi-task distributed training + - Supports multi-task learning and distributed training + - Includes benchmark score definitions (e.g., Atari human-normalized scores) + - Supports curriculum learning strategies + - Uses DDP for training acceleration + +- **`train_unizero_multitask_balance_segment_ddp.py`** - UniZero/ScaleZero balanced multi-task distributed training + - Implements balanced sampling across tasks in multi-task training + - Dynamically adjusts batch sizes for different tasks + - Suitable for scenarios with large task difficulty variations + +- **`train_unizero_multitask_segment_eval.py`** - UniZero/ScaleZero multi-task evaluation training + - Specialized for training and periodic evaluation in multi-task scenarios + - Includes detailed evaluation metric statistics + +- **`train_unizero_with_loss_landscape.py`** - UniZero training with loss landscape visualization + - For training with loss landscape visualization + - Helps understand model optimization process and generalization performance + - Integrates `loss_landscapes` library + +#### ReZero Family +- **`train_rezero.py`** - Training entry for ReZero algorithm + - Supports ReZero-MuZero and ReZero-EfficientZero + - Improves training stability through residual connections + - Paper: https://arxiv.org/pdf/2404.16364 + +### 🎓 Evaluation Entries + +- **`eval_alphazero.py`** - Evaluation entry for AlphaZero + - Loads trained AlphaZero models for evaluation + - Can play against other agents for performance testing + +- **`eval_muzero.py`** - Evaluation entry for MuZero family + - Supports evaluation of all MuZero variants + - Provides detailed performance statistics + +- **`eval_muzero_with_gym_env.py`** - MuZero evaluation for Gym environments (not recently maintained) + - Specialized for evaluating models trained in Gym environments + + +## 📖 Usage Guide + +### Basic Usage Pattern + +All training entry functions follow a similar calling pattern: + +```python +from lzero.entry import train_muzero + +# Prepare configuration +cfg = dict(...) # User configuration +create_cfg = dict(...) # Creation configuration + +# Start training +policy = train_muzero( + input_cfg=(cfg, create_cfg), + seed=0, + model=None, # Optional: pre-initialized model + model_path=None, # Optional: pretrained model path + max_train_iter=int(1e10), # Maximum training iterations + max_env_step=int(1e10), # Maximum environment steps +) +``` + +### Choosing the Right Entry Function + +1. **Single-Task Learning**: + - Board games → `train_alphazero` + - General RL tasks → `train_muzero` or `train_unizero` + - Gym environments → `train_muzero_with_gym_env` (not recently maintained) + +2. **Multi-Task Learning**: + - Standard multi-task → `train_unizero_multitask_segment_ddp` + - Balanced task sampling → `train_unizero_multitask_balance_segment_ddp` + +3. **Distributed Training**: + - All entry functions with `_ddp` suffix support distributed training + +4. **Special Requirements**: + - Loss landscape visualization → `train_unizero_with_loss_landscape` + - External reward model → `train_muzero_with_reward_model` + - Improved training stability → `train_rezero` + +## 🔗 Related Resources + +- **AlphaZero**: [Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm](https://arxiv.org/abs/1712.01815) +- **MuZero**: [Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model](https://arxiv.org/abs/1911.08265) +- **EfficientZero**: [Mastering Atari Games with Limited Data](https://arxiv.org/abs/2111.00210) +- **UniZero**: [Generalized and Efficient Planning with Scalable Latent World Models](https://arxiv.org/abs/2406.10667) +- **ReZero**: [Boosting MCTS-based Algorithms by Reconstructing the Terminal Reward](https://arxiv.org/abs/2404.16364) +- **ScaleZero**: [One Model for All Tasks: Leveraging Efficient World Models in Multi-Task Planning](https://arxiv.org/abs/2509.07945) + +## 💡 Tips + +- Recommended to start with standard `train_muzero` or `train_unizero` +- For large-scale experiments, consider using DDP versions for faster training +- Using `_segment` versions can achieve better sample efficiency (via reanalyze trick) +- Check configuration examples in `zoo/` directory to learn how to set up each algorithm + +## 📝 Notes + +1. All path parameters should use **absolute paths** +2. Pretrained model paths typically follow format: `exp_name/ckpt/ckpt_best.pth.tar` +3. When using distributed training, ensure `CUDA_VISIBLE_DEVICES` environment variable is set correctly +4. Some entry functions have specific algorithm type requirements - check function documentation diff --git a/lzero/entry/README_zh.md b/lzero/entry/README_zh.md new file mode 100644 index 000000000..502622f26 --- /dev/null +++ b/lzero/entry/README_zh.md @@ -0,0 +1,155 @@ +# LightZero 入口函数说明 + +[English](./README.md) | 中文 + +本目录包含了 LightZero 框架中各种算法的训练和评估入口函数。这些入口函数是启动不同类型强化学习实验的主要接口。 + +## 📁 目录结构 + +### 🎯 训练入口 (Training Entries) + +#### AlphaZero 系列 +- **`train_alphazero.py`** - AlphaZero 算法的训练入口 + - 适用于完美信息的棋类游戏(如五子棋、中国象棋等) + - 不需要环境模型,直接通过自我对弈学习 + - 使用蒙特卡洛树搜索(MCTS)进行策略改进 + +#### MuZero 系列 +- **`train_muzero.py`** - MuZero 算法的标准训练入口 + - 支持 MuZero、EfficientZero、Sampled EfficientZero、Gumbel MuZero 等变体 + - 学习环境的隐式模型(dynamics model) + - 适用于单任务强化学习场景 + +- **`train_muzero_segment.py`** - MuZero 带分段收集器和缓冲区重分析技巧的训练入口 + - 使用 `MuZeroSegmentCollector` 进行数据收集 + - 支持缓冲区重分析(reanalyze)技巧提高样本效率 + - 支持的算法:MuZero, EfficientZero, Sampled MuZero, Sampled EfficientZero, Gumbel MuZero, StochasticMuZero + +- **`train_muzero_with_gym_env.py`** - 适配 Gym 环境的 MuZero 训练入口 + - 专门为 OpenAI Gym 风格的环境设计 + - 简化了环境接口的适配过程 + +- **`train_muzero_with_reward_model.py`** - 带奖励模型的 MuZero 训练入口 + - 集成外部奖励模型(Reward Model) + - 适用于需要学习复杂奖励函数的场景 + +- **`train_muzero_multitask_segment_ddp.py`** - MuZero 多任务分布式训练入口 + - 支持多任务学习(Multi-task Learning) + - 使用 DDP (Distributed Data Parallel) 进行分布式训练 + - 使用分段收集器(Segment Collector) + +#### UniZero 系列 +- **`train_unizero.py`** - UniZero 算法的训练入口 + - 基于论文 "UniZero: Generalized and Efficient Planning with Scalable Latent World Models" + - 增强的规划能力,能更好地捕获长期依赖 + - 使用可扩展的隐式世界模型 + - 论文链接:https://arxiv.org/abs/2406.10667 + +- **`train_unizero_segment.py`** - UniZero 带分段收集器的训练入口 + - 使用 `MuZeroSegmentCollector` 进行高效数据收集 + - 支持缓冲区重分析技巧 + +- **`train_unizero_multitask_segment_ddp.py`** - UniZero/ScaleZero 多任务分布式训练入口 + - 支持多任务学习和分布式训练 + - 包含基准测试分数定义(如 Atari 的人类归一化分数) + - 支持课程学习(Curriculum Learning)策略 + - 使用 DDP 加速训练 + +- **`train_unizero_multitask_balance_segment_ddp.py`** - UniZero/ScaleZero 多任务均衡分布式训练入口 + - 在多任务训练中实现任务间的均衡采样 + - 动态调整不同任务的批次大小 + - 适用于任务难度差异较大的场景 + +- **`train_unizero_multitask_segment_eval.py`** - UniZero/ScaleZero 多任务评估训练入口 + - 专门用于多任务场景的训练和周期性评估 + - 包含详细的评估指标统计 + +- **`train_unizero_with_loss_landscape.py`** - UniZero 损失地形可视化训练入口 + - 用于训练的同时进行损失地形(Loss Landscape)可视化 + - 帮助理解模型的优化过程和泛化性能 + - 集成 `loss_landscapes` 库 + +#### ReZero 系列 +- **`train_rezero.py`** - ReZero 算法的训练入口 + - 支持 ReZero-MuZero 和 ReZero-EfficientZero + - 通过残差连接改进训练稳定性 + - 论文链接:https://arxiv.org/pdf/2404.16364 + +### 🎓 评估入口 (Evaluation Entries) + +- **`eval_alphazero.py`** - AlphaZero 算法的评估入口 + - 加载训练好的 AlphaZero 模型进行评估 + - 可以与其他智能体对弈测试性能 + +- **`eval_muzero.py`** - MuZero 系列算法的评估入口 + - 支持所有 MuZero 变体的评估 + - 提供详细的性能统计 + +- **`eval_muzero_with_gym_env.py`** - Gym 环境下的 MuZero 评估入口 (最近没有维护此入口) + - 专门用于评估在 Gym 环境中训练的模型 + + +## 📖 使用指南 + +### 基本使用模式 + +所有训练入口函数遵循相似的调用模式: + +```python +from lzero.entry import train_muzero + +# 准备配置 +cfg = dict(...) # 用户配置 +create_cfg = dict(...) # 创建配置 + +# 开始训练 +policy = train_muzero( + input_cfg=(cfg, create_cfg), + seed=0, + model=None, # 可选:预初始化模型 + model_path=None, # 可选:预训练模型路径 + max_train_iter=int(1e10), # 最大训练迭代次数 + max_env_step=int(1e10), # 最大环境步数 +) +``` + +### 选择合适的入口函数 + +1. **单任务学习**: + - 棋类游戏 → `train_alphazero` + - 一般 RL 任务 → `train_muzero` 或 `train_unizero` + - Gym 环境 → `train_muzero_with_gym_env` (最近没有维护此入口) + +2. **多任务学习**: + - 标准多任务 → `train_unizero_multitask_segment_ddp` + - 任务均衡采样 → `train_unizero_multitask_balance_segment_ddp` + +3. **分布式训练**: + - 所有带 `_ddp` 后缀的入口函数都支持数据并行分布式训练 + +4. **特殊需求**: + - 损失地形可视化 → `train_unizero_with_loss_landscape` + - 外部奖励模型 → `train_muzero_with_reward_model` + - 改进训练稳定性 → `train_rezero` + +## 🔗 相关资源 + +- **AlphaZero**: [Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm](https://arxiv.org/abs/1712.01815) +- **MuZero**: [Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model](https://arxiv.org/abs/1911.08265) +- **EfficientZero**: [Mastering Atari Games with Limited Data](https://arxiv.org/abs/2111.00210) +- **UniZero**: [Generalized and Efficient Planning with Scalable Latent World Models](https://arxiv.org/abs/2406.10667) +- **ReZero**: [Boosting MCTS-based Algorithms by Reconstructing the Terminal Reward](https://arxiv.org/abs/2404.16364) +- **ScaleZero**: [One Model for All Tasks: Leveraging Efficient World Models in Multi-Task Planning](https://arxiv.org/abs/2509.07945) + +## 💡 提示 + +- 建议从标准的 `train_muzero` 或 `train_unizero` 开始 +- 对于大规模实验,考虑使用 DDP 版本以提高训练速度 +- 使用 `_segment` 版本可以获得更好的样本效率 +- 查看 `zoo/` 目录下的配置示例以了解如何设置各个算法 + +## 📝 注意事项 + +1. 所有路径参数建议使用**绝对路径** +2. 预训练模型路径通常格式为 `exp_name/ckpt/ckpt_best.pth.tar` +3. 使用分布式训练时,确保正确设置 `CUDA_VISIBLE_DEVICES` 环境变量 diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index e6b84f7c1..ba846e26a 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -1,6 +1,5 @@ from .eval_alphazero import eval_alphazero from .eval_muzero import eval_muzero - from .eval_muzero_with_gym_env import eval_muzero_with_gym_env from .train_alphazero import train_alphazero from .train_muzero import train_muzero @@ -10,5 +9,33 @@ from .train_rezero import train_rezero from .train_unizero import train_unizero from .train_unizero_segment import train_unizero_segment +from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp +from .train_unizero_multitask_segment_ddp import train_unizero_multitask_segment_ddp +from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval +from .train_unizero_multitask_balance_segment_ddp import train_unizero_multitask_balance_segment_ddp from .train_unizero_with_loss_landscape import train_unizero_with_loss_landscape -from .utils import * + +# from .utils import ( +# symlog, +# inv_symlog, +# initialize_zeros_batch, +# freeze_non_lora_parameters, +# compute_task_weights, +# TemperatureScheduler, +# tasks_per_stage, +# compute_unizero_mt_normalized_stats, +# allocate_batch_size, +# is_ddp_enabled, +# ddp_synchronize, +# ddp_all_reduce_sum, +# calculate_update_per_collect, +# initialize_pad_batch, +# random_collect, +# convert_to_batch_for_unizero, +# create_unizero_loss_metrics, +# UniZeroDataLoader, +# log_module_trainable_status, +# log_param_statistics, +# log_buffer_memory_usage, +# log_buffer_run_time, +# ) diff --git a/lzero/entry/eval_muzero.py b/lzero/entry/eval_muzero.py index 6f87c656e..dcb2a1af8 100644 --- a/lzero/entry/eval_muzero.py +++ b/lzero/entry/eval_muzero.py @@ -1,7 +1,7 @@ +from ditk import logging import os from functools import partial from typing import Optional, Tuple -import logging import numpy as np import torch @@ -14,7 +14,6 @@ from ding.utils import set_pkg_seed from ding.worker import BaseLearner from lzero.worker import MuZeroEvaluator -from lzero.entry.utils import initialize_zeros_batch def eval_muzero( diff --git a/lzero/entry/train_alphazero.py b/lzero/entry/train_alphazero.py index 8aa31be06..4e1975d81 100644 --- a/lzero/entry/train_alphazero.py +++ b/lzero/entry/train_alphazero.py @@ -1,19 +1,17 @@ -import logging import os from functools import partial from typing import Optional, Tuple import torch from ding.config import compile_config -from ding.envs import create_env_manager -from ding.envs import get_vec_env_setting +from ding.envs import create_env_manager, get_vec_env_setting from ding.policy import create_policy from ding.utils import set_pkg_seed from ding.worker import BaseLearner, create_buffer -from tensorboardX import SummaryWriter - +from ditk import logging from lzero.policy import visit_count_temperature from lzero.worker import AlphaZeroCollector, AlphaZeroEvaluator +from tensorboardX import SummaryWriter def train_alphazero( diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index f33521086..b84e80c6d 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -1,4 +1,3 @@ -import logging import os from functools import partial from typing import Optional, Tuple @@ -6,20 +5,20 @@ import torch import wandb from ding.config import compile_config -from ding.envs import create_env_manager -from ding.envs import get_vec_env_setting +from ding.envs import create_env_manager, get_vec_env_setting from ding.policy import create_policy from ding.rl_utils import get_epsilon_greedy_fn -from ding.utils import set_pkg_seed, get_rank +from ding.utils import get_rank, set_pkg_seed from ding.worker import BaseLearner -from tensorboardX import SummaryWriter - +from ditk import logging from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time from lzero.policy import visit_count_temperature from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroCollector as Collector from lzero.worker import MuZeroEvaluator as Evaluator -from .utils import random_collect, calculate_update_per_collect +from tensorboardX import SummaryWriter + +from .utils import calculate_update_per_collect, random_collect def train_muzero( diff --git a/lzero/entry/train_muzero_multitask_segment_ddp.py b/lzero/entry/train_muzero_multitask_segment_ddp.py new file mode 100644 index 000000000..f693e6ca4 --- /dev/null +++ b/lzero/entry/train_muzero_multitask_segment_ddp.py @@ -0,0 +1,460 @@ +import concurrent.futures +from ditk import logging +import os +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import Policy, create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer, set_pkg_seed, get_rank, get_world_size +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, allocate_batch_size, EVALUATION_TIMEOUT, safe_eval +from lzero.mcts import MuZeroGameBuffer as GameBuffer +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator + +# ========================== +# Global Constants +# ========================== +# Note: This file uses a shorter timeout (1 hour) compared to other multitask files (200 minutes) +# You can adjust this value or use EVALUATION_TIMEOUT from utils.py instead +EVALUATION_TIMEOUT_SECONDS: int = 3600 # 1 hour +MAX_TRAIN_ITER_INF: int = int(1e10) +MAX_ENV_STEP_INF: int = int(1e10) + + +class MuZeroMultiTaskTrainer: + """ + Overview: + A trainer class to manage the multi-task training loop for MuZero. + It encapsulates the state and logic for initialization, data collection, + evaluation, training, and termination. + """ + + def __init__( + self, + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int, + model: Optional[torch.nn.Module], + model_path: Optional[str], + max_train_iter: int, + max_env_step: int, + ) -> None: + """ + Overview: + Initializes the multi-task trainer. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): Configs for all tasks. + - seed (:obj:`int`): The base random seed. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-defined model. + - model_path (:obj:`Optional[str]`): Path to a pre-trained model checkpoint. + - max_train_iter (:obj:`int`): Maximum training iterations. + - max_env_step (:obj:`int`): Maximum environment steps. + """ + self.max_train_iter = max_train_iter + self.max_env_step = max_env_step + self.seed = seed + self.rank = get_rank() + self.world_size = get_world_size() + self.timer = EasyTimer() + + # State variables + self.train_epoch = 0 + self.buffer_reanalyze_count = 0 + self.value_priority_tasks = {} + + # Task partitioning + self.tasks_for_this_rank = self._partition_tasks(input_cfg_list) + if not self.tasks_for_this_rank: + logging.warning(f"Rank {self.rank}: No tasks assigned. Process will run without tasks.") + self.is_active = False + return + self.is_active = True + + # Initialize shared components (Policy, Learner) + self.policy, self.learner, self.tb_logger = self._initialize_shared_components(model, model_path) + + # Initialize task-specific components + ( + self.cfgs, self.game_buffers, self.collectors, self.evaluators + ) = self._initialize_task_specific_components() + + self.update_per_collect = self.cfgs[0].policy.update_per_collect + + def _partition_tasks(self, input_cfg_list: List[Tuple[int, Tuple[dict, dict]]]) -> List[ + Tuple[int, Tuple[dict, dict]]]: + """Partitions tasks among distributed processes.""" + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // self.world_size + remainder = total_tasks % self.world_size + + if self.rank < remainder: + start_idx = self.rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = self.rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + logging.info(f"Rank {self.rank}/{self.world_size} is assigned tasks from index {start_idx} to {end_idx - 1}.") + return input_cfg_list[start_idx:end_idx] + + def _initialize_shared_components(self, model: Optional[torch.nn.Module], model_path: Optional[str]) -> Tuple[ + Policy, BaseLearner, SummaryWriter]: + """Initializes components shared across all tasks on this rank.""" + _, [cfg, create_cfg] = self.tasks_for_this_rank[0] + + # Set task_num for the shared policy + for task_config in self.tasks_for_this_rank: + task_config[1][0].policy.task_num = len(self.tasks_for_this_rank) + + cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu' + compiled_cfg = compile_config(cfg, seed=self.seed, auto=True, create_cfg=create_cfg, save_cfg=True) + + policy = create_policy(compiled_cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + if model_path: + logging.info(f'Loading model from {model_path}...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg.policy.device)) + logging.info(f'Model loaded successfully from {model_path}.') + + log_dir = os.path.join(f'./{compiled_cfg.exp_name}/log', f'serial_rank_{self.rank}') + tb_logger = SummaryWriter(log_dir) + learner = BaseLearner(compiled_cfg.policy.learn.learner, policy.learn_mode, tb_logger, + exp_name=compiled_cfg.exp_name) + return policy, learner, tb_logger + + def _initialize_task_specific_components(self) -> Tuple[List, List, List, List]: + """Initializes components for each task assigned to this rank.""" + cfgs, game_buffers, collectors, evaluators = [], [], [], [] + + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(self.tasks_for_this_rank): + task_seed = self.seed + task_id + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + compiled_cfg = compile_config(cfg, seed=task_seed, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Create environments + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(compiled_cfg.env) + collector_env = create_env_manager(compiled_cfg.env.manager, + [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(compiled_cfg.env.manager, + [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(task_seed) + evaluator_env.seed(task_seed, dynamic_seed=False) + set_pkg_seed(task_seed, use_cuda=compiled_cfg.policy.cuda) + + # Create buffer, collector, and evaluator + replay_buffer = GameBuffer(compiled_cfg.policy) + # Set initial batch size from config + replay_buffer.batch_size = compiled_cfg.policy.batch_size[task_id] + + collector = Collector( + env=collector_env, + policy=self.policy.collect_mode, + tb_logger=self.tb_logger, + exp_name=compiled_cfg.exp_name, + policy_config=compiled_cfg.policy, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=compiled_cfg.policy.eval_freq, + n_evaluator_episode=compiled_cfg.env.n_evaluator_episode, + stop_value=compiled_cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + tb_logger=self.tb_logger, + exp_name=compiled_cfg.exp_name, + policy_config=compiled_cfg.policy, + task_id=task_id + ) + + cfgs.append(compiled_cfg) + game_buffers.append(replay_buffer) + collectors.append(collector) + evaluators.append(evaluator) + + return cfgs, game_buffers, collectors, evaluators + + def run(self) -> Policy: + """ + Overview: + The main training loop. Executes collection, evaluation, and training steps + until a termination condition is met. + Returns: + - (:obj:`Policy`): The trained policy. + """ + if not self.is_active: + # This rank has no tasks, so it should wait for others to finish. + self._wait_for_termination() + return self.policy + + self.learner.call_hook('before_run') + + while True: + torch.cuda.empty_cache() + + self._update_dynamic_batch_sizes() + self._collect_and_evaluate() + + if self._is_training_ready(): + dist.barrier() + self._train_iteration() + dist.barrier() + else: + logging.warning(f"Rank {self.rank}: Not enough data for training, skipping training step.") + + if self._check_termination_conditions(): + dist.barrier() # Final barrier to ensure all processes stop together. + break + + self.learner.call_hook('after_run') + return self.policy + + def _update_dynamic_batch_sizes(self) -> None: + """Dynamically allocates batch sizes if enabled in the config.""" + if not self.cfgs[0].policy.get('allocated_batch_sizes', False): + return + + # Linearly increase clip_scale from 1 to 4 as train_epoch goes from 0 to 1000. + clip_scale = np.clip(1 + (3 * self.train_epoch / 1000), 1, 4) + allocated_sizes = allocate_batch_size(self.cfgs, self.game_buffers, alpha=1.0, clip_scale=clip_scale) + + # Distribute the allocated sizes to the tasks on the current rank. + # This requires knowing the global task distribution. + total_tasks = self.world_size * len(self.tasks_for_this_rank) # Approximation, needs exact count + # This part is tricky in a distributed setting without global knowledge of task indices. + # Assuming the allocation order matches the task_id order. + for i, cfg in enumerate(self.cfgs): + task_id = cfg.policy.task_id + if task_id < len(allocated_sizes): + batch_size = allocated_sizes[task_id] + cfg.policy.batch_size = batch_size + # Also update the batch size in the shared policy config if necessary + self.policy._cfg.batch_size[task_id] = batch_size + + + def _collect_and_evaluate(self) -> None: + """Runs the data collection and evaluation loop for each assigned task.""" + for i, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(self.cfgs, self.collectors, self.evaluators, self.game_buffers)): + log_buffer_memory_usage(self.learner.train_iter, replay_buffer, self.tb_logger, cfg.policy.task_id) + + # Evaluation step + if evaluator.should_eval(self.learner.train_iter): + safe_eval(evaluator, self.learner, collector, self.rank, self.world_size, + timeout=EVALUATION_TIMEOUT_SECONDS) + + # Collection step + self._collect_data_for_task(cfg, collector, replay_buffer) + + def _collect_data_for_task(self, cfg: Any, collector: Collector, replay_buffer: GameBuffer) -> None: + """Collects data for a single task and pushes it to the replay buffer.""" + policy_config = cfg.policy + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=self.learner.train_iter + ), + 'epsilon': 0.0 + } + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, end=policy_config.eps.end, + decay=policy_config.eps.decay, type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_fn(collector.envstep) + + logging.info(f'Rank {self.rank}: Collecting data for task {cfg.policy.task_id}...') + new_data = collector.collect(train_iter=self.learner.train_iter, policy_kwargs=collect_kwargs) + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + logging.info(f'Rank {self.rank}: Finished data collection for task {cfg.policy.task_id}.') + + # Periodic reanalysis of the buffer + self._reanalyze_buffer_if_needed(cfg, replay_buffer, is_during_training=False) + + def _reanalyze_buffer_if_needed(self, cfg: Any, replay_buffer: GameBuffer, is_during_training: bool, + train_loop_idx: int = 0) -> None: + """Handles the logic for reanalyzing the game buffer.""" + policy_config = cfg.policy + reanalyze_freq = policy_config.buffer_reanalyze_freq + reanalyze_batch_size = policy_config.reanalyze_batch_size + reanalyze_partition = policy_config.reanalyze_partition + update_per_collect = policy_config.update_per_collect + + should_reanalyze = False + if reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // reanalyze_freq + if is_during_training and train_loop_idx % reanalyze_interval == 0: + should_reanalyze = True + else: # reanalyze_freq is a fraction, e.g., 0.1 + if not is_during_training and self.train_epoch % int(1 / reanalyze_freq) == 0: + should_reanalyze = True + + if should_reanalyze and replay_buffer.get_num_of_transitions() // policy_config.num_unroll_steps > int(reanalyze_batch_size / reanalyze_partition): + with self.timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, self.policy) + self.buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {self.buffer_reanalyze_count}, Time: {self.timer.value:.2f}s') + + def _is_training_ready(self) -> bool: + """Checks if there is enough data in all buffers to start training.""" + for cfg, buffer in zip(self.cfgs, self.game_buffers): + if buffer.get_num_of_transitions() < cfg.policy.batch_size[cfg.policy.task_id]: + logging.warning(f"Rank {self.rank}, Task {cfg.policy.task_id}: Not enough data. " + f"Required: {cfg.policy.batch_size[cfg.policy.task_id]}, " + f"Available: {buffer.get_num_of_transitions()}") + return False + return True + + def _train_iteration(self) -> None: + """Performs one full training iteration, consisting of multiple updates.""" + for i in range(self.update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + + for idx, (cfg, collector, replay_buffer) in enumerate( + zip(self.cfgs, self.collectors, self.game_buffers)): + envstep_multi_task += collector.envstep + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + + if replay_buffer.get_num_of_transitions() > batch_size: + self._reanalyze_buffer_if_needed(cfg, replay_buffer, is_during_training=True, train_loop_idx=i) + train_data = replay_buffer.sample(batch_size, self.policy) + train_data.append(cfg.policy.task_id) # Append task_id for multi-task loss + train_data_multi_task.append(train_data) + else: + # This case should ideally be prevented by _is_training_ready + logging.warning(f"Skipping sample for task {cfg.policy.task_id} due to insufficient data.") + train_data_multi_task.clear() # Invalidate the whole batch if one task fails + break + + if train_data_multi_task: + log_vars = self.learner.train(train_data_multi_task, envstep_multi_task) + if self.cfgs[0].policy.use_priority: + self._update_priorities(train_data_multi_task, log_vars) + + self.train_epoch += 1 + + def _update_priorities(self, train_data_multi_task: List, log_vars: List[Dict]) -> None: + """Updates the priorities in the replay buffers after a training step.""" + for idx, (cfg, replay_buffer) in enumerate(zip(self.cfgs, self.game_buffers)): + task_id = cfg.policy.task_id + priority_key = f'value_priority_task{task_id}' + + if priority_key in log_vars[0]: + priorities = log_vars[0][priority_key] + replay_buffer.update_priority(train_data_multi_task[idx], priorities) + + # Log priority statistics + if cfg.policy.get('print_task_priority_logs', False): + mean_priority = np.mean(priorities) + std_priority = np.std(priorities) + + # Update running mean of priority + running_mean_key = f'running_mean_priority_task{task_id}' + alpha = 0.1 # Smoothing factor for running average + if running_mean_key not in self.value_priority_tasks: + self.value_priority_tasks[running_mean_key] = mean_priority + else: + self.value_priority_tasks[running_mean_key] = \ + alpha * mean_priority + (1 - alpha) * self.value_priority_tasks[running_mean_key] + + running_mean_priority = self.value_priority_tasks[running_mean_key] + logging.info( + f"Task {task_id} - Priority Stats: Mean={mean_priority:.6f}, " + f"Running Mean={running_mean_priority:.6f}, Std={std_priority:.6f}" + ) + + def _check_termination_conditions(self) -> bool: + """Checks if the training should be terminated based on env steps or train iterations.""" + try: + # Check max_env_step + local_envsteps = [collector.envstep for collector in self.collectors] + all_ranks_envsteps = [None for _ in range(self.world_size)] + dist.all_gather_object(all_ranks_envsteps, local_envsteps) + + # Flatten and check if all tasks have reached the step limit + all_envsteps = [step for rank_steps in all_ranks_envsteps for step in rank_steps] + if all(step >= self.max_env_step for step in all_envsteps): + logging.info(f"Rank {self.rank}: All tasks reached max_env_step ({self.max_env_step}). Terminating.") + return True + + # Check max_train_iter + local_train_iter = torch.tensor([self.learner.train_iter], device=self.policy.device) + all_train_iters = [torch.zeros_like(local_train_iter) for _ in range(self.world_size)] + dist.all_gather(all_train_iters, local_train_iter) + + if any(it.item() >= self.max_train_iter for it in all_train_iters): + logging.info(f"Rank {self.rank}: A process reached max_train_iter ({self.max_train_iter}). Terminating.") + return True + + except Exception as e: + logging.error(f'Rank {self.rank}: Failed during termination check. Error: {e}', exc_info=True) + return True # Terminate on error to prevent hanging + + return False + + def _wait_for_termination(self) -> None: + """ + For inactive ranks, this method blocks and waits for a termination signal + (e.g., another rank finishing) by participating in barriers and termination checks. + """ + while True: + # Participate in barriers to stay in sync + dist.barrier() # Pre-train barrier + dist.barrier() # Post-train barrier + + if self._check_termination_conditions(): + dist.barrier() # Final barrier + break + +def train_muzero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = MAX_TRAIN_ITER_INF, + max_env_step: Optional[int] = MAX_ENV_STEP_INF, +) -> Policy: + """ + Overview: + The main entry point for multi-task MuZero training using Distributed Data Parallel (DDP). + This function sets up the distributed environment, partitions tasks, and launches the training process, + which is managed by the MuZeroMultiTaskTrainer class. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): A list of tuples, where each tuple contains + a task ID and its corresponding configuration dictionaries (main_config, create_config). + - seed (:obj:`int`): The base random seed for reproducibility. Defaults to 0. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-defined model instance. If provided, + it will be used instead of creating a new one from the config. Defaults to None. + - model_path (:obj:`Optional[str]`): Path to a pre-trained model checkpoint file. If provided, + the model weights will be loaded before training starts. Defaults to None. + - max_train_iter (:obj:`Optional[int]`): The maximum number of training iterations. + Training will stop if any process reaches this limit. Defaults to a very large number. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment steps for each task. + Training will stop when all tasks have reached this limit. Defaults to a very large number. + Returns: + - (:obj:`Policy`): The final trained policy instance from the primary rank. + """ + # Initialize the trainer, which handles all the complex setup and logic internally. + trainer = MuZeroMultiTaskTrainer( + input_cfg_list=input_cfg_list, + seed=seed, + model=model, + model_path=model_path, + max_train_iter=max_train_iter, + max_env_step=max_env_step, + ) + + # Run the training loop and return the trained policy. + return trainer.run() diff --git a/lzero/entry/train_muzero_segment.py b/lzero/entry/train_muzero_segment.py index 4e9809e05..80474c0f3 100644 --- a/lzero/entry/train_muzero_segment.py +++ b/lzero/entry/train_muzero_segment.py @@ -1,25 +1,23 @@ -import logging import os from functools import partial from typing import Optional, Tuple import torch from ding.config import compile_config -from ding.envs import create_env_manager -from ding.envs import get_vec_env_setting +from ding.envs import create_env_manager, get_vec_env_setting from ding.policy import create_policy from ding.rl_utils import get_epsilon_greedy_fn -from ding.utils import EasyTimer -from ding.utils import set_pkg_seed, get_rank +from ding.utils import EasyTimer, get_rank, set_pkg_seed from ding.worker import BaseLearner -from tensorboardX import SummaryWriter - +from ditk import logging from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time from lzero.policy import visit_count_temperature from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroSegmentCollector as Collector -from .utils import random_collect, calculate_update_per_collect +from tensorboardX import SummaryWriter + +from .utils import calculate_update_per_collect, random_collect timer = EasyTimer() diff --git a/lzero/entry/train_muzero_with_gym_env.py b/lzero/entry/train_muzero_with_gym_env.py index a0c771d07..7ac418c34 100644 --- a/lzero/entry/train_muzero_with_gym_env.py +++ b/lzero/entry/train_muzero_with_gym_env.py @@ -1,20 +1,18 @@ -import logging import os -from typing import Optional -from typing import Tuple +from typing import Optional, Tuple import torch -from tensorboardX import SummaryWriter - from ding.config import compile_config -from ding.envs import DingEnvWrapper, BaseEnvManager +from ding.envs import BaseEnvManager, DingEnvWrapper from ding.policy import create_policy from ding.rl_utils import get_epsilon_greedy_fn from ding.utils import set_pkg_seed from ding.worker import BaseLearner +from ditk import logging from lzero.envs.get_wrapped_env import get_wrappered_env from lzero.policy import visit_count_temperature from lzero.worker import MuZeroCollector, MuZeroEvaluator +from tensorboardX import SummaryWriter def train_muzero_with_gym_env( diff --git a/lzero/entry/train_muzero_with_reward_model.py b/lzero/entry/train_muzero_with_reward_model.py index 028160430..cb53eb10f 100644 --- a/lzero/entry/train_muzero_with_reward_model.py +++ b/lzero/entry/train_muzero_with_reward_model.py @@ -1,4 +1,4 @@ -import logging +from ditk import logging import os from functools import partial from typing import Optional, Tuple diff --git a/lzero/entry/train_rezero.py b/lzero/entry/train_rezero.py index 131a1684a..694102193 100644 --- a/lzero/entry/train_rezero.py +++ b/lzero/entry/train_rezero.py @@ -1,4 +1,3 @@ -import logging import os from functools import partial from typing import Optional, Tuple @@ -8,16 +7,17 @@ from ding.envs import create_env_manager, get_vec_env_setting from ding.policy import create_policy from ding.rl_utils import get_epsilon_greedy_fn -from ding.utils import set_pkg_seed, get_rank +from ding.utils import get_rank, set_pkg_seed from ding.worker import BaseLearner -from tensorboardX import SummaryWriter - +from ditk import logging from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time from lzero.policy import visit_count_temperature from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroCollector as Collector from lzero.worker import MuZeroEvaluator as Evaluator -from .utils import random_collect, calculate_update_per_collect +from tensorboardX import SummaryWriter + +from .utils import calculate_update_per_collect, random_collect def train_rezero( @@ -227,4 +227,4 @@ def perform_offline_evaluation(cfg, learner, policy, evaluator, eval_train_iter_ stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep) logging.info(f'Offline eval at iter: {train_iter}, steps: {collector_envstep}, reward: {reward}') - logging.info('Offline evaluation completed') \ No newline at end of file + logging.info('Offline evaluation completed') diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index dfdd65487..9d41734d1 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -1,28 +1,26 @@ -import logging import os from functools import partial -from typing import Tuple, Optional +from typing import Optional, Tuple import torch +import torch.distributed as dist import wandb from ding.config import compile_config -from ding.envs import create_env_manager -from ding.envs import get_vec_env_setting +from ding.envs import create_env_manager, get_vec_env_setting from ding.policy import create_policy from ding.rl_utils import get_epsilon_greedy_fn -from ding.utils import set_pkg_seed, get_rank +from ding.utils import get_rank, get_world_size, set_pkg_seed from ding.worker import BaseLearner -from tensorboardX import SummaryWriter -from torch.utils.tensorboard import SummaryWriter - +from ditk import logging from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature from lzero.policy.random_policy import LightZeroRandomPolicy -from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroCollector as Collector -from .utils import random_collect, calculate_update_per_collect -import torch.distributed as dist -from ding.utils import set_pkg_seed, get_rank, get_world_size +from lzero.worker import MuZeroEvaluator as Evaluator +from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter + +from .utils import calculate_update_per_collect, random_collect def train_unizero( @@ -220,12 +218,8 @@ def train_unizero( if os.environ.get('DEBUG', '').lower() == 'true': import pudb; pudb.set_trace() - log_vars = learner.train(train_data, collector.envstep) - - - if cfg.policy.use_priority: replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) @@ -240,4 +234,4 @@ def train_unizero( if cfg.policy.use_wandb: wandb.finish() logging.info("===== Training Completed =====") - return policy \ No newline at end of file + return policy diff --git a/lzero/entry/train_unizero_multitask_balance_segment_ddp.py b/lzero/entry/train_unizero_multitask_balance_segment_ddp.py new file mode 100644 index 000000000..0f9174c65 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_balance_segment_ddp.py @@ -0,0 +1,572 @@ +import concurrent.futures +import math +import os +from collections import defaultdict +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer, get_rank, get_world_size, set_pkg_seed +from ding.worker import BaseLearner +from ditk import logging +from lzero.entry.utils import TemperatureScheduler, log_buffer_memory_usage +from lzero.model.unizero_world_models.transformer import (CurriculumLoRALinear, + set_curriculum_stage) +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from tensorboardX import SummaryWriter + +from .utils import (allocate_batch_size, compute_task_weights, + compute_unizero_mt_normalized_stats, + EVALUATION_TIMEOUT, safe_eval, + freeze_non_lora_parameters, log_module_trainable_status, + log_param_statistics, tasks_per_stage) + +# A global dictionary to store the most recent evaluation return for each task. +# Format: {task_id: eval_episode_return_mean} +GLOBAL_EVAL_RETURNS: Dict[int, float] = defaultdict(lambda: None) + + +class CurriculumController: + """ + Overview: + Manages the curriculum learning stages for a multi-task policy. + It tracks the number of solved tasks and training iterations to decide when to transition + to the next curriculum stage, which typically involves freezing parts of the model + and activating new LoRA adapters. + """ + + def __init__(self, cfg: 'EasyDict', policy: 'Policy') -> None: + """ + Overview: + Initializes the CurriculumController. + Arguments: + - cfg (:obj:`EasyDict`): The experiment configuration. + - policy (:obj:`Policy`): The policy being trained. + """ + world_model_cfg = cfg.policy.model.world_model_cfg + self.stage_num: int = world_model_cfg.curriculum_stage_num + self.min_stage0_iters: int = world_model_cfg.min_stage0_iters + self.max_stage_iters: int = world_model_cfg.max_stage_iters + self.policy: 'Policy' = policy + + # Flag to determine if curriculum learning should also be applied to the encoder. + # Defaults to False for backward compatibility. + self.apply_curriculum_to_encoder: bool = getattr(world_model_cfg, 'apply_curriculum_to_encoder', False) + logging.info(f"[CurriculumController] Initialized. Curriculum will be applied to Encoder: {self.apply_curriculum_to_encoder}") + + self.stage: int = 0 + self.last_switch_iter: int = 0 + self.last_solved_count: int = 0 # Snapshot of the last count of solved tasks + + def step(self, solved_count: int, unsolved_count: int, train_iter: int) -> bool: + """ + Overview: + Checks if the curriculum should transition to the next stage and performs the switch if needed. + This method should be called at the end of each training loop. + Arguments: + - solved_count (:obj:`int`): The current total number of solved tasks. + - unsolved_count (:obj:`int`): The current number of tasks yet to be solved. + - train_iter (:obj:`int`): The current training iteration. + Returns: + - bool: True if a stage switch occurred, False otherwise. + """ + # --- Stage 0 is a mandatory training phase for a minimum number of iterations --- + if self.stage == 0 and train_iter < self.min_stage0_iters: + return False + + # --- Determine if a stage switch is necessary --- + should_switch = False + + # 1. Trigger based on task progress + newly_solved = solved_count - self.last_solved_count + remaining_lora_stages = self.stage_num - 1 - self.stage # Stage 0 doesn't use LoRA + if remaining_lora_stages > 0: + # Calculate tasks per stage (tps) for the remaining unsolved tasks + tps = tasks_per_stage(unsolved_count, remaining_lora_stages) + if newly_solved >= tps: + should_switch = True + + # 2. Trigger based on maximum iterations per stage + if train_iter - self.last_switch_iter >= self.max_stage_iters: + should_switch = True + + # --- Execute the stage switch --- + if should_switch and self.stage < self.stage_num - 1: + is_entering_stage1 = (self.stage == 0) + self.stage += 1 + + world_model = self.policy._learn_model.world_model + vit_encoder = world_model.tokenizer.encoder + transformer_backbone = world_model.transformer + + # --- Apply curriculum stage update and freeze parameters accordingly --- + + # 1. Conditionally apply to ViT Encoder based on configuration + if self.apply_curriculum_to_encoder: + logging.info(f"[Curriculum] Applying curriculum stage {self.stage} to ViT Encoder.") + set_curriculum_stage(vit_encoder, self.stage) + if is_entering_stage1: + logging.info("[Curriculum] Entering Stage 1. Freezing non-LoRA parameters in ViT Encoder.") + freeze_non_lora_parameters(vit_encoder, freeze=True, verbose=True) + log_module_trainable_status(vit_encoder, "ViT Encoder") + else: + logging.info("[Curriculum] Skipping curriculum stage update for ViT Encoder as per configuration.") + log_module_trainable_status(vit_encoder, "ViT Encoder (Curriculum Not Applied)") + + # 2. Always apply to Transformer Decoder + logging.info(f"[Curriculum] Applying curriculum stage {self.stage} to Transformer Backbone.") + set_curriculum_stage(transformer_backbone, self.stage) + if is_entering_stage1: + logging.info("[Curriculum] Entering Stage 1. Freezing non-LoRA parameters in Transformer Backbone.") + freeze_non_lora_parameters(transformer_backbone, freeze=True, verbose=True) + log_module_trainable_status(transformer_backbone, "Transformer Backbone") + + logging.info( + f'[Curriculum] Switched to stage {self.stage} ' + f'(solved={solved_count}, unsolved={unsolved_count}, iter={train_iter})' + ) + + # Log parameter statistics after the switch + updated_params = sum(p.requires_grad for p in self.policy._learn_model.world_model.parameters()) + total_params = sum(1 for _ in self.policy._learn_model.world_model.parameters()) + logging.info(f'{updated_params}/{total_params} parameters in the world model will be optimized.') + log_param_statistics(self.policy._learn_model.world_model) + + self.last_solved_count = solved_count + self.last_switch_iter = train_iter + return True + + return False + + +def train_unizero_multitask_balance_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + benchmark_name: str = "atari" +) -> 'Policy': + """ + Overview: + The main training entry point for UniZero in a multi-task, curriculum-based setting using DDP. + This function orchestrates distributed data collection, training, and evaluation across multiple tasks. + The curriculum learning strategy involves: + - Defining a `target_return` for each task. + - Moving tasks to a `solved_task_pool` once they achieve their target return, excluding them from + further training and collection. + - Progressing through curriculum stages where the model's backbone is frozen, and only specialized + modules (like LoRA) are trained on harder, unsolved tasks. + This allows the model to first learn general features and then specialize on difficult tasks without + catastrophic forgetting. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): A list of configurations for each task. + - seed (:obj:`int`): The random seed. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-existing model instance. + - model_path (:obj:`Optional[str]`): Path to a pre-trained model checkpoint file. + - max_train_iter (:obj:`Optional[int]`): The maximum number of training iterations. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment steps. + - benchmark_name (:obj:`str`): The name of the benchmark (e.g., "atari", "dmc") to load normalization scores. + Returns: + - Policy: The trained policy. + """ + # --- Initialization and DDP Setup --- + logging.basicConfig(level=logging.INFO) + rank = get_rank() + world_size = get_world_size() + timer = EasyTimer() + + # --- Benchmark Score Initialization --- + if benchmark_name == "atari": + RANDOM_SCORES = np.array([ + 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, + 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, + -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 + ]) + HUMAN_SCORES = np.array([ + 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, + 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, + 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 + ]) + new_order = [ + 20, 19, 24, 6, 0, 8, 14, 23, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 15, 16, 17, 18, 21, 25, 22, 7 + ] + new_RANDOM_SCORES = RANDOM_SCORES[new_order] + new_HUMAN_SCORES = HUMAN_SCORES[new_order] + elif benchmark_name == "dmc": + new_RANDOM_SCORES = np.zeros(26) + new_HUMAN_SCORES = np.ones(26) * 1000 + else: + raise ValueError(f"Unsupported benchmark_name: {benchmark_name}") + + # --- Task Distribution Across Ranks --- + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + start_idx = rank * tasks_per_rank + min(rank, remainder) + end_idx = start_idx + tasks_per_rank + (1 if rank < remainder else 0) + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # Ensure at least one task is assigned. + if len(tasks_for_this_rank) == 0: + logging.error(f"Rank {rank}: No tasks assigned, continuing execution.") + else: + logging.info(f"Rank {rank}/{world_size} processing tasks {start_idx} to {end_idx - 1}") + + # --- Environment, Policy, and Worker Initialization --- + task_configs, replay_buffers, collectors, evaluators = [], [], [], [] + + # Use the first task's config to create the shared policy and learner + _, [main_cfg, main_create_cfg] = tasks_for_this_rank[0] + for _, [cfg, _] in tasks_for_this_rank: + cfg.policy.task_num = len(tasks_for_this_rank) + + # Ensure main_cfg has a valid exp_name before calling compile_config. + # If exp_name is missing, None, or too long, set a safe default. + if not hasattr(main_cfg, 'exp_name') or main_cfg.exp_name is None or len(str(main_cfg.exp_name)) > 200: + # Use a simplified experiment name for the main config + safe_exp_name = f'data_unizero_multitask_balance/multitask_seed{seed}' + logging.warning( + f"Rank {rank}: main_cfg.exp_name is missing, None, or too long. " + f"Setting to safe default: {safe_exp_name}" + ) + main_cfg.exp_name = safe_exp_name + else: + logging.info(f"Rank {rank}: Using exp_name from config: {main_cfg.exp_name}") + + assert main_create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], \ + "This entry only supports 'unizero_multitask' or 'sampled_unizero_multitask' policies." + + GameBuffer = None + if main_create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + elif main_create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + main_cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu' + compiled_cfg = compile_config(main_cfg, seed=seed, auto=True, create_cfg=main_create_cfg, save_cfg=True) + + policy = create_policy(compiled_cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Log initial model architecture info BEFORE loading checkpoint + if rank == 0: + num_layers_config = compiled_cfg.policy.model.world_model_cfg.num_layers + initial_params = sum(p.numel() for p in policy._learn_model.world_model.parameters()) + initial_trainable = sum(p.numel() for p in policy._learn_model.world_model.parameters() if p.requires_grad) + logging.info(f"=" * 80) + logging.info(f"Model Architecture Configuration:") + logging.info(f" - num_layers from config: {num_layers_config}") + logging.info(f" - Total parameters (before checkpoint load): {initial_params:,}") + logging.info(f" - Trainable parameters (before checkpoint load): {initial_trainable:,}") + logging.info(f"=" * 80) + + if model_path: + logging.info(f'Loading pre-trained model from: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=compiled_cfg.policy.device)) + logging.info('Model loading complete.') + if rank == 0: + loaded_params = sum(p.numel() for p in policy._learn_model.world_model.parameters()) + loaded_trainable = sum(p.numel() for p in policy._learn_model.world_model.parameters() if p.requires_grad) + logging.info(f"Model Parameters After Loading Checkpoint:") + logging.info(f" - Total parameters (after checkpoint load): {loaded_params:,}") + logging.info(f" - Trainable parameters (after checkpoint load): {loaded_trainable:,}") + if initial_params != loaded_params: + logging.warning(f"⚠️ WARNING: Parameter count mismatch!") + logging.warning(f" Config specifies {initial_params:,} params, but loaded model has {loaded_params:,} params") + logging.warning(f" This usually means the checkpoint was trained with different num_layers!") + logging.warning(f" The loaded checkpoint architecture will override your config settings.") + + + tb_logger = SummaryWriter(os.path.join(f'./{compiled_cfg.exp_name}/log', f'rank_{rank}')) + learner = BaseLearner(compiled_cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=compiled_cfg.exp_name) + learner.call_hook('before_run') + + # Initialize components for each assigned task + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + task_seed = seed + task_id + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + + # ==================== START: Robust exp_name Fix for Task Config ==================== + # Ensure each task config has a valid exp_name before calling compile_config + if not hasattr(cfg, 'exp_name') or cfg.exp_name is None: + # Extract env_id from config if available, otherwise use task_id + env_id = getattr(cfg.env, 'env_id', f'task{task_id}') + cfg.exp_name = f'data_unizero_mt_balance/task_{env_id}_seed{task_seed}' + logging.warning( + f"Rank {rank}: Task {task_id} config missing exp_name. " + f"Setting to: {cfg.exp_name}" + ) + # ==================== END: Robust exp_name Fix for Task Config ==================== + + compiled_task_cfg = compile_config(cfg, seed=task_seed, auto=True, create_cfg=create_cfg, save_cfg=True) + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(compiled_task_cfg.env) + collector_env = create_env_manager(compiled_task_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(compiled_task_cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(task_seed) + evaluator_env.seed(task_seed, dynamic_seed=False) + set_pkg_seed(task_seed, use_cuda=compiled_task_cfg.policy.cuda) + + replay_buffers.append(GameBuffer(compiled_task_cfg.policy)) + collectors.append(Collector( + collect_print_freq=100, + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=compiled_task_cfg.exp_name, + instance_name=f'collector_task{task_id}', + policy_config=compiled_task_cfg.policy, + task_id=task_id + )) + evaluators.append(Evaluator( + eval_freq=compiled_task_cfg.policy.eval_freq, + n_evaluator_episode=compiled_task_cfg.env.n_evaluator_episode, + stop_value=compiled_task_cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=compiled_task_cfg.exp_name, + instance_name=f'evaluator_task{task_id}', + policy_config=compiled_task_cfg.policy, + task_id=task_id + )) + task_configs.append(compiled_task_cfg) + + # --- Curriculum and Training Loop Initialization --- + solved_task_pool = set() + curriculum_controller = CurriculumController(compiled_cfg, policy) + temperature_scheduler = TemperatureScheduler(initial_temp=10.0, final_temp=1.0, threshold_steps=int(1e4), mode='linear') + + train_epoch = 0 + buffer_reanalyze_count = 0 + + logging.info(f"Rank {rank}: Initial trainable parameters in world model: {sum(p.requires_grad for p in policy._learn_model.world_model.parameters())}/{sum(1 for _ in policy._learn_model.world_model.parameters())}") + + # ============================================================================================ + # Main Training Loop + # ============================================================================================ + while True: + # --- 1. Dynamic Batch Size Allocation (Optional) --- + if compiled_cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(task_configs, replay_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + logging.info(f"Dynamically allocated batch sizes: {allocated_batch_sizes}") + # Assign the corresponding batch size to each task config + for i, cfg in enumerate(task_configs): + task_id = cfg.policy.task_id + if isinstance(allocated_batch_sizes, dict): + cfg.policy.batch_size = allocated_batch_sizes.get(task_id, cfg.policy.batch_size) + elif isinstance(allocated_batch_sizes, list): + # Use the index in the list or task_id as fallback + cfg.policy.batch_size = allocated_batch_sizes[i] if i < len(allocated_batch_sizes) else cfg.policy.batch_size + else: + logging.warning(f"Unexpected type for allocated_batch_sizes: {type(allocated_batch_sizes)}") + # Also update the policy config (use the full list for compatibility) + policy._cfg.batch_size = allocated_batch_sizes + + # --- 2. Data Collection and Evaluation for each task on this rank --- + local_task_returns = {} + for i, (cfg, collector, evaluator, replay_buffer) in enumerate(zip(task_configs, collectors, evaluators, replay_buffers)): + task_id = cfg.policy.task_id + if task_id in solved_task_pool: + continue + + # Evaluate policy if it's time + if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): + logging.info(f'Rank {rank} evaluating task_id: {task_id}...') + evaluator._policy.reset(reset_init_data=True, task_id=task_id) + stop_flag, reward_dict = safe_eval(evaluator, learner, collector, rank, world_size) + + if reward_dict is not None: + eval_mean_reward = reward_dict.get('eval_episode_return_mean', float('-inf')) + logging.info(f"Task {task_id} evaluation reward: {eval_mean_reward}") + local_task_returns[task_id] = eval_mean_reward + if eval_mean_reward >= cfg.policy.target_return: + logging.info(f"Task {task_id} has reached its target return of {cfg.policy.target_return}. Adding to solved pool.") + solved_task_pool.add(task_id) + else: + logging.warning(f"Evaluation failed or timed out for task {task_id}. Assigning a low score.") + local_task_returns[task_id] = float('-inf') + + # Collect new data + logging.info(f'Rank {rank} collecting data for task_id: {task_id}...') + collect_kwargs = {'temperature': visit_count_temperature(cfg.policy.manual_temperature_decay, cfg.policy.fixed_temperature_value, cfg.policy.threshold_training_steps_for_final_temperature, learner.train_iter)} + if cfg.policy.eps.eps_greedy_exploration_in_collect: + epsilon_fn = get_epsilon_greedy_fn(cfg.policy.eps.start, cfg.policy.eps.end, cfg.policy.eps.decay, cfg.policy.eps.type) + collect_kwargs['epsilon'] = epsilon_fn(collector.envstep) + + collector._policy.reset(reset_init_data=True, task_id=task_id) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + logging.info(f'Rank {rank}: Data collection finished for task {task_id}.') + + # --- 3. DDP Synchronization of Task Status and Weights --- + dist.barrier() + # Gather solved tasks from all ranks + all_solved_pools = [None for _ in range(world_size)] + dist.all_gather_object(all_solved_pools, solved_task_pool) + global_solved_task_pool = set().union(*[pool for pool in all_solved_pools if pool is not None]) + solved_task_pool = global_solved_task_pool # Sync local pool with global + global_solved_count = len(solved_task_pool) + + # Gather evaluation returns and compute task weights + task_weights = None + if learner.train_iter > 10 and learner.train_iter % compiled_cfg.policy.eval_freq == 0: + all_task_returns = [None for _ in range(world_size)] + dist.all_gather_object(all_task_returns, local_task_returns) + + merged_task_returns = {k: v for d in all_task_returns if d for k, v in d.items()} + for tid, ret in merged_task_returns.items(): + GLOBAL_EVAL_RETURNS[tid] = ret # Update global tracker + + unsolved_task_returns = {tid: ret for tid, ret in merged_task_returns.items() if tid not in solved_task_pool} + + if rank == 0: + logging.info(f"Global unsolved task returns for weight calculation: {unsolved_task_returns}") + if compiled_cfg.policy.task_complexity_weight and unsolved_task_returns: + temp = temperature_scheduler.get_temperature(learner.train_iter) + task_weights = compute_task_weights(unsolved_task_returns, option="rank", temperature=temp) + logging.info(f"Computed task weights: {task_weights}") + + # Log UniZero-MT normalized stats + # Convert arrays to dictionaries with task_id as keys + human_scores_dict = {i: new_HUMAN_SCORES[i] for i in range(len(new_HUMAN_SCORES))} + random_scores_dict = {i: new_RANDOM_SCORES[i] for i in range(len(new_RANDOM_SCORES))} + mean_norm, median_norm = compute_unizero_mt_normalized_stats( + GLOBAL_EVAL_RETURNS, human_scores_dict, random_scores_dict + ) + if mean_norm is not None: + tb_logger.add_scalar('UniZero-MT/NormalizedMean', mean_norm, learner.train_iter) + tb_logger.add_scalar('UniZero-MT/NormalizedMedian', median_norm, learner.train_iter) + logging.info(f"UniZero-MT Normalized Mean={mean_norm:.4f}, Median={median_norm:.4f}") + + # Broadcast weights from rank 0 to all other ranks + broadcast_objects = [task_weights] + dist.broadcast_object_list(broadcast_objects, src=0) + task_weights = broadcast_objects[0] + + # --- 4. Curriculum Stage Update --- + unsolved_count = total_tasks - global_solved_count + switched = curriculum_controller.step(global_solved_count, unsolved_count, learner.train_iter) + + if rank == 0: + tb_logger.add_scalar('Curriculum/Stage', curriculum_controller.stage, learner.train_iter) + tb_logger.add_scalar('Curriculum/GlobalSolvedTasks', global_solved_count, learner.train_iter) + + # Log alpha scaling factors for curriculum LoRA modules + try: + transformer = policy._learn_model.world_model.transformer + for module_name, module in transformer.named_modules(): + if isinstance(module, CurriculumLoRALinear): + # Check if the module has base_weight_scale attribute + if hasattr(module, 'base_weight_scale') and module.base_weight_scale is not None: + # Log base weight scaling factor (alpha_0) + tb_logger.add_scalar( + f'Curriculum/alpha_scales/{module_name}/alpha_0_base_weight', + module.base_weight_scale().item(), + global_step=learner.train_iter + ) + + # Check if the module has adapter_scales attribute + if hasattr(module, 'adapter_scales') and module.adapter_scales is not None: + # Iterate and log scaling factors for all adapters (alpha_1, alpha_2, ...) + for adapter_idx, scale_param in enumerate(module.adapter_scales): + # adapter_idx starts from 0, corresponding to alpha_{idx+1} + tb_logger.add_scalar( + f'Curriculum/alpha_scales/{module_name}/alpha_{adapter_idx + 1}', + scale_param().item(), + global_step=learner.train_iter + ) + except Exception as e: + logging.warning(f"Failed to log alpha scales: {e}") + + # Ensure all processes are aware of a potential stage switch + dist.barrier() + + # --- 5. Training Step --- + unsolved_buffers = [rb for cfg, rb in zip(task_configs, replay_buffers) if cfg.policy.task_id not in solved_task_pool] + unsolved_cfgs = [cfg for cfg in task_configs if cfg.policy.task_id not in solved_task_pool] + + if not unsolved_buffers: + logging.info(f"Rank {rank}: All assigned tasks are solved. Performing dummy training to maintain DDP sync.") + # When all local tasks are solved, we must still participate in DDP. + # A dummy forward/backward pass with zeroed gradients can ensure this. + # The current implementation uses a minimal batch from solved tasks with `ignore_grad=True`. + for _ in range(compiled_cfg.policy.update_per_collect): + train_data_list = [] + for cfg, replay_buffer in zip(task_configs, replay_buffers): # Use original buffers + batch_size = 2 # Minimal batch size for sync + if replay_buffer.get_num_of_transitions() >= batch_size: + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) + train_data_list.append(train_data) + + if train_data_list: + learner.train(train_data_list, collector.envstep, policy_kwargs={'task_weights': None, "ignore_grad": True}) + + else: + for _ in range(compiled_cfg.policy.update_per_collect): + train_data_list = [] + total_envstep = sum(c.envstep for c in collectors) + for cfg, replay_buffer in zip(unsolved_cfgs, unsolved_buffers): + # Handle batch_size whether it's an int, list, or dict + if isinstance(cfg.policy.batch_size, (list, tuple)): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + elif isinstance(cfg.policy.batch_size, dict): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + else: + # batch_size is already an integer + batch_size = cfg.policy.batch_size + + if replay_buffer.get_num_of_transitions() >= batch_size: + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) + train_data_list.append(train_data) + else: + logging.warning(f"Skipping training for task {cfg.policy.task_id}: not enough data in buffer.") + + if train_data_list: + learn_kwargs = {'task_weights': task_weights, "ignore_grad": False} + learner.train(train_data_list, total_envstep, policy_kwargs=learn_kwargs) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # --- 6. Synchronization and Termination Check --- + dist.barrier() # Ensure all ranks complete the training step + + # Check for termination conditions + max_iter_reached = torch.tensor([learner.train_iter >= max_train_iter], dtype=torch.bool, device=compiled_cfg.policy.device) + dist.all_reduce(max_iter_reached, op=dist.ReduceOp.SUM) + + # For env_step, gather from all collectors on all ranks + local_env_steps = torch.tensor([c.envstep for c in collectors], dtype=torch.long, device=compiled_cfg.policy.device) + all_env_steps = [torch.zeros_like(local_env_steps) for _ in range(world_size)] + # Note: all_gather requires all tensors to be the same size. This assumes each rank has the same number of collectors. + # If not, a more complex gathering method (e.g., all_gather_object) is needed. + try: + dist.all_gather(all_env_steps, local_env_steps) + max_step_reached = (torch.cat(all_env_steps).min() >= max_env_step) if all_env_steps else False + except RuntimeError: # If tensor sizes mismatch + max_step_reached = False # Fallback, consider logging an error + logging.warning("Could not gather env_steps due to tensor size mismatch across ranks. Termination check may be inaccurate.") + + if max_iter_reached.item() or max_step_reached: + logging.info(f"Rank {rank}: Termination condition met. Stopping training.") + break + + # --- Finalization --- + learner.call_hook('after_run') + return policy diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py new file mode 100644 index 000000000..0eb8d9606 --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -0,0 +1,592 @@ +import concurrent.futures +import os +from collections import defaultdict +from functools import partial +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import Policy, create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import EasyTimer, get_rank, get_world_size, set_pkg_seed +from ding.worker import BaseLearner +from ditk import logging +from lzero.entry.utils import ( + EVALUATION_TIMEOUT, + TemperatureScheduler, + allocate_batch_size, + compute_task_weights, + compute_unizero_mt_normalized_stats, + log_buffer_memory_usage, + safe_eval, + symlog, + inv_symlog, +) +# NOTE: The following imports are for type hinting purposes. +# The actual GameBuffer is selected dynamically based on the policy type. +from lzero.mcts import UniZeroGameBuffer +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector +from tensorboardX import SummaryWriter + +# ==================================================================================================================== +# Note: Benchmark score definitions are initialized dynamically within the `train_unizero_multitask_segment_ddp` +# function based on the `benchmark_name` argument to ensure correct score assignment. +# ==================================================================================================================== + +# Stores the latest evaluation returns: {task_id: eval_episode_return_mean} +GLOBAL_EVAL_RETURNS: Dict[int, float] = defaultdict(lambda: None) + +timer = EasyTimer() + + +def train_unizero_multitask_segment_ddp( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), + benchmark_name: str = "atari" +) -> 'Policy': + """ + Overview: + The training entry point for UniZero, designed to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations of MuZero-like algorithms in environments requiring long-term dependency capture. + For more details, please refer to https://arxiv.org/abs/2406.10667. + + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[dict, dict]]]`): A list of configurations for different tasks. + - seed (:obj:`int`): The random seed. + - model (:obj:`Optional[torch.nn.Module]`): An instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The path to a pre-trained model checkpoint file. + - max_train_iter (:obj:`Optional[int]`): The maximum number of policy update iterations during training. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment interaction steps to collect. + - benchmark_name (:obj:`str`): The name of the benchmark, e.g., "atari" or "dmc". + + Returns: + - policy (:obj:`Policy`): The converged policy. + """ + # ------------------------------------------------------------------------------------ + # ====== UniZero-MT Benchmark Scores (corresponding to 26 Atari100k task IDs) ====== + # Original RANDOM_SCORES and HUMAN_SCORES + if benchmark_name == "atari": + RANDOM_SCORES = np.array([ + 227.8, 5.8, 222.4, 210.0, 14.2, 2360.0, 0.1, 1.7, 811.0, 10780.5, + 152.1, 0.0, 65.2, 257.6, 1027.0, 29.0, 52.0, 1598.0, 258.5, 307.3, + -20.7, 24.9, 163.9, 11.5, 68.4, 533.4 + ]) + HUMAN_SCORES = np.array([ + 7127.7, 1719.5, 742.0, 8503.3, 753.1, 37187.5, 12.1, 30.5, 7387.8, 35829.4, + 1971.0, 29.6, 4334.7, 2412.5, 30826.4, 302.8, 3035.0, 2665.5, 22736.3, 6951.6, + 14.6, 69571.3, 13455.0, 7845.0, 42054.7, 11693.2 + ]) + elif benchmark_name == "dmc": + RANDOM_SCORES = np.zeros(26) + HUMAN_SCORES = np.ones(26) * 1000 + else: + raise ValueError(f"Unsupported BENCHMARK_NAME: {benchmark_name}") + + # New order to original index mapping + # New order: [Pong, MsPacman, Seaquest, Boxing, Alien, ChopperCommand, Hero, RoadRunner, + # Amidar, Assault, Asterix, BankHeist, BattleZone, CrazyClimber, DemonAttack, + # Freeway, Frostbite, Gopher, Jamesbond, Kangaroo, Krull, KungFuMaster, + # PrivateEye, UpNDown, Qbert, Breakout] + # Mapping to indices in the original array (0-based) + new_order = [ + 20, 19, 24, 6, 0, 8, 14, 23, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 15, 16, 17, 18, 21, 25, 22, 7 + ] + global new_RANDOM_SCORES, new_HUMAN_SCORES + # Generate new arrays based on new_order + new_RANDOM_SCORES = RANDOM_SCORES[new_order] + new_HUMAN_SCORES = HUMAN_SCORES[new_order] + # Log the reordered results + logging.info("Reordered RANDOM_SCORES:") + logging.info(new_RANDOM_SCORES) + logging.info("\nReordered HUMAN_SCORES:") + logging.info(new_HUMAN_SCORES) + # ------------------------------------------------------------------------------------ + + # Initialize the temperature scheduler for task weighting. + initial_temperature = 10.0 + final_temperature = 1.0 + threshold_steps = int(1e4) # Temperature drops to 1.0 after 10k training steps. + temperature_scheduler = TemperatureScheduler( + initial_temp=initial_temperature, + final_temp=final_temperature, + threshold_steps=threshold_steps, + mode='linear' # or 'exponential' + ) + + # Get the current process rank and total world size. + rank = get_rank() + world_size = get_world_size() + + # Task partitioning among ranks. + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + # 1. Precisely calculate the number of tasks assigned to the current rank. + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + num_tasks_for_this_rank = tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + num_tasks_for_this_rank = tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + # Ensure at least one task is assigned. + if len(tasks_for_this_rank) == 0: + logging.warning(f"Rank {rank}: No tasks assigned, continuing execution.") + # Initialize empty lists to avoid errors later. + cfgs, game_buffers, collector_envs, evaluator_envs, collectors, evaluators = [], [], [], [], [], [] + else: + logging.info(f"Rank {rank}/{world_size} processing tasks {start_idx} to {end_idx - 1}") + + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + if tasks_for_this_rank: + # Use the config of the first task to create a shared policy. + task_id, [cfg, create_cfg] = tasks_for_this_rank[0] + + # ==================== START: Critical Fix ==================== + # 2. Set the correct task count to *all* related configurations. + # Configuration must be correct before creating the Policy instance. + for config_tuple in tasks_for_this_rank: + # config_tuple is (task_id, [cfg_obj, create_cfg_obj]) + config_tuple[1][0].policy.task_num = num_tasks_for_this_rank + + # 3. Ensure the cfg object used to create the Policy also has the correct task_num. + cfg.policy.task_num = num_tasks_for_this_rank + # ==================== END: Critical Fix ==================== + + # Ensure the specified policy type is supported. + assert create_cfg.policy.type in ['unizero_multitask', 'sampled_unizero_multitask'], \ + "train_unizero entry currently only supports 'unizero_multitask' or 'sampled_unizero_multitask'" + + if create_cfg.policy.type == 'unizero_multitask': + from lzero.mcts import UniZeroGameBuffer as GameBuffer + if create_cfg.policy.type == 'sampled_unizero_multitask': + from lzero.mcts import SampledUniZeroGameBuffer as GameBuffer + + # Set device based on CUDA availability. + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'Configured device: {cfg.policy.device}') + + # Compile the configuration. + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create the shared policy. + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load a pre-trained model if a path is provided. + if model_path is not None: + logging.info(f'Starting to load model: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Finished loading model: {model_path}') + + # Create a TensorBoard logger. + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # Create the shared learner. + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + policy_config = cfg.policy + + # Process each task assigned to the current rank. + for local_task_id, (task_id, [cfg, create_cfg]) in enumerate(tasks_for_this_rank): + # Set a unique random seed for each task. + cfg.policy.device = 'cuda' if cfg.policy.device == 'cuda' and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed + task_id, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + # Create environments. + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # Create task-specific game buffers, collectors, and evaluators. + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + # Handle batch_size robustly - it might be a list or already an integer + if isinstance(cfg.policy.batch_size, (list, tuple)): + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + elif isinstance(cfg.policy.batch_size, dict): + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + else: + replay_buffer.batch_size = cfg.policy.batch_size + + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + # Call the learner's before_run hook. + learner.call_hook('before_run') + value_priority_tasks = {} + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + update_per_collect = cfg.policy.update_per_collect + + task_exploitation_weight = None + + # Dictionary to store task rewards. + task_returns = {} # {task_id: reward} + + while True: + # Dynamically adjust batch sizes. + if cfg.policy.allocated_batch_sizes: + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + logging.info("Allocated batch_sizes: ", allocated_batch_sizes) + # Assign the corresponding batch size to each task config + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + task_id = cfg.policy.task_id + if isinstance(allocated_batch_sizes, dict): + cfg.policy.batch_size = allocated_batch_sizes.get(task_id, cfg.policy.batch_size) + elif isinstance(allocated_batch_sizes, list): + # Use the index in the list or task_id as fallback + cfg.policy.batch_size = allocated_batch_sizes[idx] if idx < len(allocated_batch_sizes) else cfg.policy.batch_size + else: + logging.warning(f"Unexpected type for allocated_batch_sizes: {type(allocated_batch_sizes)}") + # Also update the policy config (use the full list for compatibility) + policy._cfg.batch_size = allocated_batch_sizes + + # For each task on the current rank, perform data collection and evaluation. + for idx, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + + # Log buffer memory usage. + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, cfg.policy.task_id) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value. + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # Check if it's time for evaluation. + if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0: + # if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0: # TODO: Only for debug + + logging.info('=' * 20) + logging.info(f'Rank {rank} evaluating task_id: {cfg.policy.task_id}...') + + # TODO: Ensure policy reset logic is optimal for multi-task settings. + evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + + # Perform safe evaluation. + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + # Check if evaluation was successful. + if stop is None or reward is None: + logging.warning(f"Rank {rank} encountered issues during evaluation, continuing training...") + task_returns[cfg.policy.task_id] = float('inf') # Set task difficulty to max if evaluation fails. + else: + # Extract 'eval_episode_return_mean' from the reward dictionary. + try: + eval_mean_reward = reward.get('eval_episode_return_mean', float('inf')) + logging.info(f"Task {cfg.policy.task_id} evaluation reward: {eval_mean_reward}") + task_returns[cfg.policy.task_id] = eval_mean_reward + except Exception as e: + logging.error(f"Error extracting evaluation reward: {e}") + task_returns[cfg.policy.task_id] = float('inf') # Set reward to max on error. + + logging.info('=' * 20) + logging.info(f'Starting collection for Rank {rank} task_id: {cfg.policy.task_id}...') + logging.info(f'Rank {rank}: cfg.policy.task_id={cfg.policy.task_id} ') + logging.info(f'Rank {rank}: Starting data collection for task {cfg.policy.task_id} at train_iter {learner.train_iter}') + + # Reset initial data before each collection, crucial for multi-task settings. + collector._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id) + # Collect data. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + logging.info(f'Rank {rank}: Finished data collection for task {cfg.policy.task_id}, collected {len(new_data[0]) if new_data else 0} segments') + + # Update the replay buffer. + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # Periodically reanalyze the buffer. + if cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + if train_epoch > 0 and train_epoch % int(1 / cfg.policy.buffer_reanalyze_freq) == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalysis count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalysis time: {timer.value}') + + # Log after data collection. + logging.info(f'Rank {rank}: Completed data collection for task {cfg.policy.task_id}') + + # ========== Synchronize all ranks after data collection ========== + # Wait for all ranks to complete their data collection before proceeding. + # This prevents fast-collecting ranks from reaching barriers/all_gather calls + # while slow-collecting ranks are still in the collection loop. + try: + logging.info(f'Rank {rank}: Waiting at post-collection barrier...') + dist.barrier() + logging.info(f'Rank {rank}: All ranks completed data collection, proceeding...') + except Exception as e: + logging.error(f'Rank {rank}: Post-collection barrier failed, error: {e}') + raise e + # =============================================================================== + + # Check if there is enough data for training. + local_not_enough_data = any( + replay_buffer.get_num_of_transitions() < cfgs[0].policy.total_batch_size / world_size + for replay_buffer in game_buffers + ) + logging.info(f"Rank {rank} local_not_enough_data:{local_not_enough_data}") + flag_tensor = torch.tensor(1.0 if local_not_enough_data else 0.0, device=cfg.policy.device) + dist.all_reduce(flag_tensor, op=dist.ReduceOp.MAX) + not_enough_data = (flag_tensor.item() > 0.5) + if rank == 0: + logging.info(f"Global not_enough_data status: {not_enough_data}") + + # Get the current temperature for task weighting. + current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) + + if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0: + # Calculate task weights. + try: + # Gather task rewards. + logging.info(f'Rank {rank}: Entering evaluation synchronization barrier at train_iter {learner.train_iter}') + dist.barrier() + logging.info(f'Rank {rank}: Passed evaluation barrier, gathering task returns') + all_task_returns = [None for _ in range(world_size)] + dist.all_gather_object(all_task_returns, task_returns) + # Merge task rewards. + merged_task_returns = {} + for returns in all_task_returns: + if returns: + merged_task_returns.update(returns) + + logging.warning(f"Rank {rank}: merged_task_returns: {merged_task_returns}") + + # Calculate global task weights. + task_weights = compute_task_weights(merged_task_returns, temperature=current_temperature_task_weight) + + # ---------- Maintain UniZero-MT global evaluation results ---------- + for tid, ret in merged_task_returns.items(): + GLOBAL_EVAL_RETURNS[tid] = ret # Update even for solved tasks. + + # Calculate Human-Normalized Mean / Median. + # Convert arrays to dictionaries with task_id as keys + human_scores_dict = {i: new_HUMAN_SCORES[i] for i in range(len(new_HUMAN_SCORES))} + random_scores_dict = {i: new_RANDOM_SCORES[i] for i in range(len(new_RANDOM_SCORES))} + uni_mean, uni_median = compute_unizero_mt_normalized_stats( + GLOBAL_EVAL_RETURNS, human_scores_dict, random_scores_dict + ) + + if uni_mean is not None: # At least one task has been evaluated. + if rank == 0: # Only write to TensorBoard on rank 0 to avoid duplication. + tb_logger.add_scalar('UniZero-MT/NormalizedMean', uni_mean, global_step=learner.train_iter) + tb_logger.add_scalar('UniZero-MT/NormalizedMedian', uni_median, global_step=learner.train_iter) + logging.info(f"Rank {rank}: UniZero-MT Norm Mean={uni_mean:.4f}, Median={uni_median:.4f}") + else: + logging.info(f"Rank {rank}: No data available to compute UniZero-MT normalized metrics") + + # Synchronize task weights. + dist.broadcast_object_list([task_weights], src=0) + except Exception as e: + logging.error(f'Rank {rank}: Failed to synchronize task weights, error: {e}') + break + + # ---------------- Sampling done, preparing for backward pass ---------------- + # dist.barrier() # ★★★ Critical synchronization point ★★★ + + # Learn policy. + if not not_enough_data: + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for idx, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + # Handle batch_size robustly - it might be a list or already an integer + if isinstance(cfg.policy.batch_size, (list, tuple)): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + elif isinstance(cfg.policy.batch_size, dict): + batch_size = cfg.policy.batch_size[cfg.policy.task_id] + else: + batch_size = cfg.policy.batch_size + + if replay_buffer.get_num_of_transitions() > batch_size: + if cfg.policy.buffer_reanalyze_freq >= 1: + if i % reanalyze_interval == 0 and \ + replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int( + reanalyze_batch_size / cfg.policy.reanalyze_partition): + with timer: + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalysis count: {buffer_reanalyze_count}') + logging.info(f'Buffer reanalysis time: {timer.value}') + + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(cfg.policy.task_id) # Append task_id to differentiate tasks. + train_data_multi_task.append(train_data) + else: + logging.warning( + f'Insufficient data in replay buffer to sample mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + learn_kwargs = {'task_weights': None,"train_iter":learner.train_iter} + + # DDP automatically synchronizes gradients and parameters during training. + log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) + + # Check if task_exploitation_weight needs to be calculated. + if i == 0: + # Calculate task weights. + try: + dist.barrier() # Wait for all processes to synchronize. + if cfg.policy.use_task_exploitation_weight: # Use obs loss now, new polish. + # Gather obs_loss from all tasks. + all_obs_loss = [None for _ in range(world_size)] + # Build obs_loss data for the current process's tasks. + merged_obs_loss_task = {} + for cfg, replay_buffer in zip(cfgs, game_buffers): + task_id = cfg.policy.task_id + if f'noreduce_obs_loss_task{task_id}' in log_vars[0]: + merged_obs_loss_task[task_id] = log_vars[0][ + f'noreduce_obs_loss_task{task_id}'] + # Gather obs_loss data from all processes. + dist.all_gather_object(all_obs_loss, merged_obs_loss_task) + # Merge obs_loss data from all processes. + global_obs_loss_task = {} + for obs_loss_task in all_obs_loss: + if obs_loss_task: + global_obs_loss_task.update(obs_loss_task) + # Calculate global task weights. + if global_obs_loss_task: + task_exploitation_weight = compute_task_weights( + global_obs_loss_task, + option="rank", + # TODO: Decide whether to use the temperature scheduler here. + temperature=1, + ) + # Broadcast task weights to all processes. + dist.broadcast_object_list([task_exploitation_weight], src=0) + logging.info( + f"rank{rank}, task_exploitation_weight (sorted by task_id): {task_exploitation_weight}") + else: + logging.warning(f"Rank {rank}: Unable to compute global obs_loss task weights, obs_loss data is empty.") + task_exploitation_weight = None + else: + task_exploitation_weight = None + # Update training parameters to include the calculated task weights. + learn_kwargs['task_weight'] = task_exploitation_weight + except Exception as e: + logging.error(f'Rank {rank}: Failed to synchronize task weights, error: {e}') + raise e # Re-raise the exception for external capture and analysis. + + if cfg.policy.use_priority: + for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)): + # Update task-specific replay buffer priorities. + task_id = cfg.policy.task_id + replay_buffer.update_priority( + train_data_multi_task[idx], + log_vars[0][f'noreduce_value_priority_task{task_id}'] + ) + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # Synchronize all ranks to ensure they have completed training. + try: + dist.barrier() + logging.info(f'Rank {rank}: Passed synchronization barrier after training') + except Exception as e: + logging.error(f'Rank {rank}: Synchronization barrier failed, error: {e}') + break + + # Check for termination conditions. + try: + local_envsteps = [collector.envstep for collector in collectors] + total_envsteps = [None for _ in range(world_size)] + dist.all_gather_object(total_envsteps, local_envsteps) + + all_envsteps = torch.cat([torch.tensor(envsteps, device=cfg.policy.device) for envsteps in total_envsteps]) + max_envstep_reached = torch.all(all_envsteps >= max_env_step) + + # Gather train_iter from all processes. + global_train_iter = torch.tensor([learner.train_iter], device=cfg.policy.device) + all_train_iters = [torch.zeros_like(global_train_iter) for _ in range(world_size)] + dist.all_gather(all_train_iters, global_train_iter) + + max_train_iter_reached = torch.any(torch.stack(all_train_iters) >= max_train_iter) + + if max_envstep_reached.item() or max_train_iter_reached.item(): + logging.info(f'Rank {rank}: Termination condition reached') + dist.barrier() # Ensure all processes synchronize before exiting. + break + except Exception as e: + logging.error(f'Rank {rank}: Termination check failed, error: {e}') + break + + # Call the learner's after_run hook. + learner.call_hook('after_run') + return policy diff --git a/lzero/entry/train_unizero_multitask_segment_eval.py b/lzero/entry/train_unizero_multitask_segment_eval.py new file mode 100644 index 000000000..dbf4a891f --- /dev/null +++ b/lzero/entry/train_unizero_multitask_segment_eval.py @@ -0,0 +1,300 @@ +from ditk import logging +import os +import concurrent.futures +from functools import partial +from typing import Tuple, Optional, List, Dict, Any, Type + +import torch +import torch.distributed as dist +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy, Policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank, get_world_size, EasyTimer +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage, safe_eval, allocate_batch_size +from lzero.policy import visit_count_temperature +from lzero.mcts import UniZeroGameBuffer as GameBuffer +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroSegmentCollector as Collector + +# Configure basic logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', +) + + +def train_unizero_multitask_segment_eval( + input_cfg_list: List[Tuple[int, Tuple[Dict[str, Any], Dict[str, Any]]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The main training entry point for UniZero, as proposed in the paper "UniZero: Generalized and Efficient Planning + with Scalable Latent World Models" (https://arxiv.org/abs/2406.10667). This function sets up a distributed + multi-task training environment where multiple reinforcement learning tasks are trained in parallel using a + single shared model. It handles task distribution, component initialization (policy, learner, buffers, etc.), + and the main training loop orchestration. + Arguments: + - input_cfg_list (:obj:`List[Tuple[int, Tuple[Dict, Dict]]]`): A list of configurations for each task. Each + element is a tuple containing the task ID and its corresponding configuration dictionaries. + - seed (:obj:`int`): The master random seed for reproducibility. + - model (:obj:`Optional[torch.nn.Module]`): An optional pre-existing model instance. If None, a new model is + created based on the config. + - model_path (:obj:`Optional[str]`): An optional path to a pre-trained model checkpoint. + - max_train_iter (:obj:`Optional[int]`): The maximum number of training iterations before termination. + - max_env_step (:obj:`Optional[int]`): The maximum number of environment steps before termination. + Returns: + - (:obj:`'Policy'`): The trained policy instance after the training loop has converged or terminated. + """ + # ============================================================== + # 1. Initialization + # ============================================================== + + # 1.1. Distributed Setup & Task Partitioning + rank = get_rank() + world_size = get_world_size() + + total_tasks = len(input_cfg_list) + tasks_per_rank = total_tasks // world_size + remainder = total_tasks % world_size + + if rank < remainder: + start_idx = rank * (tasks_per_rank + 1) + end_idx = start_idx + tasks_per_rank + 1 + else: + start_idx = rank * tasks_per_rank + remainder + end_idx = start_idx + tasks_per_rank + + tasks_for_this_rank = input_cfg_list[start_idx:end_idx] + + if not tasks_for_this_rank: + logging.warning(f"Rank {rank}: No tasks assigned. This rank will be idle.") + # Keep the process alive to participate in collective communications. + dist.barrier() + return + + logging.info(f"Rank {rank}/{world_size}: Handling tasks from index {start_idx} to {end_idx - 1}.") + + # 1.2. Shared Policy, Learner, and Logger Initialization + # Use the configuration of the first task on this rank to create the shared components. + _, (first_cfg, first_create_cfg) = tasks_for_this_rank[0] + + # Set task_num for learner logging purposes. + for _, (cfg, _) in tasks_for_this_rank: + cfg.policy.task_num = tasks_per_rank + + assert first_create_cfg.policy.type in ['unizero_multitask'], \ + "This entry point currently only supports 'unizero_multitask' policy type." + + first_cfg.policy.device = 'cuda' if torch.cuda.is_available() else 'cpu' + logging.info(f'Shared policy device: {first_cfg.policy.device}') + + # Compile the main configuration. + cfg = compile_config(first_cfg, seed=seed, auto=True, create_cfg=first_create_cfg, save_cfg=True) + + # Create the shared policy. + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load a pre-trained model if a path is provided. + if model_path is not None: + logging.info(f'Loading pre-trained model from: {model_path}') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info('Model loading complete.') + + # Create a TensorBoard logger for this rank. + log_dir = os.path.join(f'./{cfg.exp_name}/log', f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + + # Create the shared learner instance. + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # 1.3. Task-Specific Components Initialization + cfgs, game_buffers, collectors, evaluators = [], [], [], [] + for task_id, (task_cfg, task_create_cfg) in tasks_for_this_rank: + # Set a unique seed for each task to ensure diversity in data collection. + task_seed = seed + task_id + task_cfg.policy.device = 'cuda' if task_cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + task_cfg = compile_config(task_cfg, seed=task_seed, auto=True, create_cfg=task_create_cfg, save_cfg=True) + + policy.collect_mode.get_attribute('cfg').n_episode = task_cfg.policy.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = task_cfg.policy.n_episode + + # Create environment managers for collection and evaluation. + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(task_cfg.env) + collector_env = create_env_manager(task_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(task_cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(task_seed) + evaluator_env.seed(task_seed, dynamic_seed=False) + set_pkg_seed(task_seed, use_cuda=task_cfg.policy.cuda) + + # Create task-specific buffers, collectors, and evaluators. + replay_buffer = GameBuffer(task_cfg.policy) + replay_buffer.batch_size = task_cfg.policy.batch_size[task_id] + + collector = Collector( + env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=task_cfg.exp_name, + policy_config=task_cfg.policy, task_id=task_id + ) + evaluator = Evaluator( + eval_freq=task_cfg.policy.eval_freq, n_evaluator_episode=task_cfg.env.n_evaluator_episode, + stop_value=task_cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, + tb_logger=tb_logger, exp_name=task_cfg.exp_name, policy_config=task_cfg.policy, task_id=task_id + ) + + cfgs.append(task_cfg) + game_buffers.append(replay_buffer) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + + # ============================================================== + # 2. Main Training Loop + # ============================================================== + buffer_reanalyze_count = 0 + train_epoch = 0 + while True: + if learner.train_iter >= max_train_iter or collector.envstep >= max_env_step: + break + + # 2.1. Dynamic Batch Size Allocation (Optional) + if cfg.policy.allocated_batch_sizes: + # As training progresses, allow for a larger divergence in batch sizes. + clip_scale = np.clip(1 + (3 * train_epoch / 1000), 1, 4) + allocated_batch_sizes = allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=clip_scale) + if rank == 0: + logging.info(f"Allocated batch sizes: {allocated_batch_sizes}") + for task_cfg, replay_buffer in zip(cfgs, game_buffers): + task_cfg.policy.batch_size = allocated_batch_sizes + policy._cfg.batch_size = allocated_batch_sizes + + # 2.2. Collection and Evaluation Phase + for task_cfg, collector, evaluator, replay_buffer in zip(cfgs, collectors, evaluators, game_buffers): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger, task_cfg.policy.task_id) + + # Determine exploration parameters for collection. + collect_kwargs = { + 'temperature': visit_count_temperature( + task_cfg.policy.manual_temperature_decay, task_cfg.policy.fixed_temperature_value, + task_cfg.policy.threshold_training_steps_for_final_temperature, trained_steps=learner.train_iter + ), + 'epsilon': 0.0 + } + if task_cfg.policy.eps.eps_greedy_exploration_in_collect: + epsilon_fn = get_epsilon_greedy_fn( + start=task_cfg.policy.eps.start, end=task_cfg.policy.eps.end, + decay=task_cfg.policy.eps.decay, type_=task_cfg.policy.eps.type + ) + collect_kwargs['epsilon'] = epsilon_fn(collector.envstep) + + # Evaluate the policy periodically. + if evaluator.should_eval(learner.train_iter): + logging.info(f'Rank {rank} evaluating task_id: {task_cfg.policy.task_id}...') + stop, reward = safe_eval(evaluator, learner, collector, rank, world_size) + if stop is None or reward is None: + logging.warning(f"Rank {rank} evaluation for task {task_cfg.policy.task_id} failed or timed out.") + else: + logging.info(f"Evaluation successful for task {task_cfg.policy.task_id}: stop={stop}, reward={reward}") + + # Collect new data. + logging.info(f'Rank {rank} collecting for task_id: {task_cfg.policy.task_id}...') + # NOTE: Resetting initial data is crucial in multi-task settings to avoid state leakage. + collector._policy.reset(reset_init_data=True) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Update the replay buffer. + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # Periodically reanalyze the buffer to update value/policy targets with a more recent model. + # This logic handles two cases for `buffer_reanalyze_freq`: + # Case 1: freq < 1 (e.g., 0.5) -> Reanalyze every `1/freq` training epochs. + if 0 < task_cfg.policy.buffer_reanalyze_freq < 1: + if (train_epoch % int(1 / task_cfg.policy.buffer_reanalyze_freq) == 0 and + replay_buffer.get_num_of_transitions() // task_cfg.policy.num_unroll_steps > + int(task_cfg.policy.reanalyze_batch_size / task_cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + replay_buffer.reanalyze_buffer(task_cfg.policy.reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, Time: {timer.value:.2f}s') + + logging.info(f'Rank {rank}: Data collection complete for task {task_cfg.policy.task_id}') + + # 2.3. Pre-Training Synchronization and Data Check + # Check if any buffer has insufficient data for training. + not_enough_data = any( + rb.get_num_of_transitions() < cfg.policy.total_batch_size / world_size for rb in game_buffers + ) + + try: + dist.barrier() + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed before training with error {e}', exc_info=True) + break + + # 2.4. Training Phase + if not not_enough_data: + update_per_collect = cfg.policy.update_per_collect + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = sum(c.envstep for c in collectors) + + for task_cfg, replay_buffer in zip(cfgs, game_buffers): + batch_size = task_cfg.policy.batch_size[task_cfg.policy.task_id] + if replay_buffer.get_num_of_transitions() > batch_size: + # Case 2: freq >= 1 -> Reanalyze `freq` times per collection cycle (spread across updates). + if task_cfg.policy.buffer_reanalyze_freq >= 1: + reanalyze_interval = update_per_collect // task_cfg.policy.buffer_reanalyze_freq + if (i % reanalyze_interval == 0 and + replay_buffer.get_num_of_transitions() // task_cfg.policy.num_unroll_steps > + int(task_cfg.policy.reanalyze_batch_size / task_cfg.policy.reanalyze_partition)): + with EasyTimer() as timer: + replay_buffer.reanalyze_buffer(task_cfg.policy.reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}, Time: {timer.value:.2f}s') + + # Sample data and append task_id for multi-task learning. + train_data = replay_buffer.sample(batch_size, policy) + train_data.append(task_cfg.policy.task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f"Skipping training for task {task_cfg.policy.task_id}: insufficient data. " + f"Required: {batch_size}, Available: {replay_buffer.get_num_of_transitions()}" + ) + + if train_data_multi_task: + # DDP handles gradient synchronization automatically. + learner.train(train_data_multi_task, envstep_multi_task) + + # Synchronize after each training step to maintain consistency. + try: + dist.barrier() + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed during training step with error {e}', exc_info=True) + break + else: + logging.warning(f"Rank {rank}: Skipping training cycle due to insufficient data in one or more buffers.") + + train_epoch += 1 + policy.recompute_pos_emb_diff_and_clear_cache() + + # 2.5. Post-Training Synchronization and Termination Check + try: + dist.barrier() + except Exception as e: + logging.error(f'Rank {rank}: Barrier failed after training cycle with error {e}', exc_info=True) + break + + learner.call_hook('after_run') + logging.info(f"Rank {rank}: Training finished.") + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_segment.py b/lzero/entry/train_unizero_segment.py index c1ed74b16..0a33b789b 100644 --- a/lzero/entry/train_unizero_segment.py +++ b/lzero/entry/train_unizero_segment.py @@ -1,27 +1,25 @@ -import logging import os from functools import partial -from typing import Tuple, Optional +from typing import Optional, Tuple import torch import wandb from ding.config import compile_config -from ding.envs import create_env_manager -from ding.envs import get_vec_env_setting +from ding.envs import create_env_manager, get_vec_env_setting from ding.policy import create_policy from ding.rl_utils import get_epsilon_greedy_fn -from ding.utils import EasyTimer -from ding.utils import set_pkg_seed, get_rank, get_world_size +from ding.utils import EasyTimer, get_rank, get_world_size, set_pkg_seed from ding.worker import BaseLearner -from tensorboardX import SummaryWriter -from torch.utils.tensorboard import SummaryWriter - +from ditk import logging from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroSegmentCollector as Collector -from .utils import random_collect, calculate_update_per_collect +from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter + +from .utils import calculate_update_per_collect, random_collect timer = EasyTimer() @@ -76,7 +74,6 @@ def train_unizero_segment( evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) collector_env.seed(cfg.seed) - # collector_env.seed(cfg.seed, dynamic_seed=False) evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) @@ -154,7 +151,7 @@ def train_unizero_segment( collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) # Evaluate policy performance - if evaluator.should_eval(learner.train_iter): + if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) if stop: break diff --git a/lzero/entry/train_unizero_with_loss_landscape.py b/lzero/entry/train_unizero_with_loss_landscape.py index 356802263..c87ec2a98 100644 --- a/lzero/entry/train_unizero_with_loss_landscape.py +++ b/lzero/entry/train_unizero_with_loss_landscape.py @@ -1,4 +1,4 @@ -import logging +from ditk import logging import os from functools import partial from typing import Tuple, Optional diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index 40a90a387..dc8dacf0f 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -1,252 +1,668 @@ +# -*- coding: utf-8 -*- +""" +### 🛠️ Utility Modules + +- **`utils.py`** - Common utility functions library + - **Math & Tensor Utilities**: + - `symlog`, `inv_symlog` - Symmetric logarithm transformations + - `initialize_zeros_batch`, `initialize_pad_batch` - Batch initialization + + - **LoRA Utilities**: + - `freeze_non_lora_parameters` - Freeze non-LoRA parameters + + - **Task & Curriculum Learning Utilities**: + - `compute_task_weights` - Compute task weights + - `TemperatureScheduler` - Temperature scheduler + - `tasks_per_stage` - Calculate tasks per stage + - `compute_unizero_mt_normalized_stats` - Compute normalized statistics + - `allocate_batch_size` - Dynamically allocate batch sizes + + - **Distributed Training Utilities (DDP)**: + - `is_ddp_enabled` - Check if DDP is enabled + - `ddp_synchronize` - DDP synchronization + - `ddp_all_reduce_sum` - DDP all-reduce sum + + - **RL Workflow Utilities**: + - `calculate_update_per_collect` - Calculate updates per collection + - `random_collect` - Random policy data collection + - `convert_to_batch_for_unizero` - UniZero batch data conversion + - `create_unizero_loss_metrics` - Create loss metrics function + - `UniZeroDataLoader` - UniZero data loader + + - **Logging Utilities**: + - `log_module_trainable_status` - Log module trainable status + - `log_param_statistics` - Log parameter statistics + - `log_buffer_memory_usage` - Log buffer memory usage + - `log_buffer_run_time` - Log buffer runtime + +- **`__init__.py`** - Package initialization file + - Exports all training and evaluation entry functions + - Exports commonly used functions from utility modules + +""" + +# ============================================================================== +# Imports +# ============================================================================== +from __future__ import annotations + +import concurrent.futures +from ditk import logging +import math import os -from typing import Optional, Callable, Union, List, Tuple +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import psutil import torch import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F from pympler.asizeof import asizeof from tensorboardX import SummaryWriter -import torch -import torch.distributed as dist -def is_ddp_enabled(): +# ============================================================================== +# Placeholder Types for External Dependencies +# +# To ensure type hints work without having the full definitions of these complex +# external classes, we define them as `Any`. +# ============================================================================== +EasyDict = Any +Policy = Any +RandomPolicy = Any +ISerialCollector = Any +BaseEnvManager = Any +IBuffer = Any +GameBuffer = Any +BaseLearner = Any +Evaluator = Any +Collector = Any + + +# ============================================================================== +# Global Constants +# ============================================================================== + +# Timeout for evaluation process in seconds (200 minutes) +EVALUATION_TIMEOUT = 12000 + + +# ============================================================================== +# Mathematical & Tensor Utilities +# ============================================================================== + +def symlog(x: torch.Tensor) -> torch.Tensor: """ - Check if Distributed Data Parallel (DDP) is enabled by verifying if - PyTorch's distributed package is available and initialized. + Overview: + Applies the symlog transformation to a tensor, which is useful for + normalizing target values with large magnitude differences. + The transformation is defined as: symlog(x) = sign(x) * log(|x| + 1). + + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + + Returns: + - torch.Tensor: The tensor after applying the symlog transformation. """ - return dist.is_available() and dist.is_initialized() + return torch.sign(x) * torch.log(torch.abs(x) + 1) + -def ddp_synchronize(): +def inv_symlog(x: torch.Tensor) -> torch.Tensor: """ - Perform a barrier synchronization across all processes in DDP mode. - Ensures all processes reach this point before continuing. + Overview: + Applies the inverse of the symlog transformation to a tensor, restoring + the original scale of the values. + The transformation is defined as: inv_symlog(x) = sign(x) * (exp(|x|) - 1). + + Arguments: + - x (:obj:`torch.Tensor`): The input tensor in symlog space. + + Returns: + - torch.Tensor: The tensor restored to its original scale. """ - if is_ddp_enabled(): - dist.barrier() + return torch.sign(x) * (torch.exp(torch.abs(x)) - 1) -def ddp_all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: + +# ============================================================================== +# LoRA (Low-Rank Adaptation) Utilities +# ============================================================================== + +# A compiled regex pattern to efficiently detect LoRA-related parameters. +# It matches parameter names ending with: +# - .lora_A or .lora_B (for LoRA weights) +# - .adapter_scales.{digit}.logit (for learnable scale parameters) +_LORA_PAT = re.compile(r"\.(?:lora_[AB]|adapter_scales\.\d+\.logit)$") + + +def _is_lora_param(name: str) -> bool: + """A helper function to check if a parameter name matches the LoRA pattern.""" + return bool(_LORA_PAT.search(name)) + + +def freeze_non_lora_parameters( + module: nn.Module, + freeze: bool = True, + *, + verbose: bool = False, +) -> Tuple[int, int]: """ - Perform an all-reduce operation (sum) on the given tensor across - all processes in DDP mode. Returns the reduced tensor. + Overview: + Freezes or un-freezes all parameters in a module that are not identified + as LoRA-related parameters. This is useful for curriculum learning stages + where the backbone model is frozen and only LoRA adapters are trained. Arguments: - - tensor (:obj:`torch.Tensor`): The input tensor to be reduced. + - module (:obj:`nn.Module`): The PyTorch module to process (e.g., a transformer). + - freeze (:obj:`bool`): If True, sets `requires_grad=False` for non-LoRA parameters. + If False, sets `requires_grad=True` for non-LoRA parameters. + - verbose (:obj:`bool`): If True, prints a summary of trainable and frozen parameters. Returns: - - torch.Tensor: The reduced tensor, summed across all processes. + - Tuple[int, int]: A tuple containing the number of frozen parameters and trainable parameters. """ - if is_ddp_enabled(): - dist.all_reduce(tensor, op=dist.ReduceOp.SUM) - return tensor - -def calculate_update_per_collect(cfg: 'EasyDict', new_data: List[List[torch.Tensor]], world_size: int = 1) -> int: + n_frozen = 0 + n_trainable = 0 + + for name, param in module.named_parameters(): + if _is_lora_param(name): + # LoRA-related parameters should always be trainable. + param.requires_grad = True + n_trainable += 1 + else: + # All other parameters are frozen or unfrozen based on the `freeze` flag. + param.requires_grad = not freeze + if param.requires_grad: + n_trainable += 1 + else: + n_frozen += 1 + + if verbose: + total = n_frozen + n_trainable + # Ensure total is not zero to avoid division by zero error. + percentage_trainable = (n_trainable / total * 100) if total > 0 else 0 + print( + f"[freeze_non_lora] Trainable: {n_trainable}/{total} ({percentage_trainable:.1f}%), " + f"Frozen: {n_frozen}" + ) + return n_frozen, n_trainable + + +# ============================================================================== +# Task & Curriculum Learning Utilities +# ============================================================================== + +def compute_task_weights( + task_returns: Dict[str, float], + option: str = "symlog", + epsilon: float = 1e-6, + temperature: float = 1.0, + use_softmax: bool = False, + reverse: bool = False, + clip_min: float = 1e-2, + clip_max: float = 1.0, +) -> Dict[str, float]: """ - Calculate the number of updates to perform per data collection in a - Distributed Data Parallel (DDP) setting. This ensures that all GPUs - compute the same `update_per_collect` value, synchronized across processes. + Overview: + Calculates sampling weights for different tasks based on their returns (e.g., rewards or losses). + This function supports various normalization methods, softmax-based distribution, + proportional/inverse weighting, and weight clipping. Arguments: - - cfg: Configuration object containing policy settings. - - new_data (List[List[torch.Tensor]]): The newly collected data segments. - - world_size (int): The total number of processes. + - task_returns (:obj:`Dict[str, float]`): A dictionary mapping task IDs to their return values. + - option (:obj:`str`): Normalization method. One of ["symlog", "max-min", "run-max-min", "rank", "none"]. + - epsilon (:obj:`float`): A small value to prevent division by zero. + - temperature (:obj:`float`): A temperature parameter to control the sharpness of the weight distribution. + - use_softmax (:obj:`bool`): If True, use softmax to compute weights; otherwise, use direct normalization. + - reverse (:obj:`bool`): If True, weights are inversely proportional to returns; otherwise, directly proportional. + - clip_min (:obj:`float`): The minimum value to clip the final weights to. + - clip_max (:obj:`float`): The maximum value to clip the final weights to. Returns: - - int: The number of updates to perform per collection. + - Dict[str, float]: A dictionary mapping task IDs to their computed weights. """ - # Retrieve the update_per_collect setting from the configuration - update_per_collect = cfg.policy.update_per_collect + if not task_returns: + return {} + + task_ids = list(task_returns.keys()) + returns_tensor = torch.tensor(list(task_returns.values()), dtype=torch.float32) + + # Step 1: Normalize the returns based on the chosen option. + scaled_returns: torch.Tensor + if option == "symlog": + scaled_returns = symlog(returns_tensor) + elif option == "max-min": + min_val, max_val = returns_tensor.min(), returns_tensor.max() + scaled_returns = (returns_tensor - min_val) / (max_val - min_val + epsilon) + elif option == "run-max-min": + # Use function attributes to maintain state across calls, avoiding global variables. + compute_task_weights.RUNNING_MAX = max(compute_task_weights.RUNNING_MAX, returns_tensor.max().item()) + compute_task_weights.RUNNING_MIN = min(compute_task_weights.RUNNING_MIN, returns_tensor.min().item()) + scaled_returns = (returns_tensor - compute_task_weights.RUNNING_MIN) / \ + (compute_task_weights.RUNNING_MAX - compute_task_weights.RUNNING_MIN + epsilon) + elif option == "rank": + sorted_indices = torch.argsort(returns_tensor) + ranks = torch.empty_like(returns_tensor) + # Ranks are from 1 to N. + ranks[sorted_indices] = torch.arange(1, len(returns_tensor) + 1, dtype=torch.float32) + scaled_returns = ranks + elif option == "none": + scaled_returns = returns_tensor + else: + raise ValueError(f"Unsupported normalization option: {option}") + + # Step 2: Determine if weights should be proportional or inversely proportional to returns. + if reverse: + # Inverse proportion: smaller return -> higher weight. + raw_weights = 1.0 / (scaled_returns + epsilon) + else: + # Direct proportion: higher return -> higher weight. + raw_weights = scaled_returns + + # Step 3: Calculate final weights using either softmax or direct normalization. + final_weights: np.ndarray + safe_temperature = max(temperature, epsilon) + if use_softmax: + # Softmax provides a smooth distribution, often used with inverse weights. + # A higher beta (lower temperature) makes the distribution sharper. + beta = 1.0 / safe_temperature + # The sign depends on whether we want to favor high or low raw_weights. + # If reverse=True, raw_weights are high for low returns. We want to sample these more. + # Softmax(logits) gives higher probability to higher logits. + # So, logits should be proportional to the desired sampling probability. + logits = raw_weights if reverse else -raw_weights + final_weights = F.softmax(logits * beta, dim=0).numpy() + else: + # Direct normalization with temperature scaling. + scaled_weights = raw_weights**(1 / safe_temperature) + total_weight = scaled_weights.sum() + normalized_weights = scaled_weights / (total_weight + epsilon) + final_weights = normalized_weights.numpy() - if update_per_collect is None: - # If update_per_collect is not explicitly set, calculate it based on - # the number of collected transitions and the replay ratio. + # Step 4: Clip weights to the desired range and create the result dictionary. + weights_dict = { + task_id: np.clip(weight, clip_min, clip_max) + for task_id, weight in zip(task_ids, final_weights) + } - # The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game. - # On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps. - collected_transitions_num = sum( - min(len(game_segment), cfg.policy.game_segment_length) - for game_segment in new_data[0] - ) + return weights_dict - if torch.cuda.is_available() and world_size > 1: - # Convert the collected transitions count to a GPU tensor for DDP operations. - collected_transitions_tensor = torch.tensor( - collected_transitions_num, dtype=torch.int64, device='cuda' - ) +# Initialize state for the 'run-max-min' option as function attributes. +compute_task_weights.RUNNING_MAX = -float('inf') +compute_task_weights.RUNNING_MIN = float('inf') - # Synchronize the collected transitions count across all GPUs using all-reduce. - total_collected_transitions = ddp_all_reduce_sum( - collected_transitions_tensor - ).item() - # Calculate update_per_collect based on the total synchronized transitions count. - update_per_collect = int(total_collected_transitions * cfg.policy.replay_ratio) +class TemperatureScheduler: + """ + Overview: + A scheduler to gradually adjust a temperature value over a specified number + of training steps. This can be used for exploration or weighting schemes. - # Ensure the computed update_per_collect is positive. - assert update_per_collect > 0, "update_per_collect must be positive" - else: - # If not using DDP, calculate update_per_collect directly from the local count. - update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + Arguments: + - initial_temp (:obj:`float`): The starting temperature. + - final_temp (:obj:`float`): The target temperature to be reached after `threshold_steps`. + - threshold_steps (:obj:`int`): The number of steps over which the temperature will anneal. + - mode (:obj:`str`): The annealing mode, either 'linear' or 'exponential'. + """ - return update_per_collect + def __init__(self, initial_temp: float, final_temp: float, threshold_steps: int, mode: str = 'linear'): + if mode not in ['linear', 'exponential']: + raise ValueError("Mode must be 'linear' or 'exponential'.") + self.initial_temp = initial_temp + self.final_temp = final_temp + self.threshold_steps = max(1, threshold_steps) # Avoid division by zero + self.mode = mode + + def get_temperature(self, current_step: int) -> float: + """ + Overview: + Calculates the temperature for the given training step. -def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str) -> torch.Tensor: + Arguments: + - current_step (:obj:`int`): The current training step. + + Returns: + - float: The calculated temperature for the current step. + """ + if current_step >= self.threshold_steps: + return self.final_temp + + progress = current_step / self.threshold_steps + + if self.mode == 'linear': + return self.initial_temp - (self.initial_temp - self.final_temp) * progress + else: # 'exponential' + # Exponential decay from initial_temp to final_temp + # T(t) = T_initial * (T_final / T_initial)^(t / N) + if self.initial_temp <= 0: + raise ValueError("Initial temperature must be positive for exponential decay.") + scale = self.final_temp / self.initial_temp + return self.initial_temp * (scale**progress) + + +def tasks_per_stage(unsolved: int, remain_lora: int) -> int: """ Overview: - Initialize a zeros tensor for batch observations based on the shape. This function is used to initialize the UniZero model input. + Calculates the number of tasks to assign per LoRA adapter stage. + It's the ceiling of the division of unsolved tasks by remaining adapters. + Arguments: - - observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor. - - batch_size (:obj:`int`): The batch size. - - device (:obj:`str`): The device to store the tensor. + - unsolved (:obj:`int`): The number of tasks yet to be solved. + - remain_lora (:obj:`int`): The number of available LoRA adapters. + Returns: - - zeros (:obj:`torch.Tensor`): The zeros tensor. + - int: The number of tasks to be handled in the current stage, at least 1. """ - if isinstance(observation_shape, (list, tuple)): - shape = [batch_size, *observation_shape] - elif isinstance(observation_shape, int): - shape = [batch_size, observation_shape] - else: - raise TypeError(f"observation_shape must be either an int, a list, or a tuple, but got {type(observation_shape).__name__}") + return max(1, math.ceil(unsolved / max(remain_lora, 1))) - return torch.zeros(shape).to(device) -def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int = 0) -> torch.Tensor: +def compute_unizero_mt_normalized_stats( + eval_returns: Dict[int, float], + human_scores: Dict[int, float], + random_scores: Dict[int, float] +) -> Tuple[Optional[float], Optional[float]]: """ Overview: - Initialize a tensor filled with `pad_token_id` for batch observations. - This function is designed to be flexible and can handle both textual - and non-textual observations: - - - For textual observations: it initializes `input_ids` with padding tokens, - ensuring consistent sequence lengths within a batch. - - For non-textual observations: it provides a convenient way to fill - observation tensors with a default of 0, - ensuring shape compatibility and preventing uninitialized values. + Calculates the Human-Normalized Mean and Median for a set of evaluation returns. + If no valid returns are provided, it returns (None, None). + Arguments: - - observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor. - - batch_size (:obj:`int`): The batch size. - - device (:obj:`str`): The device to store the tensor. - - pad_token_id (:obj:`int`): The token ID (or placeholder value) used for padding. + - eval_returns (:obj:`Dict[int, float]`): A dictionary of evaluation returns per task ID. + - human_scores (:obj:`Dict[int, float]`): A dictionary of human expert scores per task ID. + - random_scores (:obj:`Dict[int, float]`): A dictionary of random policy scores per task ID. + Returns: - - padded_tensor (:obj:`torch.Tensor`): A tensor of the given shape, - filled with `pad_token_id`. + - Tuple[Optional[float], Optional[float]]: A tuple containing the human-normalized mean and median. + """ + normalized = [] + for tid, ret in eval_returns.items(): + if ret is None or tid not in human_scores or tid not in random_scores: + continue + denom = human_scores[tid] - random_scores[tid] + if denom == 0: + continue + normalized.append((ret - random_scores[tid]) / denom) + + if not normalized: + return None, None + + arr = np.asarray(normalized, dtype=np.float32) + return float(arr.mean()), float(np.median(arr)) + + +def allocate_batch_size( + cfgs: List[EasyDict], + game_buffers: List[GameBuffer], + alpha: float = 1.0, + clip_scale: int = 1 +) -> List[int]: """ - if isinstance(observation_shape, (list, tuple)): - shape = [batch_size, *observation_shape] - elif isinstance(observation_shape, int): - shape = [batch_size, observation_shape] + Overview: + Allocates batch sizes for different tasks inversely proportional to the + number of collected episodes for each task. It also dynamically clips + the batch size range to improve training stability. + + Arguments: + - cfgs (:obj:`List[EasyDict]`): A list of configuration objects for each task. + - game_buffers (:obj:`List[GameBuffer]`): A list of replay buffer instances for each task. + - alpha (:obj:`float`): A hyperparameter to control the degree of inverse proportionality. + - clip_scale (:obj:`int`): A scaling factor to determine the min/max batch size clip range. + + Returns: + - List[int]: A list of allocated batch sizes for each task. + """ + # This function assumes a DDP environment. + if not dist.is_available() or not dist.is_initialized(): + # Fallback for non-DDP environment if needed, though the logic is DDP-centric. + logging.warning("allocate_batch_size is designed for DDP and may not work as expected.") + world_size = 1 + rank = 0 else: - raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}") + world_size = dist.get_world_size() + rank = dist.get_rank() - return torch.full(shape, fill_value=pad_token_id, dtype=torch.float32, device=device) if pad_token_id == -1 else torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device) + # Extract the number of collected episodes from each local buffer. + local_episodes = [buffer.num_of_collected_episodes for buffer in game_buffers] -def random_collect( - policy_cfg: 'EasyDict', # noqa - policy: 'Policy', # noqa - RandomPolicy: 'Policy', # noqa - collector: 'ISerialCollector', # noqa - collector_env: 'BaseEnvManager', # noqa - replay_buffer: 'IBuffer', # noqa - postprocess_data_fn: Optional[Callable] = None -) -> None: # noqa - assert policy_cfg.random_collect_episode_num > 0 + # Gather episode counts from all ranks. + all_task_episodes_list = [None for _ in range(world_size)] + dist.all_gather_object(all_task_episodes_list, local_episodes) - random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space) - # set the policy to random policy - collector.reset_policy(random_policy.collect_mode) + # Flatten the list of lists into a single list of episode counts for all tasks. + all_task_episodes = [ep for sublist in all_task_episodes_list for ep in sublist] - # set temperature for visit count distributions according to the train_iter, - # please refer to Appendix D in MuZero paper for details. - collect_kwargs = {'temperature': 1, 'epsilon': 0.0} + if rank == 0: + logging.info(f'All task collected episodes: {all_task_episodes}') - # Collect data by default config n_sample/n_episode. - new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0, - policy_kwargs=collect_kwargs) + # Calculate weights inversely proportional to episode counts. + # Add 1 to avoid division by zero for new tasks. + inv_episodes = np.array([1.0 / (episodes + 1) for episodes in all_task_episodes]) + inv_sum = np.sum(inv_episodes) - if postprocess_data_fn is not None: - new_data = postprocess_data_fn(new_data) + # Total batch size is assumed to be consistent across configs. + total_batch_size = cfgs[0].policy.total_batch_size - # save returned new_data collected by the collector - replay_buffer.push_game_segments(new_data) - # remove the oldest data if the replay buffer is full. - replay_buffer.remove_oldest_data_to_fit() + # Define dynamic clipping range for batch sizes. + avg_batch_size = total_batch_size / len(all_task_episodes) + min_batch_size = avg_batch_size / clip_scale + max_batch_size = avg_batch_size * clip_scale - # restore the policy - collector.reset_policy(policy.collect_mode) + # Calculate batch sizes based on weights, apply alpha for smoothing. + task_weights = (inv_episodes / inv_sum)**alpha + batch_sizes = total_batch_size * task_weights + + # Clip and convert to integers. + batch_sizes = np.clip(batch_sizes, min_batch_size, max_batch_size) + batch_sizes = [int(size) for size in batch_sizes] + return batch_sizes -def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: + +# ============================================================================== +# Distributed Data Parallel (DDP) Utilities +# ============================================================================== + +def is_ddp_enabled() -> bool: """ Overview: - Log the memory usage of the buffer and the current process to TensorBoard. - Arguments: - - train_iter (:obj:`int`): The current training iteration. - - buffer (:obj:`GameBuffer`): The game buffer. - - writer (:obj:`SummaryWriter`): The TensorBoard writer. + Checks if the environment is set up for Distributed Data Parallel (DDP) training. + + Returns: + - bool: True if `torch.distributed` is available and initialized, False otherwise. """ - # "writer is None" means we are in a slave process in the DDP setup. - if writer is not None: - writer.add_scalar('Buffer/num_of_all_collected_episodes', buffer.num_of_collected_episodes, train_iter) - writer.add_scalar('Buffer/num_of_game_segments', len(buffer.game_segment_buffer), train_iter) - writer.add_scalar('Buffer/num_of_transitions', len(buffer.game_segment_game_pos_look_up), train_iter) + return dist.is_available() and dist.is_initialized() - game_segment_buffer = buffer.game_segment_buffer - # Calculate the amount of memory occupied by self.game_segment_buffer (in bytes). - buffer_memory_usage = asizeof(game_segment_buffer) +def ddp_synchronize() -> None: + """ + Overview: + Performs a barrier synchronization across all processes in a DDP group. + This ensures that all processes reach this point before any of them proceed. + """ + if is_ddp_enabled(): + dist.barrier() - # Convert buffer_memory_usage to megabytes (MB). - buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024) - # Record the memory usage of self.game_segment_buffer to TensorBoard. - writer.add_scalar('Buffer/memory_usage/game_segment_buffer', buffer_memory_usage_mb, train_iter) +def ddp_all_reduce_sum(tensor: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs an all-reduce operation (sum) on a given tensor across all + processes in the DDP group. - # Get the amount of memory currently used by the process (in bytes). - process = psutil.Process(os.getpid()) - process_memory_usage = process.memory_info().rss + Arguments: + - tensor (:obj:`torch.Tensor`): The tensor to be reduced. - # Convert process_memory_usage to megabytes (MB). - process_memory_usage_mb = process_memory_usage / (1024 * 1024) + Returns: + - torch.Tensor: The reduced tensor, with values summed across all processes. + """ + if is_ddp_enabled(): + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor - # Record the memory usage of the process to TensorBoard. - writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter) +# ============================================================================== +# Reinforcement Learning Workflow Utilities +# ============================================================================== -def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: +def calculate_update_per_collect( + cfg: EasyDict, + new_data: List[List[torch.Tensor]], + world_size: int = 1 +) -> int: """ Overview: - Log the average runtime metrics of the buffer to TensorBoard. + Calculates the number of training updates to perform per data collection cycle. + In a DDP setting, it synchronizes transition counts across all GPUs to ensure + a consistent `update_per_collect` value. + Arguments: - - train_iter (:obj:`int`): The current training iteration. - - buffer (:obj:`GameBuffer`): The game buffer containing runtime metrics. - - writer (:obj:`SummaryWriter`): The TensorBoard writer for logging metrics. + - cfg (:obj:`EasyDict`): The configuration object containing policy settings. + It's expected to have `cfg.policy.update_per_collect`, + `cfg.policy.replay_ratio`, etc. + - new_data (:obj:`List[List[torch.Tensor]]`): The newly collected data segments. + - world_size (:obj:`int`): The total number of DDP processes. - .. note:: - "writer is None" indicates that the function is being called in a slave process in the DDP setup. + Returns: + - int: The number of updates to perform. """ - if writer is not None: - sample_times = buffer.sample_times + update_per_collect = cfg.policy.get('update_per_collect') + + if update_per_collect is not None: + return update_per_collect - if sample_times == 0: - return + # If not explicitly set, calculate based on replay ratio. + # Note: A game segment's length can be less than `game_segment_length` if it's the + # final segment of an episode. + collected_transitions_num = sum( + min(len(game_segment), cfg.policy.game_segment_length) + for game_segment in new_data[0] + ) - # Calculate and log average reanalyze time. - average_reanalyze_time = buffer.compute_target_re_time / sample_times - writer.add_scalar('Buffer/average_reanalyze_time', average_reanalyze_time, train_iter) + if torch.cuda.is_available() and world_size > 1: + # In DDP, synchronize the transition count across all GPUs. + collected_transitions_tensor = torch.tensor( + collected_transitions_num, dtype=torch.int64, device='cuda' + ) + total_collected_transitions = ddp_all_reduce_sum( + collected_transitions_tensor + ).item() + updates = int(total_collected_transitions * cfg.policy.replay_ratio) + else: + # In a single-process setup. + updates = int(collected_transitions_num * cfg.policy.replay_ratio) - # Calculate and log average origin search time. - average_origin_search_time = buffer.origin_search_time / sample_times - writer.add_scalar('Buffer/average_origin_search_time', average_origin_search_time, train_iter) + return max(1, updates) # Ensure at least one update. - # Calculate and log average reuse search time. - average_reuse_search_time = buffer.reuse_search_time / sample_times - writer.add_scalar('Buffer/average_reuse_search_time', average_reuse_search_time, train_iter) - # Calculate and log average active root number. - average_active_root_num = buffer.active_root_num / sample_times - writer.add_scalar('Buffer/average_active_root_num', average_active_root_num, train_iter) - # Reset the time records in the buffer. - buffer.reset_runtime_metrics() +def random_collect( + policy_cfg: EasyDict, + policy: Policy, + RandomPolicy: Callable, + collector: ISerialCollector, + collector_env: BaseEnvManager, + replay_buffer: IBuffer, + postprocess_data_fn: Optional[Callable] = None +) -> None: + """ + Overview: + Performs an initial data collection phase using a random policy to populate + the replay buffer before training begins. + + Arguments: + - policy_cfg (:obj:`EasyDict`): Configuration for the policy. + - policy (:obj:`Policy`): The main training policy instance. + - RandomPolicy (:obj:`Callable`): A constructor or class for creating a random policy. + - collector (:obj:`ISerialCollector`): The data collector instance. + - collector_env (:obj:`BaseEnvManager`): The environment manager. + - replay_buffer (:obj:`IBuffer`): The replay buffer to store collected data. + - postprocess_data_fn (:obj:`Optional[Callable]`): An optional function to process data after collection. + """ + random_collect_episode_num = policy_cfg.get('random_collect_episode_num', 0) + if random_collect_episode_num <= 0: + return + + random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space) + collector.reset_policy(random_policy.collect_mode) + + # Use neutral MCTS parameters for random collection. + collect_kwargs = {'temperature': 1.0, 'epsilon': 0.0} + + new_data = collector.collect( + n_episode=random_collect_episode_num, + train_iter=0, + policy_kwargs=collect_kwargs + ) + + if postprocess_data_fn: + new_data = postprocess_data_fn(new_data) + + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # Restore the original policy to the collector. + collector.reset_policy(policy.collect_mode) + + +def safe_eval( + evaluator: Evaluator, + learner: BaseLearner, + collector: Collector, + rank: int, + world_size: int, + timeout: int = EVALUATION_TIMEOUT +) -> Tuple[Optional[bool], Optional[Any]]: + """ + Overview: + Safely executes an evaluation task with a timeout to prevent hangs. + This function runs the evaluation in a separate thread and enforces a timeout + to ensure the training process doesn't get stuck during evaluation. + + Arguments: + - evaluator (:obj:`Evaluator`): The evaluator instance. + - learner (:obj:`BaseLearner`): The learner instance, used for saving checkpoints. + - collector (:obj:`Collector`): The data collector instance, used to get current envstep. + - rank (:obj:`int`): The rank of the current process in distributed training. + - world_size (:obj:`int`): The total number of processes in distributed training. + - timeout (:obj:`int`): The maximum time (in seconds) to wait for evaluation. Defaults to EVALUATION_TIMEOUT. + + Returns: + - Tuple[Optional[bool], Optional[Any]]: A tuple containing: + - stop_flag: Boolean indicating if training should stop (None if timeout/error) + - reward_dict: Dictionary containing evaluation metrics (None if timeout/error) + """ + try: + logging.info(f"========= Evaluation starting on Rank {rank}/{world_size} =========") + # Reset the stop_event to ensure it is not set before each evaluation. + evaluator.stop_event.clear() + + with concurrent.futures.ThreadPoolExecutor() as executor: + # Submit the evaluation task to run in a separate thread. + future = executor.submit( + evaluator.eval, + learner.save_checkpoint, + learner.train_iter, + collector.envstep + ) + try: + stop_flag, reward = future.result(timeout=timeout) + logging.info(f"====== Evaluation finished on Rank {rank}/{world_size} ======") + return stop_flag, reward + except concurrent.futures.TimeoutError: + # If a timeout occurs, set the stop_event to signal the evaluation thread to stop. + evaluator.stop_event.set() + logging.error( + f"Evaluation timed out on Rank {rank}/{world_size} after {timeout} seconds. " + f"Continuing training." + ) + return None, None + + except Exception as e: + logging.error( + f"An error occurred during evaluation on Rank {rank}/{world_size}: {e}", + exc_info=True + ) + return None, None def convert_to_batch_for_unizero(batch_data, policy_cfg, reward_support, value_support): @@ -332,7 +748,7 @@ def create_unizero_loss_metrics(policy): Returns: - compute_metrics (:obj:`Callable`): Function that computes losses for a batch of data """ - import logging + from ditk import logging # Get reward_support and value_support from policy reward_support = policy.reward_support @@ -426,3 +842,205 @@ def __iter__(self): def __len__(self): """Return the total number of batches""" return self.num_batches + +# ============================================================================== +# Logging Utilities +# ============================================================================== + +def log_module_trainable_status( + module: nn.Module, + module_name: str, + logger: logging.Logger +) -> None: + """ + Overview: + Logs the detailed trainable/frozen status of all parameters within a given module. + + Arguments: + - module (:obj:`nn.Module`): The module to inspect (e.g., a ViT Encoder). + - module_name (:obj:`str`): The name of the module for logging purposes. + - logger (:obj:`logging.Logger`): The logger instance to use for output. + """ + logger.info(f"--- Parameter Status Details for Module: '{module_name}' ---") + + total_params = 0 + trainable_params = 0 + + param_list = list(module.named_parameters()) + if not param_list: + logger.info(" - No parameters found in this module.") + return + + for name, param in param_list: + total_params += param.numel() + status = "Trainable" if param.requires_grad else "Frozen" + logger.info(f" - {name:<60} | Shape: {str(param.shape):<25} | Status: {status}") + if param.requires_grad: + trainable_params += param.numel() + + logger.info(f"--- Summary for Module: '{module_name}' ---") + logger.info(f" - Total Parameters: {total_params:,}") + logger.info(f" - Trainable Parameters: {trainable_params:,}") + if total_params > 0: + percentage = 100 * trainable_params / total_params + logger.info(f" - Trainable Percentage: {percentage:.4f}%") + logger.info("-" * (len(module_name) + 40)) + + +def log_param_statistics(model: nn.Module, logger: logging.Logger) -> None: + """ + Overview: + Logs a concise summary of the number and size of trainable versus total + parameters in a model. + + Arguments: + - model (:obj:`nn.Module`): The model to analyze. + - logger (:obj:`logging.Logger`): The logger instance for output. + """ + n_tensors_total = sum(1 for _ in model.parameters()) + n_tensors_train = sum(1 for p in model.parameters() if p.requires_grad) + + n_elems_total = sum(p.numel() for p in model.parameters()) + n_elems_train = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info( + f'Trainable Parameters: ' + f'{n_tensors_train}/{n_tensors_total} tensors | ' + f'{n_elems_train:,}/{n_elems_total:,} elements ' + f'({n_elems_train/1e6:.2f}M / {n_elems_total/1e6:.2f}M)' + ) + + +def log_buffer_memory_usage( + train_iter: int, + buffer: GameBuffer, + writer: SummaryWriter, + task_id: int = 0 +) -> None: + """ + Overview: + Logs the memory usage of the replay buffer and the current process to TensorBoard. + + Arguments: + - train_iter (:obj:`int`): The current training iteration. + - buffer (:obj:`GameBuffer`): The replay buffer instance. + - writer (:obj:`SummaryWriter`): The TensorBoard writer. + - task_id (:obj:`int`): An optional ID to distinguish logs for different tasks. + """ + # In DDP, only the main process should write to TensorBoard. + if writer is None: + return + + prefix = f"Buffer/Task_{task_id}" + writer.add_scalar(f'{prefix}/num_collected_episodes', buffer.num_of_collected_episodes, train_iter) + writer.add_scalar(f'{prefix}/num_game_segments', len(buffer.game_segment_buffer), train_iter) + writer.add_scalar(f'{prefix}/num_transitions', len(buffer.game_segment_game_pos_look_up), train_iter) + + # Calculate and log memory usage of the main buffer component. + buffer_memory_bytes = asizeof(buffer.game_segment_buffer) + buffer_memory_mb = buffer_memory_bytes / (1024 * 1024) + writer.add_scalar(f'{prefix}/memory_usage_mb/game_segment_buffer', buffer_memory_mb, train_iter) + + # Get and log total memory usage of the current process. + process = psutil.Process(os.getpid()) + process_memory_bytes = process.memory_info().rss + process_memory_mb = process_memory_bytes / (1024 * 1024) + writer.add_scalar(f'{prefix}/memory_usage_mb/process', process_memory_mb, train_iter) + + +def log_buffer_run_time(train_iter: int, buffer: GameBuffer, writer: SummaryWriter) -> None: + """ + Overview: + Logs average runtime metrics related to buffer operations (e.g., sampling, search) + to TensorBoard. + + Arguments: + - train_iter (:obj:`int`): The current training iteration. + - buffer (:obj:`GameBuffer`): The buffer instance containing runtime metrics. + - writer (:obj:`SummaryWriter`): The TensorBoard writer. + """ + if writer is None or buffer.sample_times == 0: + return + + sample_times = buffer.sample_times + writer.add_scalar('Buffer/avg_reanalyze_time_ms', (buffer.compute_target_re_time / sample_times) * 1000, train_iter) + writer.add_scalar('Buffer/avg_origin_search_time_ms', (buffer.origin_search_time / sample_times) * 1000, train_iter) + writer.add_scalar('Buffer/avg_reuse_search_time_ms', (buffer.reuse_search_time / sample_times) * 1000, train_iter) + writer.add_scalar('Buffer/avg_active_root_num', buffer.active_root_num / sample_times, train_iter) + + # Reset metrics after logging to prepare for the next interval. + buffer.reset_runtime_metrics() + + +# ============================================================================== +# Example Usage +# ============================================================================== +if __name__ == '__main__': + # Configure a basic logger to see output from functions with `verbose=True` + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + + print("\n--- Example for `compute_task_weights` ---") + task_rewards_list = [ + {"task1": 10, "task2": 100, "task3": 1000, "task4": 500, "task5": 300}, + {"task1": 1, "task2": 10, "task3": 100, "task4": 1000, "task5": 10000}, + {"task1": 0.1, "task2": 0.5, "task3": 0.9, "task4": 5, "task5": 10}, + ] + + for i, task_rewards in enumerate(task_rewards_list, start=1): + print(f"\n--- Case {i} ---") + print(f"Original Rewards: {task_rewards}") + + # Example 1: Using 'none' normalization (proportional to raw values) + weights_none = compute_task_weights(task_rewards, option="none", use_softmax=False) + print(f"Weights (proportional to raw values): {weights_none}") + + # Example 2: Using 'symlog' normalization + weights_symlog = compute_task_weights(task_rewards, option="symlog", use_softmax=False) + print(f"Weights (with symlog normalization): {weights_symlog}") + + # Example 3: Using 'rank' normalization and softmax with inverse proportion + weights_rank_softmax = compute_task_weights(task_rewards, option="rank", use_softmax=True, reverse=True) + print(f"Weights (inverse rank with softmax): {weights_rank_softmax}") + + print("\n--- Example for `freeze_non_lora` ---") + + # ========================================================================== + # FIX: The nn.Parameter must be wrapped in an nn.Module subclass to be + # placed inside an nn.ModuleDict. + # ========================================================================== + class AdapterScale(nn.Module): + """A simple nn.Module wrapper for a single learnable parameter.""" + def __init__(self): + super().__init__() + self.logit = nn.Parameter(torch.randn(1)) + + # Create a dummy model to demonstrate freezing + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.backbone = nn.Linear(10, 10) + self.layer1 = nn.Linear(10, 10) + # Simulate LoRA parameters with correct naming + self.layer1.lora_A = nn.Parameter(torch.randn(10, 2)) + self.layer1.lora_B = nn.Parameter(torch.randn(2, 10)) + + # Correctly structure the adapter_scales using the wrapper module. + # This ensures that the value associated with key '0' is a valid nn.Module. + self.adapter_scales = nn.ModuleDict({ + '0': AdapterScale() + }) + + model = DummyModel() + print("Initial parameter status:") + log_module_trainable_status(model, "DummyModel", logging.getLogger()) + + print("\nFreezing non-LoRA parameters...") + freeze_non_lora(model, freeze=True, verbose=True) + print("\nParameter status after freezing:") + log_module_trainable_status(model, "DummyModel", logging.getLogger()) + + print("\nUn-freezing non-LoRA parameters...") + freeze_non_lora(model, freeze=False, verbose=True) + print("\nParameter status after un-freezing:") + log_module_trainable_status(model, "DummyModel", logging.getLogger()) + diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 7ab1fee8a..7cd4308b2 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -102,22 +102,23 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: """ pass - def _sample_orig_data(self, batch_size: int) -> Tuple: + def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) -> Tuple: """ Overview: - sample orig_data that contains: - game_segment_list: a list of game segments - pos_in_game_segment_list: transition index in game (relative index) - batch_index_list: the index of start transition of sampled minibatch in replay buffer - weights_list: the weight concerning the priority - make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Sample original data which includes: + - game_segment_list: A list of game segments. + - pos_in_game_segment_list: Transition index in the game (relative index). + - batch_index_list: The index of the start transition of the sampled mini-batch in the replay buffer. + - weights_list: The weight concerning the priority. + - make_time: The time the batch is made (for correctly updating the replay buffer when data is deleted). Arguments: - - batch_size (:obj:`int`): batch size - - beta: float the parameter in PER for calculating the priority + - batch_size (:obj:`int`): The size of the batch. + - print_priority_logs (:obj:`bool`): Whether to print logs related to priority statistics, defaults to False. """ - assert self._beta > 0 + assert self._beta > 0, "Beta should be greater than 0" num_of_transitions = self.get_num_of_transitions() - if self._cfg.use_priority is False: + if not self._cfg.use_priority: + # If priority is not used, set all priorities to 1 self.game_pos_priorities = np.ones_like(self.game_pos_priorities) # +1e-6 for numerical stability @@ -126,20 +127,21 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: # sample according to transition index batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) - - if self._cfg.reanalyze_outdated is True: - # NOTE: used in reanalyze part + + if self._cfg.reanalyze_outdated: + # Sort the batch indices if reanalyze is enabled batch_index_list.sort() - + + # Calculate weights for the sampled transitions weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) - weights_list /= weights_list.max() + weights_list /= weights_list.max() # Normalize weights game_segment_list = [] pos_in_game_segment_list = [] for idx in batch_index_list: game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] - game_segment_idx -= self.base_idx + game_segment_idx -= self.base_idx # Adjust index based on base index game_segment = self.game_segment_buffer[game_segment_idx] game_segment_list.append(game_segment) @@ -205,7 +207,13 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: if pos_in_game_segment >= self._cfg.game_segment_length: pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item() - segment_len = len(game_segment.action_segment) + # Compatibility handling for both GameSegment objects and list data (for unittests) + try: + segment_len = len(game_segment.action_segment) + except (AttributeError, TypeError): + # For unittest compatibility: when game_segment is a list instead of GameSegment object + segment_len = len(game_segment) + if pos_in_game_segment >= segment_len - 1: # If the segment is very short (length 0 or 1), we can't randomly sample a position # before the last one. The only safe position is 0. @@ -220,115 +228,152 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: pos_in_game_segment_list.append(pos_in_game_segment) - make_time = [time.time() for _ in range(len(batch_index_list))] - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) - return orig_data - - def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: - """ - Overview: - This function samples a batch of game segments for reanalysis from the replay buffer. - It uses priority sampling based on the `reanalyze_time` of each game segment, with segments - that have been reanalyzed more frequently receiving lower priority. - - The function returns a tuple containing information about the sampled game segments, - including their positions within each segment and the time the batch was created. - Arguments: - - batch_size (:obj:`int`): - The number of samples to draw in this batch. - - Returns: - - Tuple: - A tuple containing the following elements: - - game_segment_list: A list of the sampled game segments. - - pos_in_game_segment_list: A list of indices representing the position of each transition - within its corresponding game segment. - - batch_index_list: The indices of the sampled game segments in the replay buffer. - - make_time: A list of timestamps (set to `0` in this implementation) indicating when - the batch was created. - - Key Details: - 1. **Priority Sampling**: - Game segments are sampled based on a probability distribution calculated using - the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently - are less likely to be selected. - 2. **Segment Slicing**: - Each selected game segment is sampled at regular intervals determined by the - `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled - from each selected segment. - 3. **Handling Extra Samples**: - If the `batch_size` is not perfectly divisible by the number of samples per segment, - additional segments are sampled to make up the difference. - 4. **Reanalyze Time Update**: - The `reanalyze_time` attribute of each sampled game segment is incremented to reflect - that it has been selected for reanalysis again. - Raises: - - ValueError: - If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. - """ - train_sample_num = len(self.game_segment_buffer) - assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." - valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) - - # Calculate the number of samples per segment - samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps - - # Make sure that the batch size can be divided by the number of samples per segment - if samples_per_segment == 0: - raise ValueError("The game segment length is too small for num_unroll_steps.") - - # Calculate the number of samples per segment - batch_size_per_segment = batch_size // samples_per_segment - - # If the batch size cannot be divided, process the remainder part - extra_samples = batch_size % samples_per_segment - - # We use the reanalyze_time in the game_segment_buffer to generate weights - reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) - - # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) - base_decay_rate = 100 - decay_rate = base_decay_rate / valid_sample_num - weights = np.exp(-decay_rate * reanalyze_times) - - # Normalize the weights to a probability distribution - probabilities = weights / np.sum(weights) - - # Sample game segments according to the probabilities - selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, - p=probabilities) - - # If there are extra samples to be allocated, randomly select some game segments and sample again - if extra_samples > 0: - extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=False, p=probabilities) - selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) - - game_segment_list = [] - pos_in_game_segment_list = [] - batch_index_list = [] - - for game_segment_idx in selected_game_segments: - game_segment_idx -= self.base_idx - game_segment = self.game_segment_buffer[game_segment_idx] - - # Update reanalyze_time only once - game_segment.reanalyze_time += 1 - - # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) - for i in range(samples_per_segment): - game_segment_list.append(game_segment) - pos_in_game_segment = i * self._cfg.num_unroll_steps - if pos_in_game_segment >= len(game_segment): - pos_in_game_segment = np.random.choice(len(game_segment), 1).item() - pos_in_game_segment_list.append(pos_in_game_segment) - batch_index_list.append(game_segment_idx) + # make_time = [time.time() for _ in range(len(batch_index_list))] # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). make_time = [0. for _ in range(len(batch_index_list))] - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + + if print_priority_logs: + print(f"Sampled batch indices: {batch_index_list}") + print(f"Sampled priorities: {self.game_pos_priorities[batch_index_list]}") + print(f"Sampled weights: {weights_list}") + return orig_data + def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple: + """ + Overview: + This function samples a batch of game segments for reanalysis from the replay buffer. + It uses priority sampling based on the `reanalyze_time` of each game segment, with segments + that have been reanalyzed more frequently receiving lower priority. + + The function returns a tuple containing information about the sampled game segments, + including their positions within each segment and the time the batch was created. + Arguments: + - batch_size (:obj:`int`): + The number of samples to draw in this batch. + + Returns: + - Tuple: + A tuple containing the following elements: + - game_segment_list: A list of the sampled game segments. + - pos_in_game_segment_list: A list of indices representing the position of each transition + within its corresponding game segment. + - batch_index_list: The indices of the sampled game segments in the replay buffer. + - make_time: A list of timestamps (set to `0` in this implementation) indicating when + the batch was created. + + Key Details: + 1. **Priority Sampling**: + Game segments are sampled based on a probability distribution calculated using + the `reanalyze_time` of each segment. Segments that have been reanalyzed more frequently + are less likely to be selected. + 2. **Segment Slicing**: + Each selected game segment is sampled at regular intervals determined by the + `num_unroll_steps` parameter. Up to `samples_per_segment` transitions are sampled + from each selected segment. + 3. **Handling Extra Samples**: + If the `batch_size` is not perfectly divisible by the number of samples per segment, + additional segments are sampled to make up the difference. + 4. **Reanalyze Time Update**: + The `reanalyze_time` attribute of each sampled game segment is incremented to reflect + that it has been selected for reanalysis again. + Raises: + - ValueError: + If the `game_segment_length` is too small to accommodate the `num_unroll_steps`. + """ + train_sample_num = len(self.game_segment_buffer) + assert self._cfg.reanalyze_partition <= 0.75, "The reanalyze partition should be less than 0.75." + valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) + + # Calculate the number of samples per segment + samples_per_segment = self._cfg.game_segment_length // self._cfg.num_unroll_steps + + # Make sure that the batch size can be divided by the number of samples per segment + if samples_per_segment == 0: + raise ValueError("The game segment length is too small for num_unroll_steps.") + + # Calculate the number of samples per segment + batch_size_per_segment = batch_size // samples_per_segment + + # If the batch size cannot be divided, process the remainder part + extra_samples = batch_size % samples_per_segment + + # We use the reanalyze_time in the game_segment_buffer to generate weights + reanalyze_times = np.array([segment.reanalyze_time for segment in self.game_segment_buffer[:valid_sample_num]]) + + # Calculate weights: the larger the reanalyze_time, the smaller the weight (use exp(-reanalyze_time)) + base_decay_rate = 100 + # Add a small epsilon to avoid division by zero if valid_sample_num is 0 + decay_rate = base_decay_rate / (valid_sample_num + 1e-6) + weights = np.exp(-decay_rate * reanalyze_times) + + # Normalize the weights to a probability distribution, handle case where sum is zero + sum_weights = np.sum(weights) + if sum_weights > 0: + probabilities = weights / sum_weights + else: + # If all weights are zero, use a uniform distribution + probabilities = np.ones(valid_sample_num) / valid_sample_num + + # Sample game segments according to the probabilities + # Ensure valid_sample_num is not zero before sampling + if valid_sample_num == 0: + return ([], [], [], [], []) + + selected_game_segments = np.random.choice(valid_sample_num, batch_size_per_segment, replace=False, + p=probabilities) + + # If there are extra samples to be allocated, randomly select some game segments and sample again + if extra_samples > 0: + # We need to handle the case where we might sample the same segment again. + # A simple way is to allow replacement for extra samples or sample from remaining ones. + # For simplicity, let's stick to the original logic but ensure it's safe. + remaining_segments = np.setdiff1d(np.arange(valid_sample_num), selected_game_segments) + if len(remaining_segments) < extra_samples: + # If not enough unique segments left, sample with replacement from all valid segments + extra_game_segments = np.random.choice(valid_sample_num, extra_samples, replace=True, p=probabilities) + else: + # Sample from the remaining unique segments + remaining_probs = probabilities[remaining_segments] + remaining_probs /= np.sum(remaining_probs) + extra_game_segments = np.random.choice(remaining_segments, extra_samples, replace=False, p=remaining_probs) + + selected_game_segments = np.concatenate((selected_game_segments, extra_game_segments)) + + game_segment_list = [] + pos_in_game_segment_list = [] + batch_index_list = [] + print(f"selected_game_segments:{selected_game_segments}") + for game_segment_idx in selected_game_segments: + # ========================================================================= + # FIX: The line below is the source of the error and has been removed. + # `game_segment_idx` is already a valid physical index for `game_segment_buffer`. + # game_segment_idx -= self.base_idx + # ========================================================================= + game_segment = self.game_segment_buffer[game_segment_idx] + + # Update reanalyze_time only once + game_segment.reanalyze_time += 1 + + # The sampling position should be 0, 0 + num_unroll_steps, ... (integer multiples of num_unroll_steps) + for i in range(samples_per_segment): + game_segment_list.append(game_segment) + pos_in_game_segment = i * self._cfg.num_unroll_steps + if pos_in_game_segment >= len(game_segment): + pos_in_game_segment = np.random.choice(len(game_segment), 1).item() + pos_in_game_segment_list.append(pos_in_game_segment) + # NOTE: We should append the physical index here, as it corresponds to the sampled segment. + batch_index_list.append(game_segment_idx) + + # Set the make_time for each sample (set to 0 for now, but can be the actual time if needed). + make_time = [0. for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, [], make_time) + return orig_data + def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: """ Overview: @@ -645,7 +690,8 @@ def remove_oldest_data_to_fit(self) -> None: Overview: remove some oldest data if the replay buffer is full. """ - assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" + if isinstance(self._cfg.batch_size, int): + assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" nums_of_game_segments = self.get_num_of_game_segments() total_transition = self.get_num_of_transitions() if total_transition > self.replay_buffer_size: @@ -657,8 +703,15 @@ def remove_oldest_data_to_fit(self) -> None: # find the max game_segment index to keep in the buffer index = i break - if total_transition >= self._cfg.batch_size: - self._remove(index + 1) + if isinstance(self._cfg.batch_size, int): + if total_transition >= self._cfg.batch_size: + self._remove(index + 1) + else: + try: + if total_transition >= self._cfg.batch_size[0]: + self._remove(index + 1) + except Exception as e: + print(e) def _remove(self, excess_game_segment_index: List[int]) -> None: """ diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index faf0155a0..8cd7dfb51 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -3,7 +3,6 @@ import numpy as np import torch from ding.utils import BUFFER_REGISTRY, EasyTimer -# from line_profiler import line_profiler from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree @@ -61,6 +60,18 @@ def __init__(self, cfg: dict): self.sample_times = 0 self.active_root_num = 0 + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + try: + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + except Exception as e: + self.action_space_size = self._cfg.model.action_space_size + + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) @@ -149,7 +160,7 @@ def sample( self.compute_target_re_time += self._compute_target_timer.value batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size + policy_non_re_context, self.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies @@ -469,17 +480,20 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device) # calculate the target value - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self.value_support), - m_output.policy_logits - ] - ) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) + + + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self.value_support), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -594,17 +608,19 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: beg_index = self._cfg.mini_infer_size * i end_index = self._cfg.mini_infer_size * (i + 1) m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device) - m_output = model.initial_inference(m_obs) - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self.value_support), - m_output.policy_logits - ] - ) + if self.task_id is not None: + m_output = model.initial_inference(m_obs, task_id=self.task_id) + else: + m_output = model.initial_inference(m_obs) + + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self.value_support), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -612,7 +628,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: @@ -624,7 +640,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model with self._origin_search_timer: - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + self.origin_search_time += self._origin_search_timer.value else: # python mcts_tree @@ -634,7 +654,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: else: roots.prepare_no_noise(reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -650,7 +674,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + target_policies.append([0 for _ in range(self.action_space_size)]) else: # NOTE: It is very important to use the latest MCTS visit count distribution. sum_visits = sum(distributions) @@ -659,7 +683,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + list(np.ones(self.action_space_size) / self.action_space_size) ) else: # Update the data in game segment: @@ -676,7 +700,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: target_policies.append(policy) else: # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + policy_tmp = [0 for _ in range(self.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] @@ -705,7 +729,7 @@ def _compute_target_policy_non_reanalyzed( - game_segment_lens - action_mask_segment - to_play_segment - - policy_shape: self._cfg.model.action_space_size + - policy_shape: self.action_space_size Returns: - batch_target_policies_non_re """ @@ -728,7 +752,7 @@ def _compute_target_policy_non_reanalyzed( ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + [-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] @@ -778,6 +802,7 @@ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) - NOTE: train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list] + target_batch = [batch_rewards, batch_target_values, batch_target_policies] """ indices = train_data[0][-3] metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} diff --git a/lzero/mcts/buffer/game_buffer_sampled_unizero.py b/lzero/mcts/buffer/game_buffer_sampled_unizero.py index f91b7f08a..da09fc311 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_unizero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_unizero.py @@ -48,9 +48,18 @@ def __init__(self, cfg: dict): self.game_segment_buffer = [] self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] - # self.task_id = self._cfg.task_id self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) @@ -115,21 +124,22 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))] mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + # pad random action if self._cfg.model.continuous_action_space: actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] root_sampled_actions_tmp += [ - np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size) for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) ] else: # generate random `padded actions_tmp` actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps - len(actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, 1 # Number of sampled actions for actions_tmp is 1 ) @@ -138,7 +148,7 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> reshape = True if self._cfg.mcts_ctree else False root_sampled_actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, self._cfg.model.num_of_sampled_actions, reshape=reshape ) @@ -277,18 +287,18 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: # pad random action if self._cfg.model.continuous_action_space: actions_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) ] root_sampled_actions_tmp += [ - np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size) + np.random.rand(self._cfg.model.num_of_sampled_actions, self.action_space_size) for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp)) ] else: # generate random `padded actions_tmp` actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps - len(actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, 1 # Number of sampled actions for actions_tmp is 1 ) @@ -297,7 +307,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: reshape = True if self._cfg.mcts_ctree else False root_sampled_actions_tmp += generate_random_actions_discrete( self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp), - self._cfg.model.action_space_size, + self.action_space_size, self._cfg.model.num_of_sampled_actions, reshape=reshape ) @@ -326,7 +336,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: if self._cfg.model.continuous_action_space: # pad random action bootstrap_action_tmp += [ - np.random.randn(self._cfg.model.action_space_size) + np.random.randn(self.action_space_size) for _ in range(self._cfg.num_unroll_steps - len(bootstrap_action_tmp)) ] bootstrap_action_list.append(bootstrap_action_tmp) @@ -489,6 +499,12 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # calculate the target value # batch_action.shape (32, 10) # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 + + if self.task_id is not None: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + else: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num # ======================================================================= @@ -514,18 +530,24 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # cpp mcts_tree # roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots = MCTSCtree.roots( - transition_batch_size, legal_actions, self._cfg.model.action_space_size, + transition_batch_size, legal_actions, self.action_space_size, self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space ) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -629,7 +651,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ @@ -647,7 +669,12 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the target value # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 - m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + if self.task_id is not None: + # m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) + + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) + else: + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) # ====================================================================== # if not in training, obtain the scalars of the value/reward @@ -658,6 +685,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A m_output.policy_logits ] ) + network_output.append(m_output) if self._cfg.use_root_value: diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index b8998acb9..f4652e1cf 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy +from line_profiler import line_profiler @BUFFER_REGISTRY.register('game_buffer_unizero') @@ -48,9 +49,22 @@ def __init__(self, cfg: dict): self.game_segment_game_pos_look_up = [] self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + try: + self.action_space_size = self._cfg.model.action_space_size_list[self.task_id] + except Exception as e: + self.action_space_size = self._cfg.model.action_space_size + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + self.action_space_size = self._cfg.model.action_space_size + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range) self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range) + #@profile def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -81,7 +95,7 @@ def sample( # target policy batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1], current_batch[-1]) # current_batch[1] is batch_action batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size + policy_non_re_context, self.action_space_size ) # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies @@ -98,6 +112,7 @@ def sample( train_data = [current_batch, target_batch] return train_data + #@profile def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: @@ -133,9 +148,6 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: self._cfg.num_unroll_steps].tolist() timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() - # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid - # mask_tmp = [1. for i in range(len(actions_tmp))] - # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # TODO: the child_visits after position in the segment (with padded part) may not be updated # So the corresponding position should not be used in the training @@ -278,9 +290,6 @@ def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps].tolist() - # TODO: original buffer mask - # mask_tmp = [1. for i in range(min(len(actions_tmp), self._cfg.game_segment_length - pos_in_game_segment))] - # mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] # pad random action actions_tmp += [ @@ -415,11 +424,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if self._cfg.model.continuous_action_space is True: # when the action space of the environment is continuous, action_mask[:] is None. action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + list(np.ones(self.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) ] # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + [-1 for _ in range(self.action_space_size)] for _ in range(transition_batch_size) ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] @@ -435,18 +444,25 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # =============== NOTE: The key difference with MuZero ================= # To obtain the target policy from MCTS guided by the recent target model # TODO: batch_obs (policy_obs_list) is at timestep t, batch_action is at timestep t - m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + + if self.task_id is not None: + # TODO: support RoPE + # m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num + + else: + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num], start_pos=batch_timestep[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + # ======================================================================= - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self.value_support), - m_output.policy_logits - ] - ) + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self.value_support), + m_output.policy_logits + ] + ) network_output.append(m_output) @@ -454,7 +470,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: reward_pool = reward_pool.squeeze().tolist() policy_logits_pool = policy_logits_pool.tolist() noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self.action_space_size ).astype(np.float32).tolist() for _ in range(transition_batch_size) ] if self._cfg.mcts_ctree: @@ -462,13 +478,21 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: roots = MCTSCtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) + if self.task_id is not None: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, task_id=self.task_id) + # TODO: adapt unizero multitask to timestep in rope + # MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num], task_id=self.task_id) + else: + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) else: # python mcts_tree roots = MCTSPtree.roots(transition_batch_size, legal_actions) roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) + if self.task_id is not None: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num], task_id=self.task_id) + else: + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play, batch_timestep[:self.reanalyze_num]) roots_legal_actions_list = legal_actions roots_distributions = roots.get_distributions() @@ -479,7 +503,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: distributions = roots_distributions[policy_index] if policy_mask[policy_index] == 0: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + target_policies.append([0 for _ in range(self.action_space_size)]) else: # NOTE: It is very important to use the latest MCTS visit count distribution. sum_visits = sum(distributions) @@ -488,7 +512,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: if distributions is None: # if at some obs, the legal_action is None, add the fake target_policy target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + list(np.ones(self.action_space_size) / self.action_space_size) ) else: if self._cfg.env_type == 'not_board_games': @@ -498,7 +522,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: target_policies.append(policy) else: # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + policy_tmp = [0 for _ in range(self.action_space_size)] # to make sure target_policies have the same dimension sum_visits = sum(distributions) policy = [visit_count / sum_visits for visit_count in distributions] @@ -543,7 +567,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the bootstrapped value and target value # NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps - m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + if self.task_id is not None: + # m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep, task_id=self.task_id) + m_output = model.initial_inference(batch_obs, batch_action, task_id=self.task_id) + + else: + m_output = model.initial_inference(batch_obs, batch_action, start_pos=batch_timestep) + # ====================================================================== # if not in training, obtain the scalars of the value/reward @@ -630,3 +660,32 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A batch_target_values = np.asarray(batch_target_values) return batch_rewards, batch_target_values + + def update_priority(self, train_data: List[np.ndarray], batch_priorities: np.ndarray) -> None: + """ + Overview: + Update the priority of training data. + Arguments: + - train_data (:obj:`List[np.ndarray]`): training data to be updated priority. + - batch_priorities (:obj:`np.ndarray`): priorities to update to. + NOTE: + train_data = [current_batch, target_batch] + current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list] + """ + # TODO: NOTE: -4 is batch_index_list + indices = train_data[0][-4] + metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} + # only update the priorities for data still in replay buffer + for i in range(len(indices)): + + # Handle ValueError by using the first timestamp of the segment for comparison. + first_transition_time = metas['make_time'][i][0] + + if first_transition_time > self.clear_time: + # Handle IndexError by converting the float index to an integer before use. + idx = int(indices[i]) + prio = metas['batch_priorities'][i] + + # Now, idx is a valid integer index. + self.game_pos_priorities[idx] = prio + diff --git a/lzero/mcts/buffer/game_segment.py b/lzero/mcts/buffer/game_segment.py index ad216d196..6c2cd1999 100644 --- a/lzero/mcts/buffer/game_segment.py +++ b/lzero/mcts/buffer/game_segment.py @@ -1,5 +1,5 @@ import copy -from typing import List, Tuple +from typing import List, Tuple, Optional import numpy as np from easydict import EasyDict @@ -31,13 +31,15 @@ class GameSegment: - store_search_stats """ - def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None: + def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None, task_id: Optional[int] = None) -> None: """ Overview: Init the ``GameSegment`` according to the provided arguments. Arguments: - action_space (:obj:`int`): action space + - action_space (:obj:`int`): action space - game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block + - task_id (:obj:`Optional[int]`): The identifier for the task, used to select the correct obs and act space in multi-task settings. Defaults to None. + """ self.action_space = action_space self.game_segment_length = game_segment_length @@ -45,19 +47,32 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea self.td_steps = config.td_steps self.frame_stack_num = config.model.frame_stack_num self.discount_factor = config.discount_factor - self.action_space_size = config.model.action_space_size + if not hasattr(config.model, "action_space_size_list"): + # for single-task setting or fixed action space in multi-task setting + self.action_space_size = config.model.action_space_size self.gray_scale = config.gray_scale self.transform2string = config.transform2string self.sampled_algo = config.sampled_algo self.gumbel_algo = config.gumbel_algo self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder - if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: - # for vector obs input, e.g. classical control and box2d environments - self.zero_obs_shape = config.model.observation_shape - elif len(config.model.observation_shape) == 3: - # image obs input, e.g. atari environments - self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + if task_id is None: + if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape + elif len(config.model.observation_shape) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) + else: + if hasattr(config.model, "observation_shape_list"): + if isinstance(config.model.observation_shape_list[task_id], int) or len(config.model.observation_shape_list[task_id]) == 1: + # for vector obs input, e.g. classical control and box2d environments + self.zero_obs_shape = config.model.observation_shape_list[task_id] + elif len(config.model.observation_shape_list[task_id]) == 3: + # image obs input, e.g. atari environments + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape_list[task_id][-2], config.model.observation_shape_list[task_id][-1]) + else: + self.zero_obs_shape = (config.model.image_channel, config.model.observation_shape[-2], config.model.observation_shape[-1]) self.obs_segment = [] self.action_segment = [] diff --git a/lzero/mcts/tests/test_game_buffer.py b/lzero/mcts/tests/test_game_buffer.py index ea02dc5a2..5fbdd8f47 100644 --- a/lzero/mcts/tests/test_game_buffer.py +++ b/lzero/mcts/tests/test_game_buffer.py @@ -16,6 +16,11 @@ use_priority=True, action_type='fixed_action_space', game_segment_length=20, + model=dict( + action_space_size=6, + value_support_range=(-10, 10, 1), + reward_support_range=(-10, 10, 1), + ), ) ) diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 4e238a6b3..34abd3049 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -1,5 +1,5 @@ import copy -from typing import TYPE_CHECKING, List, Any, Union +from typing import TYPE_CHECKING, List, Any, Union, Optional import numpy as np import torch @@ -72,14 +72,13 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]], timestep: Union[int, List[Any]] + List[Any]], timestep: Union[int, List[Any]] = None, task_id: Optional[int] = None ) -> dict: """ Overview: - Perform Monte Carlo Tree Search (MCTS) for a batch of root nodes in parallel. + Perform Monte Carlo Tree Search (MCTS) for a batch of root nodes in parallel. This method utilizes the C++ implementation of the tree search for efficiency. Arguments: @@ -88,6 +87,7 @@ def search( - latent_state_roots (:obj:`List[Any]`): The hidden states of the root nodes. - to_play_batch (:obj:`Union[int, List[Any]]`): The list of players in self-play mode. - timestep (:obj:`Union[int, List[Any]]`): The step index of the environment in one episode. + - task_id (:obj:`Optional[int]`): The global task ID for the current environments. """ with torch.no_grad(): model.eval() @@ -138,6 +138,7 @@ def search( latent_states.append(latent_state_batch_in_search_path[ix][iy]) latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + # TODO: .long() is only for discrete action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() @@ -154,7 +155,23 @@ def search( # search_depth is used for rope in UniZero search_depth = results.get_search_len() # print(f'simulation_index:{simulation_index}, search_depth:{search_depth}, latent_state_index_in_search_path:{latent_state_index_in_search_path}') - network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) + if timestep is None: + # for UniZero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth) + else: + # for UniZero using RoPE + if task_id is not None: + # multi task setting + # network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep, task_id=task_id) # TODO: support RoPE + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -245,10 +262,9 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id: Optional[int] = None ) -> None: """ Overview: @@ -258,6 +274,7 @@ def search( - roots (:obj:`Any`): a batch of expanded root nodes - latent_state_roots (:obj:`list`): the hidden states of the roots - to_play_batch (:obj:`list`): the to_play_batch list used in in self-play-mode board games + - task_id (:obj:`Optional[int]`): The global task ID for the current environments. """ with torch.no_grad(): model.eval() @@ -318,6 +335,13 @@ def search( """ network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(latent_states, last_actions, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(latent_states, last_actions) + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) @@ -516,7 +540,6 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e """ return tree_muzero.Roots(active_collect_env_num, legal_actions) - # @profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None, diff --git a/lzero/mcts/tree_search/mcts_ctree_sampled.py b/lzero/mcts/tree_search/mcts_ctree_sampled.py index 02f591a1f..15b5f928d 100644 --- a/lzero/mcts/tree_search/mcts_ctree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ctree_sampled.py @@ -83,11 +83,11 @@ def roots( # @profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]], timestep: Union[int, List[Any]] + List[Any]], timestep: Union[int, List[Any]], task_id=None ) -> None: """ Overview: - Perform Monte Carlo Tree Search (MCTS) for a batch of root nodes in parallel. + Perform Monte Carlo Tree Search (MCTS) for a batch of root nodes in parallel. This method utilizes the C++ implementation of the tree search for efficiency. Arguments: @@ -96,6 +96,7 @@ def search( - latent_state_roots (:obj:`List[Any]`): The hidden states of the root nodes. - to_play_batch (:obj:`Union[int, List[Any]]`): The list of players in self-play mode. - timestep (:obj:`Union[int, List[Any]]`): The step index of the environment in one episode. + - task_id (:obj:`int`, optional): The global task ID for the current environments. """ with torch.no_grad(): model.eval() @@ -142,6 +143,7 @@ def search( latent_states.append(latent_state_batch_in_search_path[ix][iy]) latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + if self._cfg.model.continuous_action_space is True: # continuous action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device) @@ -159,9 +161,16 @@ def search( MCTS stage 3: Backup At the end of the simulation, the statistics along the trajectory are updated. """ + # search_depth is used for rope in UniZero + search_depth = results.get_search_len() # for Sampled UniZero - network_output = model.recurrent_inference(state_action_history, simulation_index, - latent_state_index_in_search_path, timestep) + # TODO: support RoPE + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) diff --git a/lzero/model/common.py b/lzero/model/common.py index 7b1bbeeae..31f254963 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -1,25 +1,26 @@ """ Overview: - In this Python file, we provide a collection of reusable model templates designed to streamline the development + This Python file provides a collection of reusable model templates designed to streamline the development process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and - customize their custom algorithms, ensuring efficient and effective development. - BTW, users can refer to the unittest of these model templates to learn how to use them. + customize their algorithms, ensuring efficient and effective development. + Users can refer to the unittest of these model templates to learn how to use them. """ import math from dataclasses import dataclass -from typing import Callable, List, Optional -from typing import Tuple +from typing import Callable, List, Optional, Tuple, Sequence import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init -from transformers import AutoModelForCausalLM, AutoTokenizer +from ditk import logging +# Assuming these imports are valid in the user's environment. +# If they are not, they should be replaced with the correct ones. from ding.torch_utils import MLP, ResBlock from ding.torch_utils.network.normalization import build_normalization -from ding.utils import SequenceType -from ditk import logging +from ding.utils import SequenceType, get_rank, get_world_size +from transformers import AutoModelForCausalLM, AutoTokenizer from ding.utils import set_pkg_seed, get_rank, get_world_size @@ -28,7 +29,7 @@ def MLP_V2( in_channels: int, hidden_channels: List[int], out_channels: int, - layer_fn: Callable = None, + layer_fn: Callable = nn.Linear, activation: Optional[nn.Module] = None, norm_type: Optional[str] = None, use_dropout: bool = False, @@ -36,118 +37,122 @@ def MLP_V2( output_activation: bool = True, output_norm: bool = True, last_linear_layer_init_zero: bool = False, -): +) -> nn.Sequential: """ Overview: - Create a multi-layer perceptron (MLP) using a list of hidden dimensions. Each layer consists of a fully + Creates a multi-layer perceptron (MLP) using a list of hidden dimensions. Each layer consists of a fully connected block with optional activation, normalization, and dropout. The final layer is configurable - to include or exclude activation, normalization, and dropout based on user preferences. - + to include or exclude activation and normalization. Arguments: - in_channels (:obj:`int`): Number of input channels (dimensionality of the input tensor). - hidden_channels (:obj:`List[int]`): A list specifying the number of channels for each hidden layer. - For example, [512, 256, 128] means the MLP will have three hidden layers with 512, 256, and 128 units, respectively. - out_channels (:obj:`int`): Number of output channels (dimensionality of the output tensor). - - layer_fn (:obj:`Callable`, optional): Layer function to construct layers (default is `nn.Linear`). - - activation (:obj:`nn.Module`, optional): Activation function to use after each layer - (e.g., `nn.ReLU`, `nn.Sigmoid`). Default is None (no activation). - - norm_type (:obj:`str`, optional): Type of normalization to apply after each layer. - If None, no normalization is applied. Supported values depend on the implementation of `build_normalization`. - - use_dropout (:obj:`bool`, optional): Whether to apply dropout after each layer. Default is False. - - dropout_probability (:obj:`float`, optional): The probability of setting elements to zero in dropout. Default is 0.5. - - output_activation (:obj:`bool`, optional): Whether to apply activation to the output layer. Default is True. - - output_norm (:obj:`bool`, optional): Whether to apply normalization to the output layer. Default is True. - - last_linear_layer_init_zero (:obj:`bool`, optional): Whether to initialize the weights and biases of the - last linear layer to zeros. This is commonly used in reinforcement learning for stable initial outputs. - + - layer_fn (:obj:`Callable`): The function to construct layers, defaults to `nn.Linear`. + - activation (:obj:`Optional[nn.Module]`): Activation function to use after each layer, defaults to None. + - norm_type (:obj:`Optional[str]`): Type of normalization to apply. If None, no normalization is applied. + - use_dropout (:obj:`bool`): Whether to apply dropout after each layer, defaults to False. + - dropout_probability (:obj:`float`): The probability for dropout, defaults to 0.5. + - output_activation (:obj:`bool`): Whether to apply activation to the output layer, defaults to True. + - output_norm (:obj:`bool`): Whether to apply normalization to the output layer, defaults to True. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer's weights and biases to zero. Returns: - block (:obj:`nn.Sequential`): A PyTorch `nn.Sequential` object containing the layers of the MLP. - - Notes: - - The final layer's normalization, activation, and dropout are controlled by `output_activation`, - `output_norm`, and `use_dropout`. - - If `last_linear_layer_init_zero` is True, the weights and biases of the last linear layer are initialized to 0. """ - assert len(hidden_channels) > 0, "The hidden_channels list must contain at least one element." - if layer_fn is None: - layer_fn = nn.Linear - - # Initialize the MLP block - block = [] - channels = [in_channels] + hidden_channels + [out_channels] - - # Build all layers except the final layer - for i, (in_channels, out_channels) in enumerate(zip(channels[:-2], channels[1:-1])): - block.append(layer_fn(in_channels, out_channels)) - if norm_type is not None: - block.append(build_normalization(norm_type, dim=1)(out_channels)) - if activation is not None: - block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - - # Build the final layer - in_channels = channels[-2] - out_channels = channels[-1] - block.append(layer_fn(in_channels, out_channels)) - - # Add optional normalization and activation for the final layer - if output_norm and norm_type is not None: - block.append(build_normalization(norm_type, dim=1)(out_channels)) - if output_activation and activation is not None: - block.append(activation) - if use_dropout: - block.append(nn.Dropout(dropout_probability)) - - # Initialize the weights and biases of the last linear layer to zero if specified + if not hidden_channels: + logging.warning("hidden_channels is empty, creating a single-layer MLP.") + + layers = [] + all_channels = [in_channels] + hidden_channels + [out_channels] + num_layers = len(all_channels) - 1 + + for i in range(num_layers): + is_last_layer = (i == num_layers - 1) + layers.append(layer_fn(all_channels[i], all_channels[i+1])) + + if not is_last_layer: + # Intermediate layers + if norm_type: + layers.append(build_normalization(norm_type, dim=1)(all_channels[i+1])) + if activation: + layers.append(activation) + if use_dropout: + layers.append(nn.Dropout(dropout_probability)) + else: + # Last layer + if output_norm and norm_type: + layers.append(build_normalization(norm_type, dim=1)(all_channels[i+1])) + if output_activation and activation: + layers.append(activation) + # Note: Dropout on the final output is usually not recommended unless for specific regularization purposes. + # The original logic applied it, so we keep it for consistency. + if use_dropout: + layers.append(nn.Dropout(dropout_probability)) + + # Initialize the last linear layer to zero if specified if last_linear_layer_init_zero: - for layer in reversed(block): + for layer in reversed(layers): if isinstance(layer, nn.Linear): nn.init.zeros_(layer.weight) nn.init.zeros_(layer.bias) break - return nn.Sequential(*block) + return nn.Sequential(*layers) + + +# --- Data-structures for Network Outputs --- -# use dataclass to make the output of network more convenient to use @dataclass class MZRNNNetworkOutput: - # output format of the MuZeroRNN model + """ + Overview: + Data structure for the output of the MuZeroRNN model. + """ value: torch.Tensor value_prefix: torch.Tensor policy_logits: torch.Tensor latent_state: torch.Tensor predict_next_latent_state: torch.Tensor - reward_hidden_state: Tuple[torch.Tensor] + reward_hidden_state: Tuple[torch.Tensor, torch.Tensor] @dataclass class EZNetworkOutput: - # output format of the EfficientZero model + """ + Overview: + Data structure for the output of the EfficientZero model. + """ value: torch.Tensor value_prefix: torch.Tensor policy_logits: torch.Tensor latent_state: torch.Tensor - reward_hidden_state: Tuple[torch.Tensor] + reward_hidden_state: Tuple[torch.Tensor, torch.Tensor] @dataclass class MZNetworkOutput: - # output format of the MuZero model + """ + Overview: + Data structure for the output of the MuZero model. + """ value: torch.Tensor reward: torch.Tensor policy_logits: torch.Tensor latent_state: torch.Tensor +# --- Core Network Components --- + class SimNorm(nn.Module): + """ + Overview: + Implements Simplicial Normalization as described in the paper: https://arxiv.org/abs/2204.00616. + It groups features and applies softmax to each group. + """ def __init__(self, simnorm_dim: int) -> None: """ - Overview: - Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616. Arguments: - - simnorm_dim (:obj:`int`): The dimension for simplicial normalization. + - simnorm_dim (:obj:`int`): The size of each group (simplex) to apply softmax over. """ super().__init__() self.dim = simnorm_dim @@ -155,185 +160,177 @@ def __init__(self, simnorm_dim: int) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ Overview: - Forward pass of the SimNorm layer. + Forward pass for SimNorm. Arguments: - - x (:obj:`torch.Tensor`): The input tensor to normalize. + - x (:obj:`torch.Tensor`): The input tensor. Returns: - - x (:obj:`torch.Tensor`): The normalized tensor. + - (:obj:`torch.Tensor`): The tensor after applying Simplicial Normalization. """ - shp = x.shape - # Ensure that there is at least one simplex to normalize across. - if shp[1] != 0: - x = x.view(*shp[:-1], -1, self.dim) - x = F.softmax(x, dim=-1) - return x.view(*shp) - else: + if x.shape[1] == 0: return x + # Reshape to (batch, groups, dim) + x_reshaped = x.view(*x.shape[:-1], -1, self.dim) + # Apply softmax over the last dimension (the simplex) + x_softmax = F.softmax(x_reshaped, dim=-1) + # Reshape back to the original tensor shape + return x_softmax.view(*x.shape) def __repr__(self) -> str: - """ - Overview: - String representation of the SimNorm layer. - Returns: - - output (:obj:`str`): The string representation. - """ return f"SimNorm(dim={self.dim})" -def AvgL1Norm(x, eps=1e-8): +def AvgL1Norm(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: """ Overview: - Normalize the input tensor by the L1 norm. + Normalizes a tensor by the mean of its absolute values (L1 norm) along the last dimension. Arguments: - x (:obj:`torch.Tensor`): The input tensor to normalize. - - eps (:obj:`float`): The epsilon value to prevent division by zero. + - eps (:obj:`float`): A small epsilon value to prevent division by zero. Returns: - - :obj:`torch.Tensor`: The normalized tensor. + - (:obj:`torch.Tensor`): The normalized tensor. """ - return x / x.abs().mean(-1, keepdim=True).clamp(min=eps) + return x / (x.abs().mean(dim=-1, keepdim=True) + eps) class FeatureAndGradientHook: + """ + Overview: + A utility class to capture and analyze features and gradients of a specific module during + the forward and backward passes. This is useful for debugging and understanding model dynamics. + """ - def __init__(self): + def __init__(self, module: nn.Module): """ - Overview: - Class to capture features and gradients at SimNorm. + Arguments: + - module (:obj:`nn.Module`): The PyTorch module to attach the hooks to. """ self.features_before = [] self.features_after = [] self.grads_before = [] self.grads_after = [] + self.forward_handler = module.register_forward_hook(self._forward_hook) + self.backward_handler = module.register_full_backward_hook(self._backward_hook) - def setup_hooks(self, model): - # Hooks to capture features and gradients at SimNorm - self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook) - self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook) - - def forward_hook(self, module, input, output): + def _forward_hook(self, module: nn.Module, inputs: Tuple[torch.Tensor], output: torch.Tensor) -> None: + """Hook to capture input and output features during the forward pass.""" with torch.no_grad(): - self.features_before.append(input[0]) - self.features_after.append(output) + self.features_before.append(inputs[0].clone().detach()) + self.features_after.append(output.clone().detach()) - def backward_hook(self, module, grad_input, grad_output): + def _backward_hook(self, module: nn.Module, grad_inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]) -> None: + """Hook to capture input and output gradients during the backward pass.""" with torch.no_grad(): - self.grads_before.append(grad_input[0] if grad_input[0] is not None else None) - self.grads_after.append(grad_output[0] if grad_output[0] is not None else None) + self.grads_before.append(grad_inputs[0].clone().detach() if grad_inputs[0] is not None else None) + self.grads_after.append(grad_outputs[0].clone().detach() if grad_outputs[0] is not None else None) - def analyze(self): - # Calculate L2 norms of features - l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before])) - l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after])) + def analyze(self) -> Tuple[float, float, float, float]: + """ + Overview: + Analyzes the captured features and gradients by computing their average L2 norms. + This method clears the stored data after analysis to free memory. + Returns: + - (:obj:`Tuple[float, float, float, float]`): A tuple containing the L2 norms of + (features_before, features_after, grads_before, grads_after). + """ + if not self.features_before: + return 0.0, 0.0, 0.0, 0.0 - # Calculate norms of gradients - grad_norm_before = torch.mean( - torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None])) - grad_norm_after = torch.mean( - torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None])) + l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2) for f in self.features_before])).item() + l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2) for f in self.features_after])).item() - # Clear stored data and delete tensors to free memory - self.clear_data() + valid_grads_before = [g for g in self.grads_before if g is not None] + grad_norm_before = torch.mean(torch.stack([torch.norm(g, p=2) for g in valid_grads_before])).item() if valid_grads_before else 0.0 - # Optionally clear CUDA cache - if torch.cuda.is_available(): - torch.cuda.empty_cache() + valid_grads_after = [g for g in self.grads_after if g is not None] + grad_norm_after = torch.mean(torch.stack([torch.norm(g, p=2) for g in valid_grads_after])).item() if valid_grads_after else 0.0 + self.clear_data() return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after - def clear_data(self): - del self.features_before[:] - del self.features_after[:] - del self.grads_before[:] - del self.grads_after[:] + def clear_data(self) -> None: + """Clears all stored feature and gradient tensors to free up memory.""" + self.features_before.clear() + self.features_after.clear() + self.grads_before.clear() + self.grads_after.clear() + if torch.cuda.is_available(): + torch.cuda.empty_cache() - def remove_hooks(self): + def remove_hooks(self) -> None: + """Removes the registered forward and backward hooks.""" self.forward_handler.remove() self.backward_handler.remove() class DownSample(nn.Module): + """ + Overview: + A convolutional network for downsampling image-based observations, commonly used in Atari environments. + It consists of a series of convolutional, normalization, and residual blocks. + """ - def __init__(self, observation_shape: SequenceType, out_channels: int, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - num_resblocks: int = 1, - ) -> None: + def __init__( + self, + observation_shape: Sequence[int], + out_channels: int, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: str = 'BN', + num_resblocks: int = 1, + ) -> None: """ - Overview: - Define downSample convolution network. Encode the observation into hidden state. - This network is often used in video games like Atari. In board games like go and chess, - we don't need this module. Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96] - for video games like atari, RGB 3 channel times stack 4 frames. - - out_channels (:obj:`int`): The output channels of output hidden state. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'. - - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1. + - observation_shape (:obj:`Sequence[int]`): The shape of the input observation, e.g., (C, H, W). + - out_channels (:obj:`int`): The number of output channels. + - activation (:obj:`nn.Module`): The activation function to use. + - norm_type (:obj:`str`): The type of normalization ('BN' or 'LN'). + - num_resblocks (:obj:`int`): The number of residual blocks in each stage. """ super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + if norm_type not in ['BN', 'LN']: + raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.") + # The original design was fixed to 1 resblock per stage. + if num_resblocks != 1: + logging.warning(f"DownSample is designed for num_resblocks=1, but got {num_resblocks}.") self.observation_shape = observation_shape - self.conv1 = nn.Conv2d( - observation_shape[0], - out_channels // 2, - kernel_size=3, - stride=2, - padding=1, - bias=False, # disable bias for better convergence - ) - if norm_type == 'BN': - self.norm1 = nn.BatchNorm2d(out_channels // 2) - elif norm_type == 'LN': - self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2], - eps=1e-5) + self.activation = activation - self.resblocks1 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels // 2, - activation=activation, - norm_type=norm_type, - res_type='basic', - bias=False - ) for _ in range(num_resblocks) - ] - ) - self.downsample_block = ResBlock( - in_channels=out_channels // 2, - out_channels=out_channels, - activation=activation, - norm_type=norm_type, - res_type='downsample', - bias=False - ) - self.resblocks2 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_resblocks) - ] - ) + # Initial convolution: stride 2 + self.conv1 = nn.Conv2d(observation_shape[0], out_channels // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.norm1 = build_normalization(norm_type, dim=2)(out_channels // 2) + + # Stage 1 with residual blocks + self.resblocks1 = nn.ModuleList([ + ResBlock(in_channels=out_channels // 2, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_resblocks) + ]) + + # Downsample block: stride 2 + self.downsample_block = ResBlock(in_channels=out_channels // 2, out_channels=out_channels, activation=activation, norm_type=norm_type, res_type='downsample', bias=False) + + # Stage 2 with residual blocks + self.resblocks2 = nn.ModuleList([ + ResBlock(in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_resblocks) + ]) + + # Pooling 1: stride 2 self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.resblocks3 = nn.ModuleList( - [ - ResBlock( - in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(1) - ] - ) + + # Stage 3 with residual blocks + self.resblocks3 = nn.ModuleList([ + ResBlock(in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_resblocks) + ]) + + # Final pooling for specific input sizes self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.activation = activation def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ - H is height. - - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ - output width, H_ is output height. + - x (:obj:`torch.Tensor`): (B, C_in, H, W) + - output (:obj:`torch.Tensor`): (B, C_out, H_out, W_out) """ x = self.conv1(x) x = self.norm1(x) @@ -341,27 +338,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for block in self.resblocks1: x = block(x) + x = self.downsample_block(x) for block in self.resblocks2: x = block(x) + x = self.pooling1(x) for block in self.resblocks3: x = block(x) - # 64, 84, 96 are the most common observation shapes in Atari games. - if self.observation_shape[1] == 64: - output = x - elif self.observation_shape[1] == 84: - x = self.pooling2(x) - output = x - elif self.observation_shape[1] == 96: - x = self.pooling2(x) - output = x + # This part handles specific Atari resolutions. A more general approach might be desirable, + # but we maintain original behavior. + obs_height = self.observation_shape[1] + if obs_height == 64: + return x + elif obs_height in [84, 96]: + return self.pooling2(x) else: - raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. " - f"You should transform the observation shape to 64 or 96 in the env.") - - return output + raise NotImplementedError( + f"DownSample for observation height {obs_height} is not implemented. " + f"Supported heights are 64, 84, 96." + ) class QwenNetwork(nn.Module): def __init__(self, @@ -482,10 +479,6 @@ def __init__(self, final_norm_option_in_encoder: str = "layernorm", tokenizer=None): """ - Overview: - This class defines a language representation network that utilizes a pretrained Hugging Face model. - The network outputs embeddings with the specified dimension and can optionally use SimNorm or LayerNorm - for normalization at the final stage to ensure training stability. Arguments: - model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'. - embedding_size (int): The dimension of the output embeddings. Default is 768. @@ -494,11 +487,9 @@ def __init__(self, - tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model. """ super().__init__() - from transformers import AutoModel, AutoTokenizer - logging.info(f"Loading model from: {model_path}") - # In distributed training, only the rank 0 process downloads the model, and other processes load from cache to speed up startup. + # In distributed settings, ensure only rank 0 downloads the model/tokenizer. if get_rank() == 0: self.pretrained_model = AutoModel.from_pretrained(model_path) @@ -508,22 +499,19 @@ def __init__(self, if get_rank() != 0: self.pretrained_model = AutoModel.from_pretrained(model_path) - if tokenizer is None: - # Only rank 0 downloads the tokenizer, and then other processes load it from cache. - if get_rank() == 0: - self.tokenizer = AutoTokenizer.from_pretrained(model_path) - if get_world_size() > 1: - torch.distributed.barrier() - if get_rank() != 0: + if get_rank() != 0: + logging.info(f"Worker process is loading model from cache: {model_path}") + self.model = AutoModel.from_pretrained(model_path) + if tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(model_path) - else: + + if tokenizer is not None: self.tokenizer = tokenizer - # Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings). self.embedding_size = embedding_size self.embed_proj_head = nn.Linear(self.pretrained_model.config.hidden_size, self.embedding_size) - # # Select the normalization method based on the final_norm_option_in_encoder parameter. + # Select the normalization method based on the final_norm_option_in_encoder parameter. if final_norm_option_in_encoder.lower() == "simnorm": self.norm = SimNorm(simnorm_dim=group_size) elif final_norm_option_in_encoder.lower() == "layernorm": @@ -534,22 +522,18 @@ def __init__(self, def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: """ - Forward Propagation: - Compute the language representation based on the input token sequence. - The [CLS] token’s representation is extracted from the output of the pretrained model, - then passed through a linear projection and final normalization layer (SimNorm or LayerNorm). - + Overview: + Computes language representation from input token IDs. Arguments: - - x (torch.Tensor): Input token sequence of shape [batch_size, seq_len]. - - no_grad (bool): Whether to run in no-gradient mode for memory efficiency. Default is True. + - x (:obj:`torch.Tensor`): Input token sequence of shape (B, seq_len). + - no_grad (:obj:`bool`): If True, run the transformer model in `torch.no_grad()` context. Returns: - - torch.Tensor: The processed language embedding with shape [batch_size, embedding_size]. + - (:obj:`torch.Tensor`): The final language embedding of shape (B, embedding_size). """ # Construct the attention mask to exclude padding tokens. attention_mask = x != self.tokenizer.pad_token_id - # Use no_grad context if specified to disable gradient computation. if no_grad: with torch.no_grad(): x = x.long() # Ensure the input tensor is of type long. @@ -561,9 +545,7 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: outputs = self.pretrained_model(x, attention_mask=attention_mask) cls_embedding = outputs.last_hidden_state[:, 0, :] - # Apply linear projection to obtain the desired output dimension. cls_embedding = self.embed_proj_head(cls_embedding) - # Normalize the embeddings using the selected normalization layer (SimNorm or LayerNorm) to ensure training stability. cls_embedding = self.norm(cls_embedding) return cls_embedding @@ -605,8 +587,12 @@ def __init__( """ super().__init__() assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - logging.info(f"Using norm type: {norm_type}") - logging.info(f"Using activation type: {activation}") + + # Only log from rank 0 to avoid excessive output in distributed training + from ding.utils import get_rank + if get_rank() == 0: + logging.info(f"Using norm type: {norm_type}") + logging.info(f"Using activation type: {activation}") self.observation_shape = observation_shape self.downsample = downsample @@ -640,20 +626,36 @@ def __init__( self.activation = activation self.embedding_dim = embedding_dim + # ==================== Modification Start ==================== if self.observation_shape[1] == 64: - self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + # Fix: Replace hardcoded 64 with num_channels + self.last_linear = nn.Linear(num_channels * 8 * 8, self.embedding_dim, bias=False) elif self.observation_shape[1] in [84, 96]: - self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + # Fix: Replace hardcoded 64 with num_channels + self.last_linear = nn.Linear(num_channels * 6 * 6, self.embedding_dim, bias=False) + # ==================== Modification End ==================== - self.final_norm_option_in_encoder = final_norm_option_in_encoder - if self.final_norm_option_in_encoder == 'LayerNorm': + self.final_norm_option_in_encoder=final_norm_option_in_encoder + # Initialize final_norm uniformly in __init__ + if self.final_norm_option_in_encoder in ['LayerNorm', 'LayerNorm_Tanh']: self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) + elif self.final_norm_option_in_encoder == 'LayerNormNoAffine': + self.final_norm = nn.LayerNorm( + self.embedding_dim, eps=1e-5, elementwise_affine=False + ) elif self.final_norm_option_in_encoder == 'SimNorm': + # Ensure SimNorm is defined self.final_norm = SimNorm(simnorm_dim=group_size) + elif self.final_norm_option_in_encoder == 'L2Norm': + # Directly instantiate our custom L2Norm module + self.final_norm = L2Norm(eps=1e-6) + elif self.final_norm_option_in_encoder is None: + # If no normalization is needed, set to nn.Identity() or None + self.final_norm = nn.Identity() else: raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") - + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: @@ -679,90 +681,75 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.view(-1, self.embedding_dim) # NOTE: very important for training stability. - x = self.final_norm(x) + # x = self.final_norm(x) + + # Uniformly call self.final_norm in forward + # This structure is clearer and more extensible + if self.final_norm is not None: + x = self.final_norm(x) + + # Special handling for LayerNorm_Tanh + if self.final_norm_option_in_encoder == 'LayerNorm_Tanh': + x = torch.tanh(x) return x class RepresentationNetwork(nn.Module): - + """ + Overview: + The standard representation network used in MuZero. It encodes a 2D image observation + into a latent state, which retains its spatial dimensions. + """ def __init__( self, - observation_shape: SequenceType = (4, 96, 96), + observation_shape: Sequence[int] = (4, 96, 96), num_res_blocks: int = 1, num_channels: int = 64, downsample: bool = True, activation: nn.Module = nn.ReLU(inplace=True), norm_type: str = 'BN', - embedding_dim: int = 256, - group_size: int = 8, use_sim_norm: bool = False, + group_size: int = 8, ) -> None: """ - Overview: - Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state. - Currently, the network only supports obs images with both a width and height of 96. Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96] - for video games like atari, 1 gray channel times stack 4 frames. + - observation_shape (:obj:`Sequence[int]`): Shape of the input observation (C, H, W). - num_res_blocks (:obj:`int`): The number of residual blocks. - - num_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - - embedding_dim (:obj:`int`): The dimension of the output hidden state. - - group_size (:obj:`int`): The size of group in the SimNorm layer. - - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False. + - num_channels (:obj:`int`): The number of channels in the convolutional layers. + - downsample (:obj:`bool`): Whether to use the `DownSample` module. + - activation (:obj:`nn.Module`): The activation function to use. + - norm_type (:obj:`str`): Normalization type ('BN' or 'LN'). + - use_sim_norm (:obj:`bool`): Whether to apply a final `SimNorm` layer. + - group_size (:obj:`int`): Group size for `SimNorm`. """ super().__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + if norm_type not in ['BN', 'LN']: + raise ValueError(f"Unsupported norm_type: {norm_type}. Must be 'BN' or 'LN'.") self.downsample = downsample + self.activation = activation + if self.downsample: - self.downsample_net = DownSample( - observation_shape, - num_channels, - activation=activation, - norm_type=norm_type, - ) + self.downsample_net = DownSample(observation_shape, num_channels, activation, norm_type) else: self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.norm = build_normalization(norm_type, dim=2)(num_channels) - if norm_type == 'BN': - self.norm = nn.BatchNorm2d(num_channels) - elif norm_type == 'LN': - if downsample: - self.norm = nn.LayerNorm( - [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], - eps=1e-5) - else: - self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) - - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - self.activation = activation + self.resblocks = nn.ModuleList([ + ResBlock(in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False) + for _ in range(num_res_blocks) + ]) self.use_sim_norm = use_sim_norm - if self.use_sim_norm: - self.embedding_dim = embedding_dim self.sim_norm = SimNorm(simnorm_dim=group_size) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ - H is height. - - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ - output width, H_ is output height. + - x (:obj:`torch.Tensor`): (B, C_in, H, W) + - output (:obj:`torch.Tensor`): (B, C_out, H_out, W_out) """ if self.downsample: x = self.downsample_net(x) @@ -775,56 +762,55 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = block(x) if self.use_sim_norm: - # NOTE: very important. - # for atari 64,8,8 = 4096 -> 768 - x = self.sim_norm(x) - + # Flatten the spatial dimensions, apply SimNorm, and then reshape back. + b, c, h, w = x.shape + x_flat = x.view(b, c * h * w) + x_norm = self.sim_norm(x_flat) + x = x_norm.view(b, c, h, w) + return x class RepresentationNetworkMLP(nn.Module): - + """ + Overview: + An MLP-based representation network for encoding vector observations into a latent state. + """ def __init__( self, - observation_shape: int, + observation_dim: int, hidden_channels: int = 64, - layer_num: int = 2, + num_layers: int = 2, activation: nn.Module = nn.GELU(approximate='tanh'), norm_type: Optional[str] = 'BN', group_size: int = 8, final_norm_option_in_encoder: str = 'LayerNorm', # TODO ) -> torch.Tensor: """ - Overview: - Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ - with Multi-Layer Perceptron (MLP). Arguments: - - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - hidden_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - observation_dim (:obj:`int`): The dimension of the input vector observation. + - hidden_channels (:obj:`int`): The number of neurons in the hidden and output layers. + - num_layers (:obj:`int`): The total number of layers in the MLP. + - activation (:obj:`nn.Module`): The activation function to use. + - norm_type (:obj:`Optional[str]`): The type of normalization ('BN', 'LN', or None). + - group_size (:obj:`int`): The group size for the final `SimNorm` layer. """ super().__init__() - self.fc_representation = MLP( - in_channels=observation_shape, - hidden_channels=hidden_channels, + # Creating hidden layers list for MLP_V2 + hidden_layers = [hidden_channels] * (num_layers - 1) if num_layers > 1 else [] + + self.fc_representation = MLP_V2( + in_channels=observation_dim, + hidden_channels=hidden_layers, out_channels=hidden_channels, - layer_num=layer_num, activation=activation, norm_type=norm_type, - # don't use activation and norm in the last layer of representation network is important for convergence. output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=True, ) - # # Select the normalization method based on the final_norm_option_in_encoder parameter. + # Select the normalization method based on the final_norm_option_in_encoder parameter. if final_norm_option_in_encoder.lower() == "simnorm": self.norm = SimNorm(simnorm_dim=group_size) elif final_norm_option_in_encoder.lower() == "layernorm": @@ -836,8 +822,8 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. - - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + - x (:obj:`torch.Tensor`): (B, observation_dim) + - output (:obj:`torch.Tensor`): (B, hidden_channels) """ x = self.fc_representation(x) x = self.norm(x) @@ -846,228 +832,232 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LatentDecoder(nn.Module): - - def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')): + """ + Overview: + A decoder network that reconstructs a 2D image from a 1D latent embedding. + It acts as the inverse of a representation network like `RepresentationNetworkUniZero`. + """ + def __init__( + self, + embedding_dim: int, + output_shape: Tuple[int, int, int], + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh') + ): """ - Overview: - Decoder network used in UniZero. Decode the latent state into 2D image obs. Arguments: - - embedding_dim (:obj:`int`): The dimension of the latent state. - - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - num_channels (:obj:`int`): The channel of output hidden state. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). + - embedding_dim (:obj:`int`): The dimension of the input latent embedding. + - output_shape (:obj:`Tuple[int, int, int]`): The shape of the target output image (C, H, W). + - num_channels (:obj:`int`): The base number of channels for the initial upsampling stage. + - activation (:obj:`nn.Module`): The activation function to use. """ super().__init__() self.embedding_dim = embedding_dim - self.output_shape = output_shape # (C, H, W) - self.num_channels = num_channels - self.activation = activation - - # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256 - # We will reverse the process of the representation network - self.initial_size = ( - num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder - self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size)) + self.output_shape = output_shape + + # This should match the spatial size of the encoder's feature map before flattening. + # Assuming a total downsampling factor of 8 (e.g., for a 64x64 -> 8x8 encoder). + self.initial_h = output_shape[1] // 8 + self.initial_w = output_shape[2] // 8 + self.initial_size = (num_channels, self.initial_h, self.initial_w) + + self.fc = nn.Linear(embedding_dim, np.prod(self.initial_size)) - # Upsampling blocks - self.conv_blocks = nn.ModuleList([ - # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4) + self.deconv_blocks = nn.Sequential( + # Block 1: (C, H/8, W/8) -> (C/2, H/4, W/4) nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1), - self.activation, + activation, nn.BatchNorm2d(num_channels // 2), - # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2) - nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, - output_padding=1), - self.activation, + # Block 2: (C/2, H/4, W/4) -> (C/4, H/2, W/2) + nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, output_padding=1), + activation, nn.BatchNorm2d(num_channels // 4), - # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W) - nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, - output_padding=1), - ]) - # TODO: last layer use sigmoid? + # Block 3: (C/4, H/2, W/2) -> (output_C, H, W) + nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1), + # A final activation like Sigmoid or Tanh is often used if pixel values are in a fixed range [0,1] or [-1,1]. + # We omit it here to maintain consistency with the original code. + ) def forward(self, embeddings: torch.Tensor) -> torch.Tensor: - # Map embeddings back to the image space - x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8) - x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8) - - # Apply conv blocks - for block in self.conv_blocks: - x = block(x) # Upsample progressively - - # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2]) + """ + Shapes: + - embeddings (:obj:`torch.Tensor`): (B, embedding_dim) + - output (:obj:`torch.Tensor`): (B, C, H, W) + """ + x = self.fc(embeddings) + x = x.view(-1, *self.initial_size) + x = self.deconv_blocks(x) return x -class LatentEncoderForMemoryEnv(nn.Module): +# --- Networks for MemoryEnv --- +class LatentEncoderForMemoryEnv(nn.Module): + """ + Overview: + An encoder for the MemoryEnv, converting a small image observation into a latent embedding. + It uses a series of convolutions followed by adaptive average pooling. + """ def __init__( self, - image_shape=(3, 5, 5), - embedding_size=100, - channels=[16, 32, 64], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], + image_shape: Tuple[int, int, int] = (3, 5, 5), + embedding_size: int = 100, + channels: List[int] = [16, 32, 64], + kernel_sizes: List[int] = [3, 3, 3], + strides: List[int] = [1, 1, 1], activation: nn.Module = nn.GELU(approximate='tanh'), - normalize_pixel=False, + normalize_pixel: bool = False, group_size: int = 8, - **kwargs, ): """ - Overview: - Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state. Arguments: - - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - embedding_size (:obj:`int`): The dimension of the latent state. - - channels (:obj:`List[int]`): The channel of output hidden state. - - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. - - strides (:obj:`List[int]`): The stride of convolution layers. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \ - Use the inplace operation to speed up. - - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False. - - group_size (:obj:`int`): The dimension for simplicial normalization + - image_shape (:obj:`Tuple[int, int, int]`): Shape of the input image (C, H, W). + - embedding_size (:obj:`int`): Dimension of the output latent embedding. + - channels (:obj:`List[int]`): List of output channels for each convolutional layer. + - kernel_sizes (:obj:`List[int]`): List of kernel sizes for each convolutional layer. + - strides (:obj:`List[int]`): List of strides for each convolutional layer. + - activation (:obj:`nn.Module`): Activation function to use. + - normalize_pixel (:obj:`bool`): Whether to normalize input pixel values to [0, 1]. + - group_size (:obj:`int`): Group size for the final `SimNorm` layer. """ - super(LatentEncoderForMemoryEnv, self).__init__() - self.shape = image_shape - self.channels = [image_shape[0]] + list(channels) + super().__init__() + self.normalize_pixel = normalize_pixel + all_channels = [image_shape[0]] + channels layers = [] - for i in range(len(self.channels) - 1): - layers.append( - nn.Conv2d( - self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], - padding=kernel_sizes[i] // 2 # keep the same size of feature map - ) - ) - layers.append(nn.BatchNorm2d(self.channels[i + 1])) - layers.append(activation) - + for i in range(len(channels)): + layers.extend([ + nn.Conv2d(all_channels[i], all_channels[i+1], kernel_sizes[i], strides[i], padding=kernel_sizes[i]//2), + nn.BatchNorm2d(all_channels[i+1]), + activation + ]) layers.append(nn.AdaptiveAvgPool2d(1)) - self.cnn = nn.Sequential(*layers) - self.linear = nn.Sequential( - nn.Linear(self.channels[-1], embedding_size, bias=False), - ) - init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu') + + self.linear = nn.Linear(channels[-1], embedding_size, bias=False) + init.kaiming_normal_(self.linear.weight, mode='fan_out', nonlinearity='relu') - self.normalize_pixel = normalize_pixel self.sim_norm = SimNorm(simnorm_dim=group_size) - def forward(self, image): + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - image (:obj:`torch.Tensor`): (B, C, H, W) + - output (:obj:`torch.Tensor`): (B, embedding_size) + """ if self.normalize_pixel: - image = image / 255.0 - x = self.cnn(image.float()) # (B, C, 1, 1) - x = torch.flatten(x, start_dim=1) # (B, C) - x = self.linear(x) # (B, embedding_size) + image = image.float() / 255.0 + + x = self.cnn(image.float()) + x = torch.flatten(x, start_dim=1) + x = self.linear(x) x = self.sim_norm(x) return x class LatentDecoderForMemoryEnv(nn.Module): - + """ + Overview: + A decoder for the MemoryEnv, reconstructing a small image from a latent embedding. + It uses a linear layer followed by a series of transposed convolutions. + """ def __init__( self, - image_shape=(3, 5, 5), - embedding_size=256, - channels=[64, 32, 16], - kernel_sizes=[3, 3, 3], - strides=[1, 1, 1], + image_shape: Tuple[int, int, int] = (3, 5, 5), + embedding_size: int = 256, + channels: List[int] = [64, 32, 16], + kernel_sizes: List[int] = [3, 3, 3], + strides: List[int] = [1, 1, 1], activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), - **kwargs, ): """ - Overview: - Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs. Arguments: - - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] - for video games like atari, RGB 3 channel times stack 4 frames. - - embedding_size (:obj:`int`): The dimension of the latent state. - - channels (:obj:`List[int]`): The channel of output hidden state. - - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. - - strides (:obj:`List[int]`): The stride of convolution layers. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \ - Use the inplace operation to speed up. + - image_shape (:obj:`Tuple[int, int, int]`): Shape of the target output image (C, H, W). + - embedding_size (:obj:`int`): Dimension of the input latent embedding. + - channels (:obj:`List[int]`): List of channels for each deconvolutional layer. + - kernel_sizes (:obj:`List[int]`): List of kernel sizes. + - strides (:obj:`List[int]`): List of strides. + - activation (:obj:`nn.Module`): Activation function for intermediate layers. """ - super(LatentDecoderForMemoryEnv, self).__init__() + super().__init__() self.shape = image_shape - self.channels = list(channels) + [image_shape[0]] - + self.deconv_channels = channels + [image_shape[0]] + self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2]) layers = [] - for i in range(len(self.channels) - 1): + for i in range(len(self.deconv_channels) - 1): layers.append( nn.ConvTranspose2d( - self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], - padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1 + self.deconv_channels[i], self.deconv_channels[i+1], kernel_sizes[i], strides[i], + padding=kernel_sizes[i]//2, output_padding=strides[i]-1 ) ) - if i < len(self.channels) - 2: - layers.append(nn.BatchNorm2d(self.channels[i + 1])) - layers.append(activation) + if i < len(self.deconv_channels) - 2: + layers.extend([nn.BatchNorm2d(self.deconv_channels[i+1]), activation]) else: + # Final layer uses Sigmoid to output pixel values in [0, 1]. layers.append(nn.Sigmoid()) - self.deconv = nn.Sequential(*layers) - def forward(self, embedding): + def forward(self, embedding: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - embedding (:obj:`torch.Tensor`): (B, embedding_size) + - output (:obj:`torch.Tensor`): (B, C, H, W) + """ x = self.linear(embedding) - x = x.view(-1, self.channels[0], self.shape[1], self.shape[2]) - x = self.deconv(x) # (B, C, H, W) + x = x.view(-1, self.deconv_channels[0], self.shape[1], self.shape[2]) + x = self.deconv(x) return x class VectorDecoderForMemoryEnv(nn.Module): - + """ + Overview: + An MLP-based decoder for MemoryEnv, reconstructing a vector observation from a latent embedding. + """ def __init__( self, embedding_dim: int, - output_shape: SequenceType, + output_dim: int, hidden_channels: int = 64, - layer_num: int = 2, - activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO + num_layers: int = 2, + activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), norm_type: Optional[str] = 'BN', - ) -> torch.Tensor: + ) -> None: """ - Overview: - Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs. Arguments: - - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. - - num_res_blocks (:obj:`int`): The number of residual blocks. - - hidden_channels (:obj:`int`): The channel of output hidden state. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ - defaults to True. This option is often used in video games like Atari. In board games like go, \ - we don't need this module. - - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ - Use the inplace operation to speed up. - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): Dimension of the input latent embedding. + - output_dim (:obj:`int`): Dimension of the target output vector. + - hidden_channels (:obj:`int`): Number of neurons in the hidden layers. + - num_layers (:obj:`int`): Total number of layers in the MLP. + - activation (:obj:`nn.Module`): Activation function to use. + - norm_type (:obj:`Optional[str]`): Normalization type ('BN', 'LN', or None). """ super().__init__() - self.fc_representation = MLP( + hidden_layers = [hidden_channels] * (num_layers - 1) if num_layers > 1 else [] + + self.fc_decoder = MLP_V2( in_channels=embedding_dim, - hidden_channels=hidden_channels, - out_channels=output_shape, - layer_num=layer_num, + hidden_channels=hidden_layers, + out_channels=output_dim, activation=activation, norm_type=norm_type, - # don't use activation and norm in the last layer of representation network is important for convergence. output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Shapes: - - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. - - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + - x (:obj:`torch.Tensor`): (B, embedding_dim) + - output (:obj:`torch.Tensor`): (B, output_dim) """ - x = self.fc_representation(x) - return x + return self.fc_decoder(x) +# --- Prediction Networks --- class PredictionNetwork(nn.Module): @@ -1207,232 +1197,77 @@ def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tenso class PredictionNetworkMLP(nn.Module): - + """ + Overview: + An MLP-based prediction network that predicts policy and value from a 1D latent state. + """ def __init__( self, - action_space_size, - num_channels, + action_space_size: int, + num_channels: int, common_layer_num: int = 2, - value_head_hidden_channels: SequenceType = [32], - policy_head_hidden_channels: SequenceType = [32], + value_head_hidden_channels: List[int] = [32], + policy_head_hidden_channels: List[int] = [32], output_support_size: int = 601, last_linear_layer_init_zero: bool = True, - activation: Optional[nn.Module] = nn.ReLU(inplace=True), + activation: nn.Module = nn.ReLU(inplace=True), norm_type: Optional[str] = 'BN', ): """ - Overview: - The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), - which is used to predict value and policy by the given latent state. Arguments: - - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ - space, it is the number of discrete actions. - - num_channels (:obj:`int`): The channels of latent states. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - action_space_size: (:obj:`int`): The size of the action space. + - num_channels (:obj:`int`): The dimension of the input latent state. + - common_layer_num (:obj:`int`): Number of layers in the shared backbone MLP. + - value_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the value MLP head. + - policy_head_hidden_channels (:obj:`List[int]`): Hidden layer sizes for the policy MLP head. + - output_support_size (:obj:`int`): The size of the categorical value distribution. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last layer of heads to zero. + - activation (:obj:`nn.Module`): The activation function. + - norm_type (:obj:`Optional[str]`): The normalization type. """ super().__init__() - self.num_channels = num_channels - - # ******* common backbone ****** - self.fc_prediction_common = MLP( - in_channels=self.num_channels, - hidden_channels=self.num_channels, - out_channels=self.num_channels, - layer_num=common_layer_num, + + common_hidden = [num_channels] * (common_layer_num - 1) if common_layer_num > 1 else [] + self.fc_prediction_common = MLP_V2( + in_channels=num_channels, + hidden_channels=common_hidden, + out_channels=num_channels, activation=activation, norm_type=norm_type, output_activation=True, output_norm=True, - # last_linear_layer_init_zero=False is important for convergence last_linear_layer_init_zero=False, ) - # ******* value and policy head ****** self.fc_value_head = MLP_V2( - in_channels=self.num_channels, + in_channels=num_channels, hidden_channels=value_head_hidden_channels, out_channels=output_support_size, activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) self.fc_policy_head = MLP_V2( - in_channels=self.num_channels, + in_channels=num_channels, hidden_channels=policy_head_hidden_channels, out_channels=action_space_size, activation=activation, norm_type=norm_type, output_activation=False, output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - - def forward(self, latent_state: torch.Tensor): - """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). - """ - x_prediction_common = self.fc_prediction_common(latent_state) - - value = self.fc_value_head(x_prediction_common) - policy = self.fc_policy_head(x_prediction_common) - return policy, value - - -class PredictionHiddenNetwork(nn.Module): - - def __init__( - self, - observation_shape: SequenceType, - action_space_size: int, - num_res_blocks: int, - num_channels: int, - value_head_channels: int, - policy_head_channels: int, - value_head_hidden_channels: int, - policy_head_hidden_channels: int, - output_support_size: int, - flatten_input_size_for_value_head: int, - flatten_input_size_for_policy_head: int, - downsample: bool = False, - last_linear_layer_init_zero: bool = True, - activation: nn.Module = nn.ReLU(inplace=True), - norm_type: Optional[str] = 'BN', - gru_hidden_size: int = 512, - ) -> None: - """ - Overview: - The definition of policy and value prediction network, which is used to predict value and policy by the - given latent state. - Arguments: - - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. - - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. - - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. - - num_channels (:obj:`int`): The channels of hidden states. - - value_head_channels (:obj:`int`): The channels of value head. - - policy_head_channels (:obj:`int`): The channels of policy head. - - value_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). - - policy_head_hidden_channels (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). - - output_support_size (:obj:`int`): The size of categorical value output. - - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ - - flatten_input_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the value head. - - flatten_input_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ - of the policy head. - - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. - - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ - dynamics/prediction mlp, default sets it to True. - - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ - operation to speedup, e.g. ReLU(inplace=True). - - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. - """ - super(PredictionHiddenNetwork, self).__init__() - assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" - - self.observation_shape = observation_shape - self.gru_hidden_size = gru_hidden_size - self.resblocks = nn.ModuleList( - [ - ResBlock( - in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False - ) for _ in range(num_res_blocks) - ] - ) - - self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) - self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) - - if norm_type == 'BN': - self.norm_value = nn.BatchNorm2d(value_head_channels) - self.norm_policy = nn.BatchNorm2d(policy_head_channels) - elif norm_type == 'LN': - if downsample: - self.norm_value = nn.LayerNorm( - [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16), - math.ceil(observation_shape[-1] / 16)], eps=1e-5) - else: - self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], - eps=1e-5) - - self.flatten_input_size_for_value_head = flatten_input_size_for_value_head - self.flatten_input_size_for_policy_head = flatten_input_size_for_policy_head - - self.activation = activation - - self.fc_value = MLP( - in_channels=self.flatten_input_size_for_value_head + self.gru_hidden_size, - hidden_channels=value_head_hidden_channels[0], - out_channels=output_support_size, - layer_num=len(value_head_hidden_channels) + 1, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. - last_linear_layer_init_zero=last_linear_layer_init_zero - ) - self.fc_policy = MLP( - in_channels=self.flatten_input_size_for_policy_head + self.gru_hidden_size, - hidden_channels=policy_head_hidden_channels[0], - out_channels=action_space_size, - layer_num=len(policy_head_hidden_channels) + 1, - activation=self.activation, - norm_type=norm_type, - output_activation=False, - output_norm=False, - # last_linear_layer_init_zero=True is beneficial for convergence speed. last_linear_layer_init_zero=last_linear_layer_init_zero ) - def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[ - torch.Tensor, torch.Tensor]: + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Overview: - Forward computation of the prediction network. - Arguments: - - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). - Returns: - - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). - - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + Shapes: + - latent_state (:obj:`torch.Tensor`): (B, num_channels) + - policy_logits (:obj:`torch.Tensor`): (B, action_space_size) + - value (:obj:`torch.Tensor`): (B, output_support_size) """ - for res_block in self.resblocks: - latent_state = res_block(latent_state) - - value = self.conv1x1_value(latent_state) - value = self.norm_value(value) - value = self.activation(value) - - policy = self.conv1x1_policy(latent_state) - policy = self.norm_policy(policy) - policy = self.activation(policy) - - latent_state_value = value.reshape(-1, self.flatten_input_size_for_value_head) - latent_state_policy = policy.reshape(-1, self.flatten_input_size_for_policy_head) - - # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size) - latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1) - latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1) - - value = self.fc_value(latent_history_value) - policy = self.fc_policy(latent_history_policy) - return policy, value \ No newline at end of file + x = self.fc_prediction_common(latent_state) + value = self.fc_value_head(x) + policy_logits = self.fc_policy_head(x) + return policy_logits, value \ No newline at end of file diff --git a/lzero/model/efficientzero_model.py b/lzero/model/efficientzero_model.py index 3448fe5b8..162f8910f 100644 --- a/lzero/model/efficientzero_model.py +++ b/lzero/model/efficientzero_model.py @@ -12,7 +12,7 @@ from numpy import ndarray from .common import RepresentationNetwork, PredictionNetwork, EZNetworkOutput -from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from .utils import renormalize, get_params_mean # use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. @@ -567,9 +567,3 @@ def forward(self, state_action_encoding: torch.Tensor, value_prefix = self.fc_reward_head(value_prefix) return next_latent_state, next_reward_hidden_state, value_prefix - - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> Tuple[ndarray, float]: - return get_reward_mean(self) diff --git a/lzero/model/efficientzero_model_mlp.py b/lzero/model/efficientzero_model_mlp.py index 862f6417c..69432712a 100644 --- a/lzero/model/efficientzero_model_mlp.py +++ b/lzero/model/efficientzero_model_mlp.py @@ -7,7 +7,7 @@ from numpy import ndarray from .common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP -from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from .utils import renormalize, get_params_mean @MODEL_REGISTRY.register('EfficientZeroModelMLP') @@ -467,9 +467,3 @@ def forward(self, state_action_encoding: torch.Tensor, reward_hidden_state): value_prefix = self.fc_reward_head(value_prefix.squeeze(0)) return next_latent_state, next_reward_hidden_state, value_prefix - - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> Tuple[ndarray, float]: - return get_reward_mean(self) diff --git a/lzero/model/muzero_model.py b/lzero/model/muzero_model.py index 75680ac06..482818114 100644 --- a/lzero/model/muzero_model.py +++ b/lzero/model/muzero_model.py @@ -12,7 +12,7 @@ from numpy import ndarray from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork, FeatureAndGradientHook, MLP_V2 -from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from .utils import renormalize, get_params_mean # use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. @@ -536,9 +536,3 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to reward = self.fc_reward_head(x) return next_latent_state, reward - - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> Tuple[ndarray, float]: - return get_reward_mean(self) diff --git a/lzero/model/muzero_model_mlp.py b/lzero/model/muzero_model_mlp.py index 17565b018..8dba756d4 100644 --- a/lzero/model/muzero_model_mlp.py +++ b/lzero/model/muzero_model_mlp.py @@ -6,7 +6,7 @@ from ding.utils import MODEL_REGISTRY, SequenceType from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP, MLP_V2 -from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from .utils import renormalize, get_params_mean @MODEL_REGISTRY.register('MuZeroModelMLP') @@ -440,9 +440,3 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to reward = self.fc_reward_head(next_latent_state_encoding) return next_latent_state, reward - - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> float: - return get_reward_mean(self) diff --git a/lzero/model/muzero_model_multitask.py b/lzero/model/muzero_model_multitask.py new file mode 100644 index 000000000..d18267b53 --- /dev/null +++ b/lzero/model/muzero_model_multitask.py @@ -0,0 +1,531 @@ +from typing import Optional, Tuple, Sequence, List + +import math +import torch +import torch.nn as nn +from ding.torch_utils import MLP, ResBlock +from ding.utils import MODEL_REGISTRY, SequenceType +from numpy import ndarray + +# The following imports are assumed to be from the same project directory. +# To maintain API consistency, their internal logic is not modified. +from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork, FeatureAndGradientHook +from .utils import renormalize, get_params_mean + + +@MODEL_REGISTRY.register('MuZeroMTModel') +class MuZeroMTModel(nn.Module): + """ + Overview: + The Multi-Task MuZero model, which is a variant of the original MuZero model adapted for multi-task learning. + This model features a shared representation network and dynamics network, but utilizes separate, task-specific + prediction networks. This architecture allows the model to learn shared dynamics while specializing its + policy and value predictions for each individual task. + """ + # Default configuration for the model. + # This structure is recommended over using cfg.get('key', default_value) inside the code. + config = dict( + observation_shape=(12, 96, 96), + action_space_size=6, + num_res_blocks=1, + num_channels=64, + reward_head_channels=16, + value_head_channels=16, + policy_head_channels=16, + fc_reward_layers=[32], + fc_value_layers=[32], + fc_policy_layers=[32], + reward_support_size=601, + value_support_size=601, + proj_hid=1024, + proj_out=1024, + pred_hid=512, + pred_out=1024, + self_supervised_learning_loss=False, + categorical_distribution=True, + activation=nn.ReLU(inplace=True), + last_linear_layer_init_zero=True, + state_norm=False, + downsample=False, + norm_type='BN', + discrete_action_encoding_type='one_hot', + analysis_sim_norm=False, + task_num=1, + ) + + def __init__( + self, + observation_shape: SequenceType = (12, 96, 96), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 16, + value_head_channels: int = 16, + policy_head_channels: int = 16, + fc_reward_layers: List[int] = [32], + fc_value_layers: List[int] = [32], + fc_policy_layers: List[int] = [32], + reward_support_size: int = 601, + value_support_size: int = 601, + proj_hid: int = 1024, + proj_out: int = 1024, + pred_hid: int = 512, + pred_out: int = 1024, + self_supervised_learning_loss: bool = False, + categorical_distribution: bool = True, + activation: Optional[nn.Module] = None, + last_linear_layer_init_zero: bool = True, + state_norm: bool = False, + downsample: bool = False, + norm_type: Optional[str] = 'BN', + discrete_action_encoding_type: str = 'one_hot', + analysis_sim_norm: bool = False, + task_num: int = 1, + *args, + **kwargs + ) -> None: + """ + Overview: + Constructor for the MuZeroMTModel. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the input observation, e.g., (12, 96, 96). + - action_space_size (:obj:`int`): The size of the action space, applicable for discrete action spaces. + - num_res_blocks (:obj:`int`): The number of residual blocks in the representation, dynamics, and prediction networks. + - num_channels (:obj:`int`): The number of channels in the latent state. + - reward_head_channels (:obj:`int`): The number of channels in the reward head. + - value_head_channels (:obj:`int`): The number of channels in the value head. + - policy_head_channels (:obj:`int`): The number of channels in the policy head. + - fc_reward_layers (:obj:`List[int]`): The hidden layer sizes of the reward MLP. + - fc_value_layers (:obj:`List[int]`): The hidden layer sizes of the value MLP. + - fc_policy_layers (:obj:`List[int]`): The hidden layer sizes of the policy MLP. + - reward_support_size (:obj:`int`): The support size for categorical reward distribution. + - value_support_size (:obj:`int`): The support size for categorical value distribution. + - proj_hid (:obj:`int`): The hidden size of the projection network for SSL. + - proj_out (:obj:`int`): The output size of the projection network for SSL. + - pred_hid (:obj:`int`): The hidden size of the prediction head for SSL. + - pred_out (:obj:`int`): The output size of the prediction head for SSL. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self-supervised learning loss. + - categorical_distribution (:obj:`bool`): Whether to use categorical distribution for value and reward. + - activation (:obj:`Optional[nn.Module]`): The activation function to use. Defaults to nn.ReLU(inplace=True). + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer to zero. + - state_norm (:obj:`bool`): Whether to apply re-normalization to the latent state. + - downsample (:obj:`bool`): Whether to downsample the observation image. + - norm_type (:obj:`Optional[str]`): The type of normalization to use, either 'BN' (BatchNorm) or 'LN' (LayerNorm). + - discrete_action_encoding_type (:obj:`str`): The encoding type for discrete actions, 'one_hot' or 'not_one_hot'. + - analysis_sim_norm (:obj:`bool`): A flag for analysis, enables hooks for SimNorm analysis. + - task_num (:obj:`int`): The total number of tasks for the multi-task setup. + """ + super(MuZeroMTModel, self).__init__() + if activation is None: + activation = nn.ReLU(inplace=True) + + # --- Store configuration --- + self.action_space_size = action_space_size + self.categorical_distribution = categorical_distribution + self.self_supervised_learning_loss = self_supervised_learning_loss + self.state_norm = state_norm + self.downsample = downsample + self.task_num = task_num + self.discrete_action_encoding_type = discrete_action_encoding_type + + if self.categorical_distribution: + self.reward_support_size = reward_support_size + self.value_support_size = value_support_size + else: + self.reward_support_size = 1 + self.value_support_size = 1 + + # --- Prepare observation shape and action encoding dimension --- + if isinstance(observation_shape, int) or len(observation_shape) == 1: + # For 1D vector observations (e.g., classic control), wrap them into a 2D image-like format [C, W, H] + # to be compatible with the convolutional networks. + observation_shape = (1, observation_shape[0], 1) if isinstance(observation_shape, tuple) else (1, observation_shape, 1) + + if self.discrete_action_encoding_type == 'one_hot': + self.action_encoding_dim = self.action_space_size + elif self.discrete_action_encoding_type == 'not_one_hot': + self.action_encoding_dim = 1 + else: + raise ValueError(f"Unsupported discrete_action_encoding_type: {self.discrete_action_encoding_type}") + + latent_size = self._get_latent_size(observation_shape, self.downsample) + + # --- Initialize Network Components --- + + # 1. Shared Representation Network + self.representation_network = RepresentationNetwork( + observation_shape=observation_shape, + num_res_blocks=num_res_blocks, + num_channels=num_channels, + downsample=self.downsample, + activation=activation, + norm_type=norm_type + ) + + # 2. Shared Dynamics Network + self.dynamics_network = DynamicsNetwork( + observation_shape=observation_shape, + action_encoding_dim=self.action_encoding_dim, + num_res_blocks=num_res_blocks, + num_channels=num_channels + self.action_encoding_dim, + reward_head_channels=reward_head_channels, + fc_reward_layers=fc_reward_layers, + output_support_size=self.reward_support_size, + flatten_output_size_for_reward_head=reward_head_channels * latent_size, + downsample=self.downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) + + # 3. Task-Specific Prediction Networks + self.prediction_networks = nn.ModuleList([ + PredictionNetwork( + observation_shape=observation_shape, + action_space_size=self.action_space_size, + num_res_blocks=num_res_blocks, + num_channels=num_channels, + value_head_channels=value_head_channels, + policy_head_channels=policy_head_channels, + fc_value_layers=fc_value_layers, + fc_policy_layers=fc_policy_layers, + output_support_size=self.value_support_size, + flatten_output_size_for_value_head=value_head_channels * latent_size, + flatten_output_size_for_policy_head=policy_head_channels * latent_size, + downsample=self.downsample, + last_linear_layer_init_zero=last_linear_layer_init_zero, + activation=activation, + norm_type=norm_type + ) for _ in range(self.task_num) + ]) + + # 4. Optional Self-Supervised Learning (SSL) Components + if self.self_supervised_learning_loss: + self.projection_network = nn.Sequential( + nn.Linear(num_channels * latent_size, proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_hid), + nn.BatchNorm1d(proj_hid), + activation, + nn.Linear(proj_hid, proj_out), + nn.BatchNorm1d(proj_out) + ) + self.prediction_head = nn.Sequential( + nn.Linear(proj_out, pred_hid), + nn.BatchNorm1d(pred_hid), + activation, + nn.Linear(pred_hid, pred_out), + ) + + # 5. Optional Hook for Analysis + if analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + @staticmethod + def _get_latent_size(observation_shape: SequenceType, downsample: bool) -> int: + """ + Overview: + Helper function to calculate the flattened size of the latent space based on observation shape and downsampling. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the input observation. + - downsample (:obj:`bool`): Whether downsampling is enabled. + Returns: + - int: The flattened size (height * width) of the latent space. + """ + if downsample: + # With downsampling, the spatial dimensions are reduced by a factor of 16 (2^4). + return math.ceil(observation_shape[-2] / 16) * math.ceil(observation_shape[-1] / 16) + else: + return observation_shape[-2] * observation_shape[-1] + + def initial_inference(self, obs: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + Overview: + Performs the initial inference from a raw observation. It encodes the observation into a latent state + and then uses the task-specific prediction network to compute the policy and value. + Arguments: + - obs (:obj:`torch.Tensor`): The raw observation tensor. + - task_id (:obj:`int`): The identifier for the current task, used to select the correct prediction network. + Returns: + - MZNetworkOutput: A dataclass containing the predicted value, reward (initially zero), policy logits, and latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, C, H, W)`, where B is batch size. + - task_id (:obj:`int`): Scalar. + - Return.value: :math:`(B, value_support_size)`. + - Return.reward: :math:`(B, reward_support_size)`. + - Return.policy_logits: :math:`(B, action_space_size)`. + - Return.latent_state: :math:`(B, num_channels, H', W')`. + """ + batch_size = obs.size(0) + latent_state = self.representation_network(obs) + if self.state_norm: + latent_state = renormalize(latent_state) + + # Select the prediction network based on the task ID. + assert 0 <= task_id < self.task_num, f"Task ID {task_id} is out of range [0, {self.task_num-1}]" + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(latent_state) + + return MZNetworkOutput( + value=value, + reward=[0. for _ in range(batch_size)], # Initial reward is always zero. + policy_logits=policy_logits, + latent_state=latent_state, + ) + + def recurrent_inference(self, latent_state: torch.Tensor, action: torch.Tensor, task_id: int = 0) -> MZNetworkOutput: + """ + Overview: + Performs recurrent inference from a latent state and an action. It uses the dynamics network to predict + the next latent state and reward, and then uses the task-specific prediction network to compute the + policy and value for the next state. + Arguments: + - latent_state (:obj:`torch.Tensor`): The current latent state. + - action (:obj:`torch.Tensor`): The action taken in the current state. + - task_id (:obj:`int`): The identifier for the current task. + Returns: + - MZNetworkOutput: A dataclass containing the predicted value, reward, policy logits, and the next latent state. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, num_channels, H', W')`. + - action (:obj:`torch.Tensor`): :math:`(B, )`. + - task_id (:obj:`int`): Scalar. + - Return.value: :math:`(B, value_support_size)`. + - Return.reward: :math:`(B, reward_support_size)`. + - Return.policy_logits: :math:`(B, action_space_size)`. + - Return.latent_state: :math:`(B, num_channels, H', W')`. + """ + next_latent_state, reward = self._dynamics(latent_state, action) + + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + + # Select the prediction network based on the task ID. + assert 0 <= task_id < self.task_num, f"Task ID {task_id} is out of range [0, {self.task_num-1}]" + prediction_net = self.prediction_networks[task_id] + policy_logits, value = prediction_net(next_latent_state) + + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) + + def _dynamics(self, latent_state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Applies the dynamics function by concatenating the latent state with the encoded action and passing it + through the dynamics network to predict the next latent state and reward. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of the input state. + - action (:obj:`torch.Tensor`): The action to rollout. + Returns: + - Tuple[torch.Tensor, torch.Tensor]: A tuple containing the predicted next latent state and reward. + Shapes: + - latent_state (:obj:`torch.Tensor`): :math:`(B, C, H', W')`. + - action (:obj:`torch.Tensor`): :math:`(B, )`. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, C, H', W')`. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`. + """ + # Encode the action and expand it to match the spatial dimensions of the latent state. + if self.discrete_action_encoding_type == 'one_hot': + # Convert action indices to one-hot vectors. + action_one_hot = F.one_hot(action.long(), num_classes=self.action_space_size).float() + # Reshape for broadcasting: (B, A) -> (B, A, 1, 1) + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + # Expand to (B, A, H', W') + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], self.action_space_size, latent_state.shape[2], latent_state.shape[3] + ) + elif self.discrete_action_encoding_type == 'not_one_hot': + # Encode action as a single channel, normalized by action space size. + # Reshape for broadcasting: (B,) -> (B, 1, 1, 1) + action_encoding_tmp = action.float().view(-1, 1, 1, 1) + # Normalize and expand to (B, 1, H', W') + action_encoding = action_encoding_tmp / self.action_space_size + action_encoding = action_encoding.expand( + latent_state.shape[0], 1, latent_state.shape[2], latent_state.shape[3] + ) + + # Concatenate latent state and action encoding along the channel dimension. + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + + # Predict next state and reward. + next_latent_state, reward = self.dynamics_network(state_action_encoding) + + if self.state_norm: + next_latent_state = renormalize(next_latent_state) + + return next_latent_state, reward + + def project(self, latent_state: torch.Tensor, with_grad: bool = True) -> torch.Tensor: + """ + Overview: + Projects the latent state into a different space for self-supervised learning (e.g., BYOL, SimSiam). + This involves a projection network and an optional prediction head. + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state to project. + - with_grad (:obj:`bool`): If False, detach the output of the projection network to stop gradients. + This is typically used for the target network in SSL. + Returns: + - torch.Tensor: The projected (and possibly predicted) representation. + """ + if not self.self_supervised_learning_loss: + raise NotImplementedError("The 'project' method requires 'self_supervised_learning_loss' to be enabled.") + + # Flatten the latent state from (B, C, H, W) to (B, C*H*W). + latent_state = latent_state.reshape(latent_state.shape[0], -1) + + proj = self.projection_network(latent_state) + + if with_grad: + # Return the output of the prediction head, with gradients flowing. + return self.prediction_head(proj) + else: + # Return the output of the projection network, detached from the graph. + return proj.detach() + + def get_params_mean(self) -> float: + """ + Overview: + Computes the mean of all model parameters. Useful for debugging and monitoring training. + Returns: + - float: The mean value of all parameters. + """ + return get_params_mean(self) + + +class DynamicsNetwork(nn.Module): + """ + Overview: + The dynamics network of the MuZero model. It takes a state-action encoding as input and predicts + the next latent state and the reward for the transition. This network is shared across all tasks + in the multi-task setup. + """ + + def __init__( + self, + observation_shape: SequenceType, + action_encoding_dim: int = 2, + num_res_blocks: int = 1, + num_channels: int = 64, + reward_head_channels: int = 64, + fc_reward_layers: List[int] = [32], + output_support_size: int = 601, + flatten_output_size_for_reward_head: int = 64, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = None, + norm_type: Optional[str] = 'BN', + ) -> None: + """ + Overview: + Constructor for the DynamicsNetwork. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the original input observation. + - action_encoding_dim (:obj:`int`): The dimension of the encoded action. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The number of channels in the input (latent_state + action_encoding). + - reward_head_channels (:obj:`int`): The number of channels for the reward head's convolutional layer. + - fc_reward_layers (:obj:`List[int]`): The hidden layer sizes of the reward MLP. + - output_support_size (:obj:`int`): The support size for the categorical reward distribution. + - flatten_output_size_for_reward_head (:obj:`int`): The flattened input size for the reward MLP. + - downsample (:obj:`bool`): Whether downsampling is used, affecting LayerNorm shapes. + - last_linear_layer_init_zero (:obj:`bool`): Whether to initialize the last linear layer to zero. + - activation (:obj:`Optional[nn.Module]`): The activation function. Defaults to nn.ReLU(inplace=True). + - norm_type (:obj:`Optional[str]`): The type of normalization, 'BN' or 'LN'. + """ + super().__init__() + if activation is None: + activation = nn.ReLU(inplace=True) + + assert norm_type in ['BN', 'LN'], f"norm_type must be 'BN' or 'LN', but got {norm_type}" + # The input channels to the first conv layer is num_channels, which includes the original latent channels + # and the action encoding channels. The output should be the number of channels for the latent state. + latent_channels = num_channels - action_encoding_dim + assert latent_channels > 0, f"num_channels ({num_channels}) must be greater than action_encoding_dim ({action_encoding_dim})" + + self.action_encoding_dim = action_encoding_dim + self.activation = activation + + # Convolutional layer to process the combined state-action encoding. + self.conv = nn.Conv2d(num_channels, latent_channels, kernel_size=3, stride=1, padding=1, bias=False) + + # Normalization layer for the main path. + if norm_type == 'BN': + self.norm_common = nn.BatchNorm2d(latent_channels) + elif norm_type == 'LN': + if downsample: + ln_shape = [latent_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)] + else: + ln_shape = [latent_channels, observation_shape[-2], observation_shape[-1]] + self.norm_common = nn.LayerNorm(ln_shape) + + # A series of residual blocks to deepen the network. + self.resblocks = nn.ModuleList( + [ResBlock(in_channels=latent_channels, activation=activation, norm_type='BN', res_type='basic', bias=False) + for _ in range(num_res_blocks)] + ) + + # --- Reward Head --- + # 1x1 convolution to create an input for the reward MLP. + self.conv1x1_reward = nn.Conv2d(latent_channels, reward_head_channels, 1) + + # Normalization for the reward head. + if norm_type == 'BN': + self.norm_reward = nn.BatchNorm2d(reward_head_channels) + elif norm_type == 'LN': + if downsample: + ln_shape_reward = [reward_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)] + else: + ln_shape_reward = [reward_head_channels, observation_shape[-2], observation_shape[-1]] + self.norm_reward = nn.LayerNorm(ln_shape_reward) + + # MLP to predict the reward value from the processed features. + self.fc_reward_head = MLP( + in_channels=flatten_output_size_for_reward_head, + hidden_channels=fc_reward_layers[0], + out_channels=output_support_size, + layer_num=len(fc_reward_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward pass for the dynamics network. + Arguments: + - state_action_encoding (:obj:`torch.Tensor`): The concatenated latent state and action encoding. + Returns: + - Tuple[torch.Tensor, torch.Tensor]: A tuple containing the next latent state and the predicted reward. + Shapes: + - state_action_encoding (:obj:`torch.Tensor`): :math:`(B, C_latent + C_action, H', W')`. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, C_latent, H', W')`. + - reward (:obj:`torch.Tensor`): :math:`(B, output_support_size)`. + """ + # The original latent state is part of the input, used for the residual connection. + state_encoding = state_action_encoding[:, : -self.action_encoding_dim, :, :] + + # Main path for predicting the next latent state. + x = self.conv(state_action_encoding) + x = self.norm_common(x) + + # Add residual connection from the original latent state. + x += state_encoding + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + next_latent_state = x + + # --- Reward Prediction Path --- + # Process the next latent state to predict the reward. + reward_x = self.conv1x1_reward(next_latent_state) + reward_x = self.norm_reward(reward_x) + reward_x = self.activation(reward_x) + # Flatten the features before passing to the MLP. + reward_x = reward_x.view(reward_x.shape[0], -1) + reward = self.fc_reward_head(reward_x) + + return next_latent_state, reward \ No newline at end of file diff --git a/lzero/model/sampled_unizero_model.py b/lzero/model/sampled_unizero_model.py index 20787838c..e178c14d9 100644 --- a/lzero/model/sampled_unizero_model.py +++ b/lzero/model/sampled_unizero_model.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn -from ding.utils import MODEL_REGISTRY, SequenceType +from ding.utils import MODEL_REGISTRY, SequenceType, get_rank from easydict import EasyDict from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ @@ -73,16 +73,17 @@ def __init__( self.representation_network = RepresentationNetworkMLP( observation_shape, hidden_channels=world_model_cfg.embed_dim, - layer_num=2, + num_layers=2, activation=self.activation, norm_type=norm_type, group_size=world_model_cfg.group_size, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder ) # TODO: only for MemoryEnv now - self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25, norm_type=norm_type) + # self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25, norm_type=norm_type) + self.decoder_network = None self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder_network=self.decoder_network, with_lpips=False) + decoder=self.decoder_network, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print('==' * 20) @@ -110,7 +111,7 @@ def __init__( self.encoder_hook.setup_hooks(self.representation_network) self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder_network=self.decoder_network, with_lpips=True,) + decoder=self.decoder_network, with_lpips=True, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') @@ -145,8 +146,8 @@ def __init__( self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(with_lpips=True, encoder=self.representation_network, - decoder_network=self.decoder_network) + self.tokenizer = Tokenizer(encoder=self.representation_network, + decoder=self.decoder_network, with_lpips=True, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') @@ -157,6 +158,153 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') print('==' * 20) + # --- Log parameter counts for analysis --- + self._log_model_parameters(world_model_cfg.obs_type) + + def _log_model_parameters(self, obs_type: str) -> None: + """ + Overview: + Logs detailed parameter counts for all model components with a comprehensive breakdown. + Includes encoder, transformer, prediction heads, and other components. + Arguments: + - obs_type (:obj:`str`): The type of observation ('vector', 'image', or 'image_memory'). + """ + # Only print from rank 0 to avoid duplicate logs in DDP + if get_rank() != 0: + return + + print('=' * 80) + print('MODEL PARAMETER STATISTICS'.center(80)) + print('=' * 80) + + # --- Total Model Parameters --- + total_params = sum(p.numel() for p in self.parameters()) + total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + print(f'\n{"TOTAL MODEL":<40} {total_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {total_trainable:>15,} parameters') + print(f'{" └─ Frozen":<40} {total_params - total_trainable:>15,} parameters') + + # --- World Model Components --- + print(f'\n{"-" * 80}') + print(f'{"WORLD MODEL BREAKDOWN":<40}') + print(f'{"-" * 80}') + + wm_params = sum(p.numel() for p in self.world_model.parameters()) + wm_trainable = sum(p.numel() for p in self.world_model.parameters() if p.requires_grad) + print(f'{"World Model Total":<40} {wm_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {wm_trainable:>15,} parameters ({100*wm_trainable/wm_params:.1f}%)') + + # --- Encoder --- + encoder_params = sum(p.numel() for p in self.tokenizer.encoder.parameters()) + encoder_trainable = sum(p.numel() for p in self.tokenizer.encoder.parameters() if p.requires_grad) + print(f'\n{"1. ENCODER (Tokenizer)":<40} {encoder_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {encoder_trainable:>15,} parameters ({100*encoder_trainable/encoder_params:.1f}%)') + + # --- Transformer Backbone --- + transformer_params = sum(p.numel() for p in self.world_model.transformer.parameters()) + transformer_trainable = sum(p.numel() for p in self.world_model.transformer.parameters() if p.requires_grad) + print(f'\n{"2. TRANSFORMER BACKBONE":<40} {transformer_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {transformer_trainable:>15,} parameters ({100*transformer_trainable/transformer_params:.1f}%)') + + # --- Prediction Heads (Detailed Breakdown) --- + print(f'\n{"3. PREDICTION HEADS":<40}') + + # Access head_dict from world_model + if hasattr(self.world_model, 'head_dict'): + head_dict = self.world_model.head_dict + + # Calculate total heads parameters + total_heads_params = sum(p.numel() for module in head_dict.values() for p in module.parameters()) + total_heads_trainable = sum(p.numel() for module in head_dict.values() for p in module.parameters() if p.requires_grad) + print(f'{" Total (All Heads)":<40} {total_heads_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {total_heads_trainable:>15,} parameters ({100*total_heads_trainable/total_heads_params:.1f}%)') + + # Breakdown by head type + head_names_map = { + 'head_policy': 'Policy Head', + 'head_value': 'Value Head', + 'head_rewards': 'Reward Head', + 'head_observations': 'Next Latent (Obs) Head' + } + + print(f'\n{" Breakdown by Head Type:":<40}') + for head_key, head_name in head_names_map.items(): + if head_key in head_dict: + head_module = head_dict[head_key] + head_params = sum(p.numel() for p in head_module.parameters()) + head_trainable = sum(p.numel() for p in head_module.parameters() if p.requires_grad) + + # Count number of task-specific heads (for ModuleList) + if isinstance(head_module, nn.ModuleList): + num_heads = len(head_module) + params_per_head = head_params // num_heads if num_heads > 0 else 0 + print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + print(f'{" └─ " + f"{num_heads} task-specific heads":<38} {params_per_head:>15,} params/head') + else: + print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + print(f'{" └─ Shared across tasks":<38}') + + # --- Positional & Task Embeddings --- + print(f'\n{"4. EMBEDDINGS":<40}') + + if hasattr(self.world_model, 'pos_emb'): + pos_emb_params = sum(p.numel() for p in self.world_model.pos_emb.parameters()) + pos_emb_trainable = sum(p.numel() for p in self.world_model.pos_emb.parameters() if p.requires_grad) + print(f'{" ├─ Positional Embedding":<40} {pos_emb_params:>15,} parameters') + if pos_emb_trainable == 0: + print(f'{" └─ (Frozen)":<40}') + + if hasattr(self.world_model, 'task_emb') and self.world_model.task_emb is not None: + task_emb_params = sum(p.numel() for p in self.world_model.task_emb.parameters()) + task_emb_trainable = sum(p.numel() for p in self.world_model.task_emb.parameters() if p.requires_grad) + print(f'{" ├─ Task Embedding":<40} {task_emb_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {task_emb_trainable:>15,} parameters') + + if hasattr(self.world_model, 'act_embedding_table'): + act_emb_params = sum(p.numel() for p in self.world_model.act_embedding_table.parameters()) + act_emb_trainable = sum(p.numel() for p in self.world_model.act_embedding_table.parameters() if p.requires_grad) + print(f'{" └─ Action Embedding":<40} {act_emb_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {act_emb_trainable:>15,} parameters') + + # --- Decoder (if applicable) --- + if self.tokenizer.decoder_network is not None: + print(f'\n{"5. DECODER":<40}') + decoder_params = sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) + decoder_trainable = sum(p.numel() for p in self.tokenizer.decoder_network.parameters() if p.requires_grad) + print(f'{" Decoder Network":<40} {decoder_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {decoder_trainable:>15,} parameters') + + if hasattr(self.tokenizer, 'lpips') and self.tokenizer.lpips is not None: + lpips_params = sum(p.numel() for p in self.tokenizer.lpips.parameters()) + print(f'{" LPIPS Loss Network":<40} {lpips_params:>15,} parameters') + + # Calculate world model params excluding decoder and LPIPS + params_without_decoder = wm_params - decoder_params - lpips_params + print(f'\n{" World Model (exc. Decoder & LPIPS)":<40} {params_without_decoder:>15,} parameters') + + # --- Summary Table --- + print(f'\n{"=" * 80}') + print(f'{"SUMMARY":<40}') + print(f'{"=" * 80}') + print(f'{"Component":<30} {"Total Params":>15} {"Trainable":>15} {"% of Total":>15}') + print(f'{"-" * 80}') + + components = [ + ("Encoder", encoder_params, encoder_trainable), + ("Transformer", transformer_params, transformer_trainable), + ] + + if hasattr(self.world_model, 'head_dict'): + components.append(("Prediction Heads", total_heads_params, total_heads_trainable)) + + for name, total, trainable in components: + pct = 100 * total / total_params if total_params > 0 else 0 + print(f'{name:<30} {total:>15,} {trainable:>15,} {pct:>14.1f}%') + + print(f'{"=" * 80}') + print(f'{"TOTAL":<30} {total_params:>15,} {total_trainable:>15,} {"100.0%":>15}') + print(f'{"=" * 80}\n') + def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, current_obs_batch: Optional[torch.Tensor] = None, start_pos: int = 0) -> MZNetworkOutput: """ diff --git a/lzero/model/sampled_unizero_model_multitask.py b/lzero/model/sampled_unizero_model_multitask.py new file mode 100644 index 000000000..49c7077d3 --- /dev/null +++ b/lzero/model/sampled_unizero_model_multitask.py @@ -0,0 +1,436 @@ +from typing import Optional, List, Sequence + +import torch +import torch.nn as nn +from ding.torch_utils import MLP +from ding.utils import MODEL_REGISTRY, get_rank +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, LatentDecoder, \ + FeatureAndGradientHook, SimNorm +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT + +class RepresentationNetworkMLPMT(nn.Module): + """ + Overview: + A multi-task representation network that encodes vector observations into a latent state + using a Multi-Layer Perceptron (MLP). It supports task-specific encoders and an optional + shared projection layer to map representations into a common embedding space. + """ + + def __init__( + self, + observation_shape_list: List[int], + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: Optional[str] = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_shared_projection: bool = False, + shared_projection_dim: Optional[int] = None, + final_norm_option_in_encoder: str = 'LayerNorm', # TODO: Further investigate norm options + ) -> None: + """ + Arguments: + - observation_shape_list (:obj:`List[int]`): A list of observation feature dimensions, one for each task. + - hidden_channels (:obj:`int`): The number of hidden channels in the task-specific MLPs. + - layer_num (:obj:`int`): The number of layers in each MLP. + - activation (:obj:`nn.Module`): The activation function to use in the MLPs. Defaults to nn.GELU(approximate='tanh'). + - norm_type (:obj:`str`): The type of normalization to use within the MLPs. Defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the final output embedding. + - group_size (:obj:`int`): The group size for SimNorm if it is used. + - use_shared_projection (:obj:`bool`): Whether to use a shared projection layer after task-specific encoding. Defaults to False. + - shared_projection_dim (:obj:`Optional[int]`): The dimension of the shared projection layer. If None, it defaults to `hidden_channels`. + - final_norm_option_in_encoder (:obj:`str`): The final normalization layer type ('LayerNorm' or 'SimNorm'). Defaults to 'LayerNorm'. + """ + super().__init__() + self.env_num = len(observation_shape_list) + self.use_shared_projection = use_shared_projection + self.hidden_channels = hidden_channels + self.shared_projection_dim = shared_projection_dim or hidden_channels + self.embedding_dim = embedding_dim + self.final_norm_option_in_encoder = final_norm_option_in_encoder + + # Task-specific representation networks + self.fc_representation = nn.ModuleList([ + MLP( + in_channels=obs_shape, + hidden_channels=hidden_channels, + out_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # No activation or norm in the last layer is important for convergence. + output_activation=False, + output_norm=False, + # Initializing the last linear layer to zero can be beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + for obs_shape in observation_shape_list + ]) + + # Final normalization layer before projection + if self.final_norm_option_in_encoder == 'LayerNorm': + self.final_norm = nn.LayerNorm(self.embedding_dim, eps=1e-5) + elif self.final_norm_option_in_encoder == 'SimNorm': + self.final_norm = SimNorm(simnorm_dim=group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_encoder: {self.final_norm_option_in_encoder}") + + # Optional shared projection layer + if self.use_shared_projection: + self.shared_projection = nn.Linear(hidden_channels, self.shared_projection_dim) + # Using SimNorm for the shared space projection + self.projection_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor, task_id: int) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): The input tensor of shape :math:`(B, N)`, where B is the batch size and N is the length of the vector observation. + - task_id (:obj:`int`): The identifier for the current task, used to select the appropriate encoder. + - output (:obj:`torch.Tensor`): The output latent state. Its shape is :math:`(B, embedding_dim)` if shared projection is not used, otherwise :math:`(B, shared_projection_dim)`. + """ + # Encode observation using the task-specific MLP + x = self.fc_representation[task_id](x) + # Apply final normalization + x = self.final_norm(x) + + # Apply the shared projection layer if enabled + if self.use_shared_projection: + x = self.shared_projection(x) + x = self.projection_norm(x) + return x + + +@MODEL_REGISTRY.register('SampledUniZeroMTModel') +class SampledUniZeroMTModel(nn.Module): + """ + Overview: + The main model for Sampled UniZero in a multi-task setting. It integrates a representation + network, a tokenizer, and a world model to perform initial and recurrent inference, + which are essential for MuZero-style planning algorithms. The model is designed to handle + both vector and image-based observations across multiple tasks. + """ + + def __init__( + self, + observation_shape_list: List[Sequence], + action_space_size_list: List[int], + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: Optional[str] = 'LN', + world_model_cfg: EasyDict = None, + *args, + **kwargs + ): + """ + Arguments: + - observation_shape_list (:obj:`List[Sequence]`): A list of observation space shapes for each task (e.g., `[C, W, H]` for images or `[D]` for vectors). + - action_space_size_list (:obj:`List[int]`): A list of action space sizes for each task. + - num_res_blocks (:obj:`int`): The number of residual blocks in the image representation network. + - num_channels (:obj:`int`): The number of channels in the hidden states of the image representation network. + - activation (:obj:`nn.Module`): The activation function used throughout the network. + - downsample (:obj:`bool`): Whether to downsample observations in the image representation network. + - norm_type (:obj:`str`): The type of normalization to use in networks. Defaults to 'LN'. + - world_model_cfg (:obj:`EasyDict`): A single configuration object for the world model, shared across all tasks. + """ + super(SampledUniZeroMTModel, self).__init__() + self.task_num = len(observation_shape_list) + self.activation = activation + self.downsample = downsample + + # Determine the embedding dimension for observations and actions + if world_model_cfg.task_embed_option == "concat_task_embed": + obs_act_embed_dim = world_model_cfg.embed_dim - world_model_cfg.task_embed_dim if hasattr(world_model_cfg, "task_embed_dim") else 96 + else: + obs_act_embed_dim = world_model_cfg.embed_dim + + world_model_cfg.norm_type = norm_type + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, \ + 'max_tokens should be 2 * max_blocks, as each timestep consists of an observation and an action token.' + + # Initialize networks based on observation type + if world_model_cfg.obs_type == 'vector': + # A single representation network capable of handling multiple tasks via task_id + self.representation_network = RepresentationNetworkMLPMT( + observation_shape_list=observation_shape_list, + hidden_channels=obs_act_embed_dim, + layer_num=2, + activation=self.activation, + norm_type=norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + use_shared_projection=world_model_cfg.use_shared_projection, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + ) + # FIXED: Tokenizer parameter name is 'decoder', not 'decoder_network' + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=None, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + + elif world_model_cfg.obs_type == 'image': + self.representation_network = nn.ModuleList() + # TODO: Currently uses a single shared encoder for all image-based tasks. + # This can be extended to support multiple independent encoders if needed. + for _ in range(1): + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape_list[0], # Assuming shared encoder uses the shape of the first task + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + )) + # TODO: The world model and tokenizer for the 'image' case should be initialized here. + # self.tokenizer = Tokenizer(...) + # self.world_model = WorldModelMT(...) + + # Print model parameter counts for verification + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + if hasattr(self.tokenizer, 'encoder') and self.tokenizer.encoder is not None: + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + + # --- Log parameter counts for analysis --- + self._log_model_parameters(world_model_cfg.obs_type) + + def _log_model_parameters(self, obs_type: str) -> None: + """ + Overview: + Logs detailed parameter counts for all model components with a comprehensive breakdown. + Includes encoder, transformer, prediction heads, and other components. + This version is adapted for multi-task models. + Arguments: + - obs_type (:obj:`str`): The type of observation ('vector' or 'image'). + """ + # Only print from rank 0 to avoid duplicate logs in DDP + if get_rank() != 0: + return + + print('=' * 80) + print('MODEL PARAMETER STATISTICS (Multi-Task)'.center(80)) + print('=' * 80) + + # --- Total Model Parameters --- + total_params = sum(p.numel() for p in self.parameters()) + total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + print(f'\n{"TOTAL MODEL":<40} {total_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {total_trainable:>15,} parameters') + print(f'{" └─ Frozen":<40} {total_params - total_trainable:>15,} parameters') + print(f'{" └─ Number of Tasks":<40} {self.task_num:>15,}') + + # --- World Model Components --- + print(f'\n{"-" * 80}') + print(f'{"WORLD MODEL BREAKDOWN":<40}') + print(f'{"-" * 80}') + + wm_params = sum(p.numel() for p in self.world_model.parameters()) + wm_trainable = sum(p.numel() for p in self.world_model.parameters() if p.requires_grad) + print(f'{"World Model Total":<40} {wm_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {wm_trainable:>15,} parameters ({100*wm_trainable/wm_params:.1f}%)') + + # --- Encoder --- + if hasattr(self.tokenizer, 'encoder') and self.tokenizer.encoder is not None: + encoder_params = sum(p.numel() for p in self.tokenizer.encoder.parameters()) + encoder_trainable = sum(p.numel() for p in self.tokenizer.encoder.parameters() if p.requires_grad) + print(f'\n{"1. ENCODER (Tokenizer)":<40} {encoder_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {encoder_trainable:>15,} parameters ({100*encoder_trainable/encoder_params:.1f}%)') + + # For multi-task encoder, show per-task breakdown + if isinstance(self.tokenizer.encoder, nn.ModuleList): + print(f'{" └─ Multi-Task Encoders":<40} {len(self.tokenizer.encoder):>15,} tasks') + for i, enc in enumerate(self.tokenizer.encoder): + task_params = sum(p.numel() for p in enc.parameters()) + print(f'{" ├─ Task " + str(i):<38} {task_params:>15,} parameters') + elif hasattr(self.tokenizer.encoder, 'fc_representation'): + # RepresentationNetworkMLPMT case + print(f'{" └─ Task-Specific Encoders":<40} {len(self.tokenizer.encoder.fc_representation):>15,} tasks') + for i, enc in enumerate(self.tokenizer.encoder.fc_representation): + task_params = sum(p.numel() for p in enc.parameters()) + print(f'{" ├─ Task " + str(i):<38} {task_params:>15,} parameters') + + # Show shared projection if exists + if hasattr(self.tokenizer.encoder, 'use_shared_projection') and self.tokenizer.encoder.use_shared_projection: + shared_proj_params = sum(p.numel() for p in self.tokenizer.encoder.shared_projection.parameters()) + print(f'{" └─ Shared Projection Layer":<40} {shared_proj_params:>15,} parameters') + + # --- Transformer Backbone --- + transformer_params = sum(p.numel() for p in self.world_model.transformer.parameters()) + transformer_trainable = sum(p.numel() for p in self.world_model.transformer.parameters() if p.requires_grad) + print(f'\n{"2. TRANSFORMER BACKBONE":<40} {transformer_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {transformer_trainable:>15,} parameters ({100*transformer_trainable/transformer_params:.1f}%)') + + # --- Prediction Heads (Detailed Breakdown) --- + print(f'\n{"3. PREDICTION HEADS":<40}') + + # Access head_dict from world_model + if hasattr(self.world_model, 'head_dict'): + head_dict = self.world_model.head_dict + + # Calculate total heads parameters + total_heads_params = sum(p.numel() for module in head_dict.values() for p in module.parameters()) + total_heads_trainable = sum(p.numel() for module in head_dict.values() for p in module.parameters() if p.requires_grad) + print(f'{" Total (All Heads)":<40} {total_heads_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {total_heads_trainable:>15,} parameters ({100*total_heads_trainable/total_heads_params:.1f}%)') + + # Breakdown by head type + head_names_map = { + 'head_policy_multi_task': 'Policy Head (Multi-Task)', + 'head_value_multi_task': 'Value Head (Multi-Task)', + 'head_rewards_multi_task': 'Reward Head (Multi-Task)', + 'head_observations_multi_task': 'Next Latent Head (Multi-Task)' + } + + print(f'\n{" Breakdown by Head Type:":<40}') + for head_key, head_name in head_names_map.items(): + if head_key in head_dict: + head_module = head_dict[head_key] + head_params = sum(p.numel() for p in head_module.parameters()) + head_trainable = sum(p.numel() for p in head_module.parameters() if p.requires_grad) + + # Count number of task-specific heads (for ModuleList) + if isinstance(head_module, nn.ModuleList): + num_heads = len(head_module) + params_per_head = head_params // num_heads if num_heads > 0 else 0 + print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + print(f'{" └─ " + f"{num_heads} task-specific heads":<38} {params_per_head:>15,} params/head') + else: + print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + print(f'{" └─ Shared across tasks":<38}') + + # --- Positional & Task Embeddings --- + print(f'\n{"4. EMBEDDINGS":<40}') + + if hasattr(self.world_model, 'pos_emb'): + pos_emb_params = sum(p.numel() for p in self.world_model.pos_emb.parameters()) + pos_emb_trainable = sum(p.numel() for p in self.world_model.pos_emb.parameters() if p.requires_grad) + print(f'{" ├─ Positional Embedding":<40} {pos_emb_params:>15,} parameters') + if pos_emb_trainable == 0: + print(f'{" └─ (Frozen)":<40}') + + if hasattr(self.world_model, 'task_emb') and self.world_model.task_emb is not None: + task_emb_params = sum(p.numel() for p in self.world_model.task_emb.parameters()) + task_emb_trainable = sum(p.numel() for p in self.world_model.task_emb.parameters() if p.requires_grad) + print(f'{" ├─ Task Embedding":<40} {task_emb_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {task_emb_trainable:>15,} parameters') + print(f'{" └─ Num Embeddings":<40} {self.task_num:>15,} tasks') + + if hasattr(self.world_model, 'act_embedding_table'): + act_emb_params = sum(p.numel() for p in self.world_model.act_embedding_table.parameters()) + act_emb_trainable = sum(p.numel() for p in self.world_model.act_embedding_table.parameters() if p.requires_grad) + print(f'{" └─ Action Embedding":<40} {act_emb_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {act_emb_trainable:>15,} parameters') + + # --- Decoder (if applicable) --- + if hasattr(self.tokenizer, 'decoder_network') and self.tokenizer.decoder_network is not None: + print(f'\n{"5. DECODER":<40}') + decoder_params = sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) + decoder_trainable = sum(p.numel() for p in self.tokenizer.decoder_network.parameters() if p.requires_grad) + print(f'{" Decoder Network":<40} {decoder_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {decoder_trainable:>15,} parameters') + + if hasattr(self.tokenizer, 'lpips') and self.tokenizer.lpips is not None: + lpips_params = sum(p.numel() for p in self.tokenizer.lpips.parameters()) + print(f'{" LPIPS Loss Network":<40} {lpips_params:>15,} parameters') + + # Calculate world model params excluding decoder and LPIPS + params_without_decoder = wm_params - decoder_params - lpips_params + print(f'\n{" World Model (exc. Decoder & LPIPS)":<40} {params_without_decoder:>15,} parameters') + + # --- Summary Table --- + print(f'\n{"=" * 80}') + print(f'{"SUMMARY":<40}') + print(f'{"=" * 80}') + print(f'{"Component":<30} {"Total Params":>15} {"Trainable":>15} {"% of Total":>15}') + print(f'{"-" * 80}') + + components = [] + + if hasattr(self.tokenizer, 'encoder') and self.tokenizer.encoder is not None: + encoder_params = sum(p.numel() for p in self.tokenizer.encoder.parameters()) + encoder_trainable = sum(p.numel() for p in self.tokenizer.encoder.parameters() if p.requires_grad) + components.append(("Encoder", encoder_params, encoder_trainable)) + + components.append(("Transformer", transformer_params, transformer_trainable)) + + if hasattr(self.world_model, 'head_dict'): + components.append(("Prediction Heads", total_heads_params, total_heads_trainable)) + + for name, total, trainable in components: + pct = 100 * total / total_params if total_params > 0 else 0 + print(f'{name:<30} {total:>15,} {trainable:>15,} {pct:>14.1f}%') + + print(f'{"=" * 80}') + print(f'{"TOTAL":<30} {total_params:>15,} {total_trainable:>15,} {"100.0%":>15}') + print(f'{"=" * 80}\n') + + def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, current_obs_batch: Optional[torch.Tensor] = None, task_id: Optional[int] = None) -> MZNetworkOutput: + """ + Overview: + Performs the initial inference step of the UniZero model. It takes an observation + and produces a latent state, a value prediction, and an initial policy. + Arguments: + - obs_batch (:obj:`torch.Tensor`): The initial batch of observations. + - action_batch (:obj:`Optional[torch.Tensor]`): An optional batch of actions. + - current_obs_batch (:obj:`Optional[torch.Tensor]`): An optional batch of current observations. + - task_id (:obj:`Optional[int]`): The identifier for the current task. + Returns (MZNetworkOutput): + An object containing the predicted value, initial reward (zero), policy logits, and latent state. + Shapes: + - obs_batch (:obj:`torch.Tensor`): :math:`(B, ...)` where B is the batch size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`. + - latent_state (:obj:`torch.Tensor`): :math:`(B, embedding_dim)`. + """ + batch_size = obs_batch.size(0) + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) + + latent_state = obs_token + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + + return MZNetworkOutput( + value=value, + reward=[0. for _ in range(batch_size)], # Initial reward is always zero + policy_logits=policy_logits, + latent_state=latent_state, + ) + + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index: int = 0, search_depth: List[int] = [], task_id: int = 0) -> MZNetworkOutput: + """ + Overview: + Performs the recurrent inference step (the dynamics function). Given a history of + latent states and actions, it predicts the next latent state, reward, value, and policy. + Arguments: + - state_action_history (:obj:`torch.Tensor`): A history of states and actions. + - simulation_index (:obj:`int`): The index of the current simulation step in MCTS. + - search_depth (:obj:`List[int]`): The indices of latent states in the current search path. + - task_id (:obj:`int`): The identifier for the current task. + Returns (MZNetworkOutput): + An object containing the predicted value, reward, policy logits, and the next latent state. + Shapes: + - state_action_history (:obj:`torch.Tensor`): :math:`(B, L, D)`, where L is sequence length. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, embedding_dim)`. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, search_depth, task_id=task_id) + + next_latent_state = logits_observations + reward = logits_rewards.squeeze(1) + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file diff --git a/lzero/model/stochastic_muzero_model.py b/lzero/model/stochastic_muzero_model.py index 7aa7ce678..f93ce072c 100644 --- a/lzero/model/stochastic_muzero_model.py +++ b/lzero/model/stochastic_muzero_model.py @@ -7,7 +7,7 @@ from ding.utils import MODEL_REGISTRY, SequenceType from .common import MZNetworkOutput, RepresentationNetwork, PredictionNetwork -from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean +from .utils import renormalize, get_params_mean # use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. @@ -572,12 +572,6 @@ def forward(self, state_action_encoding: torch.Tensor) -> Tuple[torch.Tensor, to return next_latent_state, reward - def get_dynamic_mean(self) -> float: - return get_dynamic_mean(self) - - def get_reward_mean(self) -> float: - return get_reward_mean(self) - # TODO(pu): customize different afterstate dynamics network AfterstateDynamicsNetwork = DynamicsNetwork diff --git a/lzero/model/tests/test_moe.py b/lzero/model/tests/test_moe.py new file mode 100644 index 000000000..fd90c1329 --- /dev/null +++ b/lzero/model/tests/test_moe.py @@ -0,0 +1,233 @@ +""" +test_moe.py + +Overview: + A pytest test suite to verify the functional equivalence between a standard Transformer's feed-forward network (FFN) + and a Mixture-of-Experts (MoE) layer configured with a single expert. This test demonstrates that + the MoE layer correctly specializes to a standard FFN when num_experts is 1, ensuring backward + compatibility and correct routing logic. +""" +import dataclasses +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from lzero.model.unizero_world_models.moe import MoELayer + + +@dataclasses.dataclass +class TransformerConfig: + """ + Overview: + Configuration for the Transformer block and its potential MoE layer. + + Arguments: + - embed_dim (int): The embedding dimension for the model. + - resid_pdrop (float): The dropout probability for the residual connections. + - moe_in_transformer (bool): If True, use an MoE layer for the feed-forward part. Otherwise, use a standard MLP. + - num_experts (int): The total number of experts in the MoE layer. + - num_experts_per_tok (int): The number of experts to route each token to (top-k routing). + - moe_use_lora (bool): Whether to use LoRA in the MoE layer. + - n_shared_experts (int): Number of shared experts (optional). + """ + embed_dim: int = 64 + resid_pdrop: float = 0.1 + moe_in_transformer: bool = False + num_experts: int = 1 + num_experts_per_tok: int = 1 + moe_use_lora: bool = False + n_shared_experts: int = 0 + + +class TransformerBlock(nn.Module): + """ + Overview: + A simplified Transformer block that contains a feed-forward network (FFN). + The FFN can be either a standard MLP or a Mixture-of-Experts (MoE) layer, + controlled by the configuration. + """ + def __init__(self, config: TransformerConfig): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + + if config.moe_in_transformer: + # Create experts - in the single expert case, we share the same MLP + experts = [self.mlp for _ in range(config.num_experts)] + gate = nn.Linear(config.embed_dim, config.num_experts, bias=False) + + # Use MoELayer from moe.py (note the different signature) + self.feed_forward = MoELayer( + config=config, + experts=experts, + gate=gate, + num_experts_per_tok=config.num_experts_per_tok, + ) + print("=" * 40) + print("TransformerBlock initialized with MoE layer.") + print("=" * 40) + else: + self.feed_forward = self.mlp + print("-" * 40) + print("TransformerBlock initialized with standard MLP.") + print("-" * 40) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.feed_forward(x) + + +class TestMoELayer: + """Test suite for MoE layer functionality.""" + + @pytest.fixture + def embed_dim(self): + """Embedding dimension for tests.""" + return 64 + + @pytest.fixture + def batch_size(self): + """Batch size for tests.""" + return 10 + + @pytest.fixture + def seq_len(self): + """Sequence length for tests.""" + return 5 + + def test_single_expert_moe_equivalence(self, embed_dim, batch_size, seq_len): + """ + Test that an MoE layer with a single expert produces an output identical + to that of a standard MLP layer, given that they share the same weights. + """ + torch.manual_seed(42) + + config_mlp = TransformerConfig(embed_dim=embed_dim, moe_in_transformer=False) + config_moe = TransformerConfig( + embed_dim=embed_dim, + moe_in_transformer=True, + num_experts=1, + num_experts_per_tok=1, + moe_use_lora=False, + n_shared_experts=0 + ) + + # 1. Create the standard MLP block first. + transformer_block_mlp = TransformerBlock(config_mlp) + + # 2. Create the MoE block. + transformer_block_moe = TransformerBlock(config_moe) + + # 3. CRITICAL: Load the MLP's weights into the MoE's expert MLP. + # This guarantees that the underlying expert has the exact same weights as the standalone MLP. + transformer_block_moe.mlp.load_state_dict(transformer_block_mlp.mlp.state_dict()) + + inputs = torch.randn(batch_size, seq_len, embed_dim) + + print("\nRunning forward pass for standard MLP block...") + output_mlp = transformer_block_mlp(inputs) + + print("\nRunning forward pass for MoE block...") + output_moe = transformer_block_moe(inputs) + + is_close = torch.allclose(output_moe, output_mlp, atol=1e-6) + mse_difference = F.mse_loss(output_moe, output_mlp).item() + + print("\n" + "=" * 25 + " TEST RESULTS " + "=" * 25) + print(f"Outputs are close: {is_close}") + print(f"Mean Squared Error (MSE) between outputs: {mse_difference:.10f}") + + assert is_close, f"Test failed: Outputs of single-expert MoE and MLP are not identical. MSE: {mse_difference}" + print("\n✅ Test Passed: Single-expert MoE layer behaves identically to a standard MLP.") + print("=" * 64 + "\n") + + def test_moe_output_shape(self, embed_dim, batch_size, seq_len): + """ + Test that MoE layer preserves the input shape. + """ + torch.manual_seed(42) + + config_moe = TransformerConfig( + embed_dim=embed_dim, + moe_in_transformer=True, + num_experts=4, + num_experts_per_tok=2, + moe_use_lora=False, + n_shared_experts=0 + ) + + transformer_block_moe = TransformerBlock(config_moe) + inputs = torch.randn(batch_size, seq_len, embed_dim) + + output = transformer_block_moe(inputs) + + assert output.shape == inputs.shape, \ + f"Expected output shape {inputs.shape}, but got {output.shape}" + print(f"✅ Test Passed: MoE layer preserves input shape: {inputs.shape}") + + def test_moe_with_multiple_experts(self, embed_dim, batch_size, seq_len): + """ + Test that MoE layer works correctly with multiple experts. + """ + torch.manual_seed(42) + + num_experts = 8 + num_experts_per_tok = 2 + + config_moe = TransformerConfig( + embed_dim=embed_dim, + moe_in_transformer=True, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + moe_use_lora=False, + n_shared_experts=0 + ) + + transformer_block_moe = TransformerBlock(config_moe) + inputs = torch.randn(batch_size, seq_len, embed_dim) + + output = transformer_block_moe(inputs) + + assert output.shape == inputs.shape, \ + f"Expected output shape {inputs.shape}, but got {output.shape}" + assert not torch.isnan(output).any(), "Output contains NaN values" + assert not torch.isinf(output).any(), "Output contains Inf values" + + print(f"✅ Test Passed: MoE layer with {num_experts} experts and top-{num_experts_per_tok} routing works correctly") + + +if __name__ == "__main__": + if PYTEST_AVAILABLE: + pytest.main([__file__, "-v", "-s"]) + else: + # Run tests directly without pytest + print("Pytest not available. Running tests directly...\n") + test_suite = TestMoELayer() + + # Set up fixtures + embed_dim = 64 + batch_size = 10 + seq_len = 5 + + print("\n" + "=" * 60) + print("Test 1: Single Expert MoE Equivalence") + print("=" * 60) + test_suite.test_single_expert_moe_equivalence(embed_dim, batch_size, seq_len) + + print("\n" + "=" * 60) + print("Test 2: MoE Output Shape") + print("=" * 60) + test_suite.test_moe_output_shape(embed_dim, batch_size, seq_len) + + print("\n" + "=" * 60) + print("Test 3: MoE with Multiple Experts") + print("=" * 60) + test_suite.test_moe_with_multiple_experts(embed_dim, batch_size, seq_len) + + print("\n" + "=" * 60) + print("All tests passed! ✅") + print("=" * 60) diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index 9d57b3c5f..b680a6e2d 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -1,17 +1,20 @@ from typing import Optional - import torch import torch.nn as nn -from ding.utils import MODEL_REGISTRY, SequenceType +from ding.utils import (ENV_REGISTRY, MODEL_REGISTRY, SequenceType, get_rank, + get_world_size, set_pkg_seed) +from ditk import logging from easydict import EasyDict -# from transformers import T5ForConditionalGeneration, T5Tokenizer - -from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ - VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ - HFLanguageRepresentationNetwork, QwenNetwork +from .common import (FeatureAndGradientHook, HFLanguageRepresentationNetwork, + LatentDecoder, LatentDecoderForMemoryEnv, + LatentEncoderForMemoryEnv, MZNetworkOutput, QwenNetwork, + RepresentationNetworkMLP, RepresentationNetworkUniZero, + VectorDecoderForMemoryEnv) from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model import WorldModel -from ding.utils import ENV_REGISTRY, set_pkg_seed, get_rank, get_world_size +from .vit import ViT, ViTConfig + +# from transformers import T5ForConditionalGeneration, T5Tokenizer # use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. @@ -88,13 +91,13 @@ def __init__( # TODO: only for MemoryEnv now self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) self.tokenizer = Tokenizer(encoder=self.representation_network, - decoder_network=self.decoder_network, with_lpips=False) + decoder=self.decoder_network, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) - print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - print('==' * 20) - print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - print('==' * 20) + logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + logging.info('==' * 20) + logging.info(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + logging.info(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + logging.info('==' * 20) elif world_model_cfg.obs_type == 'text': if kwargs['encoder_option'] == 'legacy': self.representation_network = HFLanguageRepresentationNetwork(model_path=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder) @@ -125,38 +128,72 @@ def __init__( self.decoder_network_tokenizer = None else: raise ValueError(f"Unsupported encoder option: {kwargs['encoder_option']}") - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, - with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) + + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, decoder_network_tokenizer=self.decoder_network_tokenizer, + with_lpips=False, projection=projection, encoder_option=kwargs['encoder_option']) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) - print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - print('==' * 20) - print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - print('==' * 20) + + # --- Log parameter counts for analysis --- + self._log_model_parameters(obs_type) + + logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + logging.info('==' * 20) + logging.info(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + logging.info(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + logging.info('==' * 20) elif world_model_cfg.obs_type == 'image': - self.representation_network = RepresentationNetworkUniZero( - observation_shape, - num_res_blocks, - num_channels, - self.downsample, - activation=self.activation, - norm_type=norm_type, - embedding_dim=world_model_cfg.embed_dim, - group_size=world_model_cfg.group_size, - final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder - ) + if world_model_cfg.encoder_type == "resnet": + self.representation_network = RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=world_model_cfg.embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + ) + elif world_model_cfg.encoder_type == "vit": + # vit base + vit_config = ViTConfig( + image_size=observation_shape[1], + patch_size=8, + num_classes=world_model_cfg.embed_dim, + dim=768, + depth=12, + heads=12, + mlp_dim=3072, + dropout=0.1, + emb_dropout=0.1, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + lora_config=world_model_cfg, + ) + self.representation_network = ViT(config=vit_config) # ====== for analysis ====== if world_model_cfg.analysis_sim_norm: self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,) + + if world_model_cfg.latent_recon_loss_weight == 0: + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=None, with_lpips=False, obs_type=world_model_cfg.obs_type) + else: + # TODO: customize LatentDecoder + self.decoder_network = LatentDecoder( + embedding_dim=world_model_cfg.embed_dim, + output_shape=[3, 64, 64], + num_channels = 64, + activation=self.activation, + ) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, with_lpips=True, obs_type=world_model_cfg.obs_type) + self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) - print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - print('==' * 20) - print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - print('==' * 20) + logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + logging.info('==' * 20) + logging.info(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + logging.info(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + logging.info('==' * 20) elif world_model_cfg.obs_type == 'image_memory': self.representation_network = LatentEncoderForMemoryEnv( image_shape=(3, 5, 5), @@ -181,17 +218,170 @@ def __init__( self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) - self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=self.decoder_network) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder=self.decoder_network, obs_type=world_model_cfg.obs_type) self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) - print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') - print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') - print('==' * 20) - print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') - print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') - print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') - print('==' * 20) + + logging.info(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + logging.info(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') + + logging.info('==' * 20) + logging.info(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + logging.info(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + logging.info(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') + logging.info('==' * 20) + + # --- Log parameter counts for analysis --- + self._log_model_parameters(world_model_cfg.obs_type) + + + def _log_model_parameters(self, obs_type: str) -> None: + """ + Overview: + Logs detailed parameter counts for all model components with a comprehensive breakdown. + Includes encoder, transformer, prediction heads, and other components. + Arguments: + - obs_type (:obj:`str`): The type of observation ('vector', 'image', or 'image_memory'). + """ + from ding.utils import get_rank + + # Only print from rank 0 to avoid duplicate logs in DDP + if get_rank() != 0: + return + + logging.info('=' * 80) + logging.info('MODEL PARAMETER STATISTICS'.center(80)) + logging.info('=' * 80) + + # --- Total Model Parameters --- + total_params = sum(p.numel() for p in self.parameters()) + total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + logging.info(f'\n{"TOTAL MODEL":<40} {total_params:>15,} parameters') + logging.info(f'{" └─ Trainable":<40} {total_trainable:>15,} parameters') + logging.info(f'{" └─ Frozen":<40} {total_params - total_trainable:>15,} parameters') + + # --- World Model Components --- + logging.info(f'\n{"-" * 80}') + logging.info(f'{"WORLD MODEL BREAKDOWN":<40}') + logging.info(f'{"-" * 80}') + + wm_params = sum(p.numel() for p in self.world_model.parameters()) + wm_trainable = sum(p.numel() for p in self.world_model.parameters() if p.requires_grad) + logging.info(f'{"World Model Total":<40} {wm_params:>15,} parameters') + logging.info(f'{" └─ Trainable":<40} {wm_trainable:>15,} parameters ({100*wm_trainable/wm_params:.1f}%)') + + # --- Encoder --- + encoder_params = sum(p.numel() for p in self.tokenizer.encoder.parameters()) + encoder_trainable = sum(p.numel() for p in self.tokenizer.encoder.parameters() if p.requires_grad) + logging.info(f'\n{"1. ENCODER (Tokenizer)":<40} {encoder_params:>15,} parameters') + logging.info(f'{" └─ Trainable":<40} {encoder_trainable:>15,} parameters ({100*encoder_trainable/encoder_params:.1f}%)') + + # --- Transformer Backbone --- + transformer_params = sum(p.numel() for p in self.world_model.transformer.parameters()) + transformer_trainable = sum(p.numel() for p in self.world_model.transformer.parameters() if p.requires_grad) + logging.info(f'\n{"2. TRANSFORMER BACKBONE":<40} {transformer_params:>15,} parameters') + logging.info(f'{" └─ Trainable":<40} {transformer_trainable:>15,} parameters ({100*transformer_trainable/transformer_params:.1f}%)') + + # --- Prediction Heads (Detailed Breakdown) --- + logging.info(f'\n{"3. PREDICTION HEADS":<40}') + + # Access head_dict from world_model + if hasattr(self.world_model, 'head_dict'): + head_dict = self.world_model.head_dict + + # Calculate total heads parameters + total_heads_params = sum(p.numel() for module in head_dict.values() for p in module.parameters()) + total_heads_trainable = sum(p.numel() for module in head_dict.values() for p in module.parameters() if p.requires_grad) + logging.info(f'{" Total (All Heads)":<40} {total_heads_params:>15,} parameters') + logging.info(f'{" └─ Trainable":<40} {total_heads_trainable:>15,} parameters ({100*total_heads_trainable/total_heads_params:.1f}%)') + + # Breakdown by head type + head_names_map = { + 'head_policy_multi_task': 'Policy Head', + 'head_value_multi_task': 'Value Head', + 'head_rewards_multi_task': 'Reward Head', + 'head_observations_multi_task': 'Next Latent (Obs) Head' + } + + logging.info(f'\n{" Breakdown by Head Type:":<40}') + for head_key, head_name in head_names_map.items(): + if head_key in head_dict: + head_module = head_dict[head_key] + head_params = sum(p.numel() for p in head_module.parameters()) + head_trainable = sum(p.numel() for p in head_module.parameters() if p.requires_grad) + + # Count number of task-specific heads (for ModuleList) + if isinstance(head_module, nn.ModuleList): + num_heads = len(head_module) + params_per_head = head_params // num_heads if num_heads > 0 else 0 + logging.info(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + logging.info(f'{" └─ " + f"{num_heads} task-specific heads":<38} {params_per_head:>15,} params/head') + else: + logging.info(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + logging.info(f'{" └─ Shared across tasks":<38}') + + # --- Positional & Task Embeddings --- + logging.info(f'\n{"4. EMBEDDINGS":<40}') + + if hasattr(self.world_model, 'pos_emb'): + pos_emb_params = sum(p.numel() for p in self.world_model.pos_emb.parameters()) + pos_emb_trainable = sum(p.numel() for p in self.world_model.pos_emb.parameters() if p.requires_grad) + logging.info(f'{" ├─ Positional Embedding":<40} {pos_emb_params:>15,} parameters') + if pos_emb_trainable == 0: + logging.info(f'{" └─ (Frozen)":<40}') + + if hasattr(self.world_model, 'task_emb') and self.world_model.task_emb is not None: + task_emb_params = sum(p.numel() for p in self.world_model.task_emb.parameters()) + task_emb_trainable = sum(p.numel() for p in self.world_model.task_emb.parameters() if p.requires_grad) + logging.info(f'{" ├─ Task Embedding":<40} {task_emb_params:>15,} parameters') + logging.info(f'{" └─ Trainable":<40} {task_emb_trainable:>15,} parameters') + + if hasattr(self.world_model, 'act_embedding_table'): + act_emb_params = sum(p.numel() for p in self.world_model.act_embedding_table.parameters()) + act_emb_trainable = sum(p.numel() for p in self.world_model.act_embedding_table.parameters() if p.requires_grad) + logging.info(f'{" └─ Action Embedding":<40} {act_emb_params:>15,} parameters') + logging.info(f'{" └─ Trainable":<40} {act_emb_trainable:>15,} parameters') + + # --- Decoder (if applicable) --- + if obs_type in ['vector', 'image_memory'] and self.tokenizer.decoder_network is not None: + logging.info(f'\n{"5. DECODER":<40}') + decoder_params = sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) + decoder_trainable = sum(p.numel() for p in self.tokenizer.decoder_network.parameters() if p.requires_grad) + logging.info(f'{" Decoder Network":<40} {decoder_params:>15,} parameters') + logging.info(f'{" └─ Trainable":<40} {decoder_trainable:>15,} parameters') + + if obs_type == 'image_memory' and hasattr(self.tokenizer, 'lpips'): + lpips_params = sum(p.numel() for p in self.tokenizer.lpips.parameters()) + logging.info(f'{" LPIPS Loss Network":<40} {lpips_params:>15,} parameters') + + # Calculate world model params excluding decoder and LPIPS + params_without_decoder = wm_params - decoder_params - lpips_params + logging.info(f'\n{" World Model (exc. Decoder & LPIPS)":<40} {params_without_decoder:>15,} parameters') + + # --- Summary Table --- + logging.info(f'\n{"=" * 80}') + logging.info(f'{"SUMMARY":<40}') + logging.info(f'{"=" * 80}') + logging.info(f'{"Component":<30} {"Total Params":>15} {"Trainable":>15} {"% of Total":>15}') + logging.info(f'{"-" * 80}') + + components = [ + ("Encoder", encoder_params, encoder_trainable), + ("Transformer", transformer_params, transformer_trainable), + ] + + if hasattr(self.world_model, 'head_dict'): + components.append(("Prediction Heads", total_heads_params, total_heads_trainable)) + + for name, total, trainable in components: + pct = 100 * total / total_params if total_params > 0 else 0 + logging.info(f'{name:<30} {total:>15,} {trainable:>15,} {pct:>14.1f}%') + + logging.info(f'{"=" * 80}') + logging.info(f'{"TOTAL":<30} {total_params:>15,} {total_trainable:>15,} {"100.0%":>15}') + logging.info(f'{"=" * 80}\n') + def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, current_obs_batch: Optional[torch.Tensor] = None, start_pos: int = 0) -> MZNetworkOutput: """ @@ -277,4 +467,4 @@ def recurrent_inference(self, state_action_history: torch.Tensor, simulation_ind policy_logits = logits_policy.squeeze(1) value = logits_value.squeeze(1) - return MZNetworkOutput(value=value, reward=reward, policy_logits=policy_logits, latent_state=next_latent_state) \ No newline at end of file + return MZNetworkOutput(value=value, reward=reward, policy_logits=policy_logits, latent_state=next_latent_state) diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py new file mode 100644 index 000000000..a7356d942 --- /dev/null +++ b/lzero/model/unizero_model_multitask.py @@ -0,0 +1,413 @@ +from typing import Optional, Sequence, Dict, Any, List + +import torch +import torch.nn as nn +from ding.utils import MODEL_REGISTRY, SequenceType +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT +from .vit import ViT, ViTConfig + + +@MODEL_REGISTRY.register('UniZeroMTModel') +class UniZeroMTModel(nn.Module): + """ + Overview: + The main model for UniZero, a multi-task agent based on a scalable latent world model. + This class orchestrates the representation network, world model, and prediction heads. + It provides two primary interfaces: + - `initial_inference`: Encodes an observation to produce an initial latent state and predictions (value, policy). + - `recurrent_inference`: Simulates dynamics by taking a history of latent states and actions to predict the next + latent state, reward, value, and policy. + """ + + def __init__( + self, + observation_shape: SequenceType = (4, 64, 64), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: str = 'BN', + world_model_cfg: EasyDict = None, + task_num: int = 1, + *args: Any, + **kwargs: Any + ) -> None: + """ + Overview: + Initializes the UniZeroMTModel, setting up the representation network, tokenizer, and world model + based on the provided configuration. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of the input observation, e.g., (C, H, W). + - action_space_size (:obj:`int`): The size of the discrete action space. + - num_res_blocks (:obj:`int`): The number of residual blocks in the ResNet-based representation network. + - num_channels (:obj:`int`): The number of channels in the ResNet-based representation network. + - activation (:obj:`nn.Module`): The activation function to use throughout the network. + - downsample (:obj:`bool`): Whether to downsample the observation in the representation network. + - norm_type (:obj:`str`): The type of normalization to use, e.g., 'BN' for BatchNorm. + - world_model_cfg (:obj:`EasyDict`): Configuration for the world model and its components. + - task_num (:obj:`int`): The number of tasks for multi-task learning. + """ + super().__init__() + print(f'========== Initializing UniZeroMTModel (num_res_blocks: {num_res_blocks}, num_channels: {num_channels}) ==========') + + # --- Basic attribute setup --- + self.task_num = task_num + self.activation = activation + self.downsample = downsample + world_model_cfg.norm_type = norm_type + + # NOTE: The action_space_size passed as an argument is immediately overridden. + # This might be intentional for specific experiments but is not a general practice. + self.action_space_size = 18 + + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, \ + "max_tokens should be 2 * max_blocks, as each timestep consists of an observation and an action token." + + # --- Determine embedding dimensions --- + if world_model_cfg.task_embed_option == "concat_task_embed": + task_embed_dim = world_model_cfg.get("task_embed_dim", 32) # Default task_embed_dim to 32 if not specified + obs_act_embed_dim = world_model_cfg.embed_dim - task_embed_dim + else: + obs_act_embed_dim = world_model_cfg.embed_dim + + # --- Initialize model components based on observation type --- + obs_type = world_model_cfg.obs_type + if obs_type == 'vector': + self._init_vector_components(world_model_cfg, obs_act_embed_dim) + elif obs_type == 'image': + self._init_image_components(world_model_cfg, observation_shape, num_res_blocks, num_channels, obs_act_embed_dim) + elif obs_type == 'image_memory': + self._init_image_memory_components(world_model_cfg) + else: + raise ValueError(f"Unsupported observation type: {obs_type}") + + # --- Initialize world model and tokenizer --- + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + + # --- Log parameter counts for analysis --- + self._log_model_parameters(obs_type) + + def _init_vector_components(self, world_model_cfg: EasyDict, obs_act_embed_dim: int) -> None: + """Initializes components for 'vector' observation type.""" + self.representation_network = RepresentationNetworkMLP( + observation_shape=world_model_cfg.observation_shape, + hidden_channels=obs_act_embed_dim, + layer_num=2, + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + # TODO: This is currently specific to MemoryEnv. Generalize if needed. + self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=self.decoder_network, + with_lpips=False, + obs_type=world_model_cfg.obs_type + ) + + def _init_image_components(self, world_model_cfg: EasyDict, observation_shape: SequenceType, num_res_blocks: int, + num_channels: int, obs_act_embed_dim: int) -> None: + """Initializes components for 'image' observation type.""" + self.representation_network = nn.ModuleList() + encoder_type = world_model_cfg.encoder_type + + # NOTE: Using a single shared encoder. The original code used a loop `for _ in range(1):`. + # To support N independent encoders, this logic would need to be modified. + if encoder_type == "resnet": + encoder = RepresentationNetworkUniZero( + observation_shape=observation_shape, + num_res_blocks=num_res_blocks, + num_channels=num_channels, + downsample=self.downsample, + activation=self.activation, + norm_type=world_model_cfg.norm_type, + embedding_dim=obs_act_embed_dim, + group_size=world_model_cfg.group_size, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + ) + self.representation_network.append(encoder) + elif encoder_type == "vit": + vit_configs = { + 'small': {'dim': 768, 'depth': 6, 'heads': 6, 'mlp_dim': 2048}, + 'base': {'dim': 768, 'depth': 12, 'heads': 12, 'mlp_dim': 3072}, + 'large': {'dim': 1024, 'depth': 24, 'heads': 16, 'mlp_dim': 4096}, + } + vit_size = 'base' if self.task_num > 8 else 'small' + selected_vit_config = vit_configs[vit_size] + + vit_params = { + 'image_size': observation_shape[1], + 'patch_size': 8, + 'num_classes': obs_act_embed_dim, + 'dropout': 0.1, + 'emb_dropout': 0.1, + 'final_norm_option_in_encoder': world_model_cfg.final_norm_option_in_encoder, + 'lora_config': world_model_cfg, + **selected_vit_config + } + vit_config = ViTConfig(**vit_params) + encoder = ViT(config=vit_config) + + self.representation_network.append(encoder) + else: + raise ValueError(f"Unsupported encoder type for image observations: {encoder_type}") + + # For image observations, the decoder is currently not used for reconstruction during training. + self.decoder_network = None + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=self.decoder_network, + with_lpips=False, + obs_type=world_model_cfg.obs_type + ) + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + def _init_image_memory_components(self, world_model_cfg: EasyDict) -> None: + """Initializes components for 'image_memory' observation type.""" + # TODO: The 'concat_task_embed' option needs to be fully implemented for this obs_type. + self.representation_network = LatentEncoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + self.decoder_network = LatentDecoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + ) + self.tokenizer = Tokenizer( + encoder=self.representation_network, + decoder=self.decoder_network, + with_lpips=True, + obs_type=world_model_cfg.obs_type + ) + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + def _log_model_parameters(self, obs_type: str) -> None: + """ + Overview: + Logs detailed parameter counts for all model components with a comprehensive breakdown. + Includes encoder, transformer, prediction heads, and other components. + Arguments: + - obs_type (:obj:`str`): The type of observation ('vector', 'image', or 'image_memory'). + """ + from ding.utils import get_rank + + # Only print from rank 0 to avoid duplicate logs in DDP + if get_rank() != 0: + return + + print('=' * 80) + print('MODEL PARAMETER STATISTICS'.center(80)) + print('=' * 80) + + # --- Total Model Parameters --- + total_params = sum(p.numel() for p in self.parameters()) + total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + print(f'\n{"TOTAL MODEL":<40} {total_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {total_trainable:>15,} parameters') + print(f'{" └─ Frozen":<40} {total_params - total_trainable:>15,} parameters') + + # --- World Model Components --- + print(f'\n{"-" * 80}') + print(f'{"WORLD MODEL BREAKDOWN":<40}') + print(f'{"-" * 80}') + + wm_params = sum(p.numel() for p in self.world_model.parameters()) + wm_trainable = sum(p.numel() for p in self.world_model.parameters() if p.requires_grad) + print(f'{"World Model Total":<40} {wm_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {wm_trainable:>15,} parameters ({100*wm_trainable/wm_params:.1f}%)') + + # --- Encoder --- + encoder_params = sum(p.numel() for p in self.tokenizer.encoder.parameters()) + encoder_trainable = sum(p.numel() for p in self.tokenizer.encoder.parameters() if p.requires_grad) + print(f'\n{"1. ENCODER (Tokenizer)":<40} {encoder_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {encoder_trainable:>15,} parameters ({100*encoder_trainable/encoder_params:.1f}%)') + + # --- Transformer Backbone --- + transformer_params = sum(p.numel() for p in self.world_model.transformer.parameters()) + transformer_trainable = sum(p.numel() for p in self.world_model.transformer.parameters() if p.requires_grad) + print(f'\n{"2. TRANSFORMER BACKBONE":<40} {transformer_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {transformer_trainable:>15,} parameters ({100*transformer_trainable/transformer_params:.1f}%)') + + # --- Prediction Heads (Detailed Breakdown) --- + print(f'\n{"3. PREDICTION HEADS":<40}') + + # Access head_dict from world_model + if hasattr(self.world_model, 'head_dict'): + head_dict = self.world_model.head_dict + + # Calculate total heads parameters + total_heads_params = sum(p.numel() for module in head_dict.values() for p in module.parameters()) + total_heads_trainable = sum(p.numel() for module in head_dict.values() for p in module.parameters() if p.requires_grad) + print(f'{" Total (All Heads)":<40} {total_heads_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {total_heads_trainable:>15,} parameters ({100*total_heads_trainable/total_heads_params:.1f}%)') + + # Breakdown by head type + head_names_map = { + 'head_policy_multi_task': 'Policy Head', + 'head_value_multi_task': 'Value Head', + 'head_rewards_multi_task': 'Reward Head', + 'head_observations_multi_task': 'Next Latent (Obs) Head' + } + + print(f'\n{" Breakdown by Head Type:":<40}') + for head_key, head_name in head_names_map.items(): + if head_key in head_dict: + head_module = head_dict[head_key] + head_params = sum(p.numel() for p in head_module.parameters()) + head_trainable = sum(p.numel() for p in head_module.parameters() if p.requires_grad) + + # Count number of task-specific heads (for ModuleList) + if isinstance(head_module, nn.ModuleList): + num_heads = len(head_module) + params_per_head = head_params // num_heads if num_heads > 0 else 0 + print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + print(f'{" └─ " + f"{num_heads} task-specific heads":<38} {params_per_head:>15,} params/head') + else: + print(f'{" ├─ " + head_name:<38} {head_params:>15,} parameters') + print(f'{" └─ Shared across tasks":<38}') + + # --- Positional & Task Embeddings --- + print(f'\n{"4. EMBEDDINGS":<40}') + + if hasattr(self.world_model, 'pos_emb'): + pos_emb_params = sum(p.numel() for p in self.world_model.pos_emb.parameters()) + pos_emb_trainable = sum(p.numel() for p in self.world_model.pos_emb.parameters() if p.requires_grad) + print(f'{" ├─ Positional Embedding":<40} {pos_emb_params:>15,} parameters') + if pos_emb_trainable == 0: + print(f'{" └─ (Frozen)":<40}') + + if hasattr(self.world_model, 'task_emb') and self.world_model.task_emb is not None: + task_emb_params = sum(p.numel() for p in self.world_model.task_emb.parameters()) + task_emb_trainable = sum(p.numel() for p in self.world_model.task_emb.parameters() if p.requires_grad) + print(f'{" ├─ Task Embedding":<40} {task_emb_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {task_emb_trainable:>15,} parameters') + + if hasattr(self.world_model, 'act_embedding_table'): + act_emb_params = sum(p.numel() for p in self.world_model.act_embedding_table.parameters()) + act_emb_trainable = sum(p.numel() for p in self.world_model.act_embedding_table.parameters() if p.requires_grad) + print(f'{" └─ Action Embedding":<40} {act_emb_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {act_emb_trainable:>15,} parameters') + + # --- Decoder (if applicable) --- + if obs_type in ['vector', 'image_memory'] and self.tokenizer.decoder_network is not None: + print(f'\n{"5. DECODER":<40}') + decoder_params = sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) + decoder_trainable = sum(p.numel() for p in self.tokenizer.decoder_network.parameters() if p.requires_grad) + print(f'{" Decoder Network":<40} {decoder_params:>15,} parameters') + print(f'{" └─ Trainable":<40} {decoder_trainable:>15,} parameters') + + if obs_type == 'image_memory' and hasattr(self.tokenizer, 'lpips'): + lpips_params = sum(p.numel() for p in self.tokenizer.lpips.parameters()) + print(f'{" LPIPS Loss Network":<40} {lpips_params:>15,} parameters') + + # Calculate world model params excluding decoder and LPIPS + params_without_decoder = wm_params - decoder_params - lpips_params + print(f'\n{" World Model (exc. Decoder & LPIPS)":<40} {params_without_decoder:>15,} parameters') + + # --- Summary Table --- + print(f'\n{"=" * 80}') + print(f'{"SUMMARY":<40}') + print(f'{"=" * 80}') + print(f'{"Component":<30} {"Total Params":>15} {"Trainable":>15} {"% of Total":>15}') + print(f'{"-" * 80}') + + components = [ + ("Encoder", encoder_params, encoder_trainable), + ("Transformer", transformer_params, transformer_trainable), + ] + + if hasattr(self.world_model, 'head_dict'): + components.append(("Prediction Heads", total_heads_params, total_heads_trainable)) + + for name, total, trainable in components: + pct = 100 * total / total_params if total_params > 0 else 0 + print(f'{name:<30} {total:>15,} {trainable:>15,} {pct:>14.1f}%') + + print(f'{"=" * 80}') + print(f'{"TOTAL":<30} {total_params:>15,} {total_trainable:>15,} {"100.0%":>15}') + print(f'{"=" * 80}\n') + + def initial_inference(self, obs_batch: torch.Tensor, action_batch: Optional[torch.Tensor] = None, + current_obs_batch: Optional[torch.Tensor] = None, task_id: Optional[Any] = None) -> MZNetworkOutput: + """ + Overview: + Performs the initial inference step of the model, corresponding to the representation function `h` in MuZero. + It takes an observation and produces a latent state and initial predictions. + Arguments: + - obs_batch (:obj:`torch.Tensor`): A batch of initial observations. + - action_batch (:obj:`Optional[torch.Tensor]`): A batch of actions (if available, context-dependent). + - current_obs_batch (:obj:`Optional[torch.Tensor]`): A batch of current observations (if different from obs_batch). + - task_id (:obj:`Optional[Any]`): Identifier for the current task in a multi-task setting. + Returns: + - MZNetworkOutput: An object containing the predicted value, policy logits, and the initial latent state. + The reward is set to a zero tensor, as it's not predicted at the initial step. + """ + batch_size = obs_batch.size(0) + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference( + obs_act_dict, task_id=task_id + ) + + # The world model returns tokens and logits; map them to the standard MZNetworkOutput format. + latent_state = obs_token + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + + return MZNetworkOutput( + value=value, + reward=torch.zeros(batch_size, device=value.device), # Reward is 0 at initial inference + policy_logits=policy_logits, + latent_state=latent_state, + ) + + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index: int = 0, + search_depth: List = [], task_id: Optional[Any] = None) -> MZNetworkOutput: + """ + Overview: + Performs a recurrent inference step, corresponding to the dynamics function `g` and prediction + function `f` in MuZero. It predicts the next latent state, reward, policy, and value based on a + history of latent states and actions. + Arguments: + - state_action_history (:obj:`torch.Tensor`): A tensor representing the history of latent states and actions. + - simulation_index (:obj:`int`): The index of the current simulation step within MCTS. + - search_depth (:obj:`List`): Information about the search depth, used for positional embeddings. + - task_id (:obj:`Optional[Any]`): Identifier for the current task in a multi-task setting. + Returns: + - MZNetworkOutput: An object containing the predicted value, reward, policy logits, and the next latent state. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, search_depth, task_id=task_id + ) + + # Map the world model outputs to the standard MZNetworkOutput format. + next_latent_state = logits_observations + reward = logits_rewards.squeeze(1) + policy_logits = logits_policy.squeeze(1) + value = logits_value.squeeze(1) + + return MZNetworkOutput( + value=value, + reward=reward, + policy_logits=policy_logits, + latent_state=next_latent_state, + ) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/kv_cache_manager.py b/lzero/model/unizero_world_models/kv_cache_manager.py new file mode 100644 index 000000000..f2717466a --- /dev/null +++ b/lzero/model/unizero_world_models/kv_cache_manager.py @@ -0,0 +1,468 @@ +""" +KV Cache Manager for UniZero World Model +========================================= + +This module provides a unified, robust, and extensible KV cache management system +for the UniZero world model. It replaces the scattered cache logic with a clean, +well-tested abstraction. + +""" + +import logging +from typing import Dict, List, Optional, Tuple, Any, Callable +from dataclasses import dataclass, field +from enum import Enum +import torch +from collections import OrderedDict + +# Assuming kv_caching is in the same directory or accessible +from .kv_caching import KeysValues + + +logger = logging.getLogger(__name__) + + +class EvictionStrategy(Enum): + """Cache eviction strategies.""" + FIFO = "fifo" # First In First Out (circular overwrite) + LRU = "lru" # Least Recently Used + PRIORITY = "priority" # Priority-based + + +@dataclass +class CacheStats: + """Statistics for cache performance monitoring.""" + hits: int = 0 + misses: int = 0 + evictions: int = 0 + total_queries: int = 0 + + @property + def hit_rate(self) -> float: + """Calculate hit rate.""" + if self.total_queries == 0: + return 0.0 + return self.hits / self.total_queries + + @property + def miss_rate(self) -> float: + """Calculate miss rate.""" + return 1.0 - self.hit_rate + + def reset(self): + """Reset all statistics.""" + self.hits = 0 + self.misses = 0 + self.evictions = 0 + self.total_queries = 0 + + def __repr__(self) -> str: + return (f"CacheStats(hits={self.hits}, misses={self.misses}, " + f"evictions={self.evictions}, hit_rate={self.hit_rate:.2%})") + + +class KVCachePool: + """ + A fixed-size pool for storing KeysValues objects. + + This class manages a pre-allocated pool of KeysValues objects and provides + efficient storage and retrieval mechanisms with configurable eviction strategies. + + Args: + pool_size: Maximum number of KV caches to store + eviction_strategy: Strategy for cache eviction + enable_stats: Whether to collect statistics + name: Name for this cache pool (for logging) + """ + + def __init__( + self, + pool_size: int, + eviction_strategy: EvictionStrategy = EvictionStrategy.FIFO, + enable_stats: bool = True, + name: str = "default" + ): + if pool_size <= 0: + raise ValueError(f"pool_size must be positive, got {pool_size}") + + self.pool_size = pool_size + self.eviction_strategy = eviction_strategy + self.enable_stats = enable_stats + self.name = name + + # Core data structures + self._pool: List[Optional[KeysValues]] = [None] * pool_size + self._key_to_index: Dict[int, int] = {} # cache_key -> pool_index + self._index_to_key: List[Optional[int]] = [None] * pool_size # pool_index -> cache_key + + # Eviction strategy specific data + self._next_index: int = 0 # For FIFO + self._access_order: OrderedDict = OrderedDict() # For LRU + self._priorities: Dict[int, float] = {} # For PRIORITY + + # Statistics + self.stats = CacheStats() if enable_stats else None + + logger.info(f"Initialized KVCachePool '{name}' with size={pool_size}, " + f"strategy={eviction_strategy.value}") + + def get(self, cache_key: int) -> Optional[KeysValues]: + """ + Retrieve a cached KeysValues object. + + Args: + cache_key: The hash key for the cache + + Returns: + The cached KeysValues object if found, None otherwise + """ + if self.enable_stats: + self.stats.total_queries += 1 + + pool_index = self._key_to_index.get(cache_key) + + if pool_index is not None: + # Cache hit + if self.enable_stats: + self.stats.hits += 1 + + # Update access order for LRU + if self.eviction_strategy == EvictionStrategy.LRU: + self._access_order.move_to_end(cache_key) + + logger.debug(f"[{self.name}] Cache HIT for key={cache_key}, index={pool_index}") + return self._pool[pool_index] + else: + # Cache miss + if self.enable_stats: + self.stats.misses += 1 + + logger.debug(f"[{self.name}] Cache MISS for key={cache_key}") + return None + + def set(self, cache_key: int, kv_cache: KeysValues) -> int: + """ + Store a KeysValues object in the cache. + + Args: + cache_key: The hash key for the cache + kv_cache: The KeysValues object to store + + Returns: + The pool index where the cache was stored + """ + # ==================== BUG FIX: Defensive Deep Copy ==================== + # CRITICAL: Always clone the input to prevent cache corruption. + # This provides an additional layer of protection in case the caller + # forgets to clone. The clone operation ensures that the stored cache + # is independent from the caller's object, preventing unintended mutations. + kv_cache_copy = kv_cache.clone() + # ======================================================================= + + # Check if key already exists + if cache_key in self._key_to_index: + # Update existing entry + pool_index = self._key_to_index[cache_key] + self._pool[pool_index] = kv_cache_copy # Store cloned copy + + if self.eviction_strategy == EvictionStrategy.LRU: + self._access_order.move_to_end(cache_key) + + logger.debug(f"[{self.name}] Updated cache for key={cache_key} at index={pool_index}") + return pool_index + + # Find a slot for new entry + pool_index = self._find_slot_for_new_entry(cache_key) + + # Evict old entry if necessary + old_key = self._index_to_key[pool_index] + if old_key is not None: + self._evict(old_key, pool_index) + + # Store new entry (already cloned above) + self._pool[pool_index] = kv_cache_copy + self._key_to_index[cache_key] = pool_index + self._index_to_key[pool_index] = cache_key + + # Update access tracking for LRU + if self.eviction_strategy == EvictionStrategy.LRU: + self._access_order[cache_key] = True + + logger.debug(f"[{self.name}] Stored cache for key={cache_key} at index={pool_index}") + return pool_index + + def _find_slot_for_new_entry(self, cache_key: int) -> int: + """Find an appropriate slot for a new cache entry based on eviction strategy.""" + if self.eviction_strategy == EvictionStrategy.FIFO: + # Simple circular buffer + pool_index = self._next_index + self._next_index = (self._next_index + 1) % self.pool_size + return pool_index + + elif self.eviction_strategy == EvictionStrategy.LRU: + # Find LRU slot + if len(self._key_to_index) < self.pool_size: + # Pool not full, find first empty slot + for i in range(self.pool_size): + if self._index_to_key[i] is None: + return i + + # Evict LRU (first item in OrderedDict) + lru_key = next(iter(self._access_order)) + return self._key_to_index[lru_key] + + elif self.eviction_strategy == EvictionStrategy.PRIORITY: + # Find lowest priority slot + if len(self._key_to_index) < self.pool_size: + # Pool not full + for i in range(self.pool_size): + if self._index_to_key[i] is None: + return i + + # Evict lowest priority + min_priority_key = min(self._priorities, key=self._priorities.get) + return self._key_to_index[min_priority_key] + + else: + raise ValueError(f"Unknown eviction strategy: {self.eviction_strategy}") + + def _evict(self, cache_key: int, pool_index: int): + """Evict a cache entry.""" + if self.enable_stats: + self.stats.evictions += 1 + + # Remove from tracking structures + del self._key_to_index[cache_key] + self._index_to_key[pool_index] = None + + if self.eviction_strategy == EvictionStrategy.LRU: + self._access_order.pop(cache_key, None) + + if self.eviction_strategy == EvictionStrategy.PRIORITY: + self._priorities.pop(cache_key, None) + + logger.debug(f"[{self.name}] Evicted key={cache_key} from index={pool_index}") + + def clear(self): + """Clear all cache entries.""" + self._pool = [None] * self.pool_size + self._key_to_index.clear() + self._index_to_key = [None] * self.pool_size + self._next_index = 0 + self._access_order.clear() + self._priorities.clear() + + if self.enable_stats: + # Don't reset stats on clear, user can call stats.reset() explicitly + pass + + def __len__(self) -> int: + """Return the number of cached entries.""" + return len(self._key_to_index) + + def __repr__(self) -> str: + stats_str = f", {self.stats}" if self.enable_stats else "" + return (f"KVCachePool(name='{self.name}', size={len(self)}/{self.pool_size}, " + f"strategy={self.eviction_strategy.value}{stats_str})") + + +class KVCacheManager: + """ + Unified KV Cache Manager for World Model. + + This class manages multiple cache pools for different inference scenarios: + - Initial inference caches (per-environment) + - Recurrent inference caches (for MCTS) + - World model caches (temporary batch caches) + + Args: + config: World model configuration + env_num: Number of environments + enable_stats: Whether to enable statistics collection + clear_recur_log_freq: How often to log 'clear_recur_cache' calls. + clear_all_log_freq: How often to log 'clear_all' calls. + """ + + def __init__( + self, + config, + env_num: int, + enable_stats: bool = True, + clear_recur_log_freq: int = 1000, + clear_all_log_freq: int = 100 + ): + self.config = config + self.env_num = env_num + self.enable_stats = enable_stats + + # Throttling parameters and counters for logging control + self.clear_recur_log_freq = clear_recur_log_freq + self.clear_all_log_freq = clear_all_log_freq + self._clear_recur_counter = 0 + self._clear_all_counter = 0 + + # Initialize cache pools + self._init_cache_pools() + + # These lists store KeysValues objects, not integers + # Used in world model's trim_and_pad_kv_cache for batch processing + self.keys_values_wm_list: List[KeysValues] = [] + self.keys_values_wm_size_list: List[int] = [] + + logger.info(f"Initialized KVCacheManager for {env_num} environments") + + def _init_cache_pools(self): + """Initialize all cache pools.""" + # Initial inference pools (one per environment) + init_pool_size = int(self.config.game_segment_length) + self.init_pools: List[KVCachePool] = [] + for env_id in range(self.env_num): + pool = KVCachePool( + pool_size=init_pool_size, + eviction_strategy=EvictionStrategy.FIFO, + enable_stats=self.enable_stats, + name=f"init_env{env_id}" + ) + self.init_pools.append(pool) + + # Recurrent inference pool (shared across all environments) + num_simulations = getattr(self.config, 'num_simulations', 50) + recur_pool_size = int(num_simulations * self.env_num) + self.recur_pool = KVCachePool( + pool_size=recur_pool_size, + eviction_strategy=EvictionStrategy.FIFO, + enable_stats=self.enable_stats, + name="recurrent" + ) + + # World model pool (temporary) + wm_pool_size = self.env_num + self.wm_pool = KVCachePool( + pool_size=wm_pool_size, + eviction_strategy=EvictionStrategy.FIFO, + enable_stats=self.enable_stats, + name="world_model" + ) + + def get_init_cache(self, env_id: int, cache_key: int) -> Optional[KeysValues]: + """Get cache from initial inference pool.""" + if env_id < 0 or env_id >= self.env_num: + raise ValueError(f"Invalid env_id: {env_id}, must be in [0, {self.env_num})") + return self.init_pools[env_id].get(cache_key) + + def set_init_cache(self, env_id: int, cache_key: int, kv_cache: KeysValues) -> int: + """Set cache in initial inference pool.""" + if env_id < 0 or env_id >= self.env_num: + raise ValueError(f"Invalid env_id: {env_id}, must be in [0, {self.env_num})") + return self.init_pools[env_id].set(cache_key, kv_cache) + + def get_recur_cache(self, cache_key: int) -> Optional[KeysValues]: + """Get cache from recurrent inference pool.""" + return self.recur_pool.get(cache_key) + + def set_recur_cache(self, cache_key: int, kv_cache: KeysValues) -> int: + """Set cache in recurrent inference pool.""" + return self.recur_pool.set(cache_key, kv_cache) + + def get_wm_cache(self, cache_key: int) -> Optional[KeysValues]: + """Get cache from world model pool.""" + return self.wm_pool.get(cache_key) + + def set_wm_cache(self, cache_key: int, kv_cache: KeysValues) -> int: + """Set cache in world model pool.""" + return self.wm_pool.set(cache_key, kv_cache) + + def hierarchical_get(self, env_id: int, cache_key: int) -> Optional[KeysValues]: + """ + Perform hierarchical cache lookup: init_pool -> recur_pool. + + This method encapsulates the two-level lookup strategy: + 1. First try to find in environment-specific init_infer cache + 2. If not found, fallback to global recurrent_infer cache + + Arguments: + - env_id (:obj:`int`): Environment ID for init cache lookup + - cache_key (:obj:`int`): Cache key to lookup + + Returns: + - kv_cache (:obj:`Optional[KeysValues]`): Found cache or None + """ + # Step 1: Try init_infer cache first (per-environment) + kv_cache = self.get_init_cache(env_id, cache_key) + if kv_cache is not None: + return kv_cache + + # Step 2: If not found, try recurrent_infer cache (global) + return self.get_recur_cache(cache_key) + + def clear_all(self): + """Clear all cache pools with throttled logging.""" + # Core clearing actions always execute + for pool in self.init_pools: + pool.clear() + self.recur_pool.clear() + self.wm_pool.clear() + self.keys_values_wm_list.clear() + self.keys_values_wm_size_list.clear() + + # Throttled logging logic + self._clear_all_counter += 1 + if self.clear_all_log_freq > 0 and self._clear_all_counter % self.clear_all_log_freq == 0: + logger.info( + f"Cleared all KV caches (this message appears every " + f"{self.clear_all_log_freq} calls, total calls: {self._clear_all_counter})" + ) + + def clear_init_caches(self): + """Clear only initial inference caches.""" + for pool in self.init_pools: + pool.clear() + logger.info("Cleared initial inference caches") + + def clear_recur_cache(self): + """Clear only recurrent inference cache with throttled logging.""" + # The core cache clearing action always executes + self.recur_pool.clear() + + # Throttled logging logic: only log if frequency is positive and counter is a multiple of the frequency + self._clear_recur_counter += 1 + if self.clear_recur_log_freq > 0 and self._clear_recur_counter % self.clear_recur_log_freq == 0: + logger.info( + f"Cleared recurrent inference cache (this message appears every " + f"{self.clear_recur_log_freq} calls, total calls: {self._clear_recur_counter})" + ) + + def get_stats_summary(self) -> Dict[str, Any]: + """Get statistics summary for all pools.""" + if not self.enable_stats: + return {"stats_enabled": False} + + summary = { + "stats_enabled": True, + "init_pools": {}, + "recur_pool": str(self.recur_pool.stats), + "wm_pool": str(self.wm_pool.stats), + } + + for env_id, pool in enumerate(self.init_pools): + summary["init_pools"][f"env_{env_id}"] = str(pool.stats) + + return summary + + def reset_stats(self): + """Reset statistics for all pools.""" + if not self.enable_stats: + return + + for pool in self.init_pools: + pool.stats.reset() + self.recur_pool.stats.reset() + self.wm_pool.stats.reset() + logger.info("Reset all cache statistics") + + def __repr__(self) -> str: + init_sizes = [len(pool) for pool in self.init_pools] + return (f"KVCacheManager(env_num={self.env_num}, " + f"init_caches={init_sizes}, " + f"recur_cache={len(self.recur_pool)}/{self.recur_pool.pool_size}, " + f"wm_cache={len(self.wm_pool)}/{self.wm_pool.pool_size})") \ No newline at end of file diff --git a/lzero/model/unizero_world_models/kv_caching.py b/lzero/model/unizero_world_models/kv_caching.py index 28b7b0ba2..b5ea8106d 100644 --- a/lzero/model/unizero_world_models/kv_caching.py +++ b/lzero/model/unizero_world_models/kv_caching.py @@ -1,110 +1,254 @@ -# Modified from https://github.com/eloialonso/iris/blob/main/src/models/kv_caching.py +# -*- coding: utf-8 -*- +""" +This script is a refactored version of the key-value caching mechanism from: +https://github.com/eloialonso/iris/blob/main/src/models/kv_caching.py -from typing import Tuple +The optimization focuses on improving clarity, documentation, and adherence to modern coding standards +while strictly preserving the original functionality and external API. +""" +from typing import Tuple, Optional import numpy as np import torch +class AssignWithoutInplaceCheck(torch.autograd.Function): + """ + Overview: + A custom autograd function to perform an in-place-like assignment on a tensor slice + without triggering PyTorch's version counter checks. This is useful for updating + buffers or caches within a computation graph. + + Reference: + Inspired by discussions on the PyTorch forums, such as: + https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4 + + .. warning:: + This function is unsafe if the same slice of the input tensor is overwritten + multiple times, as it can lead to incorrect gradient calculations. + """ + + @staticmethod + def _get_slice(dim: int, start: int, stop: int) -> Tuple[slice, ...]: + """ + Overview: + Creates a slice tuple for indexing a tensor at a specific dimension. + Arguments: + - dim (:obj:`int`): The dimension to slice along. + - start (:obj:`int`): The starting index for the slice. + - stop (:obj:`int`): The ending index for the slice. + Returns: + - slice_tuple (:obj:`Tuple[slice, ...]`): A tuple of slice objects for indexing. + """ + return (slice(None),) * dim + (slice(start, stop),) + + @staticmethod + def forward( + ctx, + input_tensor: torch.Tensor, + value: torch.Tensor, + dim: int, + start: int, + stop: int + ) -> torch.Tensor: + """ + Overview: + The forward pass assigns the `value` tensor to a slice of the `input_tensor`. + Arguments: + - ctx: The context object for storing information for the backward pass. + - input_tensor (:obj:`torch.Tensor`): The tensor to be modified. + - value (:obj:`torch.Tensor`): The tensor to assign to the slice. + - dim (:obj:`int`): The dimension along which to perform the assignment. + - start (:obj:`int`): The starting index of the slice. + - stop (:obj:`int`): The ending index of the slice. + Returns: + - modified_tensor (:obj:`torch.Tensor`): The `input_tensor` after modification. + """ + ctx.dim = dim + ctx.start = start + ctx.stop = stop + # Directly modify the data of the input tensor to bypass version checks. + input_tensor.data[AssignWithoutInplaceCheck._get_slice(dim, start, stop)] = value + return input_tensor + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Optional[torch.Tensor], ...]: + """ + Overview: + The backward pass computes gradients for the inputs of the forward pass. + Arguments: + - ctx: The context object with saved information from the forward pass. + - grad_output (:obj:`torch.Tensor`): The gradient of the output tensor. + Returns: + - grad_input_tensor (:obj:`torch.Tensor`): The gradient with respect to `input_tensor`. + - grad_value (:obj:`torch.Tensor`): The gradient with respect to `value`. + - None, None, None: Gradients for `dim`, `start`, and `stop`, which are not needed. + """ + # The gradient for the original input tensor is the same as the output gradient. + grad_input_tensor = grad_output + # The gradient for the value tensor is the slice of the output gradient. + grad_value = grad_output[AssignWithoutInplaceCheck._get_slice(ctx.dim, ctx.start, ctx.stop)] + return grad_input_tensor, grad_value, None, None, None + + class Cache: + """ + Overview: + A cache for storing a single type of intermediate tensor (e.g., keys or values) + in a Transformer-like model. It handles dynamic updates and size management. + """ + def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: """ Overview: - Cache for storing intermediate results in a transformer model. + Initializes the cache. Arguments: - - num_samples (:obj:`int`): The number of samples to cache. + - num_samples (:obj:`int`): The number of samples (batch size) to cache. - num_heads (:obj:`int`): The number of attention heads. - - max_tokens (:obj:`int`): The maximum number of tokens. - - embed_dim (:obj:`int`): The dimension of the embeddings. - - device (:obj:`torch.device`): The device on which to store the cache. + - max_tokens (:obj:`int`): The maximum number of tokens the cache can hold. + - embed_dim (:obj:`int`): The total dimension of the embeddings. + - device (:obj:`torch.device`): The device on which to store the cache tensor. """ - assert embed_dim % num_heads == 0 - self._num_samples, self._cache, self._size = num_samples, None, None - self._reset = lambda n: torch.empty(n, num_heads, max_tokens, embed_dim // num_heads, device=device) # (B, nh, T, hs) + if embed_dim % num_heads != 0: + raise ValueError(f"Embedding dimension ({embed_dim}) must be divisible by the number of heads ({num_heads}).") + + self._num_samples = num_samples + self._num_heads = num_heads + self._max_tokens = max_tokens + self._head_dim = embed_dim // num_heads + self._device = device + + self._cache: torch.Tensor = self._create_cache_tensor(self._num_samples) + self._size: int = 0 self.reset() + def _create_cache_tensor(self, num_samples: int) -> torch.Tensor: + """ + Overview: + Creates an empty tensor with the correct shape and device for the cache. + Arguments: + - num_samples (:obj:`int`): The number of samples for which to create the cache. + Returns: + - empty_cache (:obj:`torch.Tensor`): An uninitialized tensor for the cache. + """ + return torch.empty( + num_samples, self._num_heads, self._max_tokens, self._head_dim, device=self._device + ) # Shape: (B, nh, T, hs) + @property def shape(self) -> Tuple[int, int, int, int]: """ Overview: - Get the shape of the cache. + Gets the effective shape of the cache's content. Returns: - - shape (:obj:`Tuple[int, int, int, int]`): The shape of the cache. + - shape (:obj:`Tuple[int, int, int, int]`): A tuple representing (num_samples, num_heads, current_size, head_dim). """ - n, num_heads, _, head_dim = self._cache.shape - return n, num_heads, self._size, head_dim + return self._num_samples, self._num_heads, self._size, self._head_dim def reset(self) -> None: """ Overview: - Reset the cache to its initial state. + Resets the cache to an empty state. """ - self._cache = self._reset(self._num_samples) + self._cache = self._create_cache_tensor(self._num_samples) self._size = 0 def prune(self, mask: np.ndarray) -> None: """ Overview: - Prune the cache based on a mask. + Prunes the cache along the sample dimension using a boolean mask. Arguments: - - mask (:obj:`np.ndarray`): A boolean mask indicating which samples to keep. + - mask (:obj:`np.ndarray`): A 1D boolean array where `True` indicates which samples to keep. """ - assert mask.ndim == 1 and mask.shape[0] == self.shape[0] + if not (mask.ndim == 1 and mask.shape[0] == self._num_samples): + raise ValueError("Mask must be a 1D numpy array with length equal to the number of samples.") self._cache = self._cache[mask] self._num_samples = self._cache.shape[0] def get(self) -> torch.Tensor: """ Overview: - Get the current contents of the cache. + Retrieves the current contents of the cache. Returns: - - cache (:obj:`torch.Tensor`): The current contents of the cache. + - cache_content (:obj:`torch.Tensor`): A tensor containing the valid data in the cache. """ return self._cache[:, :, :self._size, :] def update(self, x: torch.Tensor, tokens: int) -> None: """ Overview: - Update the cache with new values. + Updates the cache with new tensor values. If the cache is full, it discards the oldest + tokens to make space. Arguments: - - x (:obj:`torch.Tensor`): The new values to update the cache with. - - tokens (:obj:`int`): The number of tokens to update. + - x (:obj:`torch.Tensor`): The new tensor data to add to the cache. + - tokens (:obj:`int`): The number of tokens being added (sequence length of `x`). """ - # assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)]) - # assert self._size + tokens <= self._cache.shape[2] # TODO - self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + tokens) + required_capacity = self._size + tokens + + # If the new tokens exceed the cache's maximum capacity, shift existing data to make room. + if required_capacity > self._max_tokens: + shift_amount = required_capacity - self._max_tokens + + # This logic is crucial for models like MuZero where tokens are added in (state, action) pairs. + # To maintain the integrity of these pairs, an even number of tokens must be discarded. + if shift_amount % 2 != 0: + shift_amount += 1 + + if shift_amount >= self._size: + # If the required shift is larger than the current cache size, it's more efficient to reset. + self._cache.zero_() + self._size = 0 + else: + # Shift the existing cache content to the left, discarding the oldest tokens. + self._cache[:, :, :self._size - shift_amount, :] = self._cache[:, :, shift_amount:self._size, :] + self._size -= shift_amount + # NOTE: Shifting the cache invalidates absolute positional embeddings. + # The parent model must handle positional encoding adjustments. For example, if positional + # embeddings are calculated based on `prev_steps`, this shift means `prev_steps` may no + # longer correspond to the true start, potentially causing discontinuities. + + # Use the custom autograd function to assign the new data without inplace errors. + self._cache = AssignWithoutInplaceCheck.apply( + self._cache, x, 2, self._size, self._size + tokens + ) self._size += tokens class KVCache: - def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: + """ + Overview: + A container for a pair of caches: one for keys (K) and one for values (V), + typically used in a single attention layer of a Transformer. + """ + + def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: """ Overview: - Cache for storing key and value tensors in a transformer model. + Initializes the Key-Value cache pair. Arguments: - - n (:obj:`int`): The number of samples to cache. + - num_samples (:obj:`int`): The number of samples (batch size) to cache. - num_heads (:obj:`int`): The number of attention heads. - - max_tokens (:obj:`int`): The maximum number of tokens. - - embed_dim (:obj:`int`): The dimension of the embeddings. - - device (:obj:`torch.device`): The device on which to store the cache. + - max_tokens (:obj:`int`): The maximum number of tokens the cache can hold. + - embed_dim (:obj:`int`): The total dimension of the embeddings. + - device (:obj:`torch.device`): The device on which to store the cache tensors. """ - self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device) - self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device) + self._k_cache = Cache(num_samples, num_heads, max_tokens, embed_dim, device) + self._v_cache = Cache(num_samples, num_heads, max_tokens, embed_dim, device) @property def shape(self) -> Tuple[int, int, int, int]: """ Overview: - Get the shape of the key cache. + Gets the effective shape of the key cache's content. Returns: - - shape (:obj:`Tuple[int, int, int, int]`): The shape of the key cache. + - shape (:obj:`Tuple[int, int, int, int]`): Shape of the key cache (num_samples, num_heads, current_size, head_dim). """ return self._k_cache.shape def reset(self) -> None: """ Overview: - Reset both key and value caches to their initial states. + Resets both the key and value caches to their empty states. """ self._k_cache.reset() self._v_cache.reset() @@ -112,9 +256,9 @@ def reset(self) -> None: def prune(self, mask: np.ndarray) -> None: """ Overview: - Prune both key and value caches based on a mask. + Prunes both key and value caches based on a boolean mask. Arguments: - - mask (:obj:`np.ndarray`): A boolean mask indicating which samples to keep. + - mask (:obj:`np.ndarray`): A 1D boolean array indicating which samples to keep. """ self._k_cache.prune(mask) self._v_cache.prune(mask) @@ -122,74 +266,94 @@ def prune(self, mask: np.ndarray) -> None: def get(self) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: - Get the current contents of the key and value caches. + Retrieves the current contents of the key and value caches. Returns: - key_cache (:obj:`torch.Tensor`): The current contents of the key cache. - value_cache (:obj:`torch.Tensor`): The current contents of the value cache. """ return self._k_cache.get(), self._v_cache.get() - def update(self, k: torch.Tensor, v: torch.Tensor): + def update(self, k: torch.Tensor, v: torch.Tensor) -> None: """ Overview: - Update both key and value caches with new values. + Updates both key and value caches with new tensors. Arguments: - - k (:obj:`torch.Tensor`): The new values to update the key cache with. - - v (:obj:`torch.Tensor`): The new values to update the value cache with. + - k (:obj:`torch.Tensor`): The new key tensor to add. + - v (:obj:`torch.Tensor`): The new value tensor to add. """ - self._k_cache.update(k, k.size(2)) - self._v_cache.update(v, v.size(2)) + # The number of tokens is inferred from the sequence dimension (dim 2). + num_tokens = k.size(2) + self._k_cache.update(k, num_tokens) + self._v_cache.update(v, num_tokens) class KeysValues: - def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None: + """ + Overview: + Manages a collection of KVCache objects, one for each layer in a Transformer model. + """ + + def __init__( + self, + num_samples: int, + num_heads: int, + max_tokens: int, + embed_dim: int, + num_layers: int, + device: torch.device + ) -> None: """ Overview: - Class for managing multiple layers of key and value caches in a transformer model. + Initializes KV caches for all layers. Arguments: - - n (:obj:`int`): The number of samples to cache. + - num_samples (:obj:`int`): The number of samples (batch size). - num_heads (:obj:`int`): The number of attention heads. - - max_tokens (:obj:`int`): The maximum number of tokens. + - max_tokens (:obj:`int`): The maximum number of tokens per cache. - embed_dim (:obj:`int`): The dimension of the embeddings. - - num_layers (:obj:`int`): The number of layers in the transformer model. - - device (:obj:`torch.device`): The device on which to store the caches. + - num_layers (:obj:`int`): The number of layers in the Transformer model. + - device (:obj:`torch.device`): The device for storing cache tensors. """ - self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers)]) + self._keys_values = tuple([ + KVCache(num_samples, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers) + ]) - def __getitem__(self, index: int) -> KVCache: + def __getitem__(self, layer_index: int) -> KVCache: """ Overview: - Get the key and value cache for a specific layer. + Retrieves the KVCache for a specific layer. Arguments: - - index (:obj:`int`): The layer index. + - layer_index (:obj:`int`): The index of the layer. Returns: - - kv_cache (:obj:`KVCache`): The key and value cache for the specified layer. + - kv_cache (:obj:`KVCache`): The key-value cache for the specified layer. """ - return self._keys_values[index] + return self._keys_values[layer_index] - def __len__(self): + def __len__(self) -> int: """ Overview: - Get the number of layers in the transformer model. + Gets the number of layers. Returns: - - length (:obj:`int`): The number of layers. + - num_layers (:obj:`int`): The number of layers being managed. """ return len(self._keys_values) @property - def size(self): + def size(self) -> int: """ Overview: - Get the size of the tokens in the cache. + Gets the current number of tokens stored in the caches. Returns: - - size (:obj:`int`): The size of the tokens in the cache. + - size (:obj:`int`): The number of tokens in the cache (assumes all layers have the same size). """ + # All layer caches are synchronized, so we can check the size of the first one. + if not self._keys_values: + return 0 return self._keys_values[0].shape[2] def reset(self) -> None: """ Overview: - Reset all key and value caches to their initial states. + Resets the KV caches for all layers. """ for kv_cache in self._keys_values: kv_cache.reset() @@ -197,70 +361,72 @@ def reset(self) -> None: def prune(self, mask: np.ndarray) -> None: """ Overview: - Prune all key and value caches based on a mask. + Prunes the KV caches for all layers based on a mask. Arguments: - mask (:obj:`np.ndarray`): A boolean mask indicating which samples to keep. """ for kv_cache in self._keys_values: kv_cache.prune(mask) - -class AssignWithoutInplaceCheck(torch.autograd.Function): - """ - Overview: - Custom autograd function to perform in-place assignment without triggering version checks. - Inspired from: - https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4 - - .. warning: - Do not use it to overwrite a slice twice. - """ - - @staticmethod - def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]: + def remove_register_tokens(self, register_token_num: int) -> None: """ Overview: - Get the slice object for the given dimension and range. + Removes the last `register_token_num` tokens from the active view of the cache + in each layer by adjusting the internal size pointer. This does not delete the data + but makes it invisible to subsequent `get` and `update` calls. + This is typically called after an inference step that used temporary tokens + (e.g., register tokens) to ensure they are not part of the ongoing context. Arguments: - - dim (:obj:`int`): The dimension along which to slice. - - start (:obj:`int`): The start index of the slice. - - stop (:obj:`int`): The stop index of the slice. - Returns: - - slice (:obj:`Tuple[slice]`): The slice object. + - register_token_num (:obj:`int`): The number of tokens to remove from the end of the cache view. """ - return tuple([slice(None), ] * dim + [slice(start, stop)]) + if register_token_num <= 0: + return + for kv_cache in self._keys_values: + # Decrement the size pointer for both K and V caches. + kv_cache._k_cache._size = max(0, kv_cache._k_cache._size - register_token_num) + kv_cache._v_cache._size = max(0, kv_cache._v_cache._size - register_token_num) - @staticmethod - def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor: + def clone(self) -> "KeysValues": """ Overview: - Forward pass of the custom autograd function. - Arguments: - - ctx: The context object to store information for backward computation. - - input (:obj:`torch.Tensor`): The input tensor to be modified. - - value (:obj:`torch.Tensor`): The value tensor to assign to the input. - - dim (:obj:`int`): The dimension along which to assign the value. - - start (:obj:`int`): The start index of the assignment. - - stop (:obj:`int`): The stop index of the assignment. - Returns: - - output (:obj:`torch.Tensor`): The modified input tensor. - """ - ctx.dim = dim - ctx.start = start - ctx.stop = stop - input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value - return input + Creates a deep copy of this KeysValues object. + + This method is critical for preventing cache corruption. When a cached KeysValues object + is retrieved and used in transformer forward passes, the transformer modifies it in-place. + Without cloning, this would pollute the original cache, causing incorrect predictions. - @staticmethod - def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]: - """ - Overview: - Backward pass of the custom autograd function. - Arguments: - - ctx: The context object storing information from forward computation. - - grad_out (:obj:`torch.Tensor`): The gradient of the output tensor. Returns: - - grad_input (:obj:`torch.Tensor`): The gradient of the input tensor. - - grad_value (:obj:`torch.Tensor`): The gradient of the value tensor. + - cloned_kv (:obj:`KeysValues`): A new KeysValues object with copied data. """ - return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None \ No newline at end of file + if not self._keys_values: + # Handle empty case + raise ValueError("Cannot clone an empty KeysValues object") + + # Get parameters from the first layer's cache + first_kv_cache = self._keys_values[0] + num_samples, num_heads, _, head_dim = first_kv_cache.shape + max_tokens = first_kv_cache._k_cache._max_tokens + embed_dim = num_heads * head_dim + num_layers = len(self._keys_values) + device = first_kv_cache._k_cache._device + + # Create a new KeysValues object with the same structure + cloned_kv = KeysValues( + num_samples=num_samples, + num_heads=num_heads, + max_tokens=max_tokens, + embed_dim=embed_dim, + num_layers=num_layers, + device=device + ) + + # Deep copy each layer's cache data + for src_layer, dst_layer in zip(self._keys_values, cloned_kv._keys_values): + # Copy the key and value cache tensors using torch.copy_() for efficient data transfer + dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) + dst_layer._v_cache._cache.copy_(src_layer._v_cache._cache) + # Copy the size information + dst_layer._k_cache._size = src_layer._k_cache._size + dst_layer._v_cache._size = src_layer._v_cache._size + + return cloned_kv \ No newline at end of file diff --git a/lzero/model/unizero_world_models/lpips.py b/lzero/model/unizero_world_models/lpips.py index c6ee6426c..16237df20 100644 --- a/lzero/model/unizero_world_models/lpips.py +++ b/lzero/model/unizero_world_models/lpips.py @@ -4,14 +4,14 @@ import hashlib import os -from collections import namedtuple -from pathlib import Path - import requests import torch import torch.nn as nn from torchvision import models from tqdm import tqdm +from collections import namedtuple +from pathlib import Path +from ditk import logging class LPIPS(nn.Module): @@ -20,21 +20,61 @@ def __init__(self, use_dropout: bool = True): super().__init__() self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features + # Comment out the following line if you don't need perceptual loss - # self.net = vgg16(pretrained=True, requires_grad=False) + # This line will now automatically use the path specified by TORCH_HOME + self.net = vgg16(pretrained=True, requires_grad=False) self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) - # Comment out the following line if you don't need perceptual loss - # self.load_from_pretrained() - # for param in self.parameters(): - # param.requires_grad = False + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False def load_from_pretrained(self) -> None: - ckpt = get_ckpt_path(name="vgg_lpips", root=Path.home() / ".cache/iris/tokenizer_pretrained_vgg") # Download VGG if necessary - self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + """ + Load LPIPS linear layer weights (vgg.pth) from TORCH_HOME directory. + + Raises: + EnvironmentError: If TORCH_HOME is not set or invalid. + FileNotFoundError: If checkpoint file cannot be loaded. + RuntimeError: If state dict loading fails. + """ + # Get TORCH_HOME from environment variable + torch_home = os.environ.get('TORCH_HOME') + + if torch_home is None: + error_msg = ( + "TORCH_HOME environment variable is not set. " + "Please set TORCH_HOME to specify the directory for pretrained models. " + "Example: export TORCH_HOME=/path/to/torch/home" + ) + logging.error(error_msg) + raise EnvironmentError(error_msg) + + try: + logging.info(f"Loading LPIPS pretrained weights from TORCH_HOME: {torch_home}") + ckpt = get_ckpt_path(name="vgg_lpips", root=torch_home) + + if not os.path.exists(ckpt): + error_msg = f"Checkpoint file not found: {ckpt}" + logging.error(error_msg) + raise FileNotFoundError(error_msg) + + logging.info(f"Loading checkpoint from: {ckpt}") + state_dict = torch.load(ckpt, map_location=torch.device("cpu")) + self.load_state_dict(state_dict, strict=False) + logging.info(f"Successfully loaded LPIPS pretrained weights from: {ckpt}") + + except FileNotFoundError as e: + logging.error(f"Failed to load LPIPS checkpoint: {e}") + raise + except Exception as e: + error_msg = f"Failed to load LPIPS pretrained weights: {e}" + logging.error(error_msg) + raise RuntimeError(error_msg) from e def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) @@ -74,7 +114,10 @@ def __init__(self, chn_in: int, chn_out: int = 1, use_dropout: bool = False) -> class vgg16(torch.nn.Module): def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None: super(vgg16, self).__init__() + # With TORCH_HOME set, pretrained=True will search or download the model in the specified directory + logging.info("Loading vgg16 backbone...") vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + logging.info("vgg16 backbone loaded.") self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() @@ -160,6 +203,7 @@ def md5_hash(path: str) -> str: def get_ckpt_path(name: str, root: str, check: bool = False) -> str: assert name in URL_MAP + # This function is used for loading vgg.pth, and the path is correct path = os.path.join(root, CKPT_MAP[name]) if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py new file mode 100644 index 000000000..17ce0605b --- /dev/null +++ b/lzero/model/unizero_world_models/moe.py @@ -0,0 +1,272 @@ +import dataclasses +from typing import List, Any + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +from lzero.model.unizero_world_models.transformer import _maybe_wrap_linear + +# Note: The following lines are examples of how _maybe_wrap_linear might be used. +# _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward") + +# This implementation is inspired by the following sources: +# https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/moe.py +# https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer_layers.py#L149 +# Modified from https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer.py#L108 + + +class MultiplicationFeedForward(nn.Module): + """ + Overview: + Implements the SwiGLU (Swish-Gated Linear Unit) feed-forward layer, a variant of a transformer feed-forward network + that uses element-wise multiplication of two linear projections, one of which is passed through a SiLU activation. + This is often expressed as: FFN_SwiGLU(x) = (SiLU(x @ W1) * (x @ W3)) @ W2. + """ + + def __init__(self, config: Any) -> None: + """ + Overview: + Initializes the MultiplicationFeedForward layer. + Arguments: + - config (:obj:`Any`): A configuration object containing model hyperparameters. + It is expected to have `embed_dim` (int) and `moe_use_lora` (bool). + """ + super().__init__() + hidden_dim = 4 * config.embed_dim + if config.moe_use_lora: + self.w1 = _maybe_wrap_linear(nn.Linear(config.embed_dim, hidden_dim, bias=False), config, "feed_forward") + self.w2 = _maybe_wrap_linear(nn.Linear(hidden_dim, config.embed_dim, bias=False), config, "feed_forward") + self.w3 = _maybe_wrap_linear(nn.Linear(config.embed_dim, hidden_dim, bias=False), config, "feed_forward") + else: + self.w1 = nn.Linear(config.embed_dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, config.embed_dim, bias=False) + self.w3 = nn.Linear(config.embed_dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the forward pass of the SwiGLU layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor. + Returns: + - torch.Tensor: The output tensor after applying the SwiGLU transformation. + """ + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +@dataclasses.dataclass +class MoeArgs(Serializable): + """ + Overview: + Dataclass for storing Mixture-of-Experts (MoE) configuration arguments. + """ + num_experts: int # The total number of experts in the MoE layer. + num_experts_per_tok: int # The number of experts to route each token to (k). + + +class MoELayer(nn.Module): + """ + Overview: + A straightforward implementation of a Mixture-of-Experts (MoE) layer. + This version iterates through each expert and processes the tokens routed to it. + While clear and easy to understand, it can be less efficient than vectorized approaches. + + The process is as follows: + 1. The input tensor `x` is flattened from [B, T, D] to [N, D], where N = B * T. + 2. A gating network calculates logits for each token to determine expert assignment. + 3. For each token, the top-k experts are selected based on the logits. + 4. The layer iterates through each expert, gathers all tokens assigned to it, + and computes their outputs. + 5. The outputs are weighted by the gating scores and summed up. + 6. An optional shared expert can be applied to all tokens. + 7. The final tensor is reshaped to its original shape [B, T, D]. + + Attributes: + - dim (:obj:`int`): The dimension of the input features. + - num_experts (:obj:`int`): The total number of experts. + - num_experts_per_tok (:obj:`int`): The number of experts activated per token (top-k). + - gate (:obj:`nn.Module`): The gating network that produces routing logits. + - experts (:obj:`nn.ModuleList`): A list of expert networks. + - shared_expert (:obj:`nn.Module` or `None`): An optional shared expert applied to all tokens. + """ + + def __init__(self, config: Any, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int = 1) -> None: + """ + Overview: + Initializes the MoELayer. + Arguments: + - config (:obj:`Any`): A configuration object. Expected to have `embed_dim` and optionally `n_shared_experts`. + - experts (:obj:`List[nn.Module]`): A list of PyTorch modules representing the experts. + - gate (:obj:`nn.Module`): The gating module for routing tokens. + - num_experts_per_tok (:obj:`int`): The number of experts to use for each token. + """ + super().__init__() + self.dim = config.embed_dim + self.num_experts = len(experts) + self.num_experts_per_tok = num_experts_per_tok + self.gate = gate + self.experts = nn.ModuleList(experts) + + # If specified in the config, create a shared expert branch. + if hasattr(config, "n_shared_experts") and config.n_shared_experts > 0: + # TODO: The architecture of the shared expert could be made more configurable. + self.shared_expert = nn.Sequential( + nn.Linear(self.dim, config.n_shared_experts * (4 * self.dim)), + nn.GELU(), + nn.Linear(config.n_shared_experts * (4 * self.dim), self.dim) + ) + else: + self.shared_expert = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the forward pass for the MoE layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor of shape [batch_size, seq_len, dim]. + Returns: + - torch.Tensor: The output tensor with the same shape as the input. + """ + # Store original shape and flatten input to 2D: [batch_size * seq_len, dim] + original_shape = x.size() + x = x.view(-1, self.dim) + + # Compute gate logits, shape: [num_tokens, num_experts] + gate_logits = self.gate(x) + # Select top-k experts for each token. + weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) + # Normalize the weights of selected experts using softmax. + weights = F.softmax(weights, dim=1).to(x.dtype) + + # Initialize the output tensor for expert computations. + expert_output = torch.zeros_like(x) + + # Iterate over each expert to compute outputs for the tokens routed to it. + for expert_id in range(self.num_experts): + # Find the tokens that have this expert in their top-k list. + batch_idx, expert_tok_idx = torch.where(indices == expert_id) + if batch_idx.numel() == 0: + continue + + # Select the subset of tokens for the current expert. + token_subset = x[batch_idx] # Shape: [num_tokens_for_expert, dim] + # Compute the output from the current expert. + output_expert = self.experts[expert_id](token_subset) + # Get the corresponding weights for these tokens. + token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1) + # Apply weights and accumulate the output. + expert_output[batch_idx] += output_expert * token_weights + + # If a shared expert exists, add its output. + if self.shared_expert is not None: + shared_output = self.shared_expert(x) + output = expert_output + shared_output + else: + output = expert_output + + # Restore the original tensor shape and return. + return output.view(original_shape) + + + """ + Overview: + An optimized implementation of the Mixture-of-Experts (MoE) layer that maintains the same API as `MoELayer`. + This version avoids loops over experts by using a vectorized scatter-gather approach, which is significantly + more efficient on modern hardware. The forward pass complexity is O(N_tokens + ΣE_i), where ΣE_i is the + total number of tokens processed across all experts. + + The process is as follows: + 1. **Routing**: Get top-k experts and their weights for each token. + 2. **Flattening**: Create a flat list of (token_index, expert_index, weight) tuples. + 3. **Sorting**: Sort these tuples by expert_index. This groups all tokens destined for the same expert together. + 4. **Batch Forward**: Process the tokens for each expert in a single, contiguous batch, avoiding Python loops. + 5. **Weighted Scatter**: Apply gating weights to the expert outputs and scatter-add them back to a buffer + indexed by the original token positions. + 6. **Shared Expert**: If configured, add the output from the shared expert. + 7. **Reshape**: Reshape the final output tensor to its original 3D shape. + """ + + def __init__(self, config: Any, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int = 1) -> None: + """ + Overview: + Initializes the MoELayerOptimized. + Arguments: + - config (:obj:`Any`): A configuration object. Expected to have `embed_dim` and optionally `n_shared_experts`. + - experts (:obj:`List[nn.Module]`): A list of PyTorch modules representing the experts. + - gate (:obj:`nn.Module`): The gating module for routing tokens. + - num_experts_per_tok (:obj:`int`): The number of experts to use for each token. + """ + super().__init__() + self.dim = config.embed_dim + self.num_experts = len(experts) + self.num_experts_per_tok = num_experts_per_tok + self.gate = gate + self.experts = nn.ModuleList(experts) + + self.use_shared = getattr(config, "n_shared_experts", 0) > 0 + if self.use_shared: + # TODO: The architecture of the shared expert could be made more configurable. + self.shared_expert = nn.Sequential( + nn.Linear(self.dim, config.n_shared_experts * (4 * self.dim)), + nn.GELU(), + nn.Linear(config.n_shared_experts * (4 * self.dim), self.dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Performs the optimized forward pass for the MoE layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor of shape [B, T, D]. + Returns: + - torch.Tensor: The output tensor with the same shape as the input. + """ + B, T, D = x.shape + x_flat = x.reshape(-1, D) # [N, D]; N = B*T + + # 1. Routing: Get top-k experts and weights. + gate_logits = self.gate(x_flat) # [N, E] + weights, topk_idx = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) # [N, k] + weights = F.softmax(weights, dim=1).to(x.dtype) # [N, k] + + # 2. Flatten token-expert pairs. + N, k = weights.shape + flat_token_idx = torch.arange(N, device=x.device).repeat_interleave(k) # [N*k] + flat_expert_idx = topk_idx.reshape(-1) # [N*k] + flat_weight = weights.reshape(-1, 1) # [N*k, 1] + flat_input = x_flat[flat_token_idx] # [N*k, D] + + # 3. Sort by expert index to group tokens for batch processing. + sort_order = torch.argsort(flat_expert_idx) # [N*k] + flat_expert_idx = flat_expert_idx[sort_order] + flat_token_idx = flat_token_idx[sort_order] + flat_weight = flat_weight[sort_order] + flat_input = flat_input[sort_order] + + # Count how many tokens each expert will process. + counts = torch.bincount(flat_expert_idx, minlength=self.num_experts) # [E] + + # Prepare output buffer. + out_buffer = torch.zeros_like(flat_input) # [N*k, D] + + # 4. Perform forward pass for each expert on its batch of tokens. + ptr = 0 + for eid, num in enumerate(counts.tolist()): + if num == 0: + continue + seg = slice(ptr, ptr + num) + out_buffer[seg] = self.experts[eid](flat_input[seg]) + ptr += num + + # 5. Apply weights and scatter-add results back to token-indexed buffer. + out_buffer.mul_(flat_weight) # In-place multiplication by weights. + token_output = torch.zeros_like(x_flat) # [N, D] + token_output.index_add_(0, flat_token_idx, out_buffer) + + # 6. Add shared expert output if it exists. + if self.use_shared: + token_output.add_(self.shared_expert(x_flat)) + + return token_output.reshape(B, T, D) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index e5e18461f..734b8d8f5 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -1,116 +1,129 @@ """ Modified from https://github.com/CompVis/taming-transformers +This module provides an autoencoder-style tokenizer for encoding observations into latent embeddings and decoding them back. """ +import inspect from dataclasses import dataclass +from typing import Any, Dict, Optional, List import torch import torch.nn as nn -from einops import rearrange from torch.nn import functional as F -from typing import Optional, List from transformers.modeling_outputs import BaseModelOutput -class LossWithIntermediateLosses: - def __init__(self, **kwargs): - """Initialize with various loss components.""" - self.loss_total = sum(kwargs.values()) - self.intermediate_losses = {k: v.item() for k, v in kwargs.items()} - - def __truediv__(self, value): - """Divide all loss components by a given value.""" - for k, v in self.intermediate_losses.items(): - self.intermediate_losses[k] = v / value - self.loss_total = self.loss_total / value - return self - @dataclass class TokenizerEncoderOutput: + """ + Overview: + A data structure to hold the various outputs from a VQ-VAE style encoder, + including continuous and quantized latent representations, and discrete tokens. + """ + # Continuous latent representation from the encoder. z: torch.FloatTensor + # Quantized latent representation. z_quantized: torch.FloatTensor + # Discrete integer tokens corresponding to the codebook entries. tokens: torch.LongTensor class Tokenizer(nn.Module): """ Overview: - Tokenizer model that encodes and decodes observations. - Can operate on visual or textual data, supporting optional LPIPS perceptual loss. - It optionally includes a linear projection layer and can be paired with a decoder tokenizer. + An autoencoder model that encodes high-dimensional observations (like images or state vectors) + into low-dimensional latent embeddings and decodes them back. It can also compute reconstruction + and perceptual losses. This implementation does not include the quantization step (Vector Quantization) + but serves as the encoder-decoder backbone. """ - def __init__(self, encoder=None, decoder_network=None, decoder_network_tokenizer=None, with_lpips: bool = False, projection: list = None, encoder_option='legacy') -> None: - """Initialize the Tokenizer. + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + with_lpips: bool = False, + obs_type: str = 'image' + ) -> None: + """ + Overview: + Initializes the Tokenizer (Autoencoder). Arguments: - encoder (nn.Module, optional): Encoder network to transform raw inputs into embeddings. - decoder_network (nn.Module, optional): Decoder network used for observation reconstruction or text generation. - decoder_network_tokenizer (PreTrainedTokenizer, optional): Tokenizer compatible with the decoder network (e.g., T5 tokenizer). - with_lpips (bool, optional): If True, enable perceptual loss computation via LPIPS. Defaults to False. - projection (list[int], optional): If provided, defines a linear projection layer from projection[0] → projection[1]. - If None, an identity layer is used. - encoder_option (str, optional): Option to specify the encoder type, e.g., 'legacy' for T5 decoder or 'qwen' for Qwen decoder. Defaults to 'legacy'. + - encoder (:obj:`nn.Module`): The network responsible for encoding observations into latent embeddings. It can be a single module or an nn.ModuleList for multi-task scenarios. + - decoder (:obj:`nn.Module`): The network responsible for decoding latent embeddings back into observations. + - with_lpips (:obj:`bool`): If True, initializes the LPIPS model to compute perceptual loss. Defaults to False. + - obs_type (:obj:`str`): The type of observation, e.g., 'image' or 'vector'. This can inform model architecture choices. Defaults to 'image'. """ super().__init__() + self.encoder = encoder + self.decoder_network = decoder + self.obs_type = obs_type + self.lpips: Optional[nn.Module] = None if with_lpips: + # Lazily import LPIPS as it's an optional dependency. from lzero.model.unizero_world_models.lpips import LPIPS self.lpips = LPIPS().eval() - else: - self.lpips = None - self.encoder = encoder - self.decoder_network = decoder_network - self.decoder_network_tokenizer = decoder_network_tokenizer - self.encoder_option = encoder_option - - if projection is None: - self.projection_layer = nn.Identity() - else: - self.projection_layer = nn.Linear(projection[0], projection[1]) - - def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: + def encode_to_obs_embeddings(self, x: torch.Tensor, task_id: int = 0) -> torch.Tensor: """ - Encode observations to embeddings. - + Overview: + Encodes a batch of observations into latent embeddings, handling various input shapes and multi-task encoders. Arguments: - x (torch.Tensor): Input tensor of shape (B, ...). - + - x (:obj:`torch.Tensor`): The input tensor of observations. Shape can be (B, E), (B, T, E), (B, C, H, W), or (B, T, C, H, W). + - task_id (:obj:`int`): The identifier for the task, used to select the correct encoder from an nn.ModuleList in multi-task settings. Defaults to 0. Returns: - torch.Tensor: Encoded embeddings of shape (B, 1, E). + - torch.Tensor: The encoded latent embeddings with a consistent shape of (B, 1, E), where B is the effective batch size. """ - shape = x.shape - # Process input tensor based on its dimensionality - if len(shape) == 2: - # Case when input is 2D (B, E) - obs_embeddings = self.encoder(x) - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') - elif len(shape) == 3: - # Case when input is 3D (B, T, E) - x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) - obs_embeddings = self.encoder(x) - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') - elif len(shape) == 4: - # Case when input is 4D (B, C, H, W) - obs_embeddings = self.encoder(x) - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') - elif len(shape) == 5: - # Case when input is 5D (B, T, C, H, W) - x = x.contiguous().view(-1, *shape[-3:]) # Flatten the first two dimensions (B * T, C, H, W) - obs_embeddings = self.encoder(x) - obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') + # Step 1: Select the appropriate encoder module. + # This handles both single-task (a single nn.Module) and multi-task (an nn.ModuleList) scenarios. + if isinstance(self.encoder, nn.ModuleList): + if not 0 <= task_id < len(self.encoder): + encoder_module = self.encoder[0] + else: + encoder_module = self.encoder[task_id] else: - raise ValueError(f"Invalid input shape: {shape}") + encoder_module = self.encoder + + # Step 2: Pre-process and reshape the input tensor based on its dimensions. + # The goal is to transform the input into a 2D or 4D tensor that the encoder can process. + original_shape = x.shape + if len(original_shape) == 5: # Batch of sequences of images: (B, T, C, H, W) + # Flatten the batch and time dimensions to create a batch of images. + x = x.contiguous().view(-1, *original_shape[-3:]) # Shape: (B*T, C, H, W) + elif len(original_shape) == 3: # Batch of sequences of vectors: (B, T, E) + # Flatten the batch and time dimensions to create a batch of vectors. + x = x.contiguous().view(-1, original_shape[-1]) # Shape: (B*T, E) + # Note: 2D (B, E) and 4D (B, C, H, W) inputs are processed directly without reshaping. + + # Step 3: Pass the processed tensor through the encoder. + # Some encoders (like RepresentationNetworkMLPMT) require task_id as a parameter, + # while others do not. We use inspect to check the signature and pass task_id only if needed. + sig = inspect.signature(encoder_module.forward) + if 'task_id' in sig.parameters: + # Encoder requires task_id (e.g., RepresentationNetworkMLPMT) + obs_embeddings = encoder_module(x, task_id=task_id) + else: + # Encoder does not require task_id (e.g., standard CNN/MLP encoders) + obs_embeddings = encoder_module(x) + if len(obs_embeddings.shape) != 2: + raise RuntimeError( + f"Encoder output was expected to be 2D (batch, embedding_dim), but got shape {obs_embeddings.shape}." + ) + + # Step 4: Reshape the output to a consistent sequence format (B', 1, E). + # The '1' represents a sequence length of one, making it compatible with sequence models. + obs_embeddings = obs_embeddings.unsqueeze(1) + return obs_embeddings def decode_to_obs(self, embeddings: torch.Tensor) -> torch.Tensor: - """Decode embeddings to observations. - + """ + Overview: + Decodes a batch of latent embeddings back into the observation space. Arguments: - embeddings (:obj:`torch.Tensor`): Input embeddings. - + - embeddings (:obj:`torch.Tensor`): The latent embeddings to decode. Returns: - torch.Tensor: Decoded observations. + - torch.Tensor: The reconstructed observations. """ return self.decoder_network(embeddings) @@ -268,36 +281,43 @@ def decode_to_plain_text( @staticmethod def reconstruction_loss(original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: - """Calculate the reconstruction loss. - + """ + Overview: + Calculates the reconstruction loss between original and reconstructed observations. + It uses L2 (MSE) loss for vector-based observations and L1 (MAE) loss for image-based observations. Arguments: - - original_images (:obj:`torch.Tensor`): Original images. - - reconstructed_images (:obj:`torch.Tensor`): Reconstructed images. - + - original_images (:obj:`torch.Tensor`): The ground-truth observations. + - reconstructed_images (:obj:`torch.Tensor`): The observations reconstructed by the decoder. Returns: - - torch.Tensor: Computed reconstruction loss. + - torch.Tensor: A scalar tensor representing the computed reconstruction loss. """ if len(original_images.shape) == 2: - # For memory environment vector observations - loss = F.mse_loss(original_images, reconstructed_images) # L2 loss + # Use Mean Squared Error (L2 loss) for vector-based observations. + return F.mse_loss(reconstructed_images, original_images) else: - # For Atari image environment - loss = torch.abs(original_images - reconstructed_images).mean() # L1 loss - return loss + # Use Mean Absolute Error (L1 loss) for image-based observations, which is often more robust to outliers. + return torch.abs(original_images - reconstructed_images).mean() def perceptual_loss(self, original_images: torch.Tensor, reconstructed_images: torch.Tensor) -> torch.Tensor: - """Calculate the perceptual loss using LPIPS. - + """ + Overview: + Calculates the perceptual loss (LPIPS) between original and reconstructed images. + This loss is designed to better align with human perception of image similarity. Arguments: - original_images (:obj:`torch.Tensor`): Original images. - reconstructed_images (:obj:`torch.Tensor`): Reconstructed images. - + - original_images (:obj:`torch.Tensor`): The ground-truth images. + - reconstructed_images (:obj:`torch.Tensor`): The images reconstructed by the decoder. Returns: - torch.Tensor: Computed perceptual loss. + - torch.Tensor: A scalar tensor representing the computed perceptual loss. """ + if self.lpips is None: + raise RuntimeError("LPIPS model was not initialized. Please set `with_lpips=True` during Tokenizer instantiation.") return torch.mean(self.lpips(original_images, reconstructed_images)) def __repr__(self) -> str: - return "Tokenizer" \ No newline at end of file + """ + Overview: + Provides a string representation of the Tokenizer module. + """ + return f"Tokenizer(obs_type='{self.obs_type}', with_lpips={self.lpips is not None})" \ No newline at end of file diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index c2feb8497..399f98929 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -1,372 +1,665 @@ """ -The following code is modified from https://github.com/karpathy/nanoGPT. +This script is an extension of the original transformer.py from karpathy/nanoGPT. +It incorporates LoRA (Low-Rank Adaptation) for fine-tuning and introduces a +Curriculum Learning mechanism that activates different LoRA adapters sequentially. + +Key features: +- Adds `CurriculumLoRALinear`, a custom linear layer with multiple LoRA adapters. +- Controls which modules to apply LoRA to via configuration (e.g., attention and feed-forward layers). +- Maintains the extensibility and readability of the original nanoGPT codebase. """ -import numpy as np import math +import logging from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn as nn -import torch.nn as nn -from torch.nn import functional as F from ding.torch_utils.network import GRUGatingUnit from einops import rearrange +from torch.nn import functional as F from .kv_caching import KeysValues +from lzero.model.common import SimNorm -@dataclass -class TransformerConfig: - tokens_per_block: int - max_blocks: int - attention: str +class LearnableScale(nn.Module): + """ + A learnable scalar parameter constrained within a specific range. - num_layers: int - num_heads: int - embed_dim: int + The formula `s = offset + scale * tanh(ŝ)` maps an unbounded logit `ŝ` + to the range (offset - scale, offset + scale). Using tanh can sometimes + provide more stable gradients than sigmoid. - embed_pdrop: float - resid_pdrop: float - attn_pdrop: float - - # for RoPE - rope_theta: float - max_seq_len: int - rotary_emb: bool = False + For example, to achieve a range of (0.8, 1.2), one would use + `init=1.0` and `s_range=0.2`. + """ - @property - def max_tokens(self): - return self.tokens_per_block * self.max_blocks + def __init__(self, init: float = 1.0, s_range: float = 0.2) -> None: + """ + Overview: + Initializes the LearnableScale module. + Arguments: + - init (:obj:`float`): The initial value of the scalar, which also serves as the center of the range. + - s_range (:obj:`float`): The scale factor that determines the range (init - s_range, init + s_range). + """ + super().__init__() + assert s_range > 0, "The scaling range must be positive." + self.offset = init + self.scale = s_range + # Initialize the logit to 0, so the initial output is exactly `init`. + self.logit = nn.Parameter(torch.tensor(0.0)) + # TODO: Initially frozen, activated by a CurriculumController. + self.logit.requires_grad = False -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - """ - Precompute the frequency components for the rotary positional embeddings. + def forward(self) -> torch.Tensor: + """ + Overview: + Computes the scaled value. + Returns: + - torch.Tensor: The learnable scalar, constrained to the specified range. + """ + return self.offset + self.scale * torch.tanh(self.logit) - Arguments: - - dim (int): The dimension of the embedding. - - end (int): The length of the sequence for which frequencies are computed. - - theta (float): A scaling factor for the frequencies, default is 10000.0. +############################################## +# Optimized CurriculumLoRALinear Implementation +############################################## - Returns: - - freqs_cis (torch.Tensor): A tensor of complex numbers representing the precomputed frequencies. +class CurriculumLoRALinear(nn.Module): """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - + Optimized CurriculumLoRALinear. + + Effective weight at stage s: + W_eff = α₀*W₀ + Σ_{j=1 to s} αⱼ*Δθⱼ -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + Optimization logic at stage s (s >= 1): + - Train: Δθₛ, α₀, and {αⱼ | 1 <= j < s} + - Freeze: W₀, {Δθⱼ | 1 <= j < s}, and αₛ + + This avoids the redundancy of training αₛ alongside Δθₛ. """ - Reshape the frequency components for broadcasting with the input tensor. - Arguments: - - freqs_cis (torch.Tensor): The frequency components tensor. - - x (torch.Tensor): The input tensor to which the frequencies will be applied. + def __init__(self, in_features: int, out_features: int, bias: bool = True, + r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, + curriculum_stage_num: int = 1, lora_scale_init: float = 1.0) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.r = r + self.lora_alpha = lora_alpha + self.scaling = lora_alpha / r if r > 0 else 1.0 + self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0.0 else nn.Identity() + self.curriculum_stage_num = curriculum_stage_num + self.curriculum_stage = 0 + + # Base weights (W₀ and bias) + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter('bias', None) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + # Learnable scale for the base weight (α₀) + self.base_weight_scale = LearnableScale(init=1.0, s_range=0.2) + + # A scale for each adapter (α₁, α₂, ...) + self.adapters = nn.ModuleList() + self.adapter_scales = nn.ModuleList() + + if r > 0 and (curriculum_stage_num - 1) > 0: + for _ in range(curriculum_stage_num - 1): + adapter = nn.ParameterDict({ + 'lora_A': nn.Parameter(torch.randn(r, in_features) * 0.01), + 'lora_B': nn.Parameter(torch.zeros(out_features, r)) + }) + self.adapters.append(adapter) + self.adapter_scales.append(LearnableScale(lora_scale_init, s_range=0.2)) + else: + self.adapters = None - Returns: - - torch.Tensor: The reshaped frequency components tensor. - """ - # Reference: https://github.com/meta-llama/llama3/blob/main/llama/model.py#L61 - ndim = x.ndim - shape = [d if i in (0, 2, ndim - 1) else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + self.set_curriculum_stage(0) + def set_curriculum_stage(self, stage: int) -> None: + assert 0 <= stage < self.curriculum_stage_num, f"Stage must be within [0, {self.curriculum_stage_num-1}]" + self.curriculum_stage = stage + module_id = f"({self.in_features}x{self.out_features})" + + # --- Stage 0: Base Training --- + if stage == 0: + self.weight.requires_grad = True + if self.bias is not None: self.bias.requires_grad = True + + # Freeze everything else + self.base_weight_scale.logit.requires_grad = False + if self.adapters: + for adapter in self.adapters: + adapter['lora_A'].requires_grad = False + adapter['lora_B'].requires_grad = False + for scale in self.adapter_scales: + scale.logit.requires_grad = False + + # Log only from rank 0 to avoid excessive output + from ding.utils import get_rank + if get_rank() == 0: + logging.info(f"[CurriculumLoRALinear {module_id}] Stage 0: Base layer trainable.") + + # --- Stage >= 1: Adaptation --- + else: + # Freeze base model + self.weight.requires_grad = False + if self.bias is not None: self.bias.requires_grad = False + + # α₀ is trainable from stage 1 onwards + self.base_weight_scale.logit.requires_grad = True + + if self.adapters: + # Set trainability for LoRA adapters + for idx, adapter in enumerate(self.adapters): + is_current_adapter = (idx == stage - 1) + adapter['lora_A'].requires_grad = is_current_adapter + adapter['lora_B'].requires_grad = is_current_adapter + + # --- OPTIMIZED LOGIC FOR SCALES --- + # Set trainability for adapter scales {α_j} + for idx, scale in enumerate(self.adapter_scales): + # A scale α_j is trainable if it belongs to a *previous* stage (j < s). + # The current stage's scale α_s (idx = stage - 1) is NOT trained. + is_previous_scale = (idx < stage - 1) + scale.logit.requires_grad = is_previous_scale + + # Log only from rank 0 to avoid excessive output + from ding.utils import get_rank + if get_rank() == 0: + logging.info(f"[CurriculumLoRALinear {module_id}] Stage {stage}: Activating adapter {stage - 1} and scales for stages < {stage - 1}.") + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply scaling to base weight if in an adaptation stage + if self.curriculum_stage > 0: + alpha_0 = self.base_weight_scale() + scaled_weight = self.weight * alpha_0 + baseline_out = F.linear(x, scaled_weight, self.bias) + else: + baseline_out = F.linear(x, self.weight, self.bias) + + if self.curriculum_stage == 0 or self.adapters is None: + return baseline_out + + adapter_out = 0 + # Iterate through all adapters up to the current stage + for idx in range(self.curriculum_stage): + if idx >= len(self.adapters): + break + + adapter = self.adapters[idx] + scale = self.adapter_scales[idx]() + + lora_x = self.lora_dropout(x) + out = F.linear(lora_x, adapter['lora_A']) + out = F.linear(out, adapter['lora_B']) + + # The forward pass is a simple sum. The magic happens in `set_curriculum_stage` + # which controls `requires_grad`. No need for `.detach()` here. + # Gradients will naturally flow only to parameters with `requires_grad=True`. + adapter_out = adapter_out + self.scaling * out * scale + + return baseline_out + adapter_out + -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Apply rotary positional embeddings to the query and key tensors. - Arguments: - - xq (torch.Tensor): The query tensor. - - xk (torch.Tensor): The key tensor. - - freqs_cis (torch.Tensor): The precomputed frequency components. +############################################## +# Helper function to wrap linear layers +############################################## +def _maybe_wrap_linear(linear: nn.Linear, config, module_label: str) -> nn.Module: + """ + Overview: + A helper function that wraps an `nn.Linear` layer with `CurriculumLoRALinear` + if LoRA and curriculum learning are enabled for the specified module. + Arguments: + - linear (:obj:`nn.Linear`): The original linear layer to be potentially wrapped. + - config: The model configuration object. + - module_label (:obj:`str`): A label identifying the module type (e.g., "attn", "feed_forward"). Returns: - - Tuple[torch.Tensor, torch.Tensor]: The transformed query and key tensors. - - Note: - For more information on rotary positional embeddings, refer to the blog post: - https://spaces.ac.cn/archives/8265/ or paper https://arxiv.org/abs/2104.09864 + - nn.Module: The wrapped `CurriculumLoRALinear` layer or the original `nn.Linear` layer. """ - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) - return xq_out.type_as(xq), xk_out.type_as(xk) + use_curriculum_lora = ( + config.lora_r > 0 and + module_label in config.lora_target_modules and + getattr(config, "curriculum_stage_num", 1) > 1 + ) + if use_curriculum_lora: + new_linear = CurriculumLoRALinear( + in_features=linear.in_features, + out_features=linear.out_features, + bias=(linear.bias is not None), + r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + curriculum_stage_num=config.curriculum_stage_num, + lora_scale_init=config.lora_scale_init + ) + new_linear.weight.data.copy_(linear.weight.data) + if linear.bias is not None: + new_linear.bias.data.copy_(linear.bias.data) + return new_linear + else: + return linear -class Transformer(nn.Module): - """ - Transformer model class. +############################################## +# Helper function to set curriculum stage +############################################## +def set_curriculum_stage(model: nn.Module, stage: int) -> None: + """ + Overview: + Recursively traverses all submodules of a given model, finds all instances + of `CurriculumLoRALinear`, and calls their `set_curriculum_stage` method. + This function is generic and can be applied to any model structure. Arguments: - - config (:obj:`TransformerConfig`): Configuration for the Transformer model. + - model (:obj:`nn.Module`): The model to update (e.g., a Transformer or Vision Transformer). + - stage (:obj:`int`): The curriculum stage to set. + """ + count = 0 + for module in model.modules(): + if isinstance(module, CurriculumLoRALinear): + module.set_curriculum_stage(stage) + count += 1 + + # Log only from rank 0 to avoid excessive output + from ding.utils import get_rank + if count > 0 and get_rank() == 0: + logging.info(f"[Curriculum] Updated {count} CurriculumLoRALinear modules in {type(model).__name__} to stage {stage}.") + +# Alias for backward compatibility +set_curriculum_stage_for_transformer = set_curriculum_stage + + +############################################## +# Transformer Configuration +############################################## +@dataclass +class TransformerConfig: + """Configuration for the Transformer model.""" + tokens_per_block: int + max_blocks: int + attention: str + + num_layers: int + num_heads: int + embed_dim: int + + embed_pdrop: float + resid_pdrop: float + attn_pdrop: float + + # LoRA parameters + lora_r: int = 0 + lora_alpha: int = 1 + lora_dropout: float = 0.0 + lora_target_modules: list = None + + # Curriculum Learning parameters + # `curriculum_stage_num` is the total number of stages (e.g., 3 means stages 0, 1, 2) + curriculum_stage_num: int = 1 # 1 (base) + number of available LoRA adapters + min_stage0_iters: int = 10_000 # Minimum iterations for stage 0 + max_stage_iters: int = 20_000 # Maximum iterations per stage + lora_scale_init: float = 1.0 # Initial value for learnable adapter scales + + # Other configurations + task_embed_option: str = "none" + register_token_num: int = 4 + register_token_shared: bool = True + + gru_gating: bool = False + moe_in_transformer: bool = False + multiplication_moe_in_transformer: bool = False + num_experts_of_moe_in_transformer: int = 1 - Attributes: - - config (:obj:`TransformerConfig`): Configuration object. - - drop (:obj:`nn.Dropout`): Dropout layer for embedding dropout. - - blocks (:obj:`nn.ModuleList`): List of Transformer blocks. - - ln_f (:obj:`nn.LayerNorm`): Layer normalization applied to the final output. + @property + def max_tokens(self) -> int: + """Maximum number of tokens the model can handle.""" + return self.tokens_per_block * self.max_blocks + + +class Transformer(nn.Module): + """ + A Transformer model implementation. """ - def __init__(self, config: TransformerConfig) -> None: + def __init__(self, config: TransformerConfig, task_embed: Optional[nn.Module] = None) -> None: + """ + Overview: + Initializes the Transformer model. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the model. + - task_embed (:obj:`Optional[nn.Module]`): An optional module for generating task embeddings. + """ super().__init__() self.config = config self.drop = nn.Dropout(config.embed_pdrop) self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) self.ln_f = nn.LayerNorm(config.embed_dim) - if self.config.rotary_emb: - freqs_cis = precompute_freqs_cis( - self.config.embed_dim // self.config.num_heads, - self.config.max_seq_len * 2, - self.config.rope_theta, - ) - self.register_buffer("freqs_cis", freqs_cis) + self.task_embed = task_embed + self.task_embed_option = self.config.task_embed_option + self.use_register_token = (self.task_embed_option == "register_task_embed") + + if self.use_register_token: + self.register_token_num = getattr(config, "register_token_num", 4) + self.register_token_shared = getattr(config, "register_token_shared", True) + + if self.register_token_shared: + # Shared mode: all tasks use the same register_tokens parameter. + self.register_tokens = nn.Parameter(torch.empty(self.register_token_num, config.embed_dim)) + nn.init.xavier_uniform_(self.register_tokens) + else: + # Non-shared mode: relies on the external `task_embed` module to generate + # task-specific embeddings, which are then normalized and expanded. + self.task_embed = task_embed + self.sim_norm = SimNorm(simnorm_dim=config.embed_dim) - def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: + def add_register_tokens(self, sequences: torch.Tensor, task_id: int) -> torch.Tensor: """ - Generate a placeholder for keys and values. + Overview: + Prepends or appends register tokens to the input sequences. + Arguments: + - sequences (:obj:`torch.Tensor`): The input sequences, with shape (B, T, C). + - task_id (:obj:`int`): The ID of the current task. + Returns: + - torch.Tensor: The sequences with register tokens concatenated, shape (B, T + register_token_num, C). + """ + B = sequences.size(0) + device = sequences.device + + if self.register_token_shared: + # Shared mode: use the same set of register tokens for all batches. + register_tokens = self.register_tokens.unsqueeze(0).expand(B, -1, -1) + else: + # Non-shared mode: dynamically generate task embedding and expand it. + task_embedding = self.task_embed(torch.tensor([task_id], device=device)) + task_embedding = self.sim_norm(task_embedding.view(1, -1)).view(-1) + register_tokens = task_embedding.unsqueeze(0).expand(self.register_token_num, -1) + register_tokens = register_tokens.unsqueeze(0).expand(B, -1, -1) + # Concatenate register tokens at the end of the sequence. + new_sequences = torch.cat([sequences, register_tokens], dim=1) + return new_sequences + + def remove_register_tokens_from_kv(self, past_keys_values: Optional[KeysValues]) -> None: + """ + Overview: + Removes the register tokens from the key-value cache of all layers. + This is called at the end of the forward pass during inference. Arguments: - - n (:obj:`int`): Batch size. - - max_tokens (:obj:`int`): Maximum number of tokens in the sequence. + - past_keys_values (:obj:`Optional[KeysValues]`): The key-value cache. + """ + if past_keys_values is not None: + past_keys_values.remove_register_tokens(self.register_token_num) + def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: + """ + Overview: + Generates a placeholder for the key-value cache. + Arguments: + - n (:obj:`int`): The batch size. + - max_tokens (:obj:`int`): The maximum number of tokens in the sequence. Returns: - - KeysValues: An object containing empty keys and values. + - KeysValues: An object containing empty tensors for keys and values. """ - device = self.ln_f.weight.device # Assumption: All submodules are on the same device + device = self.ln_f.weight.device return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) - def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0) -> torch.Tensor: + def forward( + self, + sequences: torch.Tensor, + past_keys_values: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None, + task_id: int = 0, + start_pos: int = 0 + ) -> torch.Tensor: """ - Forward pass of the Transformer model. - + Overview: + Performs the forward pass of the Transformer model. Arguments: - - sequences (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). - - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). - - start_pos (:obj:`int`): Starting position for rotary embeddings (default: 0). - + - sequences (:obj:`torch.Tensor`): The input tensor of shape (B, T, C). + - past_keys_values (:obj:`Optional[KeysValues]`): An optional cache for keys and values to speed up inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor indicating the valid length of the context for each sample. + - task_id (:obj:`int`): The ID of the current task. + - start_pos (:obj:`int`): The starting position for the current sequence (used with kv-caching). Returns: - - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). + - torch.Tensor: The output tensor of shape (B, T, C). """ - seqlen = sequences.shape[1] - # If using Rotary Position Embeddings (RoPE), slice the frequency components accordingly - if self.config.rotary_emb: - if isinstance(start_pos, (int, float, np.integer)): - # In the reanalyze_phase or reset stage in collection/evaluation phase, create a tensor filled with start_pos, expanded to match the batch size, and adjust for sequence type, e.g., start_pos=2. - start_pos_tensor = torch.full((sequences.shape[0],), int(start_pos), device=sequences.device) - elif isinstance(start_pos, (list, np.ndarray, torch.Tensor)): - if isinstance(start_pos[0], (np.ndarray, torch.Tensor, list)): - # In the training phase, flatten start_pos, take the first element, convert to tensor, e.g., start_pos=[array([ 8, 10, 12, 14, 16]), array([12, 14, 16, 18, 20])] - start_pos_tensor = torch.as_tensor( - [x.reshape(-1)[0].item() for x in start_pos], # Force flatten and take the first element - device=sequences.device - ) - elif isinstance(start_pos[0], (int, float, np.integer)): - # In the collection/evaluation phase, e.g., start_pos = [0, 0, 0, 0, 0, 0, 0, 0] - start_pos_tensor = torch.as_tensor([int(x) for x in start_pos], device=sequences.device) - else: - raise ValueError("start_pos must be an int, float, list, numpy array or torch.Tensor.") - - # TODO: Determine how to handle cases when episode length exceeds max_seq_len - # Use modulo operation to ensure start_pos does not exceed max_seq_len - start_pos_tensor = torch.remainder(start_pos_tensor, self.config.max_seq_len) - # Convert each sample's start_pos to a list - start_pos_list = start_pos_tensor.tolist() - # For each sample, slice the corresponding range of freqs_cis based on start_pos - freqs_cis_slices = [self.freqs_cis[int(pos): int(pos) + seqlen] for pos in start_pos_list] - freqs_cis = torch.stack(freqs_cis_slices) - - if freqs_cis.ndim == 3 and freqs_cis.shape[1] == 1: - # Convert shape [seq_len, 1, num_pairs] to [seq_len, num_pairs] - freqs_cis = freqs_cis.squeeze(1) - else: - freqs_cis = None + if self.use_register_token: + sequences = self.add_register_tokens(sequences, task_id) - # print(f"freqs_cis.shape:{freqs_cis.shape}") - - # Ensure past keys and values match the number of transformer blocks - assert past_keys_values is None or len(past_keys_values) == len(self.blocks) - # Apply dropout to the input sequences x = self.drop(sequences) - # Pass through each transformer block + for i, block in enumerate(self.blocks): - x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, freqs_cis) - # Apply final layer normalization + kv_cache_layer = None if past_keys_values is None else past_keys_values[i] + x = block(x, kv_cache_layer, valid_context_lengths) + x = self.ln_f(x) + + if self.use_register_token: + # During inference, remove register tokens from the KV cache to maintain consistency + # for external logic that does not expect them. + if past_keys_values is not None: + self.remove_register_tokens_from_kv(past_keys_values) + + # TODO: Remove register tokens from the final output to match the input sequence length. + x = x[:, :-self.register_token_num, :] + return x class Block(nn.Module): """ - Transformer block class. - - Arguments: - config (:obj:`TransformerConfig`): Configuration for the Transformer block. - - Attributes: - - gru_gating (:obj:`bool`): Flag to use GRU gating mechanism. - - gru_bias (:obj:`float`): Bias for the GRU gating mechanism. - - gate1 (:obj:`Optional[GRUGatingUnit]`): First GRU gating unit (if GRU gating is enabled). - - gate2 (:obj:`Optional[GRUGatingUnit]`): Second GRU gating unit (if GRU gating is enabled). - - ln1 (:obj:`nn.LayerNorm`): Layer normalization before the attention layer. - - ln2 (:obj:`nn.LayerNorm`): Layer normalization before the MLP. - - attn (:obj:`SelfAttention`): Self-attention mechanism. - - mlp (:obj:`nn.Sequential`): Multi-layer perceptron. + A single Transformer block, consisting of self-attention and a feed-forward network. """ def __init__(self, config: TransformerConfig) -> None: + """ + Overview: + Initializes a Transformer block. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the block. + """ super().__init__() - # NOTE: GRU gating as in GTrXL self.gru_gating = config.gru_gating - self.gru_bias = 2.0 if self.gru_gating: - self.gate1 = GRUGatingUnit(config.embed_dim, self.gru_bias) - self.gate2 = GRUGatingUnit(config.embed_dim, self.gru_bias) + # As in GTrXL, for stabilizing training with recurrence + self.gate1 = GRUGatingUnit(config.embed_dim, bias_init=2.0) + self.gate2 = GRUGatingUnit(config.embed_dim, bias_init=2.0) self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) - self.mlp = nn.Sequential( - nn.Linear(config.embed_dim, 4 * config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(4 * config.embed_dim, config.embed_dim), - nn.Dropout(config.resid_pdrop), - ) + + if config.moe_in_transformer: + from .moe import MoELayer + # Create multiple independent MLP instances as experts + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + self.feed_forward = MoELayer( + config, + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=config.num_experts_per_tok, + ) + # Log only from rank 0 to avoid excessive output + from ding.utils import get_rank + if get_rank() == 0: + logging.info(f"Using MoE in transformer feed-forward with {config.num_experts_of_moe_in_transformer} experts.") + elif config.multiplication_moe_in_transformer: + from .moe import MoELayer, MultiplicationFeedForward + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + self.feed_forward = MoELayer( + config, + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=config.num_experts_per_tok, + ) + # Log only from rank 0 to avoid excessive output + from ding.utils import get_rank + if get_rank() == 0: + logging.info(f"Using Multiplication MoE in transformer feed-forward with {config.num_experts_of_moe_in_transformer} experts.") + else: + # Standard MLP, with linear layers potentially wrapped for LoRA. + self.feed_forward = nn.Sequential( + _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward"), + nn.GELU(approximate='tanh'), + _maybe_wrap_linear(nn.Linear(4 * config.embed_dim, config.embed_dim), config, "feed_forward"), + nn.Dropout(config.resid_pdrop), + ) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Forward pass of the Transformer block. - + Overview: + Performs the forward pass of the Transformer block. Arguments: - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). - - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). - - freqs_cis (:obj:`torch.Tensor`): Frequency components for rotary position embeddings, used to modulate the attention mechanism (default: None). - + - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking. Returns: - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). """ - x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis) + attn_output = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) if self.gru_gating: - x = self.gate1(x, x_attn) - x = self.gate2(x, self.mlp(self.ln2(x))) + x = self.gate1(x, attn_output) + ff_output = self.feed_forward(self.ln2(x)) + x = self.gate2(x, ff_output) else: - x = x + x_attn - x = x + self.mlp(self.ln2(x)) - + x = x + attn_output + x = x + self.feed_forward(self.ln2(x)) return x class SelfAttention(nn.Module): """ - Implements self-attention mechanism for transformers. - - Arguments: - config (:obj:`TransformerConfig`): Configuration object containing hyperparameters. - - Attributes: - - config (:obj:`TransformerConfig`): Stores the configuration for the self-attention module. - - num_heads (:obj:`int`): Number of attention heads. - - key (:obj:`nn.Linear`): Linear layer to project input to key vectors. - - query (:obj:`nn.Linear`): Linear layer to project input to query vectors. - - value (:obj:`nn.Linear`): Linear layer to project input to value vectors. - - attn_drop (:obj:`nn.Dropout`): Dropout layer for attention weights. - - resid_drop (:obj:`nn.Dropout`): Dropout layer for residual connection. - - proj (:obj:`nn.Linear`): Final linear layer for projection. - - mask (:obj:`torch.Tensor`): Mask tensor for causal or block-causal attention. + Implements the self-attention mechanism for a Transformer. """ + def __init__(self, config: TransformerConfig) -> None: + """ + Overview: + Initializes the SelfAttention module. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the attention module. + """ super().__init__() assert config.embed_dim % config.num_heads == 0, "Embedding dimension must be divisible by number of heads." self.config = config self.num_heads = config.num_heads + + self.task_embed_option = self.config.task_embed_option + self.use_register_token = (self.task_embed_option == "register_task_embed") + if self.use_register_token: + self.register_token_num = getattr(config, "register_token_num", 4) - self.key = nn.Linear(config.embed_dim, config.embed_dim) - self.query = nn.Linear(config.embed_dim, config.embed_dim) - self.value = nn.Linear(config.embed_dim, config.embed_dim) + # Wrap linear layers if LoRA is enabled for the attention module + self.key = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.query = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.value = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") + self.proj = _maybe_wrap_linear(nn.Linear(config.embed_dim, config.embed_dim), config, "attn") self.attn_drop = nn.Dropout(config.attn_pdrop) self.resid_drop = nn.Dropout(config.resid_pdrop) - self.proj = nn.Linear(config.embed_dim, config.embed_dim) - causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) + # TODO: The mask size is conservatively large to accommodate register tokens. + # This could be made more dynamic. + mask_size = config.max_tokens + if self.use_register_token: + mask_size += self.register_token_num * 5 + causal_mask = torch.tril(torch.ones(mask_size, mask_size)) self.register_buffer('mask', causal_mask) def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Forward pass for the self-attention mechanism. - + Overview: + Performs the forward pass for the self-attention mechanism. Arguments: - - x (:obj:`torch.Tensor`): Input tensor of shape (B, T, C) where B is batch size, - T is sequence length, and C is embedding dimension. + - x (:obj:`torch.Tensor`): Input tensor of shape (B, T, C). - kv_cache (:obj:`Optional[KeysValues]`): Optional key-value cache for faster inference. - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Optional tensor containing valid context lengths. - - freqs_cis (:obj:`torch.Tensor`): Frequency components for rotary position embeddings, used to modulate the attention mechanism (default: None). - Returns: - torch.Tensor: Output tensor of shape (B, T, C). """ B, T, C = x.size() + head_size = C // self.num_heads + + past_len = 0 if kv_cache is not None: - b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." - else: - L = 0 + past_len = kv_cache.shape[2] - q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size) - - if self.config.rotary_emb: - q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) + q = self.query(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + k = self.key(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + v = self.value(x).view(B, T, self.num_heads, head_size).transpose(1, 2) if kv_cache is not None: - kv_cache.update(k, v) # time occupancy 21% - k, v = kv_cache.get() # time occupancy 5% + kv_cache.update(k, v) + k, v = kv_cache.get() + current_len = k.size(2) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + # Construct the attention mask + mask = self.mask[past_len:past_len + T, :current_len] + if valid_context_lengths is not None: - # Final mask.shape: (B, T, L + T) - # L is the context length, T is the current input length, - # valid_context_lengths is the valid length at the end of the context. - mask = torch.zeros(B, T, L + T, device=att.device) - # For each sample, set the invalid parts to 0 based on its valid length. + # This logic is for a specific use case and may need adjustment. + # It creates a custom mask for each item in the batch. + batch_mask = torch.zeros(B, T, current_len, device=att.device) for i in range(B): - mask[i] = self.mask[L:L + T, :L + T].clone() - mask[i, :, :(L - valid_context_lengths[i])] = 0 # Set invalid parts to 0. - # Adjust mask dimensions to match the last two dimensions of att. - # (B, T, L + T) -> (B, 1, T, L + T) -> (B, num_heads, T, L + T) - mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) - else: - # mask.shape: (T, L + T) - mask = self.mask[L:L + T, :L + T] + batch_mask[i] = mask.clone() + # Zero out attention to invalid past context + batch_mask[i, :, :(past_len - valid_context_lengths[i])] = 0 + mask = batch_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) + + # Adjust mask for register tokens if they are in use + if self.use_register_token and self.register_token_num > 0: + # Allow all positions to attend to register tokens and vice-versa + register_mask = mask.clone() + # Register tokens are at the end of the sequence + register_indices_start = current_len - self.register_token_num + register_mask[..., register_indices_start:] = 1 # All can see registers + # This part is more complex if T is not the full sequence length + if T > self.register_token_num: + # Only the actual register tokens in the current input `x` can see everything + register_mask[..., -self.register_token_num:, :] = 1 + mask = register_mask + + if kv_cache is not None: + # Ensure mask dimensions match the potentially smaller KV cache length + new_L = kv_cache.shape[2] + mask = mask[..., :new_L] - # att.shape: (B, num_heads, T, L + T) att = att.masked_fill(mask == 0, float('-inf')) - att = F.softmax(att, dim=-1) att = self.attn_drop(att) - y = att @ v # (B, num_heads, T, L + T) x (B, num_heads, L + T, head_size) -> (B, num_heads, T, head_size) - y = rearrange(y, 'b h t e -> b t (h e)') # Combine the heads back together (B, T, embed_dim) + y = att @ v + y = rearrange(y, 'b h t e -> b t (h e)') y = self.resid_drop(self.proj(y)) return y @@ -375,48 +668,41 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, def get_attention_map(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ - Compute the attention map for the input sequence. This is useful for visualization purposes. - More details can be found in visualizing_utils.py. - + Overview: + Computes the attention map for visualization, without computing the final output. Arguments: - x (:obj:`torch.Tensor`): Input sequence with shape (B, T, C). - - kv_cache (:obj:`Optional[KeysValues]`): Cached keys and values for supporting long sequence inference. - - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for handling variable-length contexts. - + - kv_cache (:obj:`Optional[KeysValues]`): Cached keys and values for long sequence inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for variable-length inputs. Returns: - - torch.Tensor: Attention map with shape (B, nh, T, L + T), representing the distribution of attention. + - torch.Tensor: Attention map of shape (B, num_heads, T, L + T). """ B, T, C = x.size() + head_size = C // self.num_heads + + past_len = 0 if kv_cache is not None: - b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions are inconsistent with input dimensions." - else: - L = 0 + past_len = kv_cache.shape[2] - # Compute query, key, and value projections - q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) - k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) - v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + k = self.key(x).view(B, T, self.num_heads, head_size).transpose(1, 2) + v = self.value(x).view(B, T, self.num_heads, head_size).transpose(1, 2) if kv_cache is not None: - # Update the kv_cache with the new keys and values kv_cache.update(k, v) k, v = kv_cache.get() - # Compute the attention scores + current_len = k.size(2) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + mask = self.mask[past_len:past_len + T, :current_len] if valid_context_lengths is not None: - mask = torch.zeros(B, T, L + T, device=att.device) + batch_mask = torch.zeros(B, T, current_len, device=att.device) for i in range(B): - # Create attention mask for each batch - mask[i] = self.mask[L:L + T, :L + T].clone() - mask[i, :, :(L - valid_context_lengths[i])] = 0 - mask = mask.unsqueeze(1).expand(-1, att.size(1), -1, -1) - else: - mask = self.mask[L:L + T, :L + T] + batch_mask[i] = mask.clone() + batch_mask[i, :, :(past_len - valid_context_lengths[i])] = 0 + mask = batch_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1) - # Apply the attention mask att = att.masked_fill(mask == 0, float('-inf')) att = F.softmax(att, dim=-1) diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index 99c841cbe..795e85f91 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -54,9 +54,6 @@ def custom_copy_kv_cache_to_dict_speed(src_kv: KeysValues, dst_dict: dict, cache print(f"Cache copy time: {copy_time:.6f} seconds") print(f"Total time: {shape_time + copy_time:.6f} seconds") - # print(f"Cache key '{cache_key}' has been copied to the destination dictionary.") - # print(f"Dictionary size: {len(dst_dict)}") - def custom_copy_kv_cache_to_dict(src_kv: KeysValues, dst_dict: dict, cache_key: str, reuse_cache: bool = True) -> None: """ @@ -106,8 +103,7 @@ def custom_copy_kv_cache(src_kv: KeysValues) -> KeysValues: len(src_kv), # num_layers src_kv._keys_values[0]._k_cache._cache.device, # device ) - - # with torch.no_grad(): + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): # Copy the key and value caches using torch.copy_() dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) @@ -179,17 +175,33 @@ def calculate_cuda_memory_gb(past_keys_values_cache, num_layers: int): total_memory_gb = total_memory_bytes / (1024 ** 3) return total_memory_gb -def hash_state(state): + +def hash_state(state: np.ndarray) -> int: """ - Hash the state vector. + Overview: + Computes a fast and robust hash for a NumPy array state. + + Why this is optimal: + 1. Algorithm (`xxhash.xxh64`): Uses one of the fastest non-cryptographic hash + functions available, ideal for performance-critical applications like caching. + 2. Input Preparation (`state.tobytes()`): Ensures correctness by creating a + canonical byte representation of the array. This guarantees that two + logically identical arrays will produce the same hash, regardless of their + internal memory layout (e.g., C-contiguous, F-contiguous, or strided views). + 3. Output Format (`.intdigest()`): Directly produces an integer hash value, + which is the most efficient key type for Python dictionaries, avoiding the + overhead of string keys. Arguments: - state: The state vector to be hashed. + - state (np.ndarray): The state array to be hashed. + Returns: - The hash value of the state vector. + - int: A 64-bit integer hash of the state. """ - # Use xxhash for faster hashing - return xxhash.xxh64(state).hexdigest() + # Ensure the array is contiguous in memory before converting to bytes, + # although .tobytes() handles this, being explicit can sometimes be clearer. + # For simplicity and since .tobytes() defaults to C-order, we can rely on it. + return xxhash.xxh64(state.tobytes()).intdigest() @dataclass class WorldModelOutput: @@ -201,22 +213,33 @@ class WorldModelOutput: logits_value: torch.FloatTensor -def init_weights(module, norm_type='BN'): +def init_weights(module, norm_type='BN',liner_weight_zero=False): """ Initialize the weights of the module based on the specified normalization type. - Arguments: module (nn.Module): The module to initialize. norm_type (str): The type of normalization to use ('BN' for BatchNorm, 'LN' for LayerNorm). """ - if isinstance(module, (nn.Linear, nn.Embedding)): + if isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: + elif isinstance(module, nn.Linear): + # Now this branch can be executed correctly + if norm_type == 'BN': + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + print("Init Linear using kaiming normal for BN") + elif norm_type == 'LN': + # For Transformer structures, Xavier/Glorot initialization is more common + nn.init.xavier_uniform_(module.weight) + print("Init Linear using xavier uniform for LN") + + if module.bias is not None: module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): print(f"Init {module} using zero bias, 1 weight") - module.bias.data.zero_() module.weight.data.fill_(1.0) + module.bias.data.zero_() + elif isinstance(module, nn.BatchNorm2d): print(f"Init nn.BatchNorm2d using zero bias, 1 weight") module.weight.data.fill_(1.0) @@ -228,13 +251,6 @@ def init_weights(module, norm_type='BN'): elif norm_type == 'LN': nn.init.xavier_uniform_(module.weight) print(f"Init nn.Conv2d using xavier uniform for LN") - elif isinstance(module, nn.Linear): - if norm_type == 'BN': - nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - print("Init Linear using kaiming normal for BN") - elif norm_type == 'LN': - nn.init.xavier_uniform_(module.weight) - print("Init Linear using xavier uniform for LN") class LossWithIntermediateLosses: @@ -294,7 +310,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, continu self.loss_total += self.perceptual_loss_weight * v self.intermediate_losses = { - k: v if isinstance(v, dict) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) + k: v if isinstance(v, dict) or isinstance(v, np.ndarray) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) for k, v in kwargs.items() } diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py old mode 100644 new mode 100755 index d36358456..5d34e9fe6 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -1,20 +1,30 @@ +import datetime import logging -from typing import Dict, Union, Optional, List, Tuple, Any +import os +from collections import OrderedDict, defaultdict +from typing import Any, Dict, List, Optional, Tuple, Union +import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from torch.distributions import Categorical, Independent, Normal, TransformedDistribution, TanhTransform - from lzero.model.common import SimNorm -from lzero.model.utils import cal_dormant_ratio +from lzero.model.utils import (calculate_dormant_ratio, + compute_average_weight_magnitude, + compute_effective_rank) +from matplotlib.offsetbox import AnnotationBbox, OffsetImage +from sklearn.manifold import TSNE +from torch.distributions import (Categorical, Independent, Normal, + TanhTransform, TransformedDistribution) + from .kv_caching import KeysValues from .slicer import Head, PolicyHeadCont from .tokenizer import Tokenizer from .transformer import Transformer, TransformerConfig -from .utils import LossWithIntermediateLosses, init_weights, WorldModelOutput, hash_state +from .utils import (LossWithIntermediateLosses, WorldModelOutput, hash_state, + init_weights) logging.getLogger().setLevel(logging.DEBUG) @@ -41,8 +51,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: super().__init__() self.tokenizer = tokenizer self.config = config - self.transformer = Transformer(self.config) + self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings + self.transformer = Transformer(self.config) + self.task_num = 1 + self.env_num = self.config.env_num if self.config.device == 'cpu': self.device = torch.device('cpu') else: @@ -51,6 +64,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: logging.info(f"self.device: {self.device}") self.to(self.device) + self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 + # Initialize configuration parameters self._initialize_config_parameters() @@ -65,6 +80,11 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.precompute_pos_emb_diff_kv() print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + if self.task_embed_option == "concat_task_embed": + self.obs_per_embdding_dim = self.config.embed_dim - self.task_embed_dim + else: + self.obs_per_embdding_dim = self.config.embed_dim self.continuous_action_space = self.config.continuous_action_space # Initialize action embedding table @@ -82,7 +102,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # Head modules self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) - self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, \ + self.head_observations = self._create_head_for_latent(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, \ self._get_final_norm(self.final_norm_option_in_obs_head) # NOTE: using the specified normalization method for observations head ) if self.continuous_action_space: @@ -93,6 +113,13 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + self.head_dict = {} + for name, module in self.named_children(): + if name.startswith("head_"): + self.head_dict[name] = module + if self.head_dict: + self.head_dict = nn.ModuleDict(self.head_dict) + # Build the set of modules to skip during re-initialization. # This is compatible with cases where self.tokenizer.encoder does not have 'pretrained_model', # or self.tokenizer does not have 'decoder_network'. @@ -115,9 +142,6 @@ def custom_init(module): self._initialize_last_layer() - # Cache structures - self._initialize_cache_structures() - # Projection input dimension self._initialize_projection_input_dim() @@ -130,18 +154,25 @@ def custom_init(module): self.latent_recon_loss = torch.tensor(0., device=self.device) self.perceptual_loss = torch.tensor(0., device=self.device) + # Set to game_segment_length first to keep self.shared_pool_init_infer valid + # TODO: Very important, should be changed to match segment_length + self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + # TODO: check the size of the shared pool # for self.kv_cache_recurrent_infer # If needed, recurrent_infer should store the results of the one MCTS search. self.num_simulations = getattr(self.config, 'num_simulations', 50) - self.shared_pool_size = int(self.num_simulations*self.env_num) - self.shared_pool_recur_infer = [None] * self.shared_pool_size + + + self.shared_pool_size_recur = int(self.num_simulations*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur self.shared_pool_index = 0 + # Cache structures + self._initialize_cache_structures() + # for self.kv_cache_init_infer # In contrast, init_infer only needs to retain the results of the most recent step. - # self.shared_pool_size_init = int(2*self.env_num) - self.shared_pool_size_init = int(2) # NOTE: Will having too many cause incorrect retrieval of the kv cache? self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] @@ -152,6 +183,237 @@ def custom_init(module): self.reanalyze_phase = False + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + from collections import defaultdict + + # ==================== Parallel KV Cache Systems ==================== + # Check if we should use the new KV cache manager + self.use_new_cache_manager = getattr(self.config, 'use_new_cache_manager', False) + + if self.use_new_cache_manager: + # Use new unified KV cache manager + from .kv_cache_manager import KVCacheManager + self.kv_cache_manager = KVCacheManager( + config=self.config, + env_num=self.env_num, + enable_stats=True, + clear_recur_log_freq=1000, # MCTS recurrent clearing log, print every 1000 times + clear_all_log_freq=100 # Episode reset clearing log, print every 100 times + ) + # Keep backward compatibility references + self.keys_values_wm_list = self.kv_cache_manager.keys_values_wm_list + self.keys_values_wm_size_list = self.kv_cache_manager.keys_values_wm_size_list + + # ==================== BUG FIX: Complete Refactoring ==================== + # DO NOT initialize old system attributes when using new cache manager. + # Any code that depends on these old attributes must be refactored to use + # kv_cache_manager instead. + # + # Old attributes that are NO LONGER available in new system: + # - self.past_kv_cache_recurrent_infer + # - self.pool_idx_to_key_map_recur_infer + # - self.past_kv_cache_init_infer_envs + # - self.pool_idx_to_key_map_init_envs + # + # Migration guide: + # - For accessing init cache: use kv_cache_manager.get_init_cache(env_id, key) + # - For accessing recur cache: use kv_cache_manager.get_recur_cache(key) + # - For hierarchical lookup: use kv_cache_manager.hierarchical_get(env_id, key) + # ====================================================================== + + logging.info("✓ Using NEW KVCacheManager for cache management") + else: + # Use old cache system (original implementation) + self.past_kv_cache_recurrent_infer = {} + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] + # Auxiliary data structure for reverse lookup: pool_index -> key + self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + logging.info("Using OLD cache system (original implementation)") + # ============================================================================= + + def _inspect_and_log_head_params(self, head_name: str, head_module: nn.Module, status: str): + """ + Inspect and log parameter statistics for the specified Head module. + + Args: + head_name (str): The name of the Head to inspect (e.g., "Value Head"). + head_module (nn.Module): The actual nn.Sequential module of the Head. + status (str): A string describing the current status (e.g., "Before Re-init"). + """ + logging.info(f"--- Inspecting {head_name} parameters ({status}) ---") + with torch.no_grad(): + for param_name, param in head_module.named_parameters(): + if param.numel() > 0: + stats = { + "mean": param.mean().item(), + "std": param.std().item(), + "abs_mean": param.abs().mean().item(), + "max": param.max().item(), + "min": param.min().item(), + } + logging.info( + f" -> {param_name:<20} | " + f"Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}, " + f"AbsMean: {stats['abs_mean']:.4f}, " + f"Max: {stats['max']:.4f}, Min: {stats['min']:.4f}" + ) + logging.info("-" * (23 + len(head_name) + len(status))) + + def reinit_prediction_heads(self, heads_to_reinit: List[str] = ['value', 'reward']) -> None: + """ + Reinitialize the parameters of specified prediction heads (e.g., Value Head and Reward Head). + Parameter statistics are logged before and after reinitialization for analysis. + + Args: + heads_to_reinit (List[str]): A list containing the names of the heads to reinitialize. + Defaults to ['value', 'reward']. + """ + logging.info(f"Starting reinitialization of prediction heads: {heads_to_reinit}") + + head_map = { + 'value': self.head_value, + 'reward': self.head_rewards, + 'policy': self.head_policy, + } + + def _init_weights_for_head(module): + # TODO + init_weights(module, norm_type=self.config.norm_type, liner_weight_zero=True) + + for head_name in heads_to_reinit: + if head_name in head_map and hasattr(head_map[head_name], 'head_module'): + head_instance = head_map[head_name] + capitalized_name = head_name.capitalize() + " Head" + + # 1. Inspect parameters before reinitialization + self._inspect_and_log_head_params(capitalized_name, head_instance.head_module, "Before Re-init") + + # 2. Apply reinitialization + logging.info(f"Reinitializing {capitalized_name}...") + head_instance.head_module.apply(_init_weights_for_head) + + # 3. Inspect parameters again after reinitialization + self._inspect_and_log_head_params(capitalized_name, head_instance.head_module, "After Re-init") + + logging.info(f"{capitalized_name} parameters successfully reinitialized.") + else: + logging.warning(f"Prediction head named '{head_name}' or its 'head_module' not found. Skipping.") + + logging.info("Reinitialization of all specified prediction heads completed.") + + def _analyze_latent_representation( + self, + latent_states: torch.Tensor, + timesteps: torch.Tensor, + game_states: torch.Tensor, + predicted_values: torch.Tensor, + predicted_rewards: torch.Tensor, + step_counter: int + ): + """ + Analyze and log statistics of latent states with t-SNE visualization. + [New feature]: Display corresponding game images on t-SNE plot with predicted Value and Reward annotations. + [Modified]: If the save path already exists, append a timestamp to the filename. + + Args: + latent_states (torch.Tensor): Encoder output, shape (B*L, 1, E) + timesteps (torch.Tensor): Corresponding timesteps, shape (B, L) + game_states (torch.Tensor): Original game observations, shape (B, L, C, H, W) + predicted_values (torch.Tensor): Predicted scalar Values, shape (B*L,) + predicted_rewards (torch.Tensor): Predicted scalar Rewards, shape (B*L,) + step_counter (int): Global training step count + """ + # Ensure latent_states and game_states have shape (N, ...) + if latent_states.dim() > 2: + latent_states = latent_states.reshape(-1, latent_states.shape[-1]) + num_c, num_h, num_w = game_states.shape[-3:] + game_states = game_states.reshape(-1, num_c, num_h, num_w) + + with torch.no_grad(): + l2_norm = torch.norm(latent_states, p=2, dim=1).mean() + mean = latent_states.mean() + std = latent_states.std() + print(f"[Step {step_counter}] Latent Stats | L2 Norm: {l2_norm:.4f}, Mean: {mean:.4f}, Std: {std:.4f}") + + # t-SNE visualization with images and V/R values + if step_counter >= 0: + print(f"[Step {step_counter}] Performing t-SNE analysis with images, values, and rewards...") + + # Convert data to CPU + latents_np = latent_states.detach().cpu().numpy() + images_np = game_states.detach().cpu().numpy() + values_np = predicted_values.detach().cpu().numpy() + rewards_np = predicted_rewards.detach().cpu().numpy() + + tsne = TSNE(n_components=2, perplexity=30, n_iter=300, random_state=42) + tsne_results = tsne.fit_transform(latents_np) + + # Draw scatter plot with images and annotations + + # Reduce number of images to keep clarity + num_points_to_plot = min(len(latents_np), 70) # Reduce to 70 points + indices = np.random.choice(len(latents_np), num_points_to_plot, replace=False) + + fig, ax = plt.subplots(figsize=(20, 18)) # Increase canvas size + + # First draw all points as background scatter plot + ax.scatter(tsne_results[:, 0], tsne_results[:, 1], c=values_np, cmap='viridis', alpha=0.3, s=10) + + for i in indices: + x, y = tsne_results[i] + img = images_np[i].transpose(1, 2, 0) + img = np.clip(img, 0, 1) + + # Place image + im = OffsetImage(img, zoom=0.7) # Slightly enlarge image + ab = AnnotationBbox(im, (x, y), frameon=True, pad=0.0, bboxprops=dict(edgecolor='none')) + ax.add_artist(ab) + + # Add text annotation below image + text_label = f"V:{values_np[i]:.1f} R:{rewards_np[i]:.1f}" + ax.text(x, y - 1.0, text_label, ha='center', va='top', fontsize=8, color='red', + bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.5)) + + ax.update_datalim(tsne_results) + ax.autoscale() + + ax.set_title(f't-SNE of Latent States (Value as Color) at Step {step_counter}', fontsize=16) + ax.set_xlabel('t-SNE dimension 1', fontsize=12) + ax.set_ylabel('t-SNE dimension 2', fontsize=12) + + # Add colorbar to explain background point colors + norm = plt.Normalize(values_np.min(), values_np.max()) + sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm) + sm.set_array([]) + fig.colorbar(sm, ax=ax, label='Predicted Value') + + # Modified section: Check if file exists, add timestamp if it does + base_save_path = ( + f'/mnt/nfs/zhangjinouwen/puyuan/LightZero/zoo/atari/unizero_mspacman_analyze/' + f'tsne_with_vr_{self.config.optim_type}_step_{step_counter}.png' + ) + + # Check if file exists and determine final save path + if os.path.exists(base_save_path): + # If file already exists, generate timestamp and append to filename + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + path_root, path_ext = os.path.splitext(base_save_path) + save_path = f"{path_root}_{timestamp}{path_ext}" + print(f"File '{base_save_path}' already exists. Saving to new path with timestamp.") + else: + # If file doesn't exist, use original path + save_path = base_save_path + + # Save image + plt.savefig(save_path) + plt.close(fig) # Explicitly close figure object + print(f"t-SNE plot with V/R annotations saved to {save_path}") + def _get_final_norm(self, norm_option: str) -> nn.Module: """ Return the corresponding normalization module based on the specified normalization option. @@ -209,7 +471,7 @@ def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: - index (:obj:`int`): The index in the shared pool where the KeysValues object is stored. """ src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape - + if self.shared_pool_wm[self.shared_pool_index_wm] is None: self.shared_pool_wm[self.shared_pool_index_wm] = KeysValues( src_kv_shape[0], # Number of elements (n) @@ -221,7 +483,7 @@ def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int: ) dst_kv = self.shared_pool_wm[self.shared_pool_index_wm] - + for src_layer, dst_layer in zip(src_kv._keys_values, dst_kv._keys_values): # Copy the key and value caches using torch.copy_() for efficient data transfer dst_layer._k_cache._cache.copy_(src_layer._k_cache._cache) @@ -264,7 +526,7 @@ def custom_copy_kv_cache_to_shared_recur(self, src_kv: KeysValues) -> int: dst_layer._v_cache._size = src_layer._v_cache._size index = self.shared_pool_index - self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size + self.shared_pool_index = (self.shared_pool_index + 1) % self.shared_pool_size_recur return index @@ -280,7 +542,7 @@ def _initialize_config_parameters(self) -> None: self.gamma = self.config.gamma self.context_length = self.config.context_length self.dormant_threshold = self.config.dormant_threshold - self.analysis_dormant_ratio = self.config.analysis_dormant_ratio + self.analysis_dormant_ratio_weight_rank = self.config.analysis_dormant_ratio_weight_rank self.num_observations_tokens = self.config.tokens_per_block - 1 self.latent_recon_loss_weight = self.config.latent_recon_loss_weight self.perceptual_loss_weight = self.config.perceptual_loss_weight @@ -289,9 +551,52 @@ def _initialize_config_parameters(self) -> None: self.max_cache_size = self.config.max_cache_size self.env_num = self.config.env_num self.num_layers = self.config.num_layers - self.obs_per_embdding_dim = self.config.embed_dim self.sim_norm = SimNorm(simnorm_dim=self.group_size) + # ==================== [NEW] Policy Stability Fix Options ==================== + # Load fix options from config (with defaults for backward compatibility) + self.use_policy_logits_clip = getattr(self.config, 'use_policy_logits_clip', False) + self.policy_logits_clip_method = getattr(self.config, 'policy_logits_clip_method', 'normalize_max') + self.policy_logits_clip_min = getattr(self.config, 'policy_logits_clip_min', -10.0) + self.policy_logits_clip_max = getattr(self.config, 'policy_logits_clip_max', 10.0) + self.policy_logits_soft_beta = getattr(self.config, 'policy_logits_soft_beta', 1.0) + self.policy_logits_adaptive_percentile = getattr(self.config, 'policy_logits_adaptive_percentile', 95) + + # Running statistics for adaptive clipping + if self.policy_logits_clip_method == 'adaptive': + self.register_buffer('policy_logits_running_max', torch.tensor(10.0)) + self.register_buffer('policy_logits_running_min', torch.tensor(-10.0)) + self.policy_logits_momentum = 0.99 + + # [NEW] Fix5: Temperature scaling for policy loss + self.use_policy_loss_temperature = getattr(self.config, 'use_policy_loss_temperature', False) + self.policy_loss_temperature = getattr(self.config, 'policy_loss_temperature', 1.0) + + # [NEW] Fix3: Check if target policy re-smooth is enabled (now deprecated in favor of Fix2) + use_target_policy_resmooth = getattr(self.config, 'use_target_policy_resmooth', False) + if use_target_policy_resmooth: + logging.warning( + "[DEPRECATED] use_target_policy_resmooth=True is deprecated! " + "Policy label smoothing should now be controlled by 'continuous_ls_eps' in policy config. " + "Fix3 (use_target_policy_resmooth) creates redundant smoothing with Fix2. " + "Please set use_target_policy_resmooth=False and use continuous_ls_eps instead." + ) + + # [NEW] Debug: Print configuration on initialization + if self.use_policy_logits_clip: + logging.info( + f"[Policy Logits Control] ENABLED\n" + f" Method: {self.policy_logits_clip_method}\n" + f" Range: [{self.policy_logits_clip_min}, {self.policy_logits_clip_max}]\n" + f" Soft Beta: {self.policy_logits_soft_beta if 'soft' in self.policy_logits_clip_method else 'N/A'}" + ) + else: + logging.warning(f"[Policy Logits Control] DISABLED! Logits may grow unbounded.") + + if self.use_policy_loss_temperature and self.policy_loss_temperature != 1.0: + logging.info(f"[Policy Loss Temperature] ENABLED: temperature={self.policy_loss_temperature}") + # ============================================================================= + def _initialize_patterns(self) -> None: """Initialize patterns for block masks.""" self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) @@ -301,12 +606,144 @@ def _initialize_patterns(self) -> None: self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) self.value_policy_tokens_pattern[-2] = 1 + def _apply_policy_logits_control(self, logits_policy: torch.Tensor) -> torch.Tensor: + """ + Apply policy logits control using various methods to prevent explosion. + + This method implements multiple strategies to constrain policy logits: + 1. 'hard': Hard clamp (torch.clamp) - Simple but gradients die at boundaries + 2. 'soft_tanh': Soft clamp using tanh - Smooth, gradients never zero + 3. 'soft_sigmoid': Soft clamp using sigmoid - Similar to tanh but different curve + 4. 'normalize_max': Subtract max then clamp - Preserves relative order, safer + 5. 'normalize_mean': Subtract mean then clamp - Centers distribution + 6. 'adaptive': Adaptive clipping based on running statistics + 7. 'none': No clipping + + Arguments: + - logits_policy (:obj:`torch.Tensor`): Raw policy logits from head_policy + Shape: [batch_size, num_steps, action_dim] or [batch_size * num_steps, action_dim] + + Returns: + - torch.Tensor: Controlled policy logits with the same shape + + Examples: + >>> logits = torch.randn(32, 10, 6) * 20 # Large logits + >>> controlled = self._apply_policy_logits_control(logits) + >>> assert controlled.abs().max() <= self.policy_logits_clip_max + """ + if not self.use_policy_logits_clip or self.policy_logits_clip_method == 'none': + return logits_policy + + method = self.policy_logits_clip_method + clip_min = self.policy_logits_clip_min + clip_max = self.policy_logits_clip_max + + # ==================== Method 1: Hard Clamp ==================== + if method == 'hard': + # Simple hard clipping + # Pros: Simple, fast + # Cons: Gradients become zero outside [clip_min, clip_max] + return torch.clamp(logits_policy, min=clip_min, max=clip_max) + + # ==================== Method 2: Soft Tanh Clamp ==================== + elif method == 'soft_tanh': + # Soft clamp using tanh function: clip_max * tanh(x / clip_max) + # Pros: Gradients never zero, smooth transition + # Cons: Slightly more computation + # When x is small: tanh(x) ≈ x, so output ≈ x (unchanged) + # When x is large: tanh(x) → 1, so output → clip_max (smoothly saturates) + C = clip_max # Use positive bound as scale + beta = self.policy_logits_soft_beta # Smoothness parameter + return C * torch.tanh(logits_policy / (C * beta)) + + # ==================== Method 3: Soft Sigmoid Clamp ==================== + elif method == 'soft_sigmoid': + # Soft clamp using sigmoid: maps (-∞, ∞) to (clip_min, clip_max) + # Formula: clip_min + (clip_max - clip_min) * sigmoid(x / beta) + # Pros: Smooth, bounded + # Cons: Compresses entire range, may lose relative ordering + beta = self.policy_logits_soft_beta + range_size = clip_max - clip_min + return clip_min + range_size * torch.sigmoid(logits_policy / beta) + + # ==================== Method 4: Normalize Max + Hard Clamp ==================== + elif method == 'normalize_max': + # Subtract max value first (exploits softmax translation invariance) + # softmax(x) = softmax(x - c) for any constant c + # By subtracting max, we ensure the largest logit is 0, others are negative + # Then apply hard clamp (mainly affects the negative tail) + # Pros: Preserves relative ordering, safer than pure hard clamp + # Cons: Still has gradient issues for very negative values + logits_normalized = logits_policy - logits_policy.max(dim=-1, keepdim=True)[0].detach() + return torch.clamp(logits_normalized, min=clip_min, max=clip_max) + + # ==================== Method 5: Normalize Mean + Hard Clamp ==================== + elif method == 'normalize_mean': + # Subtract mean (centers the distribution) + # Pros: Centers logits around 0, prevents drift + # Cons: May change relative probabilities more than normalize_max + logits_normalized = logits_policy - logits_policy.mean(dim=-1, keepdim=True).detach() + return torch.clamp(logits_normalized, min=clip_min, max=clip_max) + + # ==================== Method 6: Adaptive Clipping ==================== + elif method == 'adaptive': + # Dynamically adjust clipping thresholds based on running statistics + # Update running stats (only during training) + if self.training: + with torch.no_grad(): + # Compute percentile-based bounds + flat_logits = logits_policy.view(-1) + percentile = self.policy_logits_adaptive_percentile + current_max = torch.quantile(flat_logits, percentile / 100.0) + current_min = torch.quantile(flat_logits, (100 - percentile) / 100.0) + + # Update running statistics with momentum + self.policy_logits_running_max = ( + self.policy_logits_momentum * self.policy_logits_running_max + + (1 - self.policy_logits_momentum) * current_max + ) + self.policy_logits_running_min = ( + self.policy_logits_momentum * self.policy_logits_running_min + + (1 - self.policy_logits_momentum) * current_min + ) + + # Use running stats for clipping + adaptive_max = torch.clamp(self.policy_logits_running_max, max=clip_max) + adaptive_min = torch.clamp(self.policy_logits_running_min, min=clip_min) + return torch.clamp(logits_policy, min=adaptive_min, max=adaptive_max) + + else: + raise ValueError( + f"Unknown policy_logits_clip_method: {method}. " + f"Valid options: 'hard', 'soft_tanh', 'soft_sigmoid', 'normalize_max', " + f"'normalize_mean', 'adaptive', 'none'" + ) + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: """Create head modules for the transformer.""" modules = [ - nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.LayerNorm(self.config.embed_dim), # Core optimization! # TODO + nn.Linear(self.config.embed_dim, self.config.embed_dim*4), + nn.LayerNorm(self.config.embed_dim*4), # 2. New! Stabilize internal activations + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim*4, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _create_head_for_latent(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + """Create head modules for the transformer.""" + modules = [ + nn.LayerNorm(self.config.embed_dim), # Core optimization! # TODO + nn.Linear(self.config.embed_dim, self.config.embed_dim*4), + nn.LayerNorm(self.config.embed_dim*4), # 2. New! Stabilize internal activations nn.GELU(approximate='tanh'), - nn.Linear(self.config.embed_dim, output_dim) + nn.Linear(self.config.embed_dim*4, output_dim) ] if norm_layer: modules.append(norm_layer) @@ -351,21 +788,22 @@ def _initialize_last_layer(self) -> None: nn.init.zeros_(layer.bias) break - def _initialize_cache_structures(self) -> None: - """Initialize cache structures for past keys and values.""" - from collections import defaultdict - self.past_kv_cache_recurrent_infer = defaultdict(dict) - self.past_kv_cache_init_infer_envs = [defaultdict(dict) for _ in range(self.env_num)] - self.keys_values_wm_list = [] - self.keys_values_wm_size_list = [] def _initialize_projection_input_dim(self) -> None: """Initialize the projection input dimension based on the number of observation tokens.""" if self.num_observations_tokens == 16: self.projection_input_dim = 128 elif self.num_observations_tokens == 1: - self.projection_input_dim = self.obs_per_embdding_dim + # self.projection_input_dim = self.config.embed_dim + if self.task_embed_option == "concat_task_embed": + self.projection_input_dim = self.config.embed_dim - self.task_embed_dim + elif self.task_embed_option == "register_task_embed": + self.projection_input_dim = self.config.embed_dim + elif self.task_embed_option == "add_task_embed": + self.projection_input_dim = self.config.embed_dim + else: + self.projection_input_dim = self.config.embed_dim def _initialize_statistics(self) -> None: """Initialize counters for hit count and query count statistics.""" @@ -421,6 +859,7 @@ def precompute_pos_emb_diff_kv(self): self.pos_emb_diff_k.append(layer_pos_emb_diff_k) self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + #@profile def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: """ Helper function to get positional embedding for a given layer and attention type. @@ -626,11 +1065,20 @@ def forward( logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + + # ==================== [NEW] Advanced Policy Logits Control ==================== + # Apply configurable policy logits control to prevent explosion + # Multiple methods available: hard, soft_tanh, soft_sigmoid, normalize_max, etc. + if self.use_policy_logits_clip: + logits_policy = self._apply_policy_logits_control(logits_policy) + # ================================================================================ + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) # The 'logits_ends' is intentionally set to None. return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + #@profile def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths): """ @@ -659,6 +1107,7 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) return embeddings + position_embeddings + #@profile def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -698,6 +1147,7 @@ def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens, prev_step return_result += self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) return return_result, num_steps + #@profile def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -750,6 +1200,7 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths, start_pos=start_pos) + #@profile @torch.no_grad() def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos: int = 0) -> torch.FloatTensor: """ @@ -784,6 +1235,7 @@ def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor, start_pos return outputs_wm, self.latent_state + #@profile @torch.no_grad() def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor, batch_action=None, @@ -832,19 +1284,34 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens # Compute hash value using latent state for a single environment cache_key = hash_state(state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + # ==================== Storage Layer Integration ==================== # Retrieve cached value - cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) - if cache_index is not None: - matched_value = self.shared_pool_init_infer[i][cache_index] + if self.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager + matched_value = self.kv_cache_manager.get_init_cache(env_id=i, cache_key=cache_key) else: - matched_value = None + # OLD SYSTEM: Use legacy cache dictionaries + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + # ============================================================================= self.root_total_query_cnt += 1 if matched_value is not None: # If a matching value is found, add it to the list self.root_hit_cnt += 1 - # NOTE: deepcopy is needed because forward modifies matched_value in place - self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + # ==================== BUG FIX: Cache Corruption Prevention ==================== + # Perform a deep copy because the transformer's forward pass modifies matched_value in-place. + if self.use_new_cache_manager: + # NEW SYSTEM: Use KeysValues.clone() for deep copy + cached_copy = matched_value.clone() + self.keys_values_wm_list.append(cached_copy) + else: + # OLD SYSTEM: Use custom_copy_kv_cache_to_shared_wm + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + # ============================================================================= self.keys_values_wm_size_list.append(matched_value.size) else: # Reset using zero values @@ -891,7 +1358,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens # ================ calculate the target value in Train phase or calculate the target policy in reanalyze phase ================ # [192, 16, 64] -> [32, 6, 16, 64] last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, - self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + self.config.embed_dim) # (BL, K) for unroll_step=1 last_obs_embeddings = last_obs_embeddings[:, :-1, :] batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) @@ -922,6 +1389,7 @@ def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTens return outputs_wm + #@profile @torch.no_grad() def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): """ @@ -934,11 +1402,20 @@ def forward_initial_inference(self, obs_act_dict, start_pos: int = 0): """ # UniZero has context in the root node outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, start_pos) - self.past_kv_cache_recurrent_infer.clear() + + # ==================== BUG FIX: Clear Cache Using Correct API ==================== + if self.use_new_cache_manager: + # NEW SYSTEM: Clear recurrent cache using KVCacheManager + self.kv_cache_manager.clear_recur_cache() + else: + # OLD SYSTEM: Clear using legacy attribute + self.past_kv_cache_recurrent_infer.clear() + # ============================================================================= return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile @torch.no_grad() def forward_recurrent_inference(self, state_action_history, simulation_index=0, search_depth=[], start_pos: int = 0): @@ -1025,6 +1502,7 @@ def forward_recurrent_inference(self, state_action_history, simulation_index=0, return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: """ Adjusts the key-value cache for each environment to ensure they all have the same size. @@ -1077,6 +1555,7 @@ def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: return self.keys_values_wm_size_list + #@profile def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, search_depth=[], valid_context_lengths=None): """ @@ -1210,16 +1689,72 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 - if is_init_infer: - # Store the latest key-value cache for initial inference - cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) - self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + if self.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager for cache storage + # ==================== BUG FIX: Deep Copy Before Storage ==================== + # CRITICAL: Must clone before storing to prevent cache corruption. + # self.keys_values_wm_single_env is a shared object that gets modified. + # Without cloning, all cache entries would point to the same object, + # causing incorrect KV retrieval and training divergence. + kv_cache_to_store = self.keys_values_wm_single_env.clone() + # ============================================================================= + + if is_init_infer: + # Store to per-environment init cache pool + # Note: KVCacheManager automatically handles eviction logic (FIFO/LRU) + self.kv_cache_manager.set_init_cache( + env_id=i, + cache_key=cache_key, + kv_cache=kv_cache_to_store # Store cloned copy, not reference + ) + else: + # Store to global recurrent cache pool + self.kv_cache_manager.set_recur_cache( + cache_key=cache_key, + kv_cache=kv_cache_to_store # Store cloned copy, not reference + ) else: - # Store the latest key-value cache for recurrent inference - cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) - self.past_kv_cache_recurrent_infer[cache_key] = cache_index + # OLD SYSTEM: Use legacy cache with manual eviction + if is_init_infer: + # ==================== Active Eviction Fix Logic ==================== + # 1. Get the physical index that will be overwritten + index_to_write = self.shared_pool_index_init_envs[i] + # 2. Use auxiliary list to find the old key stored at this index + old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] + # 3. If old key exists, delete it from the main cache map + if old_key_to_evict is not None: + # Ensure the key to be deleted actually exists to avoid unexpected errors + if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: + del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] + + # Now it's safe to write new data + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + + # 4. Update both the main cache map and auxiliary list with new mapping + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key + else: + # ==================== RECURRENT INFER FIX ==================== + # 1. Get the physical index that will be overwritten + index_to_write = self.shared_pool_index + # 2. Use auxiliary list to find the old key stored at this index + old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] + # 3. If old key exists, delete it from the main cache map + if old_key_to_evict is not None: + if old_key_to_evict in self.past_kv_cache_recurrent_infer: + del self.past_kv_cache_recurrent_infer[old_key_to_evict] + + # 4. Now it's safe to write new data + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # 5. Update both the main cache map and auxiliary list with new mapping + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key + # ============================================================================= + + + #@profile def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, simulation_index: int = 0, start_pos: int = 0) -> list: """ @@ -1245,22 +1780,47 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, # TODO: check if this is correct matched_value = None else: - # Try to retrieve the cached value from past_kv_cache_init_infer_envs - cache_index = self.past_kv_cache_init_infer_envs[index].get(cache_key) - if cache_index is not None: - matched_value = self.shared_pool_init_infer[index][cache_index] + if self.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager's hierarchical_get for unified lookup + matched_value = self.kv_cache_manager.hierarchical_get(env_id=index, cache_key=cache_key) + + # Log cache miss (statistics are automatically handled by KVCacheManager) + if matched_value is None: + logging.debug(f"[NEW CACHE MISS] Not found for key={cache_key} in both init and recurrent cache.") else: - matched_value = None + # OLD SYSTEM: Use legacy cache dictionaries and pools + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[index].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[index][cache_index] + else: + matched_value = None - # If not found, try to retrieve from past_kv_cache_recurrent_infer - if matched_value is None: - matched_value = self.shared_pool_recur_infer[self.past_kv_cache_recurrent_infer.get(cache_key)] + # Only try to find from recurrent_infer cache if not found in init_infer + if matched_value is None: + # Safely get the index from dictionary, it may return None + recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + # Only use it to retrieve value from physical pool if the index is valid (not None) + if recur_cache_index is not None: + matched_value = self.shared_pool_recur_infer[recur_cache_index] + + if recur_cache_index is None: + logging.debug(f"[OLD CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") + # ============================================================================= if matched_value is not None: # If a matching cache is found, add it to the lists self.hit_count += 1 - # Perform a deep copy because the transformer's forward pass might modify matched_value in-place - self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + # Perform a deep copy because the transformer's forward pass modifies matched_value in-place. + # Without cloning, the original cache in init_pool or recur_pool would be polluted, + # causing incorrect predictions in subsequent queries. + if self.use_new_cache_manager: + # NEW SYSTEM: Use KeysValues.clone() for deep copy + cached_copy = matched_value.clone() + self.keys_values_wm_list.append(cached_copy) + else: + # OLD SYSTEM: Use custom_copy_kv_cache_to_shared_wm + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) self.keys_values_wm_size_list.append(matched_value.size) else: # If no matching cache is found, generate a new one using zero reset @@ -1295,27 +1855,58 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Encode observations into latent state representations obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations']) - # ========= for visual analysis ========= - # Uncomment the lines below for visual analysis in Pong - # self.plot_latent_tsne_each_and_all_for_pong(obs_embeddings, suffix='pong_H10_H4_tsne') - # self.save_as_image_with_timestep(batch['observations'], suffix='pong_H10_H4_tsne') - # Uncomment the lines below for visual analysis in visual match - # self.plot_latent_tsne_each_and_all(obs_embeddings, suffix='visual_match_memlen1-60-15_tsne') - # self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne') + # ======================== Logging for Analysis ======================== + # This block calculates various metrics for model analysis if the corresponding config flag is enabled. + # These metrics help in debugging and understanding model behavior during training. + if self.analysis_dormant_ratio_weight_rank: + # --- Dormant Ratio Calculation --- + # Calculate the dormant ratio of the encoder to monitor neuron activity. + shape = batch['observations'].shape # Original shape, e.g., (B, T, C, H, W) + # Reshape observations to create a single large batch for the encoder. + # E.g., (32, 5, 3, 64, 64) -> (160, 3, 64, 64) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) + + dormant_ratio_encoder_dict = calculate_dormant_ratio( + self.tokenizer.encoder, inputs.detach(), dormant_threshold=self.dormant_threshold + ) + dormant_ratio_encoder = dormant_ratio_encoder_dict['global'] + + # --- Average Weight Magnitude Calculation --- + # Calculate the global average absolute weight magnitude for different model components. + # This is a useful metric for monitoring training stability. + avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder) + avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer) + avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict) + + # --- Effective Rank Calculation --- + # Calculate the effective rank of representations from specific layers in the encoder. + # This metric helps analyze the dimensionality and information content of the learned features. + # The 'representation_layer_name' argument specifies the target layer within the model's named modules. + + # Effective rank for the final linear layer of the encoder. + e_rank_last_linear = compute_effective_rank( + self.tokenizer.encoder, inputs, representation_layer_name="last_linear" + ) + # Effective rank for the SimNorm layer of the encoder. + e_rank_sim_norm = compute_effective_rank( + self.tokenizer.encoder, inputs, representation_layer_name="sim_norm" + ) - - # ========= logging for analysis ========= - if self.analysis_dormant_ratio: - # Calculate dormant ratio of the encoder - shape = batch['observations'].shape # (..., C, H, W) - inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) - dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.representation_network, inputs.detach(), - percentage=self.dormant_threshold) - self.past_kv_cache_recurrent_infer.clear() + # ==================== Clear Cache Using Correct API ==================== + if self.use_new_cache_manager: + self.kv_cache_manager.clear_recur_cache() + else: + self.past_kv_cache_recurrent_infer.clear() + # ============================================================================= self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: dormant_ratio_encoder = torch.tensor(0.) + avg_weight_mag_encoder = torch.tensor(0.) + avg_weight_mag_transformer = torch.tensor(0.) + avg_weight_mag_head = torch.tensor(0.) + e_rank_last_linear = torch.tensor(0.) + e_rank_sim_norm = torch.tensor(0.) # Calculate the L2 norm of the latent state roots latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() @@ -1329,37 +1920,69 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Forward pass to obtain predictions for observations, rewards, and policies outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, start_pos=start_pos) + # Get intermediate tensor x from model output and detach computation graph + intermediate_tensor_x = outputs.output_sequence.detach() + + global_step = kwargs.get('global_step', 0) + if global_step > 0 and global_step % 100000000000 == 0: # TODO + + with torch.no_grad(): + # Convert logits to scalar values + # Note: outputs shape is (B, L, E), we need to reshape + batch_size, seq_len = batch['actions'].shape[0], batch['actions'].shape[1] + + pred_val_logits = outputs.logits_value.view(batch_size * seq_len, -1) + pred_rew_logits = outputs.logits_rewards.view(batch_size * seq_len, -1) + + scalar_values = inverse_scalar_transform_handle(pred_val_logits).squeeze(-1) + scalar_rewards = inverse_scalar_transform_handle(pred_rew_logits).squeeze(-1) + + self._analyze_latent_representation( + latent_states=obs_embeddings, + timesteps=batch['timestep'], + game_states=batch['observations'], + predicted_values=scalar_values, + predicted_rewards=scalar_rewards, + step_counter=global_step + ) + + if self.config.use_priority: + # Calculate value_priority, similar to MuZero. + with torch.no_grad(): + # 1. Get the predicted value logits for the first step of the sequence (t=0). + # The shape is (B, support_size). + predicted_value_logits_step0 = outputs.logits_value[:, 0, :] + + # 2. Convert the categorical prediction to a scalar value. + # The shape becomes (B, 1). + predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) + + # 3. Get the target scalar value for the first step from the batch. + # The shape is (B, num_unroll_steps), so we take the first column. + target_scalar_value_step0 = batch['scalar_target_value'][:, 0] + + # 4. Calculate the L1 loss (absolute difference) between prediction and target. + # This is the priority. We use reduction='none' to get per-sample priorities. + value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') + else: + value_priority = torch.tensor(0.) + if self.obs_type == 'image': - # Reconstruct observations from latent state representations - # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - - # ========== for visualization ========== - # Uncomment the lines below for visual analysis - # original_images, reconstructed_images = batch['observations'], reconstructed_images - # target_policy = batch['target_policy'] - # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( - # batch['observations'].shape[0], batch['observations'].shape[1], 1) - # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( - # batch['observations'].shape[0], batch['observations'].shape[1], 1) - # ========== for visualization ========== - - # ========== Calculate reconstruction loss and perceptual loss ============ - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 - # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 - - latent_recon_loss = self.latent_recon_loss - perceptual_loss = self.perceptual_loss + if self.config.latent_recon_loss_weight > 0: + # Reconstruct observations from latent state representations + reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # Calculate reconstruction loss and perceptual loss + latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + else: + # TODO: + latent_recon_loss = self.latent_recon_loss + perceptual_loss = self.perceptual_loss elif self.obs_type == 'vector': perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) - - # Reconstruct observations from latent state representations - # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) - - # # Calculate reconstruction loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), - # reconstructed_images) latent_recon_loss = self.latent_recon_loss elif self.obs_type == 'text': @@ -1390,49 +2013,29 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar latent_recon_loss = self.latent_recon_loss elif self.obs_type == 'image_memory': - # Reconstruct observations from latent state representations - # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) - # original_images, reconstructed_images = batch['observations'], reconstructed_images - - # ========== for visualization ========== - # Uncomment the lines below for visual analysis - # target_policy = batch['target_policy'] - # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( - # batch['observations'].shape[0], batch['observations'].shape[1], 1) - # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( - # batch['observations'].shape[0], batch['observations'].shape[1], 1) - # ========== for visualization ========== - - # Calculate reconstruction loss and perceptual loss - # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), - # reconstructed_images) latent_recon_loss = self.latent_recon_loss perceptual_loss = self.perceptual_loss - # ========= logging for analysis ========= - if self.analysis_dormant_ratio: + # ========= Logging for analysis ========= + if self.analysis_dormant_ratio_weight_rank: # Calculate dormant ratio of the world model - dormant_ratio_world_model = cal_dormant_ratio(self, { + dormant_ratio_world_model = calculate_dormant_ratio(self, { 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, - percentage=self.dormant_threshold) - self.past_kv_cache_recurrent_infer.clear() + dormant_threshold=self.dormant_threshold) + dormant_ratio_transformer = dormant_ratio_world_model['transformer'] + dormant_ratio_head = dormant_ratio_world_model['head'] + + # ==================== Clear Cache Using Correct API ==================== + if self.use_new_cache_manager: + self.kv_cache_manager.clear_recur_cache() + else: + self.past_kv_cache_recurrent_infer.clear() + # ============================================================================= self.keys_values_wm_list.clear() torch.cuda.empty_cache() else: - dormant_ratio_world_model = torch.tensor(0.) - - # ========== for visualization ========== - # Uncomment the lines below for visualization - # predict_policy = outputs.logits_policy - # predict_policy = F.softmax(outputs.logits_policy, dim=-1) - # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) - # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) - # import pdb; pdb.set_trace() - # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') - - # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') - # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') - # ========== for visualization ========== + dormant_ratio_transformer = torch.tensor(0.) + dormant_ratio_head = torch.tensor(0.) # For training stability, use target_tokenizer to compute the true next latent state representations with torch.no_grad(): @@ -1468,15 +2071,29 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # assert not torch.isinf(loss_obs).any(), "loss_obs contains Inf values" # for name, param in self.tokenizer.encoder.named_parameters(): # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + elif self.predict_latent_loss_type == 'cos_sim': + # Cosine Similarity Loss + # print("predict_latent_loss_type == 'cos_sim'") + cosine_sim_loss = 1 - F.cosine_similarity(logits_observations, labels_observations, dim=-1) + loss_obs = cosine_sim_loss # Apply mask to loss_obs mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) loss_obs = (loss_obs * mask_padding_expanded) - # Compute labels for policy and value - labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], - batch['target_policy'], - batch['mask_padding']) + # ==================== [NEW] Fix3: Load re-smooth options from config ==================== + use_target_policy_resmooth = getattr(self.config, 'use_target_policy_resmooth', False) + target_policy_resmooth_eps = getattr(self.config, 'target_policy_resmooth_eps', 0.05) + # ====================================================================================== + + # Compute labels for policy and value (with optional re-smoothing) + labels_policy, labels_value = self.compute_labels_world_model_value_policy( + batch['target_value'], + batch['target_policy'], + batch['mask_padding'], + use_target_policy_resmooth=use_target_policy_resmooth, + target_policy_resmooth_eps=target_policy_resmooth_eps + ) # Compute losses for rewards, policy, and value loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') @@ -1497,10 +2114,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - # ==== TODO: calculate the new priorities for each transition. ==== - # value_priority = L1Loss(reduction='none')(labels_value.squeeze(-1), outputs['logits_value'][:, 0]) - # value_priority = value_priority.data.cpu().numpy() + 1e-6 - # Compute timesteps timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) # Compute discount coefficients for each timestep @@ -1552,6 +2165,10 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + # Add encoder output to return dictionary for external training loop access + # Using .detach() because this tensor is only used for subsequent clip operations and should not affect gradient computation + detached_obs_embeddings = obs_embeddings.detach() + if self.continuous_action_space: return LossWithIntermediateLosses( latent_recon_loss_weight=self.latent_recon_loss_weight, @@ -1569,11 +2186,24 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar middle_step_losses=middle_step_losses, last_step_losses=last_step_losses, dormant_ratio_encoder=dormant_ratio_encoder, - dormant_ratio_world_model=dormant_ratio_world_model, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, latent_state_l2_norms=latent_state_l2_norms, policy_mu=mu, policy_sigma=sigma, target_sampled_actions=target_sampled_actions, + + value_priority=value_priority, + intermediate_tensor_x=intermediate_tensor_x, + obs_embeddings=detached_obs_embeddings, + logits_value=outputs.logits_value.detach(), + logits_reward=outputs.logits_rewards.detach(), + logits_policy=outputs.logits_policy.detach(), ) else: return LossWithIntermediateLosses( @@ -1592,8 +2222,20 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar middle_step_losses=middle_step_losses, last_step_losses=last_step_losses, dormant_ratio_encoder=dormant_ratio_encoder, - dormant_ratio_world_model=dormant_ratio_world_model, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, latent_state_l2_norms=latent_state_l2_norms, + value_priority=value_priority, + intermediate_tensor_x=intermediate_tensor_x, + obs_embeddings=detached_obs_embeddings, + logits_value=outputs.logits_value.detach(), + logits_reward=outputs.logits_rewards.detach(), + logits_policy=outputs.logits_policy.detach(), ) @@ -1659,7 +2301,7 @@ def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): return policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma - def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: + def _calculate_policy_loss_cont(self, outputs, batch: dict, task_id=None) -> Tuple[torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculate the policy loss for continuous actions. @@ -1674,9 +2316,12 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tenso - mu (:obj:`torch.Tensor`): The mean of the normal distribution. - sigma (:obj:`torch.Tensor`): The standard deviation of the normal distribution. """ - batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + if task_id is None: + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ 0], self.config.num_unroll_steps, self.config.action_space_size - + else: + batch_size, num_unroll_steps, action_space_size = outputs.logits_policy.shape[ + 0], self.config.num_unroll_steps, self.config.action_space_size_list[task_id] policy_logits_all = outputs.logits_policy mask_batch = batch['mask_padding'] child_sampled_actions_batch = batch['child_sampled_actions'] @@ -1718,6 +2363,8 @@ def _calculate_policy_loss_cont(self, outputs, batch: dict) -> Tuple[torch.Tenso # KL as projector target_log_prob_sampled_actions = torch.log(target_normalized_visit_count + 1e-6) + + # KL as projector policy_loss = -torch.sum( torch.exp(target_log_prob_sampled_actions.detach()) * log_prob_sampled_actions, 1 ) * mask_batch @@ -1738,9 +2385,15 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): logits = getattr(outputs, f'logits_{element}') + # ==================== TODO: Temperature Scaling for Policy ==================== + if element == 'policy' and self.use_policy_loss_temperature and self.policy_loss_temperature != 1.0: + # Apply temperature scaling to soften the distribution + logits = logits / self.policy_loss_temperature + # =================================================================================== + if torch.isnan(logits).any(): raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") - + if torch.isnan(labels).any(): raise ValueError(f"NaN detected in labels_value for batch {batch} and element '{element}'") @@ -1767,6 +2420,7 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): return loss + #@profile def compute_policy_entropy_loss(self, logits, mask): # Compute entropy of the policy probs = torch.softmax(logits, dim=1) @@ -1776,6 +2430,7 @@ def compute_policy_entropy_loss(self, logits, mask): entropy_loss = (entropy * mask) return entropy_loss + #@profile def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag @@ -1795,11 +2450,23 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc return labels_observations, labels_rewards.view(-1, self.support_size), None + #@profile def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, - mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask_padding: torch.BoolTensor, + use_target_policy_resmooth: bool = False, + target_policy_resmooth_eps: float = 0.05) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute labels for value and policy predictions. """ mask_fill = torch.logical_not(mask_padding) + # ==================== [NEW] Fix3: Re-smooth Target Policy ==================== + # Re-smooth target_policy to prevent extreme distributions in buffer + if use_target_policy_resmooth and target_policy_resmooth_eps > 0: + num_actions = target_policy.shape[-1] + uniform_dist = torch.ones_like(target_policy) / num_actions + target_policy = (1 - target_policy_resmooth_eps) * target_policy + \ + target_policy_resmooth_eps * uniform_dist + # ============================================================================= + # Fill the masked areas of policy mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) labels_policy = target_policy.masked_fill(mask_fill_policy, -100) @@ -1817,11 +2484,23 @@ def clear_caches(self): """ Clears the caches of the world model. """ - for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: - kv_cache_dict_env.clear() - self.past_kv_cache_recurrent_infer.clear() - self.keys_values_wm_list.clear() - print(f'Cleared {self.__class__.__name__} past_kv_cache.') + if self.use_new_cache_manager: + # Use new KV cache manager's clear method + self.kv_cache_manager.clear_all() + print(f'Cleared {self.__class__.__name__} KV caches (NEW system).') + + # Optionally print stats before clearing + if hasattr(self.kv_cache_manager, 'get_stats_summary'): + stats = self.kv_cache_manager.get_stats_summary() + if stats.get('stats_enabled'): + logging.debug(f'Cache stats before clear: {stats}') + else: + # Use old cache clearing logic + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + print(f'Cleared {self.__class__.__name__} past_kv_cache (OLD system).') def __repr__(self) -> str: return "transformer-based latent world_model of UniZero" diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py new file mode 100644 index 000000000..836a463c7 --- /dev/null +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -0,0 +1,2069 @@ +import collections +import logging +import math +import os +from typing import Any, Dict, Optional, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from ding.utils import get_rank +from einops import rearrange +from matplotlib.offsetbox import AnnotationBbox, OffsetImage +from matplotlib.patches import Patch +from sklearn.manifold import TSNE + +from lzero.model.common import SimNorm +from lzero.model.unizero_world_models.world_model import WorldModel +from lzero.model.utils import ( + calculate_dormant_ratio, + calculate_effective_rank, + compute_average_weight_magnitude, +) + +from .slicer import Head +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, WorldModelOutput, hash_state, init_weights + +# Set the logging level for the root logger +logging.getLogger().setLevel(logging.DEBUG) + + +class WorldModelMT(WorldModel): + """ + Overview: + The WorldModel class for the multi-task UniZero model. It is responsible for + predicting the next latent state, reward, policy, and value based on the + current latent state and action. This model is a scalable latent world model + composed of three main parts: a tokenizer, a transformer, and prediction heads. + """ + + def __init__(self, config: TransformerConfig, tokenizer: Tokenizer) -> None: + """ + Overview: + Initializes the multi-task WorldModel. + Arguments: + - config (:obj:`TransformerConfig`): The configuration object for the transformer and world model. + - tokenizer (:obj:`Tokenizer`): The tokenizer for encoding observations. + """ + super().__init__(config, tokenizer) + self.tokenizer = tokenizer + self.config = config + + self.continuous_action_space = self.config.continuous_action_space + self.task_num = config.task_num + self.env_num = self.config.env_num + + # Whether to share prediction heads across tasks. + self.share_head = config.share_head + + self.device = torch.device('cuda' if torch.cuda.is_available() and self.config.device != 'cpu' else 'cpu') + print(f"self.device: {self.device}") + + # Positional embedding layer. + self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + + # Task embedding setup. + self.use_task_embed = config.use_task_embed + self.task_embed_option = self.config.task_embed_option + self.task_embed_dim = config.task_embed_dim if hasattr(config, "task_embed_dim") else 96 + self.register_token_num = config.register_token_num if hasattr(config, "register_token_num") else 4 + + if self.task_embed_option == "register_task_embed": + # When using "register_task_embed", the positional encoding is not adjusted. + # Use a non-trainable, zero-initialized nn.Embedding for positional embeddings. + self.pos_emb = nn.Embedding(config.max_tokens, self.config.embed_dim, device=self.device) + nn.init.constant_(self.pos_emb.weight, 0.0) # Initialize with all zeros. + self.pos_emb.weight.requires_grad = False # Disable updates. + + # Precompute positional embedding differences for efficient inference. + self.precompute_pos_emb_diff_kv() + + self.sim_norm = SimNorm(simnorm_dim=self.config.group_size) + + # Configure embedding dimensions based on the task embedding strategy. + if self.task_embed_option == "concat_task_embed": + # TODO: Currently, with "concat_task_embed", self.pos_emb needs to be fixed at 0. + self.task_emb = nn.Embedding(self.task_num, self.task_embed_dim, max_norm=1) # TDMPC2 suggests max_norm=1. + self.obs_act_embed_dim = config.embed_dim - self.task_embed_dim + self.register_token_num = 0 + elif self.task_embed_option == "register_task_embed": + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) + self.obs_act_embed_dim = config.embed_dim + elif self.task_embed_option == "add_task_embed": + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) + self.obs_act_embed_dim = config.embed_dim + else: + self.task_emb = None + self.obs_act_embed_dim = config.embed_dim + self.register_token_num = 0 + + self.transformer = Transformer(self.config, self.task_emb) + + # --- Analysis and Logging Setup --- + self.analysis_dormant_ratio_interval = self.config.get('analysis_dormant_ratio_interval', 100) + self._analysis_step_counter = 0 + self.do_analysis = self.config.analysis_dormant_ratio_weight_rank + + self.analysis_tsne = self.config.get('analysis_tsne', False) + if self.analysis_tsne: + self.env_id_list = self.config.env_id_list + # Automatically generate short names for environments. + self.env_short_names = { + env_id: env_id.replace('NoFrameskip-v4', '') + for env_id in self.config.env_id_list + } + # Color mapping to ensure each task has a fixed color. + self.num_tasks = len(self.env_id_list) + self.colors = self._generate_colors(self.num_tasks) + + # --- Prediction Head Initialization --- + self.head_policy_multi_task = nn.ModuleList() + self.head_value_multi_task = nn.ModuleList() + self.head_rewards_multi_task = nn.ModuleList() + self.head_observations_multi_task = nn.ModuleList() + + self.num_experts_in_moe_head = config.num_experts_in_moe_head + self.use_normal_head = config.use_normal_head + self.use_moe_head = config.use_moe_head + self.use_softmoe_head = config.use_softmoe_head + + self.to(self.device) + + # Initialize configuration parameters from the config object. + self._initialize_config_parameters() + self._initialize_patterns() + + self.hidden_size = config.embed_dim // config.num_heads + + # Initialize action embedding table based on action space type. + if self.continuous_action_space: + self.act_embedding_table = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.action_space_size_list[task_id], self.obs_act_embed_dim, device=self.device, bias=False), + SimNorm(simnorm_dim=self.group_size) + ) for task_id in range(self.task_num) + ]) + else: + # For discrete action space. + self.act_embedding_table = nn.Embedding(config.action_space_size, self.obs_act_embed_dim, device=self.device) + print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + print(f'=' * 20) + print(f"self.obs_act_embed_dim: {self.obs_act_embed_dim}") + print(f'=' * 20) + + # ==================== [NEW] Policy Stability Fix Options ==================== + # Load fix options from config (with defaults for backward compatibility) + self.use_policy_logits_clip = getattr(self.config, 'use_policy_logits_clip', False) + self.policy_logits_clip_method = getattr(self.config, 'policy_logits_clip_method', 'normalize_max') + self.policy_logits_clip_min = getattr(self.config, 'policy_logits_clip_min', -10.0) + self.policy_logits_clip_max = getattr(self.config, 'policy_logits_clip_max', 10.0) + self.policy_logits_soft_beta = getattr(self.config, 'policy_logits_soft_beta', 1.0) + self.policy_logits_adaptive_percentile = getattr(self.config, 'policy_logits_adaptive_percentile', 95) + + assert self.num_experts_in_moe_head > 0 + if self.use_normal_head: + self.final_norm_option_in_obs_head = getattr(config, 'final_norm_option_in_obs_head', 'LayerNorm') + print('We use normal head') + for task_id in range(self.task_num): + if self.continuous_action_space: + self.sigma_type = self.config.sigma_type + self.bound_type = self.config.bound_type + head_policy = self._create_head_cont(self.value_policy_tokens_pattern, self.config.action_space_size_list[task_id]) + else: + head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + + if not self.share_head or task_id == 0: + self.head_policy_multi_task.append(head_policy) + + head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + if not self.share_head or task_id == 0: + self.head_value_multi_task.append(head_value) + + head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + if not self.share_head or task_id == 0: + self.head_rewards_multi_task.append(head_rewards) + + head_observations = self._create_head( + self.all_but_last_latent_state_pattern, + self.config.embed_dim, + self._get_final_norm(self.final_norm_option_in_obs_head) # Use the specified normalization method. + ) + if not self.share_head or task_id == 0: + self.head_observations_multi_task.append(head_observations) + + elif self.use_softmoe_head: + print(f'We use softmoe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + self.soft_moe_instances = {} + self.create_head_modules_softmoe() + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + elif self.use_moe_head: + print(f'We use moe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + self.moe_instances = {} + self.create_head_modules_moe() + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) + + # Group all head modules into a ModuleDict for easier management. + self.head_dict = nn.ModuleDict({ + name: module for name, module in self.named_children() + if name.startswith("head_") and name.endswith("_multi_task") + }) + print("=" * 20) + print(f"self.head_dict:{self.head_dict}") + + # Apply weight initialization. The order of initialization is important. + self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + self._initialize_last_layer_mt() + + # --- Cache and State Initialization --- + self._initialize_cache_structures() + self._initialize_projection_input_dim() + self._initialize_statistics() + self._initialize_transformer_keys_values() + + self.latent_recon_loss = torch.tensor(0., device=self.device) + self.perceptual_loss = torch.tensor(0., device=self.device) + + # Initially set to game_segment_length to ensure all KVs in self.shared_pool_init_infer are valid. + # TODO: Critical. This should be changed to match segment_length. + self.shared_pool_size_init = int(self.config.game_segment_length) # NOTE: Will having too many cause incorrect retrieval of the kv cache? + + self.shared_pool_size_recur = int(self.num_simulations*self.env_num) + self.shared_pool_recur_infer = [None] * self.shared_pool_size_recur + self.shared_pool_index = 0 + + # For init_infer, it only needs to retain the results of the most recent step. + # NOTE: A large pool size might cause incorrect retrieval of the kv cache. + self.shared_pool_init_infer = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + self.shared_pool_index_init_envs = [0 for _ in range(self.env_num)] + + # For wm (world model) forward passes during training. + self.shared_pool_size_wm = int(self.env_num) + self.shared_pool_wm = [None] * self.shared_pool_size_wm + self.shared_pool_index_wm = 0 + + self.reanalyze_phase = False + self._rank = get_rank() + + def _scale_grad(self, grad: torch.Tensor) -> torch.Tensor: + """ + Overview: + Scales the gradient. This hook is registered to encoder parameters + to stabilize multi-task training. + Arguments: + - grad (:obj:`torch.Tensor`): The original gradient. + Returns: + - (:obj:`torch.Tensor`): The scaled gradient. + """ + # Scale by 1/sqrt(k) for a conservative approach, where k is the number of tasks. + return grad / math.sqrt(self.task_num) + + def _generate_colors(self, num_colors: int) -> list: + """ + Overview: + Generates a list of unique colors for visualization purposes, + suitable for a large number of categories. + Arguments: + - num_colors (:obj:`int`): The desired number of unique colors. + Returns: + - (:obj:`list`): A list of colors. + """ + # Concatenate multiple discrete colormaps from matplotlib to get more colors. + color_maps = ['tab20', 'tab20b', 'tab20c'] + colors = [] + for cmap_name in color_maps: + cmap = plt.get_cmap(cmap_name) + colors.extend([cmap(i) for i in range(cmap.N)]) + if len(colors) >= num_colors: + break + # Generate additional colors if needed. + if len(colors) < num_colors: + additional_colors = plt.cm.get_cmap('hsv', num_colors - len(colors)) + colors.extend([additional_colors(i) for i in range(num_colors - len(colors))]) + return colors[:num_colors] + + def _initialize_config_parameters(self) -> None: + """Initializes model attributes from the configuration object.""" + self.policy_entropy_weight = self.config.policy_entropy_weight + self.predict_latent_loss_type = self.config.predict_latent_loss_type + self.group_size = self.config.group_size + self.num_groups = self.config.embed_dim // self.group_size + self.obs_type = self.config.obs_type + self.embed_dim = self.config.embed_dim + self.num_heads = self.config.num_heads + self.gamma = self.config.gamma + self.context_length = self.config.context_length + self.dormant_threshold = self.config.dormant_threshold + self.analysis_dormant_ratio_weight_rank = self.config.analysis_dormant_ratio_weight_rank + self.num_observations_tokens = self.config.tokens_per_block - 1 + self.latent_recon_loss_weight = self.config.latent_recon_loss_weight + self.perceptual_loss_weight = self.config.perceptual_loss_weight + self.support_size = self.config.support_size + self.action_space_size = self.config.action_space_size + self.max_cache_size = self.config.max_cache_size + self.num_layers = self.config.num_layers + + def _initialize_patterns(self) -> None: + """Initializes patterns (masks) for selecting specific tokens for prediction heads.""" + self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) + self.all_but_last_latent_state_pattern[-2] = 0 + self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.act_tokens_pattern[-1] = 1 + self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.value_policy_tokens_pattern[-2] = 1 + + def _get_final_norm(self, norm_option: str) -> nn.Module: + """Returns the specified normalization module.""" + if norm_option == 'LayerNorm': + return nn.LayerNorm(self.config.embed_dim, eps=1e-5) + elif norm_option == 'SimNorm': + return SimNorm(simnorm_dim=self.config.group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_obs_head: {norm_option}") + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None) -> Head: + """Creates a standard prediction head.""" + modules = [ + nn.LayerNorm(self.config.embed_dim), # TODO + nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.LayerNorm(self.config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None, moe: Optional[nn.Module] = None) -> Head: + """Creates a prediction head with a Mixture-of-Experts (MoE) layer.""" + modules = [ + nn.LayerNorm(self.config.embed_dim), # TODO + moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def get_moe(self, name: str) -> nn.Module: + """Gets or creates a MoE instance by name.""" + from .moe import MoELayer, MultiplicationFeedForward + + if name not in self.moe_instances: + # Create multiple FeedForward instances for multiplication-based MoE. + experts = nn.ModuleList([ + MultiplicationFeedForward(self.config) for _ in range(self.config.num_experts_of_moe_in_transformer) + ]) + self.moe_instances[name] = MoELayer( + experts=experts, + gate=nn.Linear(self.config.embed_dim, self.config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + return self.moe_instances[name] + + def create_head_modules_moe(self) -> None: + """Creates all MoE prediction head modules.""" + self.head_rewards = self._create_head_moe(self.act_tokens_pattern, self.support_size, moe=self.get_moe("rewards_moe")) + self.head_observations = self._create_head_moe(self.all_but_last_latent_state_pattern, self.embed_dim, norm_layer=self.sim_norm, moe=self.get_moe("observations_moe")) + self.head_policy = self._create_head_moe(self.value_policy_tokens_pattern, self.action_space_size, moe=self.get_moe("policy_moe")) + self.head_value = self._create_head_moe(self.value_policy_tokens_pattern, self.support_size, moe=self.get_moe("value_moe")) + + def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer: Optional[nn.Module] = None, soft_moe: Optional[nn.Module] = None) -> Head: + """Creates a prediction head with a Soft-MoE layer.""" + modules = [ + soft_moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def get_soft_moe(self, name: str) -> nn.Module: + """Gets or creates a Soft-MoE instance by name.""" + from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE + if name not in self.soft_moe_instances: + self.soft_moe_instances[name] = SoftMoE( + dim=self.embed_dim, + num_experts=self.num_experts_in_moe_head, + geglu=True + ) + return self.soft_moe_instances[name] + + def create_head_modules_softmoe(self) -> None: + """Creates all Soft-MoE prediction head modules.""" + self.head_rewards = self._create_head_softmoe(self.act_tokens_pattern, self.support_size, soft_moe=self.get_soft_moe("rewards_soft_moe")) + self.head_observations = self._create_head_softmoe(self.all_but_last_latent_state_pattern, self.config.embed_dim, norm_layer=self.sim_norm, soft_moe=self.get_soft_moe("observations_soft_moe")) + self.head_policy = self._create_head_softmoe(self.value_policy_tokens_pattern, self.action_space_size, soft_moe=self.get_soft_moe("policy_soft_moe")) + self.head_value = self._create_head_softmoe(self.value_policy_tokens_pattern, self.support_size, soft_moe=self.get_soft_moe("value_soft_moe")) + + def _initialize_last_layer_mt(self) -> None: + """Initializes the last linear layer of prediction heads to zero for training stability.""" + last_linear_layer_init_zero = True + print(f'world_model_mt.py:self.task_num:{self.task_num}') + if last_linear_layer_init_zero: + if self.continuous_action_space: + # For continuous actions, policy head might have a different initialization strategy. + module_to_initialize = self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task + else: + module_to_initialize = self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task + + for head in module_to_initialize: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + def _initialize_cache_structures(self) -> None: + """Initializes cache structures for storing past keys and values during inference.""" + + self.past_kv_cache_recurrent_infer = {} + self.pool_idx_to_key_map_recur_infer = [None] * self.shared_pool_size_recur + self.past_kv_cache_init_infer_envs = [{} for _ in range(self.env_num)] + # Auxiliary data structure for reverse lookup: pool_index -> key + self.pool_idx_to_key_map_init_envs = [[None] * self.shared_pool_size_init for _ in range(self.env_num)] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _initialize_projection_input_dim(self) -> None: + """Initializes the input dimension for the projection based on observation tokenization.""" + if self.num_observations_tokens == 16: + self.projection_input_dim = 128 + elif self.num_observations_tokens == 1: + if self.task_embed_option in ["concat_task_embed", "register_task_embed", "add_task_embed"]: + self.projection_input_dim = self.config.embed_dim + if self.task_embed_option == "concat_task_embed": + self.projection_input_dim -= self.task_embed_dim + else: + self.projection_input_dim = self.config.embed_dim + + def _initialize_statistics(self) -> None: + """Initializes counters for cache hit rates and other statistics.""" + self.hit_count = 0 + self.total_query_count = 0 + self.length_largethan_maxminus5_context_cnt = 0 + self.length_largethan_maxminus7_context_cnt = 0 + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + + def _initialize_transformer_keys_values(self) -> None: + """Initializes empty key-value cache structures for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.context_length) + + def precompute_pos_emb_diff_kv(self) -> None: + """ + Overview: + Precomputes positional embedding differences for keys and values. This is an + optimization to speed up KV cache updates during recurrent inference by avoiding + re-computation of positional embeddings. + """ + if self.context_length <= 2: + return # No context to precompute for. + + # Precompute positional embedding matrices for all layers. + self.positional_embedding_k = [self._get_positional_embedding(layer, 'key') for layer in range(self.config.num_layers)] + self.positional_embedding_v = [self._get_positional_embedding(layer, 'value') for layer in range(self.config.num_layers)] + + # Precompute all possible positional embedding differences. + self.pos_emb_diff_k = [] + self.pos_emb_diff_v = [] + + for layer in range(self.config.num_layers): + layer_pos_emb_diff_k = {} + layer_pos_emb_diff_v = {} + + # This is for the case when context window is full and we shift it. + # TODO: Generalize for different start/end points if necessary. + for start in [2]: + for end in [self.context_length - 1]: + original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] + new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] + layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k + + original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :] + new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :] + layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v + + self.pos_emb_diff_k.append(layer_pos_emb_diff_k) + self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + + def _get_positional_embedding(self, layer: int, attn_type: str) -> torch.Tensor: + """ + Overview: + Helper function to get positional embedding for a given layer and attention type. + Arguments: + - layer (:obj:`int`): The layer index. + - attn_type (:obj:`str`): The attention type, either 'key' or 'value'. + Returns: + - (:obj:`torch.Tensor`): The positional embedding tensor, detached from the graph. + """ + # TODO: Review the use of detach(). It's used here to prevent gradients from flowing back + # through the positional embeddings during this pre-computation phase. + attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + pos_emb = attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2) + return pos_emb.to(self.device).detach() + + def forward( + self, + obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, + is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, + task_id: int = 0 + ) -> WorldModelOutput: + """ + Overview: + Main forward pass for the world model. It processes either observation embeddings, + action tokens, or a combination of both, and passes them through the transformer + to generate predictions. + Arguments: + - obs_embeddings_or_act_tokens (:obj:`Dict`): A dictionary containing input tensors. + Can be 'obs_embeddings', 'act_tokens', or 'obs_embeddings_and_act_tokens'. + - past_keys_values (:obj:`Optional[torch.Tensor]`): The KV cache from previous steps. + - kvcache_independent (:obj:`bool`): Whether to use independent KV caching per item in the batch. + - is_init_infer (:obj:`bool`): Flag indicating if this is an initial inference step. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor of valid context lengths for each item. + - task_id (:obj:`int`): The ID of the current task. + Returns: + - (:obj:`WorldModelOutput`): An object containing the transformer output and logits for + observations, rewards, policy, and value. + """ + if self.use_task_embed: + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1, -1)).view(-1) + else: + # Use a zero tensor if task embeddings are disabled. + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) + + prev_steps = 0 if past_keys_values is None else past_keys_values.size + if kvcache_independent: + prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], device=self.device) + + if is_init_infer: + valid_context_lengths = None + + # --- Branch 1: Inference Phase (Collect/Eval) - Process observation embeddings --- + if 'obs_embeddings' in obs_embeddings_or_act_tokens: + obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] + if len(obs_embeddings.shape) == 2: + obs_embeddings = obs_embeddings.unsqueeze(1) + + # Apply task embeddings based on the chosen strategy. + if self.task_embed_option == "add_task_embed": + obs_embeddings = obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + if is_init_infer and not self.reanalyze_phase: + # Concatenate task embeddings only during initial inference. + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) + obs_embeddings = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + + num_steps = obs_embeddings.size(1) + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) + + # --- Branch 2: Inference Phase (Collect/Eval) - Process action tokens --- + elif 'act_tokens' in obs_embeddings_or_act_tokens: + act_tokens = obs_embeddings_or_act_tokens['act_tokens'] + if self.continuous_action_space: + num_steps = 1 + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: + act_tokens = act_tokens.unsqueeze(1) + else: + if len(act_tokens.shape) == 3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) + + # Get action embeddings from the task-specific or shared table. + if self.task_num >= 1 and self.continuous_action_space: + act_embeddings = self.act_embedding_table[task_id](act_tokens) + else: + act_embeddings = self.act_embedding_table(act_tokens) + + # Apply task embeddings. + if self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(act_embeddings.shape[0], act_embeddings.shape[1], -1) + act_embeddings = torch.cat([act_embeddings, task_emb_expanded], dim=-1) + + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths) + + # --- Branch 3: Training Phase - Process combined observation embeddings and action tokens --- + else: + if self.continuous_action_space: + sequences, num_steps = self._process_obs_act_combined_cont(obs_embeddings_or_act_tokens, prev_steps, task_id=task_id) + else: + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + + # Pass sequences through the transformer. + x = self._transformer_pass(sequences, past_keys_values, kvcache_independent, valid_context_lengths, task_id=task_id) + + # Generate logits using shared, task-specific, or MoE heads. + head_index = 0 if self.share_head else task_id + if self.use_moe_head or self.use_softmoe_head: + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + else: + logits_observations = self.head_observations_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value_multi_task[head_index](x, num_steps=num_steps, prev_steps=prev_steps) + + # ==================== [NEW] Advanced Policy Logits Control ==================== + # Apply configurable policy logits control to prevent explosion + # Multiple methods available: hard, soft_tanh, soft_sigmoid, normalize_max, etc. + self.use_policy_logits_clip=True # TODO + if self.use_policy_logits_clip: + logits_policy = self._apply_policy_logits_control(logits_policy) + # ================================================================================ + + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + def _add_position_embeddings( + self, + embeddings: torch.Tensor, + prev_steps: Union[int, torch.Tensor], + num_steps: int, + kvcache_independent: bool, + is_init_infer: bool, + valid_context_lengths: Optional[torch.Tensor] + ) -> torch.Tensor: + """ + Overview: + Adds positional embeddings to the input embeddings. + Arguments: + - embeddings (:obj:`torch.Tensor`): Input embeddings. + - prev_steps (:obj:`Union[int, torch.Tensor]`): Number of previous steps in the cache. + - num_steps (:obj:`int`): Number of new steps being added. + - kvcache_independent (:obj:`bool`): Flag for independent KV caching. + - is_init_infer (:obj:`bool`): Flag for initial inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths for each sequence. + Returns: + - (:obj:`torch.Tensor`): Embeddings with added positional information. + """ + if kvcache_independent: + steps_indices = prev_steps.unsqueeze(1) + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices) + return embeddings + position_embeddings + else: + if is_init_infer: + # For initial inference, positions are sequential from the previous step count. + pos_indices = prev_steps + torch.arange(num_steps, device=self.device) + return embeddings + self.pos_emb(pos_indices) + else: + # For recurrent steps, use valid_context_lengths to get correct positions. + valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + pos_indices = valid_context_lengths.unsqueeze(1) + torch.arange(num_steps, device=self.device) + position_embeddings = self.pos_emb(pos_indices) + return embeddings + position_embeddings + + def _process_obs_act_combined_cont(self, obs_embeddings_or_act_tokens: dict, prev_steps: int, task_id: int = 0) -> Tuple[torch.Tensor, int]: + """ + Overview: + Processes and combines observation embeddings and continuous action tokens for training. + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary with 'obs_embeddings_and_act_tokens'. + - prev_steps (:obj:`int`): Number of previous steps. + - task_id (:obj:`int`): The current task ID. + Returns: + - (:obj:`Tuple[torch.Tensor, int]`): A tuple of the combined sequence tensor and the number of steps. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_tokens = act_tokens.float() + if len(act_tokens.shape) == 2: + act_tokens = act_tokens.unsqueeze(-1) + + act_embeddings = self.act_embedding_table[task_id](act_tokens) + + B, L, K, E_obs = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + + if self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1) + + for i in range(L): + obs = obs_embeddings[:, i, :, :] + if self.task_embed_option == "add_task_embed": + obs = obs + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + obs = torch.cat([obs, task_emb_expanded.expand(B, K, -1)], dim=-1) + + act = act_embeddings[:, i, :].unsqueeze(1) + if self.task_embed_option == "concat_task_embed": + act = torch.cat([act, task_emb_expanded], dim=-1) + + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + pos_indices = prev_steps + torch.arange(num_steps, device=self.device) + return obs_act_embeddings + self.pos_emb(pos_indices), num_steps + + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens: dict, prev_steps: int, task_id: int = 0) -> Tuple[torch.Tensor, int]: + """ + Overview: + Processes and combines observation embeddings and discrete action tokens for training. + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary with 'obs_embeddings_and_act_tokens'. + - prev_steps (:obj:`int`): Number of previous steps. + - task_id (:obj:`int`): The current task ID. + Returns: + - (:obj:`Tuple[torch.Tensor, int]`): A tuple of the combined sequence tensor and the number of steps. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E_obs = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), self.config.embed_dim, device=self.device) + + if self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(B, 1, -1) + + for i in range(L): + obs = obs_embeddings[:, i, :, :] + if self.task_embed_option == "add_task_embed": + obs = obs + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + obs = torch.cat([obs, task_emb_expanded.expand(B, K, -1)], dim=-1) + + act = act_embeddings[:, i, 0, :].unsqueeze(1) + if self.task_embed_option == "concat_task_embed": + act = torch.cat([act, task_emb_expanded], dim=-1) + + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + pos_indices = prev_steps + torch.arange(num_steps, device=self.device) + return obs_act_embeddings + self.pos_emb(pos_indices), num_steps + + def _transformer_pass( + self, + sequences: torch.Tensor, + past_keys_values: Optional[torch.Tensor], + kvcache_independent: bool, + valid_context_lengths: Optional[torch.Tensor], + task_id: int = 0 + ) -> torch.Tensor: + """ + Overview: + Passes sequences through the transformer, handling different KV cache modes. + Arguments: + - sequences (:obj:`torch.Tensor`): Input sequences. + - past_keys_values (:obj:`Optional[torch.Tensor]`): The KV cache from previous steps. + - kvcache_independent (:obj:`bool`): Flag for independent KV caching. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Tensor of valid context lengths. + - task_id (:obj:`int`): The current task ID. + Returns: + - (:obj:`torch.Tensor`): The output from the transformer. + """ + if kvcache_independent: + x = [ + self.transformer(sequences[k].unsqueeze(0), past_kv, valid_context_lengths=valid_context_lengths[k].unsqueeze(0)) + for k, past_kv in enumerate(past_keys_values) + ] + return torch.cat(x, dim=0) + else: + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + + @torch.no_grad() + def reset_for_initial_inference(self, obs_act_dict: dict, task_id: int = 0) -> Tuple[WorldModelOutput, torch.Tensor]: + """ + Overview: + Resets the model state for the beginning of an episode or a new inference sequence. + It processes the initial observations and actions to create the first latent state + and populate the KV cache. + Arguments: + - obs_act_dict (:obj:`dict`): A dictionary containing 'obs', 'action', and 'current_obs'. + - task_id (:obj:`int`): The ID of the current task. + Returns: + - (:obj:`Tuple[WorldModelOutput, torch.Tensor]`): A tuple containing the world model output + and the initial latent state. + """ + if self.use_task_embed: + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1, -1)).view(-1) + else: + self.task_embeddings = torch.zeros(self.config.embed_dim, device=self.device) + + batch_obs = obs_act_dict['obs'] + batch_action = obs_act_dict['action'] + batch_current_obs = obs_act_dict['current_obs'] + + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_obs, task_id=task_id) + + if batch_current_obs is not None: + # --- Collect and Evaluation Phase --- + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs, task_id=task_id) + + # The latent state is the combination of observation embedding and task embedding. + if self.use_task_embed: + if self.task_embed_option == "add_task_embed": + self.latent_state = current_obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(current_obs_embeddings.shape[0], current_obs_embeddings.shape[1], -1) + self.latent_state = torch.cat([current_obs_embeddings, task_emb_expanded], dim=-1) + else: # "register_task_embed" or other cases + self.latent_state = current_obs_embeddings + else: + self.latent_state = current_obs_embeddings + + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, current_obs_embeddings, task_id=task_id) + else: + # --- Training Phase (for calculating target values) --- + if self.use_task_embed: + if self.task_embed_option == "add_task_embed": + self.latent_state = obs_embeddings + self.task_embeddings + elif self.task_embed_option == "concat_task_embed": + task_emb_expanded = self.task_embeddings.view(1, 1, -1).expand(obs_embeddings.shape[0], obs_embeddings.shape[1], -1) + self.latent_state = torch.cat([obs_embeddings, task_emb_expanded], dim=-1) + else: + self.latent_state = obs_embeddings + else: + self.latent_state = obs_embeddings + + outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, None, task_id=task_id) + + return outputs_wm, self.latent_state + + + #@profile + @torch.no_grad() + def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor, + batch_action=None, + current_obs_embeddings=None, task_id = 0) -> torch.FloatTensor: + """ + Refresh key-value pairs with the initial latent state for inference. + + Arguments: + - latent_state (:obj:`torch.LongTensor`): The latent state embeddings. + - batch_action (optional): Actions taken. + - current_obs_embeddings (optional): Current observation embeddings. + Returns: + - torch.FloatTensor: The outputs from the world model. + """ + n, num_observations_tokens, _ = last_obs_embeddings.shape + if n <= self.env_num and current_obs_embeddings is not None: + # ================ Collect and Evaluation Phase ================ + if current_obs_embeddings is not None: + if self.continuous_action_space: + first_step_flag = not isinstance(batch_action[0], np.ndarray) + else: + first_step_flag = max(batch_action) == -1 + if first_step_flag: + # First step in an episode + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], + max_tokens=self.context_length) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + if self.use_task_embed and self.task_embed_option in ["concat_task_embed", "add_task_embed"]: + # Copy and store keys_values_wm for a single environment + self.update_cache_context(self.latent_state, is_init_infer=True) + else: + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + else: + # Assume latest_state is the new latent_state, containing information from ready_env_num environments + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + for i in range(ready_env_num): + # Retrieve latent state for a single environment + state_single_env = last_obs_embeddings[i] + # Compute hash value using latent state for a single environment + cache_key = hash_state( + state_single_env.view(-1).cpu().numpy()) # last_obs_embeddings[i] is torch.Tensor + + # Retrieve cached value + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + self.root_total_query_cnt += 1 + if matched_value is not None: + # If a matching value is found, add it to the list + self.root_hit_cnt += 1 + # deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # Reset using zero values + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, + past_keys_values=self.keys_values_wm_single_env, + is_init_infer=True, task_id=task_id) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # Input self.keys_values_wm_list, output self.keys_values_wm + self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) + + batch_action = batch_action[:ready_env_num] + # if ready_env_num < self.env_num: + # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') + if self.continuous_action_space: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(1) + else: + act_tokens = torch.from_numpy(np.array(batch_action)).to(last_obs_embeddings.device).unsqueeze(-1) + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, + is_init_infer=True, task_id=task_id) + + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + if self.use_task_embed and self.task_embed_option in ["concat_task_embed", "add_task_embed"]: + # Copy and store keys_values_wm for a single environment + self.update_cache_context(self.latent_state, is_init_infer=True) + else: + # import ipdb; ipdb.set_trace() + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + + elif batch_action is not None and current_obs_embeddings is None: + # elif n > self.env_num and batch_action is not None and current_obs_embeddings is None: + # ================ calculate the target value in Train phase ================ + # [192, 16, 64] -> [32, 6, 16, 64] + last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens, + self.obs_act_embed_dim) # (BL, K) for unroll_step=1 + + last_obs_embeddings = last_obs_embeddings[:, :-1, :] + batch_action = torch.from_numpy(batch_action).to(last_obs_embeddings.device) + + if self.continuous_action_space: + act_tokens = batch_action + else: + act_tokens = rearrange(batch_action, 'b l -> b l 1') + + # select the last timestep for each sample + # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. + last_steps_act = act_tokens[:, -1:, :] + act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) + + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (last_obs_embeddings, act_tokens)}, task_id=task_id) + + # select the last timestep for each sample + last_steps_value = outputs_wm.logits_value[:, -1:, :] + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + + #@profile + @torch.no_grad() + def forward_initial_inference(self, obs_act_dict, task_id = 0): + """ + Perform initial inference based on the given observation-action dictionary. + + Arguments: + - obs_act_dict (:obj:`dict`): Dictionary containing observations and actions. + Returns: + - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. + """ + # UniZero has context in the root node + outputs_wm, latent_state = self.reset_for_initial_inference(obs_act_dict, task_id=task_id) + self.past_kv_cache_recurrent_infer.clear() + + return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, + outputs_wm.logits_policy, outputs_wm.logits_value) + + #@profile + @torch.no_grad() + def forward_recurrent_inference(self, state_action_history, simulation_index=0, + latent_state_index_in_search_path=[], task_id = 0): + """ + Perform recurrent inference based on the state-action history. + + Arguments: + - state_action_history (:obj:`list`): List containing tuples of state and action history. + - simulation_index (:obj:`int`, optional): Index of the current simulation. Defaults to 0. + - latent_state_index_in_search_path (:obj:`list`, optional): List containing indices of latent states in the search path. Defaults to []. + Returns: + - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. + """ + latest_state, action = state_action_history[-1] + ready_env_num = latest_state.shape[0] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.keys_values_wm_size_list = self.retrieve_or_generate_kvcache(latest_state, ready_env_num, simulation_index, task_id=task_id) + + latent_state_list = [] + if not self.continuous_action_space: + token = action.reshape(-1, 1) + else: + token = action.reshape(-1, self.config.action_space_size_list[task_id]) + + # ======= Print statistics for debugging ============= + # min_size = min(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 5: + # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 7: + # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: + # self.hit_freq = self.hit_count / self.total_query_count + # print('total_query_count:', self.total_query_count) + # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + + # Trim and pad kv_cache + self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) + self.keys_values_wm_size_list_current = self.keys_values_wm_size_list + + for k in range(2): + # action_token obs_token + if k == 0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + # Perform forward pass + outputs_wm = self.forward( + obs_embeddings_or_act_tokens, + past_keys_values=self.keys_values_wm, + kvcache_independent=False, + is_init_infer=False, + task_id = task_id + ) + + self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + if k == 0: + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + token = outputs_wm.logits_observations + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + # print(f'token.shape:{token.shape}') + + latent_state_list.append(token) + + del self.latent_state # Very important to minimize cuda memory usage + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache_context( + self.latent_state, + is_init_infer=False, + simulation_index=simulation_index, + latent_state_index_in_search_path=latent_state_index_in_search_path + ) + + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: + """ + Adjusts the key-value cache for each environment to ensure they all have the same size. + + In a multi-environment setting, the key-value cache (kv_cache) for each environment is stored separately. + During recurrent inference, the kv_cache sizes may vary across environments. This method pads each kv_cache + to match the largest size found among them, facilitating batch processing in the transformer forward pass. + + Arguments: + - is_init_infer (:obj:`bool`): Indicates if this is an initial inference. Default is True. + Returns: + - list: Updated sizes of the key-value caches. + """ + # Find the maximum size among all key-value caches + max_size = max(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate through each environment's key-value pairs + for idx, keys_values in enumerate(self.keys_values_wm_list): + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + effective_size = self.keys_values_wm_size_list[idx] + pad_size = max_size - effective_size + + # If padding is required, trim the end and pad the beginning of the cache + if pad_size > 0: + k_cache_trimmed = k_cache[:, :, :-pad_size, :] + v_cache_trimmed = v_cache[:, :, :-pad_size, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along a new dimension and remove any extra dimensions + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the maximum size + self.keys_values_wm._keys_values[layer]._k_cache._size = max_size + self.keys_values_wm._keys_values[layer]._v_cache._size = max_size + + return self.keys_values_wm_size_list + + #@profile + def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + latent_state_index_in_search_path=[], valid_context_lengths=None): + """ + Update the cache context with the given latent state. + + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state tensor. + - is_init_infer (:obj:`bool`): Flag to indicate if this is the initial inference. + - simulation_index (:obj:`int`): Index of the simulation. + - latent_state_index_in_search_path (:obj:`list`): List of indices in the search path. + - valid_context_lengths (:obj:`list`): List of valid context lengths. + """ + if self.context_length <= 2: + # No context to update if the context length is less than or equal to 2. + return + for i in range(latent_state.size(0)): + # ============ Iterate over each environment ============ + cache_key = hash_state(latent_state[i].view(-1).cpu().numpy()) # latent_state[i] is torch.Tensor + + context_length = self.context_length + + if not is_init_infer: + # ============ Internal Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + current_max_context_length = max(self.keys_values_wm_size_list_current) + trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + # cache shape [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + if trim_size > 0: + # Trim invalid leading zeros as per effective length + # Remove the first trim_size zero kv items + k_cache_trimmed = k_cache_current[:, trim_size:, :] + v_cache_trimmed = v_cache_current[:, trim_size:, :] + # If effective length < current_max_context_length, pad the end of cache with 'trim_size' zeros + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", + 0) # Pad with 'trim_size' zeros at end of cache + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + else: + k_cache_padded = k_cache_current + v_cache_padded = v_cache_current + + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = \ + self.keys_values_wm_size_list_current[i] + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = \ + self.keys_values_wm_size_list_current[i] + + # ============ NOTE: Very Important ============ + if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length - 1: + # import ipdb; ipdb.set_trace() + + # Keep only the last self.context_length-3 timesteps of context + # For memory environments, training is for H steps, recurrent_inference might exceed H steps + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + + # Index pre-computed positional encoding differences + # import ipdb; ipdb.set_trace() + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update single environment cache + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + else: + # ============ Root Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + + if self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + else: + # import ipdb; ipdb.set_trace() + + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :] + v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + # ORIGNAL + # if is_init_infer: + # # Store the latest key-value cache for initial inference + # cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + # self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + # else: + # # Store the latest key-value cache for recurrent inference + # cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + # self.past_kv_cache_recurrent_infer[cache_key] = cache_index + + if is_init_infer: + # TODO + # ==================== Active Eviction Logic ==================== + # 1. Retrieve the physical index that is about to be overwritten. + index_to_write = self.shared_pool_index_init_envs[i] + # 2. Use the auxiliary map to identify the old key stored at this index. + old_key_to_evict = self.pool_idx_to_key_map_init_envs[i][index_to_write] + # 3. If an old key exists, remove it from the main cache map. + if old_key_to_evict is not None: + # Ensure the key to be deleted actually exists to avoid unexpected errors. + if old_key_to_evict in self.past_kv_cache_init_infer_envs[i]: + del self.past_kv_cache_init_infer_envs[i][old_key_to_evict] + + # Now it is safe to write the new data. + cache_index = self.custom_copy_kv_cache_to_shared_init_envs(self.keys_values_wm_single_env, i) + + # 4. Update the new mapping in both the main cache map and the auxiliary map. + self.past_kv_cache_init_infer_envs[i][cache_key] = cache_index + self.pool_idx_to_key_map_init_envs[i][index_to_write] = cache_key + else: + # ==================== RECURRENT INFER FIX ==================== + # 1. Retrieve the physical index that is about to be overwritten. + index_to_write = self.shared_pool_index + # 2. Use the auxiliary map to identify the old key stored at this index. + old_key_to_evict = self.pool_idx_to_key_map_recur_infer[index_to_write] + # 3. If an old key exists, remove it from the main cache map. + if old_key_to_evict is not None: + if old_key_to_evict in self.past_kv_cache_recurrent_infer: + del self.past_kv_cache_recurrent_infer[old_key_to_evict] + + # 4. Now it is safe to write the new data. + cache_index = self.custom_copy_kv_cache_to_shared_recur(self.keys_values_wm_single_env) + + # 5. Update the new mapping in both the main cache map and the auxiliary map. + self.past_kv_cache_recurrent_infer[cache_key] = cache_index + self.pool_idx_to_key_map_recur_infer[index_to_write] = cache_key + + + #@profile + def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, + simulation_index: int = 0, task_id = 0) -> list: + """ + Retrieves or generates key-value caches for each environment based on the latent state. + + For each environment, this method either retrieves a matching cache from the predefined + caches if available, or generates a new cache if no match is found. The method updates + the internal lists with these caches and their sizes. + + Arguments: + - latent_state (:obj:`list`): List of latent states for each environment. + - ready_env_num (:obj:`int`): Number of environments ready for processing. + - simulation_index (:obj:`int`, optional): Index for simulation tracking. Default is 0. + Returns: + - list: Sizes of the key-value caches for each environment. + """ + for i in range(ready_env_num): + self.total_query_count += 1 + state_single_env = latent_state[i] # latent_state[i] is np.array + cache_key = hash_state(state_single_env) + + if self.reanalyze_phase: + # TODO: check if this is correct + matched_value = None + else: + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + cache_index = self.past_kv_cache_init_infer_envs[i].get(cache_key) + if cache_index is not None: + matched_value = self.shared_pool_init_infer[i][cache_index] + else: + matched_value = None + + # Only try to find from recurrent_infer cache if not found in init_infer + if matched_value is None: + # Safely get the index from dictionary, it may return None + recur_cache_index = self.past_kv_cache_recurrent_infer.get(cache_key) + # Only use it to retrieve value from physical pool if the index is valid (not None) + if recur_cache_index is not None: + matched_value = self.shared_pool_recur_infer[recur_cache_index] + + if recur_cache_index is None: + logging.debug(f"[OLD CACHE MISS] Not found for key={cache_key} in recurrent infer. Generating new cache.") + + + if matched_value is not None: + # If a matching cache is found, add it to the lists + self.hit_count += 1 + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(self.custom_copy_kv_cache_to_shared_wm(matched_value)) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # If no matching cache is found, generate a new one using zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( + n=1, max_tokens=self.context_length + ) + self.forward( + {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id + ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + return self.keys_values_wm_size_list + + def plot_embeddings( + self, + tsne_results: np.ndarray, + task_ids: np.ndarray, + observations: Union[np.ndarray, torch.Tensor], + samples_per_task: int = 5, + save_dir: str = 'tsne_plots_26games' + ) -> None: + """ + Overview: + Generates a t-SNE visualization plot and annotates it with a specified number of + randomly selected observation images for each task. + + Arguments: + - tsne_results (:obj:`np.ndarray`): The t-SNE dimensionality reduction results (N x 2 array). + - task_ids (:obj:`np.ndarray`): An array of environment task IDs, used for coloring the points (N array). + - observations (:obj:`Union[np.ndarray, torch.Tensor]`): The corresponding observation samples (N x C x H x W tensor or array). + - samples_per_task (:obj:`int`): The number of samples to select for image annotation per task. Defaults to 5. + - save_dir (:obj:`str`): The directory path where the plot will be saved. Defaults to 'tsne_plots_26games'. + """ + # Create the save directory if it doesn't exist. + os.makedirs(save_dir, exist_ok=True) + print(f"[INFO] Save directory created or already exists: {save_dir}") + + # Create the t-SNE plot. + print("[INFO] Starting to draw the t-SNE scatter plot...") + plt.figure(figsize=(18, 10)) # Increase figure width to accommodate the legend on the right. + + # Scatter plot of the t-SNE results. + scatter = plt.scatter( + tsne_results[:, 0], + tsne_results[:, 1], + c=[self.colors[tid] for tid in task_ids], + alpha=0.6, + edgecolor='w', + linewidth=0.5 + ) + + # Create a custom legend for the tasks. + legend_elements = [] + for idx, env_id in enumerate(self.env_id_list): + short_name = self.env_short_names.get(env_id, env_id) + color = self.colors[idx] + legend_elements.append( + Patch(facecolor=color, edgecolor='w', label=f"{idx}: {short_name}") + ) + + # Place the legend on the right side of the plot, with each item on a new line. + plt.legend( + handles=legend_elements, + title="Environment IDs", + loc='center left', + bbox_to_anchor=(1, 0.5), # Position the legend in the center-right of the plot area. + fontsize=10, + title_fontsize=12, + ncol=1, + frameon=False # Remove the legend border for a cleaner look. + ) + + # Set the title and axis labels. + plt.title("t-SNE of Latent States across Environments", fontsize=16) + plt.xlabel("t-SNE Dimension 1", fontsize=14) + plt.ylabel("t-SNE Dimension 2", fontsize=14) + plt.xticks(fontsize=12) + plt.yticks(fontsize=12) + plt.grid(True, linestyle='--', alpha=0.5) + print(f"[INFO] t-SNE scatter plot completed with {len(tsne_results)} points.") + + # Select a specified number of samples per task for image annotation. + print(f"[INFO] Starting to select {samples_per_task} samples per task for image annotation...") + for task_id in range(len(self.env_id_list)): + # Find all indices for the current task. + task_indices = np.where(task_ids == task_id)[0] + if len(task_indices) == 0: + print(f"[WARNING] No samples found for task ID {task_id}.") + continue + + # If the number of samples is less than required, select all of them. + if len(task_indices) < samples_per_task: + selected_indices = task_indices + print(f"[INFO] Task ID {task_id} has fewer samples ({len(task_indices)}) than required ({samples_per_task}). Selecting all.") + else: + selected_indices = np.random.choice(task_indices, size=samples_per_task, replace=False) + print(f"[INFO] Randomly selecting {samples_per_task} samples for task ID {task_id} for annotation.") + + for idx in selected_indices: + img = observations[idx] + if isinstance(img, torch.Tensor): + img = img.cpu().numpy() + + # Handle channel-first (C, H, W) format for grayscale or RGB images. + if img.shape[0] == 1 or img.shape[0] == 3: + img = np.transpose(img, (1, 2, 0)) + else: + raise ValueError(f"Unsupported image shape: {img.shape}") + + # Normalize the image to the [0, 1] range for correct display. + img_min, img_max = img.min(), img.max() + if img_max - img_min > 1e-5: + img = (img - img_min) / (img_max - img_min) + else: + img = np.zeros_like(img) + + imagebox = OffsetImage(img, zoom=0.5) + ab = AnnotationBbox( + imagebox, + (tsne_results[idx, 0], tsne_results[idx, 1]), + frameon=False, + pad=0.3 + ) + plt.gca().add_artist(ab) + print(f"[INFO] Added image annotation: Task ID {task_id}, point index {idx}, t-SNE coords ({tsne_results[idx, 0]:.2f}, {tsne_results[idx, 1]:.2f})") + + # Adjust layout to prevent the legend from being cut off. + plt.tight_layout(rect=[0, 0, 0.9, 1]) # Reserve space for the legend on the right. + + # Save the figure in both PNG and PDF formats with high resolution. + save_path_png = os.path.join(save_dir, 'tsne_plot.png') + save_path_pdf = os.path.join(save_dir, 'tsne_plot.pdf') + plt.savefig(save_path_png, dpi=300, bbox_inches='tight') + plt.savefig(save_path_pdf, dpi=300, bbox_inches='tight') + print(f"[INFO] t-SNE visualization plot saved to: {save_path_png} and {save_path_pdf}") + plt.close() + + @torch.no_grad() + def gather_and_plot( + self, + local_embeddings: torch.Tensor, + local_task_ids: torch.Tensor, + local_observations: torch.Tensor + ) -> None: + """ + Overview: + Gathers embeddings, task IDs, and observations from all distributed processes. + On the main process (rank 0), it performs t-SNE and plots the results. + + Arguments: + - local_embeddings (:obj:`torch.Tensor`): The embedding tensor from the current process. + - local_task_ids (:obj:`torch.Tensor`): The task ID tensor from the current process. + - local_observations (:obj:`torch.Tensor`): The observation tensor from the current process. + """ + world_size = dist.get_world_size() + rank = dist.get_rank() + + # Prepare lists to receive CUDA tensors from all processes. + embeddings_list = [torch.zeros_like(local_embeddings) for _ in range(world_size)] + task_ids_list = [torch.zeros_like(local_task_ids) for _ in range(world_size)] + + # Prepare a list to receive CPU objects (observations) from all processes. + observations_list = [None for _ in range(world_size)] + + try: + # Gather CUDA tensors: embeddings and task_ids. + dist.all_gather(embeddings_list, local_embeddings) + dist.all_gather(task_ids_list, local_task_ids) + + # Gather CPU objects: observations (must be moved to CPU and converted first). + local_observations_cpu = local_observations.cpu().numpy().tolist() + dist.all_gather_object(observations_list, local_observations_cpu) + except RuntimeError as e: + print(f"Rank {rank}: all_gather failed with error: {e}") + return + + if rank == 0: + # Concatenate all embeddings and task_ids on the main process. + all_embeddings = torch.cat(embeddings_list, dim=0).cpu().numpy() + all_task_ids = torch.cat(task_ids_list, dim=0).cpu().numpy() + + # Concatenate all observations. + all_observations_list = [] + for obs in observations_list: + all_observations_list.extend(obs) + all_observations = np.array(all_observations_list) + + print(f"Shape of all_embeddings: {all_embeddings.shape}") + all_embeddings = all_embeddings.reshape(-1, all_embeddings.shape[-1]) + print(f"Shape of all_observations: {all_observations.shape}") + all_observations = all_observations.reshape(-1, *all_observations.shape[-3:]) + + # Perform t-SNE dimensionality reduction. + tsne = TSNE(n_components=2, random_state=42) + tsne_results = tsne.fit_transform(all_embeddings) + + # Plot and save the resulting image. + self.plot_embeddings(tsne_results, all_task_ids, all_observations, save_dir=f'tsne_plots_{self.num_tasks}games') + + #@profile + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id = 0, **kwargs: Any) -> LossWithIntermediateLosses: + # Encode observations into latent state representations + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + if self.analysis_tsne: + # =========== tsne analysis =========== + if not obs_embeddings.is_cuda: + obs_embeddings = obs_embeddings.cuda() + obs_embeddings = obs_embeddings.contiguous() + local_embeddings = obs_embeddings.detach() + local_task_ids = torch.full((local_embeddings.size(0),), task_id, dtype=torch.long, device=local_embeddings.device) + local_observations = batch['observations'].detach().cpu() + self.gather_and_plot(local_embeddings, local_task_ids, local_observations) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio_weight_rank: + self._analysis_step_counter += 1 + self.do_analysis = ( + self.analysis_dormant_ratio_weight_rank # 总开关 + and self._analysis_step_counter % self.analysis_dormant_ratio_interval == 0 + ) + + # ========= logging for analysis ========= + if self.do_analysis: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + if self.continuous_action_space: + encoder_index = task_id + else: + encoder_index = 0 + dormant_ratio_encoder_dict = calculate_dormant_ratio(self.tokenizer.encoder[encoder_index], inputs.detach(), + dormant_threshold=self.dormant_threshold) + + dormant_ratio_encoder = dormant_ratio_encoder_dict['global'] + + avg_weight_mag_encoder = compute_average_weight_magnitude(self.tokenizer.encoder[encoder_index]) + avg_weight_mag_transformer = compute_average_weight_magnitude(self.transformer) + avg_weight_mag_head = compute_average_weight_magnitude(self.head_dict) + + e_rank_last_linear = calculate_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="last_linear") + try: + e_rank_sim_norm = calculate_effective_rank(self.tokenizer.encoder[encoder_index], inputs, representation_layer_name="final_norm") + except Exception as e: + e_rank_sim_norm = torch.tensor(0.) + + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_encoder = torch.tensor(0.) + avg_weight_mag_encoder = torch.tensor(0.) + avg_weight_mag_transformer = torch.tensor(0.) + avg_weight_mag_head = torch.tensor(0.) + e_rank_last_linear = torch.tensor(0.) + e_rank_sim_norm = torch.tensor(0.) + # dormant_ratio_encoder = None + + + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + + if self.obs_type == 'image': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # ==== for value priority ==== + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + # original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + # reconstructed_images) + + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Action tokens + if self.continuous_action_space: + act_tokens = batch['actions'] + else: + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, task_id=task_id) + + if self.config.use_priority: + # ==================== START MODIFICATION 5 ==================== + # Calculate value_priority, similar to MuZero. + with torch.no_grad(): + # 1. Get the predicted value logits for the first step of the sequence (t=0). + # The shape is (B, support_size). + predicted_value_logits_step0 = outputs.logits_value[:, 0, :] + + # 2. Convert the categorical prediction to a scalar value. + # The shape becomes (B, 1). + predicted_scalar_value_step0 = inverse_scalar_transform_handle(predicted_value_logits_step0) + + # 3. Get the target scalar value for the first step from the batch. + # The shape is (B, num_unroll_steps), so we take the first column. + target_scalar_value_step0 = batch['scalar_target_value'][:, 0] + + # 4. Calculate the L1 loss (absolute difference) between prediction and target. + # This is the priority. We use reduction='none' to get per-sample priorities. + value_priority = F.l1_loss(predicted_scalar_value_step0.squeeze(-1), target_scalar_value_step0, reduction='none') + # ===================== END MODIFICATION 5 ===================== + else: + value_priority = torch.tensor(0.) + + # ========= logging for analysis ========= + # if self.analysis_dormant_ratio_weight_rank: + if self.do_analysis: + # Calculate dormant ratio of the world model + dormant_ratio_world_model = calculate_dormant_ratio(self, { + 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, + dormant_threshold=self.dormant_threshold) + dormant_ratio_transformer = dormant_ratio_world_model['transformer'] + dormant_ratio_head = dormant_ratio_world_model['head'] + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_transformer = torch.tensor(0.) + dormant_ratio_head = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== + + # For training stability, use target_tokenizer to compute the true next latent state representations + with torch.no_grad(): + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + # Compute labels for observations, rewards, and ends + labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(target_obs_embeddings, + batch['rewards'], + batch['ends'], + batch['mask_padding']) + + # Reshape the logits and labels for observations + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + if self.use_task_embed and self.task_embed_option == "concat_task_embed": + # Expand task embeddings to match the sequence shape + self.task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = self.sim_norm(self.task_embeddings.view(1,-1)).view(-1) # TODO + task_emb_expanded = self.task_embeddings.expand(labels_observations.shape[0], -1) + labels_observations = torch.cat([labels_observations, task_emb_expanded.detach()], dim=-1) # NOTE: detach() + + # Compute prediction loss for observations. Options: MSE and Group KL + if self.predict_latent_loss_type == 'mse': + # MSE loss, directly compare logits and labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean( + -1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL loss, group features and calculate KL divergence within each group + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + # ========== for debugging ========== + # assert not torch.isnan(logits_reshaped).any(), "logits_reshaped contains NaN values" + # assert not torch.isnan(labels_reshaped).any(), "labels_reshaped contains NaN values" + # print('loss_obs:', loss_obs.mean()) + # for name, param in self.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # logits_grad = torch.autograd.grad(loss_obs.mean(), logits_observations, retain_graph=True)[0] + # print(f"logits_grad (min, max, mean): {logits_grad.min()}, {logits_grad.max()}, {logits_grad.mean()}") + + # Apply mask to loss_obs + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Compute labels for policy and value + labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + # Compute losses for rewards, policy, and value + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + + if not self.continuous_action_space: + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, + batch, + element='policy') + else: + # NOTE: for continuous action space + if self.config.policy_loss_type == 'simple': + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont_simple( + outputs, batch) + else: + orig_policy_loss, policy_entropy_loss, target_policy_entropy, target_sampled_actions, mu, sigma = self._calculate_policy_loss_cont( + outputs, batch, task_id=task_id) + + loss_policy = orig_policy_loss + self.policy_entropy_weight * policy_entropy_loss + policy_entropy = - policy_entropy_loss + + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # Compute timesteps + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + # Compute discount coefficients for each timestep + discounts = self.gamma ** timesteps + + if batch['mask_padding'].sum() == 0: + assert False, "mask_padding is all zeros" + + # Group losses into first step, middle step, and last step + first_step_losses = {} + middle_step_losses = {} + last_step_losses = {} + # batch['mask_padding'] indicates mask status for future H steps, exclude masked losses to maintain accurate mean statistics + # Group losses for each loss item + for loss_name, loss_tmp in zip( + ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + ): + if loss_name == 'loss_obs': + seq_len = batch['actions'].shape[1] - 1 + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, 1:seq_len] + else: + seq_len = batch['actions'].shape[1] + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, :seq_len] + + # Adjust loss shape to (batch_size, seq_len) + loss_tmp = loss_tmp.view(-1, seq_len) + + # First step loss + first_step_mask = mask_padding[:, 0] + first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean() + + # Middle step loss + middle_step_index = seq_len // 2 + middle_step_mask = mask_padding[:, middle_step_index] + middle_step_losses[loss_name] = loss_tmp[:, middle_step_index][middle_step_mask].mean() + + # Last step loss + last_step_mask = mask_padding[:, -1] + last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() + + # Discount reconstruction loss and perceptual loss + discounted_latent_recon_loss = latent_recon_loss + discounted_perceptual_loss = perceptual_loss + + # Calculate overall discounted loss + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).sum()/ batch['mask_padding'][:,1:].sum() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).sum()/ batch['mask_padding'].sum() + + detached_obs_embeddings = obs_embeddings.detach() + + if self.continuous_action_space: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=True, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, + latent_state_l2_norms=latent_state_l2_norms, + policy_mu=mu, + policy_sigma=sigma, + target_sampled_actions=target_sampled_actions, + + value_priority=value_priority, + obs_embeddings=detached_obs_embeddings, + logits_value=outputs.logits_value.detach(), + logits_reward=outputs.logits_rewards.detach(), + logits_policy=outputs.logits_policy.detach(), + + ) + else: + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + continuous_action_space=False, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_transformer=dormant_ratio_transformer, + dormant_ratio_head=dormant_ratio_head, + avg_weight_mag_encoder = avg_weight_mag_encoder, + avg_weight_mag_transformer = avg_weight_mag_transformer, + avg_weight_mag_head = avg_weight_mag_head, + e_rank_last_linear = e_rank_last_linear, + e_rank_sim_norm = e_rank_sim_norm, + latent_state_l2_norms=latent_state_l2_norms, + + value_priority=value_priority, + obs_embeddings=detached_obs_embeddings, + + logits_value=outputs.logits_value.detach(), + logits_reward=outputs.logits_rewards.detach(), + logits_policy=outputs.logits_policy.detach(), + + + ) + + #@profile + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. + # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] + + # Reshape your mask. True indicates valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute cross-entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding) + + # if torch.isnan(loss).any(): + # raise ValueError(f"NaN detected in outputs for batch {batch} and element '{element}'") + + if element == 'policy': + # Compute policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine losses with specified weight + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + #@profile + def compute_policy_entropy_loss(self, logits, mask): + # Compute entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply mask and return average entropy loss + entropy_loss = (entropy * mask) + return entropy_loss + + #@profile + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + mask_fill = torch.logical_not(mask_padding) + + # Prepare observation labels + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] + + # Fill the masked areas of rewards + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + # Fill the masked areas of ends + # labels_ends = ends.masked_fill(mask_fill, -100) + + # return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + return labels_observations, labels_rewards.view(-1, self.support_size), None + + #@profile + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Compute labels for value and policy predictions. """ + mask_fill = torch.logical_not(mask_padding) + + # Fill the masked areas of policy + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + # Fill the masked areas of value + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + + if self.continuous_action_space: + return None, labels_value.reshape(-1, self.support_size) + else: + return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + + #@profile + def clear_caches(self): + """ + Clears the caches of the world model. + """ + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + + print(f'rank {self._rank} Cleared {self.__class__.__name__} past_kv_cache.') + + def __repr__(self) -> str: + return "transformer-based latent world_model of UniZero" diff --git a/lzero/model/utils.py b/lzero/model/utils.py index 70a89d3b0..1204070f9 100644 --- a/lzero/model/utils.py +++ b/lzero/model/utils.py @@ -1,163 +1,319 @@ """ Overview: - In this file, we provide a set of utility functions for probing network parameters and gradients, - which can be helpful in analyzing and debugging the inner workings of various models. + This file provides a set of utility functions for probing network parameters and gradients. + These tools are helpful for analyzing and debugging the inner workings of various models. """ -from typing import List, Tuple +from typing import List, Tuple, Union, Dict, Type, Optional import numpy as np import torch import torch.nn as nn -class LinearOutputHook: +def compute_average_weight_magnitude(model: nn.Module) -> float: """ Overview: - Hook to capture the output of linear layers. + Calculates the average absolute magnitude of all parameters in a given model. + + Arguments: + - model (:obj:`nn.Module`): The model to be evaluated. + + Returns: + - float: The average absolute magnitude of the model's weights. + """ + num_weights = 0 + # Use the device of the model's first parameter to ensure consistency. + device = next(model.parameters()).device + sum_weight_magnitude = torch.tensor(0.0, device=device) + + for p in model.parameters(): + num_weights += p.numel() + sum_weight_magnitude += torch.sum(torch.abs(p)) + + if num_weights == 0: + return 0.0 + return sum_weight_magnitude.cpu().item() / num_weights + + +def compute_effective_rank(singular_values: np.ndarray) -> float: """ + Overview: + Computes the effective rank from an array of singular values. The formula is: + effective_rank = exp(-sum_i [p_i * log(p_i)]), where p_i is the normalized singular value. + + Arguments: + - singular_values (:obj:`np.ndarray`): An array of singular values. + + Returns: + - float: The calculated effective rank. + """ + # Normalize singular values to form a probability distribution. + norm_sv = singular_values / np.sum(np.abs(singular_values)) + entropy = 0.0 + for p in norm_sv: + if p > 1e-8: # Avoid log(0) + entropy -= p * np.log(p) + return np.exp(entropy) + +class IntermediateOutputHook: + """ + Overview: + A hook class to capture and store the output tensors from a specific nn.Module during a forward pass. + """ def __init__(self): + self.outputs: List[torch.Tensor] = [] + + def __call__(self, module: nn.Module, inputs: Tuple[torch.Tensor, ...], output: torch.Tensor) -> None: """ Overview: - Initialize the hook. + This method is called by PyTorch when the hooked module completes its forward pass. """ - self.outputs: List[torch.Tensor] = [] + # Detach the tensor from the computation graph and move to CPU to save memory. + self.outputs.append(output.detach().cpu()) - def __call__(self, module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor) -> None: + def clear(self) -> None: """ Overview: - Capture the output of the module. - Arguments: - - module: The module being hooked. - - input: The input to the module (unused in this hook). - - output: The output from the module. + Clears the list of captured outputs. """ - self.outputs.append(output) + self.outputs.clear() -def cal_dormant_ratio(model: nn.Module, *inputs: torch.Tensor, percentage: float = 0.025) -> float: +def calculate_effective_rank( + model: nn.Module, + inputs: Union[torch.Tensor, List[torch.Tensor]], + representation_layer_name: str, +) -> float: """ Overview: - Calculate the dormant neuron ratio in the model. A neuron is considered dormant if its output is less than a - specified percentage of the average output of the layer. This function is useful for analyzing the sparsity of the model. - More details can be found in the paper https://arxiv.org/abs/2302.12902. + Calculates the effective rank of a specified intermediate layer's output (representation) + by using a forward hook to capture the activations. + Arguments: - - model: The model to evaluate. - - inputs: The inputs to the model. - - percentage: The threshold percentage to consider a neuron dormant, defaults to 0.025. + - model (:obj:`nn.Module`): The model to be evaluated. + - inputs (:obj:`Union[torch.Tensor, List[torch.Tensor]]`): The inputs for the model's forward pass. + - representation_layer_name (:obj:`str`): The name of the representation layer, which must be + findable within `model.named_modules()`. + Returns: - - float: The ratio of dormant neurons in the model. + - float: The effective rank of the representation layer's output. """ - # List to store hooks and their handlers - hooks: List[LinearOutputHook] = [] - hook_handlers: List[torch.utils.hooks.RemovableHandle] = [] - total_neurons: int = 0 - dormant_neurons: int = 0 + module_dict = dict(model.named_modules()) + if representation_layer_name not in module_dict: + raise KeyError(f"Representation layer '{representation_layer_name}' not found in model.named_modules().") + representation_module = module_dict[representation_layer_name] - # Register hooks to capture outputs of specific layers - for _, module in model.named_modules(): - if isinstance(module, (nn.Linear, nn.Conv2d, nn.LSTM)): - hook = LinearOutputHook() - hooks.append(hook) - hook_handlers.append(module.register_forward_hook(hook)) + hook = IntermediateOutputHook() + handle = representation_module.register_forward_hook(hook) + model.eval() with torch.no_grad(): - # Forward pass to capture outputs - model(*inputs) - - # Analyze the captured outputs - for module, hook in zip((module for module in model.modules() if isinstance(module, (nn.Linear, nn.Conv2d, nn.LSTM))), hooks): - with torch.no_grad(): - for output_data in hook.outputs: - mean_output = output_data.abs().mean(0) - avg_neuron_output = mean_output.mean() - dormant_indices = (mean_output < avg_neuron_output * percentage).nonzero(as_tuple=True)[0] - - if isinstance(module, nn.Linear): - # Calculate total and dormant neurons for Linear layers - total_neurons += module.weight.shape[0] * output_data.shape[0] - dormant_neurons += len(dormant_indices) - elif isinstance(module, nn.Conv2d): - # Calculate total and dormant neurons for Conv2D layers - total_neurons += module.weight.shape[0] * output_data.shape[0] * output_data.shape[2] * output_data.shape[3] - dormant_neurons += len(dormant_indices) - elif isinstance(module, nn.LSTM): - # Calculate total and dormant neurons for LSTM layers - total_neurons += module.hidden_size * module.num_layers * output_data.shape[0] * output_data.shape[1] - dormant_neurons += len(dormant_indices) - - # Clean up hooks - for hook in hooks: - hook.outputs.clear() - del hook.outputs - - for hook_handler in hook_handlers: - hook_handler.remove() - del hook_handler - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return dormant_neurons / total_neurons + if isinstance(inputs, (list, tuple)): + _ = model(*inputs) + else: + _ = model(inputs) -def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor: + # Always remove the hook to prevent memory leaks. + handle.remove() + + if not hook.outputs: + raise RuntimeError("No outputs were captured from the representation layer.") + + # Concatenate all captured outputs along the batch dimension. + rep_tensor = torch.cat(hook.outputs, dim=0) if len(hook.outputs) > 1 else hook.outputs[0] + + # Reshape the representation to a 2D matrix (samples, features). + rep_tensor = rep_tensor.view(rep_tensor.size(0), -1) + + # Compute singular values using SVD. + singular_values = np.linalg.svd(rep_tensor.cpu().numpy(), full_matrices=False, compute_uv=False) + + # Calculate the effective rank. + e_rank = compute_effective_rank(singular_values) + + hook.clear() + return e_rank + + +def compute_dormant_stats(outputs: List[torch.Tensor], threshold: float) -> Tuple[int, int]: """ Overview: - Normalize the input data using the max-min-normalization. + Computes element-wise statistics for a list of output tensors from a layer. + Arguments: - - inputs (:obj:`torch.Tensor`): The input data needs to be normalized. - - first_dim (:obj:`int`): The first dimension of flattening the input data. + - outputs (:obj:`List[torch.Tensor]`): A list of tensors, each representing an output from a forward pass. + - threshold (:obj:`float`): The activation threshold below which a neuron is considered dormant. + Returns: - - output (:obj:`torch.Tensor`): The normalized data. + - Tuple[int, int]: A tuple containing the total number of elements and the number of dormant elements. """ - if first_dim < 0: - first_dim = len(inputs.shape) + first_dim - flat_input = inputs.view(*inputs.shape[:first_dim], -1) - max_val = torch.max(flat_input, first_dim, keepdim=True).values - min_val = torch.min(flat_input, first_dim, keepdim=True).values - flat_input = (flat_input - min_val) / (max_val - min_val) + layer_total = 0 + layer_dormant = 0 + for out in outputs: + flattened = out.view(-1) + layer_total += flattened.numel() + layer_dormant += torch.sum(flattened <= threshold).item() + return layer_total, layer_dormant + + +def calculate_dormant_ratio( + model: nn.Module, + inputs: Union[torch.Tensor, List[torch.Tensor]], + dormant_threshold: float = 1e-2, + target_modules: Tuple[Type[nn.Module], ...] = (nn.Conv2d, nn.Linear), +) -> Dict[str, float]: + """ + Overview: + Calculates the dormant ratio (percentage of neurons with activation below a threshold) for + different parts of a model (e.g., encoder, transformer, head). It assumes the model has + attributes like `encoder`, `transformer`, or `head_dict`. + + Arguments: + - model (:obj:`nn.Module`): The model to evaluate, expected to have `encoder`, `transformer`, or `head_dict` attributes. + - inputs (:obj:`Union[torch.Tensor, List[torch.Tensor]]`): The inputs for the model's forward pass. + - dormant_threshold (:obj:`float`): The activation threshold for defining a dormant neuron. Defaults to 1e-2. + - target_modules (:obj:`Tuple[Type[nn.Module], ...]`): A tuple of module types to attach hooks to. - return flat_input.view(*inputs.shape) + Returns: + - Dict[str, float]: A dictionary containing the dormant ratios for each model part and a global ratio. + """ + parts = {} + if hasattr(model, "encoder"): + parts["encoder"] = model.encoder + if hasattr(model, "transformer"): + parts["transformer"] = model.transformer + if hasattr(model, "head_dict"): + parts["head"] = model.head_dict + # Fallback for models that don't have the standard part attributes. + if not parts: + parts["model"] = model -def get_dynamic_mean(model: nn.Module) -> float: - dynamic_mean = np.abs(model.conv.weight.detach().cpu().numpy().reshape(-1)).tolist() + hooks_dict = {part: [] for part in parts} + hook_handles = [] - for block in model.resblocks: - for name, param in block.named_parameters(): - dynamic_mean += np.abs(param.detach().cpu().numpy().reshape(-1)).tolist() - dynamic_mean = sum(dynamic_mean) / len(dynamic_mean) - return dynamic_mean + # Register a forward hook for each target module in each part. + for part_name, submodule in parts.items(): + for name, module in submodule.named_modules(): + if isinstance(module, target_modules): + hook = IntermediateOutputHook() + full_name = f"{part_name}/{name}" + hooks_dict[part_name].append((full_name, hook)) + handle = module.register_forward_hook(hook) + hook_handles.append(handle) + model.eval() + with torch.no_grad(): + if isinstance(inputs, (list, tuple)): + _ = model(*inputs) + else: + _ = model(inputs) -def get_reward_mean(model: nn.Module) -> Tuple[np.ndarray, float]: - reward_w_dist = model.conv1x1_reward.weight.detach().cpu().numpy().reshape(-1) + results = {} + total_global = 0 + dormant_global = 0 - for name, param in model.fc.named_parameters(): - temp_weights = param.detach().cpu().numpy().reshape(-1) - reward_w_dist = np.concatenate((reward_w_dist, temp_weights)) - reward_mean = np.abs(reward_w_dist).mean() - return reward_w_dist, reward_mean + # Calculate dormant stats from captured outputs. + for part, hooks in hooks_dict.items(): + part_total = 0 + part_dormant = 0 + for full_name, hook in hooks: + layer_total, layer_dormant = compute_dormant_stats(hook.outputs, dormant_threshold) + part_total += layer_total + part_dormant += layer_dormant + + results[part] = (part_dormant / part_total) * 100.0 if part_total > 0 else 0.0 + total_global += part_total + dormant_global += part_dormant + results["global"] = (dormant_global / total_global) * 100.0 if total_global > 0 else 0.0 -def get_params_mean(model: nn.Module) -> Tuple[np.ndarray, float, float, float]: - representation_mean = model.representation_network.get_param_mean() - dynamic_mean = model.dynamics_network.get_dynamic_mean() - reward_w_dist, reward_mean = model.dynamics_network.get_reward_mean() + # Clean up all hooks. + for handle in hook_handles: + handle.remove() + for hooks in hooks_dict.values(): + for _, hook in hooks: + hook.clear() - return reward_w_dist, representation_mean, dynamic_mean, reward_mean + return results -def get_gradients(model: nn.Module) -> List[torch.Tensor]: - grads = [] - for p in model.parameters(): - grad = None if p.grad is None else p.grad.detach() - grads.append(grad) - return grads +def renormalize(inputs: torch.Tensor, first_dim: int = 1) -> torch.Tensor: + """ + Overview: + Normalizes the input tensor using min-max scaling. The normalization is applied + over all dimensions starting from `first_dim`. + + Arguments: + - inputs (:obj:`torch.Tensor`): The input tensor to be normalized. + - first_dim (:obj:`int`): The first dimension from which to flatten the tensor for normalization. + + Returns: + - torch.Tensor: The min-max normalized tensor. + """ + if first_dim < 0: + first_dim = inputs.dim() + first_dim + + shape = inputs.shape + flat_input = inputs.view(*shape[:first_dim], -1) + + max_val, _ = torch.max(flat_input, dim=first_dim, keepdim=True) + min_val, _ = torch.min(flat_input, dim=first_dim, keepdim=True) + + # Add a small epsilon to avoid division by zero. + denominator = max_val - min_val + denominator[denominator < 1e-8] = 1e-8 + + normalized_flat = (flat_input - min_val) / denominator + + return normalized_flat.view(*shape) + + +def get_params_mean(model: nn.Module) -> float: + """ + Overview: + Calculates the mean of the absolute values of all parameters in a model. This is an alias + for `compute_average_weight_magnitude`. + Arguments: + - model (:obj:`nn.Module`): The model to be evaluated. + + Returns: + - float: The mean of the absolute parameter values. + """ + return compute_average_weight_magnitude(model) + + +def get_gradients(model: nn.Module) -> List[Optional[torch.Tensor]]: + """ + Overview: + Retrieves the gradients of all parameters in a model. + + Arguments: + - model (:obj:`nn.Module`): The model from which to get gradients. + + Returns: + - List[Optional[torch.Tensor]]: A list of gradient tensors. If a parameter has no gradient, + the corresponding list entry is None. + """ + return [p.grad.detach() if p.grad is not None else None for p in model.parameters()] + + +def set_gradients(model: nn.Module, gradients: List[Optional[torch.Tensor]]) -> None: + """ + Overview: + Sets the gradients for all parameters in a model. + + Arguments: + - model (:obj:`nn.Module`): The model whose gradients are to be set. + - gradients (:obj:`List[Optional[torch.Tensor]]`): A list of gradients to assign to the model's parameters. + """ + params = list(model.parameters()) + if len(gradients) != len(params): + raise ValueError(f"Number of gradients ({len(gradients)}) does not match number of model parameters ({len(params)}).") -def set_gradients(model: nn.Module, gradients: List[torch.Tensor]) -> None: - # TODO due to the drawback of zip operation, we have to check whether gradients match model's parameters - for g, p in zip(gradients, model.parameters()): + for g, p in zip(gradients, params): if g is not None: - p.grad = g + # Ensure the gradient is on the same device as the parameter. + p.grad = g.to(p.device) \ No newline at end of file diff --git a/lzero/model/vit.py b/lzero/model/vit.py new file mode 100644 index 000000000..c70282b7e --- /dev/null +++ b/lzero/model/vit.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +""" +Optimized Vision Transformer (ViT) Model. + +This script provides an optimized implementation of the Vision Transformer (ViT) architecture. +It includes improvements in code structure, clarity, and adherence to modern Python coding standards, +including comprehensive type hinting and documentation. The implementation also supports +integration with Low-Rank Adaptation (LoRA) through a flexible configuration system. + +""" + +from typing import Optional, Tuple, Type, Union + +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from lzero.model.common import SimNorm +from lzero.model.unizero_world_models.transformer import (TransformerConfig, + _maybe_wrap_linear) +from torch import nn + +class ViTConfig: + """ + Overview: + Configuration class for the Vision Transformer (ViT) model. + This class centralizes all hyperparameters, making the model easier to configure and manage. + """ + def __init__(self, **kwargs): + """ + Overview: + Initializes the ViTConfig object. + Arguments: + - **kwargs: Arbitrary keyword arguments to override default settings. + """ + # Image and Patch Dimensions + self.image_size: Union[int, Tuple[int, int]] = 64 + self.patch_size: Union[int, Tuple[int, int]] = 8 + self.channels: int = 3 + + # Model Architecture + self.num_classes: int = 768 + self.dim: int = 768 + self.depth: int = 12 + self.heads: int = 12 + self.mlp_dim: int = 3072 + self.dim_head: int = 64 + + # Pooling and Normalization + self.pool: str = 'cls' # 'cls' or 'mean' + self.final_norm_option_in_encoder: str = 'LayerNorm' # 'LayerNorm' or 'SimNorm' + + # Dropout Rates + self.dropout: float = 0.1 + self.emb_dropout: float = 0.1 + + # LoRA Configuration + self.lora_config: Optional[TransformerConfig] = None + + # Update attributes with any provided keyword arguments + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + print(f"Warning: Ignoring unknown config parameter '{key}'") + + +# ==================== Helper Functions ==================== + +def pair(t: Union[int, Tuple[int, int]]) -> Tuple[int, int]: + """ + Overview: + Converts an integer to a tuple of two identical integers. If the input is already a tuple, it is returned as is. + This is useful for handling kernel sizes, strides, etc., which can be specified as a single number or a tuple. + Arguments: + - t (:obj:`Union[int, Tuple[int, int]]`): The input value. + Returns: + - (:obj:`Tuple[int, int]`): A tuple of two integers. + """ + return t if isinstance(t, tuple) else (t, t) + + +# ==================== Core Modules ==================== + +class FeedForward(nn.Module): + """ + Overview: + A standard feed-forward network block used in Transformer architectures. + It consists of two linear layers with a GELU activation in between. + """ + def __init__( + self, + dim: int, + hidden_dim: int, + dropout: float = 0.0, + config: Optional[TransformerConfig] = None + ): + """ + Overview: + Initializes the FeedForward module. + Arguments: + - dim (:obj:`int`): The input and output dimension. + - hidden_dim (:obj:`int`): The dimension of the hidden layer. + - dropout (:obj:`float`): The dropout rate. + - config (:obj:`Optional[TransformerConfig]`): Configuration for LoRA wrapping. + """ + super().__init__() + self.net = nn.Sequential( + nn.LayerNorm(dim), + _maybe_wrap_linear(nn.Linear(dim, hidden_dim), config, "feed_forward"), + nn.GELU(), + nn.Dropout(dropout), + _maybe_wrap_linear(nn.Linear(hidden_dim, dim), config, "feed_forward"), + nn.Dropout(dropout) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the FeedForward block. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor of shape (batch_size, num_tokens, dim). + Returns: + - (:obj:`torch.Tensor`): The output tensor of the same shape as input. + """ + return self.net(x) + + +class Attention(nn.Module): + """ + Overview: + Multi-Head Self-Attention (MHSA) module. + It computes scaled dot-product attention across multiple heads. + """ + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + config: Optional[TransformerConfig] = None + ): + """ + Overview: + Initializes the Attention module. + Arguments: + - dim (:obj:`int`): The input and output dimension. + - heads (:obj:`int`): The number of attention heads. + - dim_head (:obj:`int`): The dimension of each attention head. + - dropout (:obj:`float`): The dropout rate for attention weights and output. + - config (:obj:`Optional[TransformerConfig]`): Configuration for LoRA wrapping. + """ + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.norm = nn.LayerNorm(dim) + self.attend = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + + # Linear layer to project input to Q, K, V. Potentially wrapped for LoRA. + self.to_qkv = _maybe_wrap_linear(nn.Linear(dim, inner_dim * 3, bias=False), config, "attn") + + # Output projection layer. + if project_out: + # Wrap the linear layer inside the sequential module for LoRA. + wrapped_linear = _maybe_wrap_linear(nn.Linear(inner_dim, dim), config, "attn") + self.to_out = nn.Sequential( + wrapped_linear, + nn.Dropout(dropout) + ) + else: + self.to_out = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the Attention module. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, num_tokens, dim). + Returns: + - (:obj:`torch.Tensor`): Output tensor of the same shape as input. + """ + x = self.norm(x) + + # Project to Q, K, V and split. + qkv = self.to_qkv(x).chunk(3, dim=-1) + # Rearrange for multi-head attention: b n (h d) -> b h n d + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) + + # Scaled dot-product attention. + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + attn = self.attend(dots) + attn = self.dropout(attn) + + # Apply attention to values. + out = torch.matmul(attn, v) + # Rearrange back to original shape: b h n d -> b n (h d) + out = rearrange(out, 'b h n d -> b n (h d)') + + return self.to_out(out) + + +class Transformer(nn.Module): + """ + Overview: + A stack of Transformer blocks, each containing a multi-head self-attention + layer and a feed-forward network. + """ + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + mlp_dim: int, + dropout: float = 0.0, + config: Optional[TransformerConfig] = None + ): + """ + Overview: + Initializes the Transformer module. + Arguments: + - dim (:obj:`int`): The dimension of the token embeddings. + - depth (:obj:`int`): The number of Transformer blocks. + - heads (:obj:`int`): The number of attention heads. + - dim_head (:obj:`int`): The dimension of each attention head. + - mlp_dim (:obj:`int`): The hidden dimension of the feed-forward network. + - dropout (:obj:`float`): The dropout rate. + - config (:obj:`Optional[TransformerConfig]`): Configuration for LoRA. + """ + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, config=config), + FeedForward(dim, mlp_dim, dropout=dropout, config=config) + ])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the Transformer stack. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, num_tokens, dim). + Returns: + - (:obj:`torch.Tensor`): Output tensor of the same shape. + """ + for attn, ff in self.layers: + x = attn(x) + x # Apply attention and residual connection + x = ff(x) + x # Apply feed-forward and residual connection + return self.norm(x) + + +class ViT(nn.Module): + """ + Overview: + Vision Transformer (ViT) model. This model applies the Transformer architecture + to sequences of image patches for image classification tasks. + """ + def __init__(self, config: ViTConfig): + """ + Overview: + Initializes the ViT model using a configuration object. + Arguments: + - config (:obj:`ViTConfig`): A configuration object containing all model hyperparameters. + """ + super().__init__() + self.config = config + + image_height, image_width = pair(config.image_size) + patch_height, patch_width = pair(config.patch_size) + + assert image_height % patch_height == 0 and image_width % patch_width == 0, \ + 'Image dimensions must be divisible by the patch size.' + + num_patches = (image_height // patch_height) * (image_width // patch_width) + patch_dim = config.channels * patch_height * patch_width + assert config.pool in {'cls', 'mean'}, 'pool type must be either "cls" or "mean"' + + # Patch embedding layer + self.to_patch_embedding = nn.Sequential( + Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, config.dim), + nn.LayerNorm(config.dim), + ) + + # Positional embedding and CLS token + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, config.dim)) + self.cls_token = nn.Parameter(torch.randn(1, 1, config.dim)) + self.dropout = nn.Dropout(config.emb_dropout) + + # Transformer encoder stack + self.transformer = Transformer( + dim=config.dim, + depth=config.depth, + heads=config.heads, + dim_head=config.dim_head, + mlp_dim=config.mlp_dim, + dropout=config.dropout, + config=config.lora_config + ) + + self.pool = config.pool + self.last_linear = nn.Linear(config.dim, config.num_classes) + + # Final normalization layer + if config.final_norm_option_in_encoder == 'LayerNorm': + self.final_norm = nn.LayerNorm(config.num_classes, eps=1e-5) + elif config.final_norm_option_in_encoder == 'SimNorm': + group_size = 8 # As specified in original code + self.final_norm = SimNorm(simnorm_dim=group_size) + else: + raise ValueError(f"Unsupported final_norm_option_in_encoder: {config.final_norm_option_in_encoder}") + + def forward(self, img: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass for the ViT model. + Arguments: + - img (:obj:`torch.Tensor`): Input image tensor of shape (batch_size, channels, height, width). + Returns: + - (:obj:`torch.Tensor`): Output logits tensor of shape (batch_size, num_classes). + """ + # 1. Patch embedding + x = self.to_patch_embedding(img) + b, n, _ = x.shape + + # 2. Prepend CLS token + cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) + x = torch.cat((cls_tokens, x), dim=1) + + # 3. Add positional embedding + x += self.pos_embedding[:, :(n + 1)] + x = self.dropout(x) + + # 4. Pass through Transformer encoder + x = self.transformer(x) + + # 5. Pooling + x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] + + # 6. Final classification head + x = self.last_linear(x) + x = self.final_norm(x) + + return x \ No newline at end of file diff --git a/lzero/policy/head_clip_manager.py b/lzero/policy/head_clip_manager.py new file mode 100644 index 000000000..647286013 --- /dev/null +++ b/lzero/policy/head_clip_manager.py @@ -0,0 +1,471 @@ +""" +Head Clip Manager - Dynamic Head Clipping implementation consistent with Encoder-Clip principles + +This module provides dynamic Head Clipping functionality similar to Encoder-Clip: +1. Monitor the range of head outputs (logits) +2. Scale all weights of the entire head module when exceeding the threshold +3. Support annealing (threshold gradually becomes stricter from loose) +4. Support independent configuration for multiple heads + +Differences from previous Head Weight Scaling: +- Before: Static scaling once during initialization +- Now: Dynamic monitoring and scaling during training (consistent with Encoder-Clip) + +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils.convert_parameters import (parameters_to_vector, + vector_to_parameters) + +logging.getLogger().setLevel(logging.INFO) + + +@dataclass +class HeadClipConfig: + """Clip configuration for a single Head""" + # Fixed threshold (if annealing is not enabled) + clip_threshold: float = 15.0 + + # Whether to enable annealing + use_annealing: bool = True + + # Annealing configuration + anneal_type: str = 'cosine' # 'cosine' or 'linear' + start_value: float = 30.0 # Loose in the early phase + end_value: float = 10.0 # Strict in the later phase + anneal_steps: int = 500000 + + def __post_init__(self): + """Validate configuration""" + if self.clip_threshold <= 0: + raise ValueError(f"clip_threshold must be positive, got {self.clip_threshold}") + if self.use_annealing: + if self.start_value <= 0 or self.end_value <= 0: + raise ValueError("start_value and end_value must be positive") + if self.anneal_steps <= 0: + raise ValueError("anneal_steps must be positive") + if self.anneal_type not in ['cosine', 'linear']: + raise ValueError(f"anneal_type must be 'cosine' or 'linear', got {self.anneal_type}") + + +class HeadClipManager: + """ + Head Clip Manager - Dynamically monitor and clip Head outputs + + Working principle (consistent with Encoder-Clip): + 1. After each training iteration, monitor the outputs (logits) of each head + 2. Calculate max(|logits|) + 3. If exceeding the current threshold, scale the weights of the entire head module + 4. Threshold supports annealing (from loose to strict) + """ + + def __init__( + self, + enabled: bool = True, + enabled_heads: Optional[List[str]] = None, + head_configs: Optional[Dict[str, HeadClipConfig]] = None, + monitor_freq: int = 1, + log_freq: int = 1000, + ): + """ + Initialize Head Clip Manager + + Args: + enabled (bool): Whether to enable Head Clip + enabled_heads (List[str], optional): List of heads that need clipping + Example: ['policy', 'value', 'rewards'] + If None, no head will be enabled + head_configs (Dict[str, HeadClipConfig], optional): Configuration for each head + Example: {'policy': HeadClipConfig(...), 'value': HeadClipConfig(...)} + If a head is not in this dictionary, use default configuration + monitor_freq (int): Monitoring frequency (check every N iterations) + log_freq (int): Log printing frequency + """ + self.enabled = enabled + self.enabled_heads = enabled_heads or [] + self.head_configs = head_configs or {} + self.monitor_freq = monitor_freq + self.log_freq = log_freq + + # Statistical information + self.scaling_history = {head: [] for head in self.enabled_heads} + self.iteration_count = 0 + + # Log mapping + self.logits_key_mapping = { + 'policy': 'logits_policy', + 'value': 'logits_value', + 'reward': 'logits_reward', + 'rewards': 'logits_reward', # Compatible with both naming conventions + 'observations': 'logits_observations', + } + + self.head_module_mapping = { + 'policy': 'head_policy', + 'value': 'head_value', + 'reward': 'head_rewards', + 'rewards': 'head_rewards', + 'observations': 'head_observations', + } + + if self.enabled and self.enabled_heads: + logging.info("=" * 60) + logging.info(">>> Head Clip Manager Enabled <<<") + logging.info(f" Enabled heads: {self.enabled_heads}") + logging.info(f" Monitor freq: {self.monitor_freq}") + logging.info(f" Log freq: {self.log_freq}") + for head_name in self.enabled_heads: + config = self.get_head_config(head_name) + if config.use_annealing: + logging.info( + f" {head_name}: annealing {config.start_value:.1f} → {config.end_value:.1f} " + f"over {config.anneal_steps} steps ({config.anneal_type})" + ) + else: + logging.info(f" {head_name}: fixed threshold = {config.clip_threshold:.1f}") + logging.info("=" * 60) + + def get_head_config(self, head_name: str) -> HeadClipConfig: + """ + Get the configuration for the specified head. If it doesn't exist, return the default configuration + + Args: + head_name (str): Name of the head + + Returns: + HeadClipConfig: Configuration object + """ + if head_name in self.head_configs: + return self.head_configs[head_name] + else: + # Return default configuration + return HeadClipConfig() + + def compute_current_threshold( + self, + head_name: str, + train_iter: int + ) -> float: + """ + Compute the threshold for the current training step (considering annealing) + + Args: + head_name (str): Name of the head + train_iter (int): Current training iteration count + + Returns: + float: Current threshold + """ + config = self.get_head_config(head_name) + + if not config.use_annealing: + return config.clip_threshold + + # Calculate annealing progress + progress = min(1.0, train_iter / config.anneal_steps) + + if config.anneal_type == 'cosine': + # Cosine schedule: smooth transition from 1 to 0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_value = config.end_value + \ + (config.start_value - config.end_value) * cosine_progress + else: # 'linear' + current_value = config.start_value * (1 - progress) + \ + config.end_value * progress + + return current_value + + def apply_head_clip( + self, + world_model: nn.Module, + losses: Any, # LossWithIntermediateLosses + train_iter: int + ) -> Dict[str, Dict]: + """ + Apply Head Clip (main function) + + Workflow: + 1. Iterate through all enabled heads + 2. Get the output (logits) of each head + 3. Calculate max(|logits|) + 4. If exceeding the current threshold, scale the entire head module + + Args: + world_model (nn.Module): WorldModel instance + losses (LossWithIntermediateLosses): Loss object containing intermediate outputs + train_iter (int): Current training iteration count + + Returns: + Dict[str, Dict]: Scaling information for each head + Example: { + 'policy': { + 'max_logits': 25.5, + 'threshold': 15.0, + 'scale_factor': 0.588, + 'scaled': True + } + } + """ + if not self.enabled: + return {} + + # Only check at specified frequency + if train_iter % self.monitor_freq != 0: + return {} + + self.iteration_count = train_iter + results = {} + + for head_name in self.enabled_heads: + # 1. Get logits + logits = self._get_head_logits(losses, head_name) + if logits is None: + continue + + # 2. Calculate current threshold + current_threshold = self.compute_current_threshold(head_name, train_iter) + + # 3. Calculate maximum absolute value of logits + max_logits = logits.abs().max().item() + + # 4. Determine if scaling is needed + scaled = False + scale_factor = 1.0 + + if max_logits > current_threshold: + scale_factor = current_threshold / max_logits + + # Get head module + head_module = self._get_head_module(world_model, head_name) + if head_module is not None: + # Scale all weights of the entire head module + success = self._scale_module_weights(head_module, scale_factor) + scaled = success + + if success: + # Record history + self.scaling_history[head_name].append({ + 'iteration': train_iter, + 'max_logits': max_logits, + 'threshold': current_threshold, + 'scale_factor': scale_factor, + }) + + # 5. Record results + results[head_name] = { + 'max_logits': max_logits, + 'threshold': current_threshold, + 'scale_factor': scale_factor, + 'scaled': scaled, + } + + # 6. Print log + if scaled and train_iter % self.log_freq == 0: + logging.info( + f"[Head-Clip] Iter {train_iter}: {head_name} head - " + f"max_logits={max_logits:.2f} > threshold={current_threshold:.2f}, " + f"scaling by {scale_factor:.4f}" + ) + + return results + + def _get_head_logits( + self, + losses: Any, + head_name: str + ) -> Optional[torch.Tensor]: + """ + Get the logits of the specified head from the losses object + + Args: + losses (LossWithIntermediateLosses): Loss object + head_name (str): Name of the head + + Returns: + Optional[torch.Tensor]: Logits tensor, returns None if not found + """ + if not hasattr(losses, 'intermediate_losses'): + return None + + logits_key = self.logits_key_mapping.get(head_name) + if logits_key is None: + return None + + return losses.intermediate_losses.get(logits_key) + + def _get_head_module( + self, + world_model: nn.Module, + head_name: str + ) -> Optional[nn.Module]: + """ + Get the module of the specified head + + Args: + world_model (nn.Module): WorldModel instance + head_name (str): Name of the head + + Returns: + Optional[nn.Module]: Head module, returns None if not found + """ + module_name = self.head_module_mapping.get(head_name) + if module_name is None: + return None + + if hasattr(world_model, module_name): + return getattr(world_model, module_name) + else: + return None + + def _scale_module_weights( + self, + module: nn.Module, + scale_factor: float + ) -> bool: + """ + Scale all weights of the module (consistent with scale_module_weights_vectorized) + + Args: + module (nn.Module): Module to be scaled + scale_factor (float): Scaling factor + + Returns: + bool: Whether the operation was successful + """ + if not (0.0 < scale_factor < 1.0): + return False + + try: + # 1. Flatten all parameters of the module into a single vector + params_vec = parameters_to_vector(module.parameters()) + + # 2. Perform multiplication operation on this vector + params_vec.data.mul_(scale_factor) + + # 3. Copy the scaled vector back to the individual parameters of the module + vector_to_parameters(params_vec, module.parameters()) + + return True + except Exception as e: + logging.error(f"Error scaling module weights: {e}") + return False + + def get_statistics(self) -> Dict: + """ + Get statistical information + + Returns: + Dict: Statistical information + """ + stats = { + 'enabled': self.enabled, + 'total_iterations': self.iteration_count, + 'scaling_history': {}, + } + + for head_name in self.enabled_heads: + history = self.scaling_history.get(head_name, []) + if history: + stats['scaling_history'][head_name] = { + 'total_scalings': len(history), + 'last_scaling': history[-1], + 'average_scale_factor': sum(h['scale_factor'] for h in history) / len(history), + 'total_cumulative_scaling': np.prod([h['scale_factor'] for h in history]), + } + + return stats + + +def create_head_clip_manager_from_dict(config_dict: Dict) -> HeadClipManager: + """ + Create HeadClipManager from a configuration dictionary + + Args: + config_dict (Dict): Configuration dictionary + + Returns: + HeadClipManager: Manager instance + + Example: + config_dict = { + 'enabled': True, + 'enabled_heads': ['policy', 'value'], + 'head_configs': { + 'policy': { + 'use_annealing': True, + 'start_value': 30.0, + 'end_value': 10.0, + 'anneal_steps': 500000, + 'anneal_type': 'cosine', + }, + 'value': { + 'clip_threshold': 20.0, + 'use_annealing': False, + }, + }, + 'monitor_freq': 1, + 'log_freq': 1000, + } + """ + enabled = config_dict.get('enabled', True) + enabled_heads = config_dict.get('enabled_heads', []) + monitor_freq = config_dict.get('monitor_freq', 1) + log_freq = config_dict.get('log_freq', 1000) + + # Parse head_configs + head_configs = {} + head_configs_dict = config_dict.get('head_configs', {}) + for head_name, head_config_dict in head_configs_dict.items(): + head_configs[head_name] = HeadClipConfig(**head_config_dict) + + return HeadClipManager( + enabled=enabled, + enabled_heads=enabled_heads, + head_configs=head_configs, + monitor_freq=monitor_freq, + log_freq=log_freq, + ) + + +if __name__ == "__main__": + # Usage example + print("=" * 60) + print("Head Clip Manager Usage Example") + print("=" * 60) + + # Example 1: Basic configuration + print("\nExample 1: Basic configuration") + config_dict = { + 'enabled': True, + 'enabled_heads': ['policy'], + 'head_configs': { + 'policy': { + 'use_annealing': True, + 'start_value': 30.0, + 'end_value': 10.0, + 'anneal_steps': 500000, + 'anneal_type': 'cosine', + }, + }, + 'monitor_freq': 1, + 'log_freq': 1000, + } + + manager = create_head_clip_manager_from_dict(config_dict) + print(f"Manager created successfully, enabled heads: {manager.enabled_heads}") + + # Example 2: Compute current threshold + print("\nExample 2: Compute current threshold") + for iter in [0, 100000, 250000, 500000]: + threshold = manager.compute_current_threshold('policy', iter) + print(f" Iter {iter}: threshold = {threshold:.2f}") + + print("\n" + "=" * 60) + print("All examples ran successfully!") + print("=" * 60) diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 7bd2e8d2b..bf13543f8 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -11,11 +11,10 @@ from ding.utils import POLICY_REGISTRY from torch.nn import L1Loss -from lzero.entry.utils import initialize_zeros_batch from lzero.mcts import MuZeroMCTSCtree as MCTSCtree from lzero.mcts import MuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms -from lzero.model.utils import cal_dormant_ratio +from lzero.model.utils import calculate_dormant_ratio from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ prepare_obs, configure_optimizers @@ -71,7 +70,7 @@ class MuZeroPolicy(Policy): norm_type='BN', # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, - # (bool) Whether to analyze dormant ratio. + # (bool) Whether to analyze dormant ratio. More details can be found in https://proceedings.mlr.press/v202/sokar23a/sokar23a.pdf. analysis_dormant_ratio=False, # (bool) Whether to use HarmonyDream to balance weights between different losses. Default to False. # More details can be found in https://arxiv.org/abs/2310.00344. @@ -113,7 +112,7 @@ class MuZeroPolicy(Policy): # This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically. eval_offline=False, # (bool) Whether to calculate the dormant ratio. - cal_dormant_ratio=False, + calculate_dormant_ratio=False, # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, # (bool) Whether to analyze dormant ratio. @@ -423,8 +422,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # ========= logging for analysis ========= # calculate dormant ratio of encoder - if self._cfg.cal_dormant_ratio: - self.dormant_ratio_encoder = cal_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(), + if self._cfg.calculate_dormant_ratio: + self.dormant_ratio_encoder = calculate_dormant_ratio(self._learn_model.representation_network, obs_batch.detach(), percentage=self._cfg.dormant_threshold) # calculate L2 norm of latent state latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() @@ -470,7 +469,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) # ========= logging for analysis =============== - if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.calculate_dormant_ratio: # calculate dormant ratio of encoder action_tmp = action_batch[:, step_k] if len(action_tmp.shape) == 1: @@ -486,7 +485,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] ) state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) - self.dormant_ratio_dynamics = cal_dormant_ratio(self._learn_model.dynamics_network, + self.dormant_ratio_dynamics = calculate_dormant_ratio(self._learn_model.dynamics_network, state_action_encoding.detach(), percentage=self._cfg.dormant_threshold) # ========= logging for analysis =============== @@ -941,12 +940,13 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ return output - def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: + def _reset_collect(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: """ Overview: Reset the observation and action for the collector environment. Arguments: - data_id (`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + - task_id (:obj:`int`): The task id. Default is None, which means MuZero is in the single-task mode. """ if self._cfg.model.model_type in ["conv_context"]: self.last_batch_obs = initialize_zeros_batch( @@ -956,12 +956,13 @@ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None: ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + def _reset_eval(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: """ Overview: Reset the observation and action for the evaluator environment. Arguments: - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + - task_id (:obj:`int`): The task id. Default is None, which means MuZero is in the single-task mode. """ if self._cfg.model.model_type in ["conv_context"]: self.last_batch_obs = initialize_zeros_batch( @@ -970,6 +971,7 @@ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: self._cfg.device ) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + def _monitor_vars_learn(self) -> List[str]: """ Overview: diff --git a/lzero/policy/muzero_multitask.py b/lzero/policy/muzero_multitask.py new file mode 100644 index 000000000..91614769e --- /dev/null +++ b/lzero/policy/muzero_multitask.py @@ -0,0 +1,895 @@ +import copy +from typing import List, Dict, Tuple, Union, Optional + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY + +from lzero.mcts import MuZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.model.utils import cal_dormant_ratio +from lzero.policy import ( + scalar_transform, + InverseScalarTransform, + cross_entropy_loss, + phi_transform, + DiscreteSupport, + to_torch_float_tensor, + mz_network_output_unpack, + select_action, + negative_cosine_similarity, + prepare_obs, +) +from lzero.policy.muzero import MuZeroPolicy + + +def generate_task_loss_dict(multi_task_losses: List[float], task_name_template: str, task_id: int) -> Dict[str, float]: + """ + Overview: + Generates a dictionary for the losses of each task. + Arguments: + - multi_task_losses (:obj:`List[float]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): A template for the task name, e.g., 'loss_task{}'. + - task_id (:obj:`int`): The starting global task ID for the current rank. Used to offset task indices when generating task names. + Returns: + - task_loss_dict (:obj:`Dict[str, float]`): A dictionary containing the loss for each task. + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + # Ensure the loss is a scalar value for logging. + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception: + task_loss_dict[task_name] = task_loss + return task_loss_dict + +class WrappedModelV2: + """ + Overview: + A wrapper class to bundle different parts of a model (tokenizer, transformer, embeddings) + for easier management of parameters and gradients. + """ + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> List[torch.nn.Parameter]: + """ + Overview: + Returns a list of all parameters from the tokenizer, transformer, and all embedding layers. + """ + return ( + list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters()) + ) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all parameters in the tokenizer, transformer, and embedding layers to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + +@POLICY_REGISTRY.register('muzero_multitask') +class MuZeroMTPolicy(MuZeroPolicy): + """ + Overview: + The multi-task policy for MuZero, extending MuZeroPolicy. It supports training multiple tasks + simultaneously by separating the loss for each task and optimizing them jointly. + """ + + # Default configuration for MuZeroMTPolicy. + config = dict( + type='muzero_multitask', + model=dict( + model_type='conv', # options={'mlp', 'conv'} + continuous_action_space=False, + observation_shape=(4, 96, 96), # example shape + self_supervised_learning_loss=False, + categorical_distribution=True, + image_channel=1, + frame_stack_num=1, + num_res_blocks=1, + num_channels=64, + support_scale=300, + bias=True, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + harmony_balance=False, + ), + # ****** common ****** + use_rnd_model=False, + multi_gpu=False, + sampled_algo=False, + gumbel_algo=False, + mcts_ctree=True, + cuda=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=200, + eval_offline=False, + cal_dormant_ratio=False, + analysis_sim_norm=False, + analysis_dormant_ratio=False, + + # ****** observation ****** + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + + # ******* learn ****** + use_wandb=False, + ignore_done=False, + update_per_collect=None, + replay_ratio=0.25, + batch_size=256, + optim_type='SGD', + learning_rate=0.2, + target_update_freq=100, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=10, + n_episode=8, + num_segments=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=5, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + policy_entropy_weight=0, + ssl_loss_weight=0, + lr_piecewise_constant_decay=True, + threshold_training_steps_for_final_lr=int(5e4), + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(1e5), + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + + # ****** UCB ****** + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + eps_greedy_exploration_in_collect=False, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + + # ****** Multi-task related ****** + task_num=2, # Number of tasks, adjust as needed. + task_id=0, # The starting ID of the current task. + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Returns the default model configuration for this algorithm. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and a list of import paths. + """ + return 'MuZeroMTModel', ['lzero.model.muzero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Initializes the learning mode. This method sets up the learning model, optimizer, and MCTS utilities. + """ + super()._init_learn() + + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + # NOTE: In board games, for a fixed learning rate of 0.003, 'Adam' performs better than 'SGD'. + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + + # Learning rate scheduler + if self._cfg.lr_piecewise_constant_decay: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: 1, 0.1, 0.01 are decay rates, not the learning rate itself. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # Use model_wrapper for specialized demands of different modes. + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + # Image augmentation + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + + # Support for categorical distribution + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + + # ============================================================== + # HarmonyDream (learnable weights for different losses) + # ============================================================== + if self._cfg.model.harmony_balance: + # List of parameter names. + harmony_names = ["harmony_dynamics", "harmony_policy", "harmony_value", "harmony_reward", "harmony_entropy"] + # Initialize and name each parameter. + for name in harmony_names: + param = torch.nn.Parameter(-torch.log(torch.tensor(1.0))) + setattr(self, name, param) + + # RND model for intrinsic reward + if self._cfg.use_rnd_model: + if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward} + ) + elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum': + self._target_model_for_intrinsic_reward = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward} + ) + + # ========= Logging for analysis ========= + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + self.dormant_ratio_encoder = 0. + self.dormant_ratio_dynamics = 0. + + # Initialize multi-task related parameters. + self.task_num_for_current_rank = self._cfg.task_num + self.task_id = self._cfg.task_id + + def _forward_learn(self, data: List[Tuple[torch.Tensor, torch.Tensor, int]]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning, which is the core of the learning process. + Data is sampled from the replay buffer, and the loss is calculated and backpropagated + to update the model. + Arguments: + - data (:obj:`List[Tuple[torch.Tensor, torch.Tensor, int]]`): A list of data tuples for each task, + where each tuple contains (current_batch, target_batch, task_id). + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): A dictionary of information for logging, + including the current learning loss and other learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + # Initialize lists for multi-task losses. + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + consistency_loss_multi_task = [] + policy_entropy_multi_task = [] + lambd_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + weighted_total_loss = 0.0 # Initialize to zero. + losses_list = [] # To store the loss for each task. + + for task_idx, (current_batch, target_batch, task_id) in enumerate(data): + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Data augmentation. + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to tensor. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [mask_batch, target_reward, target_value, target_policy, weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor( + data_list, self._cfg.device + ) + + target_reward = target_reward.view(self._cfg.batch_size[task_idx], -1) + target_value = target_value.view(self._cfg.batch_size[task_idx], -1) + + assert obs_batch.size(0) == self._cfg.batch_size[task_idx] == target_reward.size(0) + + # Transform rewards and values to scaled representation. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distribution. + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Initial inference. + network_output = self._learn_model.initial_inference(obs_batch, task_id=task_id) + + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # Log Dormant Ratio and L2 Norm. + if self._cfg.cal_dormant_ratio: + self.dormant_ratio_encoder = cal_dormant_ratio( + self._learn_model.representation_network, obs_batch.detach(), + percentage=self._cfg.dormant_threshold + ) + latent_state_l2_norms = torch.norm(latent_state.view(latent_state.shape[0], -1), p=2, dim=1).mean() + + # Inverse transform value. + original_value = self.inverse_scalar_transform_handle(value) + + # Initialize predicted values and policies. + predicted_rewards = [] + if self._cfg.monitor_extra_statistics: + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # Calculate priority. + value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + value_priority = value_priority.data.cpu().numpy() + 1e-6 + + # Calculate policy and value loss for the first step. + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss = -entropy + + reward_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size[task_idx], device=self._cfg.device) + target_policy_entropy = 0 + + # Unroll loop for multiple steps. + for step_k in range(self._cfg.num_unroll_steps): + # Recurrent inference using the dynamics function. + network_output = self._learn_model.recurrent_inference(latent_state, action_batch[:, step_k]) + latent_state, reward, value, policy_logits = mz_network_output_unpack(network_output) + + # Log Dormant Ratio for the dynamics network. + if step_k == self._cfg.num_unroll_steps - 1 and self._cfg.cal_dormant_ratio: + action_tmp = action_batch[:, step_k] + if len(action_tmp.shape) == 1: + action_tmp = action_tmp.unsqueeze(-1) + # Convert action to one-hot encoding. + action_one_hot = torch.zeros(action_tmp.shape[0], policy_logits.shape[-1], device=action_tmp.device) + action_tmp = action_tmp.long() + action_one_hot.scatter_(1, action_tmp, 1) + action_encoding_tmp = action_one_hot.unsqueeze(-1).unsqueeze(-1) + action_encoding = action_encoding_tmp.expand( + latent_state.shape[0], policy_logits.shape[-1], latent_state.shape[2], latent_state.shape[3] + ) + state_action_encoding = torch.cat((latent_state, action_encoding), dim=1) + self.dormant_ratio_dynamics = cal_dormant_ratio( + self._learn_model.dynamics_network, + state_action_encoding.detach(), + percentage=self._cfg.dormant_threshold + ) + + # Inverse transform value. + original_value = self.inverse_scalar_transform_handle(value) + + # Calculate consistency loss (self-supervised learning). + if self._cfg.model.self_supervised_learning_loss and self._cfg.ssl_loss_weight > 0: + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index], task_id=task_id) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + consistency_loss += temp_loss + + # Calculate policy and value losses. + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k]) + + # Calculate policy entropy loss. + prob = torch.softmax(policy_logits, dim=-1) + entropy = -(prob * torch.log(prob + 1e-9)).sum(-1) + policy_entropy_loss += -entropy + + # Calculate target policy entropy (for debugging purposes only). + target_normalized_visit_count = target_policy[:, step_k + 1] + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_policy_entropy += -( + (target_normalized_visit_count_masked + 1e-6) * + torch.log(target_normalized_visit_count_masked + 1e-6) + ).sum(-1).mean() + else: + target_policy_entropy += torch.log( + torch.tensor(target_normalized_visit_count.shape[-1], device=self._cfg.device) + ) + + # Log predicted values and rewards if monitoring extra statistics. + if self._cfg.monitor_extra_statistics: + original_rewards = self.inverse_scalar_transform_handle(reward) + original_rewards_cpu = original_rewards.detach().cpu() + + predicted_values = torch.cat( + (predicted_values, self.inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_rewards.append(original_rewards_cpu) + predicted_policies = torch.cat( + (predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu()) + ) + + # Core learning model update step. + weighted_loss = self._cfg.policy_loss_weight * policy_loss + \ + self._cfg.value_loss_weight * value_loss + \ + self._cfg.reward_loss_weight * reward_loss + \ + self._cfg.ssl_loss_weight * consistency_loss + \ + self._cfg.policy_entropy_weight * policy_entropy_loss + + # Accumulate losses from multiple tasks. + weighted_total_loss += weighted_loss.mean() + + # Store per-task losses for logging. + reward_loss_multi_task.append(reward_loss.mean().item()) + policy_loss_multi_task.append(policy_loss.mean().item()) + value_loss_multi_task.append(value_loss.mean().item()) + consistency_loss_multi_task.append(consistency_loss.mean().item()) + policy_entropy_multi_task.append(policy_entropy_loss.mean().item()) + # TODO: Adjust if using gradient correction. + lambd_multi_task.append(torch.tensor(0., device=self._cfg.device).item()) + value_priority_multi_task.append(value_priority.mean().item()) + value_priority_mean_multi_task.append(value_priority.mean().item()) + losses_list.append(weighted_loss.mean().item()) + + # Zero the optimizer's gradients. + self._optimizer.zero_grad() + + # Backward pass. + weighted_total_loss.backward() + + # Gradient clipping. + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), + self._cfg.grad_clip_value + ) + + # Sync gradients for multi-GPU training. + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + + # Update optimizer. + self._optimizer.step() + if self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + # Update target model. + self._target_model.update(self._learn_model.state_dict()) + + # Get GPU memory usage. + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0.0 + max_memory_allocated_gb = 0.0 + + # Build the return loss dictionary. + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr_world_model': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # Generate task-specific loss dictionaries, prefixing each with "noreduce_". + multi_task_loss_dicts = { + **generate_task_loss_dict(consistency_loss_multi_task, 'noreduce_consistency_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd_multi_task, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + } + + # Merge the dictionaries. + return_loss_dict.update(multi_task_loss_dicts) + + # Return the final loss dictionary. + return return_loss_dict + + def _reset_collect(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: + """ + Overview: + Reset the observation and action for the collector environment. + Arguments: + - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + - task_id (:obj:`int`): The global task ID. + """ + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + + def _reset_eval(self, data_id: Optional[List[int]] = None, task_id: int = None) -> None: + """ + Overview: + Reset the observation and action for the evaluator environment. + Arguments: + - data_id (:obj:`Optional[List[int]]`): List of data ids to reset (not used in this implementation). + - task_id (:obj:`int`): The global task ID. + """ + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + + def _monitor_vars_learn(self, num_tasks: int = None) -> List[str]: + """ + Overview: + Registers variables to be monitored during the learning phase. The registered variables + will be recorded to TensorBoard based on the return value of `_forward_learn`. + If `num_tasks` is provided, it generates monitoring variables for each task. + Arguments: + - num_tasks (:obj:`int`, optional): The number of tasks. + Returns: + - monitored_vars (:obj:`List[str]`): A list of variable names to be monitored. + """ + # Basic monitoring variables. + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # Task-specific monitoring variables. + task_specific_vars = [ + 'noreduce_consistency_loss', + 'noreduce_reward_loss', + 'noreduce_policy_loss', + 'noreduce_value_loss', + 'noreduce_policy_entropy', + 'noreduce_lambd', + 'noreduce_value_priority', + 'noreduce_value_priority_mean', + ] + + # Use self.task_num_for_current_rank as the number of tasks for the current rank. + num_tasks = self.task_num_for_current_rank + print(f'self.task_num_for_current_rank: {self.task_num_for_current_rank}') + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{self.task_id + task_idx}') + else: + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self.collect_epsilon = 0.0 + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([8, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(8)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - epsilon (:obj:`float`): The epsilon of the eps greedy exploration. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The global task ID for the current environments. + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._collect_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, + data, task_id=task_id) + + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + if not self._cfg.collect_with_pure_policy: + # The only difference between collect and eval is the dirichlet noise. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # C++ MCTS tree. + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # Python MCTS tree. + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + # List of lists, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + if self._cfg.eps.eps_greedy_exploration_in_collect: + # Epsilon-greedy exploration for collection. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self.collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # Normal collection. + # NOTE: Only legal actions possess visit counts, so ``action_index_in_legal_action_set`` represents + # the index within the legal action set, not the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + else: + # Pure policy collection (without MCTS). + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), + dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _get_target_obs_index_in_step_k(self, step: int) -> Tuple[int, int]: + """ + Overview: + Get the begin and end indices of the target observation at step k. + Arguments: + - step (:obj:`int`): The current step k. + Returns: + - beg_index (:obj:`int`): The beginning index of the target observation. + - end_index (:obj:`int`): The ending index of the target observation. + """ + if self._cfg.model.model_type in ['conv', 'conv_context']: + beg_index = self._cfg.model.image_channel * step + end_index = self._cfg.model.image_channel * (step + self._cfg.model.frame_stack_num) + elif self._cfg.model.model_type in ['mlp', 'mlp_context']: + beg_index = self._cfg.model.observation_shape * step + end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) + return beg_index, end_index + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + if self._cfg.model.model_type == 'conv_context': + self.last_batch_obs = torch.zeros([3, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(3)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The global task ID for the current environments. + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + if self._cfg.model.model_type in ["conv", "mlp"]: + network_output = self._eval_model.initial_inference(data, task_id=task_id) + elif self._cfg.model.model_type == "conv_context": + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # If not in training, obtain the scalar values of the value/reward. + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape (B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape (B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # C++ MCTS tree. + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # Python MCTS tree. + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # List of lists, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # NOTE: Only legal actions possess visit counts, so ``action_index_in_legal_action_set`` represents + # the index within the legal action set, not the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) + # rather than sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + if self._cfg.model.model_type in ["conv_context"]: + batch_action.append(action) + + if self._cfg.model.model_type in ["conv_context"]: + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output \ No newline at end of file diff --git a/lzero/policy/sampled_muzero.py b/lzero/policy/sampled_muzero.py index 3548c03be..a62d99ff0 100644 --- a/lzero/policy/sampled_muzero.py +++ b/lzero/policy/sampled_muzero.py @@ -1,24 +1,24 @@ import copy -from typing import List, Dict, Any, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import numpy as np import torch -import wandb import torch.optim as optim +import wandb from ding.model import model_wrap from ding.torch_utils import to_tensor from ding.utils import POLICY_REGISTRY from ditk import logging +from lzero.mcts import SampledMuZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import (DiscreteSupport, InverseScalarTransform, + cross_entropy_loss, from, import, lzero.policy, + mz_network_output_unpack, negative_cosine_similarity, + phi_transform, prepare_obs, scalar_transform, + select_action, to_torch_float_tensor) from torch.distributions import Categorical, Independent, Normal from torch.nn import L1Loss -from lzero.mcts import SampledMuZeroMCTSCtree as MCTSCtree -# from lzero.mcts import SampledMuZeroMCTSPtree as MCTSPtree -from lzero.model import ImageTransforms -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, \ - prepare_obs -from lzero.policy.muzero import MuZeroPolicy from .utils import configure_optimizers_nanogpt diff --git a/lzero/policy/sampled_unizero.py b/lzero/policy/sampled_unizero.py index b0485ca4a..a29c00505 100644 --- a/lzero/policy/sampled_unizero.py +++ b/lzero/policy/sampled_unizero.py @@ -1,22 +1,21 @@ import copy import logging from collections import defaultdict -from typing import List, Dict, Tuple, Union +from typing import Dict, List, Tuple, Union import numpy as np import torch import wandb from ding.model import model_wrap from ding.utils import POLICY_REGISTRY - from lzero.mcts import SampledUniZeroMCTSCtree as MCTSCtree from lzero.model import ImageTransforms -from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs, \ - prepare_obs_stack_for_unizero +from lzero.policy import (DiscreteSupport, InverseScalarTransform, + mz_network_output_unpack, phi_transform, prepare_obs, + prepare_obs_stack_for_unizero, scalar_transform, + select_action, to_torch_float_tensor) from lzero.policy.unizero import UniZeroPolicy from .utils import configure_optimizers_nanogpt -from lzero.entry.utils import initialize_zeros_batch def get_action(roots_sampled_actions, i, action): @@ -333,6 +332,7 @@ def _init_learn(self) -> None: if self._cfg.cos_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR + # TODO: check the total training steps self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) @@ -377,9 +377,24 @@ def _init_learn(self) -> None: self.l2_norm_after = 0. self.grad_norm_before = 0. self.grad_norm_after = 0. - self.pad_token_id = 0 # for compatibility - + if self._cfg.model.model_type == 'conv': + # for image-input env + self.pad_token_id = -1 + else: + # for text-input env and vector-input env + # Retrieve the tokenizer from the encoder module if it exists + encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) + + # Extract the padding token ID from the tokenizer if available, otherwise use 0 as default. Used in _reset_collect() + # The pad_token_id is used to identify padding tokens in sequences, which is essential for: + # 1. Masking padded positions during attention computation to prevent them from affecting the output + # 2. Properly handling variable-length sequences in batch processing + # 3. Distinguishing between actual tokens and padding in loss calculation + # Default value 0 is a common convention when no specific padding token is defined + self.pad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 + + # @profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ diff --git a/lzero/policy/sampled_unizero_multitask.py b/lzero/policy/sampled_unizero_multitask.py new file mode 100644 index 000000000..5b2bc162e --- /dev/null +++ b/lzero/policy/sampled_unizero_multitask.py @@ -0,0 +1,986 @@ +import copy +import logging +# Please add the path to your LibMTL library. +# For example: sys.path.append('/path/to/your/LibMTL/') +import sys +from collections import defaultdict +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import wandb +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY, get_rank, get_world_size, set_pkg_seed +# sys.path.append('/path/to/your/LibMTL/') # Template path +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import SampledUniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import (DiscreteSupport, InverseScalarTransform, + mz_network_output_unpack, phi_transform, prepare_obs, + prepare_obs_stack_for_unizero, scalar_transform, + select_action, to_torch_float_tensor) +from lzero.policy.unizero import UniZeroPolicy + +from .utils import configure_optimizers_nanogpt + + +def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]], task_name_template: str, task_id: int) -> Dict[str, float]: + """ + Overview: + Generates a dictionary for losses of each task. + Arguments: + - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): A template for the task name, e.g., 'obs_loss_task{}'. + - task_id (:obj:`int`): The starting global task ID for the current rank. Used to offset task indices when generating task names. + Returns: + - (:obj:`Dict[str, float]`): A dictionary containing the loss for each task. + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + # Convert tensor to float if it has .item(), otherwise cast to float. + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else float(task_loss) + except Exception as e: + # Fallback for cases where conversion fails. + task_loss_dict[task_name] = task_loss + return task_loss_dict + + +class WrappedModelV2: + """ + Overview: + A wrapper class to conveniently manage different parts of a larger model, + such as the tokenizer, transformer, and various embedding layers. This allows for + easier handling of parameters and gradients for these components. + """ + def __init__(self, tokenizer: torch.nn.Module, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Overview: + Initializes the WrappedModelV2 with model components. + Arguments: + - tokenizer (:obj:`torch.nn.Module`): The tokenizer module. + - transformer (:obj:`torch.nn.Module`): The main transformer module. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding layer. + - task_emb (:obj:`torch.nn.Module`): The task embedding layer. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> List[torch.Tensor]: + """ + Overview: + Collects and returns all parameters from the wrapped model components. + Returns: + - (:obj:`List[torch.Tensor]`): A list of all parameters. + """ + return ( + list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + # list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters()) + ) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all wrapped model components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. Defaults to False. + """ + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + # self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + def get_group_parameters(self) -> Dict[str, List[torch.Tensor]]: + """ + Overview: + Returns a dictionary where keys are module names (or finer-grained layers) + and values are the corresponding parameter lists. The order of parameters in the + returned dictionary's values should be consistent with the `parameters()` method. + Returns: + - (:obj:`Dict[str, List[torch.Tensor]]`): A dictionary of grouped parameters. + """ + groups = {} + groups['tokenizer'] = list(self.tokenizer.parameters()) + groups['transformer'] = list(self.transformer.parameters()) + groups['pos_emb'] = list(self.pos_emb.parameters()) + groups['act_embedding_table'] = list(self.act_embedding_table.parameters()) + + # Example of how to add parameters from sub-layers within the transformer. + # This is for demonstration; ensure the order in parameters() is consistent if used. + if hasattr(self.transformer, 'blocks'): + for i, layer in enumerate(self.transformer.blocks): + groups[f'transformer_layer_{i}'] = list(layer.parameters()) + return groups + + +@POLICY_REGISTRY.register('sampled_unizero_multitask') +class SampledUniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for Sampled UniZero Multitask, combining multi-task learning with sampled-based MCTS. + This implementation extends the UniZeroPolicy to handle multiple tasks simultaneously while utilizing + sampled MCTS for action selection. It ensures scalability and correctness in multi-task environments. + """ + + # The default_config for Sampled UniZero Multitask policy. + config = dict( + type='sampled_unizero_multitask', + model=dict( + model_type='conv', # options={'mlp', 'conv'} + continuous_action_space=False, + observation_shape=(3, 64, 64), + self_supervised_learning_loss=True, + categorical_distribution=True, + image_channel=3, + frame_stack_num=1, + num_res_blocks=1, + num_channels=64, + support_scale=50, + bias=True, + res_connection_in_dynamics=True, + norm_type='LN', + analysis_sim_norm=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + tokens_per_block=2, + max_blocks=10, + max_tokens=20, + context_length=8, + gru_gating=False, + device='cpu', + analysis_sim_norm=False, + analysis_dormant_ratio=False, + action_space_size=6, + group_size=8, + attention='causal', + num_layers=2, + num_heads=8, + embed_dim=768, + embed_pdrop=0.1, + resid_pdrop=0.1, + attn_pdrop=0.1, + support_size=101, + max_cache_size=5000, + env_num=8, + latent_recon_loss_weight=0., + perceptual_loss_weight=0., + policy_entropy_weight=5e-3, + predict_latent_loss_type='group_kl', + obs_type='image', + gamma=1, + dormant_threshold=0.01, + policy_loss_type='kl', + ), + ), + use_rnd_model=False, + multi_gpu=True, + sampled_algo=True, + gumbel_algo=False, + mcts_ctree=True, + cuda=True, + collector_env_num=8, + evaluator_env_num=3, + env_type='not_board_games', + action_type='fixed_action_space', + battle_mode='play_with_bot_mode', + monitor_extra_statistics=True, + game_segment_length=400, + analysis_sim_norm=False, + collect_with_pure_policy=False, + eval_freq=int(5e3), + sample_type='transition', + + transform2string=False, + gray_scale=False, + use_augmentation=False, + augmentation=['shift', 'intensity'], + + ignore_done=False, + update_per_collect=None, + replay_ratio=0.25, + batch_size=256, + optim_type='AdamW', + learning_rate=0.0001, + init_w=3e-3, + target_update_freq=100, + target_update_theta=0.05, + target_update_freq_for_intrinsic_reward=1000, + weight_decay=1e-4, + momentum=0.9, + grad_clip_value=5, + n_episode=8, + num_simulations=50, + discount_factor=0.997, + td_steps=5, + num_unroll_steps=10, + reward_loss_weight=1, + value_loss_weight=0.25, + policy_loss_weight=1, + ssl_loss_weight=0, + cos_lr_scheduler=False, + piecewise_decay_lr_scheduler=False, + threshold_training_steps_for_final_lr=int(5e4), + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(1e5), + fixed_temperature_value=0.25, + use_ture_chance_label_in_chance_encoder=False, + + use_priority=False, + priority_prob_alpha=0.6, + priority_prob_beta=0.4, + train_start_after_envsteps=0, + + root_dirichlet_alpha=0.3, + root_noise_weight=0.25, + + random_collect_episode_num=0, + + eps=dict( + eps_greedy_exploration_in_collect=False, + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm's default model setting for demonstration. + Returns: + - (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and the import paths. + """ + return 'SampledUniZeroMTModel', ['lzero.model.sampled_unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Initializes the learning mode. This method sets up the learn model, optimizer, + target model, and other utilities required for training, such as LR schedulers + and gradient correction methods (e.g., MoCo). + """ + # Configure optimizer for the world model using NanoGPT's configuration utility. + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + # Initialize learning rate schedulers if configured. + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR + + if self._cfg.cos_lr_scheduler: + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, T_max=int(1e5), eta_min=0, last_epoch=-1 + ) + elif self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler = StepLR( + self._optimizer_world_model, step_size=int(5e4), gamma=0.1 + ) + + # Initialize weights for continuous action spaces. + if self._cfg.model.continuous_action_space: + init_w = self._cfg.init_w + self._model.world_model.fc_policy_head.mu.weight.data.uniform_(-init_w, init_w) + self._model.world_model.fc_policy_head.mu.bias.data.uniform_(-init_w, init_w) + try: + self._model.world_model.fc_policy_head.log_sigma_layer.weight.data.uniform_(-init_w, init_w) + self._model.world_model.fc_policy_head.log_sigma_layer.bias.data.uniform_(-init_w, init_w) + except Exception as exception: + logging.warning(exception) + + # Initialize and compile the target model. + self._target_model = copy.deepcopy(self._model) + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "Torch version 2.0 or higher is required." + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + + # Wrap the target model for soft updates (momentum-based). + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + # Initialize utilities for loss calculation and transformations. + + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + + self.inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') + + # Initialize gradient correction method (MoCo) if enabled. + if self._cfg.use_moco or self._cfg.only_use_moco_stats: + # Wrap model components for gradient correction. Note: Heads are not included. + wrapped_model = WrappedModelV2( + self._learn_model.world_model.tokenizer.encoder, # TODO: This might contain one or multiple encoders. + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # TODO: The GradCorrect class might need adjustments for multi-GPU training compatibility. + # Initialize the gradient correction mechanism. + self.grad_correct = GradCorrect(wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) + + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + + + encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) + self.pad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 + + + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights: Any = None, ignore_grad: bool = False) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward pass for training. This method processes a batch of data for multiple tasks, + computes losses, and updates the model weights. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): A tuple of data batches, one for each task. + - task_weights (:obj:`Any`): Weights for each task's loss. Defaults to None. + - ignore_grad (:obj:`bool`): If True, gradients are zeroed out after computation, effectively skipping the update. Defaults to False. + Returns: + - (:obj:`Dict[str, Union[float, int]]`): A dictionary containing various loss values and training statistics. + """ + self._learn_model.train() + self._target_model.train() + + # Initialize lists to store losses and metrics for each task. + task_weight_multi_task, obs_loss_multi_task, reward_loss_multi_task = [], [], [] + policy_loss_multi_task, orig_policy_loss_multi_task, policy_entropy_multi_task = [], [], [] + value_loss_multi_task, latent_recon_loss_multi_task, perceptual_loss_multi_task = [], [], [] + latent_state_l2_norms_multi_task, average_target_policy_entropy_multi_task = [], [] + value_priority_multi_task, value_priority_mean_multi_task = [], [] + + weighted_total_loss = 0.0 + losses_list = [] # Stores the individual loss tensor for each task. + + for task_id, data_one_task in enumerate(data): + # Unpack data for the current task. + current_batch, target_batch, task_id = data_one_task + obs_batch_ori, action_batch, child_sampled_actions_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations. + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg, task_id) + + # Apply data augmentation if enabled. + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare actions and convert data to torch tensors. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1) + if not self._cfg.model.continuous_action_space: + action_batch = action_batch.long() + + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, self._cfg.device) + + cur_batch_size = target_reward.size(0) + target_reward = target_reward.view(cur_batch_size, -1) + target_value = target_value.view(cur_batch_size, -1) + + # Transform scalar targets to their categorical representation. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Prepare the batch for the GPT-based world model. + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape_list[task_id], int) or len(self._cfg.model.observation_shape_list[task_id]) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(cur_batch_size, -1, self._cfg.model.observation_shape_list[task_id]) + else: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape(cur_batch_size, -1, *self._cfg.model.observation_shape_list[task_id]) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['child_sampled_actions'] = torch.from_numpy(child_sampled_actions_batch).to(self._cfg.device)[:, :-1] + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = (mask_batch == 1.0)[:, :-1] # 0 indicates invalid padding data. + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # Compute target policy entropy for monitoring. + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Compute losses using the world model. + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id + ) + + # Accumulate weighted total loss. + current_task_weight = task_weights[task_id] if task_weights is not None else 1 + weighted_total_loss += losses.loss_total * current_task_weight + losses_list.append(losses.loss_total * current_task_weight) + task_weight_multi_task.append(current_task_weight) + + # Store intermediate losses for logging. + for loss_name, loss_value in losses.intermediate_losses.items(): + self.intermediate_losses[f"{loss_name}"] = loss_value + + # Collect individual losses for the current task. + obs_loss_multi_task.append(self.intermediate_losses.get('loss_obs', 0.0) or 0.0) + reward_loss_multi_task.append(self.intermediate_losses.get('loss_rewards', 0.0) or 0.0) + policy_loss_multi_task.append(self.intermediate_losses.get('loss_policy', 0.0) or 0.0) + orig_policy_loss_multi_task.append(self.intermediate_losses.get('orig_policy_loss', 0.0) or 0.0) + policy_entropy_multi_task.append(self.intermediate_losses.get('policy_entropy', 0.0) or 0.0) + value_loss_multi_task.append(self.intermediate_losses.get('loss_value', 0.0) or 0.0) + latent_recon_loss_multi_task.append(self.intermediate_losses.get('latent_recon_loss', 0.0) or 0.0) + perceptual_loss_multi_task.append(self.intermediate_losses.get('perceptual_loss', 0.0) or 0.0) + latent_state_l2_norms_multi_task.append(self.intermediate_losses.get('latent_state_l2_norms', 0.0) or 0.0) + average_target_policy_entropy_multi_task.append(average_target_policy_entropy) + value_priority = torch.tensor(0., device=self._cfg.device) # Placeholder + value_priority_multi_task.append(value_priority) + value_priority_mean_multi_task.append(value_priority.mean().item()) + + # --- Model Update Step --- + self._optimizer_world_model.zero_grad() + + # Perform backward pass, either with or without gradient correction. + if self._cfg.use_moco: + # Use MoCo for gradient correction and backpropagation. + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + elif self._cfg.only_use_moco_stats: + # Compute MoCo stats but perform standard backpropagation. + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + weighted_total_loss.backward() + else: + # Standard backpropagation without gradient correction. + lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) + weighted_total_loss.backward() + + # Clip gradients to prevent exploding gradients. + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), self._cfg.grad_clip_value) + + # NOTE: If ignore_grad is True, zero out gradients. This is useful for DDP synchronization + # when a GPU has finished all its tasks but still needs to participate in the training step. + if ignore_grad: + self._optimizer_world_model.zero_grad() + + # Synchronize gradients across GPUs in multi-GPU setup. + if self._cfg.multi_gpu: + if not self._cfg.use_moco: + # TODO: Investigate if a barrier is needed here for synchronization. + # dist.barrier() + self.sync_gradients(self._learn_model) + + # Update model parameters. + self._optimizer_world_model.step() + + # Step the learning rate scheduler. + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Update the target model using a soft update rule. + self._target_model.update(self._learn_model.state_dict()) + + # Monitor GPU memory usage. + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated_gb = torch.cuda.memory_allocated() / (1024 ** 3) + max_memory_allocated_gb = torch.cuda.max_memory_allocated() / (1024 ** 3) + else: + current_memory_allocated_gb, max_memory_allocated_gb = 0., 0. + + # --- Logging and Return --- + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # Generate and merge task-specific loss dictionaries. + # The "noreduce_" prefix indicates these are per-rank values before DDP reduction. + multi_task_loss_dicts = { + **generate_task_loss_dict(task_weight_multi_task, 'noreduce_task_weight_task{}', self.task_id), + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', self.task_id), + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', self.task_id), + } + return_loss_dict.update(multi_task_loss_dicts) + + # Log to wandb if enabled. + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in return_loss_dict.items()}, step=self.env_step) + wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) + + return return_loss_dict + + def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]: + """ + Overview: + Specifies the variables to be monitored during training. These variables will be logged + (e.g., to TensorBoard) based on the dictionary returned by `_forward_learn`. + Arguments: + - num_tasks (:obj:`int`): The number of tasks to generate monitored variables for. This argument is for API consistency and is overridden by `self.task_num_for_current_rank`. + Returns: + - (:obj:`List[str]`): A list of variable names to monitor. + """ + # Basic monitored variables, independent of the number of tasks. + monitored_vars = [ + 'Current_GPU', 'Max_GPU', 'collect_epsilon', 'collect_mcts_temperature', + 'cur_lr_world_model', 'weighted_total_loss', 'total_grad_norm_before_clip_wm', + ] + + # Task-specific variables. + task_specific_vars = [ + 'noreduce_task_weight', 'noreduce_obs_loss', 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', 'noreduce_latent_recon_loss', 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', 'noreduce_reward_loss', 'noreduce_value_loss', + 'noreduce_perceptual_loss', 'noreduce_latent_state_l2_norms', 'noreduce_lambd', + 'noreduce_value_priority_mean', + ] + + # The number of tasks handled by the current rank. + num_tasks_on_rank = self.task_num_for_current_rank + + # Generate full variable names for each task on the current rank. + if num_tasks_on_rank is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks_on_rank): + # The task ID is offset by the base task ID for this rank. + monitored_vars.append(f'{var}_task{self.task_id + task_idx}') + else: + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def monitor_weights_and_grads(self, model: torch.nn.Module) -> None: + """ + Overview: + A utility function to monitor and print the statistics (mean, std) of model weights and their gradients. + Arguments: + - model (:obj:`torch.nn.Module`): The model to inspect. + """ + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Initializes the collection mode. This method sets up the collect model, MCTS utilities, + and initial states for the collector environments. + """ + self._collect_model = self._model + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self._task_weight_temperature = 10. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + + # Initialize placeholders for the last observation and action batches. + if self._cfg.model.model_type == 'conv': + obs_shape = [self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64] + self.last_batch_obs = torch.zeros(obs_shape, device=self._cfg.device) + elif self._cfg.model.model_type == 'mlp': + obs_shape = [self.collector_env_num, self._cfg.model.observation_shape_list[0]] + self.last_batch_obs = torch.zeros(obs_shape, device=self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.collector_env_num)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: List = None, + temperature: float = 1.0, + to_play: List[int] = [-1], + epsilon: float = 0.25, + ready_env_id: np.ndarray = None, + timestep: List[int] = [0], + task_id: int = None, + ) -> Dict[int, Dict[str, Any]]: + """ + Overview: + The forward pass for data collection. It uses MCTS to select actions for the current states. + Arguments: + - data (:obj:`torch.Tensor`): The current batch of observations. + - action_mask (:obj:`List`): A list of action masks for each environment. + - temperature (:obj:`float`): The temperature parameter for MCTS action selection. + - to_play (:obj:`List[int]`): A list indicating the current player for each environment. + - epsilon (:obj:`float`): The exploration noise parameter. + - ready_env_id (:obj:`np.ndarray`): An array of environment IDs that are ready for action. + - timestep (:obj:`List[int]`): The current timestep for each environment. + - task_id (:obj:`int`): The global task ID for the current environments. + Returns: + - (:obj:`Dict[int, Dict[str, Any]]`): A dictionary mapping environment IDs to action selection results. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + # 1. Initial inference to get root information. + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + # 2. Prepare MCTS roots. + if not self._cfg.model.continuous_action_space: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + else: + legal_actions = [[-1] * self._cfg.model.world_model_cfg.num_of_sampled_actions for _ in range(active_collect_env_num)] + + noises = [np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.world_model_cfg.num_of_sampled_actions).astype(np.float32).tolist() for _ in range(active_collect_env_num)] + + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots(active_collect_env_num, legal_actions, self._cfg.model.world_model_cfg.action_space_size, self._cfg.model.world_model_cfg.num_of_sampled_actions, self._cfg.model.continuous_action_space) + else: + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + + # 3. MCTS search. + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep=timestep, task_id=task_id) + + # 4. Get results from MCTS and select actions. + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + roots_sampled_actions = roots.get_sampled_actions() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + root_sampled_actions = np.array([getattr(action, 'value', action) for action in roots_sampled_actions[i]]) + + # Select action based on visit counts, with temperature for exploration. + action_idx, visit_count_distribution_entropy = select_action(distributions, temperature=self._collect_mcts_temperature, deterministic=False) + action = root_sampled_actions[action_idx] + if not self._cfg.model.continuous_action_space: + action = int(action.item()) + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'root_sampled_actions': root_sampled_actions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + # 5. Update state for the next step. + self.last_batch_obs = data + self.last_batch_action = batch_action + + # Reset collector if the number of active environments is less than expected. + if active_collect_env_num < self.collector_env_num: + logging.warning(f'Number of active envs ({active_collect_env_num}) is less than collector_env_num ({self.collector_env_num}). Resetting collector.') + self._reset_collect(reset_init_data=True, task_id=task_id) + + return output + + def _init_eval(self) -> None: + """ + Overview: + Initializes the evaluation mode. This method sets up the evaluation model, MCTS utilities, + and initial states for the evaluator environments. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + self.task_id_for_eval = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + # Initialize placeholders for the last observation and action batches for evaluation. + if self._cfg.model.model_type == 'conv': + obs_shape = [self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64] + self.last_batch_obs_eval = torch.zeros(obs_shape, device=self._cfg.device) + elif self._cfg.model.model_type == 'mlp': + # TODO: Ensure observation_shape_list is correctly indexed for the evaluation task. + obs_shape = [self.evaluator_env_num, self._cfg.model.observation_shape_list[self.task_id_for_eval]] + self.last_batch_obs_eval = torch.zeros(obs_shape, device=self._cfg.device) + print(f'rank {get_rank()} last_batch_obs_eval shape: {self.last_batch_obs_eval.shape}') + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.ndarray = None, timestep: List[int] = [0], task_id: int = None) -> Dict[int, Dict[str, Any]]: + """ + Overview: + The forward pass for evaluation. It uses MCTS to select actions deterministically. + Arguments: + - data (:obj:`torch.Tensor`): The current batch of observations. + - action_mask (:obj:`List`): A list of action masks for each environment. + - to_play (:obj:`int`): The current player. + - ready_env_id (:obj:`np.ndarray`): An array of environment IDs that are ready for action. + - timestep (:obj:`List[int]`): The current timestep for each environment. + - task_id (:obj:`int`): The global task ID for the current environments. + Returns: + - (:obj:`Dict[int, Dict[str, Any]]`): A dictionary mapping environment IDs to action selection results. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + # 1. Initial inference. + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + # 2. Prepare MCTS roots without noise for deterministic evaluation. + if not self._cfg.model.continuous_action_space: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + else: + legal_actions = [[-1] * self._cfg.model.world_model_cfg.num_of_sampled_actions for _ in range(active_eval_env_num)] + + if self._cfg.mcts_ctree: + roots = MCTSCtree.roots(active_eval_env_num, legal_actions, self._cfg.model.world_model_cfg.action_space_size, self._cfg.model.world_model_cfg.num_of_sampled_actions, self._cfg.model.continuous_action_space) + else: + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + + # 3. MCTS search. + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep=timestep, task_id=task_id) + + # 4. Get results and select actions deterministically. + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + roots_sampled_actions = roots.get_sampled_actions() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + root_sampled_actions = np.array([getattr(action, 'value', action) for action in roots_sampled_actions[i]]) + + # Select action deterministically (greedy selection from visit counts). + action_idx, visit_count_distribution_entropy = select_action(distributions, temperature=1, deterministic=True) + action = root_sampled_actions[action_idx] + if not self._cfg.model.continuous_action_space: + action = int(action.item()) + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'root_sampled_actions': root_sampled_actions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + # 5. Update state for the next evaluation step. + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the collector state. This can be a full reset of initial data or a periodic + clearing of model caches to manage memory. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None, applies to all. + - current_steps (:obj:`int`): The current number of steps, used for periodic cache clearing. + - reset_init_data (:obj:`bool`): Whether to reset the initial observation and action batches. + - task_id (:obj:`int`, optional): The global task ID, used to determine observation shape. + """ + if reset_init_data: + obs_shape = self._cfg.model.observation_shape_list[task_id] if task_id is not None else self._cfg.model.observation_shape + self.last_batch_obs = initialize_zeros_batch(obs_shape, self._cfg.collector_env_num, self._cfg.device) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + logging.info(f'Collector: last_batch_obs and last_batch_action have been reset. Shape: {self.last_batch_obs.shape}') + + if env_id is None or isinstance(env_id, list): + return + + # Periodically clear model caches to free up memory. + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + if current_steps > 0 and current_steps % clear_interval == 0: + logging.info(f'Clearing model caches at step {current_steps}.') + world_model = self._collect_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + torch.cuda.empty_cache() + logging.info('Collector: collect_model caches cleared.') + self._reset_target_model() + + def _reset_target_model(self) -> None: + """ + Overview: + Resets the caches of the target model to free up GPU memory. + """ + world_model = self._target_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + torch.cuda.empty_cache() + logging.info('Collector: target_model caches cleared.') + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Returns the state dictionary of the learning components. + Returns: + - (:obj:`Dict[str, Any]`): A dictionary containing the state of the model, target model, and optimizer. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Loads the state dictionary into the learning components. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary to load. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + + # TODO: The following is a version for pretrain-finetune workflow, which only loads backbone parameters. + # def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + # """ + # Overview: + # Loads a state_dict into the policy's learn mode, but excludes parameters related to + # multi-task heads and task embeddings. This is useful for fine-tuning a pre-trained model + # on a new set of tasks. + # Arguments: + # - state_dict (:obj:`Dict[str, Any]`): The dict of the policy learn state saved previously. + # """ + # # Define prefixes of parameters to exclude (e.g., multi-task heads, task embeddings). + # exclude_prefixes = [ + # '_orig_mod.world_model.head_policy_multi_task.', + # '_orig_mod.world_model.head_value_multi_task.', + # '_orig_mod.world_model.head_rewards_multi_task.', + # '_orig_mod.world_model.head_observations_multi_task.', + # '_orig_mod.world_model.task_emb.' + # ] + + # # Define specific keys to exclude if they don't fit a prefix pattern. + # exclude_keys = [ + # '_orig_mod.world_model.task_emb.weight', + # '_orig_mod.world_model.task_emb.bias', + # ] + + # def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + # """ + # Filters out parameters that should not be loaded. + # """ + # filtered = {} + # for k, v in state_dict_loader.items(): + # if any(k.startswith(prefix) for prefix in exclude_prefixes) or k in exclude_keys: + # print(f"Excluding parameter from loading: {k}") + # continue + # filtered[k] = v + # return filtered + + # # Filter and load state_dict for the main model. + # if 'model' in state_dict: + # model_state_dict = state_dict['model'] + # filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + # missing, unexpected = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + # if missing: + # print(f"Missing keys when loading _learn_model: {missing}") + # if unexpected: + # print(f"Unexpected keys when loading _learn_model: {unexpected}") + # else: + # print("Warning: 'model' key not found in the state_dict.") + + # # Filter and load state_dict for the target model. + # if 'target_model' in state_dict: + # target_model_state_dict = state_dict['target_model'] + # filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + # missing, unexpected = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + # if missing: + # print(f"Missing keys when loading _target_model: {missing}") + # if unexpected: + # print(f"Unexpected keys when loading _target_model: {unexpected}") + # else: + # print("Warning: 'target_model' key not found in the state_dict.") + + # # Load optimizer state_dict. This is often skipped during fine-tuning, but included here for completeness. + # if 'optimizer_world_model' in state_dict: + # try: + # self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) + # except Exception as e: + # print(f"Could not load optimizer state_dict: {e}. This may be expected during fine-tuning.") + # else: + # print("Warning: 'optimizer_world_model' key not found in the state_dict.") diff --git a/lzero/policy/scaling_transform.py b/lzero/policy/scaling_transform.py index 19a852f56..ac3305332 100644 --- a/lzero/policy/scaling_transform.py +++ b/lzero/policy/scaling_transform.py @@ -1,4 +1,5 @@ from typing import Union + import torch @@ -11,7 +12,6 @@ def __init__(self, start: float, stop: float, step: float = 1., device: Union[st assert self.size > 0, "DiscreteSupport size must be greater than 0" self.step = step - def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) -> torch.Tensor: """ Overview: @@ -107,33 +107,44 @@ def visit_count_temperature( return fixed_temperature_value + def phi_transform( discrete_support: DiscreteSupport, x: torch.Tensor, + label_smoothing_eps: float = 0.0 ) -> torch.Tensor: """ Overview: - Map a real-valued scalar to a categorical distribution over a discrete support using linear interpolation (a.k.a. “soft” one-hot). + Map a real-valued scalar to a categorical distribution over a discrete support + using linear interpolation (a.k.a. "soft" one-hot). - For each scalar value the probability mass is split between the two + For each scalar value, the probability mass is split between the two nearest support atoms so that their weighted sum equals the original - value (MuZero, Appendix F). + value (see MuZero, Appendix F). Arguments: - discrete_support : DiscreteSupport Container with the support values (must be evenly spaced). - x : torch.Tensor Input tensor of arbitrary shape ``(...,)`` containing real numbers. + - label_smoothing_eps : float + Epsilon value for label smoothing (default: 0). When > 0, mixes the target + distribution with a uniform distribution to: + - Prevent overconfidence and overfitting to discrete support atoms + - Improve generalization through smoother value/reward representations + - Enhance numerical stability during training + + Formula: smooth_target = (1 - ε) * target + ε / N, where N = support size. Returns: - torch.Tensor Tensor of shape ``(*x.shape, N)`` where ``N = discrete_support.size``. - The last dimension is a probability distribution (sums to 1). + The last dimension represents a probability distribution (sums to 1). Notes ----- - • No in-place ops on the input are used, improving autograd safety. - • Only one `scatter_add_` kernel is launched for efficiency. + • No in-place ops on the input are used, improving autograd safety. + • Only one `scatter_add_` kernel is launched for efficiency. """ # --- constants ---------------------------------------------------------- min_bound = discrete_support.arange[0, 0] @@ -141,20 +152,21 @@ def phi_transform( step = discrete_support.step size = discrete_support.size - # --- 1. clip to the valid range ---------------------------------------- + # --- 1. Clip to the valid range ---------------------------------------- x = x.clamp(min_bound, max_bound) - # --- 2. locate neighbouring indices ------------------------------------ - pos = (x - min_bound) / step # continuous position - low_idx_float = torch.floor(pos) # lower index - low_idx_long = low_idx_float.long() # lower index - high_idx = low_idx_long + 1 # upper index (may overflow) + # --- 2. Locate neighbouring indices ------------------------------------ + pos = (x - min_bound) / step # Continuous position relative to support + low_idx_float = torch.floor(pos) # Lower index (float) + low_idx_long = low_idx_float.long() # Lower index (long) + high_idx = low_idx_long + 1 # Upper index (may temporarily overflow) - # --- 3. linear interpolation weights ----------------------------------- - p_high = pos - low_idx_float # distance to lower atom - p_low = 1.0 - p_high # complementary mass + # --- 3. Linear interpolation weights ----------------------------------- + p_high = pos - low_idx_float # Distance to the lower atom (weight for upper) + p_low = 1.0 - p_high # Complementary mass (weight for lower) - # --- 4. stack indices / probs and scatter ------------------------------ + # --- 4. Stack indices / probs and scatter ------------------------------ + # Clamp high_idx to handle the edge case where x is exactly max_bound idx = torch.stack([low_idx_long, torch.clamp(high_idx, max=size - 1)], dim=-1) # (*x, 2) prob = torch.stack([p_low, p_high], dim=-1) # (*x, 2) @@ -163,7 +175,14 @@ def phi_transform( dtype=x.dtype, device=x.device) target.scatter_add_(-1, idx, prob) - return target + + # --- 5. Apply label smoothing ------------------------------------------ + if label_smoothing_eps > 0: + # Mix the original "two-hot" target with a uniform distribution + # smooth_target + return (1.0 - label_smoothing_eps) * target + (label_smoothing_eps / size) + else: + return target def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: diff --git a/lzero/policy/tests/config/atari_muzero_config_for_test.py b/lzero/policy/tests/config/atari_muzero_config_for_test.py index b5fbee948..31b1e4686 100644 --- a/lzero/policy/tests/config/atari_muzero_config_for_test.py +++ b/lzero/policy/tests/config/atari_muzero_config_for_test.py @@ -50,6 +50,8 @@ self_supervised_learning_loss=True, # default is False discrete_action_encoding_type='one_hot', norm_type='BN', + value_support_range=(-300., 301., 1.), + reward_support_range=(-300., 301., 1.), ), cuda=True, env_type='not_board_games', diff --git a/lzero/policy/tests/config/cartpole_muzero_config_for_test.py b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py index 0c899608e..0c587d06a 100644 --- a/lzero/policy/tests/config/cartpole_muzero_config_for_test.py +++ b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py @@ -30,12 +30,14 @@ model=dict( observation_shape=4, action_space_size=2, - model_type='mlp', + model_type='mlp', lstm_hidden_size=128, latent_state_dim=128, self_supervised_learning_loss=True, # NOTE: default is False. discrete_action_encoding_type='one_hot', - norm_type='BN', + norm_type='BN', + value_support_range=(-300., 301., 1.), + reward_support_range=(-300., 301., 1.), ), cuda=True, env_type='not_board_games', diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py old mode 100644 new mode 100755 index 8450a2ac5..766012870 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -1,23 +1,92 @@ import copy +import logging from collections import defaultdict -from typing import List, Dict, Any, Tuple, Union +from typing import Any, Dict, List, Tuple, Union import numpy as np import torch +import torch.nn.functional as F import wandb from ding.model import model_wrap from ding.utils import POLICY_REGISTRY - -from lzero.entry.utils import initialize_zeros_batch, initialize_pad_batch from lzero.mcts import UniZeroMCTSCtree as MCTSCtree from lzero.model import ImageTransforms -from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs, \ - prepare_obs_stack_for_unizero +from lzero.policy import (DiscreteSupport, InverseScalarTransform, + mz_network_output_unpack, phi_transform, prepare_obs, + prepare_obs_stack_for_unizero, scalar_transform, + select_action, to_torch_float_tensor) +from lzero.policy.head_clip_manager import (HeadClipConfig, HeadClipManager, + create_head_clip_manager_from_dict) from lzero.policy.muzero import MuZeroPolicy +from lzero.policy.utils import initialize_pad_batch +from torch.nn.utils.convert_parameters import (parameters_to_vector, + vector_to_parameters) + from .utils import configure_optimizers_nanogpt +def scale_module_weights_vectorized(module: torch.nn.Module, scale_factor: float): + """ + Efficiently scale all weights of a module using vectorized operations. + """ + if not (0.0 < scale_factor < 1.0): + return # Do nothing if the scaling factor is invalid + + # 1. Flatten all parameters of the module into a single vector + params_vec = parameters_to_vector(module.parameters()) + + # 2. Perform multiplication operation on this vector + params_vec.data.mul_(scale_factor) + + # 3. Copy the scaled vector back to the individual parameters of the module + vector_to_parameters(params_vec, module.parameters()) + + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + Configure optimizer with differentiated learning rates and weight decay for encoder/backbone/head of UniZero model. + """ + # 1. Define parameters that need special handling + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + + # 2. Divide parameters into three groups: Transformer backbone, Tokenizer, and Heads + transformer_params = {pn: p for pn, p in param_dict.items() if 'transformer' in pn} + tokenizer_params = {pn: p for pn, p in param_dict.items() if 'tokenizer' in pn} + + # Head parameters are those that belong to neither transformer nor tokenizer + head_params = { + pn: p for pn, p in param_dict.items() + if 'transformer' not in pn and 'tokenizer' not in pn + } + + # 3. Set different optimizer parameters for each group (especially learning rate) + # We still use AdamW here, but with more reasonable learning rate settings + optim_groups = [ + { + 'params': list(tokenizer_params.values()), + 'lr': learning_rate, # Tokenizer uses base learning rate, e.g., 1e-4 + 'weight_decay': weight_decay + }, + { + 'params': list(transformer_params.values()), + 'lr': learning_rate, # Tokenizer uses base learning rate, e.g., 1e-4 + 'weight_decay': weight_decay + }, + { + 'params': list(head_params.values()), + 'lr': learning_rate, # Heads also use base learning rate, e.g., 1e-4 + 'weight_decay': weight_decay + + } + ] + + logging.info("--- Optimizer Groups ---") + logging.info(f"Transformer LR: {learning_rate}") + logging.info(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer + @POLICY_REGISTRY.register('unizero') class UniZeroPolicy(MuZeroPolicy): """ @@ -65,6 +134,8 @@ class UniZeroPolicy(MuZeroPolicy): # (int) The save interval of the model. learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), world_model_cfg=dict( + # (str) The encoder type, e.g., 'resnet' or 'vit'. + encoder_type='resnet', # (bool) If True, the action space of the environment is continuous, otherwise discrete. continuous_action_space=False, # (int) The number of tokens per block. @@ -81,8 +152,8 @@ class UniZeroPolicy(MuZeroPolicy): device='cpu', # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, - # (bool) Whether to analyze dormant ratio. - analysis_dormant_ratio=False, + # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. + analysis_dormant_ratio_weight_rank=False, # (int) The shape of the action space. action_space_size=6, # (int) The size of the group, related to simulation normalization. @@ -139,13 +210,131 @@ class UniZeroPolicy(MuZeroPolicy): rope_theta=10000, # (int) The maximum sequence length for position encoding. max_seq_len=8192, + # (int) The rank parameter for LoRA (Low-Rank Adaptation). Set to 0 to disable LoRA. + lora_r=0, + # (float) The alpha parameter for LoRA scaling. + lora_alpha=1, + # (float) The dropout probability for LoRA layers. + lora_dropout=0.0, # Controls where to compute reconstruction loss: 'after_backbone', 'before_backbone', or None. # - after_backbone: The reconstruction loss is computed after the encoded representation passes through the backbone. - # - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone. + # - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone. decode_loss_mode=None, + # (str/None) Task embedding option. Set to None to disable task-specific embeddings. Options are ['concat_task_embed', 'add_task_embed', 'register_task_embed']. + # Please note that "register_task_embed" has not yet been fully tested. + task_embed_option=None, + # (bool) Whether to use task embeddings. + use_task_embed=False, + # TODO: optimize the following configs. + # (bool) Whether to use normal head (standard prediction heads). + use_normal_head=True, + # (bool) Whether to use Soft Mixture-of-Experts (MoE) head. + use_softmoe_head=False, + # (bool) Whether to use Mixture-of-Experts (MoE) head. + use_moe_head=False, + # (int) Number of experts in the MoE head. + num_experts_in_moe_head=4, + # (bool) Whether to use MoE in the transformer layers. + moe_in_transformer=False, + # (bool) Whether to use multiplicative MoE in the transformer layers. + multiplication_moe_in_transformer=False, + # (int) Number of shared experts in MoE. + n_shared_experts=1, + # (int) Number of experts to use per token in MoE. + num_experts_per_tok=1, + # (int) Total number of experts in the transformer MoE. + num_experts_of_moe_in_transformer=8, + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, ), ), # ****** common ****** + # (bool) Whether to enable adaptive policy entropy weight (alpha) + use_adaptive_entropy_weight=True, + # (float) Learning rate for adaptive alpha optimizer + adaptive_entropy_alpha_lr=1e-3, + # (float) Target entropy ratio at the start of training (higher = more exploration) + target_entropy_start_ratio=0.98, + # (float) Target entropy ratio at the end of training (lower = more exploitation) + target_entropy_end_ratio=0.05, + # (int) Number of training steps to decay target entropy from start to end ratio + target_entropy_decay_steps=500000, + + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) Whether to enable annealing for encoder-clip values. + use_encoder_clip_annealing=True, + # (str) Annealing type. Options: 'linear' or 'cosine'. + encoder_clip_anneal_type='cosine', + # (float) Starting clip value for annealing (looser in early training). + encoder_clip_start_value=30.0, + # (float) Ending clip value for annealing (stricter in later training). + encoder_clip_end_value=10.0, + # (int) Training iteration steps required to complete annealing from start to end value. + encoder_clip_anneal_steps=100000, # e.g., reach final value after 100k iterations + # (float) Fixed latent norm clip threshold (used when encoder_clip_annealing is disabled) + latent_norm_clip_threshold=20.0, + # ===================== END: Encoder-Clip Annealing Config ===================== + + # ==================== START: Head-Clip Annealing Config ==================== + # NOTE: The usage and implementation of Head-Clip may need to be optimized + # (bool) Whether to enable head-clip (dynamically clip head output range) + use_head_clip=False, # Disabled by default + # Detailed Head-Clip configuration + head_clip_config=dict( + enabled=False, + # Specify heads that need clipping (optional, defaults to empty list) + enabled_heads=[], # Example: ['policy', 'value', 'rewards'] + # Detailed configuration for each head (optional) + head_configs={ + # 'policy': { + # 'use_annealing': True, + # 'anneal_type': 'cosine', # 'cosine' or 'linear' + # 'start_value': 30.0, # Loose in early phase + # 'end_value': 10.0, # Strict in later phase + # 'anneal_steps': 500000, + # }, + # 'value': { + # 'clip_threshold': 20.0, + # 'use_annealing': False, + # }, + }, + # Monitoring configuration + monitor_freq=1, # Check every iteration + log_freq=1000, # Print log every 1000 iterations + ), + # ===================== END: Head-Clip Annealing Config ===================== + + # ==================== START: Policy Label Smoothing Config ==================== + # (float) Starting epsilon value for policy label smoothing (higher = more smoothing) + policy_ls_eps_start=0.05, + # (float) Ending epsilon value for policy label smoothing (lower = less smoothing) + policy_ls_eps_end=0.01, + # (int) Number of training steps to decay label smoothing epsilon from start to end + policy_ls_eps_decay_steps=50000, + + label_smoothing_eps=0.1, # TODO: For value + + # (bool) Whether to use continuous (fixed) label smoothing throughout training + use_continuous_label_smoothing=False, + # (float) Fixed epsilon value for continuous label smoothing (only used when use_continuous_label_smoothing=True) + continuous_ls_eps=0.05, + # ===================== END: Policy Label Smoothing Config ===================== + + # ==================== START: Learning Rate Scheduler Config ==================== + # (int) Total training iterations for cosine annealing LR scheduler (only used when cos_lr_scheduler=True) + total_iterations=500000, + # (float) Final learning rate for cosine annealing LR scheduler (only used when cos_lr_scheduler=True) + final_learning_rate=4e-5, + # ===================== END: Learning Rate Scheduler Config ===================== + + # ==================== START: Monitoring Config ==================== + # (int) Frequency of monitoring model parameter and gradient norms (in training iterations). Set to 0 to disable. + monitor_norm_freq=5000, + # (bool) Whether to enable enhanced policy monitoring (logits statistics, target policy entropy, etc.) + use_enhanced_policy_monitoring=False, + # ===================== END: Monitoring Config ===================== + # (bool) whether to use rnd model. use_rnd_model=False, # (bool) Whether to use multi-gpu training. @@ -178,7 +367,7 @@ class UniZeroPolicy(MuZeroPolicy): # (bool) Whether to use the pure policy to collect data. collect_with_pure_policy=False, # (int) The evaluation frequency. - eval_freq=int(2e3), + eval_freq=int(5e3), # (str) The sample type. Options are ['episode', 'transition']. sample_type='transition', # ****** observation ****** @@ -227,8 +416,12 @@ class UniZeroPolicy(MuZeroPolicy): n_episode=8, # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. num_segments=8, - # (int) the number of simulations in MCTS. + # (int) the number of simulations in MCTS for renalyze. num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, # (float) Discount factor (gamma) for returns. discount_factor=0.997, # (int) The number of steps for calculating target q_value. @@ -273,6 +466,8 @@ class UniZeroPolicy(MuZeroPolicy): priority_prob_beta=0.4, # (int) The initial Env Steps for training. train_start_after_envsteps=int(0), + # (bool) Whether to use task_exploitation_weight. + use_task_exploitation_weight=False, # ****** UCB ****** # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. @@ -283,7 +478,7 @@ class UniZeroPolicy(MuZeroPolicy): # ****** Explore by random collect ****** # (int) The number of episodes to collect data randomly before training. random_collect_episode_num=0, - + # ****** Explore by eps greedy ****** eps=dict( # (bool) Whether to use eps greedy exploration in collecting data. @@ -313,24 +508,139 @@ def default_model(self) -> Tuple[str, List[str]]: """ return 'UniZeroModel', ['lzero.model.unizero_model'] + + # ==================== Model Norm Monitoring Function ==================== + def _monitor_model_norms(self) -> Dict[str, float]: + """ + Overview: + Calculate and return parameter matrix norms for key model components (Encoder, Transformer, Heads). + This function should be called within a torch.no_grad() context for efficiency. + Returns: + - norm_metrics (:obj:`Dict[str, float]`): Dictionary containing all norm metrics for logging. + """ + world_model = self._learn_model.world_model + norm_metrics = {} + + # Define module groups to monitor + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policy, + } + + for group_name, group_module in module_groups.items(): + total_norm_sq = 0.0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad: + # Calculate L2 norm for single layer parameters + param_norm = param.data.norm(2).item() + # Replace dots to display correctly as hierarchy in TensorBoard + log_name = f'norm/{group_name}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + + # Calculate total norm for entire module + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}/_total_norm'] = total_group_norm + + return norm_metrics + + def _monitor_gradient_norms(self) -> Dict[str, float]: + """ + Overview: + Calculate and return gradient norms for key model components. + This function should be called after gradient computation and before parameter updates. + Returns: + - grad_metrics (:obj:`Dict[str, float]`): Dictionary containing all gradient norm metrics for logging. + """ + world_model = self._learn_model.world_model + grad_metrics = {} + + # Define module groups to monitor + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value, + 'head_reward': world_model.head_rewards, + 'head_policy': world_model.head_policy, + } + + for group_name, group_module in module_groups.items(): + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + + for param_name, param in group_module.named_parameters(): + if param.requires_grad and param.grad is not None: + # Calculate L2 norm for single layer parameter gradients + grad_norm = param.grad.data.norm(2).item() + # Replace dots to display correctly as hierarchy in TensorBoard + log_name = f'grad/{group_name}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + + # Calculate total gradient norm for entire module + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}/_total_norm'] = 0.0 + + return grad_metrics + # ================================================================= + def _init_learn(self) -> None: """ Overview: Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. """ - # NOTE: nanoGPT optimizer - self._optimizer_world_model = configure_optimizers_nanogpt( - model=self._model.world_model, - learning_rate=self._cfg.learning_rate, - weight_decay=self._cfg.weight_decay, - device_type=self._cfg.device, - betas=(0.9, 0.95), - ) + if self._cfg.optim_type == 'SGD': + # Configure SGD optimizer + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) if self._cfg.cos_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR - # TODO: check the total training steps - self.lr_scheduler = CosineAnnealingLR(self._optimizer_world_model, 1e5, eta_min=0, last_epoch=-1) + total_iters = self._cfg.total_iterations + final_lr = self._cfg.final_learning_rate + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + logging.info(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) @@ -364,19 +674,113 @@ def _init_learn(self) -> None: self.grad_norm_after = 0. if self._cfg.model.model_type == 'conv': + # for image-input env self.pad_token_id = -1 else: + # for text-input env and vector-input env + # Retrieve the tokenizer from the encoder module if it exists encoder_tokenizer = getattr(self._model.tokenizer.encoder, 'tokenizer', None) + + # Extract the padding token ID from the tokenizer if available, otherwise use 0 as default. Used in _reset_collect() + # The pad_token_id is used to identify padding tokens in sequences, which is essential for: + # 1. Masking padded positions during attention computation to prevent them from affecting the output + # 2. Properly handling variable-length sequences in batch processing + # 3. Distinguishing between actual tokens and padding in loss calculation + # Default value 0 is a common convention when no specific padding token is defined self.pad_token_id = encoder_tokenizer.pad_token_id if encoder_tokenizer is not None else 0 - if self._cfg.use_wandb: # TODO: add the model to wandb wandb.watch(self._learn_model.representation_network, log="all") self.accumulation_steps = self._cfg.accumulation_steps - # @profile + # ==================== START: Target Entropy Regularization Initialization ==================== + # Read whether to enable adaptive alpha from config, and provide a default value + self.use_adaptive_entropy_weight = self._cfg.use_adaptive_entropy_weight + + # Add configuration in _init_learn + self.target_entropy_start_ratio = self._cfg.target_entropy_start_ratio + self.target_entropy_end_ratio = self._cfg.target_entropy_end_ratio + self.target_entropy_decay_steps = self._cfg.target_entropy_decay_steps # e.g., complete annealing within 200k steps (2M envsteps) + + if self.use_adaptive_entropy_weight: + # 1. Set target entropy. For discrete action spaces, a common heuristic is the negative logarithm + # of action space dimension multiplied by a coefficient. + # This coefficient (e.g., 0.98) can be used as a hyperparameter. + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. Initialize a learnable log_alpha parameter. + # Initialized to 0, meaning initial alpha = exp(0) = 1.0. + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. Create a dedicated optimizer for log_alpha. + # Using a smaller learning rate (e.g., 1e-4) different from the main optimizer is usually more stable. + alpha_lr = self._cfg.adaptive_entropy_alpha_lr + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + logging.info("="*20) + logging.info(">>> Target Entropy Regularization (Adaptive Alpha) Enabled <<<") + logging.info(f" Target Entropy: {self.target_entropy:.4f}") + logging.info(f" Alpha Optimizer Learning Rate: {alpha_lr:.2e}") + logging.info("="*20) + # ===================== END: Target Entropy Regularization Initialization ===================== + + # ==================== START: Initialize Encoder-Clip Annealing Parameters ==================== + self.use_encoder_clip_annealing = self._cfg.use_encoder_clip_annealing + self.latent_norm_clip_threshold = self._cfg.latent_norm_clip_threshold # TODO + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.encoder_clip_anneal_type + self.encoder_clip_start = self._cfg.encoder_clip_start_value + self.encoder_clip_end = self._cfg.encoder_clip_end_value + self.encoder_clip_anneal_steps = self._cfg.encoder_clip_anneal_steps + + logging.info("="*20) + logging.info(">>> Encoder-Clip Annealing Enabled <<<") + logging.info(f" Type: {self.encoder_clip_anneal_type}") + logging.info(f" Range: {self.encoder_clip_start} -> {self.encoder_clip_end}") + logging.info(f" Steps: {self.encoder_clip_anneal_steps}") + logging.info("="*20) + else: + # If annealing is not enabled, use a fixed clip threshold + self.latent_norm_clip_threshold = self._cfg.latent_norm_clip_threshold + # ===================== END: Initialize Encoder-Clip Annealing Parameters ===================== + + # ==================== START: Initialize Head-Clip Manager ==================== + self.use_head_clip = self._cfg.use_head_clip + + if self.use_head_clip: + head_clip_config_dict = self._cfg.head_clip_config + # Ensure enabled is consistent with top-level configuration + head_clip_config_dict['enabled'] = self.use_head_clip + + # Create HeadClipManager + self.head_clip_manager = create_head_clip_manager_from_dict(head_clip_config_dict) + + logging.info("=" * 60) + logging.info(">>> Head-Clip Manager Initialized <<<") + logging.info(f" Enabled heads: {self.head_clip_manager.enabled_heads}") + for head_name in self.head_clip_manager.enabled_heads: + config = self.head_clip_manager.get_head_config(head_name) + if config.use_annealing: + logging.info( + f" {head_name}: annealing {config.start_value:.1f} → {config.end_value:.1f} " + f"over {config.anneal_steps} steps ({config.anneal_type})" + ) + else: + logging.info(f" {head_name}: fixed threshold = {config.clip_threshold:.1f}") + logging.info("=" * 60) + else: + self.head_clip_manager = None + # ===================== END: Initialize Head-Clip Manager ===================== + + # Policy Label Smoothing Parameters + self.policy_ls_eps_start = self._cfg.policy_ls_eps_start + self.policy_ls_eps_end = self._cfg.policy_ls_eps_end + self.policy_ls_eps_decay_steps = self._cfg.policy_ls_eps_decay_steps + logging.info(f"self.policy_ls_eps_start: {self.policy_ls_eps_start}") + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ Overview: @@ -397,11 +801,26 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch target_reward, target_value, target_policy = target_batch + # Calculate current epsilon for policy label smoothing + # ==================== Continuous Label Smoothing ==================== + use_continuous_label_smoothing = self._cfg.use_continuous_label_smoothing + if use_continuous_label_smoothing: + # Use fixed high epsilon throughout training + current_policy_label_eps = self._cfg.continuous_ls_eps + else: + # Use original decay schedule + if self.policy_ls_eps_start > 0: + progress = min(1.0, train_iter / self.policy_ls_eps_decay_steps) + current_policy_label_eps = self.policy_ls_eps_start * (1 - progress) + self.policy_ls_eps_end * progress + else: + current_policy_label_eps = 0.0 + # ================================================================================ + # Prepare observations based on frame stack number if self._cfg.model.frame_stack_num > 1: obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) else: - obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) # TODO: optimize + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) # Apply augmentations if needed if self._cfg.use_augmentation: @@ -425,8 +844,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in transformed_target_value = scalar_transform(target_value) # Convert to categorical distributions - target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) # Prepare batch for GPT model batch_for_gpt = {} @@ -447,7 +866,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, device=self._cfg.device) batch_for_gpt['target_value'] = target_value_categorical[:, :-1] - batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # ==================== Apply Policy Label Smoothing ==================== + # This was previously computed but never applied. Now we actually smooth the target_policy. + smoothed_target_policy = target_policy[:, :-1] + if current_policy_label_eps > 0: + num_actions = smoothed_target_policy.shape[-1] + uniform_dist = torch.ones_like(smoothed_target_policy) / num_actions + smoothed_target_policy = (1.0 - current_policy_label_eps) * smoothed_target_policy + \ + current_policy_label_eps * uniform_dist + batch_for_gpt['target_policy'] = smoothed_target_policy + # =================================================================================== + + batch_for_gpt['scalar_target_value'] = target_value # Extract valid target policy data and compute entropy valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] @@ -456,13 +887,114 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Update world model losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle - ) # NOTE : compute_loss third argument is now a dead argument. If this changes, it could need adaptation between value_inverse and reward_inverse. + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, global_step=train_iter, current_policy_label_eps=current_policy_label_eps, + ) + + # ==================== Integrate norm monitoring logic ==================== + norm_log_dict = {} + # Check if monitoring frequency is reached + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + with torch.no_grad(): + # 1. Monitor model parameter norms + param_norm_metrics = self._monitor_model_norms() + norm_log_dict.update(param_norm_metrics) + + # 2. Monitor intermediate tensor x (Transformer output) + intermediate_x = losses.intermediate_losses.get('intermediate_tensor_x') + if intermediate_x is not None: + # x shape is (B, T, E) + # Calculate L2 norm for each token + token_norms = intermediate_x.norm(p=2, dim=-1) + + # Record statistics of these norms + norm_log_dict['norm/x_token/mean'] = token_norms.mean().item() + norm_log_dict['norm/x_token/std'] = token_norms.std().item() + norm_log_dict['norm/x_token/max'] = token_norms.max().item() + norm_log_dict['norm/x_token/min'] = token_norms.min().item() + + # 3. Monitor detailed statistics of logits (Value, Policy, Reward) + logits_value = losses.intermediate_losses.get('logits_value') + if logits_value is not None: + norm_log_dict['logits/value/mean'] = logits_value.mean().item() + norm_log_dict['logits/value/std'] = logits_value.std().item() + norm_log_dict['logits/value/max'] = logits_value.max().item() + norm_log_dict['logits/value/min'] = logits_value.min().item() + norm_log_dict['logits/value/abs_max'] = logits_value.abs().max().item() + + logits_policy = losses.intermediate_losses.get('logits_policy') + if logits_policy is not None: + norm_log_dict['logits/policy/mean'] = logits_policy.mean().item() + norm_log_dict['logits/policy/std'] = logits_policy.std().item() + norm_log_dict['logits/policy/max'] = logits_policy.max().item() + norm_log_dict['logits/policy/min'] = logits_policy.min().item() + norm_log_dict['logits/policy/abs_max'] = logits_policy.abs().max().item() + + logits_reward = losses.intermediate_losses.get('logits_reward') + if logits_reward is not None: + norm_log_dict['logits/reward/mean'] = logits_reward.mean().item() + norm_log_dict['logits/reward/std'] = logits_reward.std().item() + norm_log_dict['logits/reward/max'] = logits_reward.max().item() + norm_log_dict['logits/reward/min'] = logits_reward.min().item() + norm_log_dict['logits/reward/abs_max'] = logits_reward.abs().max().item() + + # 4. Monitor obs_embeddings (Encoder output) statistics + obs_embeddings = losses.intermediate_losses.get('obs_embeddings') + if obs_embeddings is not None: + # Calculate L2 norm for each embedding + emb_norms = obs_embeddings.norm(p=2, dim=-1) + norm_log_dict['embeddings/obs/norm_mean'] = emb_norms.mean().item() + norm_log_dict['embeddings/obs/norm_std'] = emb_norms.std().item() + norm_log_dict['embeddings/obs/norm_max'] = emb_norms.max().item() + norm_log_dict['embeddings/obs/norm_min'] = emb_norms.min().item() + + # ==================== Early Warning System ==================== + # Detect potential training instability and issue warnings + warnings_issued = [] + + # Check 1: Policy logits explosion (should be caught by clip, but warn anyway) + if 'logits/policy/abs_max' in norm_log_dict: + policy_abs_max = norm_log_dict['logits/policy/abs_max'] + if policy_abs_max > 8.0: + warnings_issued.append(f"⚠️ CRITICAL: Policy logits explosion detected! abs_max={policy_abs_max:.2f} (threshold: 8.0)") + elif policy_abs_max > 5.0: + warnings_issued.append(f"⚠️ WARNING: Policy logits getting large! abs_max={policy_abs_max:.2f} (threshold: 5.0)") + + # Check 2: Embedding norm explosion + if 'embeddings/obs/norm_std' in norm_log_dict: + emb_norm_std = norm_log_dict['embeddings/obs/norm_std'] + if emb_norm_std > 10.0: + warnings_issued.append(f"⚠️ CRITICAL: Embedding norm std explosion! std={emb_norm_std:.2f} (threshold: 10.0)") + elif emb_norm_std > 5.0: + warnings_issued.append(f"⚠️ WARNING: Embedding norm std increasing! std={emb_norm_std:.2f} (threshold: 5.0)") + + # Check 3: X token norm collapse + if 'norm/x_token/std' in norm_log_dict: + x_token_std = norm_log_dict['norm/x_token/std'] + if x_token_std < 0.1: + warnings_issued.append(f"⚠️ CRITICAL: X token norm collapse! std={x_token_std:.4f} (threshold: 0.1)") + elif x_token_std < 0.5: + warnings_issued.append(f"⚠️ WARNING: X token norm decreasing! std={x_token_std:.4f} (threshold: 0.5)") + + # Log warnings if any + if warnings_issued: + logging.warning(f"\n{'='*80}\n[TRAINING STABILITY] Iteration {train_iter}:\n" + "\n".join(warnings_issued) + f"\n{'='*80}") + norm_log_dict['stability/warning_count'] = float(len(warnings_issued)) + else: + norm_log_dict['stability/warning_count'] = 0.0 + # ==================================================================== + # ================================================================= + + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + + weighted_total_loss = (weights * losses.loss_total).mean() - weighted_total_loss = losses.loss_total for loss_name, loss_value in losses.intermediate_losses.items(): self.intermediate_losses[f"{loss_name}"] = loss_value + # Extract losses from intermediate_losses dictionary obs_loss = self.intermediate_losses['loss_obs'] reward_loss = self.intermediate_losses['loss_rewards'] policy_loss = self.intermediate_losses['loss_policy'] @@ -475,8 +1007,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in middle_step_losses = self.intermediate_losses['middle_step_losses'] last_step_losses = self.intermediate_losses['last_step_losses'] dormant_ratio_encoder = self.intermediate_losses['dormant_ratio_encoder'] - dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] + dormant_ratio_transformer = self.intermediate_losses['dormant_ratio_transformer'] + dormant_ratio_head = self.intermediate_losses['dormant_ratio_head'] + avg_weight_mag_encoder = self.intermediate_losses['avg_weight_mag_encoder'] + avg_weight_mag_transformer = self.intermediate_losses['avg_weight_mag_transformer'] + avg_weight_mag_head = self.intermediate_losses['avg_weight_mag_head'] + e_rank_last_linear = self.intermediate_losses['e_rank_last_linear'] + e_rank_sim_norm = self.intermediate_losses['e_rank_sim_norm'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] + latent_action_l2_norms = self.intermediate_losses['latent_action_l2_norms'] + + temperature_value=self.intermediate_losses['temperature_value'] + temperature_reward=self.intermediate_losses['temperature_reward'] + temperature_policy=self.intermediate_losses['temperature_policy'] assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" @@ -486,19 +1029,129 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in if (train_iter % self.accumulation_steps) == 0: self._optimizer_world_model.zero_grad() + + # ==================== START: Target Entropy Regularization Update Logic ==================== + alpha_loss = None + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # Default to fixed value + if self.use_adaptive_entropy_weight: + # Dynamically calculate target entropy (this logic is correct and preserved) + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # Note: We define target_entropy as a positive number, which is more intuitive + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # Calculate alpha_loss (corrected sign) + # This is the core correction: removed the negative sign at the front + # detach() is still critical to ensure alpha_loss gradient only flows to log_alpha + alpha_loss = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() + + # Update log_alpha + self.alpha_optimizer.zero_grad() + alpha_loss.backward() + self.alpha_optimizer.step() + # [Optimization suggestion] Add log_alpha clipping as a safety measure + with torch.no_grad(): + # Limit alpha to a range, e.g., [1e-4, 10.0] + self.log_alpha.clamp_(np.log(5e-2), np.log(10.0)) + + # Use current updated alpha (with gradient flow truncated) + current_alpha = self.log_alpha.exp().detach() + + # Recalculate weighted policy loss and total loss + # Note: policy_entropy here is already an average value of a batch + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # Rebuild total loss (not using losses.loss_total) + # Ensure the weights here are consistent with the calculation in LossWithIntermediateLosses class + self.obs_loss_weight = 2 + self.value_loss_weight = 0.5 + self.reward_loss_weight = 1. + self.policy_loss_weight = 1. + self.ends_loss_weight = 0. + + self.latent_recon_loss_weight = self._cfg.model.world_model_cfg.latent_recon_loss_weight + self.perceptual_loss_weight = self._cfg.model.world_model_cfg.perceptual_loss_weight + + if self.latent_recon_loss_weight>0: + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss + + self.latent_recon_loss_weight * latent_recon_loss+ + self.perceptual_loss_weight*perceptual_loss + ) + else: + + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss + + ) + weighted_total_loss = (weights * total_loss).mean() + # ===================== END: Target Entropy Regularization Update Logic ===================== + # Scale the loss by the number of accumulation steps weighted_total_loss = weighted_total_loss / self.accumulation_steps weighted_total_loss.backward() + # Still executed within torch.no_grad() context + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: Dynamically calculate current Clip threshold ==================== + current_clip_value = self.latent_norm_clip_threshold # Default to fixed value + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # Cosine schedule: smoothly transition from 1 to 0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # Default to linear schedule + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: Dynamically calculate current Clip threshold ===================== + + # 1. Encoder-Clip (using dynamically calculated current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # No longer print frequently, or can be changed to print every N steps + if train_iter % 1000 == 0: + logging.info(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + + if self.use_head_clip and self.head_clip_manager is not None: + head_clip_results = self.head_clip_manager.apply_head_clip( + self._learn_model.world_model, + losses, + train_iter + ) + + # Check if the current iteration completes an accumulation cycle if (train_iter + 1) % self.accumulation_steps == 0: + # ==================== [NEW] Monitor gradient norms ==================== + # Monitor gradient norms before gradient clipping to diagnose gradient explosion/vanishing issues + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + grad_norm_metrics = self._monitor_gradient_norms() + norm_log_dict.update(grad_norm_metrics) + # ================================================================= + # Analyze gradient norms if simulation normalization analysis is enabled if self._cfg.analysis_sim_norm: # Clear previous analysis results to prevent memory overflow del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() self._target_model.encoder_hook.clear_data() - + # Clip gradients to prevent exploding gradients total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_( self._learn_model.world_model.parameters(), self._cfg.grad_clip_value @@ -565,21 +1218,91 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in 'target_policy_entropy': average_target_policy_entropy.item(), 'reward_loss': reward_loss.item(), 'value_loss': value_loss.item(), - # 'value_priority_orig': np.zeros(self._cfg.batch_size), # TODO + # Add value_priority to the log dictionary. + 'value_priority': value_priority_np.mean().item(), + 'value_priority_orig': value_priority_np, 'target_reward': target_reward.mean().item(), 'target_value': target_value.mean().item(), 'transformed_target_reward': transformed_target_reward.mean().item(), 'transformed_target_value': transformed_target_value.mean().item(), 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), - 'analysis/dormant_ratio_encoder': dormant_ratio_encoder.item(), - 'analysis/dormant_ratio_world_model': dormant_ratio_world_model.item(), + 'analysis/dormant_ratio_encoder': dormant_ratio_encoder, + 'analysis/dormant_ratio_transformer': dormant_ratio_transformer, + 'analysis/dormant_ratio_head': dormant_ratio_head, + + 'analysis/avg_weight_mag_encoder': avg_weight_mag_encoder, + 'analysis/avg_weight_mag_transformer': avg_weight_mag_transformer, + 'analysis/avg_weight_mag_head': avg_weight_mag_head, + 'analysis/e_rank_last_linear': e_rank_last_linear, + 'analysis/e_rank_sim_norm': e_rank_sim_norm, + 'analysis/latent_state_l2_norms': latent_state_l2_norms.item(), + 'analysis/latent_action_l2_norms': latent_action_l2_norms, 'analysis/l2_norm_before': self.l2_norm_before, 'analysis/l2_norm_after': self.l2_norm_after, 'analysis/grad_norm_before': self.grad_norm_before, 'analysis/grad_norm_after': self.grad_norm_after, + + "temperature_value":temperature_value, + "temperature_reward":temperature_reward, + "temperature_policy":temperature_policy, + + "current_policy_label_eps":current_policy_label_eps, } - + + if norm_log_dict: + return_log_dict.update(norm_log_dict) + + use_enhanced_policy_monitoring = self._cfg.use_enhanced_policy_monitoring + if use_enhanced_policy_monitoring: + # Monitor policy logits statistics + with torch.no_grad(): + logits_policy = losses.intermediate_losses.get('logits_policy') + if logits_policy is not None: + return_log_dict['policy_logits/norm'] = logits_policy.norm(dim=-1).mean().item() + return_log_dict['policy_logits/max'] = logits_policy.max().item() + return_log_dict['policy_logits/min'] = logits_policy.min().item() + return_log_dict['policy_logits/std'] = logits_policy.std().item() + + # [NEW] Also monitor Value and Reward logits + logits_value = losses.intermediate_losses.get('logits_value') + if logits_value is not None: + return_log_dict['value_logits/abs_max'] = logits_value.abs().max().item() + return_log_dict['value_logits/norm'] = logits_value.norm(dim=-1).mean().item() + + logits_reward = losses.intermediate_losses.get('logits_reward') + if logits_reward is not None: + return_log_dict['reward_logits/abs_max'] = logits_reward.abs().max().item() + return_log_dict['reward_logits/norm'] = logits_reward.norm(dim=-1).mean().item() + + # Monitor target_policy entropy statistics (minimum entropy indicates extreme distributions) + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropies = -torch.sum( + valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1 + ) + return_log_dict['target_policy_entropy/mean'] = target_policy_entropies.mean().item() + return_log_dict['target_policy_entropy/min'] = target_policy_entropies.min().item() + return_log_dict['target_policy_entropy/max'] = target_policy_entropies.max().item() + return_log_dict['target_policy_entropy/std'] = target_policy_entropies.std().item() + # ================================================================================ + + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + return_log_dict['alpha_loss'] = alpha_loss.item() + + if self.use_encoder_clip_annealing: + return_log_dict['current_encoder_clip_value'] = current_clip_value + + if self.use_head_clip and self.head_clip_manager is not None: + # Add head clip results to log (if any) + if head_clip_results: + for head_name, info in head_clip_results.items(): + return_log_dict[f'head_clip/{head_name}/max_logits'] = info['max_logits'] + return_log_dict[f'head_clip/{head_name}/threshold'] = info['threshold'] + if info['scaled']: + return_log_dict[f'head_clip/{head_name}/scale_factor'] = info['scale_factor'] + if self._cfg.use_wandb: wandb.log({'learner_step/' + k: v for k, v in return_log_dict.items()}, step=self.env_step) wandb.log({"learner_iter_vs_env_step": self.train_iter}, step=self.env_step) @@ -589,7 +1312,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in def monitor_weights_and_grads(self, model): for name, param in model.named_parameters(): if param.requires_grad: - print(f"Layer: {name} | " + logging.info(f"Layer: {name} | " f"Weight mean: {param.data.mean():.4f} | " f"Weight std: {param.data.std():.4f} | " f"Grad mean: {param.grad.mean():.4f} | " @@ -601,24 +1324,25 @@ def _init_collect(self) -> None: Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. """ self._collect_model = self._model - + # Create a configuration copy for collect MCTS and set specific simulation count + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations if self._cfg.mcts_ctree: - self._mcts_collect = MCTSCtree(self._cfg) + self._mcts_collect = MCTSCtree(mcts_collect_cfg) else: - self._mcts_collect = MCTSPtree(self._cfg) + self._mcts_collect = MCTSPtree(mcts_collect_cfg) self._collect_mcts_temperature = 1. self._collect_epsilon = 0.0 self.collector_env_num = self._cfg.collector_env_num if self._cfg.model.model_type == 'conv': self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) - self.last_batch_action = [-1 for i in range(self.collector_env_num)] + self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': self.last_batch_obs = torch.full( [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, ).to(self._cfg.device) - self.last_batch_action = [-1 for i in range(self.collector_env_num)] + self.last_batch_action_collect = [-1 for i in range(self.collector_env_num)] - # @profile def _forward_collect( self, data: torch.Tensor, @@ -626,8 +1350,9 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id: np.ndarray = None, - timestep: List = [0] + ready_env_id: np.array = None, + timestep: List = [0], + task_id: int = None, ) -> Dict: """ Overview: @@ -640,6 +1365,7 @@ def _forward_collect( - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. - timestep (:obj:`list`): The step index of the env in one episode. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -664,7 +1390,7 @@ def _forward_collect( output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action_collect, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() @@ -687,7 +1413,7 @@ def _forward_collect( roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) next_latent_state_with_env = self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep) - + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} @@ -696,7 +1422,7 @@ def _forward_collect( batch_action = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] - + if self._cfg.eps.eps_greedy_exploration_in_collect: # eps greedy collect action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( @@ -716,20 +1442,13 @@ def _forward_collect( action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] next_latent_state = next_latent_state_with_env[i][action] - + if self._cfg.model.world_model_cfg.obs_type == 'text' and self._cfg.model.world_model_cfg.decode_loss_mode is not None and self._cfg.model.world_model_cfg.decode_loss_mode.lower() != 'none': # Output the plain text content decoded by the decoder from the next latent state predicted_next = self._collect_model.tokenizer.decode_to_plain_text(embeddings=next_latent_state, max_length=256) else: predicted_next = None - # ============== TODO: only for visualize ============== - # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( - # distributions, temperature=self._collect_mcts_temperature, deterministic=True - # ) - # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - # ============== TODO: only for visualize ============== - output[env_id] = { 'action': action, 'visit_count_distributions': distributions, @@ -743,15 +1462,27 @@ def _forward_collect( batch_action.append(action) self.last_batch_obs = data - self.last_batch_action = batch_action + self.last_batch_action_collect = batch_action - # ========= TODO: for muzero_segment_collector now ========= + # This logic is a temporary workaround specific to the muzero_segment_collector. if active_collect_env_num < self.collector_env_num: - print('==========collect_forward============') - print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + # When an environment finishes an episode ('done'), the length of `self.last_batch_obs` passed back + # becomes smaller than the total number of collector environments. + # Handling this dynamic batch size is complex, as the transformer's KV cache retrieval + # requires a stable environment ID for correct indexing. A mismatch would cause retrieval errors. + # + # Therefore, as a simpler solution, we reset the collection state for ALL environments. + # By resetting `self.last_batch_action` to -1 for all `self.collector_env_num` environments, + # we force the transformer to start its context from scratch, avoiding incorrect cache lookups. + logging.info('========== collect_forward ============') + logging.info(f'An environment has finished. Active envs: {active_collect_env_num} < Total envs: {self.collector_env_num}. Resetting all.') + self._reset_collect(reset_init_data=True) + + # If the sampling type is 'episode', it's unexpected for the number of active environments to drop, + # as this suggests an inconsistent state or a potential issue in the collection logic. if getattr(self._cfg, 'sample_type', '') == 'episode': - print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + logging.warning('Inconsistent state detected. `sample_type` is "episode", but the number of active environments has changed.') return output @@ -761,23 +1492,29 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model + + # Create a configuration copy for eval MCTS and set specific simulation count + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(self._cfg) + self._mcts_eval = MCTSCtree(mcts_eval_cfg) else: - self._mcts_eval = MCTSPtree(self._cfg) + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num if self._cfg.model.model_type == 'conv': self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) - self.last_batch_action = [-1 for i in range(self.collector_env_num)] + self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)] elif self._cfg.model.model_type == 'mlp': self.last_batch_obs = torch.full( [self.collector_env_num, self._cfg.model.observation_shape], fill_value=self.pad_token_id, ).to(self._cfg.device) - self.last_batch_action = [-1 for i in range(self.collector_env_num)] + self.last_batch_action_eval = [-1 for i in range(self.collector_env_num)] - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [-1], - ready_env_id: np.array = None, timestep: List = [0]) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, timestep: List = [0], task_id: int = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -788,6 +1525,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to eval. - timestep (:obj:`list`): The step index of the env in one episode. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of eval_env, C is the number of channels, \ @@ -808,7 +1546,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ ready_env_id = np.arange(active_eval_env_num) output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep) + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action_eval, data, timestep) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) # if not in training, obtain the scalars of the value/reward @@ -831,10 +1569,9 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ roots_values = roots.get_values() # shape: {list: batch_size} batch_action = [] - + for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] - # print("roots_visit_count_distributions:", distributions, "root_value:", value) # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. @@ -868,12 +1605,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: List = [ } batch_action.append(action) - self.last_batch_obs = data - self.last_batch_action = batch_action + self.last_batch_obs_eval = data + self.last_batch_action_eval = batch_action return output - def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: This method resets the collection process for a specific environment. It clears caches and memory @@ -894,31 +1631,52 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in ) self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # ==================== BUG FIX: Refactored Cache Clearing ==================== + # Clear the specific environment's initial inference cache. + if hasattr(world_model, 'use_new_cache_manager') and world_model.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager to clear per-environment cache + if eid < world_model.env_num: + world_model.kv_cache_manager.init_pools[eid].clear() + logging.info(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end (NEW system).') + else: + # OLD SYSTEM: Use legacy cache dictionary + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + logging.info(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end (OLD system).') + # ============================================================================= # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: - print(f'clear_interval: {clear_interval}') + if current_steps is not None and current_steps % clear_interval == 0: + logging.info(f'clear_interval: {clear_interval}') # Clear various caches in the collect model's world model world_model = self._collect_model.world_model - for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: - kv_cache_dict_env.clear() - world_model.past_kv_cache_recurrent_infer.clear() - world_model.keys_values_wm_list.clear() + # ==================== Phase 1.5: Use unified clear_caches() method ==================== + # This automatically handles both old and new cache systems + world_model.clear_caches() + # ====================================================================================== # Free up GPU memory torch.cuda.empty_cache() - print('collector: collect_model clear()') - print(f'eps_steps_lst[{env_id}]: {current_steps}') + logging.info(f'eps_steps_lst[{env_id}]: {current_steps}, collector: collect_model clear()') - def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True, task_id: int = None) -> None: """ Overview: This method resets the evaluation process for a specific environment. It clears caches and memory @@ -931,37 +1689,80 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. """ if reset_init_data: - self.last_batch_obs = initialize_pad_batch( - self._cfg.model.observation_shape, - self._cfg.evaluator_env_num, - self._cfg.device, - pad_token_id=self.pad_token_id - ) + if task_id is not None: + self.last_batch_obs_eval = initialize_pad_batch( + self._cfg.model.observation_shape_list[task_id], + self._cfg.evaluator_env_num, + self._cfg.device, + pad_token_id=self.pad_token_id + ) + logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + + else: + self.last_batch_obs_eval = initialize_pad_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device, + pad_token_id=self.pad_token_id + ) + logging.info(f'unizero.py task_id:{task_id} after _reset_eval: last_batch_obs_eval:', self.last_batch_obs_eval.shape) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] - # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # ==================== BUG FIX: Refactored Cache Clearing ==================== + # Clear the specific environment's initial inference cache. + if hasattr(world_model, 'use_new_cache_manager') and world_model.use_new_cache_manager: + # NEW SYSTEM: Use KVCacheManager to clear per-environment cache + if eid < world_model.env_num: + world_model.kv_cache_manager.init_pools[eid].clear() + logging.info(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end (NEW system).') + else: + # OLD SYSTEM: Use legacy cache dictionary + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + logging.info(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end (OLD system).') + # ============================================================================= + + # The recurrent cache is global. + # ==================== Phase 1.5: Use unified clear_caches() method ==================== + # This automatically handles both old and new cache systems + world_model.clear_caches() + # ====================================================================================== + + world_model.keys_values_wm_list.clear() + torch.cuda.empty_cache() + return # Determine the clear interval based on the environment's sample type - clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 - + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length # Clear caches if the current steps are a multiple of the clear interval - if current_steps % clear_interval == 0: - print(f'clear_interval: {clear_interval}') + if current_steps is not None and current_steps % clear_interval == 0: + logging.info(f'clear_interval: {clear_interval}') # Clear various caches in the eval model's world model world_model = self._eval_model.world_model - for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: - kv_cache_dict_env.clear() - world_model.past_kv_cache_recurrent_infer.clear() - world_model.keys_values_wm_list.clear() + # ==================== Phase 1.5: Use unified clear_caches() method ==================== + # This automatically handles both old and new cache systems + world_model.clear_caches() + # ====================================================================================== # Free up GPU memory torch.cuda.empty_cache() - print('evaluator: eval_model clear()') - print(f'eps_steps_lst[{env_id}]: {current_steps}') + logging.info('evaluator: eval_model clear()') + logging.info(f'eps_steps_lst[{env_id}]: {current_steps}') def _monitor_vars_learn(self) -> List[str]: """ @@ -969,57 +1770,158 @@ def _monitor_vars_learn(self) -> List[str]: Register the variables to be monitored in learn mode. The registered variables will be logged in tensorboard according to the return value ``_forward_learn``. """ - return [ + base_vars = [ + # ==================== Analysis Metrics ==================== 'analysis/dormant_ratio_encoder', - 'analysis/dormant_ratio_world_model', + 'analysis/dormant_ratio_transformer', + 'analysis/dormant_ratio_head', + 'analysis/avg_weight_mag_encoder', + 'analysis/avg_weight_mag_transformer', + 'analysis/avg_weight_mag_head', + 'analysis/e_rank_last_linear', + 'analysis/e_rank_sim_norm', 'analysis/latent_state_l2_norms', + 'analysis/latent_action_l2_norms', 'analysis/l2_norm_before', 'analysis/l2_norm_after', 'analysis/grad_norm_before', 'analysis/grad_norm_after', + # ==================== Step-wise Loss Analysis ==================== 'analysis/first_step_loss_value', 'analysis/first_step_loss_policy', 'analysis/first_step_loss_rewards', 'analysis/first_step_loss_obs', - 'analysis/middle_step_loss_value', 'analysis/middle_step_loss_policy', 'analysis/middle_step_loss_rewards', 'analysis/middle_step_loss_obs', - 'analysis/last_step_loss_value', 'analysis/last_step_loss_policy', 'analysis/last_step_loss_rewards', 'analysis/last_step_loss_obs', + # ==================== System Metrics ==================== 'Current_GPU', 'Max_GPU', 'collect_epsilon', 'collect_mcts_temperature', 'cur_lr_world_model', - 'cur_lr_tokenizer', + # ==================== Core Losses ==================== 'weighted_total_loss', 'obs_loss', 'policy_loss', 'orig_policy_loss', 'policy_entropy', 'latent_recon_loss', + 'perceptual_loss', 'target_policy_entropy', 'reward_loss', 'value_loss', - 'consistency_loss', 'value_priority', 'target_reward', 'target_value', + 'transformed_target_reward', + 'transformed_target_value', + + # ==================== Gradient Norms ==================== 'total_grad_norm_before_clip_wm', - # tokenizer - 'commitment_loss', - 'reconstruction_loss', - 'perceptual_loss', + + # ==================== Temperature Parameters ==================== + 'temperature_value', + 'temperature_reward', + 'temperature_policy', + + # ==================== Training Configuration ==================== + 'current_policy_label_eps', + 'adaptive_alpha', + 'adaptive_target_entropy_ratio', + 'alpha_loss', + 'current_encoder_clip_value', + ] + + # ==================== [NEW] Norm and Intermediate Tensor Monitoring Variables ==================== + norm_vars = [ + # Module total norms (parameter norms) + 'norm/encoder/_total_norm', + 'norm/transformer/_total_norm', + 'norm/head_value/_total_norm', + 'norm/head_reward/_total_norm', + 'norm/head_policy/_total_norm', + + # Module total norms (gradient norms) + 'grad/encoder/_total_norm', + 'grad/transformer/_total_norm', + 'grad/head_value/_total_norm', + 'grad/head_reward/_total_norm', + 'grad/head_policy/_total_norm', + + # Intermediate tensor x (Transformer output) statistics + 'norm/x_token/mean', + 'norm/x_token/std', + 'norm/x_token/max', + 'norm/x_token/min', + + # Detailed logits statistics (Value) + 'logits/value/mean', + 'logits/value/std', + 'logits/value/max', + 'logits/value/min', + 'logits/value/abs_max', + + # Detailed logits statistics (Policy) + 'logits/policy/mean', + 'logits/policy/std', + 'logits/policy/max', + 'logits/policy/min', + 'logits/policy/abs_max', + + # Detailed logits statistics (Reward) + 'logits/reward/mean', + 'logits/reward/std', + 'logits/reward/max', + 'logits/reward/min', + 'logits/reward/abs_max', + + # Embeddings statistics + 'embeddings/obs/norm_mean', + 'embeddings/obs/norm_std', + 'embeddings/obs/norm_max', + 'embeddings/obs/norm_min', + + ] + + head_clip_vars = [] + # Check if head_clip is enabled and manager exists + if getattr(self, 'use_head_clip', False) and getattr(self, 'head_clip_manager', None) is not None: + # Iterate through all enabled heads and generate corresponding monitoring keys + for head_name in self.head_clip_manager.enabled_heads: + head_clip_vars.append(f'head_clip/{head_name}/max_logits') + head_clip_vars.append(f'head_clip/{head_name}/threshold') + head_clip_vars.append(f'head_clip/{head_name}/scale_factor') + + + enhanced_policy_vars = [ + # Policy logits statistics + 'policy_logits/norm', + 'policy_logits/max', + 'policy_logits/min', + 'policy_logits/std', + # Target policy entropy statistics + 'target_policy_entropy/mean', + 'target_policy_entropy/min', + 'target_policy_entropy/max', + 'target_policy_entropy/std', ] + stability_vars = [ + 'stability/warning_count', # Number of warnings issued in current check + ] + + return base_vars + norm_vars+ head_clip_vars + enhanced_policy_vars + stability_vars + + def _state_dict_learn(self) -> Dict[str, Any]: """ Overview: @@ -1027,11 +1929,16 @@ def _state_dict_learn(self) -> Dict[str, Any]: Returns: - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring. """ - return { + state_dict = { 'model': self._learn_model.state_dict(), 'target_model': self._target_model.state_dict(), 'optimizer_world_model': self._optimizer_world_model.state_dict(), } + # ==================== START: Save Alpha Optimizer State ==================== + if self.use_adaptive_entropy_weight: + state_dict['alpha_optimizer'] = self.alpha_optimizer.state_dict() + # ===================== END: Save Alpha Optimizer State ===================== + return state_dict def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ @@ -1042,7 +1949,6 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: """ self._learn_model.load_state_dict(state_dict['model']) self._target_model.load_state_dict(state_dict['target_model']) - self._optimizer_world_model.load_state_dict(state_dict['optimizer_world_model']) def recompute_pos_emb_diff_and_clear_cache(self) -> None: """ @@ -1054,4 +1960,4 @@ def recompute_pos_emb_diff_and_clear_cache(self) -> None: # If rotary_emb is False, nn.Embedding is used for absolute position encoding. model.world_model.precompute_pos_emb_diff_kv() model.world_model.clear_caches() - torch.cuda.empty_cache() \ No newline at end of file + torch.cuda.empty_cache() diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py new file mode 100644 index 000000000..64a52d8af --- /dev/null +++ b/lzero/policy/unizero_multitask.py @@ -0,0 +1,1978 @@ +import copy +import sys +from collections import defaultdict +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import torch +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import (DiscreteSupport, InverseScalarTransform, + mz_network_output_unpack, phi_transform, prepare_obs, + prepare_obs_stack_for_unizero, scalar_transform, + select_action, to_torch_float_tensor) +from lzero.policy.unizero import UniZeroPolicy, scale_module_weights_vectorized + +from .utils import configure_optimizers_nanogpt, initialize_zeros_batch + +# Please replace the path with the actual location of your LibMTL library. +sys.path.append('/path/to/your/LibMTL') + +import torch.distributed as dist +from LibMTL.weighting.moco_fast_mem_eff import FastMoCoMemEff as FastMoCo +from LibMTL.weighting.moco_fast_mem_eff import MoCoCfg +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect + + +def generate_task_loss_dict(multi_task_losses: List[Union[torch.Tensor, float]], task_name_template: str, task_id: int) -> Dict[str, float]: + """ + Overview: + Generates a dictionary for the losses of each task. + Arguments: + - multi_task_losses (:obj:`List[Union[torch.Tensor, float]]`): A list containing the loss for each task. + - task_name_template (:obj:`str`): The template for the task name, e.g., 'obs_loss_task{}'. + - task_id (:obj:`int`): The starting global task ID for the current rank. Used to offset task indices when generating task names. + Returns: + - task_loss_dict (:obj:`Dict[str, float]`): A dictionary where keys are formatted task names and values are the corresponding losses. + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx + task_id) + try: + # Get the scalar value of the loss if it's a tensor. + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception as e: + task_loss_dict[task_name] = task_loss + return task_loss_dict + + +class WrappedModel: + """ + Overview: + A wrapper for specific components of the world model. + This version is designed to group parameters that are considered "shared" + across tasks for gradient correction methods like MoCo, excluding the prediction heads. + """ + def __init__(self, tokenizer: torch.nn.Module, transformer: torch.nn.Module, pos_emb: torch.nn.Module, task_emb: torch.nn.Module, act_embedding_table: torch.nn.Module): + """ + Arguments: + - tokenizer (:obj:`torch.nn.Module`): The tokenizer module. + - transformer (:obj:`torch.nn.Module`): The transformer backbone. + - pos_emb (:obj:`torch.nn.Module`): The positional embedding module. + - task_emb (:obj:`torch.nn.Module`): The task embedding module. + - act_embedding_table (:obj:`torch.nn.Module`): The action embedding table. + """ + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self) -> iter: + """ + Overview: + Returns an iterator over the parameters of the wrapped components (tokenizer, transformer, embeddings). + These are typically the shared parts of the model whose gradients need to be managed for multi-task learning. + """ + return (list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + # list(self.task_emb.parameters()) + # TODO: Decide whether to include task embeddings in shared parameters. + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none: bool = False) -> None: + """ + Overview: + Sets the gradients of all wrapped components to zero. + Arguments: + - set_to_none (:obj:`bool`): Whether to set gradients to None instead of zero. + """ + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + # self.task_emb.zero_grad(set_to_none=set_to_none) # TODO: Match the decision made in the parameters() method. + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + + +def configure_optimizer_unizero(model, learning_rate, weight_decay, device_type, betas): + """ + Configure optimizer with differentiated learning rates for UniZero model. + (Corrected version ensuring parameter groups are mutually exclusive) + """ + # 1. Create empty parameter lists for grouping + transformer_params = [] + tokenizer_params = [] + head_params = [] + + # 2. Iterate through all trainable parameters, using if/elif/else structure to ensure each parameter is assigned to only one group + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if 'transformer' in name: + transformer_params.append(param) + elif 'tokenizer' in name: + tokenizer_params.append(param) + else: + head_params.append(param) + + # 3. Set different optimizer parameters for each group + # We still use AdamW here, but with more reasonable learning rate settings + optim_groups = [ + { + 'params': transformer_params, + 'lr': learning_rate, # 1e-4 + 'weight_decay': weight_decay + }, + { + 'params': tokenizer_params, + 'lr': learning_rate, # Tokenizer uses base learning rate, e.g., 1e-4 + 'weight_decay': weight_decay + }, + { + 'params': head_params, + 'lr': learning_rate, # Heads also use base learning rate, e.g., 1e-4 + 'weight_decay': weight_decay + + } + ] + + print("--- Optimizer Groups ---") + # Print parameter count for each group for debugging + print(f"Transformer params: {len(transformer_params)}") + print(f"Tokenizer params: {len(tokenizer_params)}") + print(f"Head params: {len(head_params)}") + print(f"Transformer LR: {learning_rate}") + print(f"Tokenizer/Heads LR: {learning_rate}") + + optimizer = torch.optim.AdamW(optim_groups, betas=betas) + return optimizer + +@POLICY_REGISTRY.register('unizero_multitask') +class UniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for multi-task UniZero, an official implementation for the paper "UniZero: Generalized and Efficient Planning + with Scalable Latent World Models". UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations of MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found at: https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero multi-task policy. + config = dict( + type='unizero_multitask', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=50, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='LN', # NOTE: LayerNorm is used in the transformer-based world model. + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: for sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=2, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy. + policy_entropy_weight=1e-4, + # (str) The normalization type for the final layer in both the head and the encoder. + # This option must be the same for both 'final_norm_option_in_head' and 'final_norm_option_in_encoder'. + # Valid options are 'LayerNorm' and 'SimNorm'. + # When set to 'LayerNorm', the 'predict_latent_loss_type' should be 'mse'. + # When set to 'SimNorm', the 'predict_latent_loss_type' should be 'group_kl'. + final_norm_option_in_head="LayerNorm", + final_norm_option_in_encoder="LayerNorm", + # (str) The type of loss function for predicting latent variables. + # Options are 'mse' (Mean Squared Error) or 'group_kl' (Group Kullback-Leibler divergence). + # This choice is dependent on the normalization method selected above. + predict_latent_loss_type='mse', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (bool) Whether to analyze dormant ratio, average_weight_magnitude of net, effective_rank of latent. + analysis_dormant_ratio_weight_rank=False, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.01, + share_head=False, + ), + ), + # ****** common ****** + # (bool) Whether to enable adaptive policy entropy weight (alpha) + use_adaptive_entropy_weight=True, + # (float) Learning rate for adaptive alpha optimizer + adaptive_entropy_alpha_lr=1e-3, + # (float) Target entropy ratio at the start of training (higher = more exploration) + target_entropy_start_ratio=0.98, + # (float) Target entropy ratio at the end of training (lower = more exploitation) + target_entropy_end_ratio=0.1, + # (int) Number of training steps to decay target entropy from start to end ratio + target_entropy_decay_steps=150000, + + # ==================== START: Encoder-Clip Annealing Config ==================== + # (bool) Whether to enable annealing for encoder-clip values. + use_encoder_clip_annealing=True, + # (str) Annealing type. Options: 'linear' or 'cosine'. + encoder_clip_anneal_type='cosine', + # (float) Starting clip value for annealing (looser in early training). + encoder_clip_start_value=30.0, + # (float) Ending clip value for annealing (stricter in later training). + encoder_clip_end_value=10.0, + # (int) Training iteration steps required to complete annealing from start to end value. + encoder_clip_anneal_steps=100000, # e.g., reach final value after 100k iterations + # (float) Fixed latent norm clip threshold (used when encoder_clip_annealing is disabled) + latent_norm_clip_threshold=30.0, + # ===================== END: Encoder-Clip Annealing Config ===================== + + # ==================== START: Policy Label Smoothing Config ==================== + # (float) Starting epsilon value for policy label smoothing (higher = more smoothing) + policy_ls_eps_start=0.05, + # (float) Ending epsilon value for policy label smoothing (lower = less smoothing) + policy_ls_eps_end=0.01, + # (int) Number of training steps to decay label smoothing epsilon from start to end + policy_ls_eps_decay_steps=50000, + # ===================== END: Policy Label Smoothing Config ===================== + + # ==================== START: Learning Rate Scheduler Config ==================== + # (int) Total training iterations for cosine annealing LR scheduler (only used when cos_lr_scheduler=True) + total_iterations=500000, + # (float) Final learning rate for cosine annealing LR scheduler (only used when cos_lr_scheduler=True) + final_learning_rate=1e-6, + # ===================== END: Learning Rate Scheduler Config ===================== + + # ==================== START: Monitoring Config ==================== + # (int) Frequency of monitoring model parameter and gradient norms (in training iterations). Set to 0 to disable. + monitor_norm_freq=10000, + # ===================== END: Monitoring Config ===================== + + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=True, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(5e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=5, + # (int) The number of episodes in each collecting stage when use muzero_collector. + n_episode=8, + # (int) The number of num_segments in each collecting stage when use muzero_segment_collector. + num_segments=8, + # # (int) the number of simulations in MCTS for renalyze. + num_simulations=50, + # (int) The number of simulations in MCTS for the collect phase. + collect_num_simulations=25, + # (int) The number of simulations in MCTS for the eval phase. + eval_num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + cos_lr_scheduler=False, + piecewise_decay_lr_scheduler=False, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + # (bool) Whether to use allocated batch size. + allocated_batch_sizes=False, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm's default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): A tuple containing the model name and a list of import paths. + - model_type (:obj:`str`): The model type used in this algorithm, registered in ModelRegistry. + - import_names (:obj:`List[str]`): The list of model class paths used in this algorithm. + .. note:: + Users can define and use customized network models, but they must adhere to the same interface definition + as indicated by the import_names path. For multi-task UniZero, this is ``lzero.model.unizero_model_multitask.UniZeroMTModel``. + """ + # NOTE: This specifies the default multi-task model. + return 'UniZeroMTModel', ['lzero.model.unizero_model_multitask'] + + # ==================== Model Norm Monitoring Function ==================== + def _monitor_model_norms(self) -> Dict[str, float]: + """ + Overview: + Calculate and return parameter matrix norms for key model components (Encoder, Transformer, Heads). + This function should be called within a torch.no_grad() context for efficiency. + Returns: + - norm_metrics (:obj:`Dict[str, float]`): Dictionary containing all norm metrics for logging. + """ + world_model = self._learn_model.world_model + norm_metrics = {} + + # Define module groups to monitor + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value_multi_task, # Note: multi-task uses head_value (plural) + 'head_reward': world_model.head_rewards_multi_task, + 'head_policy': world_model.head_policy_multi_task, # Note: multi-task uses head_policies (plural) + } + + for group_name, group_module in module_groups.items(): + # Handle ModuleList (for multi-task heads) + if isinstance(group_module, torch.nn.ModuleList): + for task_idx, task_module in enumerate(group_module): + total_norm_sq = 0.0 + for param_name, param in task_module.named_parameters(): + if param.requires_grad: + param_norm = param.data.norm(2).item() + log_name = f'norm/{group_name}_task{task_idx}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}_task{task_idx}/_total_norm'] = total_group_norm + else: + # Handle single module + total_norm_sq = 0.0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad: + param_norm = param.data.norm(2).item() + log_name = f'norm/{group_name}/{param_name.replace(".", "/")}' + norm_metrics[log_name] = param_norm + total_norm_sq += param_norm ** 2 + total_group_norm = np.sqrt(total_norm_sq) + norm_metrics[f'norm/{group_name}/_total_norm'] = total_group_norm + + return norm_metrics + + def _monitor_gradient_norms(self) -> Dict[str, float]: + """ + Overview: + Calculate and return gradient norms for key model components. + This function should be called after gradient computation and before parameter updates. + Returns: + - grad_metrics (:obj:`Dict[str, float]`): Dictionary containing all gradient norm metrics for logging. + """ + world_model = self._learn_model.world_model + grad_metrics = {} + + # Define module groups to monitor + module_groups = { + 'encoder': world_model.tokenizer.encoder, + 'transformer': world_model.transformer, + 'head_value': world_model.head_value_multi_task, + 'head_reward': world_model.head_rewards_multi_task, + 'head_policy': world_model.head_policy_multi_task, + } + + for group_name, group_module in module_groups.items(): + # Handle ModuleList (for multi-task heads) + if isinstance(group_module, torch.nn.ModuleList): + for task_idx, task_module in enumerate(group_module): + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + for param_name, param in task_module.named_parameters(): + if param.requires_grad and param.grad is not None: + grad_norm = param.grad.data.norm(2).item() + log_name = f'grad/{group_name}_task{task_idx}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}_task{task_idx}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}_task{task_idx}/_total_norm'] = 0.0 + else: + # Handle single module + total_grad_norm_sq = 0.0 + num_params_with_grad = 0 + for param_name, param in group_module.named_parameters(): + if param.requires_grad and param.grad is not None: + grad_norm = param.grad.data.norm(2).item() + log_name = f'grad/{group_name}/{param_name.replace(".", "/")}' + grad_metrics[log_name] = grad_norm + total_grad_norm_sq += grad_norm ** 2 + num_params_with_grad += 1 + if num_params_with_grad > 0: + total_group_grad_norm = np.sqrt(total_grad_norm_sq) + grad_metrics[f'grad/{group_name}/_total_norm'] = total_group_grad_norm + else: + grad_metrics[f'grad/{group_name}/_total_norm'] = 0.0 + + return grad_metrics + # ================================================================= + + def _init_learn(self) -> None: + """ + Overview: + Initializes the learn mode. This method is called by ``self.__init__``. + It sets up the learn model, optimizer, target model, and other utilities required for training. + """ + if self._cfg.optim_type == 'SGD': + # Configure SGD optimizer + self._optimizer_world_model = torch.optim.SGD( + self._model.world_model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay + ) + elif self._cfg.optim_type == 'AdamW': + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + elif self._cfg.optim_type == 'AdamW_mix_lr_wdecay': + self._optimizer_world_model = configure_optimizer_unizero( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + if self._cfg.cos_lr_scheduler: + from torch.optim.lr_scheduler import CosineAnnealingLR + total_iters = self._cfg.total_iterations # 500k iter + final_lr = self._cfg.final_learning_rate + + self.lr_scheduler = CosineAnnealingLR( + self._optimizer_world_model, + T_max=total_iters, + eta_min=final_lr + ) + print(f"CosineAnnealingLR enabled: T_max={total_iters}, eta_min={final_lr}") + + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer_world_model, lr_lambda=lr_lambda) + + + # Use a deep copy for the target model. + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is >= 2.0 for torch.compile. + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + + # Wrap the target model for soft updates (momentum-based). + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + + + self.task_id = self._cfg.task_id + self.task_num_for_current_rank = self._cfg.task_num + + print(f'self._cfg.only_use_moco_stats:{self._cfg.only_use_moco_stats}') + if self._cfg.use_moco or self._cfg.only_use_moco_stats: + # The prediction heads' gradients are not corrected. + self.wrapped_model = WrappedModel( + # TODO: This assumes the tokenizer has an encoder attribute which is a list. This might need to be more robust. + self._learn_model.world_model.tokenizer.encoder[0], + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + + # Pass the wrapped_model as `shared_module` to the gradient correction method. + # ========= Initialize MoCo/CAGrad parameters ========= + if self._cfg.moco_version=="v0": + # This version is only compatible with single-GPU training. + self.grad_correct = GradCorrect(self.wrapped_model, self._cfg.total_task_num, self._cfg.device, self._cfg.multi_gpu) + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + elif self._cfg.moco_version=="v1": + cfg_moco = MoCoCfg( + beta0=0.9, beta_sigma=0.95, + gamma0=0.1, gamma_sigma=0.95, + rho=0.01, stat_interval=10000) + self.grad_correct = FastMoCo( + shared_module=self.wrapped_model, + world_task_num=self._cfg.total_task_num, # Total number of tasks globally + device=self._cfg.device, + multi_gpu=self._cfg.multi_gpu, + cfg=cfg_moco, + ) + + # Cache for plasticity-related metrics from the previous frame. + self._prev_plasticity_metrics = dict( + dormant_ratio_encoder = 0.0, + dormant_ratio_transformer = 0.0, + dormant_ratio_head = 0.0, + avg_weight_mag_encoder = 0.0, + avg_weight_mag_transformer = 0.0, + avg_weight_mag_head = 0.0, + e_rank_last_linear = 0.0, + e_rank_sim_norm = 0.0, + ) + + # ==================== START: Target Entropy Regularization Initialization ==================== + # Read whether to enable adaptive alpha from config, and provide a default value + self.use_adaptive_entropy_weight = self._cfg.use_adaptive_entropy_weight + + # Add configuration in _init_learn + self.target_entropy_start_ratio = self._cfg.target_entropy_start_ratio + self.target_entropy_end_ratio = self._cfg.target_entropy_end_ratio + self.target_entropy_decay_steps = self._cfg.target_entropy_decay_steps # e.g., complete annealing within 200k steps (2M envsteps) + + if self.use_adaptive_entropy_weight: + # 1. Set target entropy. For discrete action spaces, a common heuristic is the negative logarithm + # of action space dimension multiplied by a coefficient. + # This coefficient (e.g., 0.98) can be used as a hyperparameter. + action_space_size = self._cfg.model.action_space_size + self.target_entropy = -np.log(1.0 / action_space_size) * 0.98 + + # 2. Initialize a learnable log_alpha parameter. + # Initialized to 0, meaning initial alpha = exp(0) = 1.0. + self.log_alpha = torch.nn.Parameter(torch.zeros(1, device=self._cfg.device), requires_grad=True) + + # 3. Create a dedicated optimizer for log_alpha. + # Using a smaller learning rate (e.g., 1e-4) different from the main optimizer is usually more stable. + alpha_lr = self._cfg.adaptive_entropy_alpha_lr + self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr) + + print("="*20) + print(">>> Target Entropy Regularization (Adaptive Alpha) Enabled <<<") + print(f" Target Entropy: {self.target_entropy:.4f}") + print(f" Alpha Optimizer Learning Rate: {alpha_lr:.2e}") + print("="*20) + # ===================== END: Target Entropy Regularization Initialization ===================== + + self.latent_norm_clip_threshold = self._cfg.latent_norm_clip_threshold + # ==================== START: Initialize Encoder-Clip Annealing Parameters ==================== + self.use_encoder_clip_annealing = self._cfg.use_encoder_clip_annealing + if self.use_encoder_clip_annealing: + self.encoder_clip_anneal_type = self._cfg.encoder_clip_anneal_type + self.encoder_clip_start = self._cfg.encoder_clip_start_value + self.encoder_clip_end = self._cfg.encoder_clip_end_value + self.encoder_clip_anneal_steps = self._cfg.encoder_clip_anneal_steps + + print("="*20) + print(">>> Encoder-Clip Annealing Enabled <<<") + print(f" Type: {self.encoder_clip_anneal_type}") + print(f" Range: {self.encoder_clip_start} -> {self.encoder_clip_end}") + print(f" Steps: {self.encoder_clip_anneal_steps}") + print("="*20) + else: + # If annealing is not enabled, use a fixed clip threshold + self.latent_norm_clip_threshold = self._cfg.latent_norm_clip_threshold + # ===================== END: Initialize Encoder-Clip Annealing Parameters ===================== + + # Policy Label Smoothing Parameters + self.policy_ls_eps_start = self._cfg.policy_ls_eps_start # TODO policy_label_smoothing_eps_start: larger action space requires larger eps + self.policy_ls_eps_end = self._cfg.policy_ls_eps_end # TODO policy_label_smoothing_eps_start + self.policy_ls_eps_decay_steps = self._cfg.policy_ls_eps_decay_steps # TODO 50k + print(f"self.policy_ls_eps_start:{self.policy_ls_eps_start}") + + @staticmethod + def _is_zero(x: Union[float, torch.Tensor], eps: float = 1e-8) -> bool: + """ + Overview: + Checks if a scalar or a 0-D tensor can be considered zero within a small tolerance. + Arguments: + - x (:obj:`Union[float, torch.Tensor]`): The input value to check. + - eps (:obj:`float`): The tolerance for checking against zero. + Returns: + - (:obj:`bool`): True if the value is close to zero, False otherwise. + """ + if isinstance(x, torch.Tensor): + return torch.all(torch.abs(x) < eps).item() + return abs(x) < eps + + def _retain_prev_if_zero(self, name: str, + value: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Overview: + If the current `value` is close to zero, returns the cached value from the previous frame. + Otherwise, it updates the cache with the current value and returns it. This is useful for + metrics that are computed intermittently. + Arguments: + - name (:obj:`str`): The name of the metric to cache. + - value (:obj:`Union[float, torch.Tensor]`): The current value of the metric. + Returns: + - (:obj:`Union[float, torch.Tensor]`): The retained or current value. + """ + if self._is_zero(value): + # Directly return the previous value (can be float or tensor). + return self._prev_plasticity_metrics[name] + else: + # Update the cache and return the current value. + self._prev_plasticity_metrics[name] = value + return value + + + def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, train_iter=None, ignore_grad=False) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning in the policy. This is the core of the training process. + Data is sampled from the replay buffer, losses are calculated, and the model is updated via backpropagation. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): A tuple of data batches, where each element corresponds to a different task. + - task_weights (:obj:`Any`, optional): Optional weights for each task's loss. Not currently used. + - ignore_grad (:obj:`bool`): If True, gradients are zeroed out after computation, effectively skipping the update. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): A dictionary containing current learning losses and statistics for logging. + """ + self._learn_model.train() + self._target_model.train() + + # Lists to store metrics for each task within the batch. + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + weighted_total_loss = 0.0 # Initialize to 0.0 to avoid in-place operations. + total_alpha_loss = 0.0 + + latent_state_l2_norms_multi_task = [] + average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + + # Metrics for network plasticity analysis. + dormant_ratio_encoder_multi_task = [] + dormant_ratio_transformer_multi_task = [] + dormant_ratio_head_multi_task = [] + avg_weight_mag_encoder_multi_task = [] + avg_weight_mag_transformer_multi_task = [] + avg_weight_mag_head_multi_task = [] + e_rank_last_linear_multi_task = [] + e_rank_sim_norm_multi_task = [] + + current_policy_label_eps = 0.01 + + # Add a list to collect real global IDs of all tasks in the current batch + global_task_ids_in_batch = [] + alpha_loss = None + + + # New lists for Alpha logging + alpha_loss_multi_task = [] + target_entropy_multi_task = [] + + # Pre-fetch current alpha value only when adaptive alpha is enabled, ensuring consistency across all tasks in a single iteration + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight + if self.use_adaptive_entropy_weight: + current_alpha = self.log_alpha.exp().detach() + + losses_list = [] # Used to store the loss tensor for each task, required by gradient correction methods. + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task # task_id is the real global ID + + # Add the real global ID to the list + global_task_ids_in_batch.append(task_id) + + # TODO: Adapt RoPE for multitask settings (using timestep_batch). + obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number. + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Apply augmentations if needed. + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to a torch tensor. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space. + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, + weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + + cur_batch_size = target_reward.size(0) # Run-time batch size. + + target_reward = target_reward.view(cur_batch_size, -1) + target_value = target_value.view(cur_batch_size, -1) + + # Transform scalar rewards and values to their scaled representations. + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert scaled representations to categorical distributions. + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward, label_smoothing_eps= self._cfg.label_smoothing_eps) + target_value_categorical = phi_transform(self.value_support, transformed_target_value, label_smoothing_eps=self._cfg.label_smoothing_eps) + + + # Prepare the batch for the transformer-based world model. + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + cur_batch_size, -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + cur_batch_size, -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data. + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + batch_for_gpt['scalar_target_value'] = target_value + + # Extract valid target policy data and compute its entropy. + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model and compute losses. + intermediate_losses = defaultdict(float) + + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.value_inverse_scalar_transform_handle, current_policy_label_eps=current_policy_label_eps, task_id=task_id + ) + + # Extract the calculated value_priority from the returned losses. + value_priority_tensor = losses.intermediate_losses['value_priority'] + # Convert to numpy array for the replay buffer, adding a small epsilon. + value_priority_np = value_priority_tensor.detach().cpu().numpy() + 1e-6 + + + # TODO: Accumulate the weighted total loss. This assumes the loss from `compute_loss` is already weighted. + weighted_total_loss += losses.loss_total # NOTE:+= + + # TODO: Append the total loss for this task, used by MoCo. + losses_list.append(losses.loss_total) + + for loss_name, loss_value in losses.intermediate_losses.items(): + intermediate_losses[f"{loss_name}"] = loss_value + + + + obs_loss = intermediate_losses['loss_obs'] + reward_loss = intermediate_losses['loss_rewards'] + policy_loss = intermediate_losses['loss_policy'] + orig_policy_loss = intermediate_losses['orig_policy_loss'] + policy_entropy = intermediate_losses['policy_entropy'] + value_loss = intermediate_losses['loss_value'] + latent_recon_loss = intermediate_losses['latent_recon_loss'] + perceptual_loss = intermediate_losses['perceptual_loss'] + latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] + + # ==================== START: Target Entropy Regularization Update Logic ==================== + current_alpha = self._cfg.model.world_model_cfg.policy_entropy_weight # Default to fixed value + if self.use_adaptive_entropy_weight: + + # Dynamically calculate target entropy (this logic is correct and preserved) + progress = min(1.0, train_iter / self.target_entropy_decay_steps) + current_ratio = self.target_entropy_start_ratio * (1 - progress) + self.target_entropy_end_ratio * progress + action_space_size = self._cfg.model.action_space_size + # Note: We define target_entropy as a positive number, which is more intuitive + current_target_entropy = -np.log(1.0 / action_space_size) * current_ratio + + # Calculate alpha_loss (corrected sign) + # This is the core correction: removed the negative sign at the front + # detach() is still critical to ensure alpha_loss gradient only flows to log_alpha + alpha_loss_task = (self.log_alpha * (policy_entropy.detach() - current_target_entropy)).mean() # NOTE:= + + + # Accumulate alpha_loss + total_alpha_loss += alpha_loss_task + # Collect each task's alpha_loss and target_entropy for logging + alpha_loss_multi_task.append(alpha_loss_task) + target_entropy_multi_task.append(current_target_entropy) + + # [Optimization suggestion] Add log_alpha clipping as a safety measure + with torch.no_grad(): + # Limit alpha to a range, e.g., [1e-4, 10.0] + self.log_alpha.clamp_(np.log(5e-3), np.log(10.0)) + + + # Use current updated alpha (with gradient flow truncated) + current_alpha = self.log_alpha.exp().detach() + + # Recalculate weighted policy loss and total loss + # Note: policy_entropy here is already an average value of a batch + weighted_policy_loss = orig_policy_loss - current_alpha * policy_entropy + # Rebuild total loss (not using losses.loss_total) + # Ensure the weights here are consistent with the calculation in LossWithIntermediateLosses class + self.obs_loss_weight = 10 + self.value_loss_weight = 0.5 + self.reward_loss_weight = 1. + self.policy_loss_weight = 1. + self.ends_loss_weight = 0. + total_loss = ( + self.reward_loss_weight * reward_loss + + self.value_loss_weight * value_loss + + self.policy_loss_weight * weighted_policy_loss + + self.obs_loss_weight * obs_loss + ) + weighted_total_loss += (weights * total_loss).mean() # NOTE:+= + # ===================== END: Target Entropy Regularization Update Logic ===================== + + # Metrics related to network plasticity. + # Use the helper function to retain the previous value if the current one is zero. + dormant_ratio_encoder = self._retain_prev_if_zero( + 'dormant_ratio_encoder', + intermediate_losses['dormant_ratio_encoder']) + dormant_ratio_transformer = self._retain_prev_if_zero( + 'dormant_ratio_transformer', + intermediate_losses['dormant_ratio_transformer']) + dormant_ratio_head = self._retain_prev_if_zero( + 'dormant_ratio_head', + intermediate_losses['dormant_ratio_head']) + avg_weight_mag_encoder = self._retain_prev_if_zero( + 'avg_weight_mag_encoder', + intermediate_losses['avg_weight_mag_encoder']) + avg_weight_mag_transformer = self._retain_prev_if_zero( + 'avg_weight_mag_transformer', + intermediate_losses['avg_weight_mag_transformer']) + avg_weight_mag_head = self._retain_prev_if_zero( + 'avg_weight_mag_head', + intermediate_losses['avg_weight_mag_head']) + e_rank_last_linear = self._retain_prev_if_zero( + 'e_rank_last_linear', + intermediate_losses['e_rank_last_linear']) + e_rank_sim_norm = self._retain_prev_if_zero( + 'e_rank_sim_norm', + intermediate_losses['e_rank_sim_norm']) + + # Append all metrics for this task to their respective lists. + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + value_priority_multi_task.append(value_priority_tensor) + value_priority_mean_multi_task.append(value_priority_tensor.mean().item()) + + # Append plasticity metrics. + dormant_ratio_encoder_multi_task.append(dormant_ratio_encoder) + dormant_ratio_transformer_multi_task.append(dormant_ratio_transformer) + dormant_ratio_head_multi_task.append(dormant_ratio_head) + avg_weight_mag_encoder_multi_task.append(avg_weight_mag_encoder) + avg_weight_mag_transformer_multi_task.append(avg_weight_mag_transformer) + avg_weight_mag_head_multi_task.append(avg_weight_mag_head) + e_rank_last_linear_multi_task.append(e_rank_last_linear) + e_rank_sim_norm_multi_task.append(e_rank_sim_norm) + + + # ==================== Integrate norm monitoring logic ==================== + norm_log_dict = {} + # Check if monitoring frequency is reached + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + with torch.no_grad(): + # 1. Monitor model parameter norms + param_norm_metrics = self._monitor_model_norms() + norm_log_dict.update(param_norm_metrics) + + # 2. Monitor intermediate tensor x (Transformer output) + intermediate_x = losses.intermediate_losses.get('intermediate_tensor_x') + if intermediate_x is not None: + # x shape is (B, T, E) + # Calculate L2 norm for each token + token_norms = intermediate_x.norm(p=2, dim=-1) + + # Record statistics of these norms + norm_log_dict['norm/x_token/mean'] = token_norms.mean().item() + norm_log_dict['norm/x_token/std'] = token_norms.std().item() + norm_log_dict['norm/x_token/max'] = token_norms.max().item() + norm_log_dict['norm/x_token/min'] = token_norms.min().item() + + # 3. Monitor detailed statistics of logits (Value, Policy, Reward) + logits_value = losses.intermediate_losses.get('logits_value') + if logits_value is not None: + norm_log_dict['logits/value/mean'] = logits_value.mean().item() + norm_log_dict['logits/value/std'] = logits_value.std().item() + norm_log_dict['logits/value/max'] = logits_value.max().item() + norm_log_dict['logits/value/min'] = logits_value.min().item() + norm_log_dict['logits/value/abs_max'] = logits_value.abs().max().item() + + logits_policy = losses.intermediate_losses.get('logits_policy') + if logits_policy is not None: + norm_log_dict['logits/policy/mean'] = logits_policy.mean().item() + norm_log_dict['logits/policy/std'] = logits_policy.std().item() + norm_log_dict['logits/policy/max'] = logits_policy.max().item() + norm_log_dict['logits/policy/min'] = logits_policy.min().item() + norm_log_dict['logits/policy/abs_max'] = logits_policy.abs().max().item() + + logits_reward = losses.intermediate_losses.get('logits_reward') + if logits_reward is not None: + norm_log_dict['logits/reward/mean'] = logits_reward.mean().item() + norm_log_dict['logits/reward/std'] = logits_reward.std().item() + norm_log_dict['logits/reward/max'] = logits_reward.max().item() + norm_log_dict['logits/reward/min'] = logits_reward.min().item() + norm_log_dict['logits/reward/abs_max'] = logits_reward.abs().max().item() + + # 4. Monitor obs_embeddings (Encoder output) statistics + obs_embeddings = losses.intermediate_losses.get('obs_embeddings') + if obs_embeddings is not None: + # Calculate L2 norm for each embedding + emb_norms = obs_embeddings.norm(p=2, dim=-1) + norm_log_dict['embeddings/obs/norm_mean'] = emb_norms.mean().item() + norm_log_dict['embeddings/obs/norm_std'] = emb_norms.std().item() + norm_log_dict['embeddings/obs/norm_max'] = emb_norms.max().item() + norm_log_dict['embeddings/obs/norm_min'] = emb_norms.min().item() + # ================================================================= + + # Core learn model update step. + self._optimizer_world_model.zero_grad() + + if self.use_adaptive_entropy_weight: + self.alpha_optimizer.zero_grad() + # 2. Calculate final alpha loss (average after accumulation) + final_alpha_loss = None + if self.use_adaptive_entropy_weight: + if len(data) > 0: + final_alpha_loss = total_alpha_loss / len(data) + else: # Defensive programming to avoid division by zero + final_alpha_loss = torch.tensor(0.0, device=self._cfg.device) + + # Assuming losses_list is a list of tensors with gradients, e.g., [loss1, loss2, ...]. + if self._cfg.use_moco: + # Call MoCo's backward method, which handles gradient correction internally. + if self._cfg.moco_version=="v0": + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + elif self._cfg.moco_version=="v1": + lambd, stats = self.grad_correct.backward(losses_list) + + # Separate backward pass for alpha loss + if self.use_adaptive_entropy_weight: + final_alpha_loss.backward() + + elif self._cfg.only_use_moco_stats: + # Only compute MoCo stats without applying gradient correction. + lambd, stats = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + + # If adaptive alpha is enabled, add alpha loss to main loss for joint backward pass + if self.use_adaptive_entropy_weight: + (weighted_total_loss + final_alpha_loss).backward() + elif weighted_total_loss != 0.0: # Ensure there is loss to backpropagate + weighted_total_loss.backward() + + else: + # If not using gradient correction, each rank performs standard backpropagation. + lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) + + # If adaptive alpha is enabled, add alpha loss to main loss for joint backward pass + if self.use_adaptive_entropy_weight: + (weighted_total_loss + final_alpha_loss).backward() + elif weighted_total_loss != 0.0: # Ensure there is loss to backpropagate + weighted_total_loss.backward() + + # Still executed within torch.no_grad() context + # ================================================================= + with torch.no_grad(): + # 1. Encoder-Clip + # ==================== START: Dynamically calculate current Clip threshold ==================== + current_clip_value = self.latent_norm_clip_threshold # Default to fixed value + if self.use_encoder_clip_annealing: + progress = min(1.0, train_iter / self.encoder_clip_anneal_steps) + + if self.encoder_clip_anneal_type == 'cosine': + # Cosine schedule: smoothly transition from 1 to 0 + cosine_progress = 0.5 * (1.0 + np.cos(np.pi * progress)) + current_clip_value = self.encoder_clip_end + \ + (self.encoder_clip_start - self.encoder_clip_end) * cosine_progress + else: # Default to linear schedule + current_clip_value = self.encoder_clip_start * (1 - progress) + \ + self.encoder_clip_end * progress + # ===================== END: Dynamically calculate current Clip threshold ===================== + + # 1. Encoder-Clip (using dynamically calculated current_clip_value) + if current_clip_value > 0 and 'obs_embeddings' in losses.intermediate_losses: + obs_embeddings = losses.intermediate_losses['obs_embeddings'] + if obs_embeddings is not None: + max_latent_norm = obs_embeddings.norm(p=2, dim=-1).max() + if max_latent_norm > current_clip_value: + scale_factor = current_clip_value / max_latent_norm.item() + # No longer print frequently, or can be changed to print every N steps + if train_iter % 1000 == 0: + print(f"[Encoder-Clip Annealing] Iter {train_iter}: Max latent norm {max_latent_norm.item():.2f} > {current_clip_value:.2f}. Scaling by {scale_factor:.4f}.") + scale_module_weights_vectorized(self._model.world_model.tokenizer.encoder, scale_factor) + + + # ==================== Monitor gradient norms ==================== + # Monitor gradient norms before gradient clipping to diagnose gradient explosion/vanishing issues + if self._cfg.monitor_norm_freq > 0 and (train_iter == 0 or (train_iter % self._cfg.monitor_norm_freq == 0)): + grad_norm_metrics = self._monitor_gradient_norms() + norm_log_dict.update(grad_norm_metrics) + # ================================================================= + + if self._cfg.analysis_sim_norm: + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), + self._cfg.grad_clip_value) + + if ignore_grad: + # NOTE: For cases where all tasks on a GPU are solved, `train` is still called for DDP synchronization, + # but gradients should be zeroed out to prevent updates. + self._optimizer_world_model.zero_grad() + + if self._cfg.multi_gpu: + # If not using a gradient correction method that handles it, sync gradients manually. + if not self._cfg.use_moco: + self.sync_gradients(self._learn_model) + + self._optimizer_world_model.step() + + # 4. Update Alpha optimizer + if self.use_adaptive_entropy_weight: + self.alpha_optimizer.step() + # Clip log_alpha to ensure stability + with torch.no_grad(): + self.log_alpha.clamp_(np.log(1e-4), np.log(10.0)) + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # Core target model update step. + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # Build the dictionary of return values for logging. + return_log_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # ==================== START: Add new log items ==================== + if self.use_adaptive_entropy_weight: + return_log_dict['adaptive_alpha'] = current_alpha.item() + return_log_dict['adaptive_target_entropy_ratio'] = current_ratio + return_log_dict['final_alpha_loss'] = final_alpha_loss.item() + # ===================== END: Add new log items ===================== + + # Generate task-related loss dictionaries and prefix each task-related loss with "noreduce_". + multi_task_loss_dicts = { + **generate_task_loss_dict(obs_loss_multi_task, 'noreduce_obs_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'noreduce_latent_recon_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(perceptual_loss_multi_task, 'noreduce_perceptual_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'noreduce_latent_state_l2_norms_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + + **generate_task_loss_dict(policy_loss_multi_task, 'noreduce_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'noreduce_orig_policy_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(policy_entropy_multi_task, 'noreduce_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(reward_loss_multi_task, 'noreduce_reward_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_loss_multi_task, 'noreduce_value_loss_task{}', task_id=self.task_id), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'noreduce_target_policy_entropy_task{}', task_id=self.task_id), + **generate_task_loss_dict(lambd, 'noreduce_lambd_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_multi_task, 'noreduce_value_priority_task{}', task_id=self.task_id), + **generate_task_loss_dict(value_priority_mean_multi_task, 'noreduce_value_priority_mean_task{}', task_id=self.task_id), + + # Add alpha related logs + **generate_task_loss_dict(alpha_loss_multi_task, 'noreduce_alpha_loss_task{}', self.task_id), + **generate_task_loss_dict(target_entropy_multi_task, 'noreduce_target_entropy_task{}', self.task_id), + } + return_log_dict.update(multi_task_loss_dicts) + + + if self._learn_model.world_model.do_analysis: + # Include plasticity metrics if analysis is enabled. + plasticity_loss_dicts = { + **generate_task_loss_dict(dormant_ratio_encoder_multi_task, 'noreduce_dormant_ratio_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_transformer_multi_task, 'noreduce_dormant_ratio_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(dormant_ratio_head_multi_task, 'noreduce_dormant_ratio_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_encoder_multi_task, 'noreduce_avg_weight_mag_encoder_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_transformer_multi_task, 'noreduce_avg_weight_mag_transformer_task{}', task_id=self.task_id), + **generate_task_loss_dict(avg_weight_mag_head_multi_task, 'noreduce_avg_weight_mag_head_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_last_linear_multi_task, 'noreduce_e_rank_last_linear_task{}', task_id=self.task_id), + **generate_task_loss_dict(e_rank_sim_norm_multi_task, 'noreduce_e_rank_sim_norm_task{}', task_id=self.task_id), + } + # Merge the dictionaries. + return_log_dict.update(plasticity_loss_dicts) + + # Merge norm monitoring results into the log + if norm_log_dict: + return_log_dict.update(norm_log_dict) + + # Return the final loss dictionary. + return return_log_dict + + def monitor_weights_and_grads(self, model: torch.nn.Module) -> None: + """ + Overview: + A utility function to print the mean and standard deviation of weights and their gradients for each layer in a model. + Useful for debugging training issues like exploding or vanishing gradients. + Arguments: + - model (:obj:`torch.nn.Module`): The model to monitor. + """ + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Initializes the collect mode. This method is called by ``self.__init__``. + It sets up the collect model and MCTS utilities for data collection. + """ + self._collect_model = self._model + + # Create a copy of the configuration for collect MCTS and set a specific number of simulations. + mcts_collect_cfg = copy.deepcopy(self._cfg) + mcts_collect_cfg.num_simulations = self._cfg.collect_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(mcts_collect_cfg) + else: + self._mcts_collect = MCTSPtree(mcts_collect_cfg) + + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # TODO: The num_tasks parameter is hardcoded. It should ideally be derived from the config. + def _monitor_vars_learn(self, num_tasks: int = 2) -> List[str]: + """ + Overview: + Registers variables to be monitored during training. These variables will be logged in TensorBoard. + It dynamically creates variable names for each task if `num_tasks` is provided. + Arguments: + - num_tasks (:obj:`int`): The number of tasks being trained on the current rank. + Returns: + - monitored_vars (:obj:`List[str]`): A list of strings, where each string is the name of a variable to be logged. + """ + # Basic monitored variables that do not depend on the number of tasks. + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + + 'adaptive_alpha', + "adaptive_target_entropy_ratio", + 'final_alpha_loss', + ] + + # ==================== Norm and Intermediate Tensor Monitoring Variables ==================== + # These variables are shared across all tasks (not per-task) + norm_vars = [ + # Module total norms (parameter norms) - shared modules + 'norm/encoder/_total_norm', + 'norm/transformer/_total_norm', + + # Module total norms (gradient norms) - shared modules + 'grad/encoder/_total_norm', + 'grad/transformer/_total_norm', + + # Intermediate tensor x (Transformer output) statistics + 'norm/x_token/mean', + 'norm/x_token/std', + 'norm/x_token/max', + 'norm/x_token/min', + + # Detailed logits statistics (Value) + 'logits/value/mean', + 'logits/value/std', + 'logits/value/max', + 'logits/value/min', + 'logits/value/abs_max', + + # Detailed logits statistics (Policy) + 'logits/policy/mean', + 'logits/policy/std', + 'logits/policy/max', + 'logits/policy/min', + 'logits/policy/abs_max', + + # Detailed logits statistics (Reward) + 'logits/reward/mean', + 'logits/reward/std', + 'logits/reward/max', + 'logits/reward/min', + 'logits/reward/abs_max', + + # Embeddings statistics + 'embeddings/obs/norm_mean', + 'embeddings/obs/norm_std', + 'embeddings/obs/norm_max', + 'embeddings/obs/norm_min', + ] + monitored_vars.extend(norm_vars) + # ======================================================================== + + + + # Task-specific variables to be monitored. + task_specific_vars = [ + 'noreduce_obs_loss', + 'noreduce_orig_policy_loss', + 'noreduce_policy_loss', + 'noreduce_latent_recon_loss', + 'noreduce_policy_entropy', + 'noreduce_target_policy_entropy', + 'noreduce_reward_loss', + 'noreduce_value_loss', + 'noreduce_perceptual_loss', + 'noreduce_latent_state_l2_norms', + 'noreduce_lambd', + 'noreduce_value_priority_mean', + # Metrics related to network plasticity. + 'noreduce_dormant_ratio_encoder', + 'noreduce_dormant_ratio_transformer', + 'noreduce_dormant_ratio_head', + 'noreduce_avg_weight_mag_encoder', + 'noreduce_avg_weight_mag_transformer', + 'noreduce_avg_weight_mag_head', + 'noreduce_e_rank_last_linear', + 'noreduce_e_rank_sim_norm', + "noreduce_alpha_loss", + "noreduce_target_entropy", + + ] + + # Use self.task_num_for_current_rank as the number of tasks for the current rank. + num_tasks = self.task_num_for_current_rank + # If the number of tasks is provided, extend the monitored variables list with task-specific variable names. + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{self.task_id+task_idx}') + else: + # If num_tasks is not provided, assume a single task and use the original variable names. + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + timestep: List = [0], + task_id: int = None, + ) -> Dict: + """ + Overview: + The forward function for collecting data. It uses the model to perform MCTS search and + selects actions via sampling to encourage exploration. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`, optional): A list of action masks for each environment. + - temperature (:obj:`float`, optional): The temperature for MCTS action selection. + - to_play (:obj:`List`, optional): A list of player IDs for each environment. + - epsilon (:obj:`float`, optional): The probability for epsilon-greedy exploration. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The global task ID for the current environments. + Returns: + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + + # Core fix: C++ binding requires a list, even though it represents rewards in MuZero. + reward_roots = reward_roots.detach().cpu().numpy().tolist() + + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # The main difference between collect and eval is the addition of Dirichlet noise at the root. + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # C++ MCTS tree implementation. + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # Python MCTS tree implementation. + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # Epsilon-greedy collection strategy. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # Standard collection strategy (sampling from MCTS policy). + # NOTE: `action_index_in_legal_action_set` is the index within the set of legal actions. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # Convert the index back to the action in the full action space. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + # TODO: This logic is currently for the `muzero_segment_collector`. + if active_collect_env_num < self.collector_env_num: + # When one environment in `collect_env` finishes early, the length of `self.last_batch_obs` is reduced. + # The transformer needs the `env_id` to retrieve from the KV cache, which is complex to manage with a dynamic batch size. + # Therefore, we reset `self.last_batch_action` for all environments to -1, forcing the transformer + # to start from scratch and avoid retrieval errors. + print('==========collect_forward============') + print(f'len(self.last_batch_obs) < self.collector_env_num, {active_collect_env_num}<{self.collector_env_num}') + self._reset_collect(reset_init_data=True, task_id=task_id) + if getattr(self._cfg, 'sample_type', '') == 'episode': + print('BUG: sample_type is episode, but len(self.last_batch_obs) < self.collector_env_num') + + return output + + def _init_eval(self) -> None: + """ + Overview: + Initializes the eval mode. This method is called by ``self.__init__``. + It sets up the eval model and MCTS utilities for evaluation. + """ + self._eval_model = self._model + + # Create a copy of the configuration for eval MCTS and set a specific number of simulations. + mcts_eval_cfg = copy.deepcopy(self._cfg) + mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations + + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(mcts_eval_cfg) + else: + self._mcts_eval = MCTSPtree(mcts_eval_cfg) + + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, timestep: List = [0], task_id: int = None) -> Dict: + """ + Overview: + The forward function for evaluating the policy. It uses the model to perform MCTS search and + selects actions deterministically (choosing the one with the highest visit count). + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e., the current observation. + - action_mask (:obj:`list`): A list of action masks for each environment. + - to_play (:obj:`int`, optional): The player ID for the current turn. + - ready_env_id (:obj:`np.array`, optional): An array of IDs for environments that are ready for a new action. + - timestep (:obj:`List`, optional): The current timestep in each environment. + - task_id (:obj:`int`, optional): The global task ID for the current environments. + Returns: + - output (:obj:`Dict`): A dictionary where keys are environment IDs and values are dictionaries + containing the selected action and other MCTS statistics. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs_eval, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + # Core fix: C++ binding requires a list, even though it represents rewards in MuZero. + reward_roots = reward_roots.detach().cpu().numpy().tolist() # TODO============================= + + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # C++ MCTS tree implementation. + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # Python MCTS tree implementation. + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + + # During evaluation, no noise is added to the root policy. + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep= timestep, task_id=task_id) + + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + # NOTE: `deterministic=True` means we select the action with the highest visit count (argmax) + # rather than sampling, which is standard for evaluation. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # Convert the index back to the action in the full action space. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs_eval = data + self.last_batch_action = batch_action + + return output + + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the collection process for a specific environment or all environments. + It can clear caches and reset initial data to ensure optimal performance and prevent state leakage. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None, the reset applies more broadly. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count in the environment, used to trigger periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The global task ID. Can be used to handle different observation shapes per task. Defaults to None. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + + # We must handle both single int and list of ints for env_id. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the collector. + if current_steps is None: + world_model = self._collect_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Collector] Cleared KV cache for env_id: {eid} at episode end.') + + + # Determine the clear interval based on the environment's sample type. + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # Clear caches periodically to manage memory. + if current_steps is not None and current_steps % clear_interval == 0: + + print(f'clear_interval: {clear_interval}') + + # Clear various KV caches in the collect model's world model. + world_model = self._collect_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + + print(f'Collector: Caches cleared for collect_model at step {current_steps} for env {env_id}.') + + # TODO: Check if resetting the target model here is correct and necessary. + self._reset_target_model() + + def _reset_target_model(self) -> None: + """ + Overview: + Resets the target model by clearing its internal caches. This is crucial for managing memory, + especially when using transformer-based models with KV caching. + """ + # Clear various KV caches in the target model's world model. + world_model = self._target_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + print('Collector: Target model past_kv_cache cleared.') + + def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True, task_id: int = None) -> None: + """ + Overview: + Resets the evaluation process for a specific environment or all environments. + Clears caches and resets initial data to ensure clean evaluation runs. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. Defaults to None. + - current_steps (:obj:`int`, optional): The current step count, used for periodic cache clearing. Defaults to 0. + - reset_init_data (:obj:`bool`, optional): If True, resets the initial observation and action buffers. Defaults to True. + - task_id (:obj:`int`, optional): The global task ID. Can be used to handle different observation shapes per task. Defaults to None. + """ + if reset_init_data: + self.last_batch_obs_eval = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + + + # This logic handles the crucial end-of-episode cache clearing for evaluation. + # The evaluator calls `_policy.reset([env_id])` when an episode is done. + if env_id is not None: + if isinstance(env_id, int): + env_ids_to_reset = [env_id] + else: # Assumes it's a list + env_ids_to_reset = env_id + + # The key condition: `current_steps` is None only on the end-of-episode reset call from the evaluator. + if current_steps is None: + world_model = self._eval_model.world_model + for eid in env_ids_to_reset: + # Clear the specific environment's initial inference cache. + if eid < len(world_model.past_kv_cache_init_infer_envs): + world_model.past_kv_cache_init_infer_envs[eid].clear() + + print(f'>>> [Evaluator] Cleared KV cache for env_id: {eid} at episode end.') + + # The recurrent cache is global. + world_model.past_kv_cache_recurrent_infer.clear() + + if hasattr(world_model, 'keys_values_wm_list'): + world_model.keys_values_wm_list.clear() + + torch.cuda.empty_cache() + return + + # Determine the clear interval. + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else self._cfg.game_segment_length + + # Clear caches periodically. + if current_steps is not None and current_steps % clear_interval == 0: + + print(f'clear_interval: {clear_interval}') + + # Clear various KV caches in the eval model's world model. + world_model = self._eval_model.world_model + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up unused GPU memory. + torch.cuda.empty_cache() + + print(f'Evaluator: Caches cleared for eval_model at step {current_steps} for env {env_id}.') + + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clears all KV caches and precomputes positional embedding matrices in the model. + This is typically called when the maximum sequence length changes. + """ + # NOTE: This must be done for both the collect and target models. + for model in [self._collect_model, self._target_model]: + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Returns the state dictionary of the learn mode. + This typically includes the model, target model, and optimizer states, + which are necessary for saving and resuming training. + Returns: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary for the current learning progress. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer_world_model': self._optimizer_world_model.state_dict(), + } + + # ========== NOTE: This is a pretrain-finetune version that selectively loads parameters and freezes layers. ========== + def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components: List[str] = []) -> None: + """ + Overview: + Loads a state_dict for fine-tuning. It excludes multi-task specific parameters + and can freeze parts of the model (e.g., encoder, transformer) based on `finetune_components`. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): The state dictionary from a pre-trained model. + - finetune_components (:obj:`List[str]`, optional): A list of component names (e.g., "encoder", "transformer") + that will remain trainable. Components not in this list will have their parameters frozen. + """ + # Example configurations for fine-tuning: + # finetune_components = [] # Loads encoder & transformer, fine-tunes only heads. + # finetune_components = ['transformer'] # Loads encoder & transformer, fine-tunes transformer & heads. + finetune_components = ["representation_network", "encoder"] # Loads encoder & transformer, fine-tunes encoder & heads. + + # Define prefixes of parameters to be excluded from loading (typically multi-task heads). + exclude_prefixes = [ + '_orig_mod.world_model.head_policy_multi_task.', + '_orig_mod.world_model.head_value_multi_task.', + '_orig_mod.world_model.head_rewards_multi_task.', + '_orig_mod.world_model.head_observations_multi_task.', + '_orig_mod.world_model.task_emb.' + ] + + # Define specific parameter keys to be excluded (for special cases like task embeddings). + exclude_keys = [ + '_orig_mod.world_model.task_emb.weight', + '_orig_mod.world_model.task_emb.bias', + ] + + def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, exclude_keys: list = []) -> Dict[str, Any]: + """ + Filters out parameters from a state_dict based on prefixes and specific keys. + """ + filtered = {} + for k, v in state_dict_loader.items(): + if any(k.startswith(prefix) for prefix in exclude_prefixes): + print(f"Excluding parameter: {k}") # For debugging + continue + if k in exclude_keys: + print(f"Excluding specific parameter: {k}") # For debugging + continue + filtered[k] = v + return filtered + + # Filter and load the 'model' state_dict. + if 'model' in state_dict: + model_state_dict = state_dict['model'] + filtered_model_state_dict = filter_state_dict(model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._learn_model.load_state_dict(filtered_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _learn_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _learn_model: {unexpected_keys}") + else: + print("No 'model' key found in the state_dict.") + + # Filter and load the 'target_model' state_dict. + if 'target_model' in state_dict: + target_model_state_dict = state_dict['target_model'] + filtered_target_model_state_dict = filter_state_dict(target_model_state_dict, exclude_prefixes, exclude_keys) + missing_keys, unexpected_keys = self._target_model.load_state_dict(filtered_target_model_state_dict, strict=False) + if missing_keys: + print(f"Missing keys when loading _target_model: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys when loading _target_model: {unexpected_keys}") + else: + print("No 'target_model' key found in the state_dict.") + + # Handle freezing/unfreezing of parameters in _learn_model based on finetune_components. + # This assumes a naming convention where component names are present in parameter names. + for name, param in self._learn_model.named_parameters(): + # Freeze the encoder if "encoder" is not in finetune_components. + if "encoder" in name and "encoder" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + # Freeze the representation network if "representation_network" is not in finetune_components. + elif "representation_network" in name and "representation_network" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + # Freeze the transformer if "transformer" is not in finetune_components. + elif "transformer" in name and "transformer" not in finetune_components: + param.requires_grad = False + print(f"Freezing parameter: {name}") + else: + # Other parameters remain trainable by default. + print(f"Parameter remains trainable: {name}") diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 8b25c98b7..d2bdaa6c8 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -1,8 +1,5 @@ import inspect import logging -from typing import List, Dict, Union -from typing import Tuple - import matplotlib.pyplot as plt import numpy as np import torch @@ -10,9 +7,70 @@ from easydict import EasyDict from scipy.stats import entropy from torch.nn import functional as F +from typing import List, Dict, Union, Tuple import nltk from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +def initialize_pad_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str, pad_token_id: int = 0) -> torch.Tensor: + """ + Overview: + Initialize a tensor filled with `pad_token_id` for batch observations. + This function is designed to be flexible and can handle both textual + and non-textual observations: + + - For textual observations: it initializes `input_ids` with padding tokens, + ensuring consistent sequence lengths within a batch. + - For non-textual observations: it provides a convenient way to fill + observation tensors with a default of 0, + ensuring shape compatibility and preventing uninitialized values. + Arguments: + - observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor. + - batch_size (:obj:`int`): The batch size. + - device (:obj:`str`): The device to store the tensor. + - pad_token_id (:obj:`int`): The token ID (or placeholder value) used for padding. + Returns: + - padded_tensor (:obj:`torch.Tensor`): A tensor of the given shape, + filled with `pad_token_id`. + """ + if isinstance(observation_shape, (list, tuple)): + shape = [batch_size, *observation_shape] + elif isinstance(observation_shape, int): + shape = [batch_size, observation_shape] + else: + raise TypeError(f"observation_shape must be int, list, or tuple, but got {type(observation_shape).__name__}") + + return torch.full(shape, fill_value=pad_token_id, dtype=torch.float32, device=device) if pad_token_id == -1 else torch.full(shape, fill_value=pad_token_id, dtype=torch.long, device=device) + + +def initialize_zeros_batch( + observation_shape: Union[int, List[int], Tuple[int, ...]], + batch_size: int, + device: str +) -> torch.Tensor: + """ + Overview: + Initializes a zeros tensor for a batch of observations based on the + provided shape. This is commonly used to prepare initial input for models + like UniZero. + + Arguments: + - observation_shape (:obj:`Union[int, List[int], Tuple[int, ...]]`): The shape of a single observation. + - batch_size (:obj:`int`): The number of observations in the batch. + - device (:obj:`str`): The device to store the tensor on (e.g., 'cpu', 'cuda'). + + Returns: + - torch.Tensor: A zeros tensor with the shape [batch_size, *observation_shape]. + """ + if isinstance(observation_shape, (list, tuple)): + shape = (batch_size, *observation_shape) + elif isinstance(observation_shape, int): + shape = (batch_size, observation_shape) + else: + raise TypeError( + f"observation_shape must be an int, list, or tuple, but got {type(observation_shape).__name__}" + ) + return torch.zeros(shape, device=device) + def compute_bleu(reference: str, prediction: str) -> float: """ Compute sentence-level BLEU-4 score with smoothing and scale it to 0–1. @@ -211,29 +269,69 @@ def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) -# modified from https://github.com/karpathy/nanoGPT/blob/master/model.py#L263 -def configure_optimizers_nanogpt(model, weight_decay, learning_rate, betas, device_type): - # start with all of the candidate parameters +# The following code is modified from the original implementation at: +# https://github.com/karpathy/nanoGPT/blob/master/model.py#L263 + +def configure_optimizers_nanogpt( + model: nn.Module, + weight_decay: float, + learning_rate: float, + betas: Tuple[float, float], + device_type: str +) -> torch.optim.AdamW: + """ + Overview: + Configures the AdamW optimizer for the nanoGPT model. This function separates model + parameters into two groups: one that will be subject to weight decay and one that will not. + Typically, 2D and higher-dimensional tensors (e.g., weights of linear layers) are decayed, + while 1D tensors (e.g., biases and LayerNorm weights) are not. + + Arguments: + - model (:obj:`nn.Module`): The model for which to configure optimizers. + - weight_decay (:obj:`float`): The weight decay coefficient to apply. + - learning_rate (:obj:`float`): The learning rate for the optimizer. + - betas (:obj:`Tuple[float, float]`): The beta coefficients for the AdamW optimizer (e.g., (0.9, 0.95)). + - device_type (:obj:`str`): The type of device being used, e.g., 'cuda' or 'cpu'. + + Returns: + (:obj:`torch.optim.AdamW`): The configured AdamW optimizer instance. + """ + # Start with all of the candidate parameters from the model. param_dict = {pn: p for pn, p in model.named_parameters()} - # filter out those that do not require grad - param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} - # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. - # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + + # TODO: The following code is commented out, which is crucial for a balanced pipeline. + # We do not filter out parameters with `requires_grad=False` because their `requires_grad` + # attribute might be set to `True` at a later stage during training. + # param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + + # Create optimizer parameter groups. Any parameter that is 2D or higher will be weight decayed, + # otherwise no. i.e. all weight tensors in matrix multiplications and embeddings will be decayed, + # while all biases and layernorm weights will not. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] + num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") - # Create AdamW optimizer and use the fused version if it is available + + # Create the AdamW optimizer. + # Check if a fused version of AdamW is available in the current PyTorch installation. fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + + # Note: The current logic creates a standard AdamW optimizer on CUDA-enabled systems. + # The 'fused' version is only considered on non-CUDA systems, where it will ultimately not be used + # because `device_type` would not be 'cuda'. if torch.cuda.is_available(): + # On a CUDA-enabled system, create a standard AdamW optimizer. optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) else: + # On a non-CUDA system, check if the fused optimizer can be used. + # This will be False if device_type is not 'cuda'. use_fused = fused_available and device_type == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) @@ -372,7 +470,7 @@ def prepare_obs_stack_for_unizero(obs_batch_ori: np.ndarray, cfg: EasyDict) -> T return obs_batch, obs_target_batch -def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, torch.Tensor]: +def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict, task_id = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: Prepare the observations for the model by converting the original batch of observations @@ -382,6 +480,7 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, Arguments: - obs_batch_ori (:obj:`np.ndarray`): The original observations in a batch style. - cfg (:obj:`EasyDict`): The configuration dictionary containing model settings. + - task_id (:obj:`int`, optional): The global task ID, used in multitask settings to select the appropriate observation shape. Returns: - obs_batch (:obj:`torch.Tensor`): The tensor containing the observations for the initial inference. @@ -395,9 +494,12 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # Calculate the dimension size to slice based on the model configuration. # For convolutional models ('conv'), use the number of frames to stack times the number of channels. # For multi-layer perceptron models ('mlp'), use the number of frames to stack times the size of the observation space. - stack_dim = cfg.model.frame_stack_num * ( + if task_id is None: + stack_dim = cfg.model.frame_stack_num * ( cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape) - + else: + stack_dim = cfg.model.frame_stack_num * ( + cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape_list[task_id]) # Slice the original observation tensor to obtain the batch for the initial inference. obs_batch = obs_batch_ori[:, :stack_dim] @@ -408,7 +510,10 @@ def prepare_obs(obs_batch_ori: np.ndarray, cfg: EasyDict) -> Tuple[torch.Tensor, # Determine the starting dimension to exclude based on the model type. # For 'conv', exclude the first 'image_channel' dimensions. # For 'mlp', exclude the first 'observation_shape' dimensions. - exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape + if task_id is None: + exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape + else: + exclude_dim = cfg.model.image_channel if cfg.model.model_type in ['conv', 'conv_context'] else cfg.model.observation_shape_list[task_id] # Slice the original observation tensor to obtain the batch for consistency loss calculation. obs_target_batch = obs_batch_ori[:, exclude_dim:] diff --git a/lzero/worker/README.md b/lzero/worker/README.md new file mode 100644 index 000000000..cbbfb4ac7 --- /dev/null +++ b/lzero/worker/README.md @@ -0,0 +1,71 @@ +# Worker Module + +This directory contains the worker components for LightZero's reinforcement learning algorithms, including data collectors and evaluators. + +## File Overview + +### Collector Files +Collectors are responsible for gathering experience data during training through environment interaction. + +| File | Algorithm | Collection Mode | Description | +|------|-----------|----------------|-------------| +| `alphazero_collector.py` | AlphaZero | Episode-based | Collects complete game episodes for AlphaZero algorithm. Designed for perfect information games (e.g., board games). | +| `muzero_collector.py` | MuZero/EfficientZero/Gumbel MuZero | Episode-based | Collects complete game episodes for MCTS+RL algorithms. Supports both perfect and imperfect information environments. | +| `muzero_segment_collector.py` | MuZero/EfficientZero/Gumbel MuZero | Segment-based | Collects a specified number of game segments rather than complete episodes. Provides greater flexibility and extensibility. | + +### Evaluator Files +Evaluators assess the performance of trained policies during the training process. + +| File | Algorithm | Description | +|------|-----------|-------------| +| `alphazero_evaluator.py` | AlphaZero | Evaluates AlphaZero policy performance on test environments. | +| `muzero_evaluator.py` | MuZero/EfficientZero | Evaluates MuZero-based policy performance with support for multi-task scenarios. | + +## Key Differences + +### AlphaZero vs MuZero +- **AlphaZero**: Specifically designed for perfect information games where the full game state is observable (e.g., Go, Chess) +- **MuZero**: General-purpose algorithm supporting both perfect and imperfect information environments, with learned dynamics models + +### Collector vs Evaluator +- **Collector**: Gathers training data through self-play or environment interaction +- **Evaluator**: Assesses policy performance at regular intervals during training + +### MuZeroCollector vs MuZeroSegmentCollector +- **MuZeroCollector**: Returns data after collecting complete game episodes +- **MuZeroSegmentCollector**: Returns data after collecting a specified number of game segments, offering more fine-grained control over data collection + +## Common Features + +All workers support: +- Distributed training (multi-process/multi-GPU) +- TensorBoard logging +- Multi-task learning scenarios (via `task_id` parameter) +- Configurable collection/evaluation frequencies +- Environment and policy reset capabilities + +## Usage Example + +```python +from lzero.worker import MuZeroCollector, MuZeroEvaluator + +# Initialize collector +collector = MuZeroCollector( + collect_print_freq=100, + env=env_manager, + policy=policy, + tb_logger=tb_logger, + exp_name='my_experiment', + policy_config=policy_config +) + +# Initialize evaluator +evaluator = MuZeroEvaluator( + eval_freq=1000, + n_evaluator_episode=10, + env=eval_env, + policy=policy, + tb_logger=tb_logger, + exp_name='my_experiment' +) +``` diff --git a/lzero/worker/README_zh.md b/lzero/worker/README_zh.md new file mode 100644 index 000000000..170c532a4 --- /dev/null +++ b/lzero/worker/README_zh.md @@ -0,0 +1,71 @@ +# Worker 模块 + +本目录包含 LightZero 强化学习算法的工作组件,包括数据收集器和评估器。 + +## 文件概览 + +### 收集器文件 +收集器负责在训练过程中通过环境交互收集经验数据。 + +| 文件 | 算法 | 收集模式 | 说明 | +|------|------|---------|------| +| `alphazero_collector.py` | AlphaZero | 基于回合 | 为 AlphaZero 算法收集完整的游戏回合。专为完全信息博弈设计(如棋类游戏)。 | +| `muzero_collector.py` | MuZero/EfficientZero/Gumbel MuZero | 基于回合 | 为 MCTS+RL 算法收集完整的游戏回合。支持完全和不完全信息环境。 | +| `muzero_segment_collector.py` | MuZero/EfficientZero/Gumbel MuZero | 基于片段 | 收集指定数量的游戏片段而非完整回合。提供更大的灵活性和可扩展性。 | + +### 评估器文件 +评估器在训练过程中评估已训练策略的性能。 + +| 文件 | 算法 | 说明 | +|------|------|------| +| `alphazero_evaluator.py` | AlphaZero | 在测试环境中评估 AlphaZero 策略性能。 | +| `muzero_evaluator.py` | MuZero/EfficientZero | 评估基于 MuZero 的策略性能,支持多任务场景。 | + +## 主要差异 + +### AlphaZero vs MuZero +- **AlphaZero**:专为完全信息博弈设计,游戏状态完全可观察(如围棋、象棋) +- **MuZero**:通用算法,支持完全和不完全信息环境,具有学习的动力学模型 + +### Collector vs Evaluator +- **Collector(收集器)**:通过自我对弈或环境交互收集训练数据 +- **Evaluator(评估器)**:在训练期间定期评估策略性能 + +### MuZeroCollector vs MuZeroSegmentCollector +- **MuZeroCollector**:收集完整游戏回合后返回数据 +- **MuZeroSegmentCollector**:收集指定数量的游戏片段后返回数据,提供更细粒度的数据收集控制 + +## 共同特性 + +所有工作组件都支持: +- 分布式训练(多进程/多GPU) +- TensorBoard 日志记录 +- 多任务学习场景(通过 `task_id` 参数) +- 可配置的收集/评估频率 +- 环境和策略重置功能 + +## 使用示例 + +```python +from lzero.worker import MuZeroCollector, MuZeroEvaluator + +# 初始化收集器 +collector = MuZeroCollector( + collect_print_freq=100, + env=env_manager, + policy=policy, + tb_logger=tb_logger, + exp_name='my_experiment', + policy_config=policy_config +) + +# 初始化评估器 +evaluator = MuZeroEvaluator( + eval_freq=1000, + n_evaluator_episode=10, + env=eval_env, + policy=policy, + tb_logger=tb_logger, + exp_name='my_experiment' +) +``` diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 4d3b1b740..856685c0d 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,7 +1,6 @@ -import os import time from collections import deque, namedtuple -from typing import Optional, Any, List +from typing import Optional, Any, List, Dict, Set import numpy as np import torch @@ -16,70 +15,77 @@ from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation -from lzero.policy.utils import compute_bleu @SERIAL_COLLECTOR_REGISTRY.register('episode_muzero') class MuZeroCollector(ISerialCollector): """ Overview: - The Episode Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero. - It manages the data collection process for training these algorithms using a serial mechanism. + The episode-based collector for MCTS-based reinforcement learning algorithms, + including MuZero, EfficientZero, Sampled EfficientZero, and Gumbel MuZero. + It orchestrates the data collection process in a serial manner, managing interactions + between the policy and the environment to generate game segments for training. Interfaces: - ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, - ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``collect``, + ``_compute_priorities``, ``pad_and_save_last_trajectory``, ``_output_log``, ``close``, ``__del__``. Properties: - ``envstep`` + ``envstep``. """ - # TO be compatible with ISerialCollector + # Default configuration for the collector. To be compatible with ISerialCollector. config = dict() def __init__( self, collect_print_freq: int = 100, - env: BaseEnvManager = None, - policy: namedtuple = None, + env: Optional[BaseEnvManager] = None, + policy: Optional[namedtuple] = None, tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'collector', + exp_name: str = 'default_experiment', + instance_name: str = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: Optional[int] = None, ) -> None: """ Overview: - Initialize the MuZeroCollector with the given parameters. + Initializes the MuZeroCollector with the given configuration. Arguments: - - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. - - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. - - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. - - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - - instance_name (:obj:`str`): Unique identifier for this collector instance. - - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - collect_print_freq (:obj:`int`): The frequency (in training iterations) at which to print collection statistics. + - env (:obj:`Optional[BaseEnvManager]`): An instance of a vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): A namedtuple containing the policy's forward pass and other methods. + - tb_logger (:obj:`Optional[SummaryWriter]`): A TensorBoard logger instance for logging metrics. + - exp_name (:obj:`str`): The name of the experiment, used for organizing logs. + - instance_name (:obj:`str`): A unique name for this collector instance. + - policy_config (:obj:`'policy_config'`): The configuration object for the policy. + - task_id (:obj:`Optional[int]`): The identifier for the current task in a multi-task setting. If None, operates in single-task mode. """ + self.task_id = task_id self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq self._timer = EasyTimer() self._end_flag = False + # Get distributed training info self._rank = get_rank() self._world_size = get_world_size() + + # Logger setup: only rank 0 creates the main logger and TensorBoard logger. if self._rank == 0: if tb_logger is not None: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name ) else: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) self._tb_logger = None @@ -91,12 +97,11 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset or replace the environment managed by this collector. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. + Resets or replaces the environment managed by the collector. + If `_env` is None, it resets the existing environment. Otherwise, it replaces the old + environment with the new one and launches it. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to be used. If None, resets the current environment. """ if _env is not None: self._env = _env @@ -108,42 +113,39 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset or replace the policy used by this collector. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets or replaces the policy used by the collector. + If `_policy` is None, it resets the existing policy. Otherwise, it replaces the old + policy with the new one. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - _policy (:obj:`Optional[namedtuple]`): The new policy to be used. """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, '_env'), "Please set env first before resetting policy." if _policy is not None: self._policy = _policy self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) self._logger.debug( - 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) + f"Set default n_episode mode(n_episode({self._default_n_episode}), env_num({self._env_num}))" ) self._policy.reset() def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the collector with the given policy and/or environment. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets the collector, including the environment and policy. Also re-initializes + internal state variables for tracking collection progress. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _policy (:obj:`Optional[namedtuple]`): The new policy to use. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to use. """ if _env is not None: self.reset_env(_env) if _policy is not None: self.reset_policy(_policy) - self._env_info = {env_id: {'time': 0., 'step': 0, 'text_bleu': 0.} for env_id in range(self._env_num)} + # Initialize per-environment tracking info + self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + # Reset overall statistics self._episode_info = [] self._total_envstep_count = 0 self._total_episode_count = 0 @@ -151,36 +153,35 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._last_train_iter = 0 self._end_flag = False - # A game_segment_pool implementation based on the deque structure. + # A pool to store completed game segments, implemented using a deque. self.game_segment_pool = deque(maxlen=int(1e6)) self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps def _reset_stat(self, env_id: int) -> None: """ Overview: - Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ - and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ - to get more messages. + Resets the statistics for a specific environment, identified by `env_id`. + This is typically called when an episode in that environment ends. Arguments: - - env_id (:obj:`int`): the id where we need to reset the collector's state + - env_id (:obj:`int`): The ID of the environment to reset statistics for. """ - self._env_info[env_id] = {'time': 0., 'step': 0, 'text_bleu': 0.} + self._env_info[env_id] = {'time': 0., 'step': 0} @property def envstep(self) -> int: """ Overview: - Get the total number of environment steps collected. + Returns the total number of environment steps collected since the last reset. Returns: - - envstep (:obj:`int`): Total number of environment steps collected. + - envstep (:obj:`int`): The total environment step count. """ return self._total_envstep_count def close(self) -> None: """ Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger \ - and close the tb_logger. + Closes the collector, including the environment and any loggers. + Ensures that all resources are properly released. """ if self._end_flag: return @@ -193,627 +194,455 @@ def close(self) -> None: def __del__(self) -> None: """ Overview: - Execute the close command and close the collector. __del__ is automatically called to \ - destroy the collector instance when the collector finishes its work + Destructor for the collector instance, ensuring that `close` is called + to clean up resources. """ self.close() # ============================================================== - # MCTS+RL related core code + # MCTS+RL Core Collection Logic # ============================================================== - def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> Optional[np.ndarray]: """ Overview: - Compute the priorities for transitions based on prediction and search value discrepancies. + Computes priorities for experience replay based on the discrepancy between + predicted values and MCTS search values. Arguments: - - i (:obj:`int`): Index of the values in the list to compute the priority for. - - pred_values_lst (:obj:`List[float]`): List of predicted values. - - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + - i (:obj:`int`): The index of the environment's data in the lists. + - pred_values_lst (:obj:`List[float]`): A list containing lists of predicted values for each environment. + - search_values_lst (:obj:`List[float]`): A list containing lists of search values from MCTS for each environment. Returns: - - priorities (:obj:`np.ndarray`): Array of computed priorities. + - priorities (:obj:`Optional[np.ndarray]`): An array of priorities for the transitions. Returns None if priority is not used. """ if self.policy_config.use_priority: - # Calculate priorities. The priorities are the L1 losses between the predicted - # values and the search values. We use 'none' as the reduction parameter, which - # means the loss is calculated for each element individually, instead of being summed or averaged. - # A small constant (1e-6) is added to the results to avoid zero priorities. This - # is done because zero priorities could potentially cause issues in some scenarios. + # Calculate priorities as the L1 loss between predicted values and search values. + # 'reduction=none' ensures the loss is calculated for each element individually. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) - search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device - ).float().view(-1) - priorities = L1Loss(reduction='none' - )(pred_values, - search_values).detach().cpu().numpy() + 1e-6 + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device).float().view(-1) + + # A small epsilon is added to avoid zero priorities. + priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + 1e-6 else: - # priorities is None -> use the max priority for all newly collected data + # If priority is not used, return None. The replay buffer will use max priority for new data. priorities = None return priorities - def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], - last_game_priorities: List[np.ndarray], - game_segments: List[GameSegment], done: np.ndarray) -> None: + def pad_and_save_last_trajectory( + self, i: int, last_game_segments: List[Optional[GameSegment]], + last_game_priorities: List[Optional[np.ndarray]], + game_segments: List[GameSegment], done: np.ndarray + ) -> None: """ Overview: - Save the game segment to the pool if the current game is finished, padding it if necessary. + Pads the end of the `last_game_segment` with data from the start of the current `game_segment`. + This is necessary to compute target values for the final transitions of a segment. After padding, + the completed segment is stored in the `game_segment_pool`. Arguments: - - i (:obj:`int`): Index of the current game segment. - - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. - - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. - - game_segments (:obj:`List[GameSegment]`): List of the current game segments. - - done (:obj:`np.ndarray`): Array indicating whether each game is done. + - i (:obj:`int`): The index of the environment being processed. + - last_game_segments (:obj:`List[Optional[GameSegment]]`): List of game segments from the previous collection chunk. + - last_game_priorities (:obj:`List[Optional[np.ndarray]]`): List of priorities corresponding to the last game segments. + - game_segments (:obj:`List[GameSegment]`): List of game segments from the current collection chunk. + - done (:obj:`np.ndarray`): Array indicating if the episode has terminated for each environment. Note: - (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + An implicit assumption is that the start of the new segment's observation history overlaps with the + end of the last segment's, e.g., `(last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all()` is True. """ - # pad over last segment trajectory - beg_index = self.policy_config.model.frame_stack_num - end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - - # the start obs is init zero obs, so we take the - # [ : +] obs as the pad obs - # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs - pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] - - # NOTE: for unizero - beg_index = 0 - end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - pad_action_lst = game_segments[i].action_segment[beg_index:end_index] - - # NOTE: for unizero - pad_child_visits_lst = game_segments[i].child_visit_segment[ - :self.policy_config.num_unroll_steps + self.policy_config.td_steps] - - # EfficientZero original repo bug: - # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] - - beg_index = 0 - end_index = beg_index + self.unroll_plus_td_steps - 1 + # --- Prepare padding data from the current game segment --- + # Observations for padding are taken from the start of the new segment. + beg_index_obs = self.policy_config.model.frame_stack_num + end_index_obs = beg_index_obs + self.policy_config.num_unroll_steps + self.policy_config.td_steps + pad_obs_lst = game_segments[i].obs_segment[beg_index_obs:end_index_obs] + + # Actions for padding. + beg_index_ac = 0 + end_index_ac = beg_index_ac + self.policy_config.num_unroll_steps + self.policy_config.td_steps + pad_action_lst = game_segments[i].action_segment[beg_index_ac:end_index_ac] + + # Child visits for padding. + pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] + + # Rewards for padding. + beg_index_rew = 0 + end_index_rew = beg_index_rew + self.unroll_plus_td_steps - 1 + pad_reward_lst = game_segments[i].reward_segment[beg_index_rew:end_index_rew] + + # Root values for padding. + beg_index_val = 0 + end_index_val = beg_index_val + self.unroll_plus_td_steps + pad_root_values_lst = game_segments[i].root_value_segment[beg_index_val:end_index_val] - pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_lst = game_segments[i].chance_segment[beg_index:end_index] - - beg_index = 0 - end_index = beg_index + self.unroll_plus_td_steps - - pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] - + chance_lst = game_segments[i].chance_segment[beg_index_rew:end_index_rew] + if self.policy_config.gumbel_algo: - pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] + pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index_val:end_index_val] - # pad over and save + # --- Pad the last game segment and save it --- if self.policy_config.gumbel_algo: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_segment_improved_policy=pad_improved_policy_prob) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, + pad_child_visits_lst, next_segment_improved_policy=pad_improved_policy_prob + ) else: if self.policy_config.use_ture_chance_label_in_chance_encoder: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_chances=chance_lst) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, + pad_child_visits_lst, next_chances=chance_lst + ) else: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst) - """ - Note: - game_segment element shape: - obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 - rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length -> 20 - root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - child_visits: game_segment_length + num_unroll_steps -> 20 +5 - to_play: game_segment_length -> 20 - action_mask: game_segment_length -> 20 - """ - + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst + ) + + # Convert the segment's lists to NumPy arrays for efficient storage. last_game_segments[i].game_segment_to_array() - # put the game segment into the pool + # Add the completed game segment and its associated data to the pool. self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) - # reset last game_segments + # Reset the placeholder for the last game segment. last_game_segments[i] = None last_game_priorities[i] = None - return None - - def collect(self, - n_episode: Optional[int] = None, - train_iter: int = 0, - policy_kwargs: Optional[dict] = None, - collect_with_pure_policy: bool = False) -> List[Any]: + def collect( + self, + n_episode: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[Dict] = None, + collect_with_pure_policy: bool = False + ) -> List[Any]: """ Overview: - Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations. + Collects `n_episode` episodes of data. It manages the entire lifecycle of an episode, + from getting actions from the policy, stepping the environment, storing transitions, + and saving completed game segments. Arguments: - - n_episode (:obj:`Optional[int]`): Number of episodes to collect. - - train_iter (:obj:`int`): Number of training iterations completed so far. - - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. - - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. + - n_episode (:obj:`Optional[int]`): The number of episodes to collect. If None, uses the default from the policy config. + - train_iter (:obj:`int`): The current training iteration, used for logging. + - policy_kwargs (:obj:`Optional[Dict]`): Additional keyword arguments to pass to the policy's forward method, like temperature for exploration. + - collect_with_pure_policy (:obj:`bool`): If True, collects data using a pure policy (e.g., greedy action) without MCTS. Returns: - - return_data (:obj:`List[Any]`): Collected data in the form of a list. + - return_data (:obj:`List[Any]`): A list containing the collected game segments and metadata. """ - # TODO: collect_with_pure_policy as a separate collector + # TODO(author): Consider implementing `collect_with_pure_policy` as a separate, more streamlined collector for clarity and modularity. if n_episode is None: if self._default_n_episode is None: - raise RuntimeError("Please specify collect n_episode") + raise RuntimeError("Please specify `n_episode` for collection.") else: n_episode = self._default_n_episode - assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num) + assert n_episode >= self._env_num, f"Please ensure n_episode ({n_episode}) >= env_num ({self._env_num})." + if policy_kwargs is None: policy_kwargs = {} - temperature = policy_kwargs['temperature'] - epsilon = policy_kwargs['epsilon'] + temperature = policy_kwargs.get('temperature', 1.0) + epsilon = policy_kwargs.get('epsilon', 0.0) + # --- Initializations --- collected_episode = 0 - collected_step = 0 env_nums = self._env_num retry_waiting_time = 0.05 - # initializations + # Wait for all environments to be ready and get initial observations. init_obs = self._env.ready_obs while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + self._logger.warning(f"Waiting for all environments to reset. Ready envs: {list(init_obs.keys())}") time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - ) init_obs = self._env.ready_obs + # Prepare initial state dictionaries from observations. action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} to_play_dict = {i: to_ndarray(init_obs[i]['to_play']) for i in range(env_nums)} - - timestep_dict = {} - for i in range(env_nums): - if 'timestep' not in init_obs[i]: - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - print(f"Warning: 'timestep' key is missing in init_obs[{i}]. Assigning value -1. Please note that the unizero algorithm may require the 'timestep' key in init_obs.") - timestep_dict[i] = to_ndarray(init_obs[i].get('timestep', -1)) - + timestep_dict = {i: to_ndarray(init_obs[i].get('timestep', -1)) for i in range(env_nums)} if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} - game_segments = [ - GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) for _ in range(env_nums) - ] - # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[] for _ in range(env_nums)] + # Initialize game segments and observation stacks for each environment. + game_segments = [GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) for _ in range(env_nums)] + observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)] for env_id in range(env_nums): - observation_window_stack[env_id] = deque( - [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) + for _ in range(self.policy_config.model.frame_stack_num): + observation_window_stack[env_id].append(to_ndarray(init_obs[env_id]['observation'])) game_segments[env_id].reset(observation_window_stack[env_id]) + # State tracking variables for the collection loop. dones = np.array([False for _ in range(env_nums)]) - last_game_segments = [None for _ in range(env_nums)] - last_game_priorities = [None for _ in range(env_nums)] - # for priorities in self-play + last_game_segments: List[Optional[GameSegment]] = [None for _ in range(env_nums)] + last_game_priorities: List[Optional[np.ndarray]] = [None for _ in range(env_nums)] + + # Buffers for priority calculation. search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] - # some logs - eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) + # Logging variables. + eps_steps_lst = np.zeros(env_nums) + visit_entropies_lst = np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) - self_play_moves = 0. - self_play_episodes = 0. - self_play_moves_max = 0 - self_play_visit_entropy = [] - total_transitions = 0 - ready_env_id = set() + ready_env_id: Set[int] = set() remain_episode = n_episode if collect_with_pure_policy: - temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + # Dummy visit counts for pure policy collection. + temp_visit_list = [0.0 for _ in range(self._env.action_space.n)] + # --- Main Collection Loop --- while True: with self._timer: - # Get current ready env obs. + # Get observations from ready environments. obs = self._env.ready_obs - new_available_env_id = set(obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + ready_env_id.update(list(new_available_env_id)[:remain_episode]) remain_episode -= min(len(new_available_env_id), remain_episode) - - # NOTE: If waiting for N environments to synchronize, it may result in some environments not being completed (done) by the time of return. - # However, the current muzero_collector does not properly maintain the global self.last_game_segments, leading to some data not being collected. - - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) - - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id} + # Prepare policy inputs. + stack_obs_list = [game_segments[env_id].get_obs() for env_id in ready_env_id] action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} - - stack_obs = to_ndarray(stack_obs) - # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + stack_obs_array = to_ndarray(stack_obs_list) + stack_obs_tensor = prepare_observation(stack_obs_array, self.policy_config.model.model_type) + stack_obs_tensor = torch.from_numpy(stack_obs_tensor).to(self.policy_config.device) # ============================================================== - # Key policy forward step + # Policy Forward Pass # ============================================================== - # print(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) - - pred_next_text_with_env_id = {k: v['predicted_next_text'] if 'predicted_next_text' in v else -1 for k, v in policy_output.items()} - - # Extract relevant policy outputs - actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} - value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} - pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = { - k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() + policy_input = { + 'x': stack_obs_tensor, + 'action_mask': action_mask, + 'temperature': temperature, + 'to_play': to_play, + 'epsilon': epsilon, + 'ready_env_id': ready_env_id, + 'timestep': timestep } - + if self.task_id is not None: + policy_input['task_id'] = self.task_id + + policy_output = self._policy.forward(**policy_input) + + # --- Unpack policy outputs --- + actions, value_dict, pred_value_dict = {}, {}, {} + distributions_dict, visit_entropy_dict = {}, {} if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = { - k: v['root_sampled_actions'] for k, v in policy_output.items() - } - - if not collect_with_pure_policy: - distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in - policy_output.items()} - visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in - policy_output.items()} + root_sampled_actions_dict = {} + if self.policy_config.gumbel_algo: + improved_policy_dict, completed_value_dict = {}, {} - if self.policy_config.gumbel_algo: - improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in - policy_output.items()} - completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} - - # Initialize dictionaries to store results - actions = {} - value_dict = {} - pred_value_dict = {} - timestep_dict = {} - pred_next_text = {} - - if not collect_with_pure_policy: - distributions_dict = {} - visit_entropy_dict = {} - - if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - - if self.policy_config.gumbel_algo: - improved_policy_dict = {} - completed_value_dict = {} - - # Populate the result dictionaries for env_id in ready_env_id: - actions[env_id] = actions_with_env_id.pop(env_id) - value_dict[env_id] = value_dict_with_env_id.pop(env_id) - pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) - timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id) - pred_next_text[env_id] = pred_next_text_with_env_id.pop(env_id) - + output = policy_output[env_id] + actions[env_id] = output['action'] + value_dict[env_id] = output['searched_value'] + pred_value_dict[env_id] = output['predicted_value'] + if not collect_with_pure_policy: - distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) - + distributions_dict[env_id] = output['visit_count_distributions'] + visit_entropy_dict[env_id] = output['visit_count_distribution_entropy'] if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) - - visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) - + root_sampled_actions_dict[env_id] = output['root_sampled_actions'] if self.policy_config.gumbel_algo: - improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) - completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) - + improved_policy_dict[env_id] = output['improved_policy_probs'] + completed_value_dict[env_id] = output['roots_completed_value'] + # ============================================================== - # Interact with the environment + # Environment Interaction # ============================================================== timesteps = self._env.step(actions) - interaction_duration = self._timer.value / len(timesteps) - - groundtrut_next_text = {} + interaction_duration = self._timer.value / len(timesteps) if timesteps else 0 + for env_id, episode_timestep in timesteps.items(): with self._timer: + # Handle abnormal timesteps by resetting the environment and policy state. if episode_timestep.info.get('abnormal', False): - # If there is an abnormal episode_timestep, reset all the related variables(including this env). - # suppose there is no reset param, reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) - self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, episode_timestep.info)) + self._logger.info(f"Environment {env_id} returned an abnormal step, info: {episode_timestep.info}") continue + obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info - - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - obs_input_ids = torch.tensor(obs['observation'], dtype=torch.long) # shape: [L] - obs_attn_mask = torch.tensor(obs['obs_attn_mask'][0], dtype=torch.long) - valid_input_ids = obs_input_ids[obs_attn_mask == 1].tolist() - - groundtrut_next_text[env_id] = self._env._envs[env_id].tokenizer.decode(valid_input_ids, skip_special_tokens=True) - text_bleu = compute_bleu(reference=groundtrut_next_text[env_id], prediction=pred_next_text[env_id]) - # Whether to output text comparisons with high BLEU scores to evaluate the effectiveness of decoding the next latent. - if text_bleu > 0.85: - os.makedirs("./log", exist_ok=True) - with open("./log/bleu_match.txt", "a", encoding="utf-8") as f: - f.write(f"pred_text={pred_next_text[env_id]}\ngroundtruth_text={groundtrut_next_text[env_id]}\ntext_bleu={text_bleu:.4f}\n\n") - + # Store MCTS search statistics. if collect_with_pure_policy: game_segments[env_id].store_search_stats(temp_visit_list, 0) else: if self.policy_config.sampled_algo: - game_segments[env_id].store_search_stats( - distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] - ) + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id]) elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], - improved_policy=improved_policy_dict[env_id]) + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], improved_policy=improved_policy_dict[env_id]) else: game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) - # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} - # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` + # Append the current transition to the game segment. + append_args = (actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], to_play_dict[env_id]) if self.policy_config.use_ture_chance_label_in_chance_encoder: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id], chance_dict[env_id] - ) - else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id] - ) + append_args += (chance_dict[env_id],) + append_args += (timestep_dict[env_id],) + game_segments[env_id].append(*append_args) - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] are corresponding to the next action + # Update state dictionaries for the next step. action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) if self.policy_config.use_ture_chance_label_in_chance_encoder: chance_dict[env_id] = to_ndarray(obs['chance']) - if self.policy_config.ignore_done: - dones[env_id] = False - else: - dones[env_id] = done - + dones[env_id] = done if not self.policy_config.ignore_done else False + + # Update logging and priority data. if not collect_with_pure_policy: visit_entropies_lst[env_id] += visit_entropy_dict[env_id] if self.policy_config.gumbel_algo: completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) - + eps_steps_lst[env_id] += 1 - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) - - total_transitions += 1 - if self.policy_config.use_priority: pred_values_lst[env_id].append(pred_value_dict[env_id]) search_values_lst[env_id].append(value_dict[env_id]) - if self.policy_config.gumbel_algo and not collect_with_pure_policy: - improved_policy_lst[env_id].append(improved_policy_dict[env_id]) - # append the newest obs + # Update the observation window with the new observation. observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== - # we will save a game segment if it is the end of the game or the next game segment is finished. + # Game Segment Saving Logic # ============================================================== - - # if game segment is full, we will save the last game segment + # If a segment is full, pad and save the previous segment. if game_segments[env_id].is_full(): - # pad over last segment trajectory if last_game_segments[env_id] is not None: - # TODO(pu): return the one game segment - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) + self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, game_segments, dones) - # calculate priority + # Calculate priorities for the now-completed `last_game_segment`. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - if self.policy_config.gumbel_algo and not collect_with_pure_policy: - improved_policy_lst[env_id] = [] + pred_values_lst[env_id], search_values_lst[env_id] = [], [] - # the current game_segments become last_game_segment + # The current segment becomes the `last_game_segment`. last_game_segments[env_id] = game_segments[env_id] last_game_priorities[env_id] = priorities - # create new GameSegment - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) + # Start a new game segment. + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) game_segments[env_id].reset(observation_window_stack[env_id]) self._env_info[env_id]['step'] += 1 - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - self._env_info[env_id]['text_bleu'] += text_bleu - collected_step += 1 self._env_info[env_id]['time'] += self._timer.value + interaction_duration - if episode_timestep.done: - reward = episode_timestep.info['eval_episode_return'] - info = { - 'reward': reward, - 'time': self._env_info[env_id]['time'], - 'step': self._env_info[env_id]['step'], - } - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - info.update({'text_bleu':self._env_info[env_id]['text_bleu'] / self._env_info[env_id]['step']}) - + + # --- Episode Termination Handling --- + if done: + collected_episode += 1 + reward = info['eval_episode_return'] + log_info = {'reward': reward, 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step']} if not collect_with_pure_policy: - info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + log_info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 if self.policy_config.gumbel_algo: - info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] + log_info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 + self._episode_info.append(log_info) - collected_episode += 1 - self._episode_info.append(info) - - # ============================================================== - # if it is the end of the game, we will save the game segment - # ============================================================== - - # NOTE: put the penultimate game segment in one episode into the trajectory_pool - # pad over 2th last game_segment using the last game_segment + # Pad and save the segment before the final one. if last_game_segments[env_id] is not None: - self.pad_and_save_last_trajectory( - env_id, last_game_segments, last_game_priorities, game_segments, dones - ) - - # store current segment trajectory + self.pad_and_save_last_trajectory(env_id, last_game_segments, last_game_priorities, game_segments, dones) + + # Process and save the final segment of the episode. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - - # NOTE: put the last game segment in one episode into the trajectory_pool game_segments[env_id].game_segment_to_array() - - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id].reward_segment) != 0: + if len(game_segments[env_id].reward_segment) > 0: self.game_segment_pool.append((game_segments[env_id], priorities, dones[env_id])) - # print(game_segments[env_id].reward_segment) - # reset the finished env and init game_segments + # Reset environment-specific states for a new episode. if n_episode > self._env_num: - # Get current ready env obs. + # Re-initialize the state for this env_id. init_obs = self._env.ready_obs - retry_waiting_time = 0.001 - while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + while env_id not in init_obs: + self._logger.warning(f"Waiting for env {env_id} to reset...") time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 - ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) init_obs = self._env.ready_obs - - new_available_env_id = set(init_obs.keys()).difference(ready_env_id) - ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) - remain_episode -= min(len(new_available_env_id), remain_episode) - + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) - - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) - observation_window_stack[env_id] = deque( - [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) + chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) + + # Reset game segment and observation stack. + game_segments[env_id] = GameSegment(self._env.action_space, game_segment_length=self.policy_config.game_segment_length, config=self.policy_config) + observation_window_stack[env_id].clear() + for _ in range(self.policy_config.model.frame_stack_num): + observation_window_stack[env_id].append(init_obs[env_id]['observation']) game_segments[env_id].reset(observation_window_stack[env_id]) last_game_segments[env_id] = None last_game_priorities[env_id] = None - # log - self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) - if not collect_with_pure_policy: - self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) - self_play_moves += eps_steps_lst[env_id] - self_play_episodes += 1 - - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - eps_steps_lst[env_id] = 0 - visit_entropies_lst[env_id] = 0 + # Reset tracking and logging variables. + pred_values_lst[env_id], search_values_lst[env_id] = [], [] + eps_steps_lst[env_id], visit_entropies_lst[env_id] = 0, 0 + if self.policy_config.gumbel_algo: + completed_value_lst[env_id] = 0 - # Env reset is done by env_manager automatically - self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. + # Reset policy and collector stats for the finished environment. + self._policy.reset([env_id]) self._reset_stat(env_id) ready_env_id.remove(env_id) + # --- Check for Collection Completion --- if collected_episode >= n_episode: - # [data, meta_data] - return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ - { - 'priorities': self.game_segment_pool[i][1], - 'done': self.game_segment_pool[i][2], + # Prepare data for returning. + return_data = [ + [item[0] for item in self.game_segment_pool], + [{ + 'priorities': item[1], + 'done': item[2], 'unroll_plus_td_steps': self.unroll_plus_td_steps - } for i in range(len(self.game_segment_pool)) + } for item in self.game_segment_pool] ] self.game_segment_pool.clear() break - + + # --- Finalize and Log --- collected_duration = sum([d['time'] for d in self._episode_info]) - # reduce data when enables DDP - if self._world_size > 1: - # Before allreduce - self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}") - collected_step = allreduce_data(collected_step, 'sum') - collected_episode = allreduce_data(collected_episode, 'sum') - collected_duration = allreduce_data(collected_duration, 'sum') - # After allreduce - self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}") + # NOTE: Only for usual DDP not for unizero_multitask pipeline. + # In DDP, aggregate statistics across all processes. + # if self._world_size > 1: + # collected_step = allreduce_data(collected_step, 'sum') + # collected_episode = allreduce_data(collected_episode, 'sum') + # collected_duration = allreduce_data(collected_duration, 'sum') self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration - # log self._output_log(train_iter) return return_data def _output_log(self, train_iter: int) -> None: """ Overview: - Log the collector's data and output the log information. + Aggregates and logs collection statistics to the console, TensorBoard, and WandB. + This method is only executed by the rank 0 process in a distributed setup. Arguments: - - train_iter (:obj:`int`): Current training iteration number for logging context. + - train_iter (:obj:`int`): The current training iteration number, used as the logging step. """ if self._rank != 0: return + if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) envstep_count = sum([d['step'] for d in self._episode_info]) duration = sum([d['time'] for d in self._episode_info]) episode_reward = [d['reward'] for d in self._episode_info] - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - episode_bleu = [d['text_bleu'] for d in self._episode_info] - - if not self.collect_with_pure_policy: - visit_entropy = [d['visit_entropy'] for d in self._episode_info] - else: - visit_entropy = [0.0] - if self.policy_config.gumbel_algo: - completed_value = [d['completed_value'] for d in self._episode_info] - self._total_duration += duration + info = { 'episode_count': episode_count, 'envstep_count': envstep_count, 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_episode_per_sec': episode_count / duration, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_episode_per_sec': episode_count / duration if duration > 0 else 0, 'collect_time': duration, 'reward_mean': np.mean(episode_reward), 'reward_std': np.std(episode_reward), @@ -822,22 +651,32 @@ def _output_log(self, train_iter: int) -> None: 'total_envstep_count': self._total_envstep_count, 'total_episode_count': self._total_episode_count, 'total_duration': self._total_duration, - 'visit_entropy': np.mean(visit_entropy), } - if "world_model_cfg" in self.policy_config.model and self.policy_config.model.world_model_cfg.obs_type == 'text': - info.update({'text_avg_bleu':np.mean(episode_bleu)}) + + if not self.collect_with_pure_policy: + visit_entropy = [d['visit_entropy'] for d in self._episode_info] + info['visit_entropy_mean'] = np.mean(visit_entropy) if self.policy_config.gumbel_algo: - info['completed_value'] = np.mean(completed_value) + completed_value = [d['completed_value'] for d in self._episode_info] + info['completed_value_mean'] = np.mean(completed_value) + self._episode_info.clear() - self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + # Log to console + self._logger.info("Collector Training Summary:\n{}".format('\n'.join([f' {k}: {v}' for k, v in info.items()]))) + + # Log to TensorBoard and WandB for k, v in info.items(): - if k in ['each_reward']: - continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - if k in ['total_envstep_count']: - continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) - + if self.task_id is None: + tb_prefix_iter = f'{self._instance_name}_iter/' + tb_prefix_step = f'{self._instance_name}_step/' + else: + tb_prefix_iter = f'{self._instance_name}_iter_task{self.task_id}/' + tb_prefix_step = f'{self._instance_name}_step_task{self.task_id}/' + + self._tb_logger.add_scalar(tb_prefix_iter + k, v, train_iter) + self._tb_logger.add_scalar(tb_prefix_step + k, v, self._total_envstep_count) + if self.policy_config.use_wandb: - wandb.log({'{}_step/'.format(self._instance_name) + k: v for k, v in info.items()}, step=self._total_envstep_count) + wandb_log_data = {tb_prefix_step + k: v for k, v in info.items()} + wandb.log(wandb_log_data, step=self._total_envstep_count) \ No newline at end of file diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 2a70feea5..5fc680d97 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -1,18 +1,20 @@ import copy +import threading import time from collections import namedtuple -from typing import Optional, Callable, Tuple, Dict, Any +from typing import Any, Callable, Dict, Optional, Tuple import numpy as np import torch import wandb from ding.envs import BaseEnvManager -from ding.torch_utils import to_ndarray, to_item, to_tensor -from ding.utils import build_logger, EasyTimer -from ding.utils import get_world_size, get_rank, broadcast_object_list -from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor +from ding.torch_utils import to_item, to_ndarray, to_tensor +from ding.utils import (EasyTimer, broadcast_object_list, build_logger, + get_rank, get_world_size) +from ding.worker.collector.base_serial_evaluator import (ISerialEvaluator, + VectorEvalMonitor) +from ditk import logging from easydict import EasyDict - from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation @@ -20,95 +22,101 @@ class MuZeroEvaluator(ISerialEvaluator): """ Overview: - The Evaluator class for MCTS+RL algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. + The Evaluator for MCTS-based reinforcement learning algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. Interfaces: __init__, reset, reset_policy, reset_env, close, should_eval, eval Properties: env, policy """ + # Default configuration for the MuZeroEvaluator. + config = dict( + # The frequency of evaluation, measured in training iterations. + eval_freq=5000, + ) + @classmethod def default_config(cls: type) -> EasyDict: """ Overview: - Retrieve the default configuration for the evaluator by merging evaluator-specific defaults with other - defaults and any user-provided configuration. + Get the default configuration of the MuZeroEvaluator. Returns: - - cfg (:obj:`EasyDict`): The default configuration for the evaluator. + - cfg (:obj:`EasyDict`): An EasyDict object representing the default configuration. """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg - config = dict( - # Evaluate every "eval_freq" training iterations. - eval_freq=50, - ) - def __init__( self, eval_freq: int = 1000, n_evaluator_episode: int = 3, - stop_value: int = 1e6, - env: BaseEnvManager = None, - policy: namedtuple = None, - tb_logger: 'SummaryWriter' = None, # noqa - exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'evaluator', - policy_config: 'policy_config' = None, # noqa + stop_value: float = 1e6, + env: Optional[BaseEnvManager] = None, + policy: Optional[namedtuple] = None, + tb_logger: Optional['SummaryWriter'] = None, + exp_name: str = 'default_experiment', + instance_name: str = 'evaluator', + policy_config: Optional[EasyDict] = None, + task_id: Optional[int] = None, ) -> None: """ Overview: - Initialize the evaluator with configuration settings for various components such as logger helper and timer. + Initializes the MuZeroEvaluator. This evaluator is compatible with MuZero, Sampled MuZero, Gumbel MuZero, EfficientZero, UniZero, and Sampled UniZero (i.e., all algorithms except AlphaZero). Arguments: - - eval_freq (:obj:`int`): Evaluation frequency in terms of training steps. - - n_evaluator_episode (:obj:`int`): Number of episodes to evaluate in total. - - stop_value (:obj:`float`): A reward threshold above which the training is considered converged. - - env (:obj:`Optional[BaseEnvManager]`): An optional instance of a subclass of BaseEnvManager. - - policy (:obj:`Optional[namedtuple]`): An optional API namedtuple defining the policy for evaluation. - - tb_logger (:obj:`Optional[SummaryWriter]`): Optional TensorBoard logger instance. - - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. - - instance_name (:obj:`str`): Name of this evaluator instance. - - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. + - eval_freq (:obj:`int`): The frequency, in training iterations, at which to run evaluation. + - n_evaluator_episode (:obj:`int`): The total number of episodes to run during each evaluation. + - stop_value (:obj:`float`): The reward threshold at which training is considered converged and will stop. + - env (:obj:`Optional[BaseEnvManager]`): An optional environment manager for evaluation. + - policy (:obj:`Optional[namedtuple]`): An optional policy for evaluation. + - tb_logger (:obj:`Optional['SummaryWriter']`): An optional TensorBoard logger. + - exp_name (:obj:`str`): The name of the experiment, used for logging. + - instance_name (:obj:`str`): The name of this evaluator instance. + - policy_config (:obj:`Optional[EasyDict]`): Configuration for the policy. + - task_id (:obj:`Optional[int]`): The unique identifier for the task. If None, the evaluator operates in single-task mode. In a multi-task setting, each task corresponds to a specific evaluator instance. """ + self.stop_event = threading.Event() # Event to signal a stop, e.g., due to a timeout. + self.task_id = task_id self._eval_freq = eval_freq self._exp_name = exp_name self._instance_name = instance_name + self._rank = get_rank() - # Logger (Monitor will be initialized in policy setter) - # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output. - if get_rank() == 0: + # Initialize logger. Only rank 0 needs a full logger with TensorBoard. + if self._rank == 0: if tb_logger is not None: self._logger, _ = build_logger( - './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name ) else: - self._logger, self._tb_logger = None, None # for close elegantly + if tb_logger is not None: + self._logger, _ = build_logger( + f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger + logging.info(f'rank {self._rank}, self.task_id: {self.task_id}') self.reset(policy, env) - self._timer = EasyTimer() self._default_n_episode = n_evaluator_episode self._stop_value = stop_value # ============================================================== - # MCTS+RL related core code + # MCTS+RL related core properties # ============================================================== self.policy_config = policy_config def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the environment for the evaluator, optionally replacing it with a new environment. - If _env is None, reset the old environment. If _env is not None, replace the old environment - in the evaluator with the new passed in environment and launch. + Reset the environment. If a new environment is provided, it replaces the old one. Arguments: - - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. + - _env (:obj:`Optional[BaseEnvManager]`): New environment manager to use. If None, resets the existing environment. """ if _env is not None: self._env = _env @@ -120,29 +128,22 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset the policy for the evaluator, optionally replacing it with a new policy. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the evaluator with the new passed in policy. + Reset the policy. If a new policy is provided, it replaces the old one. Arguments: - - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. + - _policy (:obj:`Optional[namedtuple]`): New policy to use. If None, resets the existing policy. """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, '_env'), "Please set environment first." if _policy is not None: self._policy = _policy - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset both the policy and environment for the evaluator, optionally replacing them. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the evaluator with the new passed in \ - environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the evaluator with the new passed in policy. + Reset both the policy and the environment. Arguments: - - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. - - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. + - _policy (:obj:`Optional[namedtuple]`): New policy to use. + - _env (:obj:`Optional[BaseEnvManager]`): New environment manager to use. """ if _env is not None: self.reset_env(_env) @@ -151,37 +152,36 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._max_episode_return = float("-inf") self._last_eval_iter = 0 self._end_flag = False - def close(self) -> None: """ Overview: - Close the evaluator, the environment, flush and close the TensorBoard logger if applicable. + Close the evaluator, including the environment and the TensorBoard logger. """ if self._end_flag: return self._end_flag = True - self._env.close() + if hasattr(self, '_env'): + self._env.close() if self._tb_logger: self._tb_logger.flush() self._tb_logger.close() - def __del__(self): + def __del__(self) -> None: """ Overview: - Execute the close command and close the evaluator. __del__ is automatically called \ - to destroy the evaluator instance when the evaluator finishes its work + Destructor that ensures `close` is called to clean up resources. """ self.close() def should_eval(self, train_iter: int) -> bool: """ Overview: - Determine whether to initiate evaluation based on the training iteration count and evaluation frequency. + Determine whether it's time to run an evaluation based on the training iteration. Arguments: - - train_iter (:obj:`int`): The current count of training iterations. + - train_iter (:obj:`int`): The current training iteration. Returns: - - (:obj:`bool`): `True` if evaluation should be initiated, otherwise `False`. + - (:obj:`bool`): True if evaluation should be run, otherwise False. """ if train_iter == self._last_eval_iter: return False @@ -192,54 +192,64 @@ def should_eval(self, train_iter: int) -> bool: def eval( self, - save_ckpt_fn: Callable = None, + save_ckpt_fn: Optional[Callable] = None, train_iter: int = -1, envstep: int = -1, n_episode: Optional[int] = None, return_trajectory: bool = False, - ) -> Tuple[bool, float]: + ) -> Tuple[bool, Dict[str, Any]]: """ Overview: - Evaluate the current policy, storing the best policy if it achieves the highest historical reward. + Run a full evaluation process. It will evaluate the current policy, log the results, + and save a checkpoint if a new best performance is achieved. Arguments: - - save_ckpt_fn (:obj:`Optional[Callable]`): Optional function to save a checkpoint when a new best reward is achieved. - - train_iter (:obj:`int`): The current training iteration count. - - envstep (:obj:`int`): The current environment step count. - - n_episode (:obj:`Optional[int]`): Optional number of evaluation episodes; defaults to the evaluator's setting. - - return_trajectory (:obj:`bool`): Return the evaluated trajectory `game_segments` in `episode_info` if True. + - save_ckpt_fn (:obj:`Optional[Callable]`): A function to save a checkpoint. Called when a new best reward is achieved. + - train_iter (:obj:`int`): The current training iteration. + - envstep (:obj:`int`): The current total environment steps. + - n_episode (:obj:`Optional[int]`): The number of episodes to evaluate. Defaults to the value set in `__init__`. + - return_trajectory (:obj:`bool`): Whether to return the collected `game_segments` in the result dictionary. Returns: - - stop_flag (:obj:`bool`): Indicates whether the training can be stopped based on the stop value. - - episode_info (:obj:`Dict[str, Any]`): A dictionary containing information about the evaluation episodes. + - stop_flag (:obj:`bool`): A flag indicating whether the training should stop (e.g., if the stop value is reached). + - episode_info (:obj:`Dict[str, Any]`): A dictionary containing evaluation results, such as rewards and episode lengths. """ - # the evaluator only works on rank0 + if torch.cuda.is_available() and self.task_id is not None: + # NOTE: important for unizero_multitask pipeline. + self._logger.info(f"=========in eval() Rank {get_rank()} ===========") + device = torch.cuda.current_device() + self._logger.info(f"before set device: {device}") + torch.cuda.set_device(get_rank()) + self._logger.info(f"after set device: {get_rank()}") + episode_info = None stop_flag = False - if get_rank() == 0: + if self.task_id is not None and get_rank() >= 0: + # In a multi-task setting, each task corresponds to a specific evaluator instance. + eval_flag = True + elif self.task_id is None and get_rank() == 0: + # In a single-task setting, only evaluate rank 0. + eval_flag = True + else: + eval_flag = False + + if eval_flag: if n_episode is None: n_episode = self._default_n_episode - assert n_episode is not None, "please indicate eval n_episode" + assert n_episode is not None, "Please specify the number of evaluation episodes (n_episode)." envstep_count = 0 eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) env_nums = self._env.env_num self._env.reset() - self._policy.reset() + self._policy.reset(task_id=self.task_id) - # initializations + # Initializations init_obs = self._env.ready_obs + # Wait for all environments to be ready, especially in subprocess-based environment managers. retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + self._logger.info(f"Waiting for all environments to reset. Current ready envs: {list(init_obs.keys())}") time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, - self._env._env_states) - ) init_obs = self._env.ready_obs action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} @@ -248,20 +258,17 @@ def eval( timestep_dict = {} for i in range(env_nums): if 'timestep' not in init_obs[i]: - if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - print(f"Warning: 'timestep' key is missing in init_obs[{i}]. Assigning value -1. Please note that the unizero algorithm may require the 'timestep' key in init_obs.") + self._logger.warning(f"'timestep' key is missing in init_obs[{i}], assigning value -1") timestep_dict[i] = to_ndarray(init_obs[i].get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict = {i: to_ndarray(init_obs[i]['chance']) for i in range(env_nums)} - dones = np.array([False for _ in range(env_nums)]) game_segments = [ GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) for _ in range(env_nums) ] for i in range(env_nums): @@ -272,73 +279,55 @@ def eval( ready_env_id = set() remain_episode = n_episode eps_steps_lst = np.zeros(env_nums) - with self._timer: while not eval_monitor.is_finished(): - # Get current ready env obs. + # Check if a timeout has occurred. + if self.stop_event.is_set(): + # self.stop_event may be set in safe_eval() methd in lzero/entry/utils.py + self._logger.info("[EVALUATOR]: Evaluation aborted due to timeout.") + break + + # Get observations from ready environments. obs = self._env.ready_obs new_available_env_id = set(obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) - # In a parallel evaluation setting, it's possible for all active environments to finish their - # episodes simultaneously. This can leave `ready_env_id` temporarily empty while the environments - # are being reset by the manager. - # To prevent processing an empty batch, which would cause an IndexError or other errors downstream, - # we check if `ready_env_id` is empty. If so, we sleep briefly to prevent a busy-wait, - # and `continue` to the next loop iteration to wait for newly reset environments to become available. - if not ready_env_id: - time.sleep(0.01) - continue - + # Prepare stacked observations and other inputs for the policy. stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} stack_obs = list(stack_obs.values()) - - action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} - to_play_dict = {env_id: to_play_dict[env_id] for env_id in ready_env_id} - timestep_dict = {env_id: timestep_dict[env_id] for env_id in ready_env_id} action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] to_play = [to_play_dict[env_id] for env_id in ready_env_id] timestep = [timestep_dict[env_id] for env_id in ready_env_id] - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict = {env_id: chance_dict[env_id] for env_id in ready_env_id} - stack_obs = to_ndarray(stack_obs) stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device).float() # ============================================================== - # policy forward + # Policy Forward Pass # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep) - + if self.task_id is None: + # Single-task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep) + else: + # Multi-task setting + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, timestep=timestep, task_id=self.task_id) + + # Unpack policy outputs. actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = { - k: v['root_sampled_actions'] - for k, v in policy_output.items() - } - + root_sampled_actions_dict_with_env_id = {k: v['root_sampled_actions'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = { - k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() - } - visit_entropy_dict_with_env_id = { - k: v['visit_count_distribution_entropy'] - for k, v in policy_output.items() - } - - actions = {} - distributions_dict = {} + timestep_dict_with_env_id = {k: v.get('timestep', -1) for k, v in policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in policy_output.items()} + + # Remap outputs from policy's internal IDs to environment IDs. + actions, distributions_dict, value_dict, pred_value_dict, timestep_dict, visit_entropy_dict = {}, {}, {}, {}, {}, {} if self.policy_config.sampled_algo: root_sampled_actions_dict = {} - value_dict = {} - pred_value_dict = {} - timestep_dict = {} - visit_entropy_dict = {} for index, env_id in enumerate(ready_env_id): actions[env_id] = actions_with_env_id.pop(env_id) @@ -351,45 +340,30 @@ def eval( visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) # ============================================================== - # Interact with env. + # Environment Interaction # ============================================================== timesteps = self._env.step(actions) timesteps = to_tensor(timesteps, dtype=torch.float32) - for env_id, episode_timestep in timesteps.items(): obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info - # obs_input_ids = obs['observation'].long() - # obs_attn_mask = obs['obs_attn_mask'][0].long() - # valid_input_ids = obs_input_ids[obs_attn_mask == 1].tolist() - eps_steps_lst[env_id] += 1 + # This reset logic is specific to UniZero-like models. if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False, task_id=self.task_id) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id], chance_dict[env_id] - ) - else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], - to_play_dict[env_id], timestep_dict[env_id] - ) + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], + to_play_dict[env_id], timestep_dict[env_id] + ) - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] are corresponding to next action + # IMPORTANT: The action_mask and to_play from the new observation correspond to the *next* state. action_mask_dict[env_id] = to_ndarray(obs['action_mask']) to_play_dict[env_id] = to_ndarray(obs['to_play']) timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(obs['chance']) dones[env_id] = done if episode_timestep.done: - # Env reset is done by env_manager automatically. self._policy.reset([env_id]) reward = episode_timestep.info['eval_episode_return'] saved_info = {'eval_episode_return': episode_timestep.info['eval_episode_return']} @@ -398,115 +372,105 @@ def eval( eval_monitor.update_info(env_id, saved_info) eval_monitor.update_reward(env_id, reward) self._logger.info( - "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format( - env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() - ) + f"[EVALUATOR] env {env_id} finished episode, final reward: {eval_monitor.get_latest_reward(env_id)}, " + f"current episode count: {eval_monitor.get_current_episode()}" ) - # reset the finished env and init game_segments + # If there are more episodes to run than available environments, reset and reuse this one. if n_episode > self._env_num: - # Get current ready env obs. init_obs = self._env.ready_obs - retry_waiting_time = 0.001 + # Wait for the environment to be ready again. while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info( - 'Before sleeping, the _env_states is {}'.format(self._env._env_states) - ) + self._logger.info(f"Waiting for env {env_id} to reset. Current ready envs: {list(init_obs.keys())}") time.sleep(retry_waiting_time) - self._logger.info( - '=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10 - ) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format( - retry_waiting_time, self._env._env_states - ) - ) init_obs = self._env.ready_obs new_available_env_id = set(init_obs.keys()).difference(ready_env_id) ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) remain_episode -= min(len(new_available_env_id), remain_episode) + # Re-initialize state for the new episode. action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - if self.policy_config.use_ture_chance_label_in_chance_encoder: - chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) - game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) - game_segments[env_id].reset( - [ - init_obs[env_id]['observation'] - for _ in range(self.policy_config.model.frame_stack_num) - ] + [init_obs[env_id]['observation'] for _ in range(self.policy_config.model.frame_stack_num)] ) eps_steps_lst[env_id] = 0 - - # Env reset is done by env_manager automatically. - self._policy.reset([env_id]) # NOTE: reset the policy for the env_id. Default reset_init_data=True. + # NOTE: Reset the policy state for this env_id. `reset_init_data` defaults to True. + self._policy.reset([env_id]) ready_env_id.remove(env_id) envstep_count += 1 - + duration = self._timer.value episode_return = eval_monitor.get_episode_return() info = { 'train_iter': train_iter, - 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter), + 'ckpt_name': f'iteration_{train_iter}.pth.tar', 'episode_count': n_episode, 'envstep_count': envstep_count, - 'avg_envstep_per_episode': envstep_count / n_episode, + 'avg_envstep_per_episode': envstep_count / n_episode if n_episode > 0 else 0, 'evaluate_time': duration, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_time_per_episode': n_episode / duration, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_time_per_episode': n_episode / duration if duration > 0 else 0, 'reward_mean': np.mean(episode_return), 'reward_std': np.std(episode_return), 'reward_max': np.max(episode_return), - 'reward_min': np.min(episode_return) - # 'each_reward': episode_return, + 'reward_min': np.min(episode_return), } episode_info = eval_monitor.get_episode_info() if episode_info is not None: info.update(episode_info) + + logging.info(f'rank {self._rank}, self.task_id: {self.task_id}') self._logger.info(self._logger.get_tabulate_vars_hor(info)) + + # Log to TensorBoard and WandB. for k, v in info.items(): - if k in ['train_iter', 'ckpt_name', 'each_reward']: + if k in ['train_iter', 'ckpt_name', 'each_reward'] or not np.isscalar(v): continue - if not np.isscalar(v): - continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + if self.task_id is None: + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, envstep) + else: + self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) + self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, envstep) if self.policy_config.use_wandb: - wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep) + wandb.log({f'{self._instance_name}_step/{k}': v}, step=envstep) - episode_return = np.mean(episode_return) - if episode_return > self._max_episode_return: + # Check for new best performance and save checkpoint. + mean_episode_return = np.mean(episode_return) + if mean_episode_return > self._max_episode_return: if save_ckpt_fn: save_ckpt_fn('ckpt_best.pth.tar') - self._max_episode_return = episode_return - stop_flag = episode_return >= self._stop_value and train_iter > 0 + self._max_episode_return = mean_episode_return + + # Check if the stop condition is met. + stop_flag = mean_episode_return >= self._stop_value and train_iter > 0 if stop_flag: self._logger.info( - "[LightZero serial pipeline] " + - "Current episode_return: {} is greater than stop_value: {}".format(episode_return, - self._stop_value) + - ", so your MCTS/RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details." + f"[LightZero serial pipeline] Current episode_return: {mean_episode_return} is greater than " + f"stop_value: {self._stop_value}. The agent is considered converged." ) - if get_world_size() > 1: - objects = [stop_flag, episode_info] - broadcast_object_list(objects, src=0) - stop_flag, episode_info = objects + # NOTE: Only for usual DDP not for unizero_multitask pipeline. + # Finalize DDP synchronization for evaluation results. + # if get_world_size() > 1: + # objects = [stop_flag, episode_info] + # print(f'rank {self._rank}, self.task_id: {self.task_id}') + # print('before broadcast_object_list') + # broadcast_object_list(objects, src=0) + # print('evaluator after broadcast_object_list') + # stop_flag, episode_info = objects episode_info = to_item(episode_info) if return_trajectory: diff --git a/lzero/worker/muzero_segment_collector.py b/lzero/worker/muzero_segment_collector.py index 46cc016bc..807829ced 100644 --- a/lzero/worker/muzero_segment_collector.py +++ b/lzero/worker/muzero_segment_collector.py @@ -1,7 +1,7 @@ import logging import time from collections import deque, namedtuple -from typing import Optional, Any, List +from typing import Optional, Any, List, Dict import numpy as np import torch @@ -20,21 +20,20 @@ class MuZeroSegmentCollector(ISerialCollector): """ Overview: - MuZeroSegmentCollector is a data collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, and Gumbel MuZero. - It manages the data collection process for training these algorithms using a serial mechanism. - - The main difference from MuZeroCollector is that MuZeroSegmentCollector returns after collecting a specified number of segments, - whereas MuZeroCollector returns after collecting a complete game. This provides more extensibility and flexibility in data collection. + MuZeroSegmentCollector is a data collector for MCTS+RL algorithms, including MuZero, EfficientZero, + Sampled EfficientZero, and Gumbel MuZero. It manages the data collection process for training these + algorithms using a serial mechanism. + The main difference from MuZeroCollector is that MuZeroSegmentCollector returns after collecting a + specified number of segments, whereas MuZeroCollector returns after collecting a complete game. + This provides more extensibility and flexibility in data collection. Interfaces: - ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, - ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` - + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``collect``, ``close``, ``__del__`` Properties: - ``envstep``: Counter for the current number of environment steps. + - envstep (:obj:`int`): The total number of environment steps collected. """ - # To be compatible with ISerialCollector + # To be compatible with ISerialCollector. config = dict() def __init__( @@ -46,19 +45,22 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: - Initialize the MuZeroSegmentCollector with the given parameters. + Initializes the MuZeroSegmentCollector. Arguments: - - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. - - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. - - policy (:obj:`Optional[namedtuple]`): Namedtuple of the collection mode policy API. - - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. - - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - - instance_name (:obj:`str`): Unique identifier for this collector instance. - - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - collect_print_freq (:obj:`int`): The frequency (in training steps) at which to print collection information. + - env (:obj:`Optional[BaseEnvManager]`): An instance of a vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): A namedtuple containing the collect mode policy API. + - tb_logger (:obj:`Optional[SummaryWriter]`): A TensorBoard logger instance. + - exp_name (:obj:`str`): The name of the experiment, used for logging and saving. + - instance_name (:obj:`str`): A unique identifier for this collector instance. + - policy_config (:obj:`Optional[policy_config]`): The configuration object for the policy. + - task_id (:obj:`int`): The ID of the task, used in multi-task learning settings. """ + self.task_id = task_id self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -67,23 +69,23 @@ def __init__( self._rank = get_rank() self._world_size = get_world_size() + if self._rank == 0: if tb_logger is not None: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), - name=self._instance_name, - need_tb=False + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) self._tb_logger = tb_logger else: self._logger, self._tb_logger = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name ) else: self._logger, _ = build_logger( - path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + path=f'./{self._exp_name}/log/{self._instance_name}', name=self._instance_name, need_tb=False ) - self._tb_logger = None + # TODO(author): This part is for UniZero multi-task DDP v2 compatibility. + self._tb_logger = tb_logger self.policy_config = policy_config self.collect_with_pure_policy = self.policy_config.collect_with_pure_policy @@ -93,12 +95,11 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset or replace the environment managed by this collector. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. + Resets or replaces the environment managed by the collector. + If `_env` is None, it resets the existing environment. Otherwise, it replaces the old + environment with the new one and launches it. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to be used. Defaults to None. """ if _env is not None: self._env = _env @@ -110,35 +111,28 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset or replace the policy used by this collector. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets or replaces the policy used by the collector. + If `_policy` is None, it resets the existing policy. Otherwise, it replaces the old + policy with the new one. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - _policy (:obj:`Optional[namedtuple]`): The new policy's API in a namedtuple format. Defaults to None. """ - assert hasattr(self, '_env'), "please set env first" + assert hasattr(self, '_env'), "Please set env before resetting policy." if _policy is not None: self._policy = _policy - - self._default_num_segments = _policy.get_attribute('cfg').get('num_segments', None) + self._default_num_segments = self._policy.get_attribute('cfg').get('num_segments', None) self._logger.debug( - 'Set default num_segments mode(num_segments({}), env_num({}))'.format(self._default_num_segments, self._env_num) + f'Set default num_segments mode(num_segments({self._default_num_segments}), env_num({self._env_num}))' ) - self._policy.reset() + self._policy.reset(task_id=self.task_id) def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the collector with the given policy and/or environment. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the collector with the new passed \ - in environment and launch. - If _policy is None, reset the old policy. - If _policy is not None, replace the old policy in the collector with the new passed in policy. + Resets the collector state, including the environment and policy. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _policy (:obj:`Optional[namedtuple]`): The new policy to use. Defaults to None. + - _env (:obj:`Optional[BaseEnvManager]`): The new environment to use. Defaults to None. """ if _env is not None: self.reset_env(_env) @@ -147,13 +141,12 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} - # Initialize action_mask_dict, to_play_dict, and chance_dict here to ensure they contain values for all env_id + # Initialize dictionaries to store environment-specific states. self.action_mask_dict = {i: None for i in range(self._env_num)} self.to_play_dict = {i: None for i in range(self._env_num)} + self.timestep_dict = {i: None for i in range(self._env_num)} if self.policy_config.use_ture_chance_label_in_chance_encoder: self.chance_dict = {i: None for i in range(self._env_num)} - - self.timestep_dict = {i: None for i in range(self._env_num)} self.dones = np.array([False for _ in range(self._env_num)]) self.last_game_segments = [None for _ in range(self._env_num)] @@ -166,18 +159,16 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana self._last_train_iter = 0 self._end_flag = False - # A game_segment_pool implementation based on the deque structure. + # A deque-based pool for storing game segments. self.game_segment_pool = deque(maxlen=int(1e6)) self.unroll_plus_td_steps = self.policy_config.num_unroll_steps + self.policy_config.td_steps def _reset_stat(self, env_id: int) -> None: """ Overview: - Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ - and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ - to get more messages. + Resets the statistics for a specific environment. Arguments: - - env_id (:obj:`int`): the id where we need to reset the collector's state + - env_id (:obj:`int`): The ID of the environment to reset. """ self._env_info[env_id] = {'time': 0., 'step': 0} @@ -185,17 +176,16 @@ def _reset_stat(self, env_id: int) -> None: def envstep(self) -> int: """ Overview: - Get the total number of environment steps collected. + Returns the total number of environment steps collected. Returns: - - envstep (:obj:`int`): Total number of environment steps collected. + - envstep (:obj:`int`): The total count of environment steps. """ return self._total_envstep_count def close(self) -> None: """ Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger \ - and close the tb_logger. + Closes the collector, including the environment and the TensorBoard logger. """ if self._end_flag: return @@ -208,79 +198,63 @@ def close(self) -> None: def __del__(self) -> None: """ Overview: - Execute the close command and close the collector. __del__ is automatically called to \ - destroy the collector instance when the collector finishes its work + Ensures that the `close` method is called when the collector instance is deleted. """ self.close() - # ============================================================== - # MCTS+RL related core code - # ============================================================== - def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> Optional[np.ndarray]: """ Overview: - Compute the priorities for transitions based on prediction and search value discrepancies. + Computes priorities for transitions based on the discrepancy between predicted and search values. Arguments: - - i (:obj:`int`): Index of the values in the list to compute the priority for. - - pred_values_lst (:obj:`List[float]`): List of predicted values. - - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + - i (:obj:`int`): The index of the values list to process. + - pred_values_lst (:obj:`List[float]`): A list containing lists of predicted values. + - search_values_lst (:obj:`List[float]`): A list containing lists of search values from MCTS. Returns: - - priorities (:obj:`np.ndarray`): Array of computed priorities. + - priorities (:obj:`Optional[np.ndarray]`): An array of computed priorities, or None if priority is disabled. """ if self.policy_config.use_priority: - # Calculate priorities. The priorities are the L1 losses between the predicted - # values and the search values. We use 'none' as the reduction parameter, which - # means the loss is calculated for each element individually, instead of being summed or averaged. - # A small constant (1e-6) is added to the results to avoid zero priorities. This - # is done because zero priorities could potentially cause issues in some scenarios. + # Calculate priorities as the L1 loss between predicted and search values. + # The reduction is 'none' to get per-element losses. + # A small epsilon (1e-6) is added to prevent zero priorities. pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self.policy_config.device).float().view(-1) - search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device - ).float().view(-1) - priorities = L1Loss(reduction='none' - )(pred_values, - search_values).detach().cpu().numpy() + 1e-6 + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self.policy_config.device).float().view(-1) + priorities = L1Loss(reduction='none')(pred_values, search_values).detach().cpu().numpy() + 1e-6 else: - # priorities is None -> use the max priority for all newly collected data + # If not using priority, all new data will use the maximum priority in the replay buffer. priorities = None return priorities - def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], - last_game_priorities: List[np.ndarray], - game_segments: List[GameSegment], done: np.ndarray) -> None: + def pad_and_save_last_trajectory( + self, i: int, last_game_segments: List[GameSegment], last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray + ) -> None: """ Overview: - Save the game segment to the pool if the current game is finished, padding it if necessary. + Pads the last game segment with data from the current segment and saves it to the pool. + This is done when a game ends or a segment becomes full. Arguments: - - i (:obj:`int`): Index of the current game segment. - - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. - - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. - - game_segments (:obj:`List[GameSegment]`): List of the current game segments. - - done (:obj:`np.ndarray`): Array indicating whether each game is done. - Note: - (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True + - i (:obj:`int`): The index of the current game segment (and environment). + - last_game_segments (:obj:`List[GameSegment]`): The list of previous game segments to be padded. + - last_game_priorities (:obj:`List[np.ndarray]`): The list of priorities for the previous game segments. + - game_segments (:obj:`List[GameSegment]`): The list of current game segments, used for padding data. + - done (:obj:`np.ndarray`): An array indicating whether each game has terminated. """ - # pad over last segment trajectory + # Pad the trajectory of the last segment. beg_index = self.policy_config.model.frame_stack_num end_index = beg_index + self.policy_config.num_unroll_steps + self.policy_config.td_steps - # the start obs is init zero obs, so we take the - # [ : +] obs as the pad obs - # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + # The initial observations are zero-padded, so we take observations from + # [ : + ] for padding. pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] - # NOTE: for unizero + # NOTE: Specific padding logic for UniZero. pad_action_lst = game_segments[i].action_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] - - # NOTE: for unizero pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps + self.policy_config.td_steps] - # EfficientZero original repo bug: - # pad_child_visits_lst = game_segments[i].child_visit_segment[beg_index:end_index] - beg_index = 0 end_index = beg_index + self.unroll_plus_td_steps - 1 - pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] if self.policy_config.use_ture_chance_label_in_chance_encoder: @@ -288,101 +262,87 @@ def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegm beg_index = 0 end_index = beg_index + self.unroll_plus_td_steps - pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] if self.policy_config.gumbel_algo: pad_improved_policy_prob = game_segments[i].improved_policy_probs[beg_index:end_index] - # pad over and save + # Pad and finalize the last game segment. if self.policy_config.gumbel_algo: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_segment_improved_policy=pad_improved_policy_prob) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob + ) else: if self.policy_config.use_ture_chance_label_in_chance_encoder: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, - next_chances=chance_lst) + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst + ) else: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst) - """ - Note: - game_segment element shape: - obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 - rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 - action: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 - child_visits: game_segment_length + num_unroll_steps -> 20 +5 - to_play: game_segment_length -> 20 - action_mask: game_segment_length -> 20 - """ + last_game_segments[i].pad_over( + pad_obs_lst, pad_reward_lst, pad_action_lst, pad_root_values_lst, pad_child_visits_lst + ) last_game_segments[i].game_segment_to_array() - # put the game segment into the pool + # Add the completed game segment to the pool. self.game_segment_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) - # reset last game_segments and last game_priorities for the next collection + # Reset placeholders for the next collection cycle. last_game_segments[i] = None last_game_priorities[i] = None - return None - - def collect(self, - num_segments: Optional[int] = None, - train_iter: int = 0, - policy_kwargs: Optional[dict] = None, - collect_with_pure_policy: bool = False) -> List[Any]: + def collect( + self, + num_segments: Optional[int] = None, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None, + collect_with_pure_policy: bool = False + ) -> List[Any]: """ Overview: - Collect `num_segments` segments of data with policy_kwargs, trained for `train_iter` iterations. + Collects a specified number of game segments using the policy. Arguments: - - num_segments (:obj:`Optional[int]`): Number of segments to collect. - - train_iter (:obj:`int`): Number of training iterations completed so far. - - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. - - collect_with_pure_policy (:obj:`bool`): Whether to collect data using pure policy without MCTS. + - num_segments (:obj:`Optional[int]`): The number of segments to collect. If None, uses the default. + - train_iter (:obj:`int`): The current training iteration, used for logging. + - policy_kwargs (:obj:`Optional[dict]`): Additional arguments for the policy forward pass. + - collect_with_pure_policy (:obj:`bool`): If True, collects data using a pure policy without MCTS. Returns: - - return_data (:obj:`List[Any]`): Collected data in the form of a list. + - return_data (:obj:`List[Any]`): A list containing the collected game segments and their metadata. """ if num_segments is None: if self._default_num_segments is None: - raise RuntimeError("Please specify collect num_segments") + raise RuntimeError("Please specify num_segments for collection.") else: num_segments = self._default_num_segments - assert num_segments == self._env_num, "Please make sure num_segments == env_num{}/{}".format(num_segments, self._env_num) + assert num_segments == self._env_num, f"num_segments({num_segments}) must be equal to env_num({self._env_num})." if policy_kwargs is None: policy_kwargs = {} - temperature = policy_kwargs['temperature'] - epsilon = policy_kwargs['epsilon'] + temperature = policy_kwargs.get('temperature', 1.0) + epsilon = policy_kwargs.get('epsilon', 0.0) + # Initializations collected_episode = 0 collected_step = 0 env_nums = self._env_num - - # initializations init_obs = self._env.ready_obs + # Wait for all environments to be ready, especially in a subprocess setup. retry_waiting_time = 0.05 - while len(init_obs.keys()) != self._env_num: - # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # len(self._env.ready_obs), especially in tictactoe env. - self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) - self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) + while len(init_obs.keys()) != env_nums: + self._logger.info(f'Waiting for all environments to reset. Ready envs: {list(init_obs.keys())}') time.sleep(retry_waiting_time) - self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - self._logger.info( - 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - ) init_obs = self._env.ready_obs for env_id in range(env_nums): - if env_id in init_obs.keys(): + if env_id in init_obs: self.action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) self.to_play_dict[env_id] = to_ndarray(init_obs[env_id]['to_play']) - if 'timestep' not in init_obs[env_id]: - print(f"Warning: 'timestep' key is missing in init_obs[{env_id}], assigning value -1") self.timestep_dict[env_id] = to_ndarray(init_obs[env_id].get('timestep', -1)) - + if 'timestep' not in init_obs[env_id]: + self._logger.warning(f"'timestep' key missing in init_obs[{env_id}], assigning default -1.") if self.policy_config.use_ture_chance_label_in_chance_encoder: self.chance_dict[env_id] = to_ndarray(init_obs[env_id]['chance']) @@ -390,151 +350,95 @@ def collect(self, GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) for _ in range(env_nums) ] - # stacked observation windows in reset stage for init game_segments - observation_window_stack = [[] for _ in range(env_nums)] - for env_id in range(env_nums): - observation_window_stack[env_id] = deque( - [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)], - maxlen=self.policy_config.model.frame_stack_num - ) + # Stacked observation windows for initializing game segments. + observation_window_stack = [deque(maxlen=self.policy_config.model.frame_stack_num) for _ in range(env_nums)] + for env_id in range(env_nums): + initial_frames = [to_ndarray(init_obs[env_id]['observation']) for _ in range(self.policy_config.model.frame_stack_num)] + observation_window_stack[env_id].extend(initial_frames) game_segments[env_id].reset(observation_window_stack[env_id]) - # for priorities in self-play + # Lists for storing values for priority calculation. search_values_lst = [[] for _ in range(env_nums)] pred_values_lst = [[] for _ in range(env_nums)] if self.policy_config.gumbel_algo: improved_policy_lst = [[] for _ in range(env_nums)] - # some logs + # Logging variables. eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros(env_nums) if self.policy_config.gumbel_algo: completed_value_lst = np.zeros(env_nums) - self_play_moves = 0. - self_play_episodes = 0. - self_play_moves_max = 0 - self_play_visit_entropy = [] - total_transitions = 0 if collect_with_pure_policy: - temp_visit_list = [0.0 for i in range(self._env.action_space.n)] + temp_visit_list = [0.0 for _ in range(self._env.action_space.n)] while True: with self._timer: - # Get current ready env obs. + # Get observations from ready environments. obs = self._env.ready_obs ready_env_id = set(obs.keys()) if len(ready_env_id) < self._env_num: - logging.info(f'muzero_segment_collector: len(ready_env_id) < self._env_num, ready_env_id: {ready_env_id}, self._env_num: {self._env_num}') - - # TODO: For UniZero, during the init-infer process, it is necessary to retrieve the current kv_cache from the kv_cache_dict corresponding to each env_id. - # In theory, this requires waiting for all environments to be ready. However, in practice, - # waiting for all environments to be ready can have a significant negative impact on UniZero's performance, - # whereas the impact on MuZero is relatively small. + self._logger.debug(f'Only {len(ready_env_id)}/{self._env_num} envs are ready.') + + # TODO: For UniZero, waiting for all environments to be ready can negatively impact performance. + # This wait loop is currently commented out, but its impact should be considered. # while len(obs.keys()) != self._env_num: - # # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to - # # len(self._env.ready_obs), especially in tictactoe env. - # self._logger.info('The current init_obs.keys() is {}'.format(obs.keys())) - # self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) # time.sleep(retry_waiting_time) - # self._logger.info('=' * 10 + 'Wait for all environments (subprocess) to finish resetting.' + '=' * 10) - # self._logger.info( - # 'After sleeping {}s, the current _env_states is {}'.format(retry_waiting_time, self._env._env_states) - # ) # obs = self._env.ready_obs # ready_env_id = set(obs.keys()) - stack_obs = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} - stack_obs = list(stack_obs.values()) + # Prepare stacked observations for the policy network. + stack_obs_dict = {env_id: game_segments[env_id].get_obs() for env_id in ready_env_id} + stack_obs_list = [stack_obs_dict[env_id] for env_id in sorted(list(ready_env_id))] self.action_mask_dict_tmp = {env_id: self.action_mask_dict[env_id] for env_id in ready_env_id} self.to_play_dict_tmp = {env_id: self.to_play_dict[env_id] for env_id in ready_env_id} self.timestep_dict_tmp = {env_id: self.timestep_dict[env_id] for env_id in ready_env_id} - - action_mask = [self.action_mask_dict_tmp[env_id] for env_id in ready_env_id] - to_play = [self.to_play_dict_tmp[env_id] for env_id in ready_env_id] - timestep = [self.timestep_dict_tmp[env_id] for env_id in ready_env_id] + + action_mask = [self.action_mask_dict_tmp[env_id] for env_id in sorted(list(ready_env_id))] + to_play = [self.to_play_dict_tmp[env_id] for env_id in sorted(list(ready_env_id))] + timestep = [self.timestep_dict_tmp[env_id] for env_id in sorted(list(ready_env_id))] if self.policy_config.use_ture_chance_label_in_chance_encoder: self.chance_dict_tmp = {env_id: self.chance_dict[env_id] for env_id in ready_env_id} - stack_obs = to_ndarray(stack_obs) - # return stack_obs shape: [B, S*C, W, H] e.g. [8, 4*1, 96, 96] - stack_obs = prepare_observation(stack_obs, self.policy_config.model.model_type) - stack_obs = torch.from_numpy(stack_obs).to(self.policy_config.device) + stack_obs_array = to_ndarray(stack_obs_list) + stack_obs_tensor = prepare_observation(stack_obs_array, self.policy_config.model.model_type) + stack_obs_tensor = torch.from_numpy(stack_obs_tensor).to(self.policy_config.device) # ============================================================== - # Key policy forward step + # Perform a forward pass with the policy. # ============================================================== - # logging.info(f'ready_env_id:{ready_env_id}') - # logging.info(f'timestep:{timestep}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, timestep=timestep) + policy_args = (stack_obs_tensor, action_mask, temperature, to_play, epsilon) + policy_kwargs_forward = {'ready_env_id': sorted(list(ready_env_id)), 'timestep': timestep} + if self.task_id is not None: + policy_kwargs_forward['task_id'] = self.task_id + + policy_output = self._policy.forward(*policy_args, **policy_kwargs_forward) - # Extract relevant policy outputs + # Extract policy outputs. actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} value_dict_with_env_id = {k: v['searched_value'] for k, v in policy_output.items()} pred_value_dict_with_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} - timestep_dict_with_env_id = { - k: v['timestep'] if 'timestep' in v else -1 for k, v in policy_output.items() - } - - if self.policy_config.sampled_algo: - root_sampled_actions_dict_with_env_id = { - k: v['root_sampled_actions'] for k, v in policy_output.items() - } if not collect_with_pure_policy: - distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in - policy_output.items()} - visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in - policy_output.items()} - - if self.policy_config.gumbel_algo: - improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in - policy_output.items()} - completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} - - # Initialize dictionaries to store results - actions = {} - value_dict = {} - pred_value_dict = {} - timestep_dict = {} - - if not collect_with_pure_policy: - distributions_dict = {} - visit_entropy_dict = {} - + distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} + visit_entropy_dict_with_env_id = {k: v['visit_count_distribution_entropy'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - + root_sampled_actions_dict_with_env_id = {k: v['root_sampled_actions'] for k, v in policy_output.items()} if self.policy_config.gumbel_algo: - improved_policy_dict = {} - completed_value_dict = {} - - # Populate the result dictionaries - for env_id in ready_env_id: - actions[env_id] = actions_with_env_id.pop(env_id) - value_dict[env_id] = value_dict_with_env_id.pop(env_id) - pred_value_dict[env_id] = pred_value_dict_with_env_id.pop(env_id) - timestep_dict[env_id] = timestep_dict_with_env_id.pop(env_id) - - if not collect_with_pure_policy: - distributions_dict[env_id] = distributions_dict_with_env_id.pop(env_id) - - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_with_env_id.pop(env_id) + improved_policy_dict_with_env_id = {k: v['improved_policy_probs'] for k, v in policy_output.items()} + completed_value_with_env_id = {k: v['roots_completed_value'] for k, v in policy_output.items()} - visit_entropy_dict[env_id] = visit_entropy_dict_with_env_id.pop(env_id) - - if self.policy_config.gumbel_algo: - improved_policy_dict[env_id] = improved_policy_dict_with_env_id.pop(env_id) - completed_value_dict[env_id] = completed_value_with_env_id.pop(env_id) + # Populate the result dictionaries, mapping outputs to original env_ids. + actions: Dict[int, Any] = {env_id: actions_with_env_id.pop(env_id) for env_id in ready_env_id} # ============================================================== - # Interact with the environment + # Step the environments with the chosen actions. # ============================================================== timesteps = self._env.step(actions) @@ -542,108 +446,98 @@ def collect(self, for env_id, episode_timestep in timesteps.items(): with self._timer: + # Handle abnormal timesteps by resetting the environment and policy state. if episode_timestep.info.get('abnormal', False): - # If there is an abnormal episode_timestep, reset all the related variables(including this env). - # suppose there is no reset param, reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) - self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, episode_timestep.info)) + self._logger.info(f'Env {env_id} had an abnormal step, info: {episode_timestep.info}') continue + obs, reward, done, info = episode_timestep.obs, episode_timestep.reward, episode_timestep.done, episode_timestep.info + # Store search statistics in the game segment. if collect_with_pure_policy: game_segments[env_id].store_search_stats(temp_visit_list, 0) else: if self.policy_config.sampled_algo: game_segments[env_id].store_search_stats( - distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] + distributions_dict_with_env_id[env_id], value_dict_with_env_id[env_id], root_sampled_actions_dict_with_env_id[env_id] ) elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], - improved_policy=improved_policy_dict[env_id]) + game_segments[env_id].store_search_stats( + distributions_dict_with_env_id[env_id], value_dict_with_env_id[env_id], + improved_policy=improved_policy_dict_with_env_id[env_id] + ) else: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + game_segments[env_id].store_search_stats(distributions_dict_with_env_id[env_id], value_dict_with_env_id[env_id]) - # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} - # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` + # Append the new transition to the game segment. + append_kwargs = {'timestep': to_ndarray(obs.get('timestep', -1))} if self.policy_config.use_ture_chance_label_in_chance_encoder: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], - self.to_play_dict_tmp[env_id], timestep=to_ndarray(obs['timestep']), chance=self.chance_dict_tmp[env_id] - ) - else: - game_segments[env_id].append( - actions[env_id], to_ndarray(obs['observation']), reward, self.action_mask_dict_tmp[env_id], - self.to_play_dict_tmp[env_id], timestep=to_ndarray(obs['timestep']) - ) - - # NOTE: the position of code snippet is very important. - # the obs['action_mask'] and obs['to_play'] are corresponding to the next action - self.action_mask_dict_tmp[env_id] = to_ndarray(obs['action_mask']) - self.to_play_dict_tmp[env_id] = to_ndarray(obs['to_play']) - # self.timestep_dict_tmp[env_id] = to_ndarray(obs['timestep']) - self.timestep_dict_tmp[env_id] = to_ndarray(obs.get('timestep', -1)) - - + append_kwargs['chance'] = self.chance_dict_tmp[env_id] + + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), reward, + self.action_mask_dict_tmp[env_id], self.to_play_dict_tmp[env_id], **append_kwargs + ) + + # NOTE: This position is crucial. The action_mask and to_play from the new observation correspond to the *next* state. + self.action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + self.to_play_dict[env_id] = to_ndarray(obs['to_play']) + self.timestep_dict[env_id] = to_ndarray(obs.get('timestep', -1)) if self.policy_config.use_ture_chance_label_in_chance_encoder: - self.chance_dict_tmp[env_id] = to_ndarray(obs['chance']) + self.chance_dict[env_id] = to_ndarray(obs['chance']) - if self.policy_config.ignore_done: - self.dones[env_id] = False - else: - self.dones[env_id] = done + self.dones[env_id] = False if self.policy_config.ignore_done else done if not collect_with_pure_policy: - visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + visit_entropies_lst[env_id] += visit_entropy_dict_with_env_id[env_id] if self.policy_config.gumbel_algo: - completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) + completed_value_lst[env_id] += np.mean(np.array(completed_value_with_env_id[env_id])) eps_steps_lst[env_id] += 1 + + # NOTE: Specific reset logic for UniZero. if self._policy.get_attribute('cfg').type in ['unizero', 'sampled_unizero']: - # ============ only for UniZero now ============ self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) - total_transitions += 1 - if self.policy_config.use_priority: - pred_values_lst[env_id].append(pred_value_dict[env_id]) - search_values_lst[env_id].append(value_dict[env_id]) + pred_values_lst[env_id].append(pred_value_dict_with_env_id[env_id]) + search_values_lst[env_id].append(value_dict_with_env_id[env_id]) if self.policy_config.gumbel_algo and not collect_with_pure_policy: - improved_policy_lst[env_id].append(improved_policy_dict[env_id]) + improved_policy_lst[env_id].append(improved_policy_dict_with_env_id[env_id]) - # append the newest obs + # Append the newest observation to the observation window. observation_window_stack[env_id].append(to_ndarray(obs['observation'])) # ============================================================== - # we will save a game segment if it is the end of the game or the next game segment is finished. + # Save a game segment if it is full or the game has ended. # ============================================================== - - # if game segment is full, we will save the last game segment if game_segments[env_id].is_full(): - # pad over last segment trajectory + # If there's a previous segment, pad and save it. if self.last_game_segments[env_id] is not None: - # TODO(pu): return the one game segment + # TODO(pu): This logic pads and saves one game segment at a time. self.pad_and_save_last_trajectory( env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones ) - # calculate priority + # Calculate priorities for the collected transitions. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] + pred_values_lst[env_id], search_values_lst[env_id] = [], [] if self.policy_config.gumbel_algo and not collect_with_pure_policy: improved_policy_lst[env_id] = [] - # the current game_segments become last_game_segment + # The current segment now becomes the 'last' segment for the next padding operation. self.last_game_segments[env_id] = game_segments[env_id] self.last_game_priorities[env_id] = priorities - # create new GameSegment + # Create a new game segment to continue collection. game_segments[env_id] = GameSegment( self._env.action_space, game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config + config=self.policy_config, + task_id=self.task_id ) game_segments[env_id].reset(observation_window_stack[env_id]) @@ -652,94 +546,84 @@ def collect(self, self._env_info[env_id]['time'] += self._timer.value + interaction_duration if episode_timestep.done: - logging.info(f'========env {env_id} done!========') + self._logger.info(f'======== Environment {env_id} episode finished! ========') self._total_episode_count += 1 - reward = episode_timestep.info['eval_episode_return'] info = { - 'reward': reward, + 'reward': episode_timestep.info['eval_episode_return'], 'time': self._env_info[env_id]['time'], 'step': self._env_info[env_id]['step'], } if not collect_with_pure_policy: - info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] + info['visit_entropy'] = visit_entropies_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 if self.policy_config.gumbel_algo: - info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] - + info['completed_value'] = completed_value_lst[env_id] / eps_steps_lst[env_id] if eps_steps_lst[env_id] > 0 else 0 collected_episode += 1 self._episode_info.append(info) # ============================================================== - # if it is the end of the game, we will save the game segment + # At the end of a game, save all remaining game segments. # ============================================================== - - # NOTE: put the penultimate game segment in one episode into the trajectory_pool - # pad over 2th last game_segment using the last game_segment + # NOTE: Store the second-to-last game segment of the episode. if self.last_game_segments[env_id] is not None: self.pad_and_save_last_trajectory( env_id, self.last_game_segments, self.last_game_priorities, game_segments, self.dones ) - # store current segment trajectory + # Calculate priorities for the final segment. priorities = self._compute_priorities(env_id, pred_values_lst, search_values_lst) - # NOTE: put the last game segment in one episode into the trajectory_pool + # NOTE: Store the final game segment of the episode. game_segments[env_id].game_segment_to_array() - - # assert len(game_segments[env_id]) == len(priorities) - # NOTE: save the last game segment in one episode into the trajectory_pool if it's not null - if len(game_segments[env_id].reward_segment) != 0: + if len(game_segments[env_id].reward_segment) > 0: self.game_segment_pool.append((game_segments[env_id], priorities, self.dones[env_id])) - # log - self_play_moves_max = max(self_play_moves_max, eps_steps_lst[env_id]) - if not collect_with_pure_policy: - self_play_visit_entropy.append(visit_entropies_lst[env_id] / eps_steps_lst[env_id]) - self_play_moves += eps_steps_lst[env_id] - self_play_episodes += 1 - - pred_values_lst[env_id] = [] - search_values_lst[env_id] = [] - eps_steps_lst[env_id] = 0 - visit_entropies_lst[env_id] = 0 - - # Env reset is done by env_manager automatically - # NOTE: ============ reset the policy for the env_id. Default reset_init_data=True. ================ - self._policy.reset([env_id]) + # Reset lists and stats for the new episode. + pred_values_lst[env_id], search_values_lst[env_id] = [], [] + eps_steps_lst[env_id], visit_entropies_lst[env_id] = 0, 0 + + # Environment reset is handled by the env_manager automatically. + # NOTE: Reset the policy state for the completed environment. + self._policy.reset([env_id], task_id=self.task_id) self._reset_stat(env_id) - ready_env_id.remove(env_id) - # ===== NOTE: if one episode done and not return, we should init its game_segments[env_id] ======= - # create new GameSegment - game_segments[env_id] = GameSegment( - self._env.action_space, - game_segment_length=self.policy_config.game_segment_length, - config=self.policy_config - ) + # NOTE: If an episode finishes but collection continues, re-initialize its game segment. + game_segments[env_id] = GameSegment( + self._env.action_space, + game_segment_length=self.policy_config.game_segment_length, + config=self.policy_config, + task_id=self.task_id + ) game_segments[env_id].reset(observation_window_stack[env_id]) - - # NOTE: must after the for loop to make sure all env_id's data are collected + # Check if the required number of segments has been collected. if len(self.game_segment_pool) >= self._default_num_segments: - logging.info(f'env {env_id} collected {len(self.game_segment_pool)} segments now!') - - # [data, meta_data] - return_data = [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], [ - { - 'priorities': self.game_segment_pool[i][1], - 'done': self.game_segment_pool[i][2], - 'unroll_plus_td_steps': self.unroll_plus_td_steps - } for i in range(len(self.game_segment_pool)) + self._logger.info(f'Collected {len(self.game_segment_pool)} segments, reaching the target of {self._default_num_segments}.') + + # Format data for returning: [game_segments, metadata_list] + return_data = [ + [self.game_segment_pool[i][0] for i in range(len(self.game_segment_pool))], + [ + { + 'priorities': self.game_segment_pool[i][1], + 'done': self.game_segment_pool[i][2], + 'unroll_plus_td_steps': self.unroll_plus_td_steps + } for i in range(len(self.game_segment_pool)) + ] ] self.game_segment_pool.clear() break + collected_duration = sum([d['time'] for d in self._episode_info]) + + # NOTE: Only for usual DDP not for unizero_multitask pipeline. # reduce data when enables DDP - if self._world_size > 1: - collected_step = allreduce_data(collected_step, 'sum') - collected_episode = allreduce_data(collected_episode, 'sum') - collected_duration = allreduce_data(collected_duration, 'sum') + # if self._world_size > 1: + # collected_step = allreduce_data(collected_step, 'sum') + # collected_episode = allreduce_data(collected_episode, 'sum') + # collected_duration = allreduce_data(collected_duration, 'sum') + self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration @@ -751,31 +635,28 @@ def collect(self, def _output_log(self, train_iter: int) -> None: """ Overview: - Log the collector's data and output the log information. + Logs collection statistics to the console and TensorBoard. Arguments: - - train_iter (:obj:`int`): Current training iteration number for logging context. + - train_iter (:obj:`int`): The current training iteration for logging context. """ - if self._rank != 0: - return if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0: self._last_train_iter = train_iter episode_count = len(self._episode_info) envstep_count = sum([d['step'] for d in self._episode_info]) duration = sum([d['time'] for d in self._episode_info]) episode_reward = [d['reward'] for d in self._episode_info] + if not self.collect_with_pure_policy: - visit_entropy = [d['visit_entropy'] for d in self._episode_info] + visit_entropy = [d.get('visit_entropy', 0.0) for d in self._episode_info] else: visit_entropy = [0.0] - if self.policy_config.gumbel_algo: - completed_value = [d['completed_value'] for d in self._episode_info] - self._total_duration += duration + info = { 'episode_count': episode_count, 'envstep_count': envstep_count, 'avg_envstep_per_episode': envstep_count / episode_count, - 'avg_envstep_per_sec': envstep_count / duration, - 'avg_episode_per_sec': episode_count / duration, + 'avg_envstep_per_sec': envstep_count / duration if duration > 0 else 0, + 'avg_episode_per_sec': episode_count / duration if duration > 0 else 0, 'collect_time': duration, 'reward_mean': np.mean(episode_reward), 'reward_std': np.std(episode_reward), @@ -784,16 +665,25 @@ def _output_log(self, train_iter: int) -> None: 'total_envstep_count': self._total_envstep_count, 'total_episode_count': self._total_episode_count, 'total_duration': self._total_duration, - 'visit_entropy': np.mean(visit_entropy), + 'visit_entropy_mean': np.mean(visit_entropy), } if self.policy_config.gumbel_algo: - info['completed_value'] = np.mean(completed_value) + completed_value = [d.get('completed_value', 0.0) for d in self._episode_info] + info['completed_value_mean'] = np.mean(completed_value) + self._episode_info.clear() - self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))) + + self._logger.info(f"Collector log (rank {self._rank}, task_id {self.task_id}):\n" + '\n'.join([f'{k}: {v}' for k, v in info.items()])) for k, v in info.items(): if k in ['each_reward']: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - if k in ['total_envstep_count']: - continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file + if self.task_id is None: + # Log for single-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step/{k}', v, self._total_envstep_count) + else: + # Log for multi-task setting + self._tb_logger.add_scalar(f'{self._instance_name}_iter_task{self.task_id}/{k}', v, train_iter) + if k not in ['total_envstep_count', 'total_episode_count', 'total_duration']: + self._tb_logger.add_scalar(f'{self._instance_name}_step_task{self.task_id}/{k}', v, self._total_envstep_count) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index faca7a062..c253988e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ DI-engine>=0.5.3 -gymnasium[atari]==0.28.0 +gymnasium[atari]>=1.0.0 +ale-py>=0.10.1 numpy>=1.24.1,<2 transformers pympler @@ -12,4 +13,5 @@ einops openai nltk pyecharts -numba \ No newline at end of file +numba +simple_parsing # for Jericho env \ No newline at end of file diff --git a/zoo/atari/config/atari_efficientzero_ddp_config.py b/zoo/atari/config/atari_efficientzero_ddp_config.py index bf1c42a0e..75a33ac64 100644 --- a/zoo/atari/config/atari_efficientzero_ddp_config.py +++ b/zoo/atari/config/atari_efficientzero_ddp_config.py @@ -23,7 +23,7 @@ exp_name=f'data_ez/{env_id[:-14]}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rer{reanalyze_ratio}_ddp_{gpu_num}gpu_seed0', env=dict( env_id=env_id, - obs_shape=(4, 96, 96), + obs_shape=(4, 64, 64), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -31,7 +31,7 @@ ), policy=dict( model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, action_space_size=action_space_size, downsample=True, diff --git a/zoo/atari/config/atari_env_action_space_map.py b/zoo/atari/config/atari_env_action_space_map.py index e2090586d..af3a54cbf 100644 --- a/zoo/atari/config/atari_env_action_space_map.py +++ b/zoo/atari/config/atari_env_action_space_map.py @@ -1,30 +1,34 @@ from easydict import EasyDict +# ale-py==0.10.1, gymnasium==1.2.1 atari_env_action_space_map = EasyDict({ - 'AlienNoFrameskip-v4': 18, - 'AmidarNoFrameskip-v4': 10, - 'AssaultNoFrameskip-v4': 7, - 'AsterixNoFrameskip-v4': 9, - 'BankHeistNoFrameskip-v4': 18, - 'BattleZoneNoFrameskip-v4': 18, - 'ChopperCommandNoFrameskip-v4': 18, - 'CrazyClimberNoFrameskip-v4': 9, - 'DemonAttackNoFrameskip-v4': 6, - 'FreewayNoFrameskip-v4': 3, - 'FrostbiteNoFrameskip-v4': 18, - 'GopherNoFrameskip-v4': 8, - 'HeroNoFrameskip-v4': 18, - 'JamesbondNoFrameskip-v4': 18, - 'KangarooNoFrameskip-v4': 18, - 'KrullNoFrameskip-v4': 18, - 'KungFuMasterNoFrameskip-v4': 14, - 'PrivateEyeNoFrameskip-v4': 18, - 'RoadRunnerNoFrameskip-v4': 18, - 'UpNDownNoFrameskip-v4': 6, - 'PongNoFrameskip-v4': 6, - 'MsPacmanNoFrameskip-v4': 9, - 'QbertNoFrameskip-v4': 6, - 'SeaquestNoFrameskip-v4': 18, - 'BoxingNoFrameskip-v4': 18, - 'BreakoutNoFrameskip-v4': 4, + 'ALE/Alien-v5': 18, + 'ALE/Amidar-v5': 10, + 'ALE/Assault-v5': 7, + 'ALE/Asterix-v5': 9, + 'ALE/BankHeist-v5': 18, + 'ALE/BattleZone-v5': 18, + 'ALE/ChopperCommand-v5': 18, + 'ALE/CrazyClimber-v5': 9, + 'ALE/DemonAttack-v5': 6, + 'ALE/Freeway-v5': 3, + 'ALE/Frostbite-v5': 18, + 'ALE/Gopher-v5': 8, + 'ALE/Hero-v5': 18, + 'ALE/Jamesbond-v5': 18, + 'ALE/Kangaroo-v5': 18, + 'ALE/Krull-v5': 18, + 'ALE/KungFuMaster-v5': 14, + 'ALE/PrivateEye-v5': 18, + 'ALE/RoadRunner-v5': 18, + 'ALE/UpNDown-v5': 6, + 'ALE/Pong-v5': 6, + 'ALE/MsPacman-v5': 9, + 'ALE/Qbert-v5': 6, + 'ALE/Seaquest-v5': 18, + 'ALE/Boxing-v5': 18, + 'ALE/Breakout-v5': 4, + 'ALE/SpaceInvaders-v5': 6, + 'ALE/BeamRider-v5': 9, + 'ALE/Gravitar-v5': 18, }) \ No newline at end of file diff --git a/zoo/atari/config/atari_gumbel_muzero_config.py b/zoo/atari/config/atari_gumbel_muzero_config.py index e2175b45f..92f858dc4 100644 --- a/zoo/atari/config/atari_gumbel_muzero_config.py +++ b/zoo/atari/config/atari_gumbel_muzero_config.py @@ -23,7 +23,7 @@ env=dict( stop_value=int(1e6), env_id=env_id, - obs_shape=(4, 96, 96), + obs_shape=(4, 64, 64), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -31,7 +31,7 @@ ), policy=dict( model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, action_space_size=action_space_size, downsample=True, diff --git a/zoo/atari/config/atari_muzero_ddp_config.py b/zoo/atari/config/atari_muzero_ddp_config.py index 2d19b6e8c..53394af05 100644 --- a/zoo/atari/config/atari_muzero_ddp_config.py +++ b/zoo/atari/config/atari_muzero_ddp_config.py @@ -27,7 +27,7 @@ env=dict( stop_value=int(1e6), env_id=env_id, - obs_shape=(4, 96, 96), + obs_shape=(4, 64, 64), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -36,7 +36,7 @@ policy=dict( model_path=None, model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, action_space_size=action_space_size, downsample=True, diff --git a/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py new file mode 100644 index 000000000..fe8fa3142 --- /dev/null +++ b/zoo/atari/config/atari_muzero_multitask_segment_ddp_config.py @@ -0,0 +1,330 @@ +""" +Overview: + Configuration generation script for multi-task MuZero training on Atari environments. + This script defines and generates the necessary configuration files for a distributed training setup. +""" +from easydict import EasyDict +from copy import deepcopy +from typing import List, Union, Dict, Any + +# The 'atari_env_action_space_map' was not used in the original code, so it has been removed. + +class AtariMuZeroMultitaskConfig: + """ + Overview: + A class to generate and manage configurations for multi-task MuZero experiments on Atari. + It encapsulates the entire configuration logic, providing a clean and extensible interface. + """ + + def __init__( + self, + env_id_list: List[str], + seed: int, + num_unroll_steps: int, + num_simulations: int, + collector_env_num: int, + evaluator_env_num: int, + max_env_step: int, + batch_size: Union[List[int], int], + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + exp_path_prefix: str = 'YOUR_EXPERIMENT_PATH_PREFIX/data_muzero_mt_atari', + ) -> None: + """ + Overview: + Initializes the multi-task configuration generator. + Arguments: + - env_id_list (:obj:`List[str]`): A list of Atari environment IDs to be trained on. + - seed (:obj:`int`): The random seed for the experiment. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model during training. + - num_simulations (:obj:`int`): The number of simulations to run in the MCTS search. + - collector_env_num (:obj:`int`): The number of environments for data collection. + - evaluator_env_num (:obj:`int`): The number of environments for evaluation. + - max_env_step (:obj:`int`): The total number of environment steps to train for. + - batch_size (:obj:`Union[List[int], int]`): The batch size for training. Can be a list for per-task sizes or a single int. + - norm_type (:obj:`str`): The type of normalization to use in the model (e.g., 'BN', 'LN'). + - buffer_reanalyze_freq (:obj:`float`): The frequency at which to reanalyze the replay buffer. + - reanalyze_batch_size (:obj:`int`): The batch size for reanalysis. + - reanalyze_partition (:obj:`float`): The partition ratio for reanalysis. + - num_segments (:obj:`int`): The number of segments for the replay buffer. + - exp_path_prefix (:obj:`str`): A template for the experiment's output path. + """ + self.env_id_list = env_id_list + self.seed = seed + self.num_unroll_steps = num_unroll_steps + self.num_simulations = num_simulations + self.collector_env_num = collector_env_num + self.evaluator_env_num = evaluator_env_num + self.max_env_step = max_env_step + self.batch_size = batch_size + self.norm_type = norm_type + self.buffer_reanalyze_freq = buffer_reanalyze_freq + self.reanalyze_batch_size = reanalyze_batch_size + self.reanalyze_partition = reanalyze_partition + self.num_segments = num_segments + self.exp_path_prefix = exp_path_prefix + + # --- Derived attributes --- + self.num_tasks = len(self.env_id_list) + self.action_space_size = 18 # Default full action space for Atari + + def _create_base_config(self) -> EasyDict: + """ + Overview: + Creates the base configuration dictionary with shared settings for all tasks. + Returns: + - (:obj:`EasyDict`): A dictionary containing the base configuration. + """ + return EasyDict(dict( + env=dict( + stop_value=int(self.max_env_step), + observation_shape=(4, 64, 64), + frame_stack_num=4, + gray_scale=True, + collector_env_num=self.collector_env_num, + evaluator_env_num=self.evaluator_env_num, + n_evaluator_episode=self.evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Very important for DDP + learn=dict( + learner=dict( + hook=dict(save_ckpt_after_iter=200000), + ), + ), + grad_correct_params=dict(), + task_num=self.num_tasks, + model=dict( + device='cuda', + num_res_blocks=2, + num_channels=256, + reward_head_channels=16, + value_head_channels=16, + policy_head_channels=16, + fc_reward_layers=[32], + fc_value_layers=[32], + fc_policy_layers=[32], + observation_shape=(4, 64, 64), + frame_stack_num=4, + gray_scale=True, + action_space_size=self.action_space_size, + norm_type=self.norm_type, + model_type='conv', + image_channel=1, + downsample=True, + self_supervised_learning_loss=True, + discrete_action_encoding_type='one_hot', + use_sim_norm=True, + use_sim_norm_kl_loss=False, + task_num=self.num_tasks, + ), + allocated_batch_sizes=False, + cuda=True, + env_type='not_board_games', + train_start_after_envsteps=2000, + # train_start_after_envsteps=0, # TODO: debug + game_segment_length=20, + random_collect_episode_num=0, + use_augmentation=True, + use_priority=False, + replay_ratio=0.25, + num_unroll_steps=self.num_unroll_steps, + update_per_collect=80, + optim_type='SGD', + td_steps=5, + lr_piecewise_constant_decay=True, + manual_temperature_decay=False, + learning_rate=0.2, + target_update_freq=100, + num_segments=self.num_segments, + num_simulations=self.num_simulations, + policy_entropy_weight=5e-3, # TODO: Fine-tune this weight. + ssl_loss_weight=2, + eval_freq=int(5e3), + replay_buffer_size=int(5e5), + collector_env_num=self.collector_env_num, + evaluator_env_num=self.evaluator_env_num, + # ============= Reanalyze Parameters ============= + buffer_reanalyze_freq=self.buffer_reanalyze_freq, + reanalyze_batch_size=self.reanalyze_batch_size, + reanalyze_partition=self.reanalyze_partition, + ), + )) + + def _get_exp_name(self, env_id: str) -> str: + """ + Overview: + Generates a formatted experiment name for a given task. + Arguments: + - env_id (:obj:`str`): The environment ID for the specific task. + Returns: + - (:obj:`str`): The formatted experiment name. + """ + # TODO: debug name + prefix = ( + f'{self.exp_path_prefix}/{self.num_tasks}games_brf{self.buffer_reanalyze_freq}/' + f'{self.num_tasks}games_brf{self.buffer_reanalyze_freq}_1-encoder-{self.norm_type}-res2-channel256_gsl20_' + f'{self.num_tasks}-pred-head_mbs-512_upc80_H{self.num_unroll_steps}_seed{self.seed}/' + ) + env_name = env_id.split('NoFrameskip')[0] + return f"{prefix}{env_name}_muzero-mt_seed{self.seed}" + + def generate_configs(self) -> List[List[Union[int, List[Any]]]]: + """ + Overview: + Generates the final list of configurations for all specified tasks, + ready to be used by the training entry point. + Returns: + - (:obj:`List[List[Union[int, List[Any]]]]`): A list where each element corresponds to a task, + containing the task_id and a list with the task's config and env_manager config. + """ + base_config = self._create_base_config() + env_manager_config = self._create_env_manager_config() + + configs = [] + for task_id, env_id in enumerate(self.env_id_list): + task_config = deepcopy(base_config) + + # --- Apply task-specific settings --- + task_config.env.env_id = env_id + task_config.policy.task_id = task_id + + # Handle per-task batch size if provided as a list + if isinstance(self.batch_size, list): + task_config.policy.batch_size = self.batch_size[task_id] + else: + task_config.policy.batch_size = self.batch_size + + task_config.exp_name = self._get_exp_name(env_id) + + configs.append([task_id, [task_config, env_manager_config]]) + + return configs + + @staticmethod + def _create_env_manager_config() -> EasyDict: + """ + Overview: + Creates a static configuration for the environment and policy managers. + Returns: + - (:obj:`EasyDict`): A dictionary containing manager configurations. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero_multitask', + import_names=['lzero.policy.muzero_multitask'], + ), + )) + + +if __name__ == "__main__": + # ============================================================== + # Hyperparameters for Multi-Task Training + # ============================================================== + + # --- List of Atari environments for multi-task learning --- + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', 'AmidarNoFrameskip-v4', + 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', + 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', 'KrullNoFrameskip-v4', + 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + + # --- Core Experiment Settings --- + seed = 0 + max_env_step = int(5e5) + + # --- Training & Model Parameters --- + num_unroll_steps = 5 + num_simulations = 50 + norm_type = 'BN' # 'BN' (Batch Normalization) or 'LN' (Layer Normalization) + + # --- Environment & Collector Settings --- + collector_env_num = 8 + evaluator_env_num = 3 + num_segments = 8 + + # --- Batch Size Configuration --- + # The batch size is dynamically calculated per task to not exceed a maximum total batch size. + max_batch_size = 512 + per_task_batch_size = int(min(64, max_batch_size / len(env_id_list))) + batch_size = [per_task_batch_size] * len(env_id_list) + + # --- Reanalyze Buffer Settings --- + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # --- (Optional) Debug Settings --- + # To use debug settings, uncomment the following lines. + # collector_env_num = 2 + # evaluator_env_num = 2 + # num_segments = 2 + # num_simulations = 3 + # debug_batch_size = int(min(2, max_batch_size / len(env_id_list))) + # batch_size = [debug_batch_size] * len(env_id_list) + # print("--- RUNNING IN DEBUG MODE ---") + + print(f'=========== Batch size per task: {batch_size[0]} ===========') + + # ============================================================== + # Configuration Generation and Training Launch + # ============================================================== + + # --- Instantiate and generate configurations --- + experiment_config = AtariMuZeroMultitaskConfig( + env_id_list=env_id_list, + seed=seed, + max_env_step=max_env_step, + num_unroll_steps=num_unroll_steps, + num_simulations=num_simulations, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + batch_size=batch_size, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + # Note: Update this path to your desired location. + exp_path_prefix='YOUR_EXPERIMENT_PATH_PREFIX/data_muzero_mt_atari_20250228' + ) + + configs_to_run = experiment_config.generate_configs() + + # --- Launch Distributed Training --- + """ + Overview: + This script should be executed with GPUs. + Set the NCCL timeout and launch the script using one of the following commands. + + Command using torch.distributed.launch: + export NCCL_TIMEOUT=3600000 + python -m torch.distributed.launch --nproc_per_node=4 --master_port=29501 ./path/to/this/script.py + + Command using torchrun: + export NCCL_TIMEOUT=3600000 + torchrun --nproc_per_node=4 --master_port=29501 ./path/to/this/script.py + """ + from lzero.entry import train_muzero_multitask_segment_ddp + from ding.utils import DDPContext + + with DDPContext(): + train_muzero_multitask_segment_ddp(configs_to_run, seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_muzero_segment_config.py b/zoo/atari/config/atari_muzero_segment_config.py index 4289fb957..f713096d9 100644 --- a/zoo/atari/config/atari_muzero_segment_config.py +++ b/zoo/atari/config/atari_muzero_segment_config.py @@ -43,7 +43,7 @@ def main(env_id, seed): env=dict( stop_value=int(1e6), env_id=env_id, - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, gray_scale=True, collector_env_num=collector_env_num, @@ -59,7 +59,7 @@ def main(env_id, seed): analysis_sim_norm=False, cal_dormant_ratio=False, model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), image_channel=1, frame_stack_num=4, gray_scale=True, @@ -123,7 +123,7 @@ def main(env_id, seed): # ============ use muzero_segment_collector instead of muzero_collector ============= from lzero.entry import train_muzero_segment - main_config.exp_name = f'data_muzero/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_seed{seed}' + main_config.exp_name = f'data_lz_muzero/{env_id[:-14]}/{env_id[:-14]}_mz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}_bs{batch_size}_seed{seed}' train_muzero_segment([main_config, create_config], seed=seed, max_env_step=max_env_step) if __name__ == "__main__": diff --git a/zoo/atari/config/atari_rezero_ez_config.py b/zoo/atari/config/atari_rezero_ez_config.py index e589041a7..ac22901fd 100644 --- a/zoo/atari/config/atari_rezero_ez_config.py +++ b/zoo/atari/config/atari_rezero_ez_config.py @@ -25,7 +25,7 @@ exp_name=f'data_rezero_ez/{env_id[:-14]}_rezero_efficientzero_ns{num_simulations}_upc{update_per_collect}_brf{buffer_reanalyze_freq}_seed0', env=dict( env_id=env_id, - obs_shape=(4, 96, 96), + obs_shape=(4, 64, 64), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -33,7 +33,7 @@ ), policy=dict( model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, action_space_size=action_space_size, downsample=True, diff --git a/zoo/atari/config/atari_rezero_mz_config.py b/zoo/atari/config/atari_rezero_mz_config.py index c7787831b..5bf21fb0c 100644 --- a/zoo/atari/config/atari_rezero_mz_config.py +++ b/zoo/atari/config/atari_rezero_mz_config.py @@ -18,6 +18,7 @@ reuse_search = True collect_with_pure_policy = True buffer_reanalyze_freq = 1 + # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -27,7 +28,7 @@ env=dict( stop_value=int(1e6), env_id=env_id, - obs_shape=(4, 96, 96), + obs_shape=(4, 64, 64), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -35,7 +36,7 @@ ), policy=dict( model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, action_space_size=action_space_size, downsample=True, diff --git a/zoo/atari/config/atari_sampled_efficientzero_config.py b/zoo/atari/config/atari_sampled_efficientzero_config.py index b5a551c3b..a96692756 100644 --- a/zoo/atari/config/atari_sampled_efficientzero_config.py +++ b/zoo/atari/config/atari_sampled_efficientzero_config.py @@ -24,7 +24,7 @@ exp_name=f'data_sez/{env_id[:-14]}_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rer{reanalyze_ratio}_seed0', env=dict( env_id=env_id, - obs_shape=(4, 96, 96), + obs_shape=(4, 64, 64), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -32,7 +32,7 @@ ), policy=dict( model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, action_space_size=action_space_size, downsample=True, diff --git a/zoo/atari/config/atari_stochastic_muzero_config.py b/zoo/atari/config/atari_stochastic_muzero_config.py index b338680de..2fb13e44e 100644 --- a/zoo/atari/config/atari_stochastic_muzero_config.py +++ b/zoo/atari/config/atari_stochastic_muzero_config.py @@ -24,7 +24,7 @@ env=dict( stop_value=int(1e6), env_id=env_id, - obs_shape=(4, 96, 96), + obs_shape=(4, 64, 64), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, @@ -32,7 +32,7 @@ ), policy=dict( model=dict( - observation_shape=(4, 96, 96), + observation_shape=(4, 64, 64), frame_stack_num=4, action_space_size=action_space_size, chance_space_size=chance_space_size, diff --git a/zoo/atari/config/atari_unizero_ddp_config.py b/zoo/atari/config/atari_unizero_ddp_config.py index d64332d58..a290f4b6b 100644 --- a/zoo/atari/config/atari_unizero_ddp_config.py +++ b/zoo/atari/config/atari_unizero_ddp_config.py @@ -21,11 +21,6 @@ infer_context_length = 4 seed = 0 -# ====== only for debug ===== -# num_simulations = 2 -# max_env_step = int(2e5) -# batch_size = 2 -# num_unroll_steps = 10 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -55,7 +50,6 @@ max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action context_length=2 * infer_context_length, device='cuda', - # device='cpu', action_space_size=action_space_size, num_layers=2, num_heads=8, diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py new file mode 100644 index 000000000..6ca1c8601 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_balance_config.py @@ -0,0 +1,550 @@ +# -*- coding: utf-8 -*- +""" +Overview: + This script contains the configuration generation logic for a multi-task UniZero agent + designed for Atari environments. It sets up experiment parameters, computes batch sizes + for distributed training, and generates the final configuration objects required to + launch the training process. + +Execution Command Example: + To run this script using distributed training with GPUs, use the following command. + Replace with the number of GPUs per node (e.g., 8) and adjust paths and log files as needed. + + cd /path/to/your/project/LightZero + python -m torch.distributed.launch --nproc_per_node= --master_port= \ + /path/to/this/script.py 2>&1 | tee /path/to/your/logs/training.log +""" +import math +from typing import List, Tuple, Dict, Any + +from easydict import EasyDict +from ding.utils import DDPContext +# It is recommended to place entry point imports within the main execution block +# to avoid circular dependencies or premature initializations. +# from lzero.entry import train_unizero_multitask_balance_segment_ddp + + +# ============================================================== +# Configuration Computation and Generation +# ============================================================== + +def compute_batch_config( + env_id_list: List[str], + effective_batch_size: int, + gpus_per_node: int = 8, + max_micro_batch_per_gpu: int = 400 +) -> Tuple[List[int], int]: + """ + Overview: + Computes the micro-batch size for each environment and the number of gradient accumulation steps. + This is designed to balance the load across multiple environments and GPUs while respecting + memory constraints (max_micro_batch_per_gpu). + + Arguments: + - env_id_list (:obj:`List[str]`): A list of environment IDs. + - effective_batch_size (:obj:`int`): The target total batch size after gradient accumulation. + - gpus_per_node (:obj:`int`): The number of GPUs available for training. Defaults to 8. + - max_micro_batch_per_gpu (:obj:`int`): The maximum micro-batch size that can fit on a single GPU. Defaults to 400. + + Returns: + - (:obj:`Tuple[List[int], int]`): A tuple containing: + - A list of micro-batch sizes, one for each environment. + - The number of gradient accumulation steps required. + """ + num_envs = len(env_id_list) + if num_envs == 0: + return [], 1 + + # To avoid division by zero, assume at least one environment is processed per GPU group. + envs_per_gpu_group = max(1, num_envs // gpus_per_node) + + # Calculate the maximum micro-batch size per environment based on GPU memory limits. + max_micro_batch_per_env = int(max_micro_batch_per_gpu / envs_per_gpu_group) + + # Calculate the theoretical batch size per environment if distributed evenly. + theoretical_env_batch = effective_batch_size / num_envs + + if theoretical_env_batch > max_micro_batch_per_env: + # If the theoretical batch size exceeds the per-environment limit, + # cap the micro-batch size at the maximum allowed value. + micro_batch_size = max_micro_batch_per_env + # Calculate gradient accumulation steps needed to reach the effective batch size. + grad_accumulate_steps = math.ceil(theoretical_env_batch / max_micro_batch_per_env) + else: + # If the theoretical batch size is within limits, use it directly. + micro_batch_size = int(theoretical_env_batch) + grad_accumulate_steps = 1 + + # Assign the same computed micro-batch size to all environments. + batch_sizes = [micro_batch_size] * num_envs + + # Logging for debugging purposes. + print(f"Number of environments: {num_envs}") + print(f"Effective total batch size: {effective_batch_size}") + print(f"Theoretical batch size per environment: {theoretical_env_batch:.2f}") + print(f"Micro-batch size per environment: {micro_batch_size}") + print(f"Gradient accumulation steps: {grad_accumulate_steps}") + + return batch_sizes, grad_accumulate_steps + + +def create_config( + env_id: str, + action_space_size: int, + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: int, + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + target_return: int, + curriculum_stage_num: int, + num_envs: int, +) -> EasyDict: + """ + Overview: + Creates the main configuration dictionary for a single UniZero task. + + Arguments: + - env_id (:obj:`str`): The ID of the environment (e.g., 'PongNoFrameskip-v4'). + - action_space_size (:obj:`int`): The size of the action space. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for collection. + - num_simulations (:obj:`int`): Number of simulations for MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`int`): The micro-batch size for training. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model dynamics. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization layer to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the replay buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalysis. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalysis. + - num_segments (:obj:`int`): Number of segments for game episodes. + - total_batch_size (:obj:`int`): The effective total batch size. + - target_return (:obj:`int`): The target return for the environment. + - curriculum_stage_num (:obj:`int`): The number of stages in curriculum learning. + - num_envs (:obj:`int`): The total number of environments in the multi-task setup. + + Returns: + - (:obj:`EasyDict`): A configuration object for the agent. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Crucial for DDP + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + continuous_action_space=False, + world_model_cfg=dict( + use_global_pooling=False, + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + share_head=False, + analysis_dormant_ratio_weight_rank=False, + dormant_threshold=0.025, + continuous_action_space=False, + task_embed_option=None, + use_task_embed=False, + use_shared_projection=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=4, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=num_envs, + task_num=num_envs, + encoder_type='vit', + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=True, + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + moe_use_lora=True, + curriculum_stage_num=curriculum_stage_num, + lora_target_modules=["attn", "feed_forward"], + lora_r=64, + lora_alpha=32, + lora_dropout=0.1, + lora_scale_init=1, + min_stage0_iters=50000, + max_stage_iters=20000, + apply_curriculum_to_encoder=False, + ), + ), + # --- Task and Learning Settings --- + total_task_num=num_envs, + task_num=num_envs, + task_id=0, # This will be overridden for each task. + target_return=target_return, + use_task_exploitation_weight=False, + task_complexity_weight=True, + balance_pipeline=True, + # --- Training Settings --- + cuda=True, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + update_per_collect=80, + replay_ratio=0.25, + optim_type='AdamW', + cos_lr_scheduler=False, + train_start_after_envsteps=int(0), + # --- Replay Buffer and Reanalysis --- + replay_buffer_size=int(5e5), + num_segments=num_segments, + use_priority=False, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + reanalyze_ratio=reanalyze_ratio, + # --- MCTS Settings --- + num_simulations=num_simulations, + collect_num_simulations=num_simulations, + eval_num_simulations=50, + # --- Collector and Evaluator Settings --- + n_episode=n_episode, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + eval_freq=int(1e4), + # --- Miscellaneous --- + print_task_priority_logs=False, + model_path=None, + game_segment_length=20, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), + ), + )) + + +def _generate_experiment_name( + base_path_prefix: str, + num_envs: int, + curriculum_stage_num: int, + buffer_reanalyze_freq: float, + seed: int, + env_id: str +) -> str: + """ + Overview: + Helper function to generate a standardized experiment name. + + Arguments: + - base_path_prefix (:obj:`str`): The prefix for the experiment path, e.g., 'data_unizero_atari_mt_balance_YYYYMMDD'. + - num_envs (:obj:`int`): The total number of environments. + - curriculum_stage_num (:obj:`int`): The number of curriculum stages. + - buffer_reanalyze_freq (:obj:`float`): The buffer reanalyze frequency. + - seed (:obj:`int`): The random seed for the experiment. + - env_id (:obj:`str`): The environment ID for this specific task. + + Returns: + - (:obj:`str`): The generated experiment name. + """ + # Template for the experiment's parent directory. + brf_str = str(buffer_reanalyze_freq).replace('.', '') + parent_dir = ( + f"{base_path_prefix}/atari_{num_envs}games_balance-total-stage{curriculum_stage_num}_" + f"stage-50k-20k_vit-small-ln_trans-nlayer4-moe8_backbone-attn-mlp-lora_no-lora-scale_" + f"brf{brf_str}_not-share-head_seed{seed}/" + ) + + # Clean the environment ID for the final part of the name. + env_name_part = env_id.split('NoFrameskip')[0] + + return f"{parent_dir}{env_name_part}_seed{seed}" + + +def generate_configs( + env_id_list: List[str], + action_space_size: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_sizes: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + target_return_dict: Dict[str, int], + curriculum_stage_num: int, +) -> List[Tuple[int, List[Any]]]: + """ + Overview: + Generates a list of configuration tuples, one for each task/environment. + + Returns: + - (:obj:`List[Tuple[int, List[Any]]]`): A list where each element is a tuple containing + the task_id and a list with the main config and the environment manager config. + """ + configs = [] + exp_name_base_prefix = 'data_unizero_mt_balance_atari' # YYYYMMDD format + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id=env_id, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_sizes[task_id], + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + target_return=target_return_dict[env_id], + curriculum_stage_num=curriculum_stage_num, + num_envs=len(env_id_list), + ) + config.policy.task_id = task_id + config.exp_name = _generate_experiment_name( + base_path_prefix=exp_name_base_prefix, + num_envs=len(env_id_list), + curriculum_stage_num=curriculum_stage_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + seed=seed, + env_id=env_id + ) + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the environment manager configuration, specifying the types of environment, + policy, and manager to be used. + + Returns: + - (:obj:`EasyDict`): A configuration object for the environment manager. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + + +def get_atari_target_return_dict(ratio: float = 1.0) -> Dict[str, int]: + """ + Overview: + Calculates the target return for each Atari game based on a predefined score + and a scaling ratio. + + Arguments: + - ratio (:obj:`float`): A scaling factor for the target returns. Defaults to 1.0. + + Returns: + - (:obj:`Dict[str, int]`): A dictionary mapping environment IDs to their calculated target returns. + """ + # Pre-defined target scores for various Atari games. + target_scores = { + 'PongNoFrameskip-v4': 20, + 'MsPacmanNoFrameskip-v4': 6951.6, + 'SeaquestNoFrameskip-v4': 42054.7, + 'BoxingNoFrameskip-v4': 12.1, + 'AlienNoFrameskip-v4': 7127.7, + 'ChopperCommandNoFrameskip-v4': 7387.8, + 'HeroNoFrameskip-v4': 30826.4, + 'RoadRunnerNoFrameskip-v4': 7845.0, + 'AmidarNoFrameskip-v4': 100.5, + 'AssaultNoFrameskip-v4': 742.0, + 'AsterixNoFrameskip-v4': 1503.3, + 'BankHeistNoFrameskip-v4': 753.1, + 'BattleZoneNoFrameskip-v4': 12187.5, + 'CrazyClimberNoFrameskip-v4': 15829.4, + 'DemonAttackNoFrameskip-v4': 1971.0, + 'FreewayNoFrameskip-v4': 29.6, + 'FrostbiteNoFrameskip-v4': 334.7, + 'GopherNoFrameskip-v4': 2412.5, + 'JamesbondNoFrameskip-v4': 302.8, + 'KangarooNoFrameskip-v4': 3035.0, + 'KrullNoFrameskip-v4': 2665.5, + 'KungFuMasterNoFrameskip-v4': 12736.3, + 'PrivateEyeNoFrameskip-v4': 1001.3, + 'UpNDownNoFrameskip-v4': 11693.2, + 'QbertNoFrameskip-v4': 13455.0, + 'BreakoutNoFrameskip-v4': 30.5, + } + return {env: int(round(score * ratio)) for env, score in target_scores.items()} + + +def get_env_id_list(num_games: int) -> List[str]: + """ + Overview: + Returns a list of Atari environment IDs based on the specified number of games. + + Arguments: + - num_games (:obj:`int`): The number of games to include (e.g., 8 or 26). + + Returns: + - (:obj:`List[str]`): A list of environment ID strings. + """ + games_8 = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', 'BoxingNoFrameskip-v4', + 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + games_26 = games_8 + [ + 'AmidarNoFrameskip-v4', 'AssaultNoFrameskip-v4', 'AsterixNoFrameskip-v4', 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', 'CrazyClimberNoFrameskip-v4', 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', 'GopherNoFrameskip-v4', 'JamesbondNoFrameskip-v4', 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', 'KungFuMasterNoFrameskip-v4', 'PrivateEyeNoFrameskip-v4', 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', 'BreakoutNoFrameskip-v4', + ] + if num_games == 3: + return ['PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4'] + elif num_games == 8: + return games_8 + elif num_games == 26: + return games_26 + else: + raise ValueError(f"Unsupported number of games: {num_games}. Supported values are 3, 8, 26.") + + +def main(): + """ + Overview: + Main function to configure and launch the multi-task training process. + """ + # ============================================================== + # Primary Hyperparameters + # ============================================================== + # --- Experiment --- + num_games = 8 # Options: 3, 8, 26 + seeds = [0] + max_env_step = int(4e5) + benchmark_name = "atari" + + # --- Curriculum --- + curriculum_stage_num = 5 + + # --- Environment and Agent --- + action_space_size = 18 + num_simulations = 50 + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + + # --- Collector and Evaluator --- + collector_env_num = 8 + evaluator_env_num = 3 + n_episode = 8 + num_segments = 8 + + # --- Reanalysis --- + reanalyze_ratio = 0.0 + buffer_reanalyze_freq = 1 / 50 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ============================================================== + # Derived Configurations + # ============================================================== + env_id_list = get_env_id_list(num_games) + target_return_dict = get_atari_target_return_dict(ratio=1.0) + + # --- Batch Size Calculation --- + if num_games == 8: + effective_batch_size = 512 + elif num_games == 26: + effective_batch_size = 512 # For ViT-Base encoder + else: + # Default or other cases + effective_batch_size = 512 + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) + # Note: `total_batch_size` is passed to the config but `effective_batch_size` is used for calculation. + # This maintains consistency with the original script's logic. + total_batch_size = effective_batch_size + + # ============================================================== + # Launch Training + # ============================================================== + from lzero.entry import train_unizero_multitask_balance_segment_ddp + + for seed in seeds: + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_sizes=batch_sizes, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + target_return_dict=target_return_dict, + curriculum_stage_num=curriculum_stage_num + ) + + with DDPContext(): + train_unizero_multitask_balance_segment_ddp( + configs, + seed=seed, + max_env_step=max_env_step, + benchmark_name=benchmark_name + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py new file mode 100644 index 000000000..08636286c --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py @@ -0,0 +1,313 @@ +from easydict import EasyDict +import math +from typing import List, Tuple, Any, Dict, Union + +# ------------------------------------------------- +# 1. Refactored compute_batch_config +# ------------------------------------------------- +def compute_batch_config( + env_id_list: List[str], + effective_batch_size: int, + gpu_num: int = 8, + max_micro_batch_one_gpu: int = 400, +) -> Tuple[List[int], int]: + """ + Overview: + Calculate the micro-batch size for each environment and the number of gradient accumulation steps + to approach a target effective batch size across multiple GPUs and environments. + + Arguments: + - env_id_list (:obj:`List[str]`): A list of environment IDs for all tasks. + - effective_batch_size (:obj:`int`): The target global batch size for one backward pass. + - gpu_num (:obj:`int`): The number of GPUs actually used. Defaults to 8. + - max_micro_batch_one_gpu (:obj:`int`): The maximum micro-batch size a single GPU can handle. Defaults to 400. + + Returns: + - batch_sizes (:obj:`List[int]`): A list of micro-batch sizes for each environment. + - grad_acc_steps (:obj:`int`): The number of gradient accumulation steps. + """ + n_env = len(env_id_list) + # Number of environments that each GPU needs to handle simultaneously. + envs_per_gpu = max(1, math.ceil(n_env / gpu_num)) + # Reduce the micro-batch limit if multiple environments share one GPU. + max_micro_batch = max(1, max_micro_batch_one_gpu // envs_per_gpu) + + # First, calculate a candidate micro-batch by distributing the effective batch size evenly. + candidate = max(1, effective_batch_size // n_env) + micro_batch = min(candidate, max_micro_batch) + + # Gradient accumulation steps = ceil(global_batch / (micro_batch * n_env)). + grad_acc_steps = max(1, math.ceil(effective_batch_size / (micro_batch * n_env))) + + # Fine-tune the micro-batch downwards to ensure: + # micro_batch * n_env * grad_acc_steps <= effective_batch_size + # This aims to get as close as possible to the target without exceeding it. + while micro_batch * n_env * grad_acc_steps > effective_batch_size: + micro_batch -= 1 + if micro_batch == 0: # Defensive check, should not happen in theory. + micro_batch = 1 + break + + batch_sizes = [micro_batch] * n_env + + # --- Debug Information --- # + real_total_batch_size = micro_batch * n_env * grad_acc_steps + print( + f"[BatchConfig] Envs={n_env}, TargetTotalBS={effective_batch_size}, " + f"MicroBS={micro_batch}, GradAccSteps={grad_acc_steps}, RealTotalBS={real_total_batch_size}" + ) + + return batch_sizes, grad_acc_steps + +def create_config( + env_id: str, action_space_size: int, collector_env_num: int, evaluator_env_num: int, n_episode: int, + num_simulations: int, reanalyze_ratio: float, batch_size: int, num_unroll_steps: int, + infer_context_length: int, norm_type: str, buffer_reanalyze_freq: float, reanalyze_batch_size: int, + reanalyze_partition: float, num_segments: int, total_batch_size: int, num_layers: int +) -> EasyDict: + """ + Overview: + Creates the main configuration structure for a single training task. + + Arguments: + - env_id (:obj:`str`): The environment ID. + - action_space_size (:obj:`int`): The size of the action space. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for evaluation. + - num_simulations (:obj:`int`): Number of simulations in MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed samples in a batch. + - batch_size (:obj:`int`): The batch size for training. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model dynamics. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization layer to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the replay buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalysis. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalysis. + - num_segments (:obj:`int`): Number of segments for data collection. + - total_batch_size (:obj:`int`): The total effective batch size. + - num_layers (:obj:`int`): Number of layers in the transformer model. + + Returns: + - (:obj:`EasyDict`): A configuration object. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + ), + policy=dict( + multi_gpu=True, + only_use_moco_stats=False, + use_moco=False, + moco_version="v1", + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, # This will be overridden for each task + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + num_layers=num_layers, + world_model_cfg=dict( + norm_type=norm_type, + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + env_num=len(env_id_list), + task_num=len(env_id_list), + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + encoder_type='vit', + device='cuda', + game_segment_length=20, + ), + ), + device='cuda', + game_segment_length=20, + update_per_collect=80, # Corresponds to replay_ratio=0.5 for 8 games (20*8*0.5=80) + learning_rate=0.0001, + weight_decay=1e-2, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + total_batch_size=total_batch_size, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + replay_buffer_size=int(5e5), + eval_freq=int(1e4), + ), + )) + +def generate_configs( + env_id_list: List[str], action_space_size: int, collector_env_num: int, n_episode: int, + evaluator_env_num: int, num_simulations: int, reanalyze_ratio: float, batch_size: List[int], + num_unroll_steps: int, infer_context_length: int, norm_type: str, seed: int, + buffer_reanalyze_freq: float, reanalyze_batch_size: int, reanalyze_partition: float, + num_segments: int, total_batch_size: int, num_layers: int +) -> List[List[Union[int, List[EasyDict]]]]: + """ + Overview: + Generates a list of configurations for all specified tasks. + + Arguments: + (See arguments for `create_config` function) + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[List[Union[int, List[EasyDict]]]]`): A list where each element contains a task_id + and its corresponding configuration objects. + """ + configs = [] + + # --- Experiment Name Template --- + benchmark_tag = "data_unizero_mt" + model_tag = f"vit_nlayer{num_layers}_tbs{total_batch_size}" + exp_name_prefix = f'{benchmark_tag}/atari_{len(env_id_list)}games_{model_tag}_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, + reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, num_segments, total_batch_size, num_layers + ) + config.policy.task_id = task_id + # Correctly extract the game name from 'ALE/GameName-v5' format. + game_name = env_id.split('/')[1].split('-')[0] + config.exp_name = exp_name_prefix + f"{game_name}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the environment manager configuration, specifying the types of environment, + policy, and their import paths. + + Returns: + - (:obj:`EasyDict`): A configuration object for the environment manager. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs for distributed training. + + Example launch commands: + + export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + cd /path/to/your/project/ + + torchrun --nproc_per_node=4 /mnt/shared-storage-user/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py + """ + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import torch.distributed as dist + import os + + # ==================== Main Experiment Settings ==================== + num_games = 8 # Options: 3, 8, 26 + num_layers = 2 + action_space_size = 18 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e6) + reanalyze_ratio = 0.0 + + # ==================== Environment Configuration ==================== + if num_games == 3: + env_id_list = ['ALE/Pong-v5', 'ALE/MsPacman-v5', 'ALE/Seaquest-v5'] + elif num_games == 8: + env_id_list = [ + 'ALE/Pong-v5', 'ALE/MsPacman-v5', 'ALE/Seaquest-v5', 'ALE/Boxing-v5', + 'ALE/Alien-v5', 'ALE/ChopperCommand-v5', 'ALE/Hero-v5', 'ALE/RoadRunner-v5', + ] + elif num_games == 26: + env_id_list = [ + 'ALE/Pong-v5', 'ALE/MsPacman-v5', 'ALE/Seaquest-v5', 'ALE/Boxing-v5', + 'ALE/Alien-v5', 'ALE/ChopperCommand-v5', 'ALE/Hero-v5', 'ALE/RoadRunner-v5', + 'ALE/Amidar-v5', 'ALE/Assault-v5', 'ALE/Asterix-v5', 'ALE/BankHeist-v5', + 'ALE/BattleZone-v5', 'ALE/CrazyClimber-v5', 'ALE/DemonAttack-v5', 'ALE/Freeway-v5', + 'ALE/Frostbite-v5', 'ALE/Gopher-v5', 'ALE/Jamesbond-v5', 'ALE/Kangaroo-v5', + 'ALE/Krull-v5', 'ALE/KungFuMaster-v5', 'ALE/PrivateEye-v5', 'ALE/UpNDown-v5', + 'ALE/Qbert-v5', 'ALE/Breakout-v5', + ] + else: + raise ValueError(f"Unsupported number of environments: {num_games}") + + # ==================== Batch Size Calculation ==================== + if len(env_id_list) == 8: + if num_layers in [2, 4]: + effective_batch_size = 1024 + elif num_layers == 8: + effective_batch_size = 512 + elif len(env_id_list) == 26: + effective_batch_size = 512 + elif len(env_id_list) == 3: + effective_batch_size = 10 # For debugging + else: + raise ValueError(f"Batch size not configured for {len(env_id_list)} environments.") + + batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size, gpu_num=4) + total_batch_size = effective_batch_size + + # ==================== Model and Training Settings ==================== + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000000 # Effectively disable buffer reanalyze + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # ==================== Training Loop ==================== + # Set NCCL timeout to prevent watchdog hang due to unbalanced data collection speeds + os.environ.setdefault('NCCL_TIMEOUT', '3600') # 60 minutes in seconds + os.environ.setdefault('NCCL_BLOCKING_WAIT', '1') + + for seed in [0]: + configs = generate_configs( + env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, + num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, + norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size, num_layers + ) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="atari") + print(f"Seed: {seed} training finished!") + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py new file mode 100644 index 000000000..b7973ff87 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_eval_config.py @@ -0,0 +1,333 @@ +from easydict import EasyDict +from typing import List, Any, Dict + +# ============================================================== +# Environment and Policy Manager Configuration +# ============================================================== + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the configuration for the environment and policy managers. + This config specifies the types and import paths for core components + like the environment wrapper and the policy definition. + Returns: + - manager_config (:obj:`EasyDict`): A dictionary containing the types and import names + for the environment and policy managers. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +# ============================================================== +# Main Configuration Generation +# ============================================================== + +def create_config( + env_id: str, + action_space_size: int, + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + env_id_list: List[str], +) -> EasyDict: + """ + Overview: + Creates the main configuration dictionary for a single task in a multi-task setup. + Arguments: + - env_id (:obj:`str`): The ID of the environment for this specific task. + - action_space_size (:obj:`int`): The size of the action space for the model. + - collector_env_num (:obj:`int`): The number of environments for the data collector. + - evaluator_env_num (:obj:`int`): The number of environments for the evaluator. + - n_episode (:obj:`int`): The number of episodes to run for collection. + - num_simulations (:obj:`int`): The number of simulations for the MCTS algorithm. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in the replay buffer. + - batch_size (:obj:`List[int]`): The batch size for training, specified per task. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model during training. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN' for LayerNorm). + - buffer_reanalyze_freq (:obj:`float`): The frequency at which to reanalyze the buffer. + - reanalyze_batch_size (:obj:`int`): The batch size for reanalyzing data. + - reanalyze_partition (:obj:`float`): The partition ratio for reanalyzing data. + - num_segments (:obj:`int`): The number of segments for game data. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + - env_id_list (:obj:`List[str]`): The list of all environment IDs in the multi-task setup. + Returns: + - config (:obj:`EasyDict`): The complete configuration for a single training task. + """ + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + ), + policy=dict( + multi_gpu=True, # Enable multi-GPU for DDP + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, + MoCo_rho=0, calpha=0.5, rescale=1, + ), + task_num=len(env_id_list), + task_id=0, # Placeholder, will be set in generate_configs + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + env_id_list=env_id_list, + # TODO: Implement and verify the t-SNE analysis functionality. + analysis_tsne=True, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=8, # Transformer layers + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=len(env_id_list), + task_num=len(env_id_list), + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + ), + ), + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ), + )) + + +def _generate_exp_name_prefix( + exp_base_path: str, + num_games: int, + buffer_reanalyze_freq: float, + norm_type: str, + seed: int +) -> str: + """ + Overview: + Generates a standardized prefix for the experiment name based on key hyperparameters. + Arguments: + - exp_base_path (:obj:`str`): The base directory for the experiment logs. + - num_games (:obj:`int`): The number of games in the multi-task setup. + - buffer_reanalyze_freq (:obj:`float`): The frequency of buffer reanalysis. + - norm_type (:obj:`str`): The normalization type used in the model. + - seed (:obj:`int`): The random seed for the experiment. + Returns: + - (:obj:`str`): The generated experiment name prefix. + """ + # NOTE: This name is constructed based on a specific convention to encode hyperparameters. + # It includes details about the model architecture, training parameters, and environment setup. + return ( + f'{exp_base_path}/{num_games}games_brf{buffer_reanalyze_freq}_' + f'1-encoder-{norm_type}-res2-channel256_gsl20_{num_games}-pred-head_' + f'nlayer8-nh24-lsd768_seed{seed}/' + ) + + +def generate_configs( + env_id_list: List[str], + action_space_size: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + exp_base_path: str, +) -> List[List[Any]]: + """ + Overview: + Generates a list of configurations for each task in a multi-task training setup. + Each configuration is paired with an environment manager config. + Arguments: + - (All arguments from create_config, plus): + - seed (:obj:`int`): The random seed for the experiment, used for naming. + - exp_base_path (:obj:`str`): The base path for saving experiment results. + Returns: + - configs (:obj:`List[List[Any]]`): A list where each item contains + [task_id, [task_specific_config, env_manager_config]]. + """ + configs = [] + exp_name_prefix = _generate_exp_name_prefix( + exp_base_path, len(env_id_list), buffer_reanalyze_freq, norm_type, seed + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, + n_episode, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size, env_id_list + ) + # Assign the specific task ID for this configuration + config.policy.task_id = task_id + # Set the full experiment name for logging and checkpointing + env_name = env_id.split('NoFrameskip')[0] + config.exp_name = exp_name_prefix + f"{env_name}_unizero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + + return configs + +# ============================================================== +# Main execution block +# ============================================================== + +if __name__ == "__main__": + """ + Overview: + This program is designed to obtain the t-SNE of the latent states in multi-task learning + across a set of Atari games (e.g., 8 games). + + This script should be executed with GPUs for Distributed Data Parallel (DDP) training. + Run one of the following commands to launch the script: + + Using `torch.distributed.launch` (deprecated): + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 ./path/to/this/script.py + + Using `torchrun` (recommended): + torchrun --nproc_per_node=8 ./path/to/this/script.py + """ + from lzero.entry import train_unizero_multitask_segment_eval + from ding.utils import DDPContext + + # --- Basic Environment and Model Setup --- + env_id_list = [ + 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', 'AlienNoFrameskip-v4', 'ChopperCommandNoFrameskip-v4', + 'HeroNoFrameskip-v4', 'RoadRunnerNoFrameskip-v4', + ] + action_space_size = 18 # Standard action space size for Atari games + + # --- Hyperparameter Configuration --- + # Grouping hyperparameters for better readability and management. + main_hyperparams = { + 'seed': 0, + 'collector_env_num': 2, + 'evaluator_env_num': 2, + 'n_episode': 2, + 'num_simulations': 50, + 'max_env_step': int(4e5), + 'reanalyze_ratio': 0.0, + 'num_segments': 2, + 'num_unroll_steps': 10, + 'infer_context_length': 4, + 'norm_type': 'LN', + 'buffer_reanalyze_freq': 1/50, + 'reanalyze_batch_size': 160, + 'reanalyze_partition': 0.75, + 'total_batch_size': int(4 * len(env_id_list)), + 'batch_size_per_task': 4, + # --- Path for experiment logs and pretrained model --- + # NOTE: Please update these paths to your local directory structure. + 'exp_base_path': 'data/unizero_mt_ddp-8gpu_eval-latent_state_tsne', + # Example for an 8-game pretrained model + 'pretrained_model_path': '/path/to/your/pretrained_model.pth.tar', + # Example for a 26-game pretrained model + # 'pretrained_model_path': '/path/to/your/26_game_model.pth.tar', + } + + # --- Generate Configurations for each seed --- + # This loop allows running experiments with multiple seeds easily. + for seed in [main_hyperparams['seed']]: + # The batch size is a list, with one entry per task. + batch_size_list = [main_hyperparams['batch_size_per_task']] * len(env_id_list) + + # Generate the list of configurations for the trainer + configs = generate_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=main_hyperparams['collector_env_num'], + n_episode=main_hyperparams['n_episode'], + evaluator_env_num=main_hyperparams['evaluator_env_num'], + num_simulations=main_hyperparams['num_simulations'], + reanalyze_ratio=main_hyperparams['reanalyze_ratio'], + batch_size=batch_size_list, + num_unroll_steps=main_hyperparams['num_unroll_steps'], + infer_context_length=main_hyperparams['infer_context_length'], + norm_type=main_hyperparams['norm_type'], + seed=seed, + buffer_reanalyze_freq=main_hyperparams['buffer_reanalyze_freq'], + reanalyze_batch_size=main_hyperparams['reanalyze_batch_size'], + reanalyze_partition=main_hyperparams['reanalyze_partition'], + num_segments=main_hyperparams['num_segments'], + total_batch_size=main_hyperparams['total_batch_size'], + exp_base_path=main_hyperparams['exp_base_path'], + ) + + # --- Launch Training --- + # Use DDPContext to manage the distributed training environment. + with DDPContext(): + train_unizero_multitask_segment_eval( + configs, + seed=seed, + model_path=main_hyperparams['pretrained_model_path'], + max_env_step=main_hyperparams['max_env_step'] + ) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py new file mode 100644 index 000000000..7f3a01636 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_segment_finetune_config.py @@ -0,0 +1,409 @@ +from easydict import EasyDict +from typing import List, Tuple, Union, Any, Dict + +class UniZeroAtariConfig: + """ + Overview: + Default configuration class for UniZero Atari experiments. + This class centralizes all default parameters, making it easier to manage and extend. + """ + def __init__(self) -> None: + self.exp_name: str = '' + self.env: EasyDict = self._get_default_env_config() + self.policy: EasyDict = self._get_default_policy_config() + + @staticmethod + def _get_default_env_config() -> EasyDict: + """ + Overview: + Returns the default environment configuration. + """ + return EasyDict(dict( + stop_value=int(1e6), + env_id='PongNoFrameskip-v4', + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=8, + evaluator_env_num=3, + n_evaluator_episode=3, + manager=dict(shared_memory=False), + full_action_space=True, + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(5e3), + )) + + @staticmethod + def _get_default_policy_config() -> EasyDict: + """ + Overview: + Returns the default policy configuration. + """ + return EasyDict(dict( + multi_gpu=True, + # ==============TODO============== + use_moco=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=50000))), + grad_correct_params=dict( + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + calpha=0.5, + rescale=1, + ), + task_num=1, + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=18, + norm_type='LN', + num_res_blocks=2, + num_channels=256, + world_model_cfg=dict( + # TODO: for latent state layer_norm + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + # TODO: only for latent state sim_norm + # final_norm_option_in_obs_head='SimNorm', + # final_norm_option_in_encoder='SimNorm', + # predict_latent_loss_type='group_kl', + share_head=False, # TODO + analysis_dormant_ratio_weight_rank=False, # TODO + dormant_threshold=0.025, + continuous_action_space=False, + # ==============TODO: none ============== + task_embed_option=None, + use_task_embed=False, + # ==============TODO============== + # task_embed_option='concat_task_embed', + # use_task_embed=True, + # task_embed_dim=96, + # task_embed_dim=128, + use_shared_projection=False, + max_blocks=10, # num_unroll_steps + max_tokens=20, # 2 * num_unroll_steps + context_length=8, # 2 * infer_context_length + device='cuda', + action_space_size=18, + num_layers=8, + num_heads=24, + embed_dim=768, + obs_type='image', + env_num=8, + task_num=1, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + moe_in_transformer=False, + multiplication_moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, + # LoRA parameters (enable LoRA by setting lora_r > 0) + lora_r=0, + # lora_r=8, + lora_alpha=32, + lora_dropout=0.1, + # Default target modules: attn and feed_forward + lora_target_modules=["attn", "feed_forward"], + ), + ), + # TODO + use_task_exploitation_weight=False, + task_complexity_weight=False, + total_batch_size=512, + allocated_batch_sizes=False, + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=10, + game_segment_length=20, + update_per_collect=80, + replay_ratio=0.25, + batch_size=64, + optim_type='AdamW', + cos_lr_scheduler=True, + num_segments=8, + num_simulations=50, + reanalyze_ratio=0.0, + n_episode=8, + replay_buffer_size=int(5e5), + eval_freq=int(2e4), + collector_env_num=8, + evaluator_env_num=3, + buffer_reanalyze_freq=1 / 10000000, + reanalyze_batch_size=160, + reanalyze_partition=0.75, + )) + +def create_config( + env_id: str, + action_space_size: int, + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: Union[int, List[int]], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + task_num: int +) -> EasyDict: + """ + Overview: + Creates and customizes a configuration for a specific Atari environment task. + + Arguments: + - env_id (:obj:`str`): The ID of the Atari environment. + - action_space_size (:obj:`int`): The size of the action space. + - collector_env_num (:obj:`int`): Number of environments for collecting data. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for each collection. + - num_simulations (:obj:`int`): Number of simulations in the MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed samples in the replay buffer. + - batch_size (:obj:`Union[int, List[int]]`): The batch size for training. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use. + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalyzing. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalyzing. + - num_segments (:obj:`int`): Number of segments for each game. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + - task_num (:obj:`int`): The total number of tasks. + + Returns: + - (:obj:`EasyDict`): A fully configured EasyDict object for the experiment. + """ + cfg = UniZeroAtariConfig() + + # == Update Environment Config == + cfg.env.env_id = env_id + cfg.env.collector_env_num = collector_env_num + cfg.env.evaluator_env_num = evaluator_env_num + cfg.env.n_evaluator_episode = evaluator_env_num + + # == Update Policy Config == + policy = cfg.policy + policy.task_num = task_num + policy.action_space_size = action_space_size + policy.n_episode = n_episode + policy.num_simulations = num_simulations + policy.reanalyze_ratio = reanalyze_ratio + policy.batch_size = batch_size + policy.total_batch_size = total_batch_size + policy.num_unroll_steps = num_unroll_steps + policy.collector_env_num = collector_env_num + policy.evaluator_env_num = evaluator_env_num + policy.buffer_reanalyze_freq = buffer_reanalyze_freq + policy.reanalyze_batch_size = reanalyze_batch_size + policy.reanalyze_partition = reanalyze_partition + policy.num_segments = num_segments + + # == Update Model Config == + model = policy.model + model.action_space_size = action_space_size + model.norm_type = norm_type + + # == Update World Model Config == + world_model = model.world_model_cfg + world_model.max_blocks = num_unroll_steps + world_model.max_tokens = 2 * num_unroll_steps + world_model.context_length = 2 * infer_context_length + world_model.action_space_size = action_space_size + world_model.task_num = task_num + + return EasyDict(cfg) + + +def generate_experiment_configs( + env_id_list: List[str], + action_space_size: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: Union[int, List[int]], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int +) -> List[Tuple[int, List[Union[EasyDict, Any]]]]: + """ + Overview: + Generates a list of configurations for multi-task experiments. + + Arguments: + - env_id_list (:obj:`List[str]`): List of environment IDs for the tasks. + - ... (same as create_config): Other experiment parameters. + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`List[Tuple[int, List[Union[EasyDict, Any]]]]`): A list where each element contains a task_id and its + corresponding configuration and environment manager setup. + """ + configs = [] + task_num = len(env_id_list) + + # --- Experiment Name Prefix --- + # This prefix defines the storage path for experiment data and logs. + # Please replace `` with your actual data storage path. + exp_name_prefix_template = ( + "/data_unizero_atari_mt_finetune_{timestamp}/" + "experiment_name/{task_num}games_brf{brf}_1-encoder-{norm}-res2-channel256_" + "gsl20_lsd768-nlayer8-nh8_upc80_seed{seed}/" + ) + exp_name_prefix = exp_name_prefix_template.format( + timestamp="20250308", + task_num=task_num, + brf=buffer_reanalyze_freq, + norm=norm_type, + seed=seed + ) + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, action_space_size, collector_env_num, evaluator_env_num, + n_episode, num_simulations, reanalyze_ratio, batch_size, + num_unroll_steps, infer_context_length, norm_type, + buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition, + num_segments, total_batch_size, task_num + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager() -> EasyDict: + """ + Overview: + Creates the environment and policy manager configuration. + This specifies the types and import paths for the environment and policy used in the experiment. + + Returns: + - (:obj:`EasyDict`): An EasyDict object containing manager configurations. + """ + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run one of the following commands to launch the script: + - Using torch.distributed.launch: + python -m torch.distributed.launch --nproc_per_node=8 --master_port=29507 ./path/to/this/script.py + - Using torchrun: + torchrun --nproc_per_node=8 ./path/to/this/script.py + """ + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import os + + # --- Main Experiment Settings --- + # Use DEBUG mode for fast iteration and debugging. + DEBUG = False + + # --- Environment and Task Settings --- + env_id_list = ['AmidarNoFrameskip-v4'] + action_space_size = 18 + + # --- Distributed Training Settings --- + os.environ["NCCL_TIMEOUT"] = "3600000000" + + # --- Loop over seeds for multiple runs --- + for seed in [0]: + # --- Core Algorithm Parameters --- + if DEBUG: + # Settings for quick debugging + collector_env_num = 2 + num_segments = 2 + n_episode = 2 + evaluator_env_num = 2 + num_simulations = 2 + total_batch_size = 32 + batch_size = [int(total_batch_size / len(env_id_list))] * len(env_id_list) + reanalyze_batch_size = 4 + max_env_step = int(1e3) + else: + # Standard experiment settings + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list)))] * len(env_id_list) + reanalyze_batch_size = 160 + max_env_step = int(4e5) + + # --- Shared Parameters --- + reanalyze_ratio = 0.0 + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 10000000 # Effectively disabled + reanalyze_partition = 0.75 + + # --- Generate Configurations --- + configs = generate_experiment_configs( + env_id_list=env_id_list, + action_space_size=action_space_size, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size + ) + + # --- Pretrained Model Path --- + # Please replace `` with the actual path to your model. + pretrained_model_path = ( + "/data_unizero_mt_atari/" + "atari_8games_brf0.02_not-share-head_final-ln_seed0/Pong_seed0/ckpt/ckpt_best.pth.tar" + ) + + # --- Start Training --- + with DDPContext(): + train_unizero_multitask_segment_ddp( + configs, + seed=seed, + model_path=pretrained_model_path, + max_env_step=max_env_step + ) \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_segment_config.py b/zoo/atari/config/atari_unizero_segment_config.py old mode 100644 new mode 100755 index a9a713656..261fb6a32 --- a/zoo/atari/config/atari_unizero_segment_config.py +++ b/zoo/atari/config/atari_unizero_segment_config.py @@ -1,7 +1,6 @@ from easydict import EasyDict from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map - def main(env_id, seed): action_space_size = atari_env_action_space_map[env_id] @@ -10,29 +9,28 @@ def main(env_id, seed): # ============================================================== collector_env_num = 8 num_segments = 8 + evaluator_env_num = 3 + game_segment_length = 20 - evaluator_env_num = 10 - num_simulations = 50 - max_env_step = int(5e5) - batch_size = 64 - num_layers = 2 - replay_ratio = 0.25 num_unroll_steps = 10 infer_context_length = 4 - # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. - buffer_reanalyze_freq = 1/50 - # Each reanalyze process will reanalyze sequences ( transitions per sequence) + num_simulations = 50 + batch_size = 128 + replay_ratio = 0.25 + + num_layers = 2 + norm_type = "LN" + + if env_id == 'ALE/Pong-v5': + max_env_step = int(5e5) + else: + max_env_step = int(10e6) + + # Reanalyze settings + buffer_reanalyze_freq = 1/5000000000 reanalyze_batch_size = 160 - # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. reanalyze_partition = 0.75 - - # ====== only for debug ===== - # collector_env_num = 2 - # num_segments = 2 - # evaluator_env_num = 2 - # num_simulations = 10 - # batch_size = 5 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -47,64 +45,61 @@ def main(env_id, seed): evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), - # TODO: only for debug - # collect_max_episode_steps=int(50), - # eval_max_episode_steps=int(50), ), policy=dict( - learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), # default is 10000 model=dict( observation_shape=(3, 64, 64), action_space_size=action_space_size, reward_support_range=(-300., 301., 1.), value_support_range=(-300., 301., 1.), + norm_type=norm_type, + num_res_blocks=2, + num_channels=128, world_model_cfg=dict( + latent_recon_loss_weight=0.1, + perceptual_loss_weight=0.1, + norm_type=norm_type, support_size=601, policy_entropy_weight=5e-3, - continuous_action_space=False, max_blocks=num_unroll_steps, - max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + max_tokens=2 * num_unroll_steps, context_length=2 * infer_context_length, - device='cuda', action_space_size=action_space_size, num_layers=num_layers, num_heads=8, embed_dim=768, - obs_type='image', env_num=max(collector_env_num, evaluator_env_num), num_simulations=num_simulations, - rotary_emb=False, + game_segment_length=game_segment_length, + device='cuda', + use_priority=True, ), ), - # (str) The path of the pretrained model. If None, the model will be initialized by the default model. - model_path=None, - use_augmentation=False, - manual_temperature_decay=False, - threshold_training_steps_for_final_temperature=int(2.5e4), - use_priority=False, - num_unroll_steps=num_unroll_steps, - update_per_collect=None, - replay_ratio=replay_ratio, - batch_size=batch_size, - optim_type='AdamW', + # Learning settings learning_rate=0.0001, - num_simulations=num_simulations, + weight_decay=1e-2, + batch_size=batch_size, + replay_ratio=replay_ratio, + num_unroll_steps=num_unroll_steps, num_segments=num_segments, - td_steps=5, - train_start_after_envsteps=0, game_segment_length=game_segment_length, - grad_clip_value=5, - replay_buffer_size=int(1e6), - eval_freq=int(5e3), - collector_env_num=collector_env_num, - evaluator_env_num=evaluator_env_num, - # ============= The key different params for reanalyze ============= - # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + num_simulations=num_simulations, + + # Priority settings + use_priority=True, + priority_prob_alpha=1, + priority_prob_beta=1, + + # Reanalyze settings buffer_reanalyze_freq=buffer_reanalyze_freq, - # Each reanalyze process will reanalyze sequences ( transitions per sequence) reanalyze_batch_size=reanalyze_batch_size, - # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. reanalyze_partition=reanalyze_partition, + + # Environment settings + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + eval_freq=int(1e4), + replay_buffer_size=int(5e5), ), ) atari_unizero_config = EasyDict(atari_unizero_config) @@ -126,8 +121,9 @@ def main(env_id, seed): # ============ use muzero_segment_collector instead of muzero_collector ============= from lzero.entry import train_unizero_segment - main_config.exp_name = f'data_lz/data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' - train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + main_config.exp_name = f'data_unizero/{env_id[3:-3]}/{env_id[3:-3]}_uz_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + + train_unizero_segment([main_config, create_config], seed=seed, model_path=None, max_env_step=max_env_step) if __name__ == "__main__": @@ -137,4 +133,7 @@ def main(env_id, seed): parser.add_argument('--seed', type=int, help='The seed to use', default=0) args = parser.parse_args() + # Test environments from atari8 base set + args.env = 'ALE/Pong-v5' + main(args.env, args.seed) diff --git a/zoo/atari/config/atari_unizero_segment_ddp_config.py b/zoo/atari/config/atari_unizero_segment_ddp_config.py index 2031f6ddf..a2c78aafb 100644 --- a/zoo/atari/config/atari_unizero_segment_ddp_config.py +++ b/zoo/atari/config/atari_unizero_segment_ddp_config.py @@ -43,7 +43,7 @@ def main(env_id, seed): env=dict( stop_value=int(1e6), env_id=env_id, - observation_shape=(3, 96, 96), + observation_shape=(3, 64, 64), gray_scale=False, collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -56,7 +56,7 @@ def main(env_id, seed): policy=dict( learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 model=dict( - observation_shape=(3, 96, 96), + observation_shape=(3, 64, 64), action_space_size=action_space_size, reward_support_range=(-300., 301., 1.), value_support_range=(-300., 301., 1.), diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 8bc491674..4648ce32f 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -2,7 +2,8 @@ from ditk import logging from typing import List -import gym +import gym +import ale_py import numpy as np from ding.envs import BaseEnv, BaseEnvTimestep from ding.torch_utils import to_ndarray @@ -24,14 +25,14 @@ class AtariEnvLightZero(BaseEnv): _reward_space, obs, _eval_episode_return, has_reset, _seed, _dynamic_seed """ config = dict( + # (bool) Whether to use the full action space of the environment. Default is False. If set to True, the action space size is 18 for Atari. + full_action_space=False, # (int) The number of environment instances used for data collection. collector_env_num=8, # (int) The number of environment instances used for evaluator. evaluator_env_num=3, # (int) The number of episodes to evaluate during each evaluation period. n_evaluator_episode=3, - # (str) The name of the Atari game environment. - # env_id='PongNoFrameskip-v4', # (str) The type of the environment, here it's Atari. env_type='Atari', # (tuple) The shape of the observation space, which is a stacked frame of 4 images each of 96x96 pixels. @@ -49,7 +50,7 @@ class AtariEnvLightZero(BaseEnv): replay_path=None, # (bool) If set to True, the game screen is converted to grayscale, reducing the complexity of the observation space. gray_scale=True, - # (int) Specifies the number of consecutive frames to stack after collecting environment data. + # (int) Specifies the number of consecutive frames to stack after collecting environment data. # The stacking process is applied within the collector and evaluator modules. frame_stack_num=1, # (int) The number of frames to skip between each action. Higher values result in faster simulation. @@ -121,12 +122,14 @@ def reset(self) -> dict: self.cfg.observation_shape[2] ) + self._action_space = self._env.action_space + self._observation_space = gym.spaces.Dict({ 'observation': gym.spaces.Box( low=0, high=1, shape=observation_space_before_stack, dtype=np.float32 ), 'action_mask': gym.spaces.Box( - low=0, high=1, shape=(self._env.env.action_space.n,), dtype=np.int8 + low=0, high=1, shape=(self._action_space.n,), dtype=np.int8 ), 'to_play': gym.spaces.Box( low=-1, high=2, shape=(), dtype=np.int8 @@ -136,18 +139,20 @@ def reset(self) -> dict: ), }) - self._action_space = self._env.env.action_space + # self._reward_space = gym.spaces.Box( + # low=self._env.env.reward_range[0], high=self._env.env.reward_range[1], shape=(1,), dtype=np.float32 + # ) self._reward_space = gym.spaces.Box( - low=self._env.env.reward_range[0], high=self._env.env.reward_range[1], shape=(1,), dtype=np.float32 + low=-9999, high=9999, shape=(1,), dtype=np.float32 ) self._init_flag = True if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: np_seed = 100 * np.random.randint(1, 1000) - self._env.env.seed(self._seed + np_seed) + self._env.seed(self._seed + np_seed) elif hasattr(self, '_seed'): - self._env.env.seed(self._seed) + self._env.seed(self._seed) result = self._env.reset() if isinstance(result, tuple): @@ -175,11 +180,14 @@ def step(self, action: int) -> BaseEnvTimestep: self.reward = np.array(reward).astype(np.float32) self._eval_episode_return += self.reward self._timestep += 1 - # logging.info(f'self._timestep: {self._timestep}') + # if self._timestep % 200 == 0: + # logging.info(f'self._timestep: {self._timestep}') observation = self.observe() if done: logging.info(f'one episode done! total episode length is: {self._timestep}') info['eval_episode_return'] = self._eval_episode_return + logging.debug(f'one episode of {self.cfg.env_id} done') + return BaseEnvTimestep(observation, self.reward, done, info) def observe(self) -> dict: @@ -200,7 +208,6 @@ def observe(self) -> dict: return {'observation': observation, 'action_mask': action_mask, 'to_play': np.array(-1), 'timestep': np.array(self._timestep)} - @property def legal_actions(self): return np.arange(self._action_space.n) @@ -265,4 +272,4 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: cfg.max_episode_steps = cfg.eval_max_episode_steps cfg.episode_life = False cfg.clip_rewards = False - return [cfg for _ in range(evaluator_env_num)] \ No newline at end of file + return [cfg for _ in range(evaluator_env_num)] diff --git a/zoo/atari/envs/atari_wrappers.py b/zoo/atari/envs/atari_wrappers.py index f38aa24d6..e009c4531 100644 --- a/zoo/atari/envs/atari_wrappers.py +++ b/zoo/atari/envs/atari_wrappers.py @@ -3,20 +3,20 @@ from typing import Optional import cv2 -# import gymnasium as gym -import gymnasium -import gym +import gym # For legacy API wrapper base class +import gymnasium # For creating environments +import ale_py import numpy as np from ding.envs import NoopResetWrapper, MaxAndSkipWrapper, EpisodicLifeWrapper, FireResetWrapper, WarpFrameWrapper, \ ScaledFloatFrameWrapper, \ - ClipRewardWrapper, FrameStackWrapper + ClipRewardWrapper, FrameStackWrapper, TimeLimitWrapper from ding.utils.compression_helper import jpeg_data_compressor from easydict import EasyDict -# from gymnasium.wrappers import RecordVideo -from gym.wrappers import RecordVideo +from gymnasium.wrappers import RecordVideo # only for reference now +# Note: If these functions are to be used with new environments, they also need similar gym/gymnasium compatibility modifications. def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True): """Configure environment for DeepMind-style Atari. The observation is channel-first: (c, h, w) instead of (h, w, c). @@ -29,10 +29,11 @@ def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, frame_stack=4, s :param bool warp_frame: wrap the grayscale + resize observation wrapper. :return: the wrapped atari environment. """ - assert 'NoFrameskip' in env_id - env = gym.make(env_id) + # assert 'NoFrameskip' in env_id + env = gymnasium.make(env_id) + env = GymnasiumToGymWrapper(env) # Add compatibility layer env = NoopResetWrapper(env, noop_max=30) - env = MaxAndSkipWrapper(env, skip=4) + env = MaxAndSkipWrapper(env, skip=1) if episode_life: env = EpisodicLifeWrapper(env) if 'FIRE' in env.unwrapped.get_action_meanings(): @@ -61,10 +62,11 @@ def wrap_deepmind_mr(env_id, episode_life=True, clip_rewards=True, frame_stack=4 :param bool warp_frame: wrap the grayscale + resize observation wrapper. :return: the wrapped atari environment. """ - assert 'MontezumaRevenge' in env_id - env = gym.make(env_id) + # assert 'MontezumaRevenge' in env_id + env = gymnasium.make(env_id) + env = GymnasiumToGymWrapper(env) # Add compatibility layer env = NoopResetWrapper(env, noop_max=30) - env = MaxAndSkipWrapper(env, skip=4) + env = MaxAndSkipWrapper(env, skip=1) if episode_life: env = EpisodicLifeWrapper(env) if 'FIRE' in env.unwrapped.get_action_meanings(): @@ -80,6 +82,30 @@ def wrap_deepmind_mr(env_id, episode_life=True, clip_rewards=True, frame_stack=4 return env +# This TimeLimit class can be replaced by ding.envs.TimeLimitWrapper for better consistency. +# However, if it needs to be retained, it now works correctly because it wraps the output of GymnasiumToGymWrapper. +class TimeLimit(gym.Wrapper): + """ + Overview: + A wrapper that limits the maximum number of steps in an episode. + """ + def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None): + super(TimeLimit, self).__init__(env) + self._max_episode_steps = max_episode_steps + self._elapsed_steps = 0 + + def step(self, ac): + observation, reward, done, info = self.env.step(ac) + self._elapsed_steps += 1 + if self._elapsed_steps is not None and self._elapsed_steps >= self._max_episode_steps: + done = True + info['TimeLimit.truncated'] = True + return observation, reward, done, info + + def reset(self, **kwargs): + self._elapsed_steps = 0 + return self.env.reset(**kwargs) + def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> gym.Env: """ Overview: @@ -92,11 +118,13 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> Return: - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. """ + # Step 1: Create base environment using gymnasium if config.render_mode_human: - env = gym.make(config.env_id, render_mode='human') + env = gymnasium.make(config.env_id, render_mode='human', full_action_space=config.full_action_space) else: - env = gym.make(config.env_id, render_mode='rgb_array') - assert 'NoFrameskip' in env.spec.id + env = gymnasium.make(config.env_id, render_mode='rgb_array', full_action_space=config.full_action_space) + + # (Optional) Apply gymnasium native wrappers if hasattr(config, 'save_replay') and config.save_replay \ and hasattr(config, 'replay_path') and config.replay_path is not None: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") @@ -108,12 +136,17 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> name_prefix=video_name ) - # env = GymnasiumToGymWrapper(env) + # Step 2: Add compatibility layer to convert gymnasium environment to gym interface + env = GymnasiumToGymWrapper(env) + + # Step 3: Now safely apply all ding and legacy gym-style wrappers env = NoopResetWrapper(env, noop_max=30) env = MaxAndSkipWrapper(env, skip=config.frame_skip) if episode_life: env = EpisodicLifeWrapper(env) + env = TimeLimit(env, max_episode_steps=config.max_episode_steps) + if config.warp_frame: # we must set WarpFrame before ScaledFloatFrameWrapper env = WarpFrame(env, width=config.observation_shape[1], height=config.observation_shape[2], grayscale=config.gray_scale) @@ -129,35 +162,6 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> return env -class TimeLimit(gym.Wrapper): - """ - Overview: - A wrapper that limits the maximum number of steps in an episode. - """ - - def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None): - """ - Arguments: - - env (:obj:`gym.Env`): The environment to wrap. - - max_episode_steps (:obj:`Optional[int]`): Maximum number of steps per episode. If None, no limit is applied. - """ - super(TimeLimit, self).__init__(env) - self._max_episode_steps = max_episode_steps - self._elapsed_steps = 0 - - def step(self, ac): - observation, reward, done, info = self.env.step(ac) - self._elapsed_steps += 1 - if self._elapsed_steps >= self._max_episode_steps: - done = True - info['TimeLimit.truncated'] = True - return observation, reward, done, info - - def reset(self, **kwargs): - self._elapsed_steps = 0 - return self.env.reset(**kwargs) - - class WarpFrame(gym.ObservationWrapper): """ Overview: @@ -265,53 +269,41 @@ def __init__(self, env: gym.Env): def legal_actions(self): return [_ for _ in range(self.env.action_space.n)] + +# This is the key compatibility wrapper class GymnasiumToGymWrapper(gym.Wrapper): """ Overview: A wrapper class that adapts a Gymnasium environment to the Gym interface. - Interface: - ``__init__``, ``reset``, ``seed`` - Properties: - - _seed (:obj:`int` or None): The seed value for the environment. """ - def __init__(self, env): - """ - Overview: - Initializes the GymnasiumToGymWrapper. - Arguments: - - env (:obj:`gymnasium.Env`): The Gymnasium environment to be wrapped. - """ - - assert isinstance(env, gymnasium.Env), type(env) + # Ensure the input is a gymnasium environment + assert isinstance(env, gymnasium.Env), f"Expected env to be a `gymnasium.Env` but got {type(env)}" super().__init__(env) self._seed = None def seed(self, seed): - """ - Overview: - Sets the seed value for the environment. - Arguments: - - seed (:obj:`int`): The seed value to use for random number generation. - """ self._seed = seed + # Call gymnasium's new seeder + self.env.reset(seed=seed) def reset(self, **kwargs): - """ - Overview: - Resets the environment and returns the initial observation. - Returns: - - observation (:obj:`Any`): The initial observation of the environment. - """ + # If seed is in kwargs, use it with priority + if self._seed is not None: + kwargs['seed'] = self._seed + self._seed = None # Seed only takes effect on first reset + + # Call gymnasium's reset, which returns (obs, info) result = self.env.reset(**kwargs) - if isinstance(result, tuple): - obs, info = result - else: - obs = result + obs, info = result + # Only return obs to match legacy gym API return obs def step(self, action): + # Call gymnasium's step, which returns (obs, rew, terminated, truncated, info) obs, rew, terminated, truncated, info = self.env.step(action) + # Merge terminated and truncated into done done = terminated or truncated - return obs, rew, done, info \ No newline at end of file + # Return 4 values to match legacy gym API + return obs, rew, done, info diff --git a/zoo/dmc2gym/config/dmc2gym_pixels_suz_config.py b/zoo/dmc2gym/config/dmc2gym_pixels_suz_config.py new file mode 100644 index 000000000..855c2bea1 --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_pixels_suz_config.py @@ -0,0 +1,130 @@ +from easydict import EasyDict +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== + +from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + +env_id = 'cartpole-swingup' # You can specify any DMC task here +action_space_size = dmc_state_env_action_space_map[env_id] +obs_space_size = dmc_state_env_obs_space_map[env_id] +print(f'env_id: {env_id}, action_space_size: {action_space_size}, obs_space_size: {obs_space_size}') + +domain_name = env_id.split('-')[0] +task_name = env_id.split('-')[1] + +continuous_action_space = True +K = 20 # num_of_sampled_actions +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = None +replay_ratio = 0.25 +max_env_step = int(1e6) +reanalyze_ratio = 0 +batch_size = 64 +num_unroll_steps = 10 +infer_context_length = 4 +norm_type = 'LN' +seed = 0 + +# for debug +# collector_env_num = 2 +# n_episode = 2 +# evaluator_env_num = 1 +# num_simulations = 2 +# batch_size = 2 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +dmc2gym_pixels_cont_sampled_unizero_config = dict( + exp_name=f'data_sampled_unizero/dmc2gym_{env_id}_image_cont_sampled_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_{norm_type}_seed{seed}', + env=dict( + env_id='dmc2gym-v0', + continuous=True, + domain_name=domain_name, + task_name=task_name, + from_pixels=True, # pixel/image obs + frame_skip=2, + warp_frame=True, + scale=True, + frame_stack_num=1, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 84, 84), + action_space_size=action_space_size, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + world_model_cfg=dict( + obs_type='image', + num_unroll_steps=num_unroll_steps, + policy_entropy_loss_weight=5e-3, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + sigma_type='conditioned', + fixed_sigma_value=0.3, + bound_type=None, + model_type='conv', + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=2, + num_heads=8, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + num_unroll_steps=num_unroll_steps, + cuda=True, + use_augmentation=False, + env_type='not_board_games', + game_segment_length=100, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + lr_piecewise_constant_decay=False, + learning_rate=0.0001, + target_update_freq=100, + grad_clip_value=5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +dmc2gym_pixels_cont_sampled_unizero_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_config) +main_config = dmc2gym_pixels_cont_sampled_unizero_config + +dmc2gym_pixels_cont_sampled_unizero_create_config = dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='sampled_unizero', + import_names=['lzero.policy.sampled_unizero'], + ), +) +dmc2gym_pixels_cont_sampled_unizero_create_config = EasyDict(dmc2gym_pixels_cont_sampled_unizero_create_config) +create_config = dmc2gym_pixels_cont_sampled_unizero_create_config + +if __name__ == "__main__": + from lzero.entry import train_unizero + + train_unizero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_config.py similarity index 97% rename from zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py rename to zoo/dmc2gym/config/dmc2gym_state_suz_config.py index f25eead6e..ec549b298 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_config.py @@ -62,6 +62,7 @@ def main(env_id, seed): evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + game_segment_length=game_segment_length, # TODO: only for debug # collect_max_episode_steps=int(20), # eval_max_episode_steps=int(20), @@ -74,7 +75,9 @@ def main(env_id, seed): continuous_action_space=continuous_action_space, num_of_sampled_actions=K, model_type='mlp', + num_layers=num_layers, world_model_cfg=dict( + game_segment_length=game_segment_length, policy_loss_type='kl', obs_type='vector', num_unroll_steps=num_unroll_steps, diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py new file mode 100644 index 000000000..39efbee62 --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py @@ -0,0 +1,516 @@ +# -*- coding: utf-8 -*- +""" +Overview: + This script defines the configuration for a multi-task reinforcement learning experiment + using the UniZero model on DeepMind Control Suite (DMC) environments. + It is designed to be launched with PyTorch's Distributed Data Parallel (DDP) for multi-GPU training. +""" +from __future__ import annotations + +from typing import Any, Dict, List +from easydict import EasyDict +import copy + + +def get_base_config(env_id_list: list[str], collector_env_num: int, evaluator_env_num: int, + num_unroll_steps: int, infer_context_length: int, curriculum_stage_num: int) -> EasyDict: + """ + Overview: + Creates the base configuration EasyDict with default settings for the experiment. + These settings are shared across all tasks but can be overridden. + + Arguments: + - env_id_list (:obj:`list[str]`): A list of environment IDs for all tasks. + - collector_env_num (:obj:`int`): The number of environments for data collection. + - evaluator_env_num (:obj:`int`): The number of environments for evaluation. + - num_unroll_steps (:obj:`int`): The number of game steps to unroll in the model. + - infer_context_length (:obj:`int`): The context length for inference. + - curriculum_stage_num (:obj:`int`): The number of stages in the curriculum learning. + + Returns: + - (:obj:`EasyDict`): A dictionary containing the base configuration. + """ + return EasyDict(dict( + # Environment-specific settings + env=dict( + stop_value=int(5e5), + from_pixels=False, + continuous=True, # Assuming all DMC tasks use continuous action spaces + manager=dict(shared_memory=False), + game_segment_length=100, + # TODO(user): For debugging only. Uncomment to use smaller segments and episodes. + # game_segment_length=10, + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ), + # Policy-specific settings + policy=dict( + multi_gpu=True, + # TODO: Configure MoCo settings. + only_use_moco_stats=False, + use_moco=False, + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + total_task_num=len(env_id_list), + task_num=len(env_id_list), + # Model configuration + model=dict( + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=dict( + game_segment_length=100, + + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO(user): Loss type for latent state with LayerNorm. + + share_head=False, # TODO(user): Whether to share the prediction head across tasks. + use_shared_projection=False, + + # TODO(user): analysis_dormant_ratio needs to be corrected for the DMC encoder. + analysis_dormant_ratio_weight_rank=False, + analysis_dormant_ratio_interval=5000, + # analysis_dormant_ratio_interval=20, # For debugging + + # TODO(user): Configure task embedding options. + task_embed_option=None, + use_task_embed=False, + # task_embed_option='concat_task_embed', + # use_task_embed=True, + # task_embed_dim=128, + + policy_loss_type='kl', + obs_type='vector', + policy_entropy_weight=5e-2, + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + model_type='mlp', + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + + # TODO(user): For debugging only. Use a smaller model. + # num_layers=1, + num_layers=4, + + num_heads=24, + embed_dim=768, + env_num=max(collector_env_num, evaluator_env_num), + task_num=len(env_id_list), + + # Mixture of Experts (MoE) head configuration + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + # MoE in Transformer configuration + moe_in_transformer=False, + multiplication_moe_in_transformer=True, + n_shared_experts=1, + num_experts_per_tok=1, + num_experts_of_moe_in_transformer=8, + + # LoRA (Low-Rank Adaptation) parameters + # TODO(user): Enable or disable LoRA for MoE layers. + moe_use_lora=True, + lora_target_modules=["attn", "feed_forward"], + lora_r=64, + lora_alpha=1, + lora_dropout=0.0, + lora_scale_init=1, + + # Curriculum learning stage iteration counts + curriculum_stage_num=curriculum_stage_num, + min_stage0_iters=10000, # Corresponds to 400k envsteps, 40k iters + max_stage_iters=5000, + + # TODO(user): For debugging only. Use very short stage iterations. + # min_stage0_iters=2, + # max_stage_iters=5, + ), + ), + # TODO(user): Enable or disable task exploitation weight. + use_task_exploitation_weight=False, + balance_pipeline=True, + # TODO(user): Enable or disable task complexity weight. + task_complexity_weight=True, + allocated_batch_sizes=False, + # TODO(user): Set the number of environment steps to collect before training starts. + train_start_after_envsteps=int(0), + use_priority=False, + print_task_priority_logs=False, + cuda=True, + model_path=None, + + # TODO(user): For debugging only. Set a smaller update_per_collect. + # update_per_collect=3, + update_per_collect=200, # e.g., 8 envs * 100 steps/env * 0.25 replay_ratio = 200 + replay_buffer_size=int(1e6), + eval_freq=int(4e3), + grad_clip_value=5, + learning_rate=1e-4, + discount_factor=0.99, + td_steps=5, + piecewise_decay_lr_scheduler=False, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + cos_lr_scheduler=True, + ), + )) + + +def create_task_config( + base_config: EasyDict, + env_id: str, + observation_shape_list: list[int], + action_space_size_list: list[int], + target_return_dict: dict[str, int], + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: int, + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int +) -> EasyDict: + """ + Overview: + Creates a specialized configuration for a single task by updating the base config. + + Arguments: + - base_config (:obj:`EasyDict`): The base configuration dictionary. + - env_id (:obj:`str`): The ID of the environment for this specific task. + - observation_shape_list (:obj:`list[int]`): List of observation shapes for all tasks. + - action_space_size_list (:obj:`list[int]`): List of action space sizes for all tasks. + - target_return_dict (:obj:`dict[str, int]`): A dictionary mapping env_id to its target return. + - collector_env_num (:obj:`int`): The number of collector environments. + - evaluator_env_num (:obj:`int`): The number of evaluator environments. + - n_episode (:obj:`int`): The number of episodes to run for collection. + - num_simulations (:obj:`int`): The number of simulations in MCTS. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`int`): The batch size for training this task. + - num_unroll_steps (:obj:`int`): The number of steps to unroll the model. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of buffer reanalysis. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalysis. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalysis. + - num_segments (:obj:`int`): The number of segments in the replay buffer. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + + Returns: + - (:obj:`EasyDict`): The final configuration for the specified task. + """ + domain_name, task_name = env_id.split('-', 1) + frame_skip = 8 if domain_name == "pendulum" else 4 + + config = base_config + + # Update environment settings + config.env.update(dict( + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + frame_skip=frame_skip, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + )) + + # Update model settings + config.policy.model.update(dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + )) + config.policy.model.world_model_cfg.update(dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + num_unroll_steps=num_unroll_steps, + norm_type=norm_type, + context_length=2 * infer_context_length, + )) + + # Update policy settings + config.policy.update(dict( + target_return=target_return_dict.get(env_id), + total_batch_size=total_batch_size, + num_unroll_steps=num_unroll_steps, + replay_ratio=reanalyze_ratio, + batch_size=batch_size, + num_segments=num_segments, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + )) + + return config + + +def create_env_manager_config() -> EasyDict: + """ + Overview: + Creates the configuration for the environment manager and policy type. + + Returns: + - (:obj:`EasyDict`): A dictionary with environment manager and policy import settings. + """ + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +def generate_experiment_name(num_tasks: int, curriculum_stage_num: int, buffer_reanalyze_freq: float, seed: int) -> str: + """ + Overview: + Generates a descriptive name for the experiment. + + Arguments: + - num_tasks (:obj:`int`): Number of tasks in the experiment. + - curriculum_stage_num (:obj:`int`): Number of curriculum stages. + - buffer_reanalyze_freq (:obj:`float`): Frequency of buffer reanalysis. + - seed (:obj:`int`): The random seed for the experiment. + + Returns: + - (:obj:`str`): The generated experiment name prefix. + """ + + return ( + f'data_suz_dmc_mt_balance/dmc_{num_tasks}tasks_frameskip4-pen-fs8_balance-stage-total-{curriculum_stage_num}' + f'_stage0-10k-5k_moe8_nlayer4' + f'_brf{buffer_reanalyze_freq}_seed{seed}/' + ) + + +def generate_all_task_configs( + env_id_list: list[str], + target_return_dict: dict[str, int], + action_space_size_list: list[int], + observation_shape_list: list[int], + curriculum_stage_num: int, + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: list[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int +) -> list[tuple[int, list[EasyDict, EasyDict]]]: + """ + Overview: + Generates a list of configurations, one for each task in the experiment. + + Arguments: + - env_id_list (:obj:`list[str]`): A list of all environment IDs. + - target_return_dict (:obj:`dict[str, int]`): Mapping from env_id to target return. + - action_space_size_list (:obj:`list[int]`): List of action space sizes for all tasks. + - observation_shape_list (:obj:`list[int]`): List of observation shapes for all tasks. + - curriculum_stage_num (:obj:`int`): The number of curriculum stages. + - (other args): Hyperparameters for the experiment. See `create_task_config` for details. + + Returns: + - (:obj:`list`): A list where each element is `[task_id, [task_config, env_manager_config]]`. + """ + configs = [] + exp_name_prefix = generate_experiment_name( + num_tasks=len(env_id_list), + curriculum_stage_num=curriculum_stage_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + seed=seed + ) + + base_config = get_base_config( + env_id_list=env_id_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + curriculum_stage_num=curriculum_stage_num + ) + + for task_id, env_id in enumerate(env_id_list): + task_specific_config = create_task_config( + base_config=copy.deepcopy(base_config), + env_id=env_id, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + target_return_dict=target_return_dict, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size[task_id], + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + task_specific_config.policy.task_id = task_id + task_specific_config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + + env_manager_cfg = create_env_manager_config() + configs.append([task_id, [task_specific_config, env_manager_cfg]]) + + return configs + + +def main(): + """ + Overview: + Main function to set up and launch the multi-task UniZero training experiment. + This script should be executed with GPUs. + + Example launch commands: + + cd /LightZero/ + torchrun --nproc_per_node=4 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py 2>&1 | tee \\ + ./logs/uz_mt_dmc18_balance_moe8_seed0.log + """ + from lzero.entry import train_unizero_multitask_balance_segment_ddp + from ding.utils import DDPContext + import torch.distributed as dist + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + # ============================================================== + # Experiment-level settings + # ============================================================== + # NOTE: You can switch between different sets of environments by uncommenting them. + # DMC 8-task benchmark + # env_id_list = [ + # 'acrobot-swingup', 'cartpole-balance', 'cartpole-balance_sparse', + # 'cartpole-swingup', 'cartpole-swingup_sparse', 'cheetah-run', + # "ball_in_cup-catch", "finger-spin", + # ] + # target_return_dict = { + # 'acrobot-swingup': 500, 'cartpole-balance': 950, 'cartpole-balance_sparse': 950, + # 'cartpole-swingup': 800, 'cartpole-swingup_sparse': 750, 'cheetah-run': 650, + # "ball_in_cup-catch": 950, "finger-spin": 800, + # } + + # DMC 18-task benchmark + env_id_list = [ + 'acrobot-swingup', 'cartpole-balance', 'cartpole-balance_sparse', 'cartpole-swingup', + 'cartpole-swingup_sparse', 'cheetah-run', "ball_in_cup-catch", "finger-spin", + "finger-turn_easy", "finger-turn_hard", 'hopper-hop', 'hopper-stand', + 'pendulum-swingup', 'reacher-easy', 'reacher-hard', 'walker-run', + 'walker-stand', 'walker-walk', + ] + target_return_dict = { + 'acrobot-swingup': 500, 'cartpole-balance': 900, 'cartpole-balance_sparse': 950, + 'cartpole-swingup': 750, 'cartpole-swingup_sparse': 750, 'cheetah-run': 550, + "ball_in_cup-catch": 950, "finger-spin": 800, "finger-turn_easy": 950, + "finger-turn_hard": 950, 'hopper-hop': 150, 'hopper-stand': 600, + 'pendulum-swingup': 800, 'reacher-easy': 900, 'reacher-hard': 900, + 'walker-run': 500, 'walker-stand': 900, 'walker-walk': 900, + } + + # ============================================================== + # Hyperparameters + # ============================================================== + # NOTE: For debugging, you can use smaller values. + # collector_env_num, num_segments, n_episode = 2, 2, 2 + # evaluator_env_num, num_simulations, total_batch_size = 2, 1, 8 + # batch_size = [3] * len(env_id_list) + # max_env_step = int(1e3) + + curriculum_stage_num = 5 + collector_env_num = 8 + num_segments = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(4e5) + reanalyze_ratio = 0.0 + total_batch_size = 512 + batch_size = [int(min(64, total_batch_size / len(env_id_list)))] * len(env_id_list) + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + seed = 0 # You can iterate over multiple seeds if needed + + # Fetch observation and action space info from predefined maps + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + # ============================================================== + # Generate configurations and start training + # ============================================================== + configs = generate_all_task_configs( + env_id_list=env_id_list, + target_return_dict=target_return_dict, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + curriculum_stage_num=curriculum_stage_num, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + + with DDPContext(): + # To train only a subset of tasks for debugging, you can slice the configs list. + # e.g., train_unizero_multitask_balance_segment_ddp(configs[:1], ...) + train_unizero_multitask_balance_segment_ddp(configs, seed=seed, max_env_step=max_env_step, benchmark_name="dmc") + dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py new file mode 100644 index 000000000..61b346d8b --- /dev/null +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_config.py @@ -0,0 +1,459 @@ +from easydict import EasyDict +from typing import List, Any, Dict, Tuple + + +def create_config( + env_id: str, + env_id_list: List[str], + target_return_dict: Dict[str, int], + observation_shape_list: List[Tuple[int, ...]], + action_space_size_list: List[int], + collector_env_num: int, + evaluator_env_num: int, + n_episode: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, +) -> EasyDict: + """ + Overview: + Create a configuration EasyDict for a single reinforcement learning task. + + Arguments: + - env_id (:obj:`str`): The ID of the environment, e.g., 'cartpole-swingup'. + - env_id_list (:obj:`List[str]`): A list of all environment IDs for the multi-task setup. + - target_return_dict (:obj:`Dict[str, int]`): A dictionary mapping environment IDs to their target return values. + - observation_shape_list (:obj:`List[Tuple[int, ...]]`): List of observation shapes for all tasks. + - action_space_size_list (:obj:`List[int]`): List of action space sizes for all tasks. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - n_episode (:obj:`int`): Number of episodes to run for collection. + - num_simulations (:obj:`int`): Number of simulations in the MCTS search. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`List[int]`): Batch size for training per task. + - num_unroll_steps (:obj:`int`): Number of steps to unroll the model during training. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN'). + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalyzing. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalyzing. + - num_segments (:obj:`int`): Number of segments for the replay buffer. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + + Returns: + - (:obj:`EasyDict`): A configuration object for the specified task. + """ + domain_name, task_name = env_id.split('-') + + # Specific frame_skip settings for certain domains. + if domain_name == "pendulum": + frame_skip = 8 + else: + frame_skip = 4 + + # --- Environment Configuration --- + env_cfg = dict( + stop_value=int(5e5), + env_id=env_id, + domain_name=domain_name, + task_name=task_name, + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + from_pixels=False, + frame_skip=frame_skip, + continuous=True, # Assuming all DMC tasks use continuous action spaces + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + game_segment_length=100, + # TODO: Settings for debugging purposes. + # game_segment_length=10, + # collect_max_episode_steps=int(40), + # eval_max_episode_steps=int(40), + ) + + # --- World Model Configuration --- + world_model_cfg = dict( + game_segment_length=100, + + # --- Normalization and Loss --- + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', # TODO: for latent state layer_norm + + # --- Architecture --- + share_head=False, # TODO + use_shared_projection=False, + obs_type='vector', + model_type='mlp', + continuous_action_space=True, + num_of_sampled_actions=20, + sigma_type='conditioned', + fixed_sigma_value=0.5, + bound_type=None, + norm_type=norm_type, + device='cuda', + + # --- Transformer/MOE Settings --- + num_layers=4, # TODO: 8 for standard, 1 for debug + num_heads=24, + embed_dim=768, + moe_in_transformer=False, + multiplication_moe_in_transformer=True, + num_experts_of_moe_in_transformer=8, + n_shared_experts=1, + num_experts_per_tok=1, + use_normal_head=True, + use_softmoe_head=False, + use_moe_head=False, + num_experts_in_moe_head=4, + + # --- LoRA Parameters --- + moe_use_lora=False, # TODO + curriculum_stage_num=3, + lora_target_modules=["attn", "feed_forward"], + lora_r=0, + lora_alpha=1, + lora_dropout=0.0, + + # --- Multi-task Settings --- + task_embed_option=None, # TODO: 'concat_task_embed' or None + use_task_embed=False, # TODO + task_num=len(env_id_list), + + # --- Analysis --- + analysis_dormant_ratio_weight_rank=False, # TODO + analysis_dormant_ratio_interval=5000, + + # --- Dynamic Properties --- + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + num_unroll_steps=num_unroll_steps, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # Each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + env_num=max(collector_env_num, evaluator_env_num), + + # --- Loss Weights --- + policy_loss_type='kl', + policy_entropy_weight=5e-2, + ) + + # --- Policy Configuration --- + policy_cfg = dict( + # --- Hardware & Distribution --- + multi_gpu=True, # TODO: enable multi-GPU for DDP + cuda=True, + + # --- Model --- + model=dict( + observation_shape_list=observation_shape_list, + action_space_size_list=action_space_size_list, + continuous_action_space=True, + num_of_sampled_actions=20, + model_type='mlp', + world_model_cfg=world_model_cfg, + ), + + # --- Learning --- + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))), + optim_type='AdamW', + learning_rate=1e-4, + grad_clip_value=5, + cos_lr_scheduler=True, + piecewise_decay_lr_scheduler=False, + + # --- Training Loop --- + train_start_after_envsteps=int(0), # TODO: 2e3 for standard, 0 for quick debug + update_per_collect=200, + replay_ratio=reanalyze_ratio, + + # --- Batch Sizes --- + batch_size=batch_size, + total_batch_size=total_batch_size, + allocated_batch_sizes=False, + + # --- Replay Buffer --- + replay_buffer_size=int(1e6), + num_segments=num_segments, + use_priority=False, + + # --- Reanalyze --- + reanalyze_ratio=reanalyze_ratio, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + + # --- Algorithm Hyperparameters --- + num_simulations=num_simulations, + num_unroll_steps=num_unroll_steps, + td_steps=5, + discount_factor=0.99, + manual_temperature_decay=True, + threshold_training_steps_for_final_temperature=int(2.5e4), + + # --- MoCo (Momentum Contrast) --- + use_moco=False, # TODO + only_use_moco_stats=False, + grad_correct_params=dict( + MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, + calpha=0.5, rescale=1, + ), + + # --- Multi-task Specific --- + total_task_num=len(env_id_list), + task_num=len(env_id_list), + task_id=0, # To be set per task + target_return=target_return_dict.get(env_id), + use_task_exploitation_weight=False, # TODO + task_complexity_weight=True, # TODO + balance_pipeline=True, + print_task_priority_logs=False, + + # --- Environment Interaction --- + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + eval_freq=int(4e3), + + # --- Checkpointing --- + model_path=None, + ) + + # --- Combine configurations into the final EasyDict object --- + main_config = EasyDict(dict( + env=env_cfg, + policy=policy_cfg, + )) + + return main_config + + +def generate_configs( + env_id_list: List[str], + target_return_dict: Dict[str, int], + collector_env_num: int, + n_episode: int, + evaluator_env_num: int, + num_simulations: int, + reanalyze_ratio: float, + batch_size: List[int], + num_unroll_steps: int, + infer_context_length: int, + norm_type: str, + seed: int, + buffer_reanalyze_freq: float, + reanalyze_batch_size: int, + reanalyze_partition: float, + num_segments: int, + total_batch_size: int, + dmc_state_env_action_space_map: Dict[str, int], + dmc_state_env_obs_space_map: Dict[str, Tuple[int, ...]], +) -> List[Tuple[int, List[Any]]]: + """ + Overview: + Generate a list of configurations for all specified multi-task environments. + + Arguments: + - env_id_list (:obj:`List[str]`): A list of all environment IDs for the multi-task setup. + - target_return_dict (:obj:`Dict[str, int]`): A dictionary mapping environment IDs to their target return values. + - collector_env_num (:obj:`int`): Number of environments for data collection. + - n_episode (:obj:`int`): Number of episodes to run for collection. + - evaluator_env_num (:obj:`int`): Number of environments for evaluation. + - num_simulations (:obj:`int`): Number of simulations in the MCTS search. + - reanalyze_ratio (:obj:`float`): The ratio of reanalyzed data in a batch. + - batch_size (:obj:`List[int]`): Batch size for training per task. + - num_unroll_steps (:obj:`int`): Number of steps to unroll the model during training. + - infer_context_length (:obj:`int`): The context length for inference. + - norm_type (:obj:`str`): The type of normalization to use (e.g., 'LN'). + - seed (:obj:`int`): The random seed. + - buffer_reanalyze_freq (:obj:`float`): Frequency of reanalyzing the buffer. + - reanalyze_batch_size (:obj:`int`): Batch size for reanalyzing. + - reanalyze_partition (:obj:`float`): Partition ratio for reanalyzing. + - num_segments (:obj:`int`): Number of segments for the replay buffer. + - total_batch_size (:obj:`int`): The total batch size across all tasks. + - dmc_state_env_action_space_map (:obj:`Dict[str, int]`): Map from env_id to action space size. + - dmc_state_env_obs_space_map (:obj:`Dict[str, Tuple[int, ...]]`): Map from env_id to observation shape. + + Returns: + - (:obj:`List[Tuple[int, List[Any]]]`): A list where each element contains the task ID and its corresponding + configuration objects. + """ + configs = [] + + exp_name_prefix = ( + f'data_suz_dmc_mt/dmc_{len(env_id_list)}tasks_frameskip4-pendulum-skip8_ln-mse' + f'_nlayer8_trans-moe8_brf{buffer_reanalyze_freq}_seed{seed}/' + ) + + # Get action_space_size and observation_shape for each environment. + action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list] + observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list] + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id=env_id, + env_id_list=env_id_list, + target_return_dict=target_return_dict, + action_space_size_list=action_space_size_list, + observation_shape_list=observation_shape_list, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_episode=n_episode, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id}_seed{seed}" + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager() -> EasyDict: + """ + Overview: + Create the environment and policy manager configuration. This specifies the types + of environment, policy, and their import paths. + + Returns: + - (:obj:`EasyDict`): A configuration object for the environment and policy managers. + """ + return EasyDict(dict( + env=dict( + type='dmc2gym_lightzero', + import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_unizero_multitask', + import_names=['lzero.policy.sampled_unizero_multitask'], + ), + )) + + +if __name__ == "__main__": + """ + Overview: + Main script to configure and launch a multi-task training session for DeepMind Control Suite (DMC) + environments using Distributed Data Parallel (DDP). + + Usage: + This script should be executed with GPUs. + Navigate to the project root directory and run the launch command. + + Example command: + cd + torchrun --nproc_per_node=8 /dmc2gym_state_suz_multitask_ddp_config.py 2>&1 | tee \\ + /uz_mt_dmc18_train.log + """ + # --- Import necessary components for training --- + # It's good practice to place imports inside the main guard + # if they are only used for script execution. + from lzero.entry import train_unizero_multitask_segment_ddp + from ding.utils import DDPContext + import torch.distributed as dist + from zoo.dmc2gym.config.dmc_state_env_space_map import dmc_state_env_action_space_map, dmc_state_env_obs_space_map + + # --- Experiment constants --- + BENCHMARK_NAME = 'dmc' + + # --- Environment and Task Definitions --- + # Target return values for each DMC task, used for evaluation and potential curriculum. + target_return_dict = { + 'acrobot-swingup': 500, + 'cartpole-balance': 950, + 'cartpole-balance_sparse': 950, + 'cartpole-swingup': 800, + 'cartpole-swingup_sparse': 750, + 'cheetah-run': 650, + "ball_in_cup-catch": 950, + "finger-spin": 800, + "finger-turn_easy": 950, + "finger-turn_hard": 950, + 'hopper-hop': 150, + 'hopper-stand': 600, + 'pendulum-swingup': 800, + 'reacher-easy': 950, + 'reacher-hard': 950, + 'walker-run': 600, + 'walker-stand': 950, + 'walker-walk': 950, + } + + # List of DMC environments to be used in the multi-task setup. + env_id_list = list(target_return_dict.keys()) + + # --- Hyperparameters for the training session --- + # Environment and Collector settings + collector_env_num = 8 + evaluator_env_num = 3 + n_episode = 8 + max_env_step = int(4e5) + + # Replay Buffer and Reanalyze settings + num_segments = 8 + reanalyze_ratio = 0.0 + buffer_reanalyze_freq = 1 / 100000 + reanalyze_batch_size = 160 + reanalyze_partition = 0.75 + + # Model and Training settings + total_batch_size = 512 + # Allocate batch size per task, ensuring a minimum of 64 or distributing the total size. + batch_size = [int(min(64, total_batch_size / len(env_id_list))) for _ in range(len(env_id_list))] + num_unroll_steps = 5 + infer_context_length = 2 + norm_type = 'LN' + num_simulations = 50 + + # --- Main training loop --- + # Iterate over different random seeds for multiple runs. + for seed in [1, 2]: + # Generate the specific configurations for each task for the current run. + configs = generate_configs( + env_id_list=env_id_list, + target_return_dict=target_return_dict, + collector_env_num=collector_env_num, + n_episode=n_episode, + evaluator_env_num=evaluator_env_num, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + batch_size=batch_size, + num_unroll_steps=num_unroll_steps, + infer_context_length=infer_context_length, + norm_type=norm_type, + seed=seed, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + num_segments=num_segments, + total_batch_size=total_batch_size, + dmc_state_env_action_space_map=dmc_state_env_action_space_map, + dmc_state_env_obs_space_map=dmc_state_env_obs_space_map, + ) + + with DDPContext(): + train_unizero_multitask_segment_ddp(configs, seed=seed, max_env_step=max_env_step, + benchmark_name=BENCHMARK_NAME) + # If you only want to train a subset of tasks, you can slice the configs list. + # For example, to train only the first four tasks: + # train_unizero_multitask_segment_ddp(configs[:4], seed=seed, max_env_step=max_env_step, benchmark_name=BENCHMARK_NAME) + dist.destroy_process_group() \ No newline at end of file diff --git a/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py b/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py index 4fcfb209a..4ab70ce9c 100644 --- a/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py +++ b/zoo/dmc2gym/envs/dmc2gym_lightzero_env.py @@ -18,6 +18,8 @@ from gym.spaces import Box from matplotlib import animation import imageio +import logging + def dmc2gym_observation_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Callable: def observation_space(from_pixels=True, height=84, width=84, channels_first=True) -> Box: @@ -268,6 +270,8 @@ def __init__(self, cfg: dict = {}) -> None: self._save_replay_gif = cfg.save_replay_gif self._replay_path_gif = cfg.replay_path_gif self._save_replay_count = 0 + self._timestep = 0 + self._max_episode_steps = cfg.max_episode_steps def reset(self) -> Dict[str, np.ndarray]: """ @@ -409,11 +413,12 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: if self._save_replay_gif: self._frames.append(image_obs) - - if self._timestep > self._cfg.max_episode_steps: + + if self._timestep > self._max_episode_steps: done = True if done: + logging.info(f'one episode done! episode return: {self._eval_episode_return}, episode_steps:{self._timestep}') info['eval_episode_return'] = self._eval_episode_return if self._save_replay_gif: @@ -422,7 +427,8 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}_{}.gif'.format(f'{self._cfg["domain_name"]}_{self._cfg["task_name"]}', self._save_replay_count, self._seed, timestamp) + '{}_episode_{}_seed{}_{}.gif'.format(f'{self._cfg["domain_name"]}_{self._cfg["task_name"]}', + self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') @@ -487,7 +493,7 @@ def __repr__(self) -> str: String representation of the environment. """ return "LightZero DMC2Gym Env({}:{})".format(self._cfg["domain_name"], self._cfg["task_name"]) - + @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_env_num = cfg.pop('collector_env_num') @@ -502,4 +508,4 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: cfg = copy.deepcopy(cfg) cfg.max_episode_steps = cfg.eval_max_episode_steps cfg.is_eval = True - return [cfg for _ in range(evaluator_env_num)] + return [cfg for _ in range(evaluator_env_num)] \ No newline at end of file diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index d9f50180d..c1b34da37 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -151,7 +151,7 @@ def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e manual_temperature_decay=False, num_simulations=num_simulations, n_episode=n_episode, - train_start_after_envsteps=0, + train_start_after_envsteps=0, # TODO: Adjust training start trigger if needed. replay_buffer_size=int(5e5), eval_freq=int(3e4), collector_env_num=collector_env_num,