@@ -150,29 +150,43 @@ class IdentityDict(Dictionary, nn.Module):
150150 An identity dictionary, i.e. the identity function.
151151 """
152152
153- def __init__ (self , activation_dim = None ):
153+ def __init__ (self , activation_dim = None , dtype = None , device = None ):
154154 super ().__init__ ()
155155 self .activation_dim = activation_dim
156156 self .dict_size = activation_dim
157+ self .device = device
158+ self .dtype = dtype
157159
158160 def encode (self , x ):
161+ if self .device is not None :
162+ x = x .to (self .device )
163+ if self .dtype is not None :
164+ x = x .to (self .dtype )
159165 return x
160166
161167 def decode (self , f ):
168+ if self .device is not None :
169+ f = f .to (self .device )
170+ if self .dtype is not None :
171+ f = f .to (self .dtype )
162172 return f
163173
164174 def forward (self , x , output_features = False , ghost_mask = None ):
175+ if self .device is not None :
176+ x = x .to (self .device )
177+ if self .dtype is not None :
178+ x = x .to (self .dtype )
165179 if output_features :
166180 return x , x
167181 else :
168182 return x
169183
170184 @classmethod
171- def from_pretrained (cls , path , dtype = t . float , device = None ):
185+ def from_pretrained (cls , activation_dim , path , dtype = None , device = None ):
172186 """
173187 Load a pretrained dictionary from a file.
174188 """
175- return cls (None )
189+ return cls (activation_dim , device = device , dtype = dtype )
176190
177191
178192class GatedAutoEncoder (Dictionary , nn .Module ):
0 commit comments