Skip to content

Commit 84ca42c

Browse files
committed
fix(pu): fix pad bug in buffer
1 parent bed92f5 commit 84ca42c

File tree

5 files changed

+199
-15
lines changed

5 files changed

+199
-15
lines changed

lzero/mcts/buffer/game_buffer_priorzero.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,65 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float, fetch_latest: boo
214214
else:
215215
policy_non_re_context = None
216216

217+
# [LAYER 4 MONITORING] Track padding statistics
218+
if not fetch_latest and hasattr(self, '_padding_warning_count') and self._padding_warning_count > 0:
219+
print(
220+
f"[Padding Monitor] {self._padding_warning_count} segments required padding adjustment in this batch. "
221+
f"Consider increasing game_segment_length or reducing num_unroll_steps."
222+
)
223+
self._padding_warning_count = 0 # Reset counter
224+
217225
return reward_value_context, policy_re_context, policy_non_re_context, current_batch
218226

219227
def _clear(self):
220228
self.game_pos_priorities = []
221229
self.game_segment_buffer = []
222230
self.game_segment_game_pos_look_up = []
231+
232+
def _adjust_pos_to_avoid_padding(self, game_segment, pos_in_game_segment):
233+
"""
234+
[CRITICAL FIX] Adjust position to ensure no padding is needed.
235+
236+
This prevents misalignment between raw_obs and cot_prefix caused by padding.
237+
238+
Args:
239+
game_segment: The game segment to sample from
240+
pos_in_game_segment: The initial sampled position
241+
242+
Returns:
243+
Adjusted position that guarantees sufficient data length
244+
"""
245+
# Calculate required length for unrolling
246+
required_len = self._cfg.model.frame_stack_num + self._cfg.num_unroll_steps
247+
248+
# Check actual available data length
249+
# CRITICAL: Must check raw_obs_segment, not action_segment!
250+
# Because cot_prefix aligns with raw_obs in structure
251+
actual_obs_len = len(game_segment.raw_obs_segment)
252+
actual_cot_len = len(game_segment.cot_prefix_segment)
253+
254+
# Use the minimum of both to be safe
255+
actual_len = min(actual_obs_len, actual_cot_len)
256+
257+
# If segment is too short, we can't avoid padding entirely
258+
if actual_len < required_len:
259+
# Log warning for monitoring
260+
if not hasattr(self, '_padding_warning_count'):
261+
self._padding_warning_count = 0
262+
self._padding_warning_count += 1
263+
264+
# Return position 0 and accept minimal padding
265+
# This is better than random position with more padding
266+
return 0
267+
268+
# Ensure position doesn't exceed safe range
269+
max_safe_pos = actual_len - required_len
270+
271+
if pos_in_game_segment > max_safe_pos:
272+
# Clamp to safe range
273+
pos_in_game_segment = np.random.randint(0, max_safe_pos + 1)
274+
275+
return pos_in_game_segment
223276

224277

225278
def _fetch_latest_orig_data(self, batch_size: int) -> Tuple:
@@ -271,7 +324,10 @@ def _fetch_latest_orig_data(self, batch_size: int) -> Tuple:
271324
# Indices exceeding `game_segment_length` are padded with the next segment and are not updated
272325
# in the current implementation. Therefore, we need to sample `pos_in_game_segment` within
273326
# [0, game_segment_length - num_unroll_steps] to avoid padded data.
274-
327+
328+
# [LAYER 1 FIX] Adjust position to avoid padding-induced misalignment
329+
pos_in_game_segment = self._adjust_pos_to_avoid_padding(game_segment, pos_in_game_segment)
330+
275331
if self._cfg.action_type == 'varied_action_space':
276332
# For some environments (e.g., Jericho), the action space size may be different.
277333
# To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length),

lzero/worker/muzero_evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def __init__(
100100
self._tb_logger = tb_logger
101101
else:
102102
self._tb_logger = None
103+
self._logger, _ = build_logger(
104+
f'./{self._exp_name}/log/{self._instance_name}', self._instance_name, need_tb=False
105+
)
103106

104107
self._rank = get_rank()
105108
print(f'rank {self._rank}, self.task_id: {self.task_id}')

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,11 @@ class PriorZeroLLMConfig:
157157
# 需要注意的是,buffer中取一条经验是 10个样本,因为包含10次交互; num_unroll_steps = 10
158158
train_batch_size: int = 640 # 总的train_size, 结果= micro_batch_size * GPUS * gradient_accumulation_steps
159159
# micro_train_batch_size: int = 16 # 一次micro_train_batch_size 用来计算梯度;只有一次 train_batch_size 才会更新参数
160-
micro_train_batch_size: int = 4 # 一次micro_train_batch_size 用来计算梯度;只有一次 train_batch_size 才会更新参数
160+
# micro_train_batch_size: int = 4 # 一次micro_train_batch_size 用来计算梯度;只有一次 train_batch_size 才会更新参数
161+
162+
# 2卡 1.5b mbs=2
163+
micro_train_batch_size: int = 2 # 一次micro_train_batch_size 用来计算梯度;只有一次 train_batch_size 才会更新参数
164+
161165
broadcast_every: int = 1 # 每次训练多少次 train_batch_size 才同步 vllm 参数;也就是说 vllm 中的模型 off 多少次参数更新
162166

163167
learning_rate: float = 1e-6
@@ -188,6 +192,7 @@ class PriorZeroLLMConfig:
188192
# entropy_loss_coef: Optional[float] = None # None = disabled, typical values: 0.001-0.01
189193

190194
# LLM Prior Mixing Configuration
195+
191196
# ===== baseline root policy-head-logits =====
192197
# prior_mixing_cfg: Optional[EasyDict] = field(default_factory=lambda: EasyDict({
193198
# 'enable_soft_mixing': True, # Enable soft mixing instead of hard override
@@ -207,7 +212,7 @@ class PriorZeroLLMConfig:
207212
# 'enable_soft_mixing': True, # Enable soft mixing instead of hard override
208213
'enable_soft_mixing': False, # Enable soft mixing instead of hard override
209214
# 'mixing_alpha': 0.5, # Weight for LLM prior (0=network only, 1=LLM only)
210-
'mixing_alpha': 0., # Weight for LLM prior (0=network only, 1=LLM only)
215+
'mixing_alpha': 1., # Weight for LLM prior (0=network only, 1=LLM only)
211216
'alpha_schedule': None, # 'linear', 'cosine', 'exponential', or None (fixed)
212217
# 'alpha_schedule': 'cosine', # Smooth decay
213218
'alpha_init': 0.8, # Initial alpha (high LLM influence)

zoo/jericho/priorzero/priorzero_datafactory.py

Lines changed: 123 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,16 +297,65 @@ def build_chat_context(self, user_prompt: str) -> str:
297297
add_generation_prompt=True,
298298
)
299299

300+
def _detect_padding_sample(self, raw_obs_list, cot_prefix_list, action_logprob_list, b, t):
301+
"""
302+
[LAYER 3 DEFENSE] Detect if a sample contains padding data.
303+
304+
Padding detection heuristics:
305+
1. Check for consecutive duplicate observations (unlikely in real gameplay)
306+
2. Check for None values in cot_prefix (initial state marker)
307+
3. Check for duplicate cot_prefix with different obs (clear misalignment)
308+
309+
Args:
310+
raw_obs_list: List of raw observations
311+
cot_prefix_list: List of CoT prefixes
312+
action_logprob_list: List of action logprobs
313+
b: Batch index
314+
t: Time index
315+
316+
Returns:
317+
True if padding is detected, False otherwise
318+
"""
319+
T = len(raw_obs_list[b])
320+
321+
# Heuristic 1: Check for consecutive duplicates (strong padding indicator)
322+
# If both obs and cot_prefix are duplicated at the same time, very likely padding
323+
if t + 1 < T and raw_obs_list[b][t] == raw_obs_list[b][t + 1]:
324+
# Check if cot_prefix is also duplicated
325+
if cot_prefix_list is not None and t + 2 < len(cot_prefix_list[b]):
326+
if cot_prefix_list[b][t + 1] == cot_prefix_list[b][t + 2]:
327+
# Double duplication is a strong signal of padding
328+
return True
329+
330+
# Heuristic 2: Check for None cot_prefix (should only be at t=0)
331+
if t > 0 and cot_prefix_list is not None and t + 1 < len(cot_prefix_list[b]):
332+
if cot_prefix_list[b][t + 1] is None:
333+
return True
334+
335+
# Heuristic 3: Check if action_logprob is empty or None
336+
if action_logprob_list is not None and t + 1 < len(action_logprob_list[b]):
337+
logprob = action_logprob_list[b][t + 1]
338+
if logprob is None or (isinstance(logprob, dict) and len(logprob) == 0):
339+
return True
340+
341+
# Heuristic 4: Check for triple+ consecutive duplicates (very strong signal)
342+
if t + 2 < T:
343+
if (raw_obs_list[b][t] == raw_obs_list[b][t + 1] == raw_obs_list[b][t + 2]):
344+
return True
345+
346+
return False
347+
348+
300349
def build_llm_samples(self,
301350
raw_obs_list: List[List[str]],
302351
history_obs_list: List[List[List[Tuple[str, str, float]]]],
303352
action_logprob_list: Optional[List[List[Any]]] = None,
304-
pred_values: Optional[torch.Tensor] = None, # [B, T-1]
305-
target_values: Optional[torch.Tensor] = None, # [B, T-1]
353+
pred_values: Optional[torch.Tensor] = None, # [B, T-1]
354+
target_values: Optional[torch.Tensor] = None, # [B, T-1]
306355
cot_prefix_list: Optional[List[List[str]]] = None, # CoT reuse optimization
307356
) -> List[Dict[str, Any]]:
308357
"""
309-
Build training samples from collected data.
358+
[ENHANCED] Build training samples with padding detection and filtering.
310359
311360
Args:
312361
raw_obs_list: Raw observations
@@ -324,14 +373,48 @@ def build_llm_samples(self,
324373
return samples
325374
T = len(raw_obs_list[0])
326375

376+
# [LAYER 3] Statistics for monitoring
377+
total_samples = 0
378+
filtered_samples = 0
379+
filtered_reasons = {
380+
'padding': 0,
381+
'empty_action': 0,
382+
'extreme_logprob': 0,
383+
'nan_value': 0,
384+
}
385+
327386
for b in range(B):
328387
for t in range(T - 1):
388+
total_samples += 1
389+
390+
# [LAYER 3 DEFENSE] Detect and skip padding samples
391+
if self._detect_padding_sample(raw_obs_list, cot_prefix_list, action_logprob_list, b, t):
392+
filtered_samples += 1
393+
filtered_reasons['padding'] += 1
394+
continue
395+
329396
current_obs = raw_obs_list[b][t]
330397
current_hist = history_obs_list[b][t]
331398
next_hist = history_obs_list[b][t + 1]
332399

333-
_, true_action, reward_value = next_hist[-1]
400+
# Validate history structure
401+
if not next_hist or len(next_hist) == 0:
402+
filtered_samples += 1
403+
filtered_reasons['empty_action'] += 1
404+
continue
405+
406+
try:
407+
_, true_action, reward_value = next_hist[-1]
408+
except (ValueError, IndexError) as e:
409+
if self.rank == 0:
410+
self._logger.warning(f"Unexpected history structure at b={b}, t={t}: {next_hist[-1]}")
411+
filtered_samples += 1
412+
filtered_reasons['empty_action'] += 1
413+
continue
414+
334415
if not true_action:
416+
filtered_samples += 1
417+
filtered_reasons['empty_action'] += 1
335418
continue
336419

337420
instruction = self.build_llm_prompt(
@@ -341,18 +424,28 @@ def build_llm_samples(self,
341424
prompt = self.build_chat_context(instruction)
342425
old_logprob = None
343426
if action_logprob_list is not None:
344-
old_logprob = action_logprob_list[b][t + 1][true_action]
427+
logprob_dict = action_logprob_list[b][t + 1]
428+
if isinstance(logprob_dict, dict):
429+
old_logprob = logprob_dict.get(true_action, None)
430+
else:
431+
old_logprob = None
345432

346433
# FIX: Filter samples with extreme logprobs to prevent ratio explosion
347434
if old_logprob is not None:
348435
# Skip if empty
349436
if len(old_logprob) == 0:
437+
filtered_samples += 1
438+
filtered_reasons['extreme_logprob'] += 1
350439
continue
351440
# Skip if contains extreme values (< -50 indicates very low probability)
352441
if min(old_logprob) < -50.0:
442+
filtered_samples += 1
443+
filtered_reasons['extreme_logprob'] += 1
353444
continue
354445
# Skip if contains NaN/Inf
355446
if any(math.isnan(x) or math.isinf(x) for x in old_logprob):
447+
filtered_samples += 1
448+
filtered_reasons['nan_value'] += 1
356449
continue
357450

358451
target_value = None
@@ -365,8 +458,12 @@ def build_llm_samples(self,
365458

366459
# FIX: Skip samples with NaN/Inf values
367460
if target_value is not None and (math.isnan(target_value) or math.isinf(target_value)):
461+
filtered_samples += 1
462+
filtered_reasons['nan_value'] += 1
368463
continue
369464
if pred_value is not None and (math.isnan(pred_value) or math.isinf(pred_value)):
465+
filtered_samples += 1
466+
filtered_reasons['nan_value'] += 1
370467
continue
371468

372469
# CoT reuse optimization: get CoT prefix from stored data
@@ -375,6 +472,9 @@ def build_llm_samples(self,
375472
prefix_cot = None
376473
if self.use_cot and cot_prefix_list is not None:
377474
prefix_cot = cot_prefix_list[b][t+1]
475+
# [FIX] Handle None prefix (initial state or padding)
476+
if prefix_cot is None:
477+
prefix_cot = ""
378478

379479
samples.append(
380480
{
@@ -388,6 +488,24 @@ def build_llm_samples(self,
388488
"prefix_cot": prefix_cot, # CoT reuse optimization
389489
}
390490
)
491+
492+
# [LAYER 3 MONITORING] Log filtering statistics
493+
if self.rank == 0 and total_samples > 0:
494+
filter_rate = (filtered_samples / total_samples) * 100
495+
self._logger.info(
496+
f"[Sample Filtering] Total: {total_samples} | Filtered: {filtered_samples} ({filter_rate:.2f}%) | "
497+
f"padding={filtered_reasons['padding']}, empty={filtered_reasons['empty_action']}, "
498+
f"extreme_logprob={filtered_reasons['extreme_logprob']}, nan={filtered_reasons['nan_value']}"
499+
)
500+
501+
# WARNING: If too many samples are filtered due to padding, something is wrong
502+
if filtered_reasons['padding'] > total_samples * 0.1: # >10% padding
503+
self._logger.warning(
504+
f"⚠️ High padding rate detected ({filtered_reasons['padding']}/{total_samples} = "
505+
f"{filtered_reasons['padding']/total_samples*100:.1f}%)! "
506+
f"Check sampling strategy in game buffer."
507+
)
508+
391509
return samples
392510

393511
def make_llm_train_samples(self, priorzero_batch, ddp: bool = False) -> List[Dict[str, Any]]:

zoo/jericho/priorzero/priorzero_entry_sync_ddp.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,15 @@ def train_priorzero(
205205
while True:
206206
cmd = 0 # 0 表示当前循环contiune, 1 表示继续,2 表示break
207207
priorzero_batch = None
208-
if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):
209-
logger.info(f"\n[Evaluation] Rank {rank} | Iter {learner.train_iter}")
210-
stop, reward = evaluator.eval(
211-
save_ckpt_fn=learner.save_checkpoint,
212-
train_iter=learner.train_iter,
213-
envstep=collector.envstep
214-
)
208+
209+
# TODO: priorzero evaluator
210+
# if learner.train_iter > 0 and evaluator.should_eval(learner.train_iter):
211+
# logger.info(f"\n[Evaluation] Rank {rank} | Iter {learner.train_iter}")
212+
# stop, reward = evaluator.eval(
213+
# save_ckpt_fn=learner.save_checkpoint,
214+
# train_iter=learner.train_iter,
215+
# envstep=collector.envstep
216+
# )
215217

216218
if llm_cfg.vllm_enable_sleep and vllm_engine is not None:
217219
vllm_engine.wake_up()

0 commit comments

Comments
 (0)