Skip to content

Commit e1cf00d

Browse files
authored
feat: upgrading to PyPI reinvent models 0.0.14 version (#1)
* feat: upgrading to pypi reinvent models 0.0.14 version * feat: fixed the repo url and version in setup.py
1 parent 6eb81f2 commit e1cf00d

32 files changed

+444
-719
lines changed

reinvent_models/lib_invent/enums/generative_model_parameters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
2+
13
class GenerativeModelParametersEnum:
24
NUMBER_OF_LAYERS = "num_layers"
35
NUMBER_OF_DIMENSIONS = "num_dimensions"
@@ -12,4 +14,4 @@ def __getattr__(self, name):
1214

1315
# prohibit any attempt to set any values
1416
def __setattr__(self, key, value):
15-
raise ValueError("No changes allowed.")
17+
raise ValueError("No changes allowed.")

reinvent_models/lib_invent/enums/generative_model_regime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
2+
13
class GenerativeModelRegimeEnum:
24
INFERENCE = "inference"
35
TRAINING = "training"

reinvent_models/lib_invent/models/dataset.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ def __init__(self, smiles_list, vocabulary, tokenizer):
2727
self._encoded_list.append(enc)
2828

2929
def __getitem__(self, i):
30-
return torch.tensor(
31-
self._encoded_list[i], dtype=torch.long
32-
) # pylint: disable=E1102
30+
return torch.tensor(self._encoded_list[i], dtype=torch.long) # pylint: disable=E1102
3331

3432
def __len__(self):
3533
return len(self._encoded_list)
@@ -47,21 +45,14 @@ def __init__(self, scaffold_decoration_smi_list, vocabulary):
4745

4846
self._encoded_list = []
4947
for scaffold, dec in scaffold_decoration_smi_list:
50-
en_scaff = self.vocabulary.scaffold_vocabulary.encode(
51-
self.vocabulary.scaffold_tokenizer.tokenize(scaffold)
52-
)
53-
en_dec = self.vocabulary.decoration_vocabulary.encode(
54-
self.vocabulary.decoration_tokenizer.tokenize(dec)
55-
)
48+
en_scaff = self.vocabulary.scaffold_vocabulary.encode(self.vocabulary.scaffold_tokenizer.tokenize(scaffold))
49+
en_dec = self.vocabulary.decoration_vocabulary.encode(self.vocabulary.decoration_tokenizer.tokenize(dec))
5650
if en_scaff is not None and en_dec is not None:
5751
self._encoded_list.append((en_scaff, en_dec))
5852

5953
def __getitem__(self, i):
6054
scaff, dec = self._encoded_list[i]
61-
return (
62-
torch.tensor(scaff, dtype=torch.long),
63-
torch.tensor(dec, dtype=torch.long),
64-
) # pylint: disable=E1102
55+
return (torch.tensor(scaff, dtype=torch.long), torch.tensor(dec, dtype=torch.long)) # pylint: disable=E1102
6556

6657
def __len__(self):
6758
return len(self._encoded_list)
@@ -83,9 +74,7 @@ def pad_batch(encoded_seqs):
8374
:param encoded_seqs: A list of encoded sequences.
8475
:return: A tensor with the sequences correctly padded.
8576
"""
86-
seq_lengths = torch.tensor(
87-
[len(seq) for seq in encoded_seqs], dtype=torch.int64
88-
) # pylint: disable=not-callable
77+
seq_lengths = torch.tensor([len(seq) for seq in encoded_seqs], dtype=torch.int64) # pylint: disable=not-callable
8978
if torch.cuda.is_available():
9079
return (tnnur.pad_sequence(encoded_seqs, batch_first=True).cuda(), seq_lengths)
9180
return (tnnur.pad_sequence(encoded_seqs, batch_first=True), seq_lengths)

reinvent_models/lib_invent/models/decorator.py

Lines changed: 38 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
import torch.nn as tnn
88
import torch.nn.utils.rnn as tnnur
99

10-
from reinvent_models.lib_invent.enums.generative_model_parameters import (
11-
GenerativeModelParametersEnum,
12-
)
10+
from reinvent_models.lib_invent.enums.generative_model_parameters import GenerativeModelParametersEnum
1311

1412

1513
class Encoder(tnn.Module):
@@ -27,16 +25,10 @@ def __init__(self, num_layers, num_dimensions, vocabulary_size, dropout):
2725

2826
self._embedding = tnn.Sequential(
2927
tnn.Embedding(self.vocabulary_size, self.num_dimensions),
30-
tnn.Dropout(dropout),
31-
)
32-
self._rnn = tnn.LSTM(
33-
self.num_dimensions,
34-
self.num_dimensions,
35-
self.num_layers,
36-
batch_first=True,
37-
dropout=self.dropout,
38-
bidirectional=True,
28+
tnn.Dropout(dropout)
3929
)
30+
self._rnn = tnn.LSTM(self.num_dimensions, self.num_dimensions, self.num_layers,
31+
batch_first=True, dropout=self.dropout, bidirectional=True)
4032

4133
def forward(self, padded_seqs, seq_lengths): # pylint: disable=arguments-differ
4234
# FIXME: This fails with a batch of 1 because squeezing looses a dimension with size 1
@@ -53,40 +45,26 @@ def forward(self, padded_seqs, seq_lengths): # pylint: disable=arguments-differ
5345
padded_seqs = self._embedding(padded_seqs)
5446
hs_h, hs_c = (hidden_state, hidden_state.clone().detach())
5547

56-
# FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7
48+
#FIXME: this is to guard against non compatible `gpu` input for pack_padded_sequence() method in pytorch 1.7
5749
seq_lengths = seq_lengths.cpu()
5850

59-
packed_seqs = tnnur.pack_padded_sequence(
60-
padded_seqs, seq_lengths, batch_first=True, enforce_sorted=False
61-
)
51+
packed_seqs = tnnur.pack_padded_sequence(padded_seqs, seq_lengths, batch_first=True, enforce_sorted=False)
6252
packed_seqs, (hs_h, hs_c) = self._rnn(packed_seqs, (hs_h, hs_c))
6353
padded_seqs, _ = tnnur.pad_packed_sequence(packed_seqs, batch_first=True)
6454

6555
# sum up bidirectional layers and collapse
66-
hs_h = (
67-
hs_h.view(self.num_layers, 2, batch_size, self.num_dimensions)
68-
.sum(dim=1)
69-
.squeeze()
70-
) # (layers, batch, dim)
71-
hs_c = (
72-
hs_c.view(self.num_layers, 2, batch_size, self.num_dimensions)
73-
.sum(dim=1)
74-
.squeeze()
75-
) # (layers, batch, dim)
76-
padded_seqs = (
77-
padded_seqs.view(batch_size, max_seq_size, 2, self.num_dimensions)
78-
.sum(dim=2)
79-
.squeeze()
80-
) # (batch, seq, dim)
56+
hs_h = hs_h.view(self.num_layers, 2, batch_size, self.num_dimensions)\
57+
.sum(dim=1).squeeze() # (layers, batch, dim)
58+
hs_c = hs_c.view(self.num_layers, 2, batch_size, self.num_dimensions)\
59+
.sum(dim=1).squeeze() # (layers, batch, dim)
60+
padded_seqs = padded_seqs.view(batch_size, max_seq_size, 2, self.num_dimensions)\
61+
.sum(dim=2).squeeze() # (batch, seq, dim)
8162

8263
return padded_seqs, (hs_h, hs_c)
8364

8465
def _initialize_hidden_state(self, batch_size):
8566
if torch.cuda.is_available():
86-
return torch.zeros(
87-
self.num_layers * 2, batch_size, self.num_dimensions
88-
).cuda()
89-
return torch.zeros(self.num_layers * 2, batch_size, self.num_dimensions)
67+
return torch.zeros(self.num_layers*2, batch_size, self.num_dimensions).cuda()
9068

9169
def get_params(self):
9270
parameter_enums = GenerativeModelParametersEnum
@@ -98,23 +76,23 @@ def get_params(self):
9876
parameter_enums.NUMBER_OF_LAYERS: self.num_layers,
9977
parameter_enums.NUMBER_OF_DIMENSIONS: self.num_dimensions,
10078
parameter_enums.VOCABULARY_SIZE: self.vocabulary_size,
101-
parameter_enums.DROPOUT: self.dropout,
79+
parameter_enums.DROPOUT: self.dropout
10280
}
10381

10482

10583
class AttentionLayer(tnn.Module):
84+
10685
def __init__(self, num_dimensions):
10786
super(AttentionLayer, self).__init__()
10887

10988
self.num_dimensions = num_dimensions
11089

11190
self._attention_linear = tnn.Sequential(
112-
tnn.Linear(self.num_dimensions * 2, self.num_dimensions), tnn.Tanh()
91+
tnn.Linear(self.num_dimensions*2, self.num_dimensions),
92+
tnn.Tanh()
11393
)
11494

115-
def forward(
116-
self, padded_seqs, encoder_padded_seqs, decoder_mask
117-
): # pylint: disable=arguments-differ
95+
def forward(self, padded_seqs, encoder_padded_seqs, decoder_mask): # pylint: disable=arguments-differ
11896
"""
11997
Performs the forward pass.
12098
:param padded_seqs: A tensor with the output sequences (batch, seq_d, dim).
@@ -124,19 +102,12 @@ def forward(
124102
"""
125103
# scaled dot-product
126104
# (batch, seq_d, 1, dim)*(batch, 1, seq_e, dim) => (batch, seq_d, seq_e*)
127-
attention_weights = (
128-
(padded_seqs.unsqueeze(dim=2) * encoder_padded_seqs.unsqueeze(dim=1))
129-
.sum(dim=3)
130-
.div(math.sqrt(self.num_dimensions))
105+
attention_weights = (padded_seqs.unsqueeze(dim=2)*encoder_padded_seqs.unsqueeze(dim=1))\
106+
.sum(dim=3).div(math.sqrt(self.num_dimensions))\
131107
.softmax(dim=2)
132-
)
133108
# (batch, seq_d, seq_e*)@(batch, seq_e, dim) => (batch, seq_d, dim)
134109
attention_context = attention_weights.bmm(encoder_padded_seqs)
135-
return (
136-
self._attention_linear(torch.cat([padded_seqs, attention_context], dim=2))
137-
* decoder_mask,
138-
attention_weights,
139-
)
110+
return (self._attention_linear(torch.cat([padded_seqs, attention_context], dim=2))*decoder_mask, attention_weights)
140111

141112

142113
class Decoder(tnn.Module):
@@ -154,26 +125,16 @@ def __init__(self, num_layers, num_dimensions, vocabulary_size, dropout):
154125

155126
self._embedding = tnn.Sequential(
156127
tnn.Embedding(self.vocabulary_size, self.num_dimensions),
157-
tnn.Dropout(dropout),
158-
)
159-
self._rnn = tnn.LSTM(
160-
self.num_dimensions,
161-
self.num_dimensions,
162-
self.num_layers,
163-
batch_first=True,
164-
dropout=self.dropout,
165-
bidirectional=False,
128+
tnn.Dropout(dropout)
166129
)
130+
self._rnn = tnn.LSTM(self.num_dimensions, self.num_dimensions, self.num_layers,
131+
batch_first=True, dropout=self.dropout, bidirectional=False)
167132

168133
self._attention = AttentionLayer(self.num_dimensions)
169134

170-
self._linear = tnn.Linear(
171-
self.num_dimensions, self.vocabulary_size
172-
) # just to redimension
135+
self._linear = tnn.Linear(self.num_dimensions, self.vocabulary_size) # just to redimension
173136

174-
def forward(
175-
self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states
176-
): # pylint: disable=arguments-differ
137+
def forward(self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states): # pylint: disable=arguments-differ
177138
"""
178139
Performs the forward pass.
179140
:param padded_seqs: A tensor with the output sequences (batch, seq_d, dim).
@@ -187,20 +148,13 @@ def forward(
187148

188149
padded_encoded_seqs = self._embedding(padded_seqs)
189150
packed_encoded_seqs = tnnur.pack_padded_sequence(
190-
padded_encoded_seqs, seq_lengths, batch_first=True, enforce_sorted=False
191-
)
192-
packed_encoded_seqs, hidden_states = self._rnn(
193-
packed_encoded_seqs, hidden_states
194-
)
195-
padded_encoded_seqs, _ = tnnur.pad_packed_sequence(
196-
packed_encoded_seqs, batch_first=True
197-
) # (batch, seq, dim)
151+
padded_encoded_seqs, seq_lengths, batch_first=True, enforce_sorted=False)
152+
packed_encoded_seqs, hidden_states = self._rnn(packed_encoded_seqs, hidden_states)
153+
padded_encoded_seqs, _ = tnnur.pad_packed_sequence(packed_encoded_seqs, batch_first=True) # (batch, seq, dim)
198154

199155
mask = (padded_encoded_seqs[:, :, 0] != 0).unsqueeze(dim=-1).type(torch.float)
200-
attn_padded_encoded_seqs, attention_weights = self._attention(
201-
padded_encoded_seqs, encoder_padded_seqs, mask
202-
)
203-
logits = self._linear(attn_padded_encoded_seqs) * mask # (batch, seq, voc_size)
156+
attn_padded_encoded_seqs, attention_weights = self._attention(padded_encoded_seqs, encoder_padded_seqs, mask)
157+
logits = self._linear(attn_padded_encoded_seqs)*mask # (batch, seq, voc_size)
204158
return logits, hidden_states, attention_weights
205159

206160
def get_params(self):
@@ -213,7 +167,7 @@ def get_params(self):
213167
parameter_enum.NUMBER_OF_LAYERS: self.num_layers,
214168
parameter_enum.NUMBER_OF_DIMENSIONS: self.num_dimensions,
215169
parameter_enum.VOCABULARY_SIZE: self.vocabulary_size,
216-
parameter_enum.DROPOUT: self.dropout,
170+
parameter_enum.DROPOUT: self.dropout
217171
}
218172

219173

@@ -228,9 +182,7 @@ def __init__(self, encoder_params, decoder_params):
228182
self._encoder = Encoder(**encoder_params)
229183
self._decoder = Decoder(**decoder_params)
230184

231-
def forward(
232-
self, encoder_seqs, encoder_seq_lengths, decoder_seqs, decoder_seq_lengths
233-
): # pylint: disable=arguments-differ
185+
def forward(self, encoder_seqs, encoder_seq_lengths, decoder_seqs, decoder_seq_lengths): # pylint: disable=arguments-differ
234186
"""
235187
Performs the forward pass.
236188
:param encoder_seqs: A tensor with the output sequences (batch, seq_d, dim).
@@ -239,12 +191,8 @@ def forward(
239191
:param decoder_seq_lengths: The lengths of the decoder sequences.
240192
:return : The output logits as a tensor (batch, seq_d, dim).
241193
"""
242-
encoder_padded_seqs, hidden_states = self.forward_encoder(
243-
encoder_seqs, encoder_seq_lengths
244-
)
245-
logits, _, _ = self.forward_decoder(
246-
decoder_seqs, decoder_seq_lengths, encoder_padded_seqs, hidden_states
247-
)
194+
encoder_padded_seqs, hidden_states = self.forward_encoder(encoder_seqs, encoder_seq_lengths)
195+
logits, _, _ = self.forward_decoder(decoder_seqs, decoder_seq_lengths, encoder_padded_seqs, hidden_states)
248196
return logits
249197

250198
def forward_encoder(self, padded_seqs, seq_lengths):
@@ -256,19 +204,15 @@ def forward_encoder(self, padded_seqs, seq_lengths):
256204
"""
257205
return self._encoder(padded_seqs, seq_lengths)
258206

259-
def forward_decoder(
260-
self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states
261-
):
207+
def forward_decoder(self, padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states):
262208
"""
263209
Does a forward pass only of the decoder.
264210
:param hidden_states: The hidden states from the encoder.
265211
:param padded_seqs: The data to feed to the decoder.
266212
:param seq_lengths: The length of each sequence in the batch.
267213
:return : Returns the logits and the hidden state for each element of the sequence passed.
268214
"""
269-
return self._decoder(
270-
padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states
271-
)
215+
return self._decoder(padded_seqs, seq_lengths, encoder_padded_seqs, hidden_states)
272216

273217
def get_params(self):
274218
"""
@@ -277,5 +221,5 @@ def get_params(self):
277221
"""
278222
return {
279223
"encoder_params": self._encoder.get_params(),
280-
"decoder_params": self._decoder.get_params(),
224+
"decoder_params": self._decoder.get_params()
281225
}

0 commit comments

Comments
 (0)