Skip to content

Commit 4338e6d

Browse files
committed
add temperature sampling to grus
1 parent 2d451f1 commit 4338e6d

File tree

5 files changed

+322
-53
lines changed

5 files changed

+322
-53
lines changed

acegen/models/gru.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
3434
return out
3535

3636

37+
class Temperature(torch.nn.Module):
38+
"""Implements a temperature layer.
39+
40+
Simple Module that applies a temperature value to the logits for RL inference.
41+
42+
Args:
43+
temperature (float): The temperature value.
44+
"""
45+
46+
def __init__(self):
47+
super().__init__()
48+
49+
def forward(self, logits: torch.Tensor, temperature: torch.tensor) -> torch.Tensor:
50+
return logits / temperature
51+
52+
3753
def create_gru_components(
3854
vocabulary_size: int,
3955
embedding_size: int = 256,
@@ -47,7 +63,7 @@ def create_gru_components(
4763
recurrent_state: str = "recurrent_state",
4864
python_based: bool = False,
4965
):
50-
"""Create all GRU model components: embedding, GRU, and head.
66+
"""Create all GRU model components: embedding, GRU, head and temperature.
5167
5268
These modules handle the case of having a time dimension (RL training)
5369
and not having it (RL inference).
@@ -97,8 +113,13 @@ def create_gru_components(
97113
in_keys=["features"],
98114
out_keys=[out_key],
99115
)
116+
temperature = TensorDictModule(
117+
Temperature(),
118+
in_keys=[out_key, "temperature"],
119+
out_keys=[out_key],
120+
)
100121

101-
return embedding_module, gru_module, head
122+
return embedding_module, gru_module, head, temperature
102123

103124

104125
def create_gru_actor(
@@ -139,7 +160,7 @@ def create_gru_actor(
139160
training_actor, inference_actor = create_gru_actor(10)
140161
```
141162
"""
142-
embedding, gru, head = create_gru_components(
163+
embedding, gru, head, temperature = create_gru_components(
143164
vocabulary_size,
144165
embedding_size,
145166
hidden_size,
@@ -153,7 +174,7 @@ def create_gru_actor(
153174
python_based,
154175
)
155176

156-
actor_inference_model = TensorDictSequential(embedding, gru, head)
177+
actor_inference_model = TensorDictSequential(embedding, gru, head, temperature)
157178
actor_training_model = TensorDictSequential(
158179
embedding,
159180
gru.set_recurrent_mode(True),
@@ -217,7 +238,7 @@ def create_gru_critic(
217238
output_size = vocabulary_size if critic_value_per_action else 1
218239
out_key = "action_value" if critic_value_per_action else "state_value"
219240

220-
embedding, gru, head = create_gru_components(
241+
embedding, gru, head, _ = create_gru_components(
221242
vocabulary_size,
222243
embedding_size,
223244
hidden_size,
@@ -281,7 +302,7 @@ def create_gru_actor_critic(
281302
inference_critic) = create_gru_actor_critic(10)
282303
```
283304
"""
284-
embedding, gru, actor_head = create_gru_components(
305+
embedding, gru, actor_head, temperature = create_gru_components(
285306
vocabulary_size,
286307
embedding_size,
287308
hidden_size,
@@ -295,6 +316,8 @@ def create_gru_actor_critic(
295316
python_based,
296317
)
297318

319+
actor_head = TensorDictSequential(actor_head, temperature)
320+
298321
actor_head = ProbabilisticActor(
299322
module=actor_head,
300323
in_keys=["logits"],

acegen/rl_env/token_env.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@
22

33
import torch
44
from tensordict.tensordict import TensorDict, TensorDictBase
5-
from torchrl.data import (
6-
Composite,
7-
Categorical,
8-
OneHotDiscreteTensorSpec,
9-
Unbounded,
10-
)
5+
from torchrl.data import Categorical, Composite, OneHotDiscreteTensorSpec, Unbounded
116
from torchrl.data.utils import DEVICE_TYPING
127
from torchrl.envs import EnvBase
138

@@ -109,6 +104,9 @@ def __init__(
109104
"terminated": torch.zeros(
110105
self.num_envs, 1, device=self.device, dtype=torch.bool
111106
),
107+
"temperature": torch.ones(
108+
self.num_envs, 1, device=self.device, dtype=torch.float32
109+
),
112110
"sequence": self.sequence.clone(),
113111
"sequence_mask": self.sequence_mask.clone(),
114112
},
@@ -181,9 +179,7 @@ def _set_seed(self, seed: Optional[int] = -1) -> None:
181179

182180
def _set_specs(self) -> None:
183181
obs_spec = (
184-
OneHotDiscreteTensorSpec
185-
if self.one_hot_obs_encoding
186-
else Categorical
182+
OneHotDiscreteTensorSpec if self.one_hot_obs_encoding else Categorical
187183
)
188184
self.observation_spec = Composite(
189185
{
@@ -220,9 +216,7 @@ def _set_specs(self) -> None:
220216
}
221217
).expand(self.num_envs)
222218
action_spec = (
223-
OneHotDiscreteTensorSpec
224-
if self.one_hot_action_encoding
225-
else Categorical
219+
OneHotDiscreteTensorSpec if self.one_hot_action_encoding else Categorical
226220
)
227221
self.action_spec = Composite(
228222
{
@@ -246,12 +240,8 @@ def _set_specs(self) -> None:
246240
self.done_spec = (
247241
Composite(
248242
{
249-
"done": Categorical(
250-
n=2, dtype=torch.bool, device=self.device
251-
),
252-
"truncated": Categorical(
253-
n=2, dtype=torch.bool, device=self.device
254-
),
243+
"done": Categorical(n=2, dtype=torch.bool, device=self.device),
244+
"truncated": Categorical(n=2, dtype=torch.bool, device=self.device),
255245
"terminated": Categorical(
256246
n=2, dtype=torch.bool, device=self.device
257247
),

acegen/rl_env/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def generate_complete_smiles(
3737
prompt: Union[str, list] = None,
3838
end_of_episode_key: str = "done",
3939
exploration_type: ExplorationType = ExplorationType.RANDOM,
40+
temperature: float | torch.Tensor = 1.0,
4041
promptsmiles: str = None,
4142
promptsmiles_optimize: bool = True,
4243
promptsmiles_shuffle: bool = True,
@@ -68,6 +69,7 @@ def generate_complete_smiles(
6869
indicates the end of an episode. Defaults to "done".
6970
exploration_type (ExplorationType, optional): Exploration type to use. Defaults to
7071
:class:`~torchrl.envs.utils.ExplorationType.RANDOM`.
72+
temperature (float, optional): Temperature to use when sampling actions from the policy.
7173
promptsmiles (str, optional): SMILES string of scaffold with attachment points or fragments seperated
7274
by "." with one attachment point each.
7375
promptsmiles_optimize (bool, optional): Optimize the prompt for the model being used.
@@ -335,6 +337,7 @@ def generate_complete_smiles(
335337

336338
initial_observation = initial_observation.to(policy_device)
337339
tensordict_ = initial_observation
340+
initial_temperature = tensordict_["temperature"].clone()
338341
finished = (
339342
torch.zeros(batch_size, dtype=torch.bool).unsqueeze(-1).to(policy_device)
340343
)
@@ -352,6 +355,9 @@ def generate_complete_smiles(
352355
if prompt:
353356
enforce_mask = enc_prompts[:, _] != vocabulary.end_token_index
354357

358+
# Define temperature tensor
359+
tensordict_.set("temperature", initial_temperature * temperature)
360+
355361
# Execute policy
356362
tensordict_ = tensordict_.to(policy_device)
357363
policy_sample(tensordict_)

0 commit comments

Comments
 (0)