@@ -562,8 +562,17 @@ def _initialize_config_parameters(self) -> None:
562562 # ==================== [NEW] Policy Stability Fix Options ====================
563563 # Load fix options from config (with defaults for backward compatibility)
564564 self .use_policy_logits_clip = getattr (self .config , 'use_policy_logits_clip' , False )
565+ self .policy_logits_clip_method = getattr (self .config , 'policy_logits_clip_method' , 'soft_tanh' )
565566 self .policy_logits_clip_min = getattr (self .config , 'policy_logits_clip_min' , - 10.0 )
566567 self .policy_logits_clip_max = getattr (self .config , 'policy_logits_clip_max' , 10.0 )
568+ self .policy_logits_soft_beta = getattr (self .config , 'policy_logits_soft_beta' , 1.0 )
569+ self .policy_logits_adaptive_percentile = getattr (self .config , 'policy_logits_adaptive_percentile' , 95 )
570+
571+ # Running statistics for adaptive clipping
572+ if self .policy_logits_clip_method == 'adaptive' :
573+ self .register_buffer ('policy_logits_running_max' , torch .tensor (10.0 ))
574+ self .register_buffer ('policy_logits_running_min' , torch .tensor (- 10.0 ))
575+ self .policy_logits_momentum = 0.99
567576
568577 # [NEW] Fix5: Temperature scaling for policy loss
569578 self .use_policy_loss_temperature = getattr (self .config , 'use_policy_loss_temperature' , False )
@@ -581,9 +590,14 @@ def _initialize_config_parameters(self) -> None:
581590
582591 # [NEW] Debug: Print configuration on initialization
583592 if self .use_policy_logits_clip :
584- logging .info (f"[Policy Logits Clip] ENABLED: range=[{ self .policy_logits_clip_min } , { self .policy_logits_clip_max } ]" )
593+ logging .info (
594+ f"[Policy Logits Control] ENABLED\n "
595+ f" Method: { self .policy_logits_clip_method } \n "
596+ f" Range: [{ self .policy_logits_clip_min } , { self .policy_logits_clip_max } ]\n "
597+ f" Soft Beta: { self .policy_logits_soft_beta if 'soft' in self .policy_logits_clip_method else 'N/A' } "
598+ )
585599 else :
586- logging .warning (f"[Policy Logits Clip ] DISABLED! Using default values ." )
600+ logging .warning (f"[Policy Logits Control ] DISABLED! Logits may grow unbounded ." )
587601
588602 if self .use_policy_loss_temperature and self .policy_loss_temperature != 1.0 :
589603 logging .info (f"[Policy Loss Temperature] ENABLED: temperature={ self .policy_loss_temperature } " )
@@ -598,6 +612,119 @@ def _initialize_patterns(self) -> None:
598612 self .value_policy_tokens_pattern = torch .zeros (self .config .tokens_per_block )
599613 self .value_policy_tokens_pattern [- 2 ] = 1
600614
615+ def _apply_policy_logits_control (self , logits_policy : torch .Tensor ) -> torch .Tensor :
616+ """
617+ Apply policy logits control using various methods to prevent explosion.
618+
619+ This method implements multiple strategies to constrain policy logits:
620+ 1. 'hard': Hard clamp (torch.clamp) - Simple but gradients die at boundaries
621+ 2. 'soft_tanh': Soft clamp using tanh - Smooth, gradients never zero
622+ 3. 'soft_sigmoid': Soft clamp using sigmoid - Similar to tanh but different curve
623+ 4. 'normalize_max': Subtract max then clamp - Preserves relative order, safer
624+ 5. 'normalize_mean': Subtract mean then clamp - Centers distribution
625+ 6. 'adaptive': Adaptive clipping based on running statistics
626+ 7. 'none': No clipping
627+
628+ Arguments:
629+ - logits_policy (:obj:`torch.Tensor`): Raw policy logits from head_policy
630+ Shape: [batch_size, num_steps, action_dim] or [batch_size * num_steps, action_dim]
631+
632+ Returns:
633+ - torch.Tensor: Controlled policy logits with the same shape
634+
635+ Examples:
636+ >>> logits = torch.randn(32, 10, 6) * 20 # Large logits
637+ >>> controlled = self._apply_policy_logits_control(logits)
638+ >>> assert controlled.abs().max() <= self.policy_logits_clip_max
639+ """
640+ if not self .use_policy_logits_clip or self .policy_logits_clip_method == 'none' :
641+ return logits_policy
642+
643+ method = self .policy_logits_clip_method
644+ clip_min = self .policy_logits_clip_min
645+ clip_max = self .policy_logits_clip_max
646+
647+ # ==================== Method 1: Hard Clamp ====================
648+ if method == 'hard' :
649+ # Simple hard clipping
650+ # Pros: Simple, fast
651+ # Cons: Gradients become zero outside [clip_min, clip_max]
652+ return torch .clamp (logits_policy , min = clip_min , max = clip_max )
653+
654+ # ==================== Method 2: Soft Tanh Clamp ====================
655+ elif method == 'soft_tanh' :
656+ # Soft clamp using tanh function: clip_max * tanh(x / clip_max)
657+ # Pros: Gradients never zero, smooth transition
658+ # Cons: Slightly more computation
659+ # When x is small: tanh(x) ≈ x, so output ≈ x (unchanged)
660+ # When x is large: tanh(x) → 1, so output → clip_max (smoothly saturates)
661+ C = clip_max # Use positive bound as scale
662+ beta = self .policy_logits_soft_beta # Smoothness parameter
663+ return C * torch .tanh (logits_policy / (C * beta ))
664+
665+ # ==================== Method 3: Soft Sigmoid Clamp ====================
666+ elif method == 'soft_sigmoid' :
667+ # Soft clamp using sigmoid: maps (-∞, ∞) to (clip_min, clip_max)
668+ # Formula: clip_min + (clip_max - clip_min) * sigmoid(x / beta)
669+ # Pros: Smooth, bounded
670+ # Cons: Compresses entire range, may lose relative ordering
671+ beta = self .policy_logits_soft_beta
672+ range_size = clip_max - clip_min
673+ return clip_min + range_size * torch .sigmoid (logits_policy / beta )
674+
675+ # ==================== Method 4: Normalize Max + Hard Clamp ====================
676+ elif method == 'normalize_max' :
677+ # Subtract max value first (exploits softmax translation invariance)
678+ # softmax(x) = softmax(x - c) for any constant c
679+ # By subtracting max, we ensure the largest logit is 0, others are negative
680+ # Then apply hard clamp (mainly affects the negative tail)
681+ # Pros: Preserves relative ordering, safer than pure hard clamp
682+ # Cons: Still has gradient issues for very negative values
683+ logits_normalized = logits_policy - logits_policy .max (dim = - 1 , keepdim = True )[0 ].detach ()
684+ return torch .clamp (logits_normalized , min = clip_min , max = clip_max )
685+
686+ # ==================== Method 5: Normalize Mean + Hard Clamp ====================
687+ elif method == 'normalize_mean' :
688+ # Subtract mean (centers the distribution)
689+ # Pros: Centers logits around 0, prevents drift
690+ # Cons: May change relative probabilities more than normalize_max
691+ logits_normalized = logits_policy - logits_policy .mean (dim = - 1 , keepdim = True ).detach ()
692+ return torch .clamp (logits_normalized , min = clip_min , max = clip_max )
693+
694+ # ==================== Method 6: Adaptive Clipping ====================
695+ elif method == 'adaptive' :
696+ # Dynamically adjust clipping thresholds based on running statistics
697+ # Update running stats (only during training)
698+ if self .training :
699+ with torch .no_grad ():
700+ # Compute percentile-based bounds
701+ flat_logits = logits_policy .view (- 1 )
702+ percentile = self .policy_logits_adaptive_percentile
703+ current_max = torch .quantile (flat_logits , percentile / 100.0 )
704+ current_min = torch .quantile (flat_logits , (100 - percentile ) / 100.0 )
705+
706+ # Update running statistics with momentum
707+ self .policy_logits_running_max = (
708+ self .policy_logits_momentum * self .policy_logits_running_max +
709+ (1 - self .policy_logits_momentum ) * current_max
710+ )
711+ self .policy_logits_running_min = (
712+ self .policy_logits_momentum * self .policy_logits_running_min +
713+ (1 - self .policy_logits_momentum ) * current_min
714+ )
715+
716+ # Use running stats for clipping
717+ adaptive_max = torch .clamp (self .policy_logits_running_max , max = clip_max )
718+ adaptive_min = torch .clamp (self .policy_logits_running_min , min = clip_min )
719+ return torch .clamp (logits_policy , min = adaptive_min , max = adaptive_max )
720+
721+ else :
722+ raise ValueError (
723+ f"Unknown policy_logits_clip_method: { method } . "
724+ f"Valid options: 'hard', 'soft_tanh', 'soft_sigmoid', 'normalize_max', "
725+ f"'normalize_mean', 'adaptive', 'none'"
726+ )
727+
601728 def _create_head (self , block_mask : torch .Tensor , output_dim : int , norm_layer = None ) -> Head :
602729 """Create head modules for the transformer."""
603730 modules = [
@@ -945,15 +1072,12 @@ def forward(
9451072 logits_rewards = self .head_rewards (x , num_steps = num_steps , prev_steps = prev_steps )
9461073 logits_policy = self .head_policy (x , num_steps = num_steps , prev_steps = prev_steps )
9471074
948- # ==================== [NEW] Fix1: Clip Policy Logits ====================
949- # Prevent policy logits from exploding, which can cause gradient issues
1075+ # ==================== [NEW] Advanced Policy Logits Control ====================
1076+ # Apply configurable policy logits control to prevent explosion
1077+ # Multiple methods available: hard, soft_tanh, soft_sigmoid, normalize_max, etc.
9501078 if self .use_policy_logits_clip :
951- logits_policy = torch .clamp (
952- logits_policy ,
953- min = self .policy_logits_clip_min ,
954- max = self .policy_logits_clip_max
955- )
956- # ========================================================================
1079+ logits_policy = self ._apply_policy_logits_control (logits_policy )
1080+ # ================================================================================
9571081
9581082 logits_value = self .head_value (x , num_steps = num_steps , prev_steps = prev_steps )
9591083
0 commit comments