Skip to content

Jacklanchantin/normalized rewards #1161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 44 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
4c9fb9e
online training components wip
Feb 22, 2025
6995cb9
wip
Feb 23, 2025
1a4f79a
Merge branch 'main' into online_training
Feb 23, 2025
8ca6d71
wip
Feb 23, 2025
d776eb5
Merge branch 'main' into online_training
Feb 23, 2025
d9f3bf0
wip changes
Feb 28, 2025
bcf05f4
wip
Feb 28, 2025
e915733
more
Mar 5, 2025
5e9dd9a
changes
Mar 10, 2025
9a39fca
changes
Mar 11, 2025
18f609d
changes
Mar 11, 2025
211dc92
optional weight update group since we dont need it for stationary mod…
Mar 11, 2025
e2569a8
changes
Mar 11, 2025
bf901b3
switch gsm8k rewards back to 0 and 1
Mar 12, 2025
500e28a
change batching logic in grpo loss
Mar 13, 2025
62d6c8e
metric formatters added for grpo
Mar 13, 2025
5bb5c06
more changes to math utils and rewards
Mar 18, 2025
d94ce26
allow original text in the example if needed
Mar 19, 2025
3bbbe68
microbatching wip
Mar 24, 2025
0c3f69f
microbatching and validation implemented
Mar 25, 2025
13e49d3
safetensor loading fix
Mar 26, 2025
33318b9
validation during training, reference model offloading, global actor…
Mar 28, 2025
3229a40
online dpo typo and small fixes
Apr 1, 2025
42f6ee5
formatting
Apr 1, 2025
7bd4c0d
_add_ directive removed
Apr 1, 2025
20e3ede
hf tokenizer support added
Apr 2, 2025
b4537ec
default fix
Apr 2, 2025
d8e45ec
Jacklanchantin/online training (#1101)
jacklanchantin Apr 3, 2025
e3996aa
grpo preset added
Apr 5, 2025
cdb4c24
type fix to allow none value
Apr 10, 2025
2cb30f8
flexible sampling params
Apr 17, 2025
163f776
logit entropy track
Apr 18, 2025
caa6541
force v0 engine with newer vllm versions
Apr 18, 2025
d1db955
make validation unit separately so that total examples counter is not…
Apr 24, 2025
6b3284e
Record rollouts in logs, update num examples, separate train/valid in…
jacklanchantin Apr 29, 2025
dd4ed22
validate before training (#1153)
jacklanchantin Apr 29, 2025
5fb86ab
Add athene rm (#1154)
jacklanchantin Apr 30, 2025
dd03a36
math-verify verifier added
Apr 30, 2025
cf5703b
loss zeroer log, use prompt batch size in loss normalizer
May 1, 2025
128dd3d
GRPO len norm support
May 1, 2025
0610d4e
Add validation sampling params & force sync when starting train unit …
jacklanchantin May 1, 2025
6ec348d
check if self._step_nr exists when syncing (#1160)
jacklanchantin May 2, 2025
18dace7
change var name
jacklanchantin May 2, 2025
c6fc816
check if self._step_nr exists when syncing (#1160)
jacklanchantin May 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/fairseq2/cli/_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@
InstructionFinetuneConfig,
LMLossEvalConfig,
POFinetuneConfig,
OnlineFinetuneConfig,
TextGenerateConfig,
load_instruction_finetuner,
load_lm_loss_evaluator,
load_po_finetuner,
load_text_generator,
load_online_finetuner,
)
from fairseq2.recipes.mt import (
MTEvalConfig,
Expand Down Expand Up @@ -235,6 +237,19 @@ def _register_lm_cli(cli: Cli) -> None:
help="generate text",
)

# Online Finetune
online_finetune_handler = RecipeCommandHandler(
loader=load_online_finetuner,
config_kls=OnlineFinetuneConfig,
default_preset="llama3_1_instruct",
)

group.add_command(
name="online_finetune",
handler=online_finetune_handler,
help="online-finetune a language model.",
)


def _register_mt_cli(cli: Cli) -> None:
extra_sweep_keys = {"source_lang", "target_lang"}
Expand Down
144 changes: 144 additions & 0 deletions src/fairseq2/data/text/tokenizers/huggingface_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import final

import torch
from torch import Tensor
from typing_extensions import override

from fairseq2.data import VocabularyInfo
from fairseq2.data.text.tokenizers import (
TextTokenDecoder,
TextTokenEncoder,
)
from fairseq2.typing import Device
from transformers import AutoTokenizer


@final
class HuggingfaceTokenizerEncoder(TextTokenEncoder):
"""Represents a tiktoken decoder."""

_tokenizer: AutoTokenizer
_prefix_indices: list[int]
_suffix_indices: list[int]
_prefix_index_tensor: Tensor | None
_suffix_index_tensor: Tensor | None
_device: Device | None
_pin_memory: bool

def __init__(
self,
tokenizer: AutoTokenizer,
*,
prefix_tokens: Sequence[str] | None = None,
suffix_tokens: Sequence[str] | None = None,
device: Device | None = None,
pin_memory: bool = False,
) -> None:
"""
:param tokenizer:
The huggingface :class:`AutoTokenizer` object.
:param prefix_tokens:
The prefix tokens to encode with input text.
:param suffix_tokens:
The suffix tokens to encode with input text.
:param device:
The device on which to construct tensors.
:param pin_memory:
If ``True``, uses pinned memory while constructing tensors.
"""
self._tokenizer = tokenizer

# Prefix
if prefix_tokens:
self._prefix_indices = self._tokenizer.convert_tokens_to_ids(prefix_tokens)

self._prefix_index_tensor = torch.tensor(
self._prefix_indices, dtype=torch.int64, device=device
)
else:
self._prefix_indices = []

self._prefix_index_tensor = None

# Suffix
if suffix_tokens:
self._suffix_indices = self._tokenizer.convert_tokens_to_ids(suffix_tokens)

self._suffix_index_tensor = torch.tensor(
self._suffix_indices, dtype=torch.int64, device=device
)
else:
self._suffix_indices = []

self._suffix_index_tensor = None

self._device = device
self._pin_memory = pin_memory

@override
def __call__(self, text: str) -> Tensor:
# fairseq2 tokenizer adds special tokens on its own
indices = self._tokenizer.encode(text, add_special_tokens=False)

if self._prefix_indices:
indices = self._prefix_indices + indices

if self._suffix_indices:
indices.extend(self._suffix_indices)

return torch.tensor(
indices, dtype=torch.int64, device=self._device, pin_memory=self._pin_memory
)

@override
def encode_as_tokens(self, text: str) -> list[str]:
indices = self(text).tolist()

tokens = self._tokenizer.convert_tds_to_tokens(indices)

return tokens

@property
@override
def prefix_indices(self) -> Tensor | None:
return self._prefix_index_tensor

@property
@override
def suffix_indices(self) -> Tensor | None:
return self._suffix_index_tensor


@final
class HuggingfaceTokenizerDecoder(TextTokenDecoder):
"""Represents a tiktoken decoder."""

_tokenizer: AutoTokenizer

def __init__(self, tokenizer: AutoTokenizer) -> None:
self._tokenizer = tokenizer

@override
def __call__(self, token_indices: Tensor) -> str:
if token_indices.dim() != 1:
raise ValueError(
f"`token_indices` must be one dimensional, but has {token_indices.dim()} dimensions instead."
)

return self._tokenizer.decode(token_indices)

@override
def decode_from_tokens(self, tokens: Sequence[str]) -> str:
indices = self._tokenizer.convert_tokens_to_ids(tokens)

return self._tokenizer.decode(indices)
111 changes: 111 additions & 0 deletions src/fairseq2/data/text/tokenizers/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,97 @@
TiktokenEncoder,
TiktokenModel,
)
from fairseq2.data.text.tokenizers.huggingface_tokenizer import (
HuggingfaceTokenizerEncoder,
HuggingfaceTokenizerDecoder,
)
from fairseq2.typing import Device
from transformers import AutoTokenizer


@final
class LLaMA3TokenizerHuggingFace(TextTokenizer):
"""Represents a HuggingFace version of LLama 3 tokenizer"""

_tokenizer: AutoTokenizer
_bos_token: str
_eos_token: str

def __init__(self, path: Path) -> None:

self._tokenizer = AutoTokenizer.from_pretrained(path)

self._eos_token = self._tokenizer.special_tokens_map["eos_token"]
self._bos_token = self._tokenizer.special_tokens_map["bos_token"]

@override
def create_encoder(
self,
*,
task: str | None = None,
lang: str | None = None,
mode: str | None = None,
device: Device | None = None,
pin_memory: bool = False,
) -> TiktokenEncoder:
if task is not None:
raise ValueError(f"`task` must be `None`, but is '{task}' instead.")

if lang is not None:
raise ValueError(f"`lang` must be `None`, but is '{lang}' instead.")

match mode:
case None | "default":
prefix_tokens = [self._bos_token]
suffix_tokens = [self._eos_token]
case "prompt":
prefix_tokens = [self._bos_token]
# In prompt mode, we expect the generator to finish the sequence.
suffix_tokens = []
case "prompt_response":
prefix_tokens = []
suffix_tokens = [self._eos_token]
case "as_is":
prefix_tokens = []
suffix_tokens = []
case _:
raise ValueError(
f"`mode` must be one of the following values, but is '{mode}' instead: default, prompt, prompt_response, as_is"
)

return HuggingfaceTokenizerEncoder(
self._tokenizer,
prefix_tokens=prefix_tokens,
suffix_tokens=suffix_tokens,
device=device,
pin_memory=pin_memory,
)

@override
def create_raw_encoder(
self, *, device: Device | None = None, pin_memory: bool = False
) -> TiktokenEncoder:
return HuggingfaceTokenizerEncoder(
self._tokenizer, device=device, pin_memory=pin_memory
)

@override
def create_decoder(self) -> TiktokenDecoder:
return HuggingfaceTokenizerDecoder(self._model)

@property
@override
def vocab_info(self) -> VocabularyInfo:
bos_idx = self._tokenizer.convert_tokens_to_ids(self._bos_token)
eos_idx = self._tokenizer.convert_tokens_to_ids(self._eos_token)
vocab_info = VocabularyInfo(
size=len(self._tokenizer),
bos_idx=bos_idx,
eos_idx=eos_idx,
unk_idx=None,
pad_idx=None,
)
return vocab_info


@final
Expand Down Expand Up @@ -139,6 +229,27 @@ def vocab_info(self) -> VocabularyInfo:


def load_llama_tokenizer(path: Path, card: AssetCard) -> TextTokenizer:

# first check if this is HuggingFace tokenizer
try:
use_hf = card.field("use_hf_tokenizer").as_(bool)
except AssetCardFieldNotFoundError:
use_hf = False
except AssetCardError as ex:
raise text_tokenizer_asset_card_error(card.name) from ex

if use_hf:
try:
return LLaMA3TokenizerHuggingFace(path)
except ValueError as ex:
raise TextTokenizerLoadError(
card.name, f"The '{card.name}' asset card does not contain a valid text tokenizer configuration of the '{LLAMA_TOKENIZER_FAMILY}' family. See the nested exception for details." # fmt: skip
) from ex
except RuntimeError as ex:
raise TextTokenizerLoadError(
card.name, f"The '{card.name}' text tokenizer cannot be loaded. See the nested exception for details." # fmt: skip
) from ex

try:
use_v2 = card.field("use_v2_tokenizer").as_(bool)
except AssetCardFieldNotFoundError:
Expand Down
Loading