Skip to content

Comments

Asft plus#3918

Draft
hcsolakoglu wants to merge 9 commits intounslothai:mainfrom
hcsolakoglu:asft-plus
Draft

Asft plus#3918
hcsolakoglu wants to merge 9 commits intounslothai:mainfrom
hcsolakoglu:asft-plus

Conversation

@hcsolakoglu
Copy link

I’ve added support for Anchored Supervised Fine-Tuning (ASFT) to the Unsloth CLI and training pipeline. This update gives you more flexibility by enabling alternative loss functions and reference-based regularization during fine-tuning. Paper: https://arxiv.org/abs/2509.23753

​What’s new:
​ASFTTrainer Implementation: A new trainer in unsloth/trainer.py supporting modes like sft, dft, sft+kl, and asft, plus VRAM-efficient streaming.
​CLI Upgrades: Refactored the CLI to handle new ASFT arguments (mode, KL weight, etc.) and automatically choose the right trainer based on your flags.
​Streamlined Losses: Cleaned up how loss functions are handled by moving them to a dedicated unsloth/losses module.
​Testing: Included a new test suite to verify CLI argument parsing.

…ions

Implements the ASFT objective (DFT-style reweighting + KL anchoring) with Unsloth trainer/CLI integration, tests, and an ASFT/ASFT+ demo notebook.

Credits: https://github.com/zhuchichi56/ASFT
Replaces `enabled`/`ref_strategy` with unified `mode` parameter in ASFTStreamingConfig. Adds "auto", "seq", "batch", "hybrid", and "off" modes with automatic fallback logic. Implements seq_kv_cache streaming with KV cache reuse and batch microbatching support. Updates notebook defaults to use `mode="auto"` for optimal VRAM reduction. Adds comprehensive tests for mode routing, fallback behavior, and backward compatibility.
Adds `kl_direction` ("forward"/"reverse") to control KL divergence computation direction and `normalize_by` ("tokens"/"weights") for DFT/ASFT loss normalization. Forward KL (default) matches original ASFT code behavior despite paper terminology. Reverse KL enables mode-seeking behavior. Updates `_compute_kl_divergence`, streaming strategies, `compute_asft_loss`, and `ASFTTrainer` to propagate both parameters. Adds tests for reverse KL computation
Detects `packed_seq_lengths` in forward_inputs and bypasses seq_kv_cache chunking to avoid KV cache corruption with packed sequences. Falls back to batch microbatching (if configured) or full reference forward pass. Adds test verifying fallback triggers `_compute_kl_batch_micro` with microbatch_size=1 when packed sequences present.
Extracts logit parameter resolution into `_resolve_logit_params` helper to handle model-specific scaling overrides. Adds Granite (`logits_scaling` → `1/logits_scaling`) and Falcon H1 (`lm_head_multiplier`) support alongside existing `logit_scale`/`logit_scaling` fallback chain. Updates `effective_logits` and `compute_asft_loss` to use unified resolution logic. Adds tests verifying Granite/Falcon H1 scaling in both `effective_logits` and ASFT CE
Removes unused `TestUnslothTrainingArguments` test class. Simplifies GGUF quantization logic in CLI by eliminating redundant list wrapping and intermediate variable. Removes deprecated `forced_merged_4bit` save method choice from CLI (kept `merged_4bit_forced` alias in save.py for backward compatibility). Fixes `UnslothTrainingArguments` to not store unused `embedding_learning_rate` attribute.
Reformats function signatures and multi-line expressions in `test_asft.py` and `unsloth/losses/asft.py` to improve readability. Splits long function signatures across multiple lines, uses parenthesized context managers for multiple `patch()` calls, and breaks complex conditionals/expressions. No functional changes.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hcsolakoglu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Unsloth training pipeline by integrating Anchored Supervised Fine-Tuning (ASFT). This feature provides greater flexibility in loss function design, allowing for more nuanced control over the fine-tuning process through confidence-weighted cross-entropy and Kullback-Leibler divergence regularization against a reference model. The changes are seamlessly integrated into the CLI and a new trainer, offering VRAM-optimized strategies to make these advanced techniques accessible and efficient for users.

Highlights

  • ASFT (Anchored Supervised Fine-Tuning) Support: Introduced comprehensive support for ASFT, enabling alternative loss functions and reference-based regularization during fine-tuning. This includes new loss modes: SFT (standard CE), DFT (CE weighted by confidence), SFT+KL (CE + KL divergence), and ASFT (DFT + KL).
  • New ASFTTrainer Implementation: A dedicated ASFTTrainer class has been added, inheriting from UnslothTrainer, to manage the ASFT loss computation. It integrates various ASFT parameters and handles the creation of reference models.
  • CLI Integration for ASFT: The unsloth-cli.py has been updated to include new arguments for configuring ASFT, such as --asft, --asft_mode, --kl_weight, --reference_policy, and --asft_streaming, allowing users to easily enable and customize ASFT from the command line.
  • VRAM-Efficient Streaming Strategies: Implemented VRAM-efficient streaming strategies (batch_micro, seq_kv_cache, hybrid) for computing KL divergence with a reference model, reducing peak memory usage during training.
  • Dedicated Loss Module: Loss-related functions, including ASFT components, have been refactored into a new unsloth/losses module for better organization and modularity.
  • Extensive Testing: A new test suite (tests/test_asft.py) has been added to thoroughly validate the ASFT loss computations, streaming mechanisms, and backward compatibility, ensuring robustness and correctness.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support for Anchored Supervised Fine-Tuning (ASFT), including a new loss module, a specialized ASFTTrainer, and corresponding CLI arguments. The implementation is well-structured and includes advanced, VRAM-efficient streaming strategies with robust fallbacks. The addition of extensive test suites covering unit, integration, and backward compatibility aspects is commendable and significantly boosts confidence in the new functionality. My review focuses on improving the maintainability and robustness of the core loss computation logic by addressing some code duplication and suggesting safer exception handling. Overall, this is a high-quality contribution that adds a powerful new feature to the library.

Comment on lines +593 to +841
def _compute_kl_seq_kv_cache(
model: nn.Module,
cur_logits: torch.Tensor,
shift_labels: torch.Tensor,
valid_mask: torch.Tensor,
ref_forward: Callable,
forward_inputs: Dict[str, Any],
seq_chunk_size: int,
microbatch_size: Optional[int] = None,
allow_auto_microbatch_fallback: bool = True,
logit_softcapping: float = 0,
logit_scaling: float = 0,
force_fp32: bool = True,
kl_direction: Literal["forward", "reverse"] = "forward",
) -> torch.Tensor:
"""Compute KL using sequence chunking with KV cache strategy.

Processes reference forward in sequence chunks to reduce peak VRAM.
Falls back to full forward if model doesn't support caching.

Args:
model: Current model.
cur_logits: Current model logits (B, T, V).
shift_labels: Shifted labels (B, T).
valid_mask: Valid token mask (B, T).
ref_forward: Reference forward callable.
forward_inputs: Forward inputs (without labels).
seq_chunk_size: Size of each sequence chunk.
microbatch_size: Optional microbatch size for batch dimension.
allow_auto_microbatch_fallback: Allow automatic microbatch fallback on errors.
logit_softcapping: Softcapping value.
logit_scaling: Scaling value.
force_fp32: Whether to use FP32 for KL.
kl_direction: "forward" for KL(p_ref || p_cur), "reverse" for KL(p_cur || p_ref).

Returns:
KL tensor of shape (B, T).
"""
batch_size, seq_len, vocab_size = cur_logits.shape
device = cur_logits.device

packed_seq_lengths = forward_inputs.get("packed_seq_lengths", None)
if packed_seq_lengths is not None:
# Avoid seq_kv_cache with packed sequences; fall back to batch/full reference.
fallback_microbatch = None
if microbatch_size is not None and microbatch_size < batch_size:
fallback_microbatch = microbatch_size
elif allow_auto_microbatch_fallback:
fallback_microbatch = max(1, batch_size // _DEFAULT_REF_MICROBATCH_DIVISOR)
if fallback_microbatch >= batch_size:
fallback_microbatch = None
if fallback_microbatch is not None:
return _compute_kl_batch_micro(
model,
cur_logits,
shift_labels,
valid_mask,
ref_forward,
forward_inputs,
fallback_microbatch,
logit_softcapping,
logit_scaling,
force_fp32,
kl_direction,
)
ref_outputs = ref_forward(**forward_inputs)
ref_logits, _ = _unwrap_reference_outputs(ref_outputs)
kl_full = _compute_kl_divergence(
cur_logits,
ref_logits,
model,
logit_softcapping,
logit_scaling,
force_fp32,
kl_direction,
)
if kl_full.dim() == 1:
kl_full = kl_full.view(batch_size, seq_len)
return kl_full

if microbatch_size is not None:
microbatch_size = max(1, microbatch_size)
if microbatch_size is not None and microbatch_size < batch_size:
kl = torch.zeros(batch_size, seq_len, dtype = torch.float32, device = device)
for b_start in range(0, batch_size, microbatch_size):
b_end = min(b_start + microbatch_size, batch_size)
mb_inputs = _slice_batch_inputs(forward_inputs, batch_size, b_start, b_end)
kl_mb = _compute_kl_seq_kv_cache(
model,
cur_logits[b_start:b_end],
shift_labels[b_start:b_end],
valid_mask[b_start:b_end],
ref_forward,
mb_inputs,
seq_chunk_size,
microbatch_size = None,
allow_auto_microbatch_fallback = False,
logit_softcapping = logit_softcapping,
logit_scaling = logit_scaling,
force_fp32 = force_fp32,
kl_direction = kl_direction,
)
if kl_mb.dim() == 1:
mb_batch = b_end - b_start
kl_mb = kl_mb.view(mb_batch, -1)
kl[b_start:b_end] = kl_mb
return kl

kl = torch.zeros(batch_size, seq_len, dtype = torch.float32, device = device)

# Process in chunks with KV cache
past_key_values = None

for s_start in range(0, seq_len, seq_chunk_size):
s_end = min(s_start + seq_chunk_size, seq_len)

# Build chunk inputs
chunk_inputs = {}
for key, value in forward_inputs.items():
if key == "input_ids":
chunk_inputs[key] = value[:, s_start:s_end]
elif key == "attention_mask":
# For chunked processing, need attention mask up to s_end
chunk_inputs[key] = value[:, :s_end]
elif key == "position_ids":
chunk_inputs[key] = value[:, s_start:s_end]
elif (
torch.is_tensor(value)
and value.dim() >= 2
and value.shape[1] == seq_len
):
chunk_inputs[key] = value[:, s_start:s_end]
else:
chunk_inputs[key] = value

# Add past_key_values if available
if past_key_values is not None:
chunk_inputs["past_key_values"] = past_key_values

chunk_inputs["use_cache"] = True

try:
# Get reference logits for chunk
# Note: ref_forward may not support all these kwargs
ref_outputs = ref_forward(**chunk_inputs)
ref_logits_chunk, ref_past_key_values = _unwrap_reference_outputs(
ref_outputs
)
if ref_past_key_values is None and s_end < seq_len:
# Can't continue without cache; fall back to batch micro if allowed
fallback_microbatch = None
if allow_auto_microbatch_fallback:
fallback_microbatch = (
microbatch_size
if microbatch_size is not None
else max(1, batch_size // _DEFAULT_REF_MICROBATCH_DIVISOR)
)
if fallback_microbatch is not None and fallback_microbatch < batch_size:
return _compute_kl_batch_micro(
model,
cur_logits,
shift_labels,
valid_mask,
ref_forward,
forward_inputs,
fallback_microbatch,
logit_softcapping,
logit_scaling,
force_fp32,
kl_direction,
)
ref_outputs = ref_forward(**forward_inputs)
ref_logits, _ = _unwrap_reference_outputs(ref_outputs)
kl_full = _compute_kl_divergence(
cur_logits,
ref_logits,
model,
logit_softcapping,
logit_scaling,
force_fp32,
kl_direction,
)
if kl_full.dim() == 1:
kl_full = kl_full.view(batch_size, seq_len)
return kl_full
past_key_values = ref_past_key_values

cur_logits_chunk = cur_logits[:, s_start:s_end]

# Compute KL for this chunk
kl_chunk = _compute_kl_divergence(
cur_logits_chunk,
ref_logits_chunk,
model,
logit_softcapping,
logit_scaling,
force_fp32,
kl_direction,
)

if kl_chunk.dim() == 1:
chunk_len = s_end - s_start
kl_chunk = kl_chunk.view(batch_size, chunk_len)

kl[:, s_start:s_end] = kl_chunk

del ref_logits_chunk

except (RuntimeError, ValueError, KeyError, TypeError) as e:
# Fallback to batch micro or full forward on KV cache errors
# These exceptions typically indicate the model doesn't support
# the chunked KV cache approach (e.g., missing past_key_values support)
fallback_microbatch = None
if allow_auto_microbatch_fallback:
fallback_microbatch = (
microbatch_size
if microbatch_size is not None
else max(1, batch_size // _DEFAULT_REF_MICROBATCH_DIVISOR)
)
if fallback_microbatch is not None and fallback_microbatch < batch_size:
return _compute_kl_batch_micro(
model,
cur_logits,
shift_labels,
valid_mask,
ref_forward,
forward_inputs,
fallback_microbatch,
logit_softcapping,
logit_scaling,
force_fp32,
kl_direction,
)
ref_outputs = ref_forward(**forward_inputs)
ref_logits, _ = _unwrap_reference_outputs(ref_outputs)
kl_full = _compute_kl_divergence(
cur_logits,
ref_logits,
model,
logit_softcapping,
logit_scaling,
force_fp32,
kl_direction,
)
if kl_full.dim() == 1:
kl_full = kl_full.view(batch_size, seq_len)
return kl_full

return kl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function contains significant code duplication in its fallback logic. The code blocks for handling fallbacks for packed sequences (lines 635-671), KV cache failure (lines 741-777), and general exceptions (lines 801-839) are nearly identical.

This duplication makes the function harder to maintain, as any changes to the fallback logic need to be applied in three different places.

To improve maintainability, I recommend refactoring this duplicated logic into a separate private helper function. This helper would encapsulate the logic for deciding between batch-micro and full-forward KL computation and would be called from all three fallback points.

For example, you could introduce a helper like _perform_kl_fallback(...) that contains the shared logic.

Comment on lines +801 to +839
except (RuntimeError, ValueError, KeyError, TypeError) as e:
# Fallback to batch micro or full forward on KV cache errors
# These exceptions typically indicate the model doesn't support
# the chunked KV cache approach (e.g., missing past_key_values support)
fallback_microbatch = None
if allow_auto_microbatch_fallback:
fallback_microbatch = (
microbatch_size
if microbatch_size is not None
else max(1, batch_size // _DEFAULT_REF_MICROBATCH_DIVISOR)
)
if fallback_microbatch is not None and fallback_microbatch < batch_size:
return _compute_kl_batch_micro(
model,
cur_logits,
shift_labels,
valid_mask,
ref_forward,
forward_inputs,
fallback_microbatch,
logit_softcapping,
logit_scaling,
force_fp32,
kl_direction,
)
ref_outputs = ref_forward(**forward_inputs)
ref_logits, _ = _unwrap_reference_outputs(ref_outputs)
kl_full = _compute_kl_divergence(
cur_logits,
ref_logits,
model,
logit_softcapping,
logit_scaling,
force_fp32,
kl_direction,
)
if kl_full.dim() == 1:
kl_full = kl_full.view(batch_size, seq_len)
return kl_full
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The except block here catches a very broad set of exceptions (RuntimeError, ValueError, KeyError, TypeError). While the comment explains this is for handling models that don't support KV caching, it also risks masking unrelated bugs. For instance, a KeyError could indicate a problem with the forward_inputs dictionary construction, but it would be silently caught here, triggering a performance-degrading fallback instead of surfacing the bug.

To make the code more robust and easier to debug, consider one of the following:

  1. Narrow the exceptions: If possible, catch more specific exceptions that are known to be thrown by incompatible model forward signatures.
  2. Add logging: Inside the except block, log the exception with its traceback at a WARNING or DEBUG level. This would make it clear why the fallback was triggered and help diagnose unexpected errors.

Example with logging:

import logging
logger = logging.getLogger(__name__)
...
        except (RuntimeError, ValueError, KeyError, TypeError) as e:
            logger.warning(
                "Unsloth: KV cache streaming for ASFT failed, falling back. Error: %s",
                e, exc_info=True
            )
            # ... fallback logic ...

This would greatly improve debuggability without sacrificing the robustness of the fallback mechanism.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant