From dd4f3a6ffbc0dd52f427c63a582dfe2497301aec Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 12 Mar 2026 12:51:52 +0100 Subject: [PATCH 01/12] Add Nemotron 3 to tests via tiny model --- scripts/generate_tiny_models.py | 28 ++++++++++++++++++++++++++++ tests/test_dpo_trainer.py | 11 +++++++++++ tests/test_sft_trainer.py | 14 +++++++++++++- 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 51e0444e6a6..51754444e17 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -58,6 +58,8 @@ LlavaForConditionalGeneration, LlavaNextForConditionalGeneration, MistralConfig, + NemotronHConfig, + NemotronHForCausalLM, MistralForCausalLM, OPTConfig, OPTForCausalLM, @@ -227,6 +229,32 @@ def init_weights_tiny_model(model): init_weights_tiny_model(model) push_to_hub(model, tokenizer, generation_config, "tiny", suffix) +# Hybrid Mamba-Attention models +tokenizer = AutoTokenizer.from_pretrained("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16") +generation_config = GenerationConfig.from_pretrained("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16") +config = NemotronHConfig( + vocab_size=len(tokenizer.vocab), + hidden_size=8, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=32, + layers_block_type=["mamba", "attention"], # 2 layers: one Mamba + one Attention + mamba_num_heads=4, + mamba_head_dim=2, + mamba_n_groups=1, + ssm_state_size=8, + mamba_d_conv=4, + mamba_expand=2, + n_routed_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=32, + moe_shared_expert_intermediate_size=32, + use_mamba_kernels=False, # CPU-friendly for testing +) +model = NemotronHForCausalLM(config).to(dtype=torch.bfloat16) +init_weights_tiny_model(model) +push_to_hub(model, tokenizer, generation_config, "tiny") + # Two slightly bigger models, required for vLLM testing tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct") generation_config = GenerationConfig.from_pretrained("Qwen/Qwen2.5-32B-Instruct") diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index dd5f4fa4699..33c66ccc257 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -139,17 +139,28 @@ class TestDPOTrainer(TrlTestCase): "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", "trl-internal-testing/tiny-Qwen3MoeForCausalLM", "trl-internal-testing/tiny-GptOssForCausalLM", + pytest.param( + "trl-internal-testing/tiny-NemotronHForCausalLM", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.3.0"), + reason="NemotronH models were introduced in transformers-5.3.0", + ), + ), ], ) def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + # NemotronH does not support gradient checkpointing + gradient_checkpointing = "NemotronH" not in model_id + # Initialize the trainer training_args = DPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates report_to="none", + gradient_checkpointing=gradient_checkpointing, ) trainer = DPOTrainer(model=model_id, args=training_args, train_dataset=dataset) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 442b9c247ea..b947dd46c26 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -284,14 +284,26 @@ def test_init_with_training_arguments(self): "trl-internal-testing/tiny-GptOssForCausalLM", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", "trl-internal-testing/tiny-Qwen3MoeForCausalLM", + pytest.param( + "trl-internal-testing/tiny-NemotronHForCausalLM", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.3.0"), + reason="NemotronH models were introduced in transformers-5.3.0", + ), + ), ], ) def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + # NemotronH does not support gradient checkpointing + gradient_checkpointing = "NemotronH" not in model_id + # Initialize the trainer - training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + training_args = SFTConfig( + output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=gradient_checkpointing + ) trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later From 529ef04f756d6d5222fc41a56222af04da0b30d7 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 12 Mar 2026 12:55:36 +0100 Subject: [PATCH 02/12] Code quality --- scripts/generate_tiny_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 51754444e17..393b9545f44 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -58,9 +58,9 @@ LlavaForConditionalGeneration, LlavaNextForConditionalGeneration, MistralConfig, + MistralForCausalLM, NemotronHConfig, NemotronHForCausalLM, - MistralForCausalLM, OPTConfig, OPTForCausalLM, PaliGemmaForConditionalGeneration, From b79861fb0957e7bf3b5a331196e853ab2cd1a45e Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 12 Mar 2026 14:41:32 +0100 Subject: [PATCH 03/12] Updated --- scripts/generate_tiny_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 393b9545f44..2f30ecf28aa 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -234,15 +234,15 @@ def init_weights_tiny_model(model): generation_config = GenerationConfig.from_pretrained("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16") config = NemotronHConfig( vocab_size=len(tokenizer.vocab), - hidden_size=8, + hidden_size=16, num_attention_heads=4, num_key_value_heads=2, intermediate_size=32, layers_block_type=["mamba", "attention"], # 2 layers: one Mamba + one Attention mamba_num_heads=4, - mamba_head_dim=2, + mamba_head_dim=8, mamba_n_groups=1, - ssm_state_size=8, + ssm_state_size=16, mamba_d_conv=4, mamba_expand=2, n_routed_experts=4, From 4f9b1f88121c0846cb04bfa5df7b8c058aea67d4 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 12 Mar 2026 15:35:22 +0100 Subject: [PATCH 04/12] Update --- scripts/generate_tiny_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 2f30ecf28aa..b6a295ce68b 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -239,8 +239,8 @@ def init_weights_tiny_model(model): num_key_value_heads=2, intermediate_size=32, layers_block_type=["mamba", "attention"], # 2 layers: one Mamba + one Attention - mamba_num_heads=4, - mamba_head_dim=8, + mamba_num_heads=8, + mamba_head_dim=4, mamba_n_groups=1, ssm_state_size=16, mamba_d_conv=4, From c5550400560b962e67002109f0eda6dbde746d84 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 13 Mar 2026 12:37:31 +0100 Subject: [PATCH 05/12] Update --- tests/test_dpo_trainer.py | 9 ++++++--- tests/test_sft_trainer.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 6c151d5ef6f..3a0d28a5c6f 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -184,15 +184,18 @@ def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") - # NemotronH does not support gradient checkpointing - gradient_checkpointing = "NemotronH" not in model_id + # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA + # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions. + # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels). + is_nemotron = "NemotronH" in model_id # Initialize the trainer training_args = DPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates report_to="none", - gradient_checkpointing=gradient_checkpointing, + gradient_checkpointing=not is_nemotron, + use_cpu=is_nemotron, ) trainer = DPOTrainer(model=model_id, args=training_args, train_dataset=dataset) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index b947dd46c26..4a19c72652d 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -297,12 +297,17 @@ def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") - # NemotronH does not support gradient checkpointing - gradient_checkpointing = "NemotronH" not in model_id + # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA + # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions. + # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels). + is_nemotron = "NemotronH" in model_id # Initialize the trainer training_args = SFTConfig( - output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=gradient_checkpointing + output_dir=self.tmp_dir, + report_to="none", + gradient_checkpointing=not is_nemotron, + use_cpu=is_nemotron, ) trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) From 3c8f9d4f3a7228833827d6ba0086ce319fd8064a Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 13 Mar 2026 16:28:48 +0100 Subject: [PATCH 06/12] Cursor review --- tests/test_dpo_trainer.py | 8 +++++--- tests/test_sft_trainer.py | 12 +++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 3a0d28a5c6f..c41fe94d61f 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -187,15 +187,17 @@ def test_train(self, model_id): # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions. # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels). - is_nemotron = "NemotronH" in model_id + kwargs = {} + if "NemotronH" in model_id: + kwargs["gradient_checkpointing"] = False + kwargs["use_cpu"] = True # Initialize the trainer training_args = DPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates report_to="none", - gradient_checkpointing=not is_nemotron, - use_cpu=is_nemotron, + **kwargs, ) trainer = DPOTrainer(model=model_id, args=training_args, train_dataset=dataset) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 4a19c72652d..2bd479175dc 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -300,15 +300,13 @@ def test_train(self, model_id): # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions. # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels). - is_nemotron = "NemotronH" in model_id + kwargs = {} + if "NemotronH" in model_id: + kwargs["gradient_checkpointing"] = False + kwargs["use_cpu"] = True # Initialize the trainer - training_args = SFTConfig( - output_dir=self.tmp_dir, - report_to="none", - gradient_checkpointing=not is_nemotron, - use_cpu=is_nemotron, - ) + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", **kwargs) trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later From 9eceafa53dd87c7780c8f56b3326a1677186f7ef Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 18 Mar 2026 15:41:04 +0100 Subject: [PATCH 07/12] Update --- scripts/generate_tiny_models.py | 7 +++++++ tests/test_dpo_trainer.py | 5 +---- tests/test_sft_trainer.py | 5 +---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index b6a295ce68b..c74590d324b 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -253,6 +253,13 @@ def init_weights_tiny_model(model): ) model = NemotronHForCausalLM(config).to(dtype=torch.bfloat16) init_weights_tiny_model(model) +# NemotronH keeps mixer.D and mixer.A_log in float32 in the reference model; mirror that here. +for layer in model.model.layers: + if hasattr(layer, "mixer"): + if hasattr(layer.mixer, "D"): + layer.mixer.D.data = layer.mixer.D.data.float() + if hasattr(layer.mixer, "A_log"): + layer.mixer.A_log.data = layer.mixer.A_log.data.float() push_to_hub(model, tokenizer, generation_config, "tiny") # Two slightly bigger models, required for vLLM testing diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index abefdd9dc44..cdcf83f77fb 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -184,13 +184,10 @@ def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") - # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA - # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions. - # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels). + # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing kwargs = {} if "NemotronH" in model_id: kwargs["gradient_checkpointing"] = False - kwargs["use_cpu"] = True # Initialize the trainer training_args = DPOConfig( diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 2bd479175dc..d3916ed3772 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -297,13 +297,10 @@ def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") - # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. The Mamba CUDA - # kernels require strides to be multiples of 8, which is incompatible with tiny model dimensions. - # Force CPU so that the model uses the pure PyTorch path (works fine on GPU without kernels). + # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing kwargs = {} if "NemotronH" in model_id: kwargs["gradient_checkpointing"] = False - kwargs["use_cpu"] = True # Initialize the trainer training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", **kwargs) From b965041085a11f05c9a7c25b0974ce7e10c0e69a Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 19 Mar 2026 11:27:08 +0100 Subject: [PATCH 08/12] Updated --- tests/test_dpo_trainer.py | 23 +++++++++++++++-------- tests/test_sft_trainer.py | 12 +++++++++--- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index cdcf83f77fb..b5d68646cdf 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import nullcontext +from unittest.mock import patch + import pytest import torch import transformers @@ -184,19 +187,23 @@ def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") - # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing + # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. + # Workaround: hide kernels package so transformers doesn't unconditionally load mamba CUDA kernels. + # See: https://github.com/huggingface/transformers/pull/44853 kwargs = {} + ctx = patch("transformers.integrations.hub_kernels.lazy_load_kernel", return_value=None) if "NemotronH" in model_id else nullcontext() if "NemotronH" in model_id: kwargs["gradient_checkpointing"] = False # Initialize the trainer - training_args = DPOConfig( - output_dir=self.tmp_dir, - learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates - report_to="none", - **kwargs, - ) - trainer = DPOTrainer(model=model_id, args=training_args, train_dataset=dataset) + with ctx: + training_args = DPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + **kwargs, + ) + trainer = DPOTrainer(model=model_id, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index d3916ed3772..286d1eab4b8 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -15,6 +15,8 @@ import gc import json import pathlib +from contextlib import nullcontext +from unittest.mock import patch from unittest.mock import MagicMock, patch import pytest @@ -297,14 +299,18 @@ def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") - # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing + # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. + # Workaround: hide kernels package so transformers doesn't unconditionally load mamba CUDA kernels. + # See: https://github.com/huggingface/transformers/pull/44853 kwargs = {} + ctx = patch("transformers.integrations.hub_kernels.lazy_load_kernel", return_value=None) if "NemotronH" in model_id else nullcontext() if "NemotronH" in model_id: kwargs["gradient_checkpointing"] = False # Initialize the trainer - training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", **kwargs) - trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) + with ctx: + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", **kwargs) + trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} From 6ba7b760b0cfb12df041ca6bd3329769dbeb69ba Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 19 Mar 2026 11:36:56 +0100 Subject: [PATCH 09/12] code quality --- tests/test_dpo_trainer.py | 6 +++++- tests/test_sft_trainer.py | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index b5d68646cdf..b2a227dcde0 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -191,7 +191,11 @@ def test_train(self, model_id): # Workaround: hide kernels package so transformers doesn't unconditionally load mamba CUDA kernels. # See: https://github.com/huggingface/transformers/pull/44853 kwargs = {} - ctx = patch("transformers.integrations.hub_kernels.lazy_load_kernel", return_value=None) if "NemotronH" in model_id else nullcontext() + ctx = ( + patch("transformers.integrations.hub_kernels.lazy_load_kernel", return_value=None) + if "NemotronH" in model_id + else nullcontext() + ) if "NemotronH" in model_id: kwargs["gradient_checkpointing"] = False diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 286d1eab4b8..142581eea38 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -16,7 +16,6 @@ import json import pathlib from contextlib import nullcontext -from unittest.mock import patch from unittest.mock import MagicMock, patch import pytest @@ -303,7 +302,11 @@ def test_train(self, model_id): # Workaround: hide kernels package so transformers doesn't unconditionally load mamba CUDA kernels. # See: https://github.com/huggingface/transformers/pull/44853 kwargs = {} - ctx = patch("transformers.integrations.hub_kernels.lazy_load_kernel", return_value=None) if "NemotronH" in model_id else nullcontext() + ctx = ( + patch("transformers.integrations.hub_kernels.lazy_load_kernel", return_value=None) + if "NemotronH" in model_id + else nullcontext() + ) if "NemotronH" in model_id: kwargs["gradient_checkpointing"] = False From aef3814f6314708f5cad4891459cfa5e17e29849 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 20 Apr 2026 10:56:02 +0200 Subject: [PATCH 10/12] Replace hasattr with explicit Mamba layer access in NemotronH tiny model --- scripts/generate_tiny_models.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index b8d2d1d7f07..da6e1ac32da 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -255,12 +255,10 @@ def init_weights_tiny_model(model): model = NemotronHForCausalLM(config).to(dtype=torch.bfloat16) init_weights_tiny_model(model) # NemotronH keeps mixer.D and mixer.A_log in float32 in the reference model; mirror that here. -for layer in model.model.layers: - if hasattr(layer, "mixer"): - if hasattr(layer.mixer, "D"): - layer.mixer.D.data = layer.mixer.D.data.float() - if hasattr(layer.mixer, "A_log"): - layer.mixer.A_log.data = layer.mixer.A_log.data.float() +# Layer 0 is the Mamba layer per layers_block_type above. +mamba_layer = model.model.layers[0] +mamba_layer.mixer.D.data = mamba_layer.mixer.D.data.float() +mamba_layer.mixer.A_log.data = mamba_layer.mixer.A_log.data.float() push_to_hub(model, tokenizer, generation_config, "tiny") # Two slightly bigger models, required for vLLM testing From ee4906bd1aa6e608cf1825eafc129ba9bfc948b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 20 Apr 2026 17:51:45 +0000 Subject: [PATCH 11/12] revert grad chkpt patch --- tests/test_data_utils.py | 7 +++++++ tests/test_dpo_trainer.py | 28 ++++++---------------------- tests/test_sft_trainer.py | 18 ++---------------- 3 files changed, 15 insertions(+), 38 deletions(-) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index a2b0d4a117f..cf4111d2a72 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -542,6 +542,13 @@ class TestApplyChatTemplate(TrlTestCase): "trl-internal-testing/tiny-LlamaForCausalLM-3", "trl-internal-testing/tiny-MistralForCausalLM-0.1", "trl-internal-testing/tiny-MistralForCausalLM-0.2", + pytest.param( + "trl-internal-testing/tiny-NemotronHForCausalLM", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.3.0"), + reason="NemotronH models were introduced in transformers-5.3.0", + ), + ), "trl-internal-testing/tiny-Phi3ForCausalLM", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", "trl-internal-testing/tiny-Qwen3ForCausalLM", diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 980a23e10d9..aae5bf5a74e 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext -from unittest.mock import patch import pytest import torch @@ -187,27 +185,13 @@ def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") - # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. - # Workaround: hide kernels package so transformers doesn't unconditionally load mamba CUDA kernels. - # See: https://github.com/huggingface/transformers/pull/44853 - kwargs = {} - ctx = ( - patch("transformers.integrations.hub_kernels.lazy_load_kernel", return_value=None) - if "NemotronH" in model_id - else nullcontext() - ) - if "NemotronH" in model_id: - kwargs["gradient_checkpointing"] = False - # Initialize the trainer - with ctx: - training_args = DPOConfig( - output_dir=self.tmp_dir, - learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates - report_to="none", - **kwargs, - ) - trainer = DPOTrainer(model=model_id, args=training_args, train_dataset=dataset) + training_args = DPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + ) + trainer = DPOTrainer(model=model_id, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 4ec739e8779..4761a4860e9 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -15,7 +15,6 @@ import gc import json import pathlib -from contextlib import nullcontext from unittest.mock import MagicMock, patch import pytest @@ -372,22 +371,9 @@ def test_train(self, model_id): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") - # NemotronH (hybrid Mamba-Attention) does not support gradient checkpointing. - # Workaround: hide kernels package so transformers doesn't unconditionally load mamba CUDA kernels. - # See: https://github.com/huggingface/transformers/pull/44853 - kwargs = {} - ctx = ( - patch("transformers.integrations.hub_kernels.lazy_load_kernel", return_value=None) - if "NemotronH" in model_id - else nullcontext() - ) - if "NemotronH" in model_id: - kwargs["gradient_checkpointing"] = False - # Initialize the trainer - with ctx: - training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", **kwargs) - trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) + training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none") + trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} From b12633bf9fbbc5538f11b4e0291e48f298a77fa6 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 27 Apr 2026 10:45:30 +0200 Subject: [PATCH 12/12] Bumped transformers version for NemotronH tests --- examples/scripts/sft_nemotron_3.py | 7 ++----- tests/test_data_utils.py | 4 ++-- tests/test_dpo_trainer.py | 4 ++-- tests/test_sft_trainer.py | 4 ++-- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/examples/scripts/sft_nemotron_3.py b/examples/scripts/sft_nemotron_3.py index c3f6aa0c938..8fbf6671636 100644 --- a/examples/scripts/sft_nemotron_3.py +++ b/examples/scripts/sft_nemotron_3.py @@ -15,7 +15,7 @@ # /// script # dependencies = [ # "trl[peft,quantization]", -# "transformers>=5.3.0", +# "transformers>=5.7.0", # "trackio", # "mamba_ssm==2.2.5", # "causal_conv1d==1.5.2", @@ -27,7 +27,7 @@ Prerequisites: - pip install "transformers>=5.3.0" + pip install "transformers>=5.7.0" pip install --no-build-isolation mamba_ssm==2.2.5 pip install --no-build-isolation causal_conv1d==1.5.2 @@ -62,9 +62,6 @@ def main(script_args, training_args, model_args): - # NemotronH does not support gradient checkpointing - training_args.gradient_checkpointing = False - # Load model model_kwargs = dict( revision=model_args.model_revision, diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 98d7b29f2f1..accd47ef750 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -545,8 +545,8 @@ class TestApplyChatTemplate(TrlTestCase): pytest.param( "trl-internal-testing/tiny-NemotronHForCausalLM", marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.3.0"), - reason="NemotronH models were introduced in transformers-5.3.0", + Version(transformers.__version__) < Version("5.7.0"), + reason="NemotronH gradient checkpointing requires transformers>=5.7.0 (see transformers#45625)", ), ), "trl-internal-testing/tiny-Phi3ForCausalLM-3", diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index fd77c0f0f13..238c37145b3 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -175,8 +175,8 @@ class TestDPOTrainer(TrlTestCase): pytest.param( "trl-internal-testing/tiny-NemotronHForCausalLM", marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.3.0"), - reason="NemotronH models were introduced in transformers-5.3.0", + Version(transformers.__version__) < Version("5.7.0"), + reason="NemotronH gradient checkpointing requires transformers>=5.7.0 (see transformers#45625)", ), ), ], diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index c198eaed463..00ce9021174 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -361,8 +361,8 @@ def test_init_with_training_arguments(self): pytest.param( "trl-internal-testing/tiny-NemotronHForCausalLM", marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.3.0"), - reason="NemotronH models were introduced in transformers-5.3.0", + Version(transformers.__version__) < Version("5.7.0"), + reason="NemotronH gradient checkpointing requires transformers>=5.7.0 (see transformers#45625)", ), ), ],