-
Notifications
You must be signed in to change notification settings - Fork 69
[WIP] Attention across documents. #213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d7afc8c
119df71
e24738c
3e88907
d24b533
af1a4cd
3e0036b
e4a5bac
7234b31
427291f
f28d984
3521ce1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -742,6 +742,11 @@ def parse_args(args): | |
action="store_true", | ||
help="If set, allow model to do multiple data passes over our dataset, in order to reach the desired number of tokens.", | ||
) | ||
parser.add_argument( | ||
"--mask-across-documents", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this should be an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense - will update the parameter. |
||
action="store_true", | ||
help="If set, then tokens in the same sequence will be masked across EOT.", | ||
) | ||
|
||
add_model_args(parser) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
wandb = None | ||
|
||
from open_lm.data import sample_chunk | ||
from open_lm.datapreprocess.ray.tokenize_shuffle import SpecialTokens | ||
from open_lm.distributed import is_master | ||
from open_lm.precision import get_autocast | ||
from open_lm.meters import AverageMeter | ||
|
@@ -41,6 +42,34 @@ def backward(total_loss, scaler): | |
total_loss.backward() | ||
|
||
|
||
def get_document_seqlens(inputs, args): | ||
"""Get list of document sequence lengths. | ||
|
||
Return a list of lists. The length of the outer list is equal to the batch size, while the length of the inner list | ||
is equal to the the number of distinct documents (recognized by EOT tokens). Each element of the inner lists is the | ||
length of that corresponding document | ||
""" | ||
if args.mask_across_documents: | ||
document_seqlens = [] | ||
for idx in range(inputs.shape[0]): | ||
eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value) | ||
if eot_idx.shape[0] == 0: | ||
# All tokens come from the same document. | ||
document_seqlens.append([args.seq_len]) | ||
else: | ||
start_idx = 0 | ||
seqlens = [] | ||
for k in range(eot_idx.shape[0]): | ||
seqlens.append(eot_idx[k].item() - start_idx + 1) | ||
start_idx = eot_idx[k].item() + 1 | ||
if start_idx < args.seq_len: | ||
seqlens.append(args.seq_len - start_idx) | ||
document_seqlens.append(seqlens) | ||
else: | ||
document_seqlens = None | ||
return document_seqlens | ||
|
||
|
||
def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler, total_steps, args, tb_writer=None): | ||
"""Trains model for one epoch on the provided data. | ||
|
||
|
@@ -109,13 +138,21 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler | |
|
||
(texts,) = batch | ||
texts = torch.LongTensor(texts).to(device) | ||
|
||
data_time_m.update(time.time() - end) | ||
optimizer.zero_grad() | ||
|
||
if args.accum_freq == 1: | ||
with autocast(): | ||
inputs, targets = sample_chunk(texts, args) | ||
out, _, _ = model(inputs) | ||
document_seqlens = get_document_seqlens(inputs, args) | ||
if args.mask_across_documents: | ||
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it | ||
# should not contribute to the loss. | ||
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i prefer not to hard code our EOT to keep open_lm tokenizer agnostic There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed - I'll change it so that it uses the user defined EOT token. |
||
targets[ignore_indices] = loss.ignore_index | ||
|
||
out, _, _ = model(inputs, document_seqlens=document_seqlens) | ||
|
||
if args.log_logit_mean: | ||
logit_m.update(torch.mean(out).item()) | ||
|
@@ -135,6 +172,12 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler | |
per_batch = args.per_gpu_batch_size // args.accum_freq | ||
|
||
inputs, targets = sample_chunk(texts, args) | ||
if args.mask_across_documents: | ||
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it | ||
# should not contribute to the loss. | ||
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True) | ||
targets = targets.detach().clone() # Clone this because it shares mem with input! | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, is the detach necessary here? When args.mask_across_documents is False, should we also a detach()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Detach is not necessary, but clone is - because the targets and the input share the underlying tensor, if the target is explicitly set then the input is also affected. When args.mask_across_documents is False, this is not an issue - neither the target nor the input are explicitly changed. |
||
targets[ignore_indices] = loss.ignore_index | ||
|
||
for ii in range(args.accum_freq): | ||
maybe_no_sync = nullcontext | ||
|
@@ -147,7 +190,8 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler | |
if inputs_ii.shape[0] == 0: | ||
break | ||
targets_ii = targets[ii * per_batch : (ii + 1) * per_batch] | ||
out, _, _ = model(inputs_ii) | ||
document_seqlens = get_document_seqlens(inputs_ii, args) | ||
out, _, _ = model(inputs_ii, document_seqlens=document_seqlens) | ||
|
||
if args.log_logit_mean: | ||
logit_m.update(torch.mean(out).item()) | ||
|
Uh oh!
There was an error while loading. Please reload this page.