1
1
import tensorflow as tf
2
2
from deepctr .layers .activation import activation_layer
3
3
from deepctr .layers .utils import reduce_max , reduce_mean , reduce_sum , concat_func , div , softmax
4
- from tensorflow .python .keras .initializers import RandomNormal , Zeros , glorot_normal
4
+ from tensorflow .python .keras .initializers import RandomNormal , Zeros , TruncatedNormal
5
5
from tensorflow .python .keras .layers import Layer
6
6
from tensorflow .python .keras .regularizers import l2
7
7
@@ -103,19 +103,19 @@ def call(self, inputs, training=None, **kwargs):
103
103
weight = tf .pow (weight , self .pow_p ) # [x,k_max,1]
104
104
105
105
if len (inputs ) == 3 :
106
- k_user = tf .cast (tf .maximum (
107
- 1. ,
108
- tf .minimum (
109
- tf .cast (self .k_max , dtype = "float32" ), # k_max
110
- tf .log1p (tf .cast (inputs [2 ], dtype = "float32" )) / tf .log (2. ) # hist_len
111
- )
112
- ), dtype = "int64" )
106
+ k_user = inputs [2 ]
113
107
seq_mask = tf .transpose (tf .sequence_mask (k_user , self .k_max ), [0 , 2 , 1 ])
114
108
padding = tf .ones_like (seq_mask , dtype = tf .float32 ) * (- 2 ** 32 + 1 ) # [x,k_max,1]
115
109
weight = tf .where (seq_mask , weight , padding )
116
110
117
- weight = softmax (weight , dim = 1 , name = "weight" )
118
- output = reduce_sum (keys * weight , axis = 1 )
111
+ if self .pow_p >= 100 :
112
+ idx = tf .stack (
113
+ [tf .range (tf .shape (keys )[0 ]), tf .squeeze (tf .argmax (weight , axis = 1 , output_type = tf .int32 ), axis = 1 )],
114
+ axis = 1 )
115
+ output = tf .gather_nd (keys , idx )
116
+ else :
117
+ weight = softmax (weight , dim = 1 , name = "weight" )
118
+ output = tf .reduce_sum (keys * weight , axis = 1 )
119
119
120
120
return output
121
121
@@ -172,32 +172,59 @@ def __init__(self, input_units, out_units, max_len, k_max, iteration_times=3,
172
172
super (CapsuleLayer , self ).__init__ (** kwargs )
173
173
174
174
def build (self , input_shape ):
175
- self .routing_logits = self .add_weight (shape = [1 , self .k_max , self .max_len ],
176
- initializer = RandomNormal (stddev = self .init_std ),
177
- trainable = False , name = "B" , dtype = tf .float32 )
178
175
self .bilinear_mapping_matrix = self .add_weight (shape = [self .input_units , self .out_units ],
179
- initializer = RandomNormal (stddev = self .init_std ),
180
176
name = "S" , dtype = tf .float32 )
181
177
super (CapsuleLayer , self ).build (input_shape )
182
178
183
179
def call (self , inputs , ** kwargs ):
184
- behavior_embddings , seq_len = inputs
185
- batch_size = tf .shape (behavior_embddings )[0 ]
186
- seq_len_tile = tf .tile (seq_len , [1 , self .k_max ])
180
+
181
+ behavior_embedding = inputs [0 ]
182
+ seq_len = inputs [1 ]
183
+ batch_size = tf .shape (behavior_embedding )[0 ]
184
+
185
+ mask = tf .reshape (tf .sequence_mask (seq_len , self .max_len , tf .float32 ), [- 1 , self .max_len , 1 , 1 ])
186
+
187
+ behavior_embedding_mapping = tf .tensordot (behavior_embedding , self .bilinear_mapping_matrix , axes = 1 )
188
+ behavior_embedding_mapping = tf .expand_dims (behavior_embedding_mapping , axis = 2 )
189
+
190
+ behavior_embdding_mapping_ = tf .stop_gradient (behavior_embedding_mapping ) # N,max_len,1,E
191
+ try :
192
+ routing_logits = tf .truncated_normal ([batch_size , self .max_len , self .k_max , 1 ], stddev = self .init_std )
193
+ except AttributeError :
194
+ routing_logits = tf .compat .v1 .truncated_normal ([batch_size , self .max_len , self .k_max , 1 ],
195
+ stddev = self .init_std )
196
+ routing_logits = tf .stop_gradient (routing_logits )
197
+
198
+ k_user = None
199
+ if len (inputs ) == 3 :
200
+ k_user = inputs [2 ]
201
+ interest_mask = tf .sequence_mask (k_user , self .k_max , tf .float32 )
202
+ interest_mask = tf .reshape (interest_mask , [batch_size , 1 , self .k_max , 1 ])
203
+ interest_mask = tf .tile (interest_mask , [1 , self .max_len , 1 , 1 ])
204
+
205
+ interest_padding = tf .ones_like (interest_mask ) * - 2 ** 31
206
+ interest_mask = tf .cast (interest_mask , tf .bool )
187
207
188
208
for i in range (self .iteration_times ):
189
- mask = tf .sequence_mask (seq_len_tile , self .max_len )
190
- pad = tf .ones_like (mask , dtype = tf .float32 ) * (- 2 ** 32 + 1 )
191
- routing_logits_with_padding = tf .where (mask , tf .tile (self .routing_logits , [batch_size , 1 , 1 ]), pad )
192
- weight = tf .nn .softmax (routing_logits_with_padding )
193
- behavior_embdding_mapping = tf .tensordot (behavior_embddings , self .bilinear_mapping_matrix , axes = 1 )
194
- Z = tf .matmul (weight , behavior_embdding_mapping )
195
- interest_capsules = squash (Z )
196
- delta_routing_logits = reduce_sum (
197
- tf .matmul (interest_capsules , tf .transpose (behavior_embdding_mapping , perm = [0 , 2 , 1 ])),
198
- axis = 0 , keep_dims = True
199
- )
200
- self .routing_logits .assign_add (delta_routing_logits )
209
+ if k_user is not None :
210
+ routing_logits = tf .where (interest_mask , routing_logits , interest_padding )
211
+ try :
212
+ weight = softmax (routing_logits , 2 ) * mask
213
+ except TypeError :
214
+ weight = tf .transpose (softmax (tf .transpose (routing_logits , [0 , 1 , 3 , 2 ])),
215
+ [0 , 1 , 3 , 2 ]) * mask # N,max_len,k_max,1
216
+ if i < self .iteration_times - 1 :
217
+ Z = reduce_sum (tf .matmul (weight , behavior_embdding_mapping_ ), axis = 1 , keep_dims = True ) # N,1,k_max,E
218
+ interest_capsules = squash (Z )
219
+ delta_routing_logits = reduce_sum (
220
+ interest_capsules * behavior_embdding_mapping_ ,
221
+ axis = - 1 , keep_dims = True
222
+ )
223
+ routing_logits += delta_routing_logits
224
+ else :
225
+ Z = reduce_sum (tf .matmul (weight , behavior_embedding_mapping ), axis = 1 , keep_dims = True )
226
+ interest_capsules = squash (Z )
227
+
201
228
interest_capsules = tf .reshape (interest_capsules , [- 1 , self .k_max , self .out_units ])
202
229
return interest_capsules
203
230
@@ -213,7 +240,7 @@ def get_config(self, ):
213
240
214
241
def squash (inputs ):
215
242
vec_squared_norm = reduce_sum (tf .square (inputs ), axis = - 1 , keep_dims = True )
216
- scalar_factor = vec_squared_norm / (1 + vec_squared_norm ) / tf .sqrt (vec_squared_norm + 1e-8 )
243
+ scalar_factor = vec_squared_norm / (1 + vec_squared_norm ) / tf .sqrt (vec_squared_norm + 1e-9 )
217
244
vec_squashed = scalar_factor * inputs
218
245
return vec_squashed
219
246
@@ -235,3 +262,27 @@ def get_config(self, ):
235
262
config = {'index' : self .index , }
236
263
base_config = super (EmbeddingIndex , self ).get_config ()
237
264
return dict (list (base_config .items ()) + list (config .items ()))
265
+
266
+
267
+ class MaskUserEmbedding (Layer ):
268
+
269
+ def __init__ (self , k_max , ** kwargs ):
270
+ self .k_max = k_max
271
+ super (MaskUserEmbedding , self ).__init__ (** kwargs )
272
+
273
+ def build (self , input_shape ):
274
+ super (MaskUserEmbedding , self ).build (
275
+ input_shape ) # Be sure to call this somewhere!
276
+
277
+ def call (self , x , training = None , ** kwargs ):
278
+ user_embedding , interest_num = x
279
+ if not training :
280
+ interest_mask = tf .sequence_mask (interest_num , self .k_max , tf .float32 )
281
+ interest_mask = tf .reshape (interest_mask , [- 1 , self .k_max , 1 ])
282
+ user_embedding *= interest_mask
283
+ return user_embedding
284
+
285
+ def get_config (self , ):
286
+ config = {'k_max' : self .k_max , }
287
+ base_config = super (MaskUserEmbedding , self ).get_config ()
288
+ return dict (list (base_config .items ()) + list (config .items ()))
0 commit comments