@@ -117,9 +117,9 @@ def __init__(
117117 self .refresh_batch_size = refresh_batch_size
118118 self .out_batch_size = out_batch_size
119119 self .device = device
120- self .remove_bos = remove_bos
121120 self .add_special_tokens = add_special_tokens
122121 self .tokenizer = AutoTokenizer .from_pretrained (model .name_or_path )
122+ self .remove_bos = remove_bos and (self .tokenizer .bos_token_id is not None )
123123
124124 if not self .tokenizer .pad_token :
125125 self .tokenizer .pad_token = self .tokenizer .eos_token
@@ -192,11 +192,11 @@ def refresh(self):
192192 with t .no_grad ():
193193 input = self .tokenized_batch ()
194194 hidden_states = collect_activations (self .model , self .submodule , input )
195- attn_mask = input ["attention_mask" ]
195+ mask = ( input ["attention_mask" ] != 0 )
196196 if self .remove_bos :
197- hidden_states = hidden_states [:, 1 :, :]
198- attn_mask = attn_mask [:, 1 :]
199- hidden_states = hidden_states [attn_mask != 0 ]
197+ bos_mask = ( input [ "input_ids" ] == self . tokenizer . bos_token_id )
198+ mask = mask & ~ bos_mask
199+ hidden_states = hidden_states [mask ]
200200
201201 remaining_space = self .activation_buffer_size - current_idx
202202 assert remaining_space > 0
0 commit comments