Skip to content

Jacklanchantin/online training #1124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 92 commits into
base: online_training
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
59171ec
Add Skywork Reward Model
Mar 15, 2025
73f8e2e
reorder
Mar 15, 2025
0b80b35
working checkpoint with skywork
Mar 18, 2025
15ba70c
unify VllmReward class
Mar 18, 2025
861d3ab
cleanup wrap text
Mar 18, 2025
532b5df
actually not quite working. needs dp gather and scatter
Mar 18, 2025
3fad610
add gather scatter for vllm rm
Mar 19, 2025
bae011f
working checkpoint
Mar 19, 2025
1e41be2
working checkpoint
Mar 19, 2025
0222f95
updates
Mar 20, 2025
d486d78
Instantiate metric recorders on rank 0 only (#1072)
cbalioglu Feb 27, 2025
36df0ca
fix gather bug
Mar 25, 2025
76a8a69
cleanup
Mar 25, 2025
21a9d5d
comment
Mar 25, 2025
94a52c8
add grpo (runnable, but not increasing rewards yet
Mar 26, 2025
26d089f
log outputs
Mar 26, 2025
4855903
cleannup
Mar 27, 2025
72ae45b
rename
Apr 1, 2025
cbbf856
merge
jacklanchantin Apr 1, 2025
5405196
merge
jacklanchantin Apr 1, 2025
a1b279c
fixes
jacklanchantin Apr 1, 2025
58f7c58
testing online dpo after merge
jacklanchantin Apr 2, 2025
2d1570c
bug fix
jacklanchantin Apr 2, 2025
c8cfe28
fixing merge errors
jacklanchantin Apr 2, 2025
9129ac6
fix bugs
jacklanchantin Apr 2, 2025
38b4250
Merge branch 'online_training' of github.com:facebookresearch/fairseq…
jacklanchantin Apr 2, 2025
3e9e3f9
merged with online_training
jacklanchantin Apr 2, 2025
f558efc
update grpo
jacklanchantin Apr 2, 2025
3210199
cleanup
jacklanchantin Apr 2, 2025
6d57442
fix grpo bug
jacklanchantin Apr 2, 2025
5114f29
cleanup
jacklanchantin Apr 2, 2025
2190ce5
cleanup
jacklanchantin Apr 2, 2025
96a7087
isort/black
jacklanchantin Apr 2, 2025
2bc9fb0
move vllm generate_rewards to its own function
jacklanchantin Apr 2, 2025
9ed70f9
refactor how reward models use prompt_batch
jacklanchantin Apr 2, 2025
b558aab
remove breakpoint
jacklanchantin Apr 2, 2025
d206466
working chkpt
jacklanchantin Apr 2, 2025
7104683
remove irrelevant skywork stuff
jacklanchantin Apr 3, 2025
1136b4e
online training for wildchat working
jacklanchantin Apr 9, 2025
7ed08c1
new RM
jacklanchantin Apr 11, 2025
aebc7c8
update logging
jacklanchantin Apr 14, 2025
18c15ad
logging
jacklanchantin Apr 14, 2025
ece03e9
new rm
jacklanchantin Apr 16, 2025
c290e5d
changes
jacklanchantin Apr 18, 2025
9e395ef
revert
jacklanchantin Apr 18, 2025
f7065dc
merge
jacklanchantin Apr 18, 2025
8ddfc54
merge
jacklanchantin Apr 18, 2025
1264cfd
merge
jacklanchantin Apr 18, 2025
838a69f
merge
jacklanchantin Apr 18, 2025
8a737fc
sync at step 0
jacklanchantin Apr 18, 2025
ba549a9
Merge branch 'online_training' of github.com:facebookresearch/fairseq…
jacklanchantin Apr 18, 2025
4e26a98
cleanup
jacklanchantin Apr 19, 2025
36c6c10
updates
jacklanchantin Apr 23, 2025
156ca97
update
jacklanchantin Apr 23, 2025
a2515f8
merge
jacklanchantin Apr 24, 2025
0ed0ac0
revert wandb
jacklanchantin Apr 24, 2025
6aed0dd
starting set_data_epoch
jacklanchantin Apr 25, 2025
2603bef
add log_rollouts function in _online_finetune/_common.py
jacklanchantin Apr 25, 2025
ce0ed96
set default False
jacklanchantin Apr 25, 2025
b353d35
sync before validation
jacklanchantin Apr 29, 2025
e299693
add update_batch_metrics
jacklanchantin Apr 29, 2025
d559bcc
add update_batch_metrics to grpo
jacklanchantin Apr 29, 2025
39589fc
update metrics
jacklanchantin Apr 29, 2025
27b244e
update_batch_metrics in validation
jacklanchantin Apr 29, 2025
6a04935
wandb separate train/valid
jacklanchantin Apr 29, 2025
099d1aa
remove comment
jacklanchantin Apr 29, 2025
1e438f7
remove _sync_vllm_valid_model_every_n_steps in grpo
jacklanchantin Apr 29, 2025
44fa30d
merge with log_rollouts
jacklanchantin Apr 29, 2025
7d3ab9d
merge with log_rollouts
jacklanchantin Apr 29, 2025
411d369
merge
jacklanchantin Apr 29, 2025
32f3d18
name change
jacklanchantin Apr 29, 2025
53cf98a
Merge branch 'online_training' of github.com:facebookresearch/fairseq…
jacklanchantin Apr 29, 2025
63a2bcd
merge
jacklanchantin Apr 30, 2025
823223b
revert imports
jacklanchantin Apr 30, 2025
0fb50f1
rever
jacklanchantin Apr 30, 2025
f3835f9
remove sync vllm_valid_model
jacklanchantin Apr 30, 2025
6c94e16
valid temp
jacklanchantin Apr 30, 2025
aaef6f1
Merge branch 'online_training' of github.com:facebookresearch/fairseq…
jacklanchantin Apr 30, 2025
31111af
merge
jacklanchantin May 1, 2025
3083973
remove valid_n
jacklanchantin May 1, 2025
c4cfa73
revert traininer
jacklanchantin May 1, 2025
54f00f5
athene
jacklanchantin May 1, 2025
89e4b23
athene
jacklanchantin May 1, 2025
b6befe7
fix force sync bug
jacklanchantin May 2, 2025
cee96d7
check if self._step_nr exists when syncing
jacklanchantin May 2, 2025
e6c8bd3
merge
jacklanchantin May 2, 2025
dbab616
Merge branch 'online_training' of github.com:facebookresearch/fairseq…
jacklanchantin May 2, 2025
18dace7
change var name
jacklanchantin May 2, 2025
c6fc816
check if self._step_nr exists when syncing (#1160)
jacklanchantin May 2, 2025
7df6271
force_sync bug fix
jacklanchantin May 6, 2025
408ebd7
merge
jacklanchantin May 9, 2025
6e6fa19
merge
jacklanchantin May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/fairseq2/recipes/lm/_online_finetune/_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -598,3 +598,14 @@
for rollout in rollouts[0].outputs[:num_rollouts]:
rollout_text = rollout.text
log.info(f"{split_name} Rollout: {rollout_text}")


def get_rollout_lengths(rollouts: List[SequenceData]):
"""Get the lengths of the rollouts."""
rollout_lengths = []
for rollout in rollouts:
for sample in rollout.outputs:
token_ids = sample.token_ids
token_ids_len = len(token_ids)
rollout_lengths.append(token_ids_len)
return rollout_lengths
100 changes: 100 additions & 0 deletions src/fairseq2/recipes/lm/_online_finetune/_diversity_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import string as string_lib
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

Check failure on line 2 in src/fairseq2/recipes/lm/_online_finetune/_diversity_metrics.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Cannot find implementation or library stub for module named "nltk.translate.bleu_score"
import gzip
import torch


def get_compression_ratio(strings):

Check failure on line 7 in src/fairseq2/recipes/lm/_online_finetune/_diversity_metrics.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a type annotation

flattened_generation = " ".join(strings)
original_byte_size = len(bytes(flattened_generation, "UTF-8"))
compressed_bytes_size = len(gzip.compress(bytes(flattened_generation, "UTF-8")))

cr = compressed_bytes_size / original_byte_size
cr_tensor = torch.Tensor([cr])
return cr_tensor


def get_self_bleu_score(strings):
# Create a translation table to remove punctuation
translator = str.maketrans("", "", string_lib.punctuation)

# Preprocess the strings: convert to lowercase and remove punctuation
cleaned_strings = [s.lower().translate(translator) for s in strings]

# Tokenize the cleaned strings into lists of words
tokenized_strings = [s.split() for s in cleaned_strings]

# Initialize a dictionary to store BLEU scores
bleu_scores = []

# Calculate BLEU scores for all pairs of strings
for i in range(len(tokenized_strings)):
for j in range(i + 1, len(tokenized_strings)):
# Use smoothing to handle cases where there are no n-grams in common
smoothie = SmoothingFunction().method4
bleu = sentence_bleu(
[tokenized_strings[i]],
tokenized_strings[j],
smoothing_function=smoothie,
)

# Store the BLEU score
bleu_scores.append(bleu)

mean_bleu_score = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
mean_bleu_score_tensor = torch.Tensor([mean_bleu_score])
return mean_bleu_score_tensor


def get_unique_1grams(strings):

# Initialize an empty set to store unique 1-grams
unique_words = set()
total_words = 0

# Create a translation table to remove punctuation
translator = str.maketrans("", "", string_lib.punctuation)

# Iterate over each string in the list
for string in strings:
# Convert the string to lowercase and remove punctuation
cleaned_string = string.lower().translate(translator)

# Split the cleaned string into words (1-grams) and update the set
words = cleaned_string.split()
total_words += len(words)
unique_words.update(words)

# Return the set of unique 1-grams
num_unique_1grams = len(unique_words)
num_unique_1grams_norm = len(unique_words) / total_words if total_words > 0 else 0
num_unique_1grams_tensor = torch.Tensor([num_unique_1grams])
num_unique_1grams_norm = torch.Tensor([num_unique_1grams_norm])
return num_unique_1grams_tensor, num_unique_1grams_norm


def extract_logprobs(data):
logprobs = []
for item in data:
for key, logprob in item.items():
logprobs.append(logprob.logprob)
return logprobs


def get_entropy(rollouts):
batch_sum_logprobs = []
batch_sum_logprobs_per_tok = []
for rollout_idx in range(len(rollouts[0].outputs)):
logprobs = extract_logprobs(rollouts[0].outputs[rollout_idx].logprobs)

sum_logprobs = -sum(logprobs)
sum_logprobs_per_tok = -sum(logprobs) / len(logprobs)

batch_sum_logprobs.append(sum_logprobs)
batch_sum_logprobs_per_tok.append(sum_logprobs_per_tok)

entropy = sum(batch_sum_logprobs) / len(batch_sum_logprobs)
entropy_norm = sum(batch_sum_logprobs_per_tok) / len(batch_sum_logprobs_per_tok)

return entropy, entropy_norm
23 changes: 23 additions & 0 deletions src/fairseq2/recipes/lm/_online_finetune/_grpo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -22,6 +22,7 @@
from fairseq2.recipes.lm._online_finetune._common import (
compute_token_level_entropy,
log_rollouts,
get_rollout_lengths,
)

from fairseq2.context import RuntimeContext
Expand Down Expand Up @@ -172,6 +173,14 @@
log_rollouts(prompt_batch, rollouts, "Valid")
reward_output = self._reward.process_rollouts(rollouts, prompt_batch)
avg_reward = torch.tensor(reward_output["rewards"]).float().mean()

rollout_lengths = get_rollout_lengths(rollouts)
avg_rollout_length = torch.tensor(rollout_lengths).float().mean()
avg_reward_len_norm = avg_reward / avg_rollout_length

self._metric_bag.update_avg_rollout_length(avg_rollout_length)
self._metric_bag.update_avg_reward_len_norm(avg_reward_len_norm)

self._metric_bag.update_avg_reward(avg_reward)
self._metric_bag.update_batch_metrics(prompt_batch)
# returning dummy loss since trainer expects it
Expand Down Expand Up @@ -386,6 +395,12 @@
)
self.register_metric("grpo_loss", Mean(device=gang.device), persistent=False)
self.register_metric("avg_reward", Mean(device=gang.device), persistent=False)
self.register_metric(
"avg_rollout_length", Mean(device=gang.device), persistent=False
)
self.register_metric(
"avg_reward_len_norm", Mean(device=gang.device), persistent=False
)
self.register_metric(
"logit_entropy", Mean(device=gang.device), persistent=False
)
Expand Down Expand Up @@ -426,6 +441,14 @@
def update_avg_reward(self, avg_reward):
self.avg_reward.update(avg_reward, weight=1)

@torch.inference_mode()
def update_avg_rollout_length(self, avg_rollout_length):
self.avg_rollout_length.update(avg_rollout_length, weight=1)

@torch.inference_mode()
def update_avg_reward_len_norm(self, avg_reward_len_norm):
self.avg_reward_len_norm.update(avg_reward_len_norm, weight=1)

@torch.inference_mode()
def update_batch_metrics(self, batch: PreferenceBatch):
num_examples = batch.batch_size
Expand Down
Loading
Loading