@@ -64,7 +64,8 @@ class FeaWAD(BaseDeepAD):
6464 random_state: int, optional (default=42)
6565 the seed used by the random
6666 """
67- def __init__ (self , epochs = 100 , batch_size = 64 , lr = 1e-3 ,
67+
68+ def __init__ (self , epochs = 100 , pretrain_epochs = 50 , batch_size = 64 , lr = 1e-3 ,
6869 rep_dim = 128 , hidden_dims = '100,50' , act = 'ReLU' , bias = False ,
6970 margin = 5. ,
7071 epoch_steps = - 1 , prt_steps = 10 , device = 'cuda' ,
@@ -76,13 +77,15 @@ def __init__(self, epochs=100, batch_size=64, lr=1e-3,
7677 verbose = verbose , random_state = random_state
7778 )
7879
80+ self .pretrain_epochs = pretrain_epochs
7981 self .margin = margin
8082
8183 self .rep_dim = rep_dim
8284 self .hidden_dims = hidden_dims
8385 self .act = act
8486 self .bias = bias
8587
88+ self .cur_epoch = None
8689 return
8790
8891 def training_prepare (self , X , y ):
@@ -107,6 +110,8 @@ def training_prepare(self, X, y):
107110 }
108111 net = FeaWadNet (** network_params ).to (self .device )
109112 criterion = FeaWADLoss (margin = self .margin )
113+ self .cur_epoch = 0
114+
110115 if self .verbose >= 2 :
111116 print (net )
112117
@@ -123,6 +128,8 @@ def training_forward(self, batch_x, net, criterion):
123128 batch_x = batch_x .float ().to (self .device )
124129 batch_y = batch_y .to (self .device )
125130 pred , sub_result = net (batch_x )
131+ if self .cur_epoch <= self .pretrain_epochs :
132+ return torch .nn .functional .mse_loss (batch_x , net .AEmodel (batch_x )[0 ])
126133 loss = criterion (batch_y , pred , sub_result )
127134 return loss
128135
@@ -133,6 +140,9 @@ def inference_forward(self, batch_x, net, criterion):
133140 batch_z = batch_x
134141 return batch_z , s
135142
143+ def epoch_update (self ):
144+ self .cur_epoch += 1
145+
136146
137147class FeaWadNet (torch .nn .Module ):
138148 def __init__ (self , n_features , network , n_hidden = '500,100' , n_hidden2 = '256,32' , n_emb = 20 ,
@@ -143,7 +153,7 @@ def __init__(self, n_features, network, n_hidden='500,100', n_hidden2='256,32',
143153 FWmodel = get_network ('MLP' )
144154 self .AEmodel = AEmodel_class (n_features , n_hidden = n_hidden , n_emb = n_emb ,
145155 activation = activation , bias = bias )
146- self .LinearModel = FWmodel (n_features + n_emb , n_hidden = n_hidden2 , n_output = 1 ,
156+ self .LinearModel = FWmodel (n_features + n_emb , n_hidden = n_hidden2 , n_output = 1 ,
147157 activation = activation , bias = bias )
148158
149159 def forward (self , x ):
@@ -181,6 +191,7 @@ class FeaWADLoss(torch.nn.Module):
181191 - If ``'sum'``: the output will be summed
182192
183193 """
194+
184195 def __init__ (self , margin = 5. , reduction = 'mean' ):
185196 super (FeaWADLoss , self ).__init__ ()
186197 self .margin = margin
@@ -192,8 +203,8 @@ def forward(self, y_true, y_pred, sub_result):
192203 inlier_loss = torch .abs (dev )
193204 outlier_loss = torch .abs (torch .maximum (self .margin - dev , torch .tensor (0. )))
194205
195- sub_nor = torch .norm (sub_result , p = 2 , dim = 1 if len (sub_result .shape )== 2 else [1 ,2 ])
196- outlier_sub_loss = torch .abs (torch .maximum (self .margin - sub_nor , torch .tensor (0. )))
206+ sub_nor = torch .norm (sub_result , p = 2 , dim = 1 if len (sub_result .shape ) == 2 else [1 , 2 ])
207+ outlier_sub_loss = torch .abs (torch .maximum (self .margin - sub_nor , torch .tensor (0. )))
197208 loss = (1 - y_true ) * (inlier_loss + sub_nor ) + y_true * (outlier_loss + outlier_sub_loss )
198209
199210 if self .reduction == 'mean' :
0 commit comments