Skip to content

Commit f723e41

Browse files
committed
polish(pu): polish comments and code styles in config
1 parent 6190c08 commit f723e41

12 files changed

+66
-225
lines changed

lzero/policy/scaling_transform.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Union
22
import torch
3-
import numpy as np
43

54
class DiscreteSupport(object):
65

@@ -106,29 +105,34 @@ def visit_count_temperature(
106105
return fixed_temperature_value
107106

108107

108+
109109
def phi_transform(
110110
discrete_support: DiscreteSupport,
111111
x: torch.Tensor,
112-
label_smoothing_eps: float = 0. # <--- 新增平滑参数
112+
label_smoothing_eps: float = 0.0 # <--- Added smoothing parameter
113113
) -> torch.Tensor:
114114
"""
115115
Overview:
116-
Map a real-valued scalar to a categorical distribution over a discrete support using linear interpolation (a.k.a. “soft” one-hot).
116+
Map a real-valued scalar to a categorical distribution over a discrete support
117+
using linear interpolation (a.k.a. “soft” one-hot).
117118
118-
For each scalar value the probability mass is split between the two
119+
For each scalar value, the probability mass is split between the two
119120
nearest support atoms so that their weighted sum equals the original
120-
value (MuZero, Appendix F).
121+
value (see MuZero, Appendix F).
121122
122123
Arguments:
123124
- discrete_support : DiscreteSupport
124125
Container with the support values (must be evenly spaced).
125126
- x : torch.Tensor
126127
Input tensor of arbitrary shape ``(...,)`` containing real numbers.
128+
- label_smoothing_eps : float
129+
Epsilon value for label smoothing. If greater than 0, the resulting
130+
distribution is mixed with a uniform distribution. Defaults to 0.
127131
128132
Returns:
129133
- torch.Tensor
130134
Tensor of shape ``(*x.shape, N)`` where ``N = discrete_support.size``.
131-
The last dimension is a probability distribution (sums to 1).
135+
The last dimension represents a probability distribution (sums to 1).
132136
133137
Notes
134138
-----
@@ -141,20 +145,21 @@ def phi_transform(
141145
step = discrete_support.step
142146
size = discrete_support.size
143147

144-
# --- 1. clip to the valid range ----------------------------------------
148+
# --- 1. Clip to the valid range ----------------------------------------
145149
x = x.clamp(min_bound, max_bound)
146150

147-
# --- 2. locate neighbouring indices ------------------------------------
148-
pos = (x - min_bound) / step # continuous position
149-
low_idx_float = torch.floor(pos) # lower index
150-
low_idx_long = low_idx_float.long() # lower index
151-
high_idx = low_idx_long + 1 # upper index (may overflow)
151+
# --- 2. Locate neighbouring indices ------------------------------------
152+
pos = (x - min_bound) / step # Continuous position relative to support
153+
low_idx_float = torch.floor(pos) # Lower index (float)
154+
low_idx_long = low_idx_float.long() # Lower index (long)
155+
high_idx = low_idx_long + 1 # Upper index (may temporarily overflow)
152156

153-
# --- 3. linear interpolation weights -----------------------------------
154-
p_high = pos - low_idx_float # distance to lower atom
155-
p_low = 1.0 - p_high # complementary mass
157+
# --- 3. Linear interpolation weights -----------------------------------
158+
p_high = pos - low_idx_float # Distance to the lower atom (weight for upper)
159+
p_low = 1.0 - p_high # Complementary mass (weight for lower)
156160

157-
# --- 4. stack indices / probs and scatter ------------------------------
161+
# --- 4. Stack indices / probs and scatter ------------------------------
162+
# Clamp high_idx to handle the edge case where x is exactly max_bound
158163
idx = torch.stack([low_idx_long,
159164
torch.clamp(high_idx, max=size - 1)], dim=-1) # (*x, 2)
160165
prob = torch.stack([p_low, p_high], dim=-1) # (*x, 2)
@@ -163,11 +168,10 @@ def phi_transform(
163168
dtype=x.dtype, device=x.device)
164169

165170
target.scatter_add_(-1, idx, prob)
166-
# return target
167171

168-
# --- 5. 应用标签平滑 ---
172+
# --- 5. Apply label smoothing ------------------------------------------
169173
if label_smoothing_eps > 0:
170-
# 将原始的 two-hot 目标与一个均匀分布混合
174+
# Mix the original "two-hot" target with a uniform distribution
171175
smooth_target = (1.0 - label_smoothing_eps) * target + (label_smoothing_eps / size)
172176
return smooth_target
173177
else:

lzero/policy/unizero.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,38 @@ class UniZeroPolicy(MuZeroPolicy):
209209
rope_theta=10000,
210210
# (int) The maximum sequence length for position encoding.
211211
max_seq_len=8192,
212-
lora_r= 0,
212+
# (int) The rank parameter for LoRA (Low-Rank Adaptation). Set to 0 to disable LoRA.
213+
lora_r=0,
214+
# (float) The alpha parameter for LoRA scaling.
215+
lora_alpha=1,
216+
# (float) The dropout probability for LoRA layers.
217+
lora_dropout=0.0,
213218
# Controls where to compute reconstruction loss: 'after_backbone', 'before_backbone', or None.
214219
# - after_backbone: The reconstruction loss is computed after the encoded representation passes through the backbone.
215-
# - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone.
220+
# - before_backbone: The reconstruction loss is computed directly on the encoded representation, without the backbone.
216221
decode_loss_mode=None,
222+
# (str/None) Task embedding option. Set to None to disable task-specific embeddings.
223+
task_embed_option=None,
224+
# (bool) Whether to use task embeddings.
225+
use_task_embed=False,
226+
# (bool) Whether to use normal head (standard prediction heads).
227+
use_normal_head=True,
228+
# (bool) Whether to use Soft Mixture-of-Experts (MoE) head.
229+
use_softmoe_head=False,
230+
# (bool) Whether to use Mixture-of-Experts (MoE) head.
231+
use_moe_head=False,
232+
# (int) Number of experts in the MoE head.
233+
num_experts_in_moe_head=4,
234+
# (bool) Whether to use MoE in the transformer layers.
235+
moe_in_transformer=False,
236+
# (bool) Whether to use multiplicative MoE in the transformer layers.
237+
multiplication_moe_in_transformer=False,
238+
# (int) Number of shared experts in MoE.
239+
n_shared_experts=1,
240+
# (int) Number of experts to use per token in MoE.
241+
num_experts_per_tok=1,
242+
# (int) Total number of experts in the transformer MoE.
243+
num_experts_of_moe_in_transformer=8,
217244
),
218245
),
219246
# ****** common ******

zoo/atari/config/atari_env_action_space_map_v4.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

zoo/atari/config/atari_env_action_space_map_v5.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,6 @@ def create_env_manager() -> EasyDict:
315315
cd /path/to/your/project/
316316
317317
torchrun --nproc_per_node=4 /mnt/shared-storage-user/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /mnt/shared-storage-user/puyuan/code/LightZero/logs/202512/atari8_uz_mt.log
318-
319-
/mnt/shared-storage-user/puyuan/lz/bin/python -m torch.distributed.launch --nproc_per_node=4 --master_port=29502 /mnt/shared-storage-user/puyuan/code/LightZero/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py 2>&1 | tee /mnt/shared-storage-user/puyuan/code/LightZero/logs/202512/atari8_uz_mt.log
320-
321318
"""
322319
from lzero.entry import train_unizero_multitask_segment_ddp
323320
from ding.utils import DDPContext

zoo/dmc2gym/config/dmc2gym_pixels_sampled_unizero_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
# ==============================================================
4141

4242
dmc2gym_pixels_cont_sampled_unizero_config = dict(
43-
exp_name=f'data_sampled_unizero_0901/dmc2gym_{env_id}_image_cont_sampled_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_{norm_type}_seed{seed}',
43+
exp_name=f'data_sampled_unizero/dmc2gym_{env_id}_image_cont_sampled_unizero_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_{norm_type}_seed{seed}',
4444
env=dict(
4545
env_id='dmc2gym-v0',
4646
continuous=True,
@@ -75,7 +75,6 @@
7575
max_blocks=num_unroll_steps,
7676
max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action
7777
context_length=2 * infer_context_length,
78-
# device='cpu',
7978
device='cuda',
8079
action_space_size=action_space_size,
8180
num_layers=2,
@@ -116,7 +115,6 @@
116115
type='dmc2gym_lightzero',
117116
import_names=['zoo.dmc2gym.envs.dmc2gym_lightzero_env'],
118117
),
119-
# env_manager=dict(type='subprocess'),
120118
env_manager=dict(type='base'),
121119
policy=dict(
122120
type='sampled_unizero',

zoo/dmc2gym/config/dmc2gym_state_smz_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
# ==============================================================
3131

3232
dmc2gym_state_cont_sampled_muzero_config = dict(
33-
exp_name=f'/oss/niuyazhe/puyuan/data/data_lz_202505/data_smz/dmc2gym_{env_id}_state_cont_sampled_muzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_{norm_type}_seed{seed}',
33+
exp_name=f'data_smz/dmc2gym_{env_id}_state_cont_sampled_muzero_k{K}_ns{num_simulations}_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_{norm_type}_seed{seed}',
3434
env=dict(
3535
env_id='dmc2gym-v0',
3636
domain_name=domain_name,

zoo/dmc2gym/config/dmc2gym_state_suz_config.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -99,35 +99,6 @@ def main(env_id, seed):
9999
embed_dim=768,
100100
env_num=max(collector_env_num, evaluator_env_num),
101101
rotary_emb=False,
102-
103-
104-
# --- MOE Settings ---
105-
moe_in_transformer=False,
106-
# multiplication_moe_in_transformer=True,
107-
multiplication_moe_in_transformer=False,
108-
num_experts_of_moe_in_transformer=8,
109-
n_shared_experts=1,
110-
num_experts_per_tok=1,
111-
use_normal_head=True,
112-
use_softmoe_head=False,
113-
use_moe_head=False,
114-
num_experts_in_moe_head=4,
115-
116-
# --- LoRA Parameters ---
117-
moe_use_lora=False, # TODO
118-
curriculum_stage_num=3,
119-
lora_target_modules=["attn", "feed_forward"],
120-
lora_r=0,
121-
lora_alpha=1,
122-
lora_dropout=0.0,
123-
124-
# --- Multi-task Settings ---
125-
task_embed_option=None, # TODO: 'concat_task_embed' or None
126-
use_task_embed=False, # TODO
127-
128-
# --- Analysis ---
129-
analysis_dormant_ratio_weight_rank=False, # TODO
130-
analysis_dormant_ratio_interval=5000,
131102
),
132103
),
133104
# (str) The path of the pretrained model. If None, the model will be initialized by the default model.

zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,9 @@
77
"""
88
from __future__ import annotations
99

10-
import logging
1110
from typing import Any, Dict, List
12-
1311
from easydict import EasyDict
1412
import copy
15-
# ==============================================================
16-
# Global setup: Logging
17-
# ==============================================================
18-
logging.basicConfig(
19-
level=logging.INFO,
20-
format='%(asctime)s - %(message)s',
21-
handlers=[
22-
logging.FileHandler("output.log", encoding="utf-8"), # Log to file
23-
logging.StreamHandler() # Log to console
24-
]
25-
)
2613

2714

2815
def get_base_config(env_id_list: list[str], collector_env_num: int, evaluator_env_num: int,
@@ -58,8 +45,8 @@ def get_base_config(env_id_list: list[str], collector_env_num: int, evaluator_en
5845
),
5946
# Policy-specific settings
6047
policy=dict(
61-
multi_gpu=True, # TODO(user): Enable multi-GPU for DDP.
62-
# TODO(user): Configure MoCo settings.
48+
multi_gpu=True,
49+
# TODO: Configure MoCo settings.
6350
only_use_moco_stats=False,
6451
use_moco=False,
6552
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000))),
@@ -115,7 +102,6 @@ def get_base_config(env_id_list: list[str], collector_env_num: int, evaluator_en
115102
# TODO(user): For debugging only. Use a smaller model.
116103
# num_layers=1,
117104
num_layers=4,
118-
# num_layers=8,
119105

120106
num_heads=24,
121107
embed_dim=768,
@@ -319,24 +305,9 @@ def generate_experiment_name(num_tasks: int, curriculum_stage_num: int, buffer_r
319305
Returns:
320306
- (:obj:`str`): The generated experiment name prefix.
321307
"""
322-
# NOTE: This is a template for the experiment name.
323-
# Users should customize it to reflect their specific experiment settings.
324-
#
325-
# IMPORTANT: To avoid filesystem path length issues, consider using the simplified version below.
326-
# Uncomment the simplified version and comment out the detailed version if you encounter
327-
# "File name too long" errors.
328-
#
329-
# ===== Simplified Version (RECOMMENDED to avoid path length issues) =====
330-
# return f'data_20251120/dmc_{num_tasks}t_s{curriculum_stage_num}_brf{buffer_reanalyze_freq:.0e}_s{seed}/'
331-
#
332-
# ===== Detailed Version (Current) =====
333-
# return (
334-
# f'data_suz_dmc_mt_balance_20251120/dmc_{num_tasks}tasks_frameskip4-pen-fs8_balance-stage-total-{curriculum_stage_num}'
335-
# f'_stage0-10k-5k_fix-lora-update-stablescale_moe8-uselora_nlayer4_not-share-head'
336-
# f'_brf{buffer_reanalyze_freq}_seed{seed}/'
337-
# )
308+
338309
return (
339-
f'data_suz_dmc_mt_balance_20251120/dmc_{num_tasks}tasks_frameskip4-pen-fs8_balance-stage-total-{curriculum_stage_num}'
310+
f'data_suz_dmc_mt_balance/dmc_{num_tasks}tasks_frameskip4-pen-fs8_balance-stage-total-{curriculum_stage_num}'
340311
f'_stage0-10k-5k_moe8_nlayer4'
341312
f'_brf{buffer_reanalyze_freq}_seed{seed}/'
342313
)
@@ -398,7 +369,6 @@ def generate_all_task_configs(
398369

399370
for task_id, env_id in enumerate(env_id_list):
400371
task_specific_config = create_task_config(
401-
# base_config=base_config.clone(), # Use a clone to avoid modifying the base config
402372
base_config=copy.deepcopy(base_config),
403373
env_id=env_id,
404374
action_space_size_list=action_space_size_list,
@@ -435,15 +405,10 @@ def main():
435405
This script should be executed with <nproc_per_node> GPUs.
436406
437407
Example launch commands:
438-
1. Using `torch.distributed.launch`:
439-
cd <PATH_TO_YOUR_PROJECT>/LightZero/
440-
python -m torch.distributed.launch --nproc_per_node=8 --master_port=29501 \\
441-
./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py 2>&1 | tee \\
442-
./logs/uz_mt_dmc18_balance_moe8_seed0.log
443-
444-
2. Using `torchrun`:
445-
cd <PATH_TO_YOUR_PROJECT>/LightZero/
446-
torchrun --nproc_per_node=4 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py
408+
409+
cd <PATH_TO_YOUR_PROJECT>/LightZero/
410+
torchrun --nproc_per_node=4 ./zoo/dmc2gym/config/dmc2gym_state_suz_multitask_ddp_balance_config.py 2>&1 | tee \\
411+
./logs/uz_mt_dmc18_balance_moe8_seed0.log
447412
"""
448413
from lzero.entry import train_unizero_multitask_balance_segment_ddp
449414
from ding.utils import DDPContext
@@ -492,10 +457,7 @@ def main():
492457
# batch_size = [3] * len(env_id_list)
493458
# max_env_step = int(1e3)
494459

495-
# Production settings
496-
# curriculum_stage_num = 5
497-
curriculum_stage_num = 3
498-
460+
curriculum_stage_num = 5
499461
collector_env_num = 8
500462
num_segments = 8
501463
n_episode = 8

0 commit comments

Comments
 (0)