43
43
from .extra_timm_models import *
44
44
45
45
46
+
47
+ def rename_all_gamma_to_weight_with_proxy (module ):
48
+ """
49
+ Renames all parameters named 'gamma' in a module (including submodules)
50
+ to 'weight' and sets up a property so that accesses to 'gamma' still work.
51
+ """
52
+ # Recursively iterate through submodules
53
+ for submodule_name , submodule in module .named_modules ():
54
+ # Get all parameters within the current submodule
55
+ for param_name , param in list (submodule .named_parameters (recurse = False )):
56
+ if 'gamma' in param_name :
57
+ # Generate the new name by replacing 'gamma' with 'weight'
58
+ new_name = param_name .replace ('gamma' , 'weight' )
59
+
60
+ # Remove the old parameter and assign it with the new name
61
+ delattr (submodule , param_name )
62
+ setattr (submodule , new_name , nn .Parameter (param .data ))
63
+
64
+ # Define a property to proxy access to the renamed parameter
65
+ def make_property (old_name , new_name ):
66
+ return property (lambda self : getattr (self , new_name ),
67
+ lambda self , value : setattr (self , new_name , value ))
68
+
69
+ # Add the property to the submodule to proxy access to 'gamma'
70
+ setattr (submodule .__class__ , param_name , make_property (param_name , new_name ))
71
+
72
+
46
73
class RADIOConfig (PretrainedConfig ):
47
74
"""Pretrained Hugging Face configuration for RADIO models."""
48
75
@@ -58,6 +85,7 @@ def __init__(
58
85
vitdet_window_size : Optional [int ] = None ,
59
86
feature_normalizer_config : Optional [dict ] = None ,
60
87
inter_feature_normalizer_config : Optional [dict ] = None ,
88
+ rename_gamma_to_weight : bool = False ,
61
89
** kwargs ,
62
90
):
63
91
self .args = args
@@ -79,9 +107,11 @@ def __init__(
79
107
self .vitdet_window_size = vitdet_window_size
80
108
self .feature_normalizer_config = feature_normalizer_config
81
109
self .inter_feature_normalizer_config = inter_feature_normalizer_config
110
+ self .rename_gamma_to_weight = rename_gamma_to_weight
82
111
super ().__init__ (** kwargs )
83
112
84
113
114
+
85
115
class RADIOModel (PreTrainedModel ):
86
116
"""Pretrained Hugging Face model for RADIO.
87
117
@@ -149,6 +179,9 @@ def __init__(self, config: RADIOConfig):
149
179
inter_feature_normalizer = inter_feature_normalizer ,
150
180
)
151
181
182
+ if config .rename_gamma_to_weight :
183
+ rename_all_gamma_to_weight_with_proxy (self .radio_model )
184
+
152
185
@property
153
186
def adaptors (self ) -> nn .ModuleDict :
154
187
return self .radio_model .adaptors
0 commit comments