Skip to content

Commit 351db62

Browse files
authored
Merge pull request #61 from ribenShiSunZi/feawad
Add Autoencoder pretraining module in FeaWAD
2 parents bb8c20c + 5c1ccb5 commit 351db62

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

deepod/models/tabular/feawad.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ class FeaWAD(BaseDeepAD):
6464
random_state: int, optional (default=42)
6565
the seed used by the random
6666
"""
67-
def __init__(self, epochs=100, batch_size=64, lr=1e-3,
67+
68+
def __init__(self, epochs=100, pretrain_epochs=50, batch_size=64, lr=1e-3,
6869
rep_dim=128, hidden_dims='100,50', act='ReLU', bias=False,
6970
margin=5.,
7071
epoch_steps=-1, prt_steps=10, device='cuda',
@@ -76,13 +77,15 @@ def __init__(self, epochs=100, batch_size=64, lr=1e-3,
7677
verbose=verbose, random_state=random_state
7778
)
7879

80+
self.pretrain_epochs = pretrain_epochs
7981
self.margin = margin
8082

8183
self.rep_dim = rep_dim
8284
self.hidden_dims = hidden_dims
8385
self.act = act
8486
self.bias = bias
8587

88+
self.cur_epoch = None
8689
return
8790

8891
def training_prepare(self, X, y):
@@ -107,6 +110,8 @@ def training_prepare(self, X, y):
107110
}
108111
net = FeaWadNet(**network_params).to(self.device)
109112
criterion = FeaWADLoss(margin=self.margin)
113+
self.cur_epoch = 0
114+
110115
if self.verbose >= 2:
111116
print(net)
112117

@@ -123,6 +128,8 @@ def training_forward(self, batch_x, net, criterion):
123128
batch_x = batch_x.float().to(self.device)
124129
batch_y = batch_y.to(self.device)
125130
pred, sub_result = net(batch_x)
131+
if self.cur_epoch <= self.pretrain_epochs:
132+
return torch.nn.functional.mse_loss(batch_x, net.AEmodel(batch_x)[0])
126133
loss = criterion(batch_y, pred, sub_result)
127134
return loss
128135

@@ -133,6 +140,9 @@ def inference_forward(self, batch_x, net, criterion):
133140
batch_z = batch_x
134141
return batch_z, s
135142

143+
def epoch_update(self):
144+
self.cur_epoch += 1
145+
136146

137147
class FeaWadNet(torch.nn.Module):
138148
def __init__(self, n_features, network, n_hidden='500,100', n_hidden2='256,32', n_emb=20,
@@ -143,7 +153,7 @@ def __init__(self, n_features, network, n_hidden='500,100', n_hidden2='256,32',
143153
FWmodel = get_network('MLP')
144154
self.AEmodel = AEmodel_class(n_features, n_hidden=n_hidden, n_emb=n_emb,
145155
activation=activation, bias=bias)
146-
self.LinearModel = FWmodel(n_features+n_emb, n_hidden=n_hidden2, n_output=1,
156+
self.LinearModel = FWmodel(n_features + n_emb, n_hidden=n_hidden2, n_output=1,
147157
activation=activation, bias=bias)
148158

149159
def forward(self, x):
@@ -181,6 +191,7 @@ class FeaWADLoss(torch.nn.Module):
181191
- If ``'sum'``: the output will be summed
182192
183193
"""
194+
184195
def __init__(self, margin=5., reduction='mean'):
185196
super(FeaWADLoss, self).__init__()
186197
self.margin = margin
@@ -192,8 +203,8 @@ def forward(self, y_true, y_pred, sub_result):
192203
inlier_loss = torch.abs(dev)
193204
outlier_loss = torch.abs(torch.maximum(self.margin - dev, torch.tensor(0.)))
194205

195-
sub_nor = torch.norm(sub_result, p=2, dim=1 if len(sub_result.shape)==2 else [1,2])
196-
outlier_sub_loss = torch.abs(torch.maximum(self.margin-sub_nor, torch.tensor(0.)))
206+
sub_nor = torch.norm(sub_result, p=2, dim=1 if len(sub_result.shape) == 2 else [1, 2])
207+
outlier_sub_loss = torch.abs(torch.maximum(self.margin - sub_nor, torch.tensor(0.)))
197208
loss = (1 - y_true) * (inlier_loss + sub_nor) + y_true * (outlier_loss + outlier_sub_loss)
198209

199210
if self.reduction == 'mean':

0 commit comments

Comments
 (0)