@@ -125,24 +125,25 @@ def __str__(self):
125
125
126
126
127
127
TEST_MODELS_DESC = [
128
- ModelDesc ("convnext_small" , models .convnext_small , [1 , 3 , 64 , 64 ]),
129
- ModelDesc ("densenet121" , models .densenet121 , [1 , 3 , 64 , 64 ]),
130
- ModelDesc ("efficientnet_b0" , models .efficientnet_b0 , [1 , 3 , 64 , 64 ]),
131
- ModelDesc ("inception_v3" , partial (models .inception_v3 , init_weights = False ), [1 , 3 , 300 , 300 ]),
132
- ModelDesc ("mobilenet_v2" , models .mobilenet_v2 , [1 , 3 , 64 , 64 ]),
133
- ModelDesc ("mobilenet_v3_small" , models .mobilenet_v3_small , [1 , 3 , 64 , 64 ]),
134
- ModelDesc ("resnet18" , models .resnet18 , [1 , 3 , 64 , 64 ]),
135
- ModelDesc ("resnext50_32x4d" , models .resnext50_32x4d , [1 , 3 , 64 , 64 ]),
136
- ModelDesc ("shufflenet_v2_x0_5" , models .shufflenet_v2_x0_5 , [1 , 3 , 224 , 224 ]),
137
- ModelDesc ("squeezenet1_0" , models .squeezenet1_0 , [1 , 3 , 64 , 64 ]),
138
- ModelDesc ("swin_v2_b" , models .swin_v2_b , [1 , 3 , 64 , 64 ]),
139
- ModelDesc ("vgg16" , models .vgg16 , [1 , 3 , 32 , 32 ]),
128
+ ModelDesc ("convnext_small" , partial (models .convnext_small , weights = None ), [1 , 3 , 64 , 64 ]),
129
+ ModelDesc ("densenet121" , partial (models .densenet121 , weights = None ), [1 , 3 , 64 , 64 ]),
130
+ ModelDesc ("efficientnet_b0" , partial (models .efficientnet_b0 , weights = None ), [1 , 3 , 64 , 64 ]),
131
+ ModelDesc ("inception_v3" , partial (models .inception_v3 , init_weights = False , weights = None ), [1 , 3 , 300 , 300 ]),
132
+ ModelDesc ("mobilenet_v2" , partial (models .mobilenet_v2 , weights = None ), [1 , 3 , 64 , 64 ]),
133
+ ModelDesc ("mobilenet_v3_small" , partial (models .mobilenet_v3_small , weights = None ), [1 , 3 , 64 , 64 ]),
134
+ ModelDesc ("resnet18" , partial (models .resnet18 , weights = None ), [1 , 3 , 64 , 64 ]),
135
+ ModelDesc ("resnext50_32x4d" , partial (models .resnext50_32x4d , weights = None ), [1 , 3 , 64 , 64 ]),
136
+ ModelDesc ("shufflenet_v2_x0_5" , partial (models .shufflenet_v2_x0_5 , weights = None ), [1 , 3 , 224 , 224 ]),
137
+ ModelDesc ("squeezenet1_0" , partial (models .squeezenet1_0 , weights = None ), [1 , 3 , 64 , 64 ]),
138
+ ModelDesc ("swin_v2_b" , partial (models .swin_v2_b , weights = None ), [1 , 3 , 64 , 64 ]),
139
+ ModelDesc ("vgg16" , partial (models .vgg16 , weights = None ), [1 , 3 , 32 , 32 ]),
140
+ ModelDesc ("gru" , helpers .ModelGRU , [1 , 3 , 3 ]),
140
141
]
141
142
142
143
143
144
@pytest .mark .parametrize ("desc" , TEST_MODELS_DESC , ids = str )
144
145
def test_model_graph (desc : ModelDesc , regen_ref_data : bool ):
145
- model : torch .nn .Module = desc .model_builder (weights = None )
146
+ model : torch .nn .Module = desc .model_builder ()
146
147
model = model .eval ()
147
148
model = wrap_model (model )
148
149
nncf_graph = build_nncf_graph (model , torch .randn (desc .inputs_info ))
0 commit comments