@@ -11,7 +11,6 @@ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
1111 assert (dim % 4 ) == 0 , "feature dimension must be multiple of 4 for sincos emb"
1212 omega = torch .arange (dim // 4 ) / (dim // 4 - 1 )
1313 omega = 1.0 / (temperature ** omega )
14- omega = omega .to (y .device )
1514
1615 y = y .flatten ()[:, None ] * omega [None , :]
1716 x = x .flatten ()[:, None ] * omega [None , :]
@@ -25,27 +24,26 @@ def posemb_sincos_2d_with_gsd(
2524 y , x = torch .meshgrid (torch .arange (h ), torch .arange (w ), indexing = "ij" )
2625 assert (dim % 4 ) == 0 , "feature dimension must be multiple of 4 for sincos emb"
2726
28- omega = torch .arange (dim // 4 , device = gsd .device ) / (dim // 4 - 1 )
27+ gsd = gsd .to (x .device )
28+ omega = torch .arange (dim // 4 ) / (dim // 4 - 1 )
2929 omega = 1.0 / (temperature ** (2 * omega / dim )) * (gsd / 1.0 ) # Adjusted for g
30- omega = omega .to (y .device )
3130
3231 y = y .flatten ()[:, None ] * omega [None , :]
3332 x = x .flatten ()[:, None ] * omega [None , :]
3433 pe = torch .cat ((x .sin (), x .cos (), y .sin (), y .cos ()), dim = 1 )
3534 return pe .type (dtype )
3635
3736
38- def posemb_sincos_1d (pos , dim , temperature : int = 10000 , dtype = torch .float32 ):
37+ def posemb_sincos_1d (waves , dim , temperature : int = 10000 , dtype = torch .float32 ):
3938 assert (
4039 dim % 2 == 0
4140 ), "Feature dimension must be a multiple of 2 for sincos embedding"
42- pos = torch .arange (pos ) if isinstance (pos , int ) else pos
41+ waves = torch .arange (waves ) if isinstance (waves , int ) else waves
4342
44- omega = torch .arange (dim // 2 ) / (dim // 2 - 1 )
43+ omega = torch .arange (dim // 2 , device = waves . device ) / (dim // 2 - 1 )
4544 omega = 1.0 / (temperature ** omega )
46- omega = omega .to (pos .device )
4745
48- scaled_pos = pos [:, None ] * omega [None , :]
49- pe = torch .cat ((scaled_pos .sin (), scaled_pos .cos ()), dim = 1 )
46+ scaled_waves = waves [:, None ] * omega [None , :]
47+ pe = torch .cat ((scaled_waves .sin (), scaled_waves .cos ()), dim = 1 )
5048
5149 return pe .type (dtype )
0 commit comments