Skip to content

Commit 8baf20c

Browse files
authored
Merge branch 'main' into ktyang_device_context
2 parents 4a2c058 + ca0ecf1 commit 8baf20c

File tree

9 files changed

+317
-76
lines changed

9 files changed

+317
-76
lines changed

.github/ISSUE_TEMPLATE/bug---issue.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ assignees: ''
1818
Put Minimal code to reproduce error here ###Remove Hugging Face token###
1919
```
2020

21-
🦥 You can also ask via our Reddit page: https://www.reddit.com/r/unsloth/
21+
🦥 You can also ask via our Reddit page: https://reddit.com/r/unsloth/

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.14.10
3+
rev: v0.14.11
44
hooks:
55
- id: ruff
66
args:

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ Use our official [Unsloth Docker image](https://hub.docker.com/r/unsloth/unsloth
5353
For RTX 50x, B200, 6000 GPUs: `pip install unsloth`. Read our [Blackwell Guide](https://unsloth.ai/docs/basics/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth) and [DGX Spark Guide](https://unsloth.ai/docs/basics/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth) for more details.
5454

5555
## 🦥 Unsloth News
56+
- New 7x longer context reinforcement learning vs. all other setups, via our new batching algorithms. [Blog](https://unsloth.ai/docs/new/grpo-long-context)
5657
- New RoPE & MLP **Triton Kernels** & **Padding Free + Packing**: 3x faster training & 30% less VRAM. [Blog](https://unsloth.ai/docs/new/3x-faster-training-packing)
57-
- **New Mistral**: Run Ministral 3 or Devstral 2 and fine-tune with vision/RL sodoku notebooks. [Guide](https://unsloth.ai/docs/models/ministral-3)[Notebooks](https://unsloth.ai/docs/models/ministral-3#fine-tuning-ministral-3)
58+
- **Mistral 3**: Run Ministral 3 or Devstral 2 and fine-tune with vision/RL sodoku notebooks. [Guide](https://unsloth.ai/docs/models/ministral-3)[Notebooks](https://unsloth.ai/docs/models/ministral-3#fine-tuning-ministral-3)
5859
- **500K Context**: Training a 20B model with >500K context is now possible on an 80GB GPU. [Blog](https://unsloth.ai/docs/new/500k-context-length-fine-tuning)
5960
- **FP8 Reinforcement Learning**: You can now do FP8 GRPO on consumer GPUs. [Blog](https://unsloth.ai/docs/new/fp8-reinforcement-learning)[Notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_8B_FP8_GRPO.ipynb)
6061
- **DeepSeek-OCR**: Fine-tune to improve language understanding by 89%. [Guide](https://unsloth.ai/docs/models/deepseek-ocr-how-to-run-and-fine-tune)[Notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Deepseek_OCR_(3B).ipynb)

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ huggingfacenotorch = [
5151
"sentencepiece>=0.2.0",
5252
"datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0",
5353
"accelerate>=0.34.1",
54-
"peft>=0.7.1,!=0.11.0",
54+
"peft>=0.18.0,!=0.11.0",
5555
"huggingface_hub>=0.34.0",
5656
"hf_transfer",
5757
"diffusers",
@@ -60,7 +60,7 @@ huggingfacenotorch = [
6060
]
6161
huggingface = [
6262
"unsloth[huggingfacenotorch]",
63-
"unsloth_zoo>=2026.1.2",
63+
"unsloth_zoo>=2026.1.3",
6464
"torchvision",
6565
"unsloth[triton]",
6666
]
@@ -523,7 +523,7 @@ colab-ampere-torch220 = [
523523
"flash-attn>=2.6.3 ; ('linux' in sys_platform)",
524524
]
525525
colab-new = [
526-
"unsloth_zoo>=2026.1.2",
526+
"unsloth_zoo>=2026.1.3",
527527
"packaging",
528528
"tyro",
529529
"transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,<=4.57.3",
@@ -542,7 +542,7 @@ colab-new = [
542542
colab-no-deps = [
543543
"accelerate>=0.34.1",
544544
"trl>=0.18.2,!=0.19.0,<=0.24.0",
545-
"peft>=0.7.1",
545+
"peft>=0.18.0",
546546
"xformers ; ('linux' in sys_platform or sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
547547
"bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0",
548548
"protobuf",

unsloth/kernels/swiglu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _DWf_DW_dfg_kernel(
128128

129129

130130
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
131-
batch_seq_len, hd = e.shape
131+
batch_seq_len, hd = e.shape # Flattened to 2D, so 1st dim is bsz * seq_len
132132
n_elements = e.numel()
133133
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
134134
with torch_gpu_device(e.device):

unsloth/models/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "2026.1.2"
15+
__version__ = "2026.1.3"
1616

1717
__all__ = [
1818
"SUPPORTS_BFLOAT16",

unsloth/models/rl.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,13 @@ def unsloth_prediction_step(
231231
Trainer.prediction_step = unsloth_prediction_step
232232

233233

234+
grpo_selective_log_softmax = RL_REPLACEMENTS["grpo_selective_log_softmax"]
234235
selective_log_softmax = RL_REPLACEMENTS["selective_log_softmax"]
235236
calculate_pad_tokens_in_prompt = RL_REPLACEMENTS["calculate_pad_tokens_in_prompt"]
236237
create_completion_attention_mask = RL_REPLACEMENTS["create_completion_attention_mask"]
237238
left_pack_padding = RL_REPLACEMENTS["left_pack_padding"]
238239
align_logprobs_with_mask = RL_REPLACEMENTS["align_logprobs_with_mask"]
240+
autotune_batch_and_chunks = RL_REPLACEMENTS["grpo_autotune_batch_and_chunks"]
239241

240242
RLTrainer_replacement = '''
241243
import os
@@ -247,7 +249,6 @@ def unsloth_prediction_step(
247249
from contextlib import nullcontext
248250
from torch.nn import functional as F
249251
import inspect
250-
import psutil
251252
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
252253
from transformers.training_args import ParallelMode
253254
@@ -264,17 +265,19 @@ def prepare_for_training_mode(f):
264265
def wrapper(self, *args, **kwargs):
265266
# Enable training mode
266267
_was_training = None
268+
# Get gradient checkpointing setting from training arguments
269+
use_gc = getattr(self.args, 'gradient_checkpointing', True)
267270
if hasattr(self, 'model') and hasattr(self.model, "training"):
268271
_was_training = self.model.training
269272
if hasattr(self, 'model') and hasattr(self.model, "for_training"):
270-
self.model.for_training()
273+
self.model.for_training(use_gradient_checkpointing=use_gc)
271274
output = f(self, *args, **kwargs)
272275
# Restore previous mode when possible
273276
if hasattr(self, 'model') and hasattr(self.model, "for_inference"):
274277
if _was_training is False:
275278
self.model.for_inference()
276279
elif _was_training is True and hasattr(self.model, "for_training"):
277-
self.model.for_training()
280+
self.model.for_training(use_gradient_checkpointing=use_gc)
278281
# Reset gradient checkpointing buffers to free memory while staying ready for next run
279282
try:
280283
reset_unsloth_gradient_checkpointing_buffers()
@@ -298,11 +301,13 @@ def wrapper(self, *args, **kwargs):
298301
"triton.cudagraphs" : False,
299302
}}
300303
304+
{grpo_selective_log_softmax_code}
301305
{selective_log_softmax_code}
302306
{calculate_pad_tokens_in_prompt_code}
303307
{create_completion_attention_mask_code}
304308
{left_pack_padding_code}
305309
{align_logprobs_with_mask_code}
310+
{autotune_batch_and_chunks_code}
306311
307312
{RL_pre}
308313
@@ -319,17 +324,36 @@ class Unsloth{RLConfig_name}({RLConfig_name}):
319324
default = -1,
320325
metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}},
321326
)
327+
unsloth_logit_chunk_multiplier : Optional[int] = field(
328+
default = None,
329+
metadata = {{'help': 'Multiplier for chunked logit computations.'}},
330+
)
331+
unsloth_grpo_mini_batch : Optional[int] = field(
332+
default = None,
333+
metadata = {{'help': 'Mini batch size for GRPO hidden state accumulation. Default is None unless user defines it.'}},
334+
)
322335
{max_seq_length_pre}
323336
def __init__({RLConfig_arguments},
324337
vllm_sampling_params = None,
325338
unsloth_num_chunks = -1,
339+
unsloth_logit_chunk_multiplier = None,
340+
unsloth_grpo_mini_batch = None,
326341
{max_seq_length_call}
327342
**kwargs,
328343
):
329344
{RLConfig_extra_args}
330345
super().__init__({RLConfig_call_args}{RLConfig_kwargs})
331346
self.vllm_sampling_params = vllm_sampling_params
332347
self.unsloth_num_chunks = unsloth_num_chunks
348+
if unsloth_grpo_mini_batch is not None:
349+
if self.generation_batch_size >= unsloth_grpo_mini_batch:
350+
self.unsloth_grpo_mini_batch = unsloth_grpo_mini_batch
351+
else:
352+
raise ValueError(
353+
f"Unsloth GRPO mini batch size needs to be less than or equal to the effective generation batch size, "
354+
f"which is self.per_device_train_batch_size * gradient_accumulation_steps."
355+
)
356+
self.unsloth_logit_chunk_multiplier = unsloth_logit_chunk_multiplier
333357
{max_seq_length_post}
334358
pass
335359
@@ -1027,6 +1051,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
10271051

10281052
# Selective log softmax and other functions
10291053
selective_log_softmax_code = inspect.getsource(selective_log_softmax)
1054+
grpo_selective_log_softmax_code = inspect.getsource(grpo_selective_log_softmax)
10301055
calculate_pad_tokens_in_prompt_code = inspect.getsource(
10311056
calculate_pad_tokens_in_prompt
10321057
)
@@ -1035,6 +1060,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
10351060
)
10361061
left_pack_padding_code = inspect.getsource(left_pack_padding)
10371062
align_logprobs_with_mask_code = inspect.getsource(align_logprobs_with_mask)
1063+
autotune_batch_and_chunks_code = inspect.getsource(autotune_batch_and_chunks)
10381064
# Get final source code
10391065
RLTrainer_source = RLTrainer_replacement.format(
10401066
RLTrainer_name = RLTrainer_name,
@@ -1056,8 +1082,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
10561082
max_seq_length_call = max_seq_length_call,
10571083
max_seq_length_post = max_seq_length_post,
10581084
selective_log_softmax_code = selective_log_softmax_code,
1085+
grpo_selective_log_softmax_code = grpo_selective_log_softmax_code,
10591086
calculate_pad_tokens_in_prompt_code = calculate_pad_tokens_in_prompt_code,
10601087
create_completion_attention_mask_code = create_completion_attention_mask_code,
1088+
autotune_batch_and_chunks_code = autotune_batch_and_chunks_code,
10611089
left_pack_padding_code = left_pack_padding_code,
10621090
align_logprobs_with_mask_code = align_logprobs_with_mask_code,
10631091
)
@@ -1166,6 +1194,41 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
11661194
"model = self._prepare_peft_model(model, peft_config, args)\n", "pass\n"
11671195
)
11681196

1197+
# Skip add_adapter("ref") for reference model computation
1198+
# Unsloth: We comment out the "ref" adapter creation because:
1199+
# 1. We want to use the original BASE MODEL as the reference model, not the SFT/LoRA model
1200+
# 2. PEFT doesn't allow multiple adapters when target_parameters is used (MoE models)
1201+
# When "ref" is not in peft_config, GRPO/RLOO fallback uses disable_adapter()
1202+
# which gives the base model logits - exactly what we want
1203+
add_adapter_block_pattern = (
1204+
r"([ \t]*)" # Capture leading indentation
1205+
r"if\s+is_peft_available\(\)\s+and\s+is_peft_model\(model\)\s+and\s+args\.beta\s*!=\s*0\.0\s*:"
1206+
r"(.*?)" # Match the entire block until ref_param.data.copy_
1207+
r"ref_param\.data\.copy_\(param\.data\)"
1208+
)
1209+
1210+
def comment_out_block(match):
1211+
"""Comment out each line in the matched block, preserving indentation."""
1212+
full_match = match.group(0)
1213+
indent = match.group(1)
1214+
lines = full_match.split("\n")
1215+
commented_lines = []
1216+
# Add explanation comment first
1217+
commented_lines.append(
1218+
f"{indent}# Unsloth: Commented out - use base model as reference, not SFT/LoRA model"
1219+
)
1220+
# Comment out each line - insert # after leading whitespace to preserve indentation
1221+
for line in lines:
1222+
if line.strip():
1223+
stripped = line.lstrip()
1224+
leading_ws = line[: len(line) - len(stripped)]
1225+
commented_lines.append(f"{leading_ws}# {stripped}")
1226+
else:
1227+
commented_lines.append(line)
1228+
return "\n".join(commented_lines)
1229+
1230+
init = re.sub(add_adapter_block_pattern, comment_out_block, init, flags = re.DOTALL)
1231+
11691232
# Set use_vllm if not set
11701233
if "args.use_vllm" in init and "model" in init and "args" in init:
11711234
# .*? matches first match. .+? matches final match.

0 commit comments

Comments
 (0)