Skip to content

Commit 1f2fcc9

Browse files
committed
Remove files from src, fix utils to run everything on same device
1 parent 73171dd commit 1f2fcc9

File tree

4 files changed

+7
-248
lines changed

4 files changed

+7
-248
lines changed

src/benchmark_encoder.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

src/export.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

src/test_encoder.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

src/utils.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
1111
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
1212
omega = torch.arange(dim // 4) / (dim // 4 - 1)
1313
omega = 1.0 / (temperature**omega)
14-
omega = omega.to(y.device)
1514

1615
y = y.flatten()[:, None] * omega[None, :]
1716
x = x.flatten()[:, None] * omega[None, :]
@@ -25,27 +24,26 @@ def posemb_sincos_2d_with_gsd(
2524
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
2625
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
2726

28-
omega = torch.arange(dim // 4, device=gsd.device) / (dim // 4 - 1)
27+
gsd = gsd.to(x.device)
28+
omega = torch.arange(dim // 4) / (dim // 4 - 1)
2929
omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0) # Adjusted for g
30-
omega = omega.to(y.device)
3130

3231
y = y.flatten()[:, None] * omega[None, :]
3332
x = x.flatten()[:, None] * omega[None, :]
3433
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
3534
return pe.type(dtype)
3635

3736

38-
def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32):
37+
def posemb_sincos_1d(waves, dim, temperature: int = 10000, dtype=torch.float32):
3938
assert (
4039
dim % 2 == 0
4140
), "Feature dimension must be a multiple of 2 for sincos embedding"
42-
pos = torch.arange(pos) if isinstance(pos, int) else pos
41+
waves = torch.arange(waves) if isinstance(waves, int) else waves
4342

44-
omega = torch.arange(dim // 2) / (dim // 2 - 1)
43+
omega = torch.arange(dim // 2, device=waves.device) / (dim // 2 - 1)
4544
omega = 1.0 / (temperature**omega)
46-
omega = omega.to(pos.device)
4745

48-
scaled_pos = pos[:, None] * omega[None, :]
49-
pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1)
46+
scaled_waves = waves[:, None] * omega[None, :]
47+
pe = torch.cat((scaled_waves.sin(), scaled_waves.cos()), dim=1)
5048

5149
return pe.type(dtype)

0 commit comments

Comments
 (0)