Skip to content

Commit 58cfe72

Browse files
committed
Changes to support DINOv2 in HF
1 parent 662e877 commit 58cfe72

File tree

3 files changed

+53
-3
lines changed

3 files changed

+53
-3
lines changed

Diff for: hf_hub.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from radio.adaptor_base import RadioOutput
2222
from radio.adaptor_registry import adaptor_registry
2323
from radio.adaptor_mlp import get_mlp_info_from_state
24-
from radio.hf_model import RADIOConfig, RADIOModel
24+
from radio.hf_model import RADIOConfig, RADIOModel, rename_all_gamma_to_weight_with_proxy
2525
from test_hf import deterministic_grid_init
2626

2727

@@ -164,7 +164,7 @@ def main():
164164

165165
feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
166166
feature_normalizer_config = None
167-
if feat_norm_sd is not None:
167+
if feat_norm_sd:
168168
feature_normalizer_config = {
169169
"embed_dim": feat_norm_sd['mean'].shape[0]
170170
}
@@ -219,6 +219,10 @@ def main():
219219
if inter_feat_norm_sd:
220220
radio_model.radio_model.inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)
221221

222+
# Rename "gamma" parameters to "weight"
223+
rename_all_gamma_to_weight_with_proxy(radio_model.radio_model)
224+
radio_config.rename_gamma_to_weight = True
225+
222226
radio_model.eval().cuda()
223227

224228
# Sample inference with deterministic values.
@@ -240,7 +244,7 @@ def main():
240244
hf_summary, hf_features = v.summary, v.features
241245

242246
print(
243-
f"[{k}] Sample inference on tensor shape {x.shape} returned summary ",
247+
f"[{k}] HF inference on tensor shape {x.shape} returned summary ",
244248
f"with shape={hf_summary.shape} and std={hf_summary.std().item():.3}, ",
245249
f"features with shape={hf_features.shape} and std={hf_features.std().item():.3}",
246250
)
@@ -288,6 +292,12 @@ def main():
288292
torchhub_output[k].features,
289293
)
290294

295+
print(
296+
f"[{k}] TorchHub inference on tensor shape {x.shape} returned summary ",
297+
f"with shape={torchhub_summary.shape} and std={torchhub_summary.std().item():.3}, ",
298+
f"features with shape={torchhub_features.shape} and std={torchhub_features.std().item():.3}",
299+
)
300+
291301
# Make sure the shapes are the same.
292302
assert (
293303
hf_summary.shape == torchhub_summary.shape

Diff for: radio/common.py

+7
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ class RadioResource:
8080
max_resolution=2048,
8181
preferred_resolution=Resolution(512, 512),
8282
),
83+
# RADIO-DINOv2
84+
"radio_dinov2-g": RadioResource(
85+
None, # TODO: add URL for DINOv2 student.
86+
patch_size=14,
87+
max_resolution=2044,
88+
preferred_resolution=Resolution(518, 518),
89+
),
8390
}
8491

8592
DEFAULT_VERSION = "radio_v2.5-h"

Diff for: radio/hf_model.py

+33
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,33 @@
4343
from .extra_timm_models import *
4444

4545

46+
47+
def rename_all_gamma_to_weight_with_proxy(module):
48+
"""
49+
Renames all parameters named 'gamma' in a module (including submodules)
50+
to 'weight' and sets up a property so that accesses to 'gamma' still work.
51+
"""
52+
# Recursively iterate through submodules
53+
for submodule_name, submodule in module.named_modules():
54+
# Get all parameters within the current submodule
55+
for param_name, param in list(submodule.named_parameters(recurse=False)):
56+
if 'gamma' in param_name:
57+
# Generate the new name by replacing 'gamma' with 'weight'
58+
new_name = param_name.replace('gamma', 'weight')
59+
60+
# Remove the old parameter and assign it with the new name
61+
delattr(submodule, param_name)
62+
setattr(submodule, new_name, nn.Parameter(param.data))
63+
64+
# Define a property to proxy access to the renamed parameter
65+
def make_property(old_name, new_name):
66+
return property(lambda self: getattr(self, new_name),
67+
lambda self, value: setattr(self, new_name, value))
68+
69+
# Add the property to the submodule to proxy access to 'gamma'
70+
setattr(submodule.__class__, param_name, make_property(param_name, new_name))
71+
72+
4673
class RADIOConfig(PretrainedConfig):
4774
"""Pretrained Hugging Face configuration for RADIO models."""
4875

@@ -58,6 +85,7 @@ def __init__(
5885
vitdet_window_size: Optional[int] = None,
5986
feature_normalizer_config: Optional[dict] = None,
6087
inter_feature_normalizer_config: Optional[dict] = None,
88+
rename_gamma_to_weight: bool = False,
6189
**kwargs,
6290
):
6391
self.args = args
@@ -79,9 +107,11 @@ def __init__(
79107
self.vitdet_window_size = vitdet_window_size
80108
self.feature_normalizer_config = feature_normalizer_config
81109
self.inter_feature_normalizer_config = inter_feature_normalizer_config
110+
self.rename_gamma_to_weight = rename_gamma_to_weight
82111
super().__init__(**kwargs)
83112

84113

114+
85115
class RADIOModel(PreTrainedModel):
86116
"""Pretrained Hugging Face model for RADIO.
87117
@@ -149,6 +179,9 @@ def __init__(self, config: RADIOConfig):
149179
inter_feature_normalizer=inter_feature_normalizer,
150180
)
151181

182+
if config.rename_gamma_to_weight:
183+
rename_all_gamma_to_weight_with_proxy(self.radio_model)
184+
152185
@property
153186
def adaptors(self) -> nn.ModuleDict:
154187
return self.radio_model.adaptors

0 commit comments

Comments
 (0)