|
| 1 | +# LightZero Entry Functions |
| 2 | + |
| 3 | +English | [中文](./README_zh.md) |
| 4 | + |
| 5 | +This directory contains the training and evaluation entry functions for various algorithms in the LightZero framework. These entry functions serve as the main interfaces for launching different types of reinforcement learning experiments. |
| 6 | + |
| 7 | +## 📁 Directory Structure |
| 8 | + |
| 9 | +### 🎯 Training Entries |
| 10 | + |
| 11 | +#### AlphaZero Family |
| 12 | +- **`train_alphazero.py`** - Training entry for AlphaZero algorithm |
| 13 | + - Suitable for perfect information board games (e.g., Go, Chess) |
| 14 | + - No environment model needed, learns through self-play |
| 15 | + - Uses Monte Carlo Tree Search (MCTS) for policy improvement |
| 16 | + |
| 17 | +#### MuZero Family |
| 18 | +- **`train_muzero.py`** - Standard training entry for MuZero algorithm |
| 19 | + - Supports MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero variants |
| 20 | + - Learns an implicit model of the environment (dynamics model) |
| 21 | + - Suitable for single-task reinforcement learning scenarios |
| 22 | + |
| 23 | +- **`train_muzero_segment.py`** - MuZero training with segment collector and buffer reanalyze |
| 24 | + - Uses `MuZeroSegmentCollector` for data collection |
| 25 | + - Supports buffer reanalyze trick for improved sample efficiency |
| 26 | + - Supported algorithms: MuZero, EfficientZero, Sampled MuZero, Sampled EfficientZero, Gumbel MuZero, StochasticMuZero |
| 27 | + |
| 28 | +- **`train_muzero_with_gym_env.py`** - MuZero training adapted for Gym environments |
| 29 | + - Specifically designed for OpenAI Gym-style environments |
| 30 | + - Simplifies environment interface adaptation |
| 31 | + |
| 32 | +- **`train_muzero_with_reward_model.py`** - MuZero training with reward model |
| 33 | + - Integrates external Reward Model |
| 34 | + - Suitable for scenarios requiring learning complex reward functions |
| 35 | + |
| 36 | +- **`train_muzero_multitask_segment_ddp.py`** - MuZero multi-task distributed training |
| 37 | + - Supports multi-task learning |
| 38 | + - Uses DDP (Distributed Data Parallel) for distributed training |
| 39 | + - Uses Segment Collector |
| 40 | + |
| 41 | +#### UniZero Family |
| 42 | +- **`train_unizero.py`** - Training entry for UniZero algorithm |
| 43 | + - Based on paper "UniZero: Generalized and Efficient Planning with Scalable Latent World Models" |
| 44 | + - Enhanced planning capabilities for better long-term dependency capture |
| 45 | + - Uses scalable latent world models |
| 46 | + - Paper: https://arxiv.org/abs/2406.10667 |
| 47 | + |
| 48 | +- **`train_unizero_segment.py`** - UniZero training with segment collector |
| 49 | + - Uses `MuZeroSegmentCollector` for efficient data collection |
| 50 | + - Supports buffer reanalyze trick |
| 51 | + |
| 52 | +- **`train_unizero_multitask_segment_ddp.py`** - UniZero multi-task distributed training |
| 53 | + - Supports multi-task learning and distributed training |
| 54 | + - Includes benchmark score definitions (e.g., Atari human-normalized scores) |
| 55 | + - Supports curriculum learning strategies |
| 56 | + - Uses DDP for training acceleration |
| 57 | + |
| 58 | +- **`train_unizero_multitask_balance_segment_ddp.py`** - UniZero balanced multi-task distributed training |
| 59 | + - Implements balanced sampling across tasks in multi-task training |
| 60 | + - Dynamically adjusts batch sizes for different tasks |
| 61 | + - Suitable for scenarios with large task difficulty variations |
| 62 | + |
| 63 | +- **`train_unizero_multitask_segment_eval.py`** - UniZero multi-task evaluation training |
| 64 | + - Specialized for training and periodic evaluation in multi-task scenarios |
| 65 | + - Includes detailed evaluation metric statistics |
| 66 | + |
| 67 | +- **`train_unizero_with_loss_landscape.py`** - UniZero training with loss landscape visualization |
| 68 | + - For training with loss landscape visualization |
| 69 | + - Helps understand model optimization process and generalization performance |
| 70 | + - Integrates `loss_landscapes` library |
| 71 | + |
| 72 | +#### ReZero Family |
| 73 | +- **`train_rezero.py`** - Training entry for ReZero algorithm |
| 74 | + - Supports ReZero-MuZero and ReZero-EfficientZero |
| 75 | + - Improves training stability through residual connections |
| 76 | + - Paper: https://arxiv.org/pdf/2404.16364 |
| 77 | + |
| 78 | +### 🎓 Evaluation Entries |
| 79 | + |
| 80 | +- **`eval_alphazero.py`** - Evaluation entry for AlphaZero |
| 81 | + - Loads trained AlphaZero models for evaluation |
| 82 | + - Can play against other agents for performance testing |
| 83 | + |
| 84 | +- **`eval_muzero.py`** - Evaluation entry for MuZero family |
| 85 | + - Supports evaluation of all MuZero variants |
| 86 | + - Provides detailed performance statistics |
| 87 | + |
| 88 | +- **`eval_muzero_with_gym_env.py`** - MuZero evaluation for Gym environments |
| 89 | + - Specialized for evaluating models trained in Gym environments |
| 90 | + |
| 91 | +### 🛠️ Utility Modules |
| 92 | + |
| 93 | +- **`utils.py`** - Common utility functions library |
| 94 | + - **Math & Tensor Utilities**: |
| 95 | + - `symlog`, `inv_symlog` - Symmetric logarithm transformations |
| 96 | + - `initialize_zeros_batch`, `initialize_pad_batch` - Batch initialization |
| 97 | + |
| 98 | + - **LoRA Utilities**: |
| 99 | + - `freeze_non_lora_parameters` - Freeze non-LoRA parameters |
| 100 | + |
| 101 | + - **Task & Curriculum Learning Utilities**: |
| 102 | + - `compute_task_weights` - Compute task weights |
| 103 | + - `TemperatureScheduler` - Temperature scheduler |
| 104 | + - `tasks_per_stage` - Calculate tasks per stage |
| 105 | + - `compute_unizero_mt_normalized_stats` - Compute normalized statistics |
| 106 | + - `allocate_batch_size` - Dynamically allocate batch sizes |
| 107 | + |
| 108 | + - **Distributed Training Utilities (DDP)**: |
| 109 | + - `is_ddp_enabled` - Check if DDP is enabled |
| 110 | + - `ddp_synchronize` - DDP synchronization |
| 111 | + - `ddp_all_reduce_sum` - DDP all-reduce sum |
| 112 | + |
| 113 | + - **RL Workflow Utilities**: |
| 114 | + - `calculate_update_per_collect` - Calculate updates per collection |
| 115 | + - `random_collect` - Random policy data collection |
| 116 | + - `convert_to_batch_for_unizero` - UniZero batch data conversion |
| 117 | + - `create_unizero_loss_metrics` - Create loss metrics function |
| 118 | + - `UniZeroDataLoader` - UniZero data loader |
| 119 | + |
| 120 | + - **Logging Utilities**: |
| 121 | + - `log_module_trainable_status` - Log module trainable status |
| 122 | + - `log_param_statistics` - Log parameter statistics |
| 123 | + - `log_buffer_memory_usage` - Log buffer memory usage |
| 124 | + - `log_buffer_run_time` - Log buffer runtime |
| 125 | + |
| 126 | +- **`__init__.py`** - Package initialization file |
| 127 | + - Exports all training and evaluation entry functions |
| 128 | + - Exports commonly used functions from utility modules |
| 129 | + |
| 130 | +## 📖 Usage Guide |
| 131 | + |
| 132 | +### Basic Usage Pattern |
| 133 | + |
| 134 | +All training entry functions follow a similar calling pattern: |
| 135 | + |
| 136 | +```python |
| 137 | +from lzero.entry import train_muzero |
| 138 | + |
| 139 | +# Prepare configuration |
| 140 | +cfg = dict(...) # User configuration |
| 141 | +create_cfg = dict(...) # Creation configuration |
| 142 | + |
| 143 | +# Start training |
| 144 | +policy = train_muzero( |
| 145 | + input_cfg=(cfg, create_cfg), |
| 146 | + seed=0, |
| 147 | + model=None, # Optional: pre-initialized model |
| 148 | + model_path=None, # Optional: pretrained model path |
| 149 | + max_train_iter=int(1e10), # Maximum training iterations |
| 150 | + max_env_step=int(1e10), # Maximum environment steps |
| 151 | +) |
| 152 | +``` |
| 153 | + |
| 154 | +### Choosing the Right Entry Function |
| 155 | + |
| 156 | +1. **Single-Task Learning**: |
| 157 | + - Board games → `train_alphazero` |
| 158 | + - General RL tasks → `train_muzero` or `train_unizero` |
| 159 | + - Gym environments → `train_muzero_with_gym_env` |
| 160 | + |
| 161 | +2. **Multi-Task Learning**: |
| 162 | + - Standard multi-task → `train_unizero_multitask_segment_ddp` |
| 163 | + - Balanced task sampling → `train_unizero_multitask_balance_segment_ddp` |
| 164 | + |
| 165 | +3. **Distributed Training**: |
| 166 | + - All entry functions with `_ddp` suffix support distributed training |
| 167 | + |
| 168 | +4. **Special Requirements**: |
| 169 | + - Loss landscape visualization → `train_unizero_with_loss_landscape` |
| 170 | + - External reward model → `train_muzero_with_reward_model` |
| 171 | + - Improved training stability → `train_rezero` |
| 172 | + |
| 173 | +## 🔗 Related Resources |
| 174 | + |
| 175 | +- **AlphaZero**: [Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm](https://arxiv.org/abs/1712.01815) |
| 176 | +- **MuZero**: [Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model](https://arxiv.org/abs/1911.08265) |
| 177 | +- **EfficientZero**: [Mastering Atari Games with Limited Data](https://arxiv.org/abs/2111.00210) |
| 178 | +- **UniZero**: [Generalized and Efficient Planning with Scalable Latent World Models](https://arxiv.org/abs/2406.10667) |
| 179 | +- **ReZero**: [Boosting MCTS-based Algorithms by Reconstructing the Terminal Reward](https://arxiv.org/abs/2404.16364) |
| 180 | + |
| 181 | +## 💡 Tips |
| 182 | + |
| 183 | +- Recommended to start with standard `train_muzero` or `train_unizero` |
| 184 | +- For large-scale experiments, consider using DDP versions for faster training |
| 185 | +- Using `_segment` versions can achieve better sample efficiency (via reanalyze trick) |
| 186 | +- Check configuration examples in `zoo/` directory to learn how to set up each algorithm |
| 187 | + |
| 188 | +## 📝 Notes |
| 189 | + |
| 190 | +1. All path parameters should use **absolute paths** |
| 191 | +2. Pretrained model paths typically follow format: `exp_name/ckpt/ckpt_best.pth.tar` |
| 192 | +3. When using distributed training, ensure `CUDA_VISIBLE_DEVICES` environment variable is set correctly |
| 193 | +4. Some entry functions have specific algorithm type requirements - check function documentation |
0 commit comments