Skip to content

Commit afc5321

Browse files
committed
Update
[ghstack-poisoned]
1 parent 213fe63 commit afc5321

1 file changed

Lines changed: 36 additions & 10 deletions

File tree

torchrl/modules/models/model_based.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,25 @@
1818

1919
# from torchrl.modules.tensordict_module.rnn import GRUCell
2020
from torch.nn import GRUCell
21+
from torchrl._utils import _maybe_record_function_decorator
2122

2223
from torchrl.modules.models.models import MLP
2324

2425
UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11")
2526

2627

28+
class _Contiguous(nn.Module):
29+
"""Helper module that makes a tensor contiguous.
30+
31+
This is useful inside nn.Sequential for torch.compile inductor compatibility.
32+
Inductor sometimes needs explicit contiguous() calls after reshape operations
33+
for efficient memory layout.
34+
"""
35+
36+
def forward(self, x):
37+
return x.contiguous()
38+
39+
2740
class DreamerActor(nn.Module):
2841
"""Dreamer actor network.
2942
@@ -129,16 +142,18 @@ def __init__(
129142
k = k * 2
130143
self.encoder = nn.Sequential(*layers)
131144

145+
@_maybe_record_function_decorator("ObsEncoder.forward")
132146
def forward(self, observation):
133147
*batch_sizes, C, H, W = observation.shape
134-
if len(batch_sizes) == 0:
135-
end_dim = 0
136-
else:
137-
end_dim = len(batch_sizes) - 1
138-
observation = torch.flatten(observation, start_dim=0, end_dim=end_dim)
148+
# Flatten all batch dimensions into one for conv
149+
# Use contiguous() for inductor compatibility
150+
observation = observation.flatten(
151+
0, len(batch_sizes) - 1 if batch_sizes else 0
152+
).contiguous()
139153
obs_encoded = self.encoder(observation)
140-
latent = obs_encoded.reshape(*batch_sizes, -1)
141-
return latent
154+
# Reshape back to original batch dims + latent
155+
latent = obs_encoded.unflatten(0, batch_sizes) if batch_sizes else obs_encoded
156+
return latent.reshape(*batch_sizes, -1).contiguous()
142157

143158

144159
class ObsDecoder(nn.Module):
@@ -238,14 +253,25 @@ def __init__(
238253
self.decoder = nn.Sequential(*layers)
239254
self._depth = channels
240255

256+
@_maybe_record_function_decorator("ObsDecoder.forward")
241257
def forward(self, state, rnn_hidden):
258+
# Concatenate and project to latent space
242259
latent = self.state_to_latent(torch.cat([state, rnn_hidden], dim=-1))
243260
*batch_sizes, D = latent.shape
244-
latent = latent.view(-1, D, 1, 1)
261+
# Flatten batch dimensions and reshape for conv
262+
latent = (
263+
latent.flatten(0, len(batch_sizes) - 1 if batch_sizes else 0)
264+
.unsqueeze(-1)
265+
.unsqueeze(-1)
266+
.contiguous()
267+
)
245268
obs_decoded = self.decoder(latent)
246269
_, C, H, W = obs_decoded.shape
247-
obs_decoded = obs_decoded.view(*batch_sizes, C, H, W)
248-
return obs_decoded
270+
# Unflatten back to original batch dims
271+
obs_decoded = (
272+
obs_decoded.unflatten(0, batch_sizes) if batch_sizes else obs_decoded
273+
)
274+
return obs_decoded.contiguous()
249275

250276

251277
class RSSMRollout(TensorDictModuleBase):

0 commit comments

Comments
 (0)