-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenisers.py
More file actions
132 lines (104 loc) · 5.43 KB
/
tokenisers.py
File metadata and controls
132 lines (104 loc) · 5.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from transformers import PreTrainedTokenizer
from tokenizers import Tokenizer
import atomInSmiles
import json
from rdkit import Chem
class CharLevelTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file):
with open(vocab_file, 'r') as f:
self.vocab = json.load(f)
self.ids_to_tokens = {id: token for token, id in self.vocab.items()}
def tokenize(self, text):
return list(text) # Tokenize the text into a list of characters
def convert_tokens_to_ids(self, tokens):
return [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens]
def encode(self, text, add_special_tokens=False, max_length=None, truncation=False):
tokens = self.tokenize(text)
token_ids = self.convert_tokens_to_ids(tokens)
if add_special_tokens:
token_ids = [self.vocab['[CLS]']] + token_ids + [self.vocab['[SEP]']]
if truncation and max_length:
token_ids = token_ids[:max_length]
return token_ids
def decode(self, token_ids, skip_special_tokens=True):
special_tokens = [
self.vocab.get('[CLS]'),
self.vocab.get('[SEP]'),
self.vocab.get('[PAD]'),
self.vocab.get('[UNK]')
]
if skip_special_tokens:
token_ids = [id for id in token_ids if id not in special_tokens]
return ''.join(self.ids_to_tokens.get(id, '[UNK]') for id in token_ids)
def __len__(self):
return len(self.vocab)
class AISTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file):
with open(vocab_file, 'r') as f:
self.vocab = json.load(f)
self.ids_to_tokens = {id: token for token, id in self.vocab.items()}
def tokenize(self, smiles):
try:
# Validate SMILES using RDKit before passing to atomInSmiles
mol = Chem.MolFromSmiles(smiles)
if mol is None:
print(f"Warning: Invalid SMILES skipped: {smiles}")
return [] # Return empty token list
# Tokenize using atomInSmiles
ais_token_string = atomInSmiles.encode(smiles) # Returns a space-separated string of tokens
return ais_token_string.split()
except Exception as e:
print(f"Error tokenizing SMILES '{smiles}': {e}")
return [] # Return empty token list instead of crashing
def convert_tokens_to_ids(self, tokens):
return [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens] # Map tokens to their IDs
def encode(self, smiles, add_special_tokens=True, max_length=None, truncation=False):
tokens = self.tokenize(smiles)
token_ids = self.convert_tokens_to_ids(tokens)
if add_special_tokens:
token_ids = [self.vocab.get('[CLS]')] + token_ids + [self.vocab.get('[SEP]')]
if truncation and max_length:
token_ids = token_ids[:max_length]
return token_ids
def decode(self, token_ids, skip_special_tokens=True):
tokens = [self.ids_to_tokens.get(id, '[UNK]') for id in token_ids] # Convert token IDs to tokens, replacing unknown IDs with '[UNK]'
if skip_special_tokens:
special_tokens = {'[CLS]', '[SEP]', '[PAD]', '[UNK]'}
tokens = [token for token in tokens if token not in special_tokens] # Filter out special tokens
smiles_token_string = ' '.join(tokens) # Join the tokens into a space-separated string for decoding
try:
# Attempt to decode using atomInSmiles
decoded_smiles = atomInSmiles.decode(smiles_token_string) # Decode the SMILES string into its chemical structure representation
except IndexError as e:
print(f"Decoding error with SMILES string '{smiles_token_string}': {e}. Skipping.")
return None # Return None or an empty string to indicate failure
return decoded_smiles.replace(' ', '') if decoded_smiles else None
def __len__(self):
return len(self.vocab)
class NPBPETokenizer(PreTrainedTokenizer):
def __init__(self, tokenizer_file):
self.tokenizer = Tokenizer.from_file(tokenizer_file)
def tokenize(self, text):
# Tokenize using the pre-trained tokenizer
encoding = self.tokenizer.encode(text)
return encoding.tokens # Return list of tokens
def convert_tokens_to_ids(self, tokens):
# Convert tokens to IDs using the pre-trained tokenizer's encoding
encoding = self.tokenizer.encode(" ".join(tokens))
return encoding.ids
def encode(self, text, add_special_tokens=True, max_length=None, truncation=False):
encoding = self.tokenizer.encode(text)
token_ids = encoding.ids
if add_special_tokens:
token_ids = [self.tokenizer.token_to_id("[CLS]")] + token_ids + [self.tokenizer.token_to_id("[SEP]")]
if truncation and max_length:
token_ids = token_ids[:max_length]
return token_ids
def decode(self, token_ids, skip_special_tokens=False):
if skip_special_tokens: # Define special token IDs
special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]"]
special_token_ids = {self.tokenizer.token_to_id(token) for token in special_tokens if self.tokenizer.token_to_id(token) is not None}
token_ids = [id for id in token_ids if id not in special_token_ids]
return self.tokenizer.decode(token_ids)
def __len__(self):
return self.tokenizer.get_vocab_size()