@@ -176,11 +176,11 @@ def freeze_backbone(self):
176176
177177
178178@depends_on_timm ()
179- def primus_s (input_channels , output_channels , patch_size , patch_embed_size = ( 8 , 8 , 8 ) , patch_drop_rate = 0.0 ):
179+ def primus_s (input_channels , output_channels , patch_size , patch_embed_size = 8 , patch_drop_rate = 0.0 ):
180180 model = Primus (
181181 input_channels = input_channels ,
182182 embed_dim = 396 ,
183- patch_embed_size = patch_embed_size ,
183+ patch_embed_size = ( patch_embed_size ,) * len ( patch_size ) ,
184184 num_classes = output_channels ,
185185 eva_depth = 12 ,
186186 eva_numheads = 6 ,
@@ -193,11 +193,11 @@ def primus_s(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
193193
194194
195195@depends_on_timm ()
196- def primus_b (input_channels , output_channels , patch_size , patch_embed_size = ( 8 , 8 , 8 ) , patch_drop_rate = 0.0 ):
196+ def primus_b (input_channels , output_channels , patch_size , patch_embed_size = 8 , patch_drop_rate = 0.0 ):
197197 model = Primus (
198198 input_channels = input_channels ,
199199 embed_dim = 792 ,
200- patch_embed_size = patch_embed_size ,
200+ patch_embed_size = ( patch_embed_size ,) * len ( patch_size ) ,
201201 num_classes = output_channels ,
202202 eva_depth = 12 ,
203203 eva_numheads = 12 ,
@@ -211,11 +211,11 @@ def primus_b(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
211211
212212
213213@depends_on_timm ()
214- def primus_m (input_channels , output_channels , patch_size , patch_embed_size = ( 8 , 8 , 8 ) , patch_drop_rate = 0.0 ):
214+ def primus_m (input_channels , output_channels , patch_size , patch_embed_size = 8 , patch_drop_rate = 0.0 ):
215215 model = Primus (
216216 input_channels = input_channels ,
217217 embed_dim = 864 ,
218- patch_embed_size = patch_embed_size ,
218+ patch_embed_size = ( patch_embed_size ,) * len ( patch_size ) ,
219219 num_classes = output_channels ,
220220 eva_depth = 16 ,
221221 eva_numheads = 12 ,
@@ -229,11 +229,11 @@ def primus_m(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
229229
230230
231231@depends_on_timm ()
232- def primus_l (input_channels , output_channels , patch_size , patch_embed_size = ( 8 , 8 , 8 ) , patch_drop_rate = 0.0 ):
232+ def primus_l (input_channels , output_channels , patch_size , patch_embed_size = 8 , patch_drop_rate = 0.0 ):
233233 model = Primus (
234234 input_channels = input_channels ,
235235 embed_dim = 1056 ,
236- patch_embed_size = patch_embed_size ,
236+ patch_embed_size = ( patch_embed_size ,) * len ( patch_size ) ,
237237 num_classes = output_channels ,
238238 eva_depth = 24 ,
239239 eva_numheads = 16 ,
@@ -247,11 +247,11 @@ def primus_l(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
247247
248248
249249@depends_on_timm ()
250- def primus_h (input_channels , output_channels , patch_size , patch_embed_size = ( 8 , 8 , 8 ) , patch_drop_rate = 0.0 ):
250+ def primus_h (input_channels , output_channels , patch_size , patch_embed_size = 8 , patch_drop_rate = 0.0 ):
251251 model = Primus (
252252 input_channels = input_channels ,
253253 embed_dim = 1248 ,
254- patch_embed_size = patch_embed_size ,
254+ patch_embed_size = ( patch_embed_size ,) * len ( patch_size ) ,
255255 num_classes = output_channels ,
256256 eva_depth = 32 ,
257257 eva_numheads = 16 ,
@@ -265,11 +265,11 @@ def primus_h(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
265265
266266
267267@depends_on_timm ()
268- def primus_g (input_channels , output_channels , patch_size , patch_embed_size = ( 8 , 8 , 8 ) , patch_drop_rate = 0.0 ):
268+ def primus_g (input_channels , output_channels , patch_size , patch_embed_size = 8 , patch_drop_rate = 0.0 ):
269269 model = Primus (
270270 input_channels = input_channels ,
271271 embed_dim = 1584 ,
272- patch_embed_size = patch_embed_size ,
272+ patch_embed_size = ( patch_embed_size ,) * len ( patch_size ) ,
273273 num_classes = output_channels ,
274274 eva_depth = 32 ,
275275 eva_numheads = 24 ,
@@ -284,13 +284,13 @@ def primus_g(input_channels, output_channels, patch_size, patch_embed_size=(8, 8
284284
285285@depends_on_timm ()
286286def primus_m_clsreg (
287- input_channels , output_channels , patch_size , patch_embed_size = ( 8 , 8 , 8 ) , dropout_rate = 0.0 , late_fusion : bool = False
287+ input_channels , output_channels , patch_size , patch_embed_size = 8 , dropout_rate = 0.0 , late_fusion : bool = False
288288):
289289 return PrimusCLSREG (
290290 input_channels = input_channels ,
291291 output_channels = output_channels ,
292292 embed_dim = 864 ,
293- patch_embed_size = patch_embed_size ,
293+ patch_embed_size = ( patch_embed_size ,) * len ( patch_size ) ,
294294 eva_depth = 16 ,
295295 eva_numheads = 12 ,
296296 input_shape = patch_size ,
0 commit comments