77
88def _randomized_positions (from_v , to_v ):
99 pos = tf .random .uniform (from_v .shape , minval = 0 , maxval = 1 , dtype = tf .float32 )
10- pos = pos * tf .cast (to_v - from_v , dtype = tf .float32 )
11- pos = tf .cast (pos , dtype = tf .int32 )
12- return pos
10+ pos = pos * (to_v - from_v ).cast (tf .float32 )
11+ return pos .cast (tf .int32 )
1312
1413
1514def _rounded_mean_positions (from_v , to_v ):
16- pos = tf .cast (from_v + to_v , tf .float32 )
17- pos = pos / 2
18- pos = tf .round (pos )
19- return pos
20-
21-
22- def _broadcast (row_pos , col_pos , row_ones , col_ones ):
23- # broadcast (5,) to (20,) with column-axis
24- row_pos = tf .expand_dims (row_pos , 1 )
25- row_pos = tf .matmul (row_pos , col_ones , transpose_b = True )
26- row_pos = tf .reshape (row_pos , (- 1 ,))
27- row_pos = tf .stop_gradient (row_pos )
28-
29- # broadcast (4,) to (20,) with row-axis
30- col_pos = tf .expand_dims (col_pos , 1 )
31- col_pos = tf .matmul (row_ones , col_pos , transpose_b = True )
32- col_pos = tf .reshape (col_pos , (- 1 ,))
33- col_pos = tf .stop_gradient (col_pos )
34-
35- return row_pos , col_pos
15+ pos = (from_v + to_v ).cast (tf .float32 ) / 2.
16+ return pos .round ()
3617
3718
3819class PatchPositionEncoding (layers .Layer ):
@@ -57,7 +38,7 @@ def __init__(self,
5738 self .col_embedding = layers .Embedding (self .discretize_depth , self .embedding_dim , name = 'col_embedding' )
5839
5940 def _discretize (self , pos ):
60- return tf . round (pos * self .discretize_depth )
41+ return (pos * self .discretize_depth ). round ( )
6142
6243 def _discretize_interval (self , interval ):
6344 pos_from , pos_to = interval
@@ -83,12 +64,9 @@ def call(self, inputs, *args, **kwargs):
8364 row_pos = _rounded_mean_positions (row_pos_from , row_pos_to )
8465 col_pos = _rounded_mean_positions (col_pos_from , col_pos_to )
8566
86- col_pos = tf .cast (col_pos , dtype = tf .int32 )
87- row_pos = tf .cast (row_pos , dtype = tf .int32 )
88-
8967 # > Once row and column position encoding are retrieved from the embedding table,
9068 # > they are added onto the token embedding produced by the resnet embedding function.
91- return input_ids + self .row_embedding (row_pos ) + self .col_embedding (col_pos )
69+ return input_ids + self .row_embedding (row_pos . cast ( tf . int32 )) + self .col_embedding (col_pos . cast ( tf . int32 ) )
9270
9371 def get_config (self ):
9472 config = super (PatchPositionEncoding , self ).get_config ()
@@ -127,10 +105,10 @@ def call(self, inputs, *args, **kwargs):
127105
128106 residual = self .conv_proj (self .gn_proj (x ))
129107
130- x = tf . nn . gelu ( self .gn1 (x ))
108+ x = self .gn1 (x ). gelu ( )
131109 x = self .conv1 (x )
132110
133- x = tf . nn . gelu ( self .gn2 (x ))
111+ x = self .gn2 (x ). gelu ( )
134112 x = self .conv2 (x )
135113
136114 return x + residual
@@ -185,7 +163,7 @@ def call(self, inputs, *args, **kwargs):
185163 x = block (x )
186164 if self .conv_proj is not None :
187165 x = self .conv_proj (x )
188- x = tf .reshape (x , shape = (- 1 , inputs .shape [1 ], self .config .layer_width ))
166+ x = x .reshape ((- 1 , inputs .shape [1 ], self .config .layer_width ))
189167 return x
190168
191169 def get_config (self ):
@@ -222,8 +200,7 @@ def call(self, inputs, *args, **kwargs):
222200 embed = self .embedding (obs_pos )
223201
224202 ones = tf .ones ((embed .shape [0 ], 1 , self .config .layer_width ), dtype = tf .float32 )
225- obs_mask = tf .cast (obs_mask , dtype = tf .float32 )
226- obs_mask = tf .matmul (obs_mask , ones , transpose_a = True )
203+ obs_mask = obs_mask .cast (tf .float32 ).transpose ().matmul (ones )
227204 return embed * obs_mask
228205
229206 def get_config (self ):
0 commit comments