11import tensorflow as tf
22from deepctr .layers .activation import activation_layer
33from deepctr .layers .utils import reduce_max , reduce_mean , reduce_sum , concat_func , div , softmax
4- from tensorflow .python .keras .initializers import RandomNormal , Zeros , glorot_normal
4+ from tensorflow .python .keras .initializers import RandomNormal , Zeros , TruncatedNormal
55from tensorflow .python .keras .layers import Layer
66from tensorflow .python .keras .regularizers import l2
77
@@ -103,19 +103,19 @@ def call(self, inputs, training=None, **kwargs):
103103 weight = tf .pow (weight , self .pow_p ) # [x,k_max,1]
104104
105105 if len (inputs ) == 3 :
106- k_user = tf .cast (tf .maximum (
107- 1. ,
108- tf .minimum (
109- tf .cast (self .k_max , dtype = "float32" ), # k_max
110- tf .log1p (tf .cast (inputs [2 ], dtype = "float32" )) / tf .log (2. ) # hist_len
111- )
112- ), dtype = "int64" )
106+ k_user = inputs [2 ]
113107 seq_mask = tf .transpose (tf .sequence_mask (k_user , self .k_max ), [0 , 2 , 1 ])
114108 padding = tf .ones_like (seq_mask , dtype = tf .float32 ) * (- 2 ** 32 + 1 ) # [x,k_max,1]
115109 weight = tf .where (seq_mask , weight , padding )
116110
117- weight = softmax (weight , dim = 1 , name = "weight" )
118- output = reduce_sum (keys * weight , axis = 1 )
111+ if self .pow_p >= 100 :
112+ idx = tf .stack (
113+ [tf .range (tf .shape (keys )[0 ]), tf .squeeze (tf .argmax (weight , axis = 1 , output_type = tf .int32 ), axis = 1 )],
114+ axis = 1 )
115+ output = tf .gather_nd (keys , idx )
116+ else :
117+ weight = softmax (weight , dim = 1 , name = "weight" )
118+ output = tf .reduce_sum (keys * weight , axis = 1 )
119119
120120 return output
121121
@@ -172,32 +172,59 @@ def __init__(self, input_units, out_units, max_len, k_max, iteration_times=3,
172172 super (CapsuleLayer , self ).__init__ (** kwargs )
173173
174174 def build (self , input_shape ):
175- self .routing_logits = self .add_weight (shape = [1 , self .k_max , self .max_len ],
176- initializer = RandomNormal (stddev = self .init_std ),
177- trainable = False , name = "B" , dtype = tf .float32 )
178175 self .bilinear_mapping_matrix = self .add_weight (shape = [self .input_units , self .out_units ],
179- initializer = RandomNormal (stddev = self .init_std ),
180176 name = "S" , dtype = tf .float32 )
181177 super (CapsuleLayer , self ).build (input_shape )
182178
183179 def call (self , inputs , ** kwargs ):
184- behavior_embddings , seq_len = inputs
185- batch_size = tf .shape (behavior_embddings )[0 ]
186- seq_len_tile = tf .tile (seq_len , [1 , self .k_max ])
180+
181+ behavior_embedding = inputs [0 ]
182+ seq_len = inputs [1 ]
183+ batch_size = tf .shape (behavior_embedding )[0 ]
184+
185+ mask = tf .reshape (tf .sequence_mask (seq_len , self .max_len , tf .float32 ), [- 1 , self .max_len , 1 , 1 ])
186+
187+ behavior_embedding_mapping = tf .tensordot (behavior_embedding , self .bilinear_mapping_matrix , axes = 1 )
188+ behavior_embedding_mapping = tf .expand_dims (behavior_embedding_mapping , axis = 2 )
189+
190+ behavior_embdding_mapping_ = tf .stop_gradient (behavior_embedding_mapping ) # N,max_len,1,E
191+ try :
192+ routing_logits = tf .truncated_normal ([batch_size , self .max_len , self .k_max , 1 ], stddev = self .init_std )
193+ except AttributeError :
194+ routing_logits = tf .compat .v1 .truncated_normal ([batch_size , self .max_len , self .k_max , 1 ],
195+ stddev = self .init_std )
196+ routing_logits = tf .stop_gradient (routing_logits )
197+
198+ k_user = None
199+ if len (inputs ) == 3 :
200+ k_user = inputs [2 ]
201+ interest_mask = tf .sequence_mask (k_user , self .k_max , tf .float32 )
202+ interest_mask = tf .reshape (interest_mask , [batch_size , 1 , self .k_max , 1 ])
203+ interest_mask = tf .tile (interest_mask , [1 , self .max_len , 1 , 1 ])
204+
205+ interest_padding = tf .ones_like (interest_mask ) * - 2 ** 31
206+ interest_mask = tf .cast (interest_mask , tf .bool )
187207
188208 for i in range (self .iteration_times ):
189- mask = tf .sequence_mask (seq_len_tile , self .max_len )
190- pad = tf .ones_like (mask , dtype = tf .float32 ) * (- 2 ** 32 + 1 )
191- routing_logits_with_padding = tf .where (mask , tf .tile (self .routing_logits , [batch_size , 1 , 1 ]), pad )
192- weight = tf .nn .softmax (routing_logits_with_padding )
193- behavior_embdding_mapping = tf .tensordot (behavior_embddings , self .bilinear_mapping_matrix , axes = 1 )
194- Z = tf .matmul (weight , behavior_embdding_mapping )
195- interest_capsules = squash (Z )
196- delta_routing_logits = reduce_sum (
197- tf .matmul (interest_capsules , tf .transpose (behavior_embdding_mapping , perm = [0 , 2 , 1 ])),
198- axis = 0 , keep_dims = True
199- )
200- self .routing_logits .assign_add (delta_routing_logits )
209+ if k_user is not None :
210+ routing_logits = tf .where (interest_mask , routing_logits , interest_padding )
211+ try :
212+ weight = softmax (routing_logits , 2 ) * mask
213+ except TypeError :
214+ weight = tf .transpose (softmax (tf .transpose (routing_logits , [0 , 1 , 3 , 2 ])),
215+ [0 , 1 , 3 , 2 ]) * mask # N,max_len,k_max,1
216+ if i < self .iteration_times - 1 :
217+ Z = reduce_sum (tf .matmul (weight , behavior_embdding_mapping_ ), axis = 1 , keep_dims = True ) # N,1,k_max,E
218+ interest_capsules = squash (Z )
219+ delta_routing_logits = reduce_sum (
220+ interest_capsules * behavior_embdding_mapping_ ,
221+ axis = - 1 , keep_dims = True
222+ )
223+ routing_logits += delta_routing_logits
224+ else :
225+ Z = reduce_sum (tf .matmul (weight , behavior_embedding_mapping ), axis = 1 , keep_dims = True )
226+ interest_capsules = squash (Z )
227+
201228 interest_capsules = tf .reshape (interest_capsules , [- 1 , self .k_max , self .out_units ])
202229 return interest_capsules
203230
@@ -213,7 +240,7 @@ def get_config(self, ):
213240
214241def squash (inputs ):
215242 vec_squared_norm = reduce_sum (tf .square (inputs ), axis = - 1 , keep_dims = True )
216- scalar_factor = vec_squared_norm / (1 + vec_squared_norm ) / tf .sqrt (vec_squared_norm + 1e-8 )
243+ scalar_factor = vec_squared_norm / (1 + vec_squared_norm ) / tf .sqrt (vec_squared_norm + 1e-9 )
217244 vec_squashed = scalar_factor * inputs
218245 return vec_squashed
219246
@@ -235,3 +262,27 @@ def get_config(self, ):
235262 config = {'index' : self .index , }
236263 base_config = super (EmbeddingIndex , self ).get_config ()
237264 return dict (list (base_config .items ()) + list (config .items ()))
265+
266+
267+ class MaskUserEmbedding (Layer ):
268+
269+ def __init__ (self , k_max , ** kwargs ):
270+ self .k_max = k_max
271+ super (MaskUserEmbedding , self ).__init__ (** kwargs )
272+
273+ def build (self , input_shape ):
274+ super (MaskUserEmbedding , self ).build (
275+ input_shape ) # Be sure to call this somewhere!
276+
277+ def call (self , x , training = None , ** kwargs ):
278+ user_embedding , interest_num = x
279+ if not training :
280+ interest_mask = tf .sequence_mask (interest_num , self .k_max , tf .float32 )
281+ interest_mask = tf .reshape (interest_mask , [- 1 , self .k_max , 1 ])
282+ user_embedding *= interest_mask
283+ return user_embedding
284+
285+ def get_config (self , ):
286+ config = {'k_max' : self .k_max , }
287+ base_config = super (MaskUserEmbedding , self ).get_config ()
288+ return dict (list (base_config .items ()) + list (config .items ()))
0 commit comments