11import os
2- from typing import Dict , Tuple
2+ from typing import Dict , Tuple , Optional
33from easydict import EasyDict
44import torch .distributed as dist
55from dataclasses import dataclass
66
7+ # ============================================================================
8+ # Model Configuration Presets
9+ # ============================================================================
10+ MODEL_CONFIGS = {
11+ "qwen2.5-0.5b" : {
12+ "model_name_or_path" : "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-0.5B-Instruct" ,
13+ "vllm_tensor_parallel_size" : 1 ,
14+ "gpu_memory_utilization" : 0.3 ,
15+ "description" : "Qwen2.5-0.5B-Instruct (smallest, fastest)" ,
16+ },
17+ "qwen2.5-1.5b" : {
18+ "model_name_or_path" : "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-1.5B-Instruct" ,
19+ "vllm_tensor_parallel_size" : 1 ,
20+ "gpu_memory_utilization" : 0.3 ,
21+ "description" : "Qwen2.5-1.5B-Instruct (balanced performance)" ,
22+ },
23+ "qwen2.5-3b" : {
24+ "model_name_or_path" : "/mnt/shared-storage-user/puyuan/model/Qwen2.5-3B-Instruct" ,
25+ "vllm_tensor_parallel_size" : 1 ,
26+ "gpu_memory_utilization" : 0.5 ,
27+ "description" : "Qwen2.5-3B-Instruct (better quality)" ,
28+ },
29+ "qwen2.5-7b" : {
30+ "model_name_or_path" : "/mnt/shared-storage-user/puyuan/model/Qwen2.5-7B-Instruct" ,
31+ "vllm_tensor_parallel_size" : 2 ,
32+ "gpu_memory_utilization" : 0.5 ,
33+ "description" : "Qwen2.5-7B-Instruct (high quality, needs 2+ GPUs)" ,
34+ },
35+ "qwen2.5-14b" : {
36+ "model_name_or_path" : "/mnt/shared-storage-user/puyuan/model/Qwen2.5-14B-Instruct" ,
37+ "vllm_tensor_parallel_size" : 4 ,
38+ "gpu_memory_utilization" : 0.5 ,
39+ "description" : "Qwen2.5-14B-Instruct (best quality, needs 4+ GPUs)" ,
40+ },
41+ }
42+
43+ def get_available_models ():
44+ """Get list of available model configurations"""
45+ return list (MODEL_CONFIGS .keys ())
46+
47+ def get_model_config (model_key : str ) -> Dict :
48+ """Get model configuration by key"""
49+ if model_key not in MODEL_CONFIGS :
50+ available = ", " .join (get_available_models ())
51+ raise ValueError (
52+ f"Unknown model key: { model_key } \n "
53+ f"Available models: { available } "
54+ )
55+ return MODEL_CONFIGS [model_key ]
56+
57+ def print_available_models ():
58+ """Print all available model configurations"""
59+ print ("\n " + "=" * 80 )
60+ print ("Available Model Configurations:" )
61+ print ("=" * 80 )
62+ for key , config in MODEL_CONFIGS .items ():
63+ print (f"\n { key } :" )
64+ print (f" Path: { config ['model_name_or_path' ]} " )
65+ print (f" Tensor Parallel Size: { config ['vllm_tensor_parallel_size' ]} " )
66+ print (f" GPU Memory Utilization: { config ['gpu_memory_utilization' ]} " )
67+ print (f" Description: { config ['description' ]} " )
68+ print ("=" * 80 + "\n " )
69+
770@dataclass
871class PriorZeroLLMConfig :
972 local_rank = - 1
@@ -17,9 +80,9 @@ class PriorZeroLLMConfig:
1780 # 模型相关参数
1881 # model_name_or_path: str = "/mnt/afs/wanzunian/niuyazhe/xiongjyu/models/Qwen2.5-0.5B-Instruct"
1982 # model_name_or_path: str = "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-0.5B-Instruct"
20- # model_name_or_path: str = "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-1.5B-Instruct"
83+ model_name_or_path : str = "/mnt/shared-storage-user/puyuan/xiongjyu/models/Qwen2.5-1.5B-Instruct"
2184 # model_name_or_path: str = "/mnt/shared-storage-user/puyuan/model/Qwen2.5-VL-7B-Instruct" # TODO
22- model_name_or_path : str = "/mnt/shared-storage-user/puyuan/model/Qwen2.5-7B-Instruct" # TODO
85+ # model_name_or_path: str = "/mnt/shared-storage-user/puyuan/model/Qwen2.5-7B-Instruct" # TODO
2386 attn_implementation : str = "flash_attention_2"
2487 history_length : int = 5
2588 use_cot : bool = False
@@ -35,7 +98,7 @@ class PriorZeroLLMConfig:
3598 vllm_sync_with_ray : bool = False # 是否使用 ray 来同步 vLLM 参数
3699 # vllm_tensor_parallel_size: int = 1 # 每个vllm engine使用几张GPU张量并行
37100
38- vllm_tensor_parallel_size : int = 8 # 每个vllm engine使用几张GPU张量并行 TODO
101+ vllm_tensor_parallel_size : int = 1 # 每个vllm engine使用几张GPU张量并行 (Fixed: 1.5B model should use 1 GPU)
39102
40103 gpu_memory_utilization : float = 0.3
41104 vllm_enable_sleep : bool = True # 是否可以休眠
@@ -60,10 +123,16 @@ class PriorZeroLLMConfig:
60123 ring_attn_size : int = 1
61124
62125 llm_learn_num_samples : int = 256 # 每次取buffer中最新的256条轨迹训练
63- # train_batch_size: int = 64 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
64- train_batch_size : int = 128 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
126+ train_batch_size : int = 64 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
127+ # train_batch_size: int = 128 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
65128 micro_train_batch_size : int = 8
66- gradient_accumulation_steps : int = 8
129+
130+ # debug
131+ # llm_learn_num_samples: int = 64 # 每次取buffer中最新的256条轨迹训练
132+ # train_batch_size: int = 64 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
133+ # micro_train_batch_size: int = 4
134+ # gradient_accumulation_steps: int = 2
135+
67136 learning_rate : float = 1e-6
68137 adam_betas : Tuple [float , float ] = (0.9 , 0.95 )
69138 weight_decay : float = 0.01
@@ -80,20 +149,23 @@ def get_priorzero_config(
80149 seed : int = 0 ,
81150 exp_name : str = None ,
82151 use_cot : bool = False ,
152+ model_key : Optional [str ] = None ,
83153) -> Tuple [EasyDict , EasyDict ]:
84154 """
85- Generate complete PriorZero configuration.
155+ Generate complete PriorZero configuration with automatic model configuration .
86156
87157 Args:
88158 env_id: Jericho game ID
89159 seed: Random seed
90160 exp_name: Experiment name (auto-generated if None)
91- enable_llm: Whether to enable LLM policy (if False, degrades to pure UniZero)
92- enable_rft: Whether to enable RFT training (if False, only use SFT)
161+ use_cot: Whether to use Chain-of-Thought reasoning
162+ model_key: Model configuration key (e.g., 'qwen2.5-0.5b', 'qwen2.5-1.5b', 'qwen2.5-7b')
163+ If None, uses default 'qwen2.5-1.5b'
93164
94165 Returns:
95166 main_config: Main configuration dictionary
96167 create_config: Creation configuration for DI-engine components
168+ llm_config: LLM configuration with auto-configured model parameters
97169 """
98170 env_configurations = {
99171 'detective.z5' : (12 , 100 ),
@@ -273,6 +345,24 @@ def get_priorzero_config(
273345 main_config = EasyDict (priorzero_config )
274346 create_config = EasyDict (create_config )
275347 llm_config = PriorZeroLLMConfig (use_cot = use_cot ) # 需要修改 llm 相关的参数,修改以上类即可
348+
349+ # Auto-configure model settings based on model_key
350+ if model_key is None :
351+ model_key = "qwen2.5-1.5b" # Default model
352+ print (f"[Config] Using default model: { model_key } " )
353+
354+ # Apply model configuration
355+ model_config = get_model_config (model_key )
356+ llm_config .model_name_or_path = model_config ["model_name_or_path" ]
357+ llm_config .vllm_tensor_parallel_size = model_config ["vllm_tensor_parallel_size" ]
358+ llm_config .gpu_memory_utilization = model_config ["gpu_memory_utilization" ]
359+
360+ print (f"[Config] Model configuration applied:" )
361+ print (f" - Model: { model_key } " )
362+ print (f" - Path: { llm_config .model_name_or_path } " )
363+ print (f" - Tensor Parallel Size: { llm_config .vllm_tensor_parallel_size } " )
364+ print (f" - GPU Memory Utilization: { llm_config .gpu_memory_utilization } " )
365+
276366 return main_config , create_config , llm_config
277367
278368
@@ -281,21 +371,31 @@ def get_priorzero_debug_config(
281371 seed : int = 0 ,
282372 exp_name : str = None ,
283373 use_cot : bool = False ,
374+ model_key : Optional [str ] = None ,
284375) -> EasyDict :
285-
286- main_config , create_config , llm_config = get_priorzero_config (env_id = env_id , seed = seed , exp_name = exp_name , use_cot = use_cot )
376+
377+ main_config , create_config , llm_config = get_priorzero_config (
378+ env_id = env_id , seed = seed , exp_name = exp_name , use_cot = use_cot , model_key = model_key
379+ )
287380 collector_env_num = 4
288381 evaluator_env_num = 1
289- max_steps = 10
382+ max_steps = 10
290383
291- num_unroll_steps = 5
384+ num_unroll_steps = 4
292385 infer_context_length = 2
293- batch_size = 16
386+ batch_size = 8
294387 collect_num_simulations = 2
295388 eval_num_simulations = 2
296389 num_layers = 1
297- game_segment_length = 20
298-
390+ game_segment_length = 10
391+
392+ llm_config .prompt_max_len = 512
393+ llm_config .generate_max_len = 128
394+ llm_config .llm_learn_num_samples = 16 # 每次取buffer中最新的256条轨迹训练
395+ llm_config .train_batch_size = 16 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
396+ llm_config .micro_train_batch_size = 2
397+ llm_config .gradient_accumulation_steps : int = 1
398+
299399 create_config .collector_env_num = collector_env_num
300400 create_config .evaluator_env_num = evaluator_env_num
301401 create_config .max_steps = max_steps
@@ -314,6 +414,5 @@ def get_priorzero_debug_config(
314414 main_config .policy .collector_env_num = collector_env_num
315415 main_config .policy .update_per_collect = 2
316416 main_config .policy .game_segment_length = game_segment_length
317- llm_config .llm_learn_num_samples = 32
318417
319418 return main_config , create_config , llm_config
0 commit comments