Skip to content

Commit 6d31004

Browse files
committed
QM9 trainer up and running
1 parent 8d91f8b commit 6d31004

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

applications/FLASK/Transformer/datasets/QM9.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
from pretokenize.SMILES_tokenizer import MolTokenizer
10+
from pretokenize.data_utils import random_zero_array
1011

1112
sequence_length = int(os.getenv("QM9_SEQUENCE_LENGTH", default="32"))
1213

@@ -20,10 +21,12 @@
2021
tokenizer = MolTokenizer(os.path.join(data_dir, "QM9_vocab.json"))
2122
tokenizer.load_vocab_file()
2223

23-
dataset_train = np.load(os.path.join(data_dir, "QM9_Pretokenized.npy"), allow_pickle=True)
24+
dataset_train = np.load(os.path.join(data_dir, "QM9_Pretokenized.npy"))
2425

26+
# dataset_train = np.zeros((140000, 32), dtype=np.float32)
2527
_vocab_size = 46
26-
28+
pad_index = tokenizer.token_to_id("<pad>")
29+
sep_index = tokenizer.token_to_id("<eos>")
2730

2831
# ----------------------------------------------
2932
# Sample access functions
@@ -36,7 +39,16 @@ def num_train_samples():
3639

3740
def get_train_sample(i):
3841
data = dataset_train[i]
39-
return data
42+
43+
boundary = np.where(data == sep_index)[0][0]
44+
masked_data = random_zero_array(
45+
data[:boundary], 0.15, tokenizer.token_to_id(tokenizer.mask_token)
46+
)
47+
output = np.zeros((2 * sequence_length), dtype=np.int32)
48+
output[0:boundary] = masked_data
49+
output[boundary] = sep_index
50+
output[sequence_length:] = data
51+
return output
4052

4153

4254
def sample_dims():
@@ -50,4 +62,4 @@ def vocab_size():
5062
if __name__ == "__main__":
5163
print("Training samples:", num_train_samples())
5264
print("Training sample 101:")
53-
print(get_train_sample(101))
65+
print(get_train_sample(0))
Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,34 @@
11
import numpy as np
22
from SMILES_tokenizer import MolTokenizer
33
from data_utils import random_zero_array
4+
import os
5+
import os.path
46

57

68
def main():
7-
tokenizer = MolTokenizer("SMILES_vocab.json")
9+
data_dir = os.getenv("QM9_DATA_DIR", "/p/vast1/lbann/datasets/FLASK/QM9")
10+
11+
tokenizer = MolTokenizer(os.path.join(data_dir, "QM9_vocab.json"))
812
tokenizer.load_vocab_file()
9-
with open("QM9_smiles.txt", 'r') as smiles_data:
10-
smiles_data = smiles_data.readlines()
11-
num_samples = len(smiles_data)
12-
max_length = 32
1313

14-
tokenized_data = np.ones((num_samples, max_length)) * tokenizer.encode(tokenizer.pad_token)
15-
tokenized_data[:, 0] = tokenizer.encode(tokenizer.sep_token)
14+
data_file = os.path.join(data_dir, "QM9_smiles.txt")
15+
with open(data_file, "r") as smiles_data:
16+
smiles_data = smiles_data.readlines()
17+
num_samples = len(smiles_data)
18+
max_length = 32
19+
20+
tokenized_data = np.ones((num_samples, max_length)) * tokenizer.encode(
21+
tokenizer.pad_token
22+
)
23+
tokenized_data[:, 0] = tokenizer.encode(tokenizer.sep_token)
1624

17-
for i, smiles in enumerate(smiles_data, start=1):
18-
tokens = tokenizer.tokenize(smiles)
19-
tokens = random_zero_array(tokens, 0.15, tokenizer.encode(tokenizer.mask_token))
20-
tokenized_data[i, :len(tokens)] = tokens
21-
tokenized_data[i, len(tokens)] = tokenizer.encode(tokenizer.cls_token)
25+
for i, smiles in enumerate(smiles_data, start=0):
26+
tokens = tokenizer.tokenize(smiles)
27+
tokenized_data[i, : len(tokens)] = tokens
28+
tokenized_data[i, len(tokens)] = tokenizer.encode(tokenizer.sep_token)
29+
save_file_loc = os.path.join(data_dir, "QM9_Pretokenized.npy")
30+
np.save(save_file_loc, tokenized_data)
2231

23-
np.save('QM9_Pretokenized.npy', tokenized_data)
2432

25-
if __name__ == '__main__':
33+
if __name__ == "__main__":
2634
main()

applications/FLASK/Transformer/network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ def _add_input_encoding(
144144

145145
# Apply encoder
146146
if encoder_input is not None:
147-
encoder_input = positional_encoder(
147+
encoder_input = positional_encoder.apply_input(
148148
encoder_input, encoder_sequence_length, **kwargs
149149
)
150150
if decoder_input is not None:
151-
decoder_input = positional_encoder(
151+
decoder_input = positional_encoder.apply_input(
152152
decoder_input, decoder_sequence_length, **kwargs
153153
)
154154

0 commit comments

Comments
 (0)