Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/experimental/test_async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_init_minimal(self):
rollout_worker=_StubRolloutWorker(AutoTokenizer.from_pretrained(model_id), dataset, num_generations=3),
)

def test_training(self):
def test_train(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")

Expand Down
14 changes: 7 additions & 7 deletions tests/experimental/test_bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_train(self, config_name):

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the parameters have changed
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_train_with_precompute(self):

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the parameters have changed
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_train_without_providing_ref_model(self):

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the parameters have changed
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
Expand Down Expand Up @@ -298,7 +298,7 @@ def embed_prompt(input_ids, attention_mask, model):

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the parameters have changed
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
Expand Down Expand Up @@ -335,7 +335,7 @@ def test_train_without_providing_ref_model_with_lora(self):

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the parameters have changed
# Check that the params have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
Expand Down Expand Up @@ -381,7 +381,7 @@ def test_lora_train_and_save(self):
lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, task_type="CAUSAL_LM")
tokenizer = AutoTokenizer.from_pretrained(model_id)

dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")
dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train")

training_args = BCOConfig(
output_dir=self.tmp_dir,
Expand All @@ -393,7 +393,7 @@ def test_lora_train_and_save(self):
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dataset["train"],
train_dataset=dataset,
peft_config=lora_config,
)

Expand Down
28 changes: 14 additions & 14 deletions tests/experimental/test_cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_cpo_trainer(self, name, loss_type, config_name):
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)
dataset = load_dataset("trl-internal-testing/zen", config_name)

if name == "qwen":
model = self.model
Expand All @@ -74,8 +74,8 @@ def test_cpo_trainer(self, name, loss_type, config_name):
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand All @@ -84,7 +84,7 @@ def test_cpo_trainer(self, name, loss_type, config_name):

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the parameters have changed
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
Expand Down Expand Up @@ -124,14 +124,14 @@ def test_cpo_trainer_with_lora(self, config_name):
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)
dataset = load_dataset("trl-internal-testing/zen", config_name)

trainer = CPOTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
peft_config=lora_config,
)

Expand All @@ -141,15 +141,15 @@ def test_cpo_trainer_with_lora(self, config_name):

assert trainer.state.log_history[-1]["train_loss"] is not None

# Check that the parameters have changed
# Check that the params have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
assert not torch.equal(param, new_param)

def test_compute_metrics(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

def dummy_compute_metrics(*args, **kwargs):
return {"test": 0.0}
Expand All @@ -169,8 +169,8 @@ def dummy_compute_metrics(*args, **kwargs):
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
compute_metrics=dummy_compute_metrics,
)

Expand All @@ -194,14 +194,14 @@ def test_alphapo_trainer(self):
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = CPOTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
Expand Down
4 changes: 2 additions & 2 deletions tests/experimental/test_dppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_topk_tv_requires_topk_inputs(self):
@pytest.mark.low_priority
class TestDPPOTrainer(TrlTestCase):
@pytest.mark.parametrize("divergence_type", ["binary_tv", "binary_kl"])
def test_training_binary(self, divergence_type):
def test_train_binary(self, divergence_type):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = DPPOConfig(
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_training_binary(self, divergence_type):
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize("config_name", ["standard_prompt_only", "conversational_prompt_only"])
def test_training_conversational(self, config_name):
def test_train_conversational(self, config_name):
dataset = load_dataset("trl-internal-testing/zen", config_name, split="train")

training_args = DPPOConfig(
Expand Down
16 changes: 8 additions & 8 deletions tests/experimental/test_gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,14 @@ def test_gkd_trainer(self):
per_device_eval_batch_size=2,
report_to="none",
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")

trainer = GKDTrainer(
model=self.model_id,
teacher_model=self.model_id,
args=training_args,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=self.tokenizer,
)

Expand All @@ -243,13 +243,13 @@ def test_gkd_trainer_with_liger(self):
report_to="none",
use_liger_kernel=True,
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")

trainer = GKDTrainer(
model=self.model_id,
teacher_model=self.model_id,
args=training_args,
train_dataset=dummy_dataset["train"],
train_dataset=dataset,
processing_class=self.tokenizer,
)

Expand All @@ -264,14 +264,14 @@ def test_gkd_trainer_with_liger(self):

def test_generation_config_init(self):
training_args = GKDConfig(output_dir=self.tmp_dir)
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling")

trainer = GKDTrainer(
model=self.model_id,
teacher_model=self.model_id,
args=training_args,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=self.tokenizer,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/experimental/test_grpo_with_replay_buffer_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def test_update_with_inputs_different_seq_len(self):
@pytest.mark.low_priority
@pytest.mark.parametrize("scale_rewards", ["batch", "group"])
class TestGRPOWithReplayBufferTrainer(TrlTestCase):
def test_training_with_replay_buffer(self, scale_rewards):
def test_train_with_replay_buffer(self, scale_rewards):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

# Guarantee that some rewards have 0 std
Expand Down
2 changes: 1 addition & 1 deletion tests/experimental/test_gspo_token_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


class TestGSPOTokenTrainer(TrlTestCase):
def test_training(self):
def test_train(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
Expand Down
Loading
Loading