Skip to content

Commit 9265d54

Browse files
committed
LeViT safetensors load is broken by conversion code that wasn't deactivated
1 parent 21e75a9 commit 9265d54

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

timm/models/levit.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -763,17 +763,18 @@ def checkpoint_filter_fn(state_dict, model):
763763
# filter out attn biases, should not have been persistent
764764
state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
765765

766-
D = model.state_dict()
767-
out_dict = {}
768-
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
769-
if va.ndim == 4 and vb.ndim == 2:
770-
vb = vb[:, :, None, None]
771-
if va.shape != vb.shape:
772-
# head or first-conv shapes may change for fine-tune
773-
assert 'head' in ka or 'stem.conv1.linear' in ka
774-
out_dict[ka] = vb
775-
776-
return out_dict
766+
# NOTE: old weight conversion code, disabled
767+
# D = model.state_dict()
768+
# out_dict = {}
769+
# for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
770+
# if va.ndim == 4 and vb.ndim == 2:
771+
# vb = vb[:, :, None, None]
772+
# if va.shape != vb.shape:
773+
# # head or first-conv shapes may change for fine-tune
774+
# assert 'head' in ka or 'stem.conv1.linear' in ka
775+
# out_dict[ka] = vb
776+
777+
return state_dict
777778

778779

779780
model_cfgs = dict(

0 commit comments

Comments
 (0)