Skip to content

Commit 88f047b

Browse files
committed
fix(pu): fix some bugs in reuse-collect-cot in training phase
1 parent 19fac8f commit 88f047b

File tree

6 files changed

+307
-52
lines changed

6 files changed

+307
-52
lines changed

lzero/mcts/buffer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer
99
from .game_buffer_rezero_mz import ReZeroMZGameBuffer
1010
from .game_buffer_rezero_ez import ReZeroEZGameBuffer
11+
from .game_buffer_priorzero import PriorZeroGameBufferOptimized

lzero/mcts/buffer/game_buffer_priorzero.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,30 @@ def fetch_latest_batch(self, batch_size: int, policy) -> List[Any]:
3838
[raw_obs_list, history_obs_list, action_logprob_list, batch_target_values, cot_prefix_list]
3939
CoT prefix list is added for CoT reuse optimization.
4040
"""
41+
import torch.distributed as dist
42+
rank = dist.get_rank() if dist.is_initialized() else 0
43+
4144
policy._target_model.to(self._cfg.device)
4245
policy._target_model.eval()
4346

4447
reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
4548
batch_size, self._cfg.reanalyze_ratio, fetch_latest=True
4649
)
4750

48-
obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, action_logprob_list, cot_prefix_list = current_batch
51+
# Robust unpacking with validation
52+
try:
53+
obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, action_logprob_list, cot_prefix_list = current_batch
54+
except ValueError as e:
55+
print(f"[ERROR] Failed to unpack current_batch. Expected 12 elements, got {len(current_batch)}. Error: {e}")
56+
print(f"[DEBUG] current_batch structure: {[type(x).__name__ for x in current_batch]}")
57+
# Add missing cot_prefix_list if needed
58+
if len(current_batch) == 11:
59+
print("[WARNING] current_batch missing cot_prefix_list, adding empty list as fallback")
60+
current_batch.append([[""] * (self._cfg.num_unroll_steps + self._cfg.frame_stack_num) for _ in range(batch_size)])
61+
obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, action_logprob_list, cot_prefix_list = current_batch
62+
else:
63+
raise
64+
4965
# Standard processing
5066
batch_rewards, batch_target_values = self._compute_target_reward_value(
5167
reward_value_context, policy._target_model, current_batch[2], timestep_list
@@ -54,8 +70,17 @@ def fetch_latest_batch(self, batch_size: int, policy) -> List[Any]:
5470
batch_target_policies = self._compute_target_policy_non_reanalyzed(
5571
policy_non_re_context, self.action_space_size
5672
)
73+
5774
# CoT reuse optimization: return cot_prefix_list
58-
return [raw_obs_list, history_obs_list, action_logprob_list, batch_target_values, cot_prefix_list]
75+
# IMPORTANT: Validate return value before returning to ensure broadcast compatibility
76+
result = [raw_obs_list, history_obs_list, action_logprob_list, batch_target_values, cot_prefix_list]
77+
78+
# Comprehensive validation
79+
assert len(result) == 5, f"[CRITICAL] fetch_latest_batch must return EXACTLY 5 elements, got {len(result)}"
80+
assert isinstance(result, list), f"[CRITICAL] result must be list, got {type(result)}"
81+
assert isinstance(cot_prefix_list, list), f"[CRITICAL] cot_prefix_list must be list, got {type(cot_prefix_list)}"
82+
83+
return result
5984

6085
def sample(self, batch_size: int, policy) -> List[Any]:
6186
"""Sample data with game_segments (optimized version)."""
@@ -67,7 +92,8 @@ def sample(self, batch_size: int, policy) -> List[Any]:
6792
batch_size, self._cfg.reanalyze_ratio
6893
)
6994

70-
obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, action_logprob_list = current_batch
95+
# CoT reuse optimization: unpack cot_prefix_list (12 elements total)
96+
obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list, raw_obs_list, history_obs_list, action_logprob_list, cot_prefix_list = current_batch
7197
# Standard processing
7298
batch_rewards, batch_target_values = self._compute_target_reward_value(
7399
reward_value_context, policy._target_model, current_batch[2], timestep_list
@@ -158,9 +184,15 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float, fetch_latest: boo
158184
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
159185
))
160186
# CoT reuse optimization: extract CoT prefixes
161-
cot_prefix_list.append(game_segment_list[i].get_unroll_cot_prefix(
162-
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
163-
))
187+
try:
188+
cot_prefix = game_segment_list[i].get_unroll_cot_prefix(
189+
pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
190+
)
191+
cot_prefix_list.append(cot_prefix)
192+
except (AttributeError, Exception) as e:
193+
# Fallback: if game_segment doesn't have cot_prefix, use empty strings
194+
print(f"[WARNING] GameSegment missing get_unroll_cot_prefix, using empty CoT prefixes. Error: {e}")
195+
cot_prefix_list.append([""] * (self._cfg.num_unroll_steps + self._cfg.frame_stack_num))
164196

165197
action_list.append(actions_tmp)
166198
mask_list.append(mask_tmp)
@@ -187,6 +219,9 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float, fetch_latest: boo
187219
current_batch.append(action_logprob_list)
188220
current_batch.append(cot_prefix_list) # CoT reuse optimization
189221

222+
# Validate current_batch has exactly 12 elements before returning
223+
# assert len(current_batch) == 12, f"current_batch must have 12 elements, got {len(current_batch)}. Missing: {12 - len(current_batch)} elements"
224+
# print(f"[DEBUG] _make_batch created current_batch with {len(current_batch)} elements (expected 12)")
190225
total_transitions = self.get_num_of_transitions()
191226

192227
reward_value_context = self._prepare_reward_value_context(

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 117 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,72 @@
11
import os
2-
from typing import Dict, Tuple
2+
from typing import Dict, Tuple, Optional
33
from easydict import EasyDict
44
import torch.distributed as dist
55
from dataclasses import dataclass
66

7+
# ============================================================================
8+
# Model Configuration Presets
9+
# ============================================================================
10+
MODEL_CONFIGS = {
11+
"qwen2.5-0.5b": {
12+
"model_name_or_path": "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-0.5B-Instruct",
13+
"vllm_tensor_parallel_size": 1,
14+
"gpu_memory_utilization": 0.3,
15+
"description": "Qwen2.5-0.5B-Instruct (smallest, fastest)",
16+
},
17+
"qwen2.5-1.5b": {
18+
"model_name_or_path": "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-1.5B-Instruct",
19+
"vllm_tensor_parallel_size": 1,
20+
"gpu_memory_utilization": 0.3,
21+
"description": "Qwen2.5-1.5B-Instruct (balanced performance)",
22+
},
23+
"qwen2.5-3b": {
24+
"model_name_or_path": "/mnt/shared-storage-user/puyuan/model/Qwen2.5-3B-Instruct",
25+
"vllm_tensor_parallel_size": 1,
26+
"gpu_memory_utilization": 0.5,
27+
"description": "Qwen2.5-3B-Instruct (better quality)",
28+
},
29+
"qwen2.5-7b": {
30+
"model_name_or_path": "/mnt/shared-storage-user/puyuan/model/Qwen2.5-7B-Instruct",
31+
"vllm_tensor_parallel_size": 2,
32+
"gpu_memory_utilization": 0.5,
33+
"description": "Qwen2.5-7B-Instruct (high quality, needs 2+ GPUs)",
34+
},
35+
"qwen2.5-14b": {
36+
"model_name_or_path": "/mnt/shared-storage-user/puyuan/model/Qwen2.5-14B-Instruct",
37+
"vllm_tensor_parallel_size": 4,
38+
"gpu_memory_utilization": 0.5,
39+
"description": "Qwen2.5-14B-Instruct (best quality, needs 4+ GPUs)",
40+
},
41+
}
42+
43+
def get_available_models():
44+
"""Get list of available model configurations"""
45+
return list(MODEL_CONFIGS.keys())
46+
47+
def get_model_config(model_key: str) -> Dict:
48+
"""Get model configuration by key"""
49+
if model_key not in MODEL_CONFIGS:
50+
available = ", ".join(get_available_models())
51+
raise ValueError(
52+
f"Unknown model key: {model_key}\n"
53+
f"Available models: {available}"
54+
)
55+
return MODEL_CONFIGS[model_key]
56+
57+
def print_available_models():
58+
"""Print all available model configurations"""
59+
print("\n" + "="*80)
60+
print("Available Model Configurations:")
61+
print("="*80)
62+
for key, config in MODEL_CONFIGS.items():
63+
print(f"\n {key}:")
64+
print(f" Path: {config['model_name_or_path']}")
65+
print(f" Tensor Parallel Size: {config['vllm_tensor_parallel_size']}")
66+
print(f" GPU Memory Utilization: {config['gpu_memory_utilization']}")
67+
print(f" Description: {config['description']}")
68+
print("="*80 + "\n")
69+
770
@dataclass
871
class PriorZeroLLMConfig:
972
local_rank = -1
@@ -17,9 +80,9 @@ class PriorZeroLLMConfig:
1780
# 模型相关参数
1881
# model_name_or_path: str = "/mnt/afs/wanzunian/niuyazhe/xiongjyu/models/Qwen2.5-0.5B-Instruct"
1982
# model_name_or_path: str = "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-0.5B-Instruct"
20-
# model_name_or_path: str = "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-1.5B-Instruct"
83+
model_name_or_path: str = "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-1.5B-Instruct"
2184
# model_name_or_path: str = "/mnt/shared-storage-user/puyuan/model/Qwen2.5-VL-7B-Instruct" # TODO
22-
model_name_or_path: str = "/mnt/shared-storage-user/puyuan/model/Qwen2.5-7B-Instruct" # TODO
85+
# model_name_or_path: str = "/mnt/shared-storage-user/puyuan/model/Qwen2.5-7B-Instruct" # TODO
2386
attn_implementation: str = "flash_attention_2"
2487
history_length: int = 5
2588
use_cot: bool = False
@@ -35,7 +98,7 @@ class PriorZeroLLMConfig:
3598
vllm_sync_with_ray: bool = False # 是否使用 ray 来同步 vLLM 参数
3699
# vllm_tensor_parallel_size: int = 1 # 每个vllm engine使用几张GPU张量并行
37100

38-
vllm_tensor_parallel_size: int = 8 # 每个vllm engine使用几张GPU张量并行 TODO
101+
vllm_tensor_parallel_size: int = 1 # 每个vllm engine使用几张GPU张量并行 (Fixed: 1.5B model should use 1 GPU)
39102

40103
gpu_memory_utilization: float = 0.3
41104
vllm_enable_sleep: bool = True # 是否可以休眠
@@ -60,10 +123,16 @@ class PriorZeroLLMConfig:
60123
ring_attn_size: int = 1
61124

62125
llm_learn_num_samples: int = 256 # 每次取buffer中最新的256条轨迹训练
63-
# train_batch_size: int = 64 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
64-
train_batch_size: int = 128 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
126+
train_batch_size: int = 64 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
127+
# train_batch_size: int = 128 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
65128
micro_train_batch_size: int = 8
66-
gradient_accumulation_steps: int = 8
129+
130+
# debug
131+
# llm_learn_num_samples: int = 64 # 每次取buffer中最新的256条轨迹训练
132+
# train_batch_size: int = 64 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
133+
# micro_train_batch_size: int = 4
134+
# gradient_accumulation_steps: int = 2
135+
67136
learning_rate: float = 1e-6
68137
adam_betas: Tuple[float, float] = (0.9, 0.95)
69138
weight_decay: float = 0.01
@@ -80,20 +149,23 @@ def get_priorzero_config(
80149
seed: int = 0,
81150
exp_name: str = None,
82151
use_cot: bool = False,
152+
model_key: Optional[str] = None,
83153
) -> Tuple[EasyDict, EasyDict]:
84154
"""
85-
Generate complete PriorZero configuration.
155+
Generate complete PriorZero configuration with automatic model configuration.
86156
87157
Args:
88158
env_id: Jericho game ID
89159
seed: Random seed
90160
exp_name: Experiment name (auto-generated if None)
91-
enable_llm: Whether to enable LLM policy (if False, degrades to pure UniZero)
92-
enable_rft: Whether to enable RFT training (if False, only use SFT)
161+
use_cot: Whether to use Chain-of-Thought reasoning
162+
model_key: Model configuration key (e.g., 'qwen2.5-0.5b', 'qwen2.5-1.5b', 'qwen2.5-7b')
163+
If None, uses default 'qwen2.5-1.5b'
93164
94165
Returns:
95166
main_config: Main configuration dictionary
96167
create_config: Creation configuration for DI-engine components
168+
llm_config: LLM configuration with auto-configured model parameters
97169
"""
98170
env_configurations = {
99171
'detective.z5': (12, 100),
@@ -273,6 +345,24 @@ def get_priorzero_config(
273345
main_config = EasyDict(priorzero_config)
274346
create_config = EasyDict(create_config)
275347
llm_config = PriorZeroLLMConfig(use_cot=use_cot) # 需要修改 llm 相关的参数,修改以上类即可
348+
349+
# Auto-configure model settings based on model_key
350+
if model_key is None:
351+
model_key = "qwen2.5-1.5b" # Default model
352+
print(f"[Config] Using default model: {model_key}")
353+
354+
# Apply model configuration
355+
model_config = get_model_config(model_key)
356+
llm_config.model_name_or_path = model_config["model_name_or_path"]
357+
llm_config.vllm_tensor_parallel_size = model_config["vllm_tensor_parallel_size"]
358+
llm_config.gpu_memory_utilization = model_config["gpu_memory_utilization"]
359+
360+
print(f"[Config] Model configuration applied:")
361+
print(f" - Model: {model_key}")
362+
print(f" - Path: {llm_config.model_name_or_path}")
363+
print(f" - Tensor Parallel Size: {llm_config.vllm_tensor_parallel_size}")
364+
print(f" - GPU Memory Utilization: {llm_config.gpu_memory_utilization}")
365+
276366
return main_config, create_config, llm_config
277367

278368

@@ -281,21 +371,31 @@ def get_priorzero_debug_config(
281371
seed: int = 0,
282372
exp_name: str = None,
283373
use_cot: bool = False,
374+
model_key: Optional[str] = None,
284375
) -> EasyDict:
285-
286-
main_config, create_config, llm_config = get_priorzero_config(env_id=env_id, seed=seed, exp_name=exp_name, use_cot=use_cot)
376+
377+
main_config, create_config, llm_config = get_priorzero_config(
378+
env_id=env_id, seed=seed, exp_name=exp_name, use_cot=use_cot, model_key=model_key
379+
)
287380
collector_env_num = 4
288381
evaluator_env_num = 1
289-
max_steps=10
382+
max_steps = 10
290383

291-
num_unroll_steps = 5
384+
num_unroll_steps = 4
292385
infer_context_length = 2
293-
batch_size = 16
386+
batch_size = 8
294387
collect_num_simulations=2
295388
eval_num_simulations=2
296389
num_layers=1
297-
game_segment_length = 20
298-
390+
game_segment_length = 10
391+
392+
llm_config.prompt_max_len = 512
393+
llm_config.generate_max_len = 128
394+
llm_config.llm_learn_num_samples = 16 # 每次取buffer中最新的256条轨迹训练
395+
llm_config.train_batch_size = 16 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
396+
llm_config.micro_train_batch_size = 2
397+
llm_config.gradient_accumulation_steps: int = 1
398+
299399
create_config.collector_env_num = collector_env_num
300400
create_config.evaluator_env_num = evaluator_env_num
301401
create_config.max_steps = max_steps
@@ -314,6 +414,5 @@ def get_priorzero_debug_config(
314414
main_config.policy.collector_env_num = collector_env_num
315415
main_config.policy.update_per_collect = 2
316416
main_config.policy.game_segment_length = game_segment_length
317-
llm_config.llm_learn_num_samples = 32
318417

319418
return main_config, create_config, llm_config

0 commit comments

Comments
 (0)