Skip to content

Commit d9af97d

Browse files
committed
fix: 🐛 fix clipiqa model error
1 parent 8129564 commit d9af97d

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

pyiqa/archs/clipiqa_arch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,13 @@ def __init__(self,
139139
for p in self.clip_model[0].parameters():
140140
p.requires_grad = False
141141

142-
if pretrained:
142+
if pretrained and 'clipiqa+' in model_type:
143143
if model_type == 'clipiqa+' and backbone == 'RN50':
144144
self.prompt_learner.ctx.data = torch.load(load_file_from_url(default_model_urls['clipiqa+']))
145-
else:
145+
elif model_type in default_model_urls.keys():
146146
load_pretrained_network(self, default_model_urls[model_type], True, 'params')
147+
else:
148+
raise(f'No pretrained model for {model_type}')
147149

148150
def forward(self, x):
149151
# preprocess image

tests/IAA_benchmark_results.csv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
Metric name,ava(PLCC/SRCC/KRCC)
22
nima-vgg16-ava,0.6624/0.657/0.4719
33
nima,0.7172/0.7126/0.5213
4+
clipiqa,0.3576/0.3383/0.2301

0 commit comments

Comments
 (0)