We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1373ffb commit b19faf6Copy full SHA for b19faf6
model/embedding.py
@@ -0,0 +1,21 @@
1
+import torch
2
+from torch import nn
3
+
4
5
+class TimeEmbedding(nn.Module):
6
7
+ def __init__(self, embedding_dim: int):
8
+ super().__init__()
9
+ self.embedding_dim = embedding_dim
10
+ self.half_embedding_dim = embedding_dim // 2
11
12
+ @torch.no_grad()
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ exponents = (
15
+ torch.arange(0, self.half_embedding_dim, dtype=torch.float32) / self.halv_embedding_dim
16
+ ).to(x.device)
17
+ factors = 10_000 ** exponents
18
19
+ embedding_arguments = x[:, None].repeat(1, self.half_embedding_dim) / factors
20
+ embeddings = torch.cat([torch.sin(embedding_arguments), torch.cos(embedding_arguments)], dim=-1)
21
+ return embeddings
0 commit comments