Conversation
…ions Implements the ASFT objective (DFT-style reweighting + KL anchoring) with Unsloth trainer/CLI integration, tests, and an ASFT/ASFT+ demo notebook. Credits: https://github.com/zhuchichi56/ASFT
Replaces `enabled`/`ref_strategy` with unified `mode` parameter in ASFTStreamingConfig. Adds "auto", "seq", "batch", "hybrid", and "off" modes with automatic fallback logic. Implements seq_kv_cache streaming with KV cache reuse and batch microbatching support. Updates notebook defaults to use `mode="auto"` for optimal VRAM reduction. Adds comprehensive tests for mode routing, fallback behavior, and backward compatibility.
Adds `kl_direction` ("forward"/"reverse") to control KL divergence computation direction and `normalize_by` ("tokens"/"weights") for DFT/ASFT loss normalization. Forward KL (default) matches original ASFT code behavior despite paper terminology. Reverse KL enables mode-seeking behavior. Updates `_compute_kl_divergence`, streaming strategies, `compute_asft_loss`, and `ASFTTrainer` to propagate both parameters. Adds tests for reverse KL computation
Detects `packed_seq_lengths` in forward_inputs and bypasses seq_kv_cache chunking to avoid KV cache corruption with packed sequences. Falls back to batch microbatching (if configured) or full reference forward pass. Adds test verifying fallback triggers `_compute_kl_batch_micro` with microbatch_size=1 when packed sequences present.
Extracts logit parameter resolution into `_resolve_logit_params` helper to handle model-specific scaling overrides. Adds Granite (`logits_scaling` → `1/logits_scaling`) and Falcon H1 (`lm_head_multiplier`) support alongside existing `logit_scale`/`logit_scaling` fallback chain. Updates `effective_logits` and `compute_asft_loss` to use unified resolution logic. Adds tests verifying Granite/Falcon H1 scaling in both `effective_logits` and ASFT CE
Removes unused `TestUnslothTrainingArguments` test class. Simplifies GGUF quantization logic in CLI by eliminating redundant list wrapping and intermediate variable. Removes deprecated `forced_merged_4bit` save method choice from CLI (kept `merged_4bit_forced` alias in save.py for backward compatibility). Fixes `UnslothTrainingArguments` to not store unused `embedding_learning_rate` attribute.
Reformats function signatures and multi-line expressions in `test_asft.py` and `unsloth/losses/asft.py` to improve readability. Splits long function signatures across multiple lines, uses parenthesized context managers for multiple `patch()` calls, and breaks complex conditionals/expressions. No functional changes.
Summary of ChangesHello @hcsolakoglu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Unsloth training pipeline by integrating Anchored Supervised Fine-Tuning (ASFT). This feature provides greater flexibility in loss function design, allowing for more nuanced control over the fine-tuning process through confidence-weighted cross-entropy and Kullback-Leibler divergence regularization against a reference model. The changes are seamlessly integrated into the CLI and a new trainer, offering VRAM-optimized strategies to make these advanced techniques accessible and efficient for users. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive support for Anchored Supervised Fine-Tuning (ASFT), including a new loss module, a specialized ASFTTrainer, and corresponding CLI arguments. The implementation is well-structured and includes advanced, VRAM-efficient streaming strategies with robust fallbacks. The addition of extensive test suites covering unit, integration, and backward compatibility aspects is commendable and significantly boosts confidence in the new functionality. My review focuses on improving the maintainability and robustness of the core loss computation logic by addressing some code duplication and suggesting safer exception handling. Overall, this is a high-quality contribution that adds a powerful new feature to the library.
| def _compute_kl_seq_kv_cache( | ||
| model: nn.Module, | ||
| cur_logits: torch.Tensor, | ||
| shift_labels: torch.Tensor, | ||
| valid_mask: torch.Tensor, | ||
| ref_forward: Callable, | ||
| forward_inputs: Dict[str, Any], | ||
| seq_chunk_size: int, | ||
| microbatch_size: Optional[int] = None, | ||
| allow_auto_microbatch_fallback: bool = True, | ||
| logit_softcapping: float = 0, | ||
| logit_scaling: float = 0, | ||
| force_fp32: bool = True, | ||
| kl_direction: Literal["forward", "reverse"] = "forward", | ||
| ) -> torch.Tensor: | ||
| """Compute KL using sequence chunking with KV cache strategy. | ||
|
|
||
| Processes reference forward in sequence chunks to reduce peak VRAM. | ||
| Falls back to full forward if model doesn't support caching. | ||
|
|
||
| Args: | ||
| model: Current model. | ||
| cur_logits: Current model logits (B, T, V). | ||
| shift_labels: Shifted labels (B, T). | ||
| valid_mask: Valid token mask (B, T). | ||
| ref_forward: Reference forward callable. | ||
| forward_inputs: Forward inputs (without labels). | ||
| seq_chunk_size: Size of each sequence chunk. | ||
| microbatch_size: Optional microbatch size for batch dimension. | ||
| allow_auto_microbatch_fallback: Allow automatic microbatch fallback on errors. | ||
| logit_softcapping: Softcapping value. | ||
| logit_scaling: Scaling value. | ||
| force_fp32: Whether to use FP32 for KL. | ||
| kl_direction: "forward" for KL(p_ref || p_cur), "reverse" for KL(p_cur || p_ref). | ||
|
|
||
| Returns: | ||
| KL tensor of shape (B, T). | ||
| """ | ||
| batch_size, seq_len, vocab_size = cur_logits.shape | ||
| device = cur_logits.device | ||
|
|
||
| packed_seq_lengths = forward_inputs.get("packed_seq_lengths", None) | ||
| if packed_seq_lengths is not None: | ||
| # Avoid seq_kv_cache with packed sequences; fall back to batch/full reference. | ||
| fallback_microbatch = None | ||
| if microbatch_size is not None and microbatch_size < batch_size: | ||
| fallback_microbatch = microbatch_size | ||
| elif allow_auto_microbatch_fallback: | ||
| fallback_microbatch = max(1, batch_size // _DEFAULT_REF_MICROBATCH_DIVISOR) | ||
| if fallback_microbatch >= batch_size: | ||
| fallback_microbatch = None | ||
| if fallback_microbatch is not None: | ||
| return _compute_kl_batch_micro( | ||
| model, | ||
| cur_logits, | ||
| shift_labels, | ||
| valid_mask, | ||
| ref_forward, | ||
| forward_inputs, | ||
| fallback_microbatch, | ||
| logit_softcapping, | ||
| logit_scaling, | ||
| force_fp32, | ||
| kl_direction, | ||
| ) | ||
| ref_outputs = ref_forward(**forward_inputs) | ||
| ref_logits, _ = _unwrap_reference_outputs(ref_outputs) | ||
| kl_full = _compute_kl_divergence( | ||
| cur_logits, | ||
| ref_logits, | ||
| model, | ||
| logit_softcapping, | ||
| logit_scaling, | ||
| force_fp32, | ||
| kl_direction, | ||
| ) | ||
| if kl_full.dim() == 1: | ||
| kl_full = kl_full.view(batch_size, seq_len) | ||
| return kl_full | ||
|
|
||
| if microbatch_size is not None: | ||
| microbatch_size = max(1, microbatch_size) | ||
| if microbatch_size is not None and microbatch_size < batch_size: | ||
| kl = torch.zeros(batch_size, seq_len, dtype = torch.float32, device = device) | ||
| for b_start in range(0, batch_size, microbatch_size): | ||
| b_end = min(b_start + microbatch_size, batch_size) | ||
| mb_inputs = _slice_batch_inputs(forward_inputs, batch_size, b_start, b_end) | ||
| kl_mb = _compute_kl_seq_kv_cache( | ||
| model, | ||
| cur_logits[b_start:b_end], | ||
| shift_labels[b_start:b_end], | ||
| valid_mask[b_start:b_end], | ||
| ref_forward, | ||
| mb_inputs, | ||
| seq_chunk_size, | ||
| microbatch_size = None, | ||
| allow_auto_microbatch_fallback = False, | ||
| logit_softcapping = logit_softcapping, | ||
| logit_scaling = logit_scaling, | ||
| force_fp32 = force_fp32, | ||
| kl_direction = kl_direction, | ||
| ) | ||
| if kl_mb.dim() == 1: | ||
| mb_batch = b_end - b_start | ||
| kl_mb = kl_mb.view(mb_batch, -1) | ||
| kl[b_start:b_end] = kl_mb | ||
| return kl | ||
|
|
||
| kl = torch.zeros(batch_size, seq_len, dtype = torch.float32, device = device) | ||
|
|
||
| # Process in chunks with KV cache | ||
| past_key_values = None | ||
|
|
||
| for s_start in range(0, seq_len, seq_chunk_size): | ||
| s_end = min(s_start + seq_chunk_size, seq_len) | ||
|
|
||
| # Build chunk inputs | ||
| chunk_inputs = {} | ||
| for key, value in forward_inputs.items(): | ||
| if key == "input_ids": | ||
| chunk_inputs[key] = value[:, s_start:s_end] | ||
| elif key == "attention_mask": | ||
| # For chunked processing, need attention mask up to s_end | ||
| chunk_inputs[key] = value[:, :s_end] | ||
| elif key == "position_ids": | ||
| chunk_inputs[key] = value[:, s_start:s_end] | ||
| elif ( | ||
| torch.is_tensor(value) | ||
| and value.dim() >= 2 | ||
| and value.shape[1] == seq_len | ||
| ): | ||
| chunk_inputs[key] = value[:, s_start:s_end] | ||
| else: | ||
| chunk_inputs[key] = value | ||
|
|
||
| # Add past_key_values if available | ||
| if past_key_values is not None: | ||
| chunk_inputs["past_key_values"] = past_key_values | ||
|
|
||
| chunk_inputs["use_cache"] = True | ||
|
|
||
| try: | ||
| # Get reference logits for chunk | ||
| # Note: ref_forward may not support all these kwargs | ||
| ref_outputs = ref_forward(**chunk_inputs) | ||
| ref_logits_chunk, ref_past_key_values = _unwrap_reference_outputs( | ||
| ref_outputs | ||
| ) | ||
| if ref_past_key_values is None and s_end < seq_len: | ||
| # Can't continue without cache; fall back to batch micro if allowed | ||
| fallback_microbatch = None | ||
| if allow_auto_microbatch_fallback: | ||
| fallback_microbatch = ( | ||
| microbatch_size | ||
| if microbatch_size is not None | ||
| else max(1, batch_size // _DEFAULT_REF_MICROBATCH_DIVISOR) | ||
| ) | ||
| if fallback_microbatch is not None and fallback_microbatch < batch_size: | ||
| return _compute_kl_batch_micro( | ||
| model, | ||
| cur_logits, | ||
| shift_labels, | ||
| valid_mask, | ||
| ref_forward, | ||
| forward_inputs, | ||
| fallback_microbatch, | ||
| logit_softcapping, | ||
| logit_scaling, | ||
| force_fp32, | ||
| kl_direction, | ||
| ) | ||
| ref_outputs = ref_forward(**forward_inputs) | ||
| ref_logits, _ = _unwrap_reference_outputs(ref_outputs) | ||
| kl_full = _compute_kl_divergence( | ||
| cur_logits, | ||
| ref_logits, | ||
| model, | ||
| logit_softcapping, | ||
| logit_scaling, | ||
| force_fp32, | ||
| kl_direction, | ||
| ) | ||
| if kl_full.dim() == 1: | ||
| kl_full = kl_full.view(batch_size, seq_len) | ||
| return kl_full | ||
| past_key_values = ref_past_key_values | ||
|
|
||
| cur_logits_chunk = cur_logits[:, s_start:s_end] | ||
|
|
||
| # Compute KL for this chunk | ||
| kl_chunk = _compute_kl_divergence( | ||
| cur_logits_chunk, | ||
| ref_logits_chunk, | ||
| model, | ||
| logit_softcapping, | ||
| logit_scaling, | ||
| force_fp32, | ||
| kl_direction, | ||
| ) | ||
|
|
||
| if kl_chunk.dim() == 1: | ||
| chunk_len = s_end - s_start | ||
| kl_chunk = kl_chunk.view(batch_size, chunk_len) | ||
|
|
||
| kl[:, s_start:s_end] = kl_chunk | ||
|
|
||
| del ref_logits_chunk | ||
|
|
||
| except (RuntimeError, ValueError, KeyError, TypeError) as e: | ||
| # Fallback to batch micro or full forward on KV cache errors | ||
| # These exceptions typically indicate the model doesn't support | ||
| # the chunked KV cache approach (e.g., missing past_key_values support) | ||
| fallback_microbatch = None | ||
| if allow_auto_microbatch_fallback: | ||
| fallback_microbatch = ( | ||
| microbatch_size | ||
| if microbatch_size is not None | ||
| else max(1, batch_size // _DEFAULT_REF_MICROBATCH_DIVISOR) | ||
| ) | ||
| if fallback_microbatch is not None and fallback_microbatch < batch_size: | ||
| return _compute_kl_batch_micro( | ||
| model, | ||
| cur_logits, | ||
| shift_labels, | ||
| valid_mask, | ||
| ref_forward, | ||
| forward_inputs, | ||
| fallback_microbatch, | ||
| logit_softcapping, | ||
| logit_scaling, | ||
| force_fp32, | ||
| kl_direction, | ||
| ) | ||
| ref_outputs = ref_forward(**forward_inputs) | ||
| ref_logits, _ = _unwrap_reference_outputs(ref_outputs) | ||
| kl_full = _compute_kl_divergence( | ||
| cur_logits, | ||
| ref_logits, | ||
| model, | ||
| logit_softcapping, | ||
| logit_scaling, | ||
| force_fp32, | ||
| kl_direction, | ||
| ) | ||
| if kl_full.dim() == 1: | ||
| kl_full = kl_full.view(batch_size, seq_len) | ||
| return kl_full | ||
|
|
||
| return kl |
There was a problem hiding this comment.
This function contains significant code duplication in its fallback logic. The code blocks for handling fallbacks for packed sequences (lines 635-671), KV cache failure (lines 741-777), and general exceptions (lines 801-839) are nearly identical.
This duplication makes the function harder to maintain, as any changes to the fallback logic need to be applied in three different places.
To improve maintainability, I recommend refactoring this duplicated logic into a separate private helper function. This helper would encapsulate the logic for deciding between batch-micro and full-forward KL computation and would be called from all three fallback points.
For example, you could introduce a helper like _perform_kl_fallback(...) that contains the shared logic.
| except (RuntimeError, ValueError, KeyError, TypeError) as e: | ||
| # Fallback to batch micro or full forward on KV cache errors | ||
| # These exceptions typically indicate the model doesn't support | ||
| # the chunked KV cache approach (e.g., missing past_key_values support) | ||
| fallback_microbatch = None | ||
| if allow_auto_microbatch_fallback: | ||
| fallback_microbatch = ( | ||
| microbatch_size | ||
| if microbatch_size is not None | ||
| else max(1, batch_size // _DEFAULT_REF_MICROBATCH_DIVISOR) | ||
| ) | ||
| if fallback_microbatch is not None and fallback_microbatch < batch_size: | ||
| return _compute_kl_batch_micro( | ||
| model, | ||
| cur_logits, | ||
| shift_labels, | ||
| valid_mask, | ||
| ref_forward, | ||
| forward_inputs, | ||
| fallback_microbatch, | ||
| logit_softcapping, | ||
| logit_scaling, | ||
| force_fp32, | ||
| kl_direction, | ||
| ) | ||
| ref_outputs = ref_forward(**forward_inputs) | ||
| ref_logits, _ = _unwrap_reference_outputs(ref_outputs) | ||
| kl_full = _compute_kl_divergence( | ||
| cur_logits, | ||
| ref_logits, | ||
| model, | ||
| logit_softcapping, | ||
| logit_scaling, | ||
| force_fp32, | ||
| kl_direction, | ||
| ) | ||
| if kl_full.dim() == 1: | ||
| kl_full = kl_full.view(batch_size, seq_len) | ||
| return kl_full |
There was a problem hiding this comment.
The except block here catches a very broad set of exceptions (RuntimeError, ValueError, KeyError, TypeError). While the comment explains this is for handling models that don't support KV caching, it also risks masking unrelated bugs. For instance, a KeyError could indicate a problem with the forward_inputs dictionary construction, but it would be silently caught here, triggering a performance-degrading fallback instead of surfacing the bug.
To make the code more robust and easier to debug, consider one of the following:
- Narrow the exceptions: If possible, catch more specific exceptions that are known to be thrown by incompatible model forward signatures.
- Add logging: Inside the
exceptblock, log the exception with its traceback at aWARNINGorDEBUGlevel. This would make it clear why the fallback was triggered and help diagnose unexpected errors.
Example with logging:
import logging
logger = logging.getLogger(__name__)
...
except (RuntimeError, ValueError, KeyError, TypeError) as e:
logger.warning(
"Unsloth: KV cache streaming for ASFT failed, falling back. Error: %s",
e, exc_info=True
)
# ... fallback logic ...This would greatly improve debuggability without sacrificing the robustness of the fallback mechanism.
I’ve added support for Anchored Supervised Fine-Tuning (ASFT) to the Unsloth CLI and training pipeline. This update gives you more flexibility by enabling alternative loss functions and reference-based regularization during fine-tuning. Paper: https://arxiv.org/abs/2509.23753
What’s new:
ASFTTrainer Implementation: A new trainer in unsloth/trainer.py supporting modes like sft, dft, sft+kl, and asft, plus VRAM-efficient streaming.
CLI Upgrades: Refactored the CLI to handle new ASFT arguments (mode, KL weight, etc.) and automatically choose the right trainer based on your flags.
Streamlined Losses: Cleaned up how loss functions are handled by moving them to a dedicated unsloth/losses module.
Testing: Included a new test suite to verify CLI argument parsing.