@@ -41,11 +41,18 @@ def tokenize(self, text):
4141 return tokens
4242
4343 def detokenize (self , token_ids ):
44- id_to_token = {index : token for token , index in self .tokens .items ()}
45- tokens = [id_to_token .get (token_id , "<unk>" ) for token_id in token_ids ]
46- text = " " .join (tokens )
44+ try :
45+ id_to_token = {index : token for token , index in self .tokens .items ()}
46+ tokens = [id_to_token .get (token_id , "<unk>" ) for token_id in token_ids ]
47+ text = " " .join (tokens )
48+ except TypeError :
49+ token_ids = [token_ids ]
50+ id_to_token = {index : token for token , index in self .tokens .items ()}
51+ tokens = [id_to_token .get (token_id , "<unk>" ) for token_id in token_ids ]
52+ text = " " .join (tokens )
4753 return text
48-
54+
55+
4956 def create_vocab (self , dataset ):
5057 unique_tokens = data2tokens (dataset , vocab_size = ((int (self .vocab_size ))- 4 )) # Way less optimized than the old data.split but WAY more effective for MGPL
5158 self .tokens = {token : idx + len (self .special_tokens )+ 1 for idx , token in enumerate (unique_tokens )}
@@ -312,6 +319,22 @@ def token_to_vector(self, token_id):
312319 return torch .tensor (vector , dtype = torch .float32 ).to (self .device )
313320 self .logger .log (f"Token ID { token_id } not found in embedding table. Returning zero vector." , v = False , Wh = True , mention = True )
314321 return torch .zeros ((self .vector_dim ,), dtype = torch .float32 ).to (self .device )
322+
323+ # <!> This is NOT a real function, just added it for --embedding-test, it will NOT be used in real generation. <!>
324+ def vector_to_token (self , input_vector ): # Returns the closest token based on vector distance (Euclidean)
325+ input_tensor = torch .tensor (input_vector ) if not torch .is_tensor (input_vector ) else input_vector
326+ min_distance = float ('inf' )
327+ closest_token = None
328+ input_tensor = input_tensor .to (self .device )
329+ for token , vector in self .embedding_table :
330+ vec_tensor = vector if torch .is_tensor (vector ) else torch .tensor (vector )
331+ vec_tensor = vec_tensor .to (self .device )
332+ distance = torch .norm (input_tensor - vec_tensor ).item ()
333+ if distance < min_distance :
334+ min_distance = distance
335+ closest_token = token
336+
337+ return closest_token
315338
316339class SPE ():
317340 def __init__ (self , device ):
0 commit comments