|
| 1 | +""" |
| 2 | +
|
| 3 | +Author: |
| 4 | + |
| 5 | +
|
| 6 | +""" |
| 7 | + |
| 8 | +import numpy as np |
1 | 9 | import tensorflow as tf
|
2 |
| -from deepctr.layers.activation import activation_layer |
3 | 10 | from deepctr.layers.utils import reduce_max, reduce_mean, reduce_sum, concat_func, div, softmax
|
4 |
| -from tensorflow.python.keras.initializers import RandomNormal, Zeros, TruncatedNormal |
| 11 | +from tensorflow.python.keras.initializers import Zeros |
5 | 12 | from tensorflow.python.keras.layers import Layer
|
6 |
| -from tensorflow.python.keras.regularizers import l2 |
7 | 13 |
|
8 | 14 |
|
9 | 15 | class PoolingLayer(Layer):
|
@@ -45,45 +51,103 @@ def get_config(self, ):
|
45 | 51 |
|
46 | 52 |
|
47 | 53 | class SampledSoftmaxLayer(Layer):
|
48 |
| - def __init__(self, num_sampled=5, **kwargs): |
49 |
| - self.num_sampled = num_sampled |
| 54 | + def __init__(self, sampler_config, temperature=1.0, **kwargs): |
| 55 | + self.sampler_config = sampler_config |
| 56 | + self.temperature = temperature |
| 57 | + self.sampler = self.sampler_config['sampler'] |
| 58 | + self.item_count = self.sampler_config['item_count'] |
| 59 | + |
50 | 60 | super(SampledSoftmaxLayer, self).__init__(**kwargs)
|
51 | 61 |
|
52 | 62 | def build(self, input_shape):
|
53 |
| - self.size = input_shape[0][0] |
54 |
| - self.zero_bias = self.add_weight(shape=[self.size], |
| 63 | + self.vocabulary_size = input_shape[0][0] |
| 64 | + self.zero_bias = self.add_weight(shape=[self.vocabulary_size], |
55 | 65 | initializer=Zeros,
|
56 | 66 | dtype=tf.float32,
|
57 | 67 | trainable=False,
|
58 | 68 | name="bias")
|
59 | 69 | super(SampledSoftmaxLayer, self).build(input_shape)
|
60 | 70 |
|
61 |
| - def call(self, inputs_with_label_idx, training=None, **kwargs): |
62 |
| - """ |
63 |
| - The first input should be the model as it were, and the second the |
64 |
| - target (i.e., a repeat of the training data) to compute the labels |
65 |
| - argument |
66 |
| - """ |
67 |
| - embeddings, inputs, label_idx = inputs_with_label_idx |
68 |
| - |
69 |
| - loss = tf.nn.sampled_softmax_loss(weights=embeddings, # self.item_embedding. |
70 |
| - biases=self.zero_bias, |
71 |
| - labels=label_idx, |
72 |
| - inputs=inputs, |
73 |
| - num_sampled=self.num_sampled, |
74 |
| - num_classes=self.size, # self.target_song_size |
75 |
| - ) |
| 71 | + def call(self, inputs_with_item_idx, training=None, **kwargs): |
| 72 | + item_embeddings, user_vec, item_idx = inputs_with_item_idx |
| 73 | + if item_idx.dtype != tf.int64: |
| 74 | + item_idx = tf.cast(item_idx, tf.int64) |
| 75 | + user_vec /= self.temperature |
| 76 | + if self.sampler == "inbatch": |
| 77 | + item_vec = tf.gather(item_embeddings, tf.squeeze(item_idx, axis=1)) |
| 78 | + logits = tf.matmul(user_vec, item_vec, transpose_b=True) |
| 79 | + loss = inbatch_softmax_cross_entropy_with_logits(logits, self.item_count, item_idx) |
| 80 | + |
| 81 | + else: |
| 82 | + num_sampled = self.sampler_config['num_sampled'] |
| 83 | + if self.sampler == "frequency": |
| 84 | + sampled_values = tf.nn.fixed_unigram_candidate_sampler(item_idx, 1, num_sampled, True, |
| 85 | + self.vocabulary_size, |
| 86 | + distortion=self.sampler_config['distortion'], |
| 87 | + unigrams=np.maximum(self.item_count, 1).tolist(), |
| 88 | + seed=None, |
| 89 | + name=None) |
| 90 | + elif self.sampler == "adaptive": |
| 91 | + sampled_values = tf.nn.learned_unigram_candidate_sampler(item_idx, 1, num_sampled, True, |
| 92 | + self.vocabulary_size, seed=None, name=None) |
| 93 | + elif self.sampler == "uniform": |
| 94 | + try: |
| 95 | + sampled_values = tf.nn.uniform_candidate_sampler(item_idx, 1, num_sampled, True, |
| 96 | + self.vocabulary_size, seed=None, name=None) |
| 97 | + except AttributeError: |
| 98 | + sampled_values = tf.random.uniform_candidate_sampler(item_idx, 1, num_sampled, True, |
| 99 | + self.vocabulary_size, seed=None, name=None) |
| 100 | + else: |
| 101 | + raise ValueError(' `%s` sampler is not supported ' % self.sampler) |
| 102 | + |
| 103 | + loss = tf.nn.sampled_softmax_loss(weights=item_embeddings, |
| 104 | + biases=self.zero_bias, |
| 105 | + labels=item_idx, |
| 106 | + inputs=user_vec, |
| 107 | + num_sampled=num_sampled, |
| 108 | + num_classes=self.vocabulary_size, |
| 109 | + sampled_values=sampled_values |
| 110 | + ) |
76 | 111 | return tf.expand_dims(loss, axis=1)
|
77 | 112 |
|
78 | 113 | def compute_output_shape(self, input_shape):
|
79 | 114 | return (None, 1)
|
80 | 115 |
|
81 | 116 | def get_config(self, ):
|
82 |
| - config = {'num_sampled': self.num_sampled} |
| 117 | + config = {'sampler_config': self.sampler_config, 'temperature': self.temperature} |
83 | 118 | base_config = super(SampledSoftmaxLayer, self).get_config()
|
84 | 119 | return dict(list(base_config.items()) + list(config.items()))
|
85 | 120 |
|
86 | 121 |
|
| 122 | +class InBatchSoftmaxLayer(Layer): |
| 123 | + def __init__(self, sampler_config, temperature=1.0, **kwargs): |
| 124 | + self.sampler_config = sampler_config |
| 125 | + self.temperature = temperature |
| 126 | + self.item_count = self.sampler_config['item_count'] |
| 127 | + |
| 128 | + super(InBatchSoftmaxLayer, self).__init__(**kwargs) |
| 129 | + |
| 130 | + def build(self, input_shape): |
| 131 | + super(InBatchSoftmaxLayer, self).build(input_shape) |
| 132 | + |
| 133 | + def call(self, inputs_with_item_idx, training=None, **kwargs): |
| 134 | + user_vec, item_vec, item_idx = inputs_with_item_idx |
| 135 | + if item_idx.dtype != tf.int64: |
| 136 | + item_idx = tf.cast(item_idx, tf.int64) |
| 137 | + user_vec /= self.temperature |
| 138 | + logits = tf.matmul(user_vec, item_vec, transpose_b=True) |
| 139 | + loss = inbatch_softmax_cross_entropy_with_logits(logits, self.item_count, item_idx) |
| 140 | + return tf.expand_dims(loss, axis=1) |
| 141 | + |
| 142 | + def compute_output_shape(self, input_shape): |
| 143 | + return (None, 1) |
| 144 | + |
| 145 | + def get_config(self, ): |
| 146 | + config = {'sampler_config': self.sampler_config, 'temperature': self.temperature} |
| 147 | + base_config = super(InBatchSoftmaxLayer, self).get_config() |
| 148 | + return dict(list(base_config.items()) + list(config.items())) |
| 149 | + |
| 150 | + |
87 | 151 | class LabelAwareAttention(Layer):
|
88 | 152 | def __init__(self, k_max, pow_p=1, **kwargs):
|
89 | 153 | self.k_max = k_max
|
@@ -128,38 +192,6 @@ def get_config(self, ):
|
128 | 192 | return dict(list(base_config.items()) + list(config.items()))
|
129 | 193 |
|
130 | 194 |
|
131 |
| -class Similarity(Layer): |
132 |
| - |
133 |
| - def __init__(self, gamma=1, axis=-1, type='cos', **kwargs): |
134 |
| - self.gamma = gamma |
135 |
| - self.axis = axis |
136 |
| - self.type = type |
137 |
| - super(Similarity, self).__init__(**kwargs) |
138 |
| - |
139 |
| - def build(self, input_shape): |
140 |
| - # Be sure to call this somewhere! |
141 |
| - super(Similarity, self).build(input_shape) |
142 |
| - |
143 |
| - def call(self, inputs, **kwargs): |
144 |
| - query, candidate = inputs |
145 |
| - if self.type == "cos": |
146 |
| - query_norm = tf.norm(query, axis=self.axis) |
147 |
| - candidate_norm = tf.norm(candidate, axis=self.axis) |
148 |
| - cosine_score = reduce_sum(tf.multiply(query, candidate), -1) |
149 |
| - if self.type == "cos": |
150 |
| - cosine_score = div(cosine_score, query_norm * candidate_norm + 1e-8) |
151 |
| - cosine_score = tf.clip_by_value(cosine_score, -1, 1.0) * self.gamma |
152 |
| - return cosine_score |
153 |
| - |
154 |
| - def compute_output_shape(self, input_shape): |
155 |
| - return (None, 1) |
156 |
| - |
157 |
| - def get_config(self, ): |
158 |
| - config = {'gamma': self.gamma, 'axis': self.axis, 'type': self.type} |
159 |
| - base_config = super(Similarity, self).get_config() |
160 |
| - return dict(list(base_config.items()) + list(config.items())) |
161 |
| - |
162 |
| - |
163 | 195 | class CapsuleLayer(Layer):
|
164 | 196 | def __init__(self, input_units, out_units, max_len, k_max, iteration_times=3,
|
165 | 197 | init_std=1.0, **kwargs):
|
@@ -245,6 +277,23 @@ def squash(inputs):
|
245 | 277 | return vec_squashed
|
246 | 278 |
|
247 | 279 |
|
| 280 | +def inbatch_softmax_cross_entropy_with_logits(logits, item_count, item_idx): |
| 281 | + Q = tf.gather(tf.constant(item_count / np.sum(item_count), 'float32'), |
| 282 | + tf.squeeze(item_idx, axis=1)) |
| 283 | + try: |
| 284 | + logQ = tf.reshape(tf.math.log(Q), (1, -1)) |
| 285 | + logits -= logQ # subtract_log_q |
| 286 | + labels = tf.linalg.diag(tf.ones_like(logits[0])) |
| 287 | + except AttributeError: |
| 288 | + logQ = tf.reshape(tf.log(Q), (1, -1)) |
| 289 | + logits -= logQ # subtract_log_q |
| 290 | + labels = tf.diag(tf.ones_like(logits[0])) |
| 291 | + |
| 292 | + loss = tf.nn.softmax_cross_entropy_with_logits( |
| 293 | + labels=labels, logits=logits) |
| 294 | + return loss |
| 295 | + |
| 296 | + |
248 | 297 | class EmbeddingIndex(Layer):
|
249 | 298 |
|
250 | 299 | def __init__(self, index, **kwargs):
|
|
0 commit comments