Skip to content

Commit 7cf1e2d

Browse files
committed
fix(pu): add policy_logits_clip_method option
1 parent 5ed77bf commit 7cf1e2d

File tree

4 files changed

+196
-29
lines changed

4 files changed

+196
-29
lines changed

lzero/model/unizero_world_models/lpips.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
from torchvision import models
1414
from tqdm import tqdm
1515

16+
os.environ['HF_HOME'] = '/mnt/shared-storage-user/puyuan/code/LightZero/tokenizer_pretrained_vgg'
17+
custom_torch_home = "/mnt/shared-storage-user/puyuan/code/LightZero/tokenizer_pretrained_vgg"
18+
os.environ['TORCH_HOME'] = custom_torch_home
19+
os.makedirs(os.path.join(custom_torch_home, 'hub', 'checkpoints'), exist_ok=True)
20+
1621
class LPIPS(nn.Module):
1722
# Learned perceptual metric
1823
def __init__(self, use_dropout: bool = True):

lzero/model/unizero_world_models/world_model.py

Lines changed: 134 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,17 @@ def _initialize_config_parameters(self) -> None:
562562
# ==================== [NEW] Policy Stability Fix Options ====================
563563
# Load fix options from config (with defaults for backward compatibility)
564564
self.use_policy_logits_clip = getattr(self.config, 'use_policy_logits_clip', False)
565+
self.policy_logits_clip_method = getattr(self.config, 'policy_logits_clip_method', 'soft_tanh')
565566
self.policy_logits_clip_min = getattr(self.config, 'policy_logits_clip_min', -10.0)
566567
self.policy_logits_clip_max = getattr(self.config, 'policy_logits_clip_max', 10.0)
568+
self.policy_logits_soft_beta = getattr(self.config, 'policy_logits_soft_beta', 1.0)
569+
self.policy_logits_adaptive_percentile = getattr(self.config, 'policy_logits_adaptive_percentile', 95)
570+
571+
# Running statistics for adaptive clipping
572+
if self.policy_logits_clip_method == 'adaptive':
573+
self.register_buffer('policy_logits_running_max', torch.tensor(10.0))
574+
self.register_buffer('policy_logits_running_min', torch.tensor(-10.0))
575+
self.policy_logits_momentum = 0.99
567576

568577
# [NEW] Fix5: Temperature scaling for policy loss
569578
self.use_policy_loss_temperature = getattr(self.config, 'use_policy_loss_temperature', False)
@@ -581,9 +590,14 @@ def _initialize_config_parameters(self) -> None:
581590

582591
# [NEW] Debug: Print configuration on initialization
583592
if self.use_policy_logits_clip:
584-
logging.info(f"[Policy Logits Clip] ENABLED: range=[{self.policy_logits_clip_min}, {self.policy_logits_clip_max}]")
593+
logging.info(
594+
f"[Policy Logits Control] ENABLED\n"
595+
f" Method: {self.policy_logits_clip_method}\n"
596+
f" Range: [{self.policy_logits_clip_min}, {self.policy_logits_clip_max}]\n"
597+
f" Soft Beta: {self.policy_logits_soft_beta if 'soft' in self.policy_logits_clip_method else 'N/A'}"
598+
)
585599
else:
586-
logging.warning(f"[Policy Logits Clip] DISABLED! Using default values.")
600+
logging.warning(f"[Policy Logits Control] DISABLED! Logits may grow unbounded.")
587601

588602
if self.use_policy_loss_temperature and self.policy_loss_temperature != 1.0:
589603
logging.info(f"[Policy Loss Temperature] ENABLED: temperature={self.policy_loss_temperature}")
@@ -598,6 +612,119 @@ def _initialize_patterns(self) -> None:
598612
self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block)
599613
self.value_policy_tokens_pattern[-2] = 1
600614

615+
def _apply_policy_logits_control(self, logits_policy: torch.Tensor) -> torch.Tensor:
616+
"""
617+
Apply policy logits control using various methods to prevent explosion.
618+
619+
This method implements multiple strategies to constrain policy logits:
620+
1. 'hard': Hard clamp (torch.clamp) - Simple but gradients die at boundaries
621+
2. 'soft_tanh': Soft clamp using tanh - Smooth, gradients never zero
622+
3. 'soft_sigmoid': Soft clamp using sigmoid - Similar to tanh but different curve
623+
4. 'normalize_max': Subtract max then clamp - Preserves relative order, safer
624+
5. 'normalize_mean': Subtract mean then clamp - Centers distribution
625+
6. 'adaptive': Adaptive clipping based on running statistics
626+
7. 'none': No clipping
627+
628+
Arguments:
629+
- logits_policy (:obj:`torch.Tensor`): Raw policy logits from head_policy
630+
Shape: [batch_size, num_steps, action_dim] or [batch_size * num_steps, action_dim]
631+
632+
Returns:
633+
- torch.Tensor: Controlled policy logits with the same shape
634+
635+
Examples:
636+
>>> logits = torch.randn(32, 10, 6) * 20 # Large logits
637+
>>> controlled = self._apply_policy_logits_control(logits)
638+
>>> assert controlled.abs().max() <= self.policy_logits_clip_max
639+
"""
640+
if not self.use_policy_logits_clip or self.policy_logits_clip_method == 'none':
641+
return logits_policy
642+
643+
method = self.policy_logits_clip_method
644+
clip_min = self.policy_logits_clip_min
645+
clip_max = self.policy_logits_clip_max
646+
647+
# ==================== Method 1: Hard Clamp ====================
648+
if method == 'hard':
649+
# Simple hard clipping
650+
# Pros: Simple, fast
651+
# Cons: Gradients become zero outside [clip_min, clip_max]
652+
return torch.clamp(logits_policy, min=clip_min, max=clip_max)
653+
654+
# ==================== Method 2: Soft Tanh Clamp ====================
655+
elif method == 'soft_tanh':
656+
# Soft clamp using tanh function: clip_max * tanh(x / clip_max)
657+
# Pros: Gradients never zero, smooth transition
658+
# Cons: Slightly more computation
659+
# When x is small: tanh(x) ≈ x, so output ≈ x (unchanged)
660+
# When x is large: tanh(x) → 1, so output → clip_max (smoothly saturates)
661+
C = clip_max # Use positive bound as scale
662+
beta = self.policy_logits_soft_beta # Smoothness parameter
663+
return C * torch.tanh(logits_policy / (C * beta))
664+
665+
# ==================== Method 3: Soft Sigmoid Clamp ====================
666+
elif method == 'soft_sigmoid':
667+
# Soft clamp using sigmoid: maps (-∞, ∞) to (clip_min, clip_max)
668+
# Formula: clip_min + (clip_max - clip_min) * sigmoid(x / beta)
669+
# Pros: Smooth, bounded
670+
# Cons: Compresses entire range, may lose relative ordering
671+
beta = self.policy_logits_soft_beta
672+
range_size = clip_max - clip_min
673+
return clip_min + range_size * torch.sigmoid(logits_policy / beta)
674+
675+
# ==================== Method 4: Normalize Max + Hard Clamp ====================
676+
elif method == 'normalize_max':
677+
# Subtract max value first (exploits softmax translation invariance)
678+
# softmax(x) = softmax(x - c) for any constant c
679+
# By subtracting max, we ensure the largest logit is 0, others are negative
680+
# Then apply hard clamp (mainly affects the negative tail)
681+
# Pros: Preserves relative ordering, safer than pure hard clamp
682+
# Cons: Still has gradient issues for very negative values
683+
logits_normalized = logits_policy - logits_policy.max(dim=-1, keepdim=True)[0].detach()
684+
return torch.clamp(logits_normalized, min=clip_min, max=clip_max)
685+
686+
# ==================== Method 5: Normalize Mean + Hard Clamp ====================
687+
elif method == 'normalize_mean':
688+
# Subtract mean (centers the distribution)
689+
# Pros: Centers logits around 0, prevents drift
690+
# Cons: May change relative probabilities more than normalize_max
691+
logits_normalized = logits_policy - logits_policy.mean(dim=-1, keepdim=True).detach()
692+
return torch.clamp(logits_normalized, min=clip_min, max=clip_max)
693+
694+
# ==================== Method 6: Adaptive Clipping ====================
695+
elif method == 'adaptive':
696+
# Dynamically adjust clipping thresholds based on running statistics
697+
# Update running stats (only during training)
698+
if self.training:
699+
with torch.no_grad():
700+
# Compute percentile-based bounds
701+
flat_logits = logits_policy.view(-1)
702+
percentile = self.policy_logits_adaptive_percentile
703+
current_max = torch.quantile(flat_logits, percentile / 100.0)
704+
current_min = torch.quantile(flat_logits, (100 - percentile) / 100.0)
705+
706+
# Update running statistics with momentum
707+
self.policy_logits_running_max = (
708+
self.policy_logits_momentum * self.policy_logits_running_max +
709+
(1 - self.policy_logits_momentum) * current_max
710+
)
711+
self.policy_logits_running_min = (
712+
self.policy_logits_momentum * self.policy_logits_running_min +
713+
(1 - self.policy_logits_momentum) * current_min
714+
)
715+
716+
# Use running stats for clipping
717+
adaptive_max = torch.clamp(self.policy_logits_running_max, max=clip_max)
718+
adaptive_min = torch.clamp(self.policy_logits_running_min, min=clip_min)
719+
return torch.clamp(logits_policy, min=adaptive_min, max=adaptive_max)
720+
721+
else:
722+
raise ValueError(
723+
f"Unknown policy_logits_clip_method: {method}. "
724+
f"Valid options: 'hard', 'soft_tanh', 'soft_sigmoid', 'normalize_max', "
725+
f"'normalize_mean', 'adaptive', 'none'"
726+
)
727+
601728
def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head:
602729
"""Create head modules for the transformer."""
603730
modules = [
@@ -945,15 +1072,12 @@ def forward(
9451072
logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps)
9461073
logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps)
9471074

948-
# ==================== [NEW] Fix1: Clip Policy Logits ====================
949-
# Prevent policy logits from exploding, which can cause gradient issues
1075+
# ==================== [NEW] Advanced Policy Logits Control ====================
1076+
# Apply configurable policy logits control to prevent explosion
1077+
# Multiple methods available: hard, soft_tanh, soft_sigmoid, normalize_max, etc.
9501078
if self.use_policy_logits_clip:
951-
logits_policy = torch.clamp(
952-
logits_policy,
953-
min=self.policy_logits_clip_min,
954-
max=self.policy_logits_clip_max
955-
)
956-
# ========================================================================
1079+
logits_policy = self._apply_policy_logits_control(logits_policy)
1080+
# ================================================================================
9571081

9581082
logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps)
9591083

zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,12 @@ def create_config(
184184
# (float) Learning rate for adaptive alpha optimizer
185185
adaptive_entropy_alpha_lr=1e-3,
186186
target_entropy_start_ratio=0.98,
187-
target_entropy_end_ratio=0.5, # for action_space=18
188-
target_entropy_decay_steps=100000, # e.g., reach final value after 150k iterations (300k envsteps)
187+
# target_entropy_end_ratio=0.5, # for action_space=18
188+
# target_entropy_decay_steps=100000, # e.g., reach final value after 150k iterations (300k envsteps)
189+
190+
target_entropy_end_ratio=0.1, # for action_space=18
191+
target_entropy_decay_steps=150000, # Complete decay after 150k iterations (needs coordination with replay ratio)
192+
189193

190194
# ==================== Encoder-Clip Annealing Configuration ====================
191195
# (bool) Whether to enable encoder-clip value annealing.
@@ -230,7 +234,8 @@ def create_config(
230234
num_simulations=num_simulations,
231235
reanalyze_ratio=reanalyze_ratio,
232236
n_episode=n_episode,
233-
replay_buffer_size=int(5e5),
237+
# replay_buffer_size=int(5e5),
238+
replay_buffer_size=int(1e5), # TODO
234239
eval_freq=int(1e4), # Evaluation frequency for 8 games
235240
collector_env_num=collector_env_num,
236241
evaluator_env_num=evaluator_env_num,
@@ -263,8 +268,9 @@ def generate_configs(
263268

264269
# --- Experiment Name Template ---
265270
# Replace placeholders like [BENCHMARK_TAG] and [MODEL_TAG] to define the experiment name.
266-
benchmark_tag = "data_unizero_mt_1226"
267-
model_tag = f"vit_tran-nlayer{num_layers}_moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{buffer_reanalyze_freq}_label-smooth_head-inner-ln"
271+
benchmark_tag = "data_unizero_mt_1229"
272+
model_tag = f"vit_tran-nlayer{num_layers}_moe8_encoder-100k-30-10_alpha-150k-098-01_tbs1024_brf{buffer_reanalyze_freq}_label-smooth_head-inner-ln"
273+
# model_tag = f"vit_tran-nlayer{num_layers}_moe8_encoder-100k-30-10-true_alpha-100k-098-05_prior_adamw-wd1e-2-all_tbs512_brf{buffer_reanalyze_freq}_label-smooth_head-inner-ln"
268274
exp_name_prefix = f'{benchmark_tag}/atari_{len(env_id_list)}games_{model_tag}_seed{seed}/'
269275

270276
for task_id, env_id in enumerate(env_id_list):
@@ -360,7 +366,9 @@ def create_env_manager() -> EasyDict:
360366
# to fit within GPU memory constraints.
361367
if len(env_id_list) == 8:
362368
if num_layers in [2, 4]:
363-
effective_batch_size = 512
369+
# effective_batch_size = 512
370+
effective_batch_size = 1024 # TODO 128*8=2048
371+
# effective_batch_size = 2048 # TODO 256*8=2048
364372
elif num_layers == 8:
365373
effective_batch_size = 512
366374
elif len(env_id_list) == 26:
@@ -372,7 +380,9 @@ def create_env_manager() -> EasyDict:
372380
else:
373381
raise ValueError(f"Batch size not configured for {len(env_id_list)} environments.")
374382

375-
batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size, gpu_num=6) # TODO
383+
# batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size, gpu_num=6) # TODO
384+
batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size, gpu_num=4) # TODO
385+
376386
total_batch_size = effective_batch_size # Currently for logging purposes
377387

378388
# ==================== Model and Training Settings ====================

0 commit comments

Comments
 (0)