@@ -359,34 +359,27 @@ def update_policy(self, data: DataProto):
359359 # vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
360360 # gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
361361 # clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
362+ # policy_loss_fn = get_policy_loss_fn(loss_mode)
363+ # pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
364+ # old_log_prob=old_log_prob,
365+ # log_prob=log_prob,
366+ # advantages=advantages,
367+ # response_mask=response_mask,
368+ # loss_agg_mode=loss_agg_mode,
369+ # config=self.config,
370+ # rollout_log_probs=rollout_log_probs,
371+ # )
362372 # Compute FlowRL trajectory balance loss
363- # Use environment variable to switch between versions
364- use_ablation = os .getenv ("FLOWRL_CLIP_ABLATION" , "false" ).lower () == "true"
365-
366- if use_ablation :
367- # Ablation: only clip, no hard mask
368- policy_loss , flowrl_metrics = self .compute_flowrl_cispo_clip_ablation (
369- log_prob = log_prob ,
370- ref_log_prob = ref_log_prob ,
371- old_log_prob = old_log_prob ,
372- log_z = log_z ,
373- reward = advantages ,
374- response_mask = response_mask ,
375- clip_ratio = self .config .clip_ratio ,
376- rollout_log_probs = rollout_log_probs ,
377- )
378- else :
379- # Default: CISPO with hard mask + clip
380- policy_loss , flowrl_metrics = self .compute_flowrl_cispo_clip (
381- log_prob = log_prob ,
382- ref_log_prob = ref_log_prob ,
383- old_log_prob = old_log_prob ,
384- log_z = log_z ,
385- reward = advantages ,
386- response_mask = response_mask ,
387- clip_ratio = self .config .clip_ratio ,
388- rollout_log_probs = rollout_log_probs ,
389- )
373+ policy_loss , flowrl_metrics = self .compute_flowrl (
374+ log_prob = log_prob ,
375+ ref_log_prob = ref_log_prob ,
376+ old_log_prob = old_log_prob ,
377+ log_z = log_z ,
378+ reward = advantages ,
379+ response_mask = response_mask ,
380+ clip_ratio = self .config .clip_ratio ,
381+ rollout_log_probs = rollout_log_probs ,
382+ )
390383
391384 # if entropy_coeff != 0:
392385 # entropy_loss = agg_loss(
@@ -438,7 +431,7 @@ def update_policy(self, data: DataProto):
438431 self .actor_optimizer .zero_grad ()
439432 return metrics
440433
441- def compute_flowrl_cispo_clip (
434+ def compute_flowrl (
442435 self ,
443436 log_prob = None ,
444437 ref_log_prob = None ,
@@ -449,37 +442,23 @@ def compute_flowrl_cispo_clip(
449442 clip_ratio = None ,
450443 rollout_log_probs = None ,
451444 ):
452- log_ratio = log_prob - old_log_prob # (B, T)
453- ratio = torch .exp (log_ratio ) # (B, T)
454-
455- condition_1 = (reward > 0 ) & (ratio > 1.0 + 0.28 ) # (B, T)
456- condition_2 = (reward < 0 ) & (ratio < 1.0 - 0.2 ) # (B, T)
457-
458- # CISPO mask
459- cispo_mask = ~ (condition_1 | condition_2 )
460- cispo_mask = cispo_mask .float ()
461- combined_mask = response_mask * cispo_mask
462-
463445 # squeeze log_z to (B,)
464446 log_z = log_z .squeeze (- 1 )
465447
466448 # Average token log-probs & rewards over valid positions
467- avg_log_prob = verl_F .masked_mean (log_prob , combined_mask , axis = 1 )
468- avg_ref_log_prob = verl_F .masked_mean (ref_log_prob , combined_mask , axis = 1 )
469- seq_log_reward = verl_F .masked_mean (reward , combined_mask , axis = 1 )
449+ avg_log_prob = verl_F .masked_mean (log_prob , response_mask , axis = 1 )
450+ avg_ref_log_prob = verl_F .masked_mean (ref_log_prob , response_mask , axis = 1 )
451+ seq_log_reward = verl_F .masked_mean (reward , response_mask , axis = 1 )
470452
471453 # FlowRL residual: logZ + logpf - β*R - logpref
472454 delta = log_z + avg_log_prob - self .flowrl_beta_coef * seq_log_reward - avg_ref_log_prob
473455
474456 # Importance ratio from current vs old policy (product of token ratios)
475- log_w = verl_F .masked_sum (log_prob - old_log_prob , combined_mask , axis = 1 )
457+ log_w = verl_F .masked_sum (log_prob - old_log_prob , response_mask , axis = 1 )
476458 imp_w_raw = torch .exp (log_w ).detach ()
459+ imp_w = torch .clamp (imp_w_raw , max = 10 )
477460
478- # Clamp importance weight for numerical stability (prevent extreme values)
479- # imp_w = torch.clamp(imp_w_raw, max=10.0)
480- imp_w = torch .clamp (imp_w_raw , 1 - 0.2 , 1 + 0.28 )
481-
482- # Loss: weighted squared residual with clipped importance weights
461+ # Loss: weighted squared residual with importance weights
483462 weighted_losses = imp_w * (delta ** 2 )
484463 avg_loss = torch .mean (weighted_losses )
485464
@@ -491,11 +470,6 @@ def compute_flowrl_cispo_clip(
491470 approx_kl_ref = log_prob - ref_log_prob
492471 ref_kl = verl_F .masked_mean (- approx_kl_ref , response_mask )
493472
494- # cispo
495- total_tokens = response_mask .sum ()
496- cispo_dropped = (response_mask * (1 - cispo_mask )).sum ()
497- cispo_mask_ratio = cispo_dropped / (total_tokens + 1e-8 )
498-
499473 # Metrics
500474 loss_term_dict = {
501475 "actor/log_prob" : verl_F .masked_mean (log_prob , response_mask ).detach ().item (),
@@ -504,104 +478,9 @@ def compute_flowrl_cispo_clip(
504478 "actor/log_z" : log_z .mean ().detach ().item (),
505479 "actor/log_reward" : verl_F .masked_mean (reward , response_mask ).detach ().item (),
506480 "actor/final_loss" : avg_loss .detach ().item (),
507- "actor/importance_weight_raw" : imp_w_raw .mean ().detach ().item (),
508481 "actor/importance_weight" : imp_w .mean ().detach ().item (),
509482 "actor/ppo_kl" : ppo_kl .detach ().item (), # PPO-style KL (current vs old policy)
510483 "actor/ref_kl" : ref_kl .detach ().item (), # KL with reference policy
511- "actor/cispo_mask_ratio" : cispo_mask_ratio .detach ().item (), # cispo
512- "actor/cispo_dropped_tokens" : cispo_dropped .detach ().item (), # cispo
513- "actor/condition_1_count" : (condition_1 * response_mask ).sum ().detach ().item (), # cispo
514- "actor/condition_2_count" : (condition_2 * response_mask ).sum ().detach ().item (), # cispo
515- }
516-
517- return avg_loss , loss_term_dict
518-
519- def compute_flowrl_cispo_clip_ablation (
520- self ,
521- log_prob = None ,
522- ref_log_prob = None ,
523- old_log_prob = None ,
524- log_z = None ,
525- reward = None ,
526- response_mask = None ,
527- clip_ratio = None ,
528- rollout_log_probs = None ,
529- ):
530- """
531- Ablation study: Remove hard CISPO mask, only use importance weight clipping.
532- This version uses response_mask only (no condition-based masking).
533- """
534-
535- # log_ratio = log_prob - old_log_prob # (B, T)
536- # ratio = torch.exp(log_ratio) # (B, T)
537-
538- # === Main change: Remove hard mask, only use clip ===
539- # Original version had:
540- # condition_1 = (reward > 0) & (ratio > 1.0 + 0.28)
541- # condition_2 = (reward < 0) & (ratio < 1.0 - 0.2)
542- # cispo_mask = ~(condition_1 | condition_2)
543- # combined_mask = response_mask * cispo_mask
544-
545- # New version: Only use response_mask, no hard masking
546- combined_mask = response_mask # Only keep response_mask
547- # ====================================================
548-
549- # squeeze log_z to (B,)
550- log_z = log_z .squeeze (- 1 )
551-
552- # Average token log-probs & rewards over valid positions
553- avg_log_prob = verl_F .masked_mean (log_prob , combined_mask , axis = 1 )
554- avg_ref_log_prob = verl_F .masked_mean (ref_log_prob , combined_mask , axis = 1 )
555- seq_log_reward = verl_F .masked_mean (reward , combined_mask , axis = 1 )
556-
557- # FlowRL residual: logZ + logpf - β*R - logpref
558- delta = log_z + avg_log_prob - self .flowrl_beta_coef * seq_log_reward - avg_ref_log_prob
559-
560- # Importance ratio from current vs old policy (product of token ratios)
561- log_w = verl_F .masked_sum (log_prob - old_log_prob , combined_mask , axis = 1 )
562- imp_w_raw = torch .exp (log_w ).detach ()
563-
564- # === Main change: Clipping is the core of CISPO ===
565- # This clipping is what distinguishes this from vanilla FlowRL
566- imp_w = torch .clamp (imp_w_raw , 1 - 0.2 , 1 + 0.28 ) # Keep this unchanged
567- # ==================================================
568-
569- # Loss: weighted squared residual with clipped importance weights
570- weighted_losses = imp_w * (delta ** 2 )
571- avg_loss = torch .mean (weighted_losses )
572-
573- # PPO KL: negative_approx_kl = log_prob - old_log_prob
574- negative_approx_kl = log_prob - old_log_prob
575- ppo_kl = verl_F .masked_mean (- negative_approx_kl , response_mask )
576-
577- # Reference KL: approx_kl_ref = log_prob - ref_log_prob
578- approx_kl_ref = log_prob - ref_log_prob
579- ref_kl = verl_F .masked_mean (- approx_kl_ref , response_mask )
580-
581- # === Updated statistics ===
582- # Since we're using clipping instead of masking, count clipped samples
583- total_tokens = response_mask .sum ()
584- clipped_low = ((imp_w_raw < 1.0 - 0.2 ) & (imp_w_raw > 0 )).sum ()
585- clipped_high = (imp_w_raw > 1.0 + 0.28 ).sum ()
586- cispo_clipped_count = clipped_low + clipped_high
587- cispo_clip_ratio = cispo_clipped_count / (total_tokens + 1e-8 )
588-
589- # Metrics
590- loss_term_dict = {
591- "actor/log_prob" : verl_F .masked_mean (log_prob , response_mask ).detach ().item (),
592- "actor/old_log_prob" : verl_F .masked_mean (old_log_prob , response_mask ).detach ().item (),
593- "actor/ref_log_prob" : verl_F .masked_mean (ref_log_prob , response_mask ).detach ().item (),
594- "actor/log_z" : log_z .mean ().detach ().item (),
595- "actor/log_reward" : verl_F .masked_mean (reward , response_mask ).detach ().item (),
596- "actor/final_loss" : avg_loss .detach ().item (),
597- "actor/importance_weight_raw" : imp_w_raw .mean ().detach ().item (),
598- "actor/importance_weight" : imp_w .mean ().detach ().item (),
599- "actor/ppo_kl" : ppo_kl .detach ().item (),
600- "actor/ref_kl" : ref_kl .detach ().item (),
601- "actor/cispo_clip_ratio" : cispo_clip_ratio .detach ().item (), # Renamed from mask_ratio
602- "actor/cispo_clipped_count" : cispo_clipped_count .detach ().item (), # Renamed from dropped_tokens
603- "actor/clipped_low_count" : clipped_low .detach ().item (),
604- "actor/clipped_high_count" : clipped_high .detach ().item (),
605484 }
606485
607486 return avg_loss , loss_term_dict
0 commit comments