Skip to content

Commit 8b5b068

Browse files
🔨 Modified activation in TimeEmbed function
1 parent 3c796e4 commit 8b5b068

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

ice_station_zebra/models/common/timeembed.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import torch.nn as nn
22
from torch import Tensor
33

4+
from .activations import get_activation
5+
46

57
class TimeEmbed(nn.Module):
6-
def __init__(self, dim: int = 256) -> None:
8+
def __init__(self, dim: int = 256,
9+
activation: str = "SiLU",) -> None:
710
super().__init__()
811

12+
def act():
13+
return get_activation(activation)
14+
915
self.model = nn.Sequential(
1016
nn.Linear(dim, dim * 4),
11-
nn.SiLU(),
17+
act(),
1218
nn.Linear(dim * 4, dim),
1319
)
1420

0 commit comments

Comments
 (0)