Skip to content

Add loss like Rho-1 #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions open_lm/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import atexit
import copy
import logging
import os
import re
Expand Down Expand Up @@ -599,6 +600,46 @@ def main(args):
# Put the shard shuffle seed back into args (this is done for compatibility with older, non shuffling versions)
args.shard_shuffle_seed = shard_shuffle_seed

ref_model = None
if args.reference_model is not None:
# TODO: currently assumes that the reference model is the same architecture as the training model.
with torch.device("meta" if args.experimental_meta_device and args.fsdp else args.device):
ref_model = create_model(args)

if args.distributed:
if args.fsdp:
if args.rank == 0:
print(f"Before FSDP parameter num: {sum(p.numel() for p in ref_model.parameters()):,}")
print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB")

# Reuse FSDP parameters from above.
ref_model = FSDP(
ref_model,
auto_wrap_policy=transformer_auto_wrapper_policy,
device_id=device,
mixed_precision=mp_policy,
cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload),
use_orig_params=args.fsdp_use_orig_params,
limit_all_gathers=args.fsdp_limit_all_gathers,
**fsdp_kwargs,
)

print(f"After FSDP parameter num: {sum(p.numel() for p in ref_model.parameters()):,} on rank {args.rank}")
print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}")
else:
ddp_args = {}
if args.ddp_static_graph:
# this doesn't exist in older PyTorch, arg only added if enabled
ddp_args["static_graph"] = True
ref_model = torch.nn.parallel.DistributedDataParallel(ref_model, device_ids=[device], **ddp_args)

temp_args = copy.deepcopy(args)
temp_args.resume = args.reference_model
_ = load_model(temp_args, ref_model, different_seed=True)
for p in ref_model.parameters():
p.requires_grad = False


if requires_training and global_step is None:
raise ValueError("Key 'step' not found in checkpoint, but required for training.")

Expand Down Expand Up @@ -809,6 +850,7 @@ def main(args):
total_steps=total_steps,
args=args,
tb_writer=writer,
ref_model=ref_model
)

if args.distributed:
Expand Down
13 changes: 13 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def check_args(args):
if args.experimental_meta_device:
print("WARNING: Meta device initialization requested, but this is not currently fully tested.")

if args.rho1_k is not None:
assert args.rho1_k >= 0 and args.rho1_k <= 1, "--rho1-k should be a float in [0,1]"
assert args.rho1_switch >= 0 and args.rho1_switch <= 1, "--rho1-k should be a float in [0,1]"
assert args.reference_model is not None, "a reference model is needed for Rho-1 loss"


def parse_args(args):
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -761,6 +766,14 @@ def parse_args(args):
default=0,
help="Whether to log the average model training loss. if not 0, it will log the average loss over the specified number of steps.",
)
parser.add_argument("--rho1-k", type=float, default=None, help="Percentage of tokens to keep if using Rho-1 loss")
parser.add_argument("--rho1-switch", type=float, default=None, help="Percentage of training at which to switch to Rho-1 loss.")
parser.add_argument(
"--reference-model",
type=str,
default=None,
help="Reference model to use when using Rho-1 loss."
)
add_model_args(parser)

config = maybe_load_config(parser, args)
Expand Down
35 changes: 30 additions & 5 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.distributed as dist
from torch.distributed.distributed_c10d import ReduceOp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.nn.functional as F

try:
from megablocks.layers.moe import batched_load_balancing_loss, clear_load_balancing_loss
Expand Down Expand Up @@ -42,7 +43,7 @@ def backward(total_loss, scaler):


def train_one_epoch(
model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, averagers=None
model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None, averagers=None, ref_model=None
):
"""Trains model for one epoch on the provided data.

Expand Down Expand Up @@ -122,11 +123,23 @@ def train_one_epoch(
with autocast():
inputs, targets = sample_chunk(texts, args)
out, _, _ = model(inputs)

if args.log_logit_mean:
logit_m.update(torch.mean(out).item())

total_lm_loss = loss(out.reshape(-1, args.vocab_size), targets.reshape(-1))
if ref_model is not None and step >= args.rho1_switch * total_steps:
ref_out, _, _ = ref_model(inputs)
with torch.no_grad():
loss_cur = F.cross_entropy(out.reshape(-1, args.vocab_size), targets.reshape(-1), reduction="none")
loss_ref = F.cross_entropy(ref_out.reshape(-1, args.vocab_size), targets.reshape(-1), reduction="none")
loss_diff = loss_cur - loss_ref
loss_diff = loss_diff.reshape(inputs.shape[0], inputs.shape[1])
ref_mask = (loss_diff < torch.topk(loss_diff, int(np.ceil(args.rho1_k * loss_diff.shape[1])), dim = -1)[0][..., -1].reshape(loss_diff.shape[0], -1))
loss_targets = targets.detach().clone()
loss_targets[ref_mask] = -100 # This index is ignored from the loss
else:
loss_targets = targets

total_lm_loss = loss(out.reshape(-1, args.vocab_size), loss_targets.reshape(-1))
total_loss = total_lm_loss
if args.moe_freq > 0:
total_load_balancing_loss = batched_load_balancing_loss(moe_args)
Expand Down Expand Up @@ -161,12 +174,24 @@ def train_one_epoch(
break
targets_ii = targets[ii * per_batch : (ii + 1) * per_batch]
out, _, _ = model(inputs_ii)

if args.log_logit_mean:
logit_m.update(torch.mean(out).item())

if ref_model is not None and step >= args.rho1_switch * total_steps:
ref_out_ii, _, _ = ref_model(inputs_ii)
with torch.no_grad():
loss_cur = F.cross_entropy(out.reshape(-1, args.vocab_size), targets_ii.reshape(-1), reduction="none")
loss_ref = F.cross_entropy(ref_out_ii.reshape(-1, args.vocab_size), targets_ii.reshape(-1), reduction="none")
loss_diff = loss_cur - loss_ref
loss_diff = loss_diff.reshape(inputs_ii.shape[0], inputs_ii.shape[1])
ref_mask = (loss_diff < torch.topk(loss_diff, int(np.ceil(args.rho1_k * loss_diff.shape[1])), dim = -1)[0][..., -1].reshape(loss_diff.shape[0], -1))
loss_targets = targets_ii.detach().clone()
loss_targets[ref_mask] = -100 # This index is ignored from the loss
else:
loss_targets = targets_ii

local_lm_loss = (
loss(out.reshape(-1, args.vocab_size), targets_ii.reshape(-1))
loss(out.reshape(-1, args.vocab_size), loss_targets.reshape(-1))
* inputs_ii.shape[0]
/ inputs.shape[0]
)
Expand Down
Loading