Skip to content

Commit ef6acb8

Browse files
authored
Fixes bug when padding side is right (#7)
1 parent 8add002 commit ef6acb8

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

generate_finetuning_data.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -410,18 +410,34 @@ def __init__(self, tokenizer, mlm=False, output_raw_keys=False):
410410
super().__init__(tokenizer=tokenizer, mlm=False)
411411
self.output_raw_keys = output_raw_keys
412412

413-
def generate_masking_indices(self, key_lengths, max_length, input_ids):
413+
def generate_masking_indices(self, key_lengths, response_lengths, max_length, input_ids):
414414
batch_size = key_lengths.size(0)
415415
device = input_ids.device # Ensure the mask is created on the same device as the input_ids
416416

417417
if self.tokenizer.padding_side == 'right':
418418
# Check if the first token is the BOS token
419+
# first_token = input_ids[:, 0]
420+
421+
# if (first_token == self.tokenizer.bos_token_id).all():
422+
# mask = torch.arange(max_length, device=device).expand(batch_size, -1) < (key_lengths + 1).unsqueeze(1)
423+
# else:
424+
# mask = torch.arange(max_length, device=device).expand(batch_size, -1) < key_lengths.unsqueeze(1)
425+
426+
# Mask needs to be 1 for 0 to key_length then key_length+response_length+1 to max_length
427+
428+
# This does not take into account the EOS token at the end of the response (unless response_length is explicitly increased to account for it)
429+
all_idx = torch.arange(max_length, device=device).expand(batch_size, -1)
430+
431+
offset_counter = 0
419432
first_token = input_ids[:, 0]
420433

421434
if (first_token == self.tokenizer.bos_token_id).all():
422-
mask = torch.arange(max_length, device=device).expand(batch_size, -1) < (key_lengths + 1).unsqueeze(1)
423-
else:
424-
mask = torch.arange(max_length, device=device).expand(batch_size, -1) < key_lengths.unsqueeze(1)
435+
offset_counter += 1
436+
mask = (all_idx < key_lengths.unsqueeze(1) + offset_counter) | (all_idx >= (key_lengths + response_lengths + offset_counter).unsqueeze(1))
437+
438+
return mask
439+
440+
425441
else:
426442
# Calculate the pad lengths
427443
pad_lengths = torch.sum(input_ids == self.tokenizer.pad_token_id, dim=1)
@@ -444,15 +460,17 @@ def __call__(self, batch):
444460
# A negative label will be ignored by the loss function
445461
# Get key lengths
446462
key_lengths = torch.stack([torch.tensor(x['key_length']) for x in batch])
463+
response_lengths = torch.stack([torch.tensor(x['response_length']) for x in batch])
447464

448465
# This code will be a spagetthi to handle the idiosyncrasies of the tokenizer
449466

450467
# Create a mask for the positions corresponding to the keys
451-
mask = self.generate_masking_indices(key_lengths=key_lengths, max_length=labels.size(1), input_ids=input_ids)
468+
mask = self.generate_masking_indices(key_lengths=key_lengths, max_length=labels.size(1), input_ids=input_ids, response_lengths=response_lengths)
452469

453470
# Apply the mask to set the corresponding labels to -100
454-
labels[mask] = -100
471+
labels[mask] = -100
455472
# Need to account for EOS token ?
473+
# print(labels[:, 15:19])
456474
new_batch['labels'] = labels
457475
return new_batch
458476

0 commit comments

Comments
 (0)