Skip to content

Commit 61ac634

Browse files
authored
add handling of device and dtype to IdentityDict
1 parent b2bb4b5 commit 61ac634

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

dictionary_learning/dictionary.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,29 +150,43 @@ class IdentityDict(Dictionary, nn.Module):
150150
An identity dictionary, i.e. the identity function.
151151
"""
152152

153-
def __init__(self, activation_dim=None):
153+
def __init__(self, activation_dim=None, dtype=None, device=None):
154154
super().__init__()
155155
self.activation_dim = activation_dim
156156
self.dict_size = activation_dim
157+
self.device = device
158+
self.dtype = dtype
157159

158160
def encode(self, x):
161+
if self.device is not None:
162+
x = x.to(self.device)
163+
if self.dtype is not None:
164+
x = x.to(self.dtype)
159165
return x
160166

161167
def decode(self, f):
168+
if self.device is not None:
169+
f = f.to(self.device)
170+
if self.dtype is not None:
171+
f = f.to(self.dtype)
162172
return f
163173

164174
def forward(self, x, output_features=False, ghost_mask=None):
175+
if self.device is not None:
176+
x = x.to(self.device)
177+
if self.dtype is not None:
178+
x = x.to(self.dtype)
165179
if output_features:
166180
return x, x
167181
else:
168182
return x
169183

170184
@classmethod
171-
def from_pretrained(cls, path, dtype=t.float, device=None):
185+
def from_pretrained(cls, activation_dim, path, dtype=None, device=None):
172186
"""
173187
Load a pretrained dictionary from a file.
174188
"""
175-
return cls(None)
189+
return cls(activation_dim, device=device, dtype=dtype)
176190

177191

178192
class GatedAutoEncoder(Dictionary, nn.Module):

0 commit comments

Comments
 (0)