Skip to content

Jacklanchantin/online training #1124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 57 commits into
base: online_training
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
59171ec
Add Skywork Reward Model
Mar 15, 2025
73f8e2e
reorder
Mar 15, 2025
0b80b35
working checkpoint with skywork
Mar 18, 2025
15ba70c
unify VllmReward class
Mar 18, 2025
861d3ab
cleanup wrap text
Mar 18, 2025
532b5df
actually not quite working. needs dp gather and scatter
Mar 18, 2025
3fad610
add gather scatter for vllm rm
Mar 19, 2025
bae011f
working checkpoint
Mar 19, 2025
1e41be2
working checkpoint
Mar 19, 2025
0222f95
updates
Mar 20, 2025
d486d78
Instantiate metric recorders on rank 0 only (#1072)
cbalioglu Feb 27, 2025
36df0ca
fix gather bug
Mar 25, 2025
76a8a69
cleanup
Mar 25, 2025
21a9d5d
comment
Mar 25, 2025
94a52c8
add grpo (runnable, but not increasing rewards yet
Mar 26, 2025
26d089f
log outputs
Mar 26, 2025
4855903
cleannup
Mar 27, 2025
72ae45b
rename
Apr 1, 2025
cbbf856
merge
jacklanchantin Apr 1, 2025
5405196
merge
jacklanchantin Apr 1, 2025
a1b279c
fixes
jacklanchantin Apr 1, 2025
58f7c58
testing online dpo after merge
jacklanchantin Apr 2, 2025
2d1570c
bug fix
jacklanchantin Apr 2, 2025
c8cfe28
fixing merge errors
jacklanchantin Apr 2, 2025
9129ac6
fix bugs
jacklanchantin Apr 2, 2025
38b4250
Merge branch 'online_training' of github.com:facebookresearch/fairseq…
jacklanchantin Apr 2, 2025
3e9e3f9
merged with online_training
jacklanchantin Apr 2, 2025
f558efc
update grpo
jacklanchantin Apr 2, 2025
3210199
cleanup
jacklanchantin Apr 2, 2025
6d57442
fix grpo bug
jacklanchantin Apr 2, 2025
5114f29
cleanup
jacklanchantin Apr 2, 2025
2190ce5
cleanup
jacklanchantin Apr 2, 2025
96a7087
isort/black
jacklanchantin Apr 2, 2025
2bc9fb0
move vllm generate_rewards to its own function
jacklanchantin Apr 2, 2025
9ed70f9
refactor how reward models use prompt_batch
jacklanchantin Apr 2, 2025
b558aab
remove breakpoint
jacklanchantin Apr 2, 2025
d206466
working chkpt
jacklanchantin Apr 2, 2025
7104683
remove irrelevant skywork stuff
jacklanchantin Apr 3, 2025
1136b4e
online training for wildchat working
jacklanchantin Apr 9, 2025
7ed08c1
new RM
jacklanchantin Apr 11, 2025
aebc7c8
update logging
jacklanchantin Apr 14, 2025
18c15ad
logging
jacklanchantin Apr 14, 2025
ece03e9
new rm
jacklanchantin Apr 16, 2025
c290e5d
changes
jacklanchantin Apr 18, 2025
9e395ef
revert
jacklanchantin Apr 18, 2025
f7065dc
merge
jacklanchantin Apr 18, 2025
8ddfc54
merge
jacklanchantin Apr 18, 2025
1264cfd
merge
jacklanchantin Apr 18, 2025
838a69f
merge
jacklanchantin Apr 18, 2025
8a737fc
sync at step 0
jacklanchantin Apr 18, 2025
ba549a9
Merge branch 'online_training' of github.com:facebookresearch/fairseq…
jacklanchantin Apr 18, 2025
4e26a98
cleanup
jacklanchantin Apr 19, 2025
36c6c10
updates
jacklanchantin Apr 23, 2025
156ca97
update
jacklanchantin Apr 23, 2025
a2515f8
merge
jacklanchantin Apr 24, 2025
0ed0ac0
revert wandb
jacklanchantin Apr 24, 2025
6aed0dd
starting set_data_epoch
jacklanchantin Apr 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/fairseq2/metrics/recorders/_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
name: str,
output_dir: Path,
metric_descriptors: Provider[MetricDescriptor],
id: str | None = None,
) -> None:
"""
:param project: The W&B project name.
Expand All @@ -60,7 +61,7 @@ def __init__(
self._run = None
else:
self._run = wandb.init(
project=project, name=name, dir=output_dir.parent, resume="allow"
project=project, name=name, id=id, dir=output_dir.parent, resume="allow"
)

self._metric_descriptors = metric_descriptors
Expand All @@ -77,6 +78,12 @@ def record_metrics(
if self._run is None:
return

# try:
# self._run.log({"_step": step_nr}) # Log to the specific step
# self._run.step = step_nr # Directly update the internal step counter
# except:
# ...

for name, value in values.items():
try:
descriptor = self._metric_descriptors.get(name)
Expand All @@ -88,6 +95,8 @@ def record_metrics(
else:
display_name = descriptor.display_name

display_name = run + "/" + display_name

try:
self._run.log({display_name: value}, step=step_nr)
except RuntimeError as ex:
Expand All @@ -112,6 +121,8 @@ class WandbRecorderConfig:

run: str | None = None

id: str | None = None

def validate(self) -> None:
result = ValidationResult()

Expand Down Expand Up @@ -151,7 +162,11 @@ def create(self, output_dir: Path, config: object) -> MetricRecorder:
wandb_dir = output_dir.joinpath("wandb")

return WandbRecorder(
config.project, config.run, wandb_dir, self._metric_descriptors
config.project,
config.run,
wandb_dir,
self._metric_descriptors,
id=config.id,
)

@property
Expand Down
1 change: 1 addition & 0 deletions src/fairseq2/recipes/common/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def create_trainer(
valid_data_readers=valid_data_readers,
validate_after_n_steps=regime_section.validate_after_n_steps,
validate_every_n_steps=regime_section.validate_every_n_steps,
validate_step_0=regime_section.validate_step_0,
validate_after_n_data_epochs=regime_section.validate_after_n_data_epochs,
validate_every_n_data_epochs=regime_section.validate_every_n_data_epochs,
checkpoint_manager=checkpoint_manager,
Expand Down
3 changes: 3 additions & 0 deletions src/fairseq2/recipes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ class RegimeSection:
validate_every_n_steps: int | None = None
"""The step interval at which to validate the model."""

validate_step_0: bool = False
"""Validate before training"""

validate_after_n_data_epochs: int = 0

validate_every_n_data_epochs: int | None = None
Expand Down
8 changes: 8 additions & 0 deletions src/fairseq2/recipes/lm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -161,6 +161,14 @@
from fairseq2.recipes.lm._online_finetune._rewards import (
SkyworkVerifierHandler as SkyworkVerifierHandler,
)

from fairseq2.recipes.lm._online_finetune._rewards import (
AtheneVerifier as AtheneVerifier,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
AtheneVerifierHandler as AtheneVerifierHandler,
)

from fairseq2.recipes.lm._online_finetune._rewards import (
NuminaMathVerifier as NuminaMathVerifier,
)
Expand Down
100 changes: 100 additions & 0 deletions src/fairseq2/recipes/lm/_online_finetune/_diversity_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import string as string_lib
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

Check failure on line 2 in src/fairseq2/recipes/lm/_online_finetune/_diversity_metrics.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Cannot find implementation or library stub for module named "nltk.translate.bleu_score"
import gzip
import torch


def get_compression_ratio(strings):

Check failure on line 7 in src/fairseq2/recipes/lm/_online_finetune/_diversity_metrics.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a type annotation

flattened_generation = " ".join(strings)
original_byte_size = len(bytes(flattened_generation, "UTF-8"))
compressed_bytes_size = len(gzip.compress(bytes(flattened_generation, "UTF-8")))

cr = compressed_bytes_size / original_byte_size
cr_tensor = torch.Tensor([cr])
return cr_tensor


def get_self_bleu_score(strings):
# Create a translation table to remove punctuation
translator = str.maketrans("", "", string_lib.punctuation)

# Preprocess the strings: convert to lowercase and remove punctuation
cleaned_strings = [s.lower().translate(translator) for s in strings]

# Tokenize the cleaned strings into lists of words
tokenized_strings = [s.split() for s in cleaned_strings]

# Initialize a dictionary to store BLEU scores
bleu_scores = []

# Calculate BLEU scores for all pairs of strings
for i in range(len(tokenized_strings)):
for j in range(i + 1, len(tokenized_strings)):
# Use smoothing to handle cases where there are no n-grams in common
smoothie = SmoothingFunction().method4
bleu = sentence_bleu(
[tokenized_strings[i]],
tokenized_strings[j],
smoothing_function=smoothie,
)

# Store the BLEU score
bleu_scores.append(bleu)

mean_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
mean_bleu_score_tensor = torch.Tensor([mean_bleu_score])
return mean_bleu_score_tensor


def get_unique_1grams(strings):

# Initialize an empty set to store unique 1-grams
unique_words = set()
total_words = 0

# Create a translation table to remove punctuation
translator = str.maketrans("", "", string_lib.punctuation)

# Iterate over each string in the list
for string in strings:
# Convert the string to lowercase and remove punctuation
cleaned_string = string.lower().translate(translator)

# Split the cleaned string into words (1-grams) and update the set
words = cleaned_string.split()
total_words += len(words)
unique_words.update(words)

# Return the set of unique 1-grams
num_unique_1grams = len(unique_words)
num_unique_1grams_norm = len(unique_words) / total_words if total_words > 0 else 0
num_unique_1grams_tensor = torch.Tensor([num_unique_1grams])
num_unique_1grams_norm = torch.Tensor([num_unique_1grams_norm])
return num_unique_1grams_tensor, num_unique_1grams_norm


def extract_logprobs(data):
logprobs = []
for item in data:
for key, logprob in item.items():
logprobs.append(logprob.logprob)
return logprobs


def get_entropy(rollouts):
batch_sum_logprobs = []
batch_sum_logprobs_per_tok = []
for rollout_idx in range(len(rollouts[0].outputs)):
logprobs = extract_logprobs(rollouts[0].outputs[rollout_idx].logprobs)

sum_logprobs = -sum(logprobs)
sum_logprobs_per_tok = -sum(logprobs) / len(logprobs)

batch_sum_logprobs.append(sum_logprobs)
batch_sum_logprobs_per_tok.append(sum_logprobs_per_tok)

entropy = sum(batch_sum_logprobs) / len(batch_sum_logprobs)
entropy_norm = sum(batch_sum_logprobs_per_tok) / len(batch_sum_logprobs_per_tok)

return entropy, entropy_norm
Loading
Loading