@@ -410,18 +410,34 @@ def __init__(self, tokenizer, mlm=False, output_raw_keys=False):
410
410
super ().__init__ (tokenizer = tokenizer , mlm = False )
411
411
self .output_raw_keys = output_raw_keys
412
412
413
- def generate_masking_indices (self , key_lengths , max_length , input_ids ):
413
+ def generate_masking_indices (self , key_lengths , response_lengths , max_length , input_ids ):
414
414
batch_size = key_lengths .size (0 )
415
415
device = input_ids .device # Ensure the mask is created on the same device as the input_ids
416
416
417
417
if self .tokenizer .padding_side == 'right' :
418
418
# Check if the first token is the BOS token
419
+ # first_token = input_ids[:, 0]
420
+
421
+ # if (first_token == self.tokenizer.bos_token_id).all():
422
+ # mask = torch.arange(max_length, device=device).expand(batch_size, -1) < (key_lengths + 1).unsqueeze(1)
423
+ # else:
424
+ # mask = torch.arange(max_length, device=device).expand(batch_size, -1) < key_lengths.unsqueeze(1)
425
+
426
+ # Mask needs to be 1 for 0 to key_length then key_length+response_length+1 to max_length
427
+
428
+ # This does not take into account the EOS token at the end of the response (unless response_length is explicitly increased to account for it)
429
+ all_idx = torch .arange (max_length , device = device ).expand (batch_size , - 1 )
430
+
431
+ offset_counter = 0
419
432
first_token = input_ids [:, 0 ]
420
433
421
434
if (first_token == self .tokenizer .bos_token_id ).all ():
422
- mask = torch .arange (max_length , device = device ).expand (batch_size , - 1 ) < (key_lengths + 1 ).unsqueeze (1 )
423
- else :
424
- mask = torch .arange (max_length , device = device ).expand (batch_size , - 1 ) < key_lengths .unsqueeze (1 )
435
+ offset_counter += 1
436
+ mask = (all_idx < key_lengths .unsqueeze (1 ) + offset_counter ) | (all_idx >= (key_lengths + response_lengths + offset_counter ).unsqueeze (1 ))
437
+
438
+ return mask
439
+
440
+
425
441
else :
426
442
# Calculate the pad lengths
427
443
pad_lengths = torch .sum (input_ids == self .tokenizer .pad_token_id , dim = 1 )
@@ -444,15 +460,17 @@ def __call__(self, batch):
444
460
# A negative label will be ignored by the loss function
445
461
# Get key lengths
446
462
key_lengths = torch .stack ([torch .tensor (x ['key_length' ]) for x in batch ])
463
+ response_lengths = torch .stack ([torch .tensor (x ['response_length' ]) for x in batch ])
447
464
448
465
# This code will be a spagetthi to handle the idiosyncrasies of the tokenizer
449
466
450
467
# Create a mask for the positions corresponding to the keys
451
- mask = self .generate_masking_indices (key_lengths = key_lengths , max_length = labels .size (1 ), input_ids = input_ids )
468
+ mask = self .generate_masking_indices (key_lengths = key_lengths , max_length = labels .size (1 ), input_ids = input_ids , response_lengths = response_lengths )
452
469
453
470
# Apply the mask to set the corresponding labels to -100
454
- labels [mask ] = - 100
471
+ labels [mask ] = - 100
455
472
# Need to account for EOS token ?
473
+ # print(labels[:, 15:19])
456
474
new_batch ['labels' ] = labels
457
475
return new_batch
458
476
0 commit comments