Skip to content

Commit 2d85600

Browse files
authored
Merge pull request #102 from UT-Austin-RPL/zs
minor changes to enable new methods
2 parents 9b45b8c + cb55f00 commit 2d85600

3 files changed

Lines changed: 34 additions & 3 deletions

File tree

amago/nets/ff.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ class Normalization(nn.Module):
1717
1818
Args:
1919
method: Normalization method to use. Options are: "layer", "batch",
20-
"rmsnorm", "unitball", "unitball-detach", "none". "unitball" is
21-
(x / ||x||), "unitball-detach" is (x / ||x||.detach()). "none" is a
22-
no-op and the rest are standard LayerNorm, BatchNorm, RMSNorm.
20+
"rmsnorm", "unitball", "unitball-detach", "hypersphere", "none".
21+
"unitball" is (x / ||x||), "unitball-detach" is (x / ||x||.detach()),
22+
"hypersphere" is (x / ||x|| * sqrt(d)) projecting onto S^{d-1} with
23+
||x|| = sqrt(d). "none" is a no-op and the rest are standard
24+
LayerNorm, BatchNorm, RMSNorm.
2325
d_model: Expected dimension of the input to normalize (scalar). Operates
2426
on the last dimensions of the input sequence.
2527
"""
@@ -43,6 +45,8 @@ def __init__(self, method: Optional[str], d_model: int):
4345
torch.linalg.vector_norm(x, ord=2, dim=-1, keepdim=True) + 1e-5
4446
).detach()
4547
)
48+
elif method == "hypersphere":
49+
self.norm = lambda x: F.normalize(x, dim=-1) * (x.shape[-1] ** 0.5)
4650
elif method == "rmsnorm":
4751
self.norm = _RMSNorm(size=d_model)
4852
elif method == "simnorm":

amago/nets/traj_encoders.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,26 @@ def forward(
164164
pass
165165

166166

167+
@gin.configurable
168+
@register_traj_encoder("identity")
169+
class IdentityTrajEncoder(TrajEncoder):
170+
"""Passthrough trajectory encoder that returns tstep embeddings unchanged."""
171+
172+
@property
173+
def emb_dim(self) -> int:
174+
return self.tstep_dim
175+
176+
def forward(
177+
self,
178+
seq: torch.Tensor,
179+
time_idxs: torch.Tensor,
180+
hidden_state: Optional[Any] = None,
181+
log_dict: Optional[dict] = None,
182+
) -> Tuple[torch.Tensor, Optional[Any]]:
183+
new_hidden = None if hidden_state is None else hidden_state
184+
return seq, new_hidden
185+
186+
167187
@gin.configurable
168188
@register_traj_encoder("ff")
169189
class FFTrajEncoder(TrajEncoder):

amago/nets/tstep_encoders.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ class FFTstepEncoder(TstepEncoder):
175175
space. If None, every key in the observation is used. Multi-modal
176176
observations are handled by flattening and concatenating values in a
177177
consistent order (alphabetical by key). Defaults to None.
178+
auto_scale: If True, d_hidden and d_output are clamped to be at least
179+
as large as the flattened input dimension, preventing bottlenecks
180+
when observation keys contribute many features. Defaults to False.
178181
"""
179182

180183
def __init__(
@@ -190,6 +193,7 @@ def __init__(
190193
hide_rl2s: bool = False,
191194
normalize_inputs: bool = True,
192195
specify_obs_keys: Optional[list[str]] = None,
196+
auto_scale: bool = False,
193197
):
194198
super().__init__(obs_space=obs_space, rl2_space=rl2_space, hide_rl2s=hide_rl2s)
195199
if specify_obs_keys is None:
@@ -200,6 +204,9 @@ def __init__(
200204
math.prod(self.obs_space[key].shape) for key in self.obs_keys
201205
)
202206
in_dim = flat_obs_shape + self.rl2_space.shape[-1]
207+
if auto_scale:
208+
d_hidden = max(d_hidden, in_dim)
209+
d_output = max(d_output, in_dim)
203210
self.in_norm = InputNorm(in_dim, skip=not normalize_inputs)
204211
self.base = ff.MLP(
205212
d_inp=in_dim,

0 commit comments

Comments
 (0)