Skip to content

Commit 9b5e860

Browse files
committed
init promptsmiles scaffold update
1 parent 4338e6d commit 9b5e860

File tree

8 files changed

+235
-664
lines changed

8 files changed

+235
-664
lines changed

acegen/models/common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
class Temperature(torch.nn.Module):
4+
"""Implements a temperature layer.
5+
6+
Simple Module that applies a temperature value to the logits for RL inference.
7+
8+
Args:
9+
temperature (float): The temperature value.
10+
"""
11+
12+
def __init__(self):
13+
super().__init__()
14+
15+
def forward(self, logits: torch.Tensor, temperature: torch.tensor) -> torch.Tensor:
16+
return logits / temperature

acegen/models/gru.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import torch
44
from tensordict.nn import TensorDictModule, TensorDictSequential
55
from torchrl.envs import ExplorationType
6-
from torchrl.modules import ActorValueOperator, GRUModule, MLP, ProbabilisticActor
6+
from torchrl.modules import ActorValueOperator, GRUModule, MLP, ProbabilisticActor, MaskedCategorical
7+
8+
from acegen.models.common import Temperature
79

810

911
class Embed(torch.nn.Module):
@@ -34,22 +36,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
3436
return out
3537

3638

37-
class Temperature(torch.nn.Module):
38-
"""Implements a temperature layer.
39-
40-
Simple Module that applies a temperature value to the logits for RL inference.
41-
42-
Args:
43-
temperature (float): The temperature value.
44-
"""
45-
46-
def __init__(self):
47-
super().__init__()
48-
49-
def forward(self, logits: torch.Tensor, temperature: torch.tensor) -> torch.Tensor:
50-
return logits / temperature
51-
52-
5339
def create_gru_components(
5440
vocabulary_size: int,
5541
embedding_size: int = 256,
@@ -133,6 +119,7 @@ def create_gru_actor(
133119
return_log_prob=True,
134120
in_key: str = "observation",
135121
out_key: str = "logits",
122+
action_mask_key: str = "action_mask",
136123
recurrent_state: str = "recurrent_state_actor",
137124
python_based: bool = False,
138125
):
@@ -151,6 +138,7 @@ def create_gru_actor(
151138
of the action.
152139
in_key (str): The input key name.
153140
out_key (str):): The output key name.
141+
action_mask_key (str): The action mask key name.
154142
recurrent_state (str): The name of the recurrent state.
155143
python_based (bool): Whether to use the Python-based GRU module.
156144
Default is False, a CuDNN-based GRU module is used.
@@ -181,11 +169,18 @@ def create_gru_actor(
181169
head,
182170
)
183171

172+
if action_mask_key:
173+
inf_keys = {"logits": "logits", "mask": action_mask_key}
174+
inf_dist = MaskedCategorical
175+
else:
176+
inf_keys = ["logits"]
177+
inf_dist = distribution_class
178+
184179
actor_inference_model = ProbabilisticActor(
185180
module=actor_inference_model,
186-
in_keys=["logits"],
181+
in_keys=inf_keys,
187182
out_keys=["action"],
188-
distribution_class=distribution_class,
183+
distribution_class=inf_dist,
189184
return_log_prob=return_log_prob,
190185
default_interaction_type=ExplorationType.RANDOM,
191186
)

acegen/rl_env/token_env.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def __init__(
9191
self.num_envs, self.max_length, device=self.device, dtype=torch.bool
9292
)
9393
self.sequence_mask[:, 0] = True
94+
95+
self.action_mask = torch.ones(
96+
self.num_envs, self.length_vocabulary, device=self.device, dtype=torch.bool
97+
)
9498

9599
self._reset_tensordict = TensorDict(
96100
{
@@ -109,6 +113,7 @@ def __init__(
109113
),
110114
"sequence": self.sequence.clone(),
111115
"sequence_mask": self.sequence_mask.clone(),
116+
"action_mask": self.action_mask.clone()
112117
},
113118
device=self.device,
114119
batch_size=self.batch_size,
@@ -167,6 +172,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
167172
"observation": obs,
168173
"sequence": self.sequence.clone(),
169174
"sequence_mask": self.sequence_mask.clone(),
175+
"action_mask": self.action_mask.clone()
170176
},
171177
device=self.device,
172178
batch_size=self.batch_size,

acegen/rl_env/utils.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import re
12
import warnings
3+
from pathlib import Path
24
from functools import partial
35
from typing import Callable, Union
46

@@ -12,6 +14,7 @@
1214
from torchrl.envs.utils import ExplorationType, step_mdp
1315

1416
from acegen.data.utils import smiles_to_tensordict
17+
from acegen.data.smiles_dataset import load_dataset
1518
from acegen.vocabulary import Vocabulary
1619

1720
try:
@@ -134,12 +137,24 @@ def generate_complete_smiles(
134137
vocabulary=vocabulary,
135138
max_length=max_length,
136139
)
137-
# Split fragments into a list if there are multiple
138-
promptsmiles = promptsmiles.split(".")
139-
if len(promptsmiles) == 1:
140-
promptsmiles = promptsmiles[0]
141-
140+
# Deduce type of prompt
142141
if isinstance(promptsmiles, str):
142+
if Path(promptsmiles).exists():
143+
promptsmiles = load_dataset(promptsmiles)
144+
prompt_type = "scaffold"
145+
elif "." in promptsmiles:
146+
promptsmiles = promptsmiles.split(".")
147+
prompt_type = "fragment"
148+
else:
149+
prompt_type = "scaffold"
150+
elif isinstance(promptsmiles, list):
151+
prompt_type = "scaffold"
152+
else:
153+
raise ValueError(
154+
"PromptSMILES must be a string or a list of strings, or a path to a file."
155+
)
156+
157+
if prompt_type == "scaffold":
143158
# We are decorating a Scaffold
144159
PS = ScaffoldDecorator(
145160
scaffold=promptsmiles,
@@ -152,7 +167,7 @@ def generate_complete_smiles(
152167
return_all=True,
153168
)
154169

155-
if isinstance(promptsmiles, list):
170+
if prompt_type == "fragment":
156171
# We are linking fragments
157172
PS = FragmentLinker(
158173
fragments=promptsmiles,
@@ -297,8 +312,20 @@ def generate_complete_smiles(
297312

298313
failed_encodings = []
299314
if prompt:
315+
print(prompt)
300316
if isinstance(prompt, str):
301317
prompt = [prompt] * batch_size[0]
318+
319+
# Add X to vocabulary for substitution
320+
vocabulary.add_characters("X")
321+
free_sample_tokens = torch.tensor([vocabulary["X"], vocabulary.end_token_index]).to(policy_device)
322+
323+
# Create action mask of atoms
324+
atom_patt = re.compile(r"(\[[^\]]*\]|Br|Cl|[a-wyzA-WYZ])")
325+
atom_tokens = torch.tensor([vocabulary[t] for t in vocabulary.chars if atom_patt.fullmatch(t)]).to(policy_device)
326+
atom_mask = torch.zeros(len(vocabulary)-1, dtype=torch.bool).to(policy_device)
327+
atom_mask = atom_mask.scatter(0, atom_tokens, True)
328+
302329
# Encode the prompt(s)
303330
tokens = []
304331
for i, smi in enumerate(prompt):
@@ -353,7 +380,12 @@ def generate_complete_smiles(
353380
tensordict_.set("mask", torch.ones_like(finished))
354381
tensordict_.set(("next", "mask"), torch.ones_like(finished))
355382
if prompt:
356-
enforce_mask = enc_prompts[:, _] != vocabulary.end_token_index
383+
enforce_mask = ~torch.isin(
384+
enc_prompts[:, _], free_sample_tokens
385+
)
386+
# Apply atom mask if prompt is X
387+
if any(enc_prompts[:, _] == vocabulary["X"]):
388+
tensordict_["action_mask"][enc_prompts[:, _] == vocabulary["X"]] = atom_mask
357389

358390
# Define temperature tensor
359391
tensordict_.set("temperature", initial_temperature * temperature)
@@ -365,17 +397,20 @@ def generate_complete_smiles(
365397

366398
# Enforce prompt
367399
if prompt:
368-
new_action = (~enforce_mask * tensordict_.get("action")) + (
400+
prompt_action = (~enforce_mask * tensordict_.get("action")) + (
369401
enforce_mask * enc_prompts[:, _]
370402
).long()
371-
tensordict_.set("action", new_action)
403+
tensordict_.set("action", prompt_action)
372404

373405
# Step forward in the environment
374406
tensordict_ = environment.step(tensordict_)
375407

376408
# Mask out finished environments
377409
if finished.any():
378410
tensordict_.masked_fill_(finished.squeeze(), 0)
411+
# Don't fill action_mask
412+
tensordict_["action_mask"][finished.squeeze()] = 1
413+
tensordict_["next"]["action_mask"][finished.squeeze()] = 1
379414

380415
# Extend list of tensordicts
381416
tensordicts.append(tensordict_.clone())

acegen/vocabulary/vocabulary.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
self.special_tokens = [end_token, start_token]
4545
self.special_tokens += list(set(special_tokens))
4646
self.additional_chars = set()
47-
self.chars = self.special_tokens
47+
self.chars = deepcopy(self.special_tokens)
4848
self.vocab_size = len(self.chars)
4949
self.vocab = dict(zip(self.chars, range(len(self.chars))))
5050
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
@@ -104,26 +104,39 @@ def decode(self, encoded_string, ignore_indices=()):
104104
return string
105105

106106
def add_characters(self, chars):
107-
"""Adds characters to the vocabulary.
107+
"""Adds characters to the end of the vocabulary.
108108
109109
Args:
110110
chars (list[str]): A list of characters to add to the vocabulary.
111111
"""
112+
additional_chars = set()
112113
for char in chars:
113114
if char not in self.chars:
114-
self.additional_chars.add(char)
115-
char_list = list(self.additional_chars)
116-
char_list.sort()
117-
self.chars = self.special_tokens + char_list
118-
self.vocab_size = len(self.chars)
119-
self.vocab = dict(zip(self.chars, range(len(self.chars))))
115+
additional_chars.add(char)
116+
additional_chars = list(additional_chars)
117+
additional_chars.sort()
118+
n_prev = len(self.chars)
119+
n_new = n_prev + len(additional_chars)
120+
self.chars += additional_chars
121+
self.additional_chars.update(additional_chars)
122+
self.vocab.update(dict(zip(additional_chars, range(n_prev, n_new))))
120123
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
124+
self.vocab_size = len(self.chars)
121125

122126
def __len__(self):
123127
return len(self.chars)
124128

125129
def __str__(self):
126-
return "Vocabulary containing {} tokens: {}".format(len(self), self.chars)
130+
return f"Vocabulary(len={len(self)}, tokens={self.vocab})"
131+
132+
def __repr__(self):
133+
return f"Vocabulary(len={len(self)}, tokens={self.vocab})"
134+
135+
def __getitem__(self, key):
136+
if isinstance(key, int):
137+
return self.reversed_vocab[key]
138+
if isinstance(key, str):
139+
return self.vocab[key]
127140

128141
@classmethod
129142
def create_from_strings(

0 commit comments

Comments
 (0)