@@ -89,12 +89,23 @@ def _move_special_tokens(self):
8989 self .state ["added_tokens" ][i ]["id" ] += (
9090 len (self .state ["model" ]["vocab" ]) - self .initial_length )
9191
92- for i in range (len (self .state ["post_processor" ]["processors" ])):
93- if 'special_tokens' in self .state ["post_processor" ]["processors" ][i ].keys ():
94- for k in self .state ["post_processor" ]["processors" ][i ]["special_tokens" ].keys ():
95- for j in tqdm (range (len (self .state ["post_processor" ]["processors" ][i ]["special_tokens" ][k ]['ids' ])), desc = "Moving special tokens" ):
96- self .state ["post_processor" ]["processors" ][i ]["special_tokens" ][k ]["ids" ][j ] += (
97- len (self .state ["model" ]["vocab" ]) - self .initial_length )
92+ def process_special_tokens (obj ):
93+ if isinstance (obj , dict ):
94+ for key , value in obj .items ():
95+ if key == "special_tokens" and isinstance (value , dict ):
96+ for k in value .keys ():
97+ if "ids" in value [k ]:
98+ for j in tqdm (range (len (value [k ]["ids" ])), desc = "Moving special tokens" ):
99+ value [k ]["ids" ][j ] += (
100+ len (self .state ["model" ]["vocab" ]) - self .initial_length )
101+ else :
102+ process_special_tokens (value )
103+
104+ elif isinstance (obj , list ):
105+ for item in obj :
106+ process_special_tokens (item )
107+
108+ process_special_tokens (self .state .get ("post_processor" , {}))
98109
99110 def _process_and_add_tokens (self , merge : list ):
100111 processed_merge = '' .join (merge ).replace (' ' , '' )
@@ -470,7 +481,7 @@ def updated_tokenizer(self):
470481 """
471482 self .__is_tokenizer ()
472483
473- if self .initial_length < len (self .state ["model" ]["vocab" ]):
484+ if self .initial_length != len (self .state ["model" ]["vocab" ]):
474485 self ._move_special_tokens ()
475486
476487 backend_tokenizer = Tokenizer .from_str (json .dumps (self .state ))
0 commit comments