Skip to content

Commit fdd65e3

Browse files
authored
fix bug &update doc (#77)
- Add Scaling factor gamma in DSSM,FM - Fix routing_logits error in MIND #44 #57 #75
1 parent 4c92549 commit fdd65e3

29 files changed

+360
-293
lines changed

.github/ISSUE_TEMPLATE/bug_report.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ Steps to reproduce the behavior:
1818
4. See error
1919

2020
**Operating environment(运行环境):**
21-
- python version [e.g. 3.6, 3.7]
22-
- tensorflow version [e.g. 1.4.0, 1.14.0, 2.3.0]
23-
- deepmatch version [e.g. 0.2.0,]
21+
- python version [e.g. 3.6, 3.7, 3.8]
22+
- tensorflow version [e.g. 1.4.0, 1.14.0, 2.5.0]
23+
- deepmatch version [e.g. 0.2.1,]
2424

2525
**Additional context**
2626
Add any other context about the problem here.

.github/ISSUE_TEMPLATE/question.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ A clear and concise description of what the question is.
1515
Add any other context about the problem here.
1616

1717
**Operating environment(运行环境):**
18-
- python version [e.g. 3.6]
19-
- tensorflow version [e.g. 1.4.0,]
20-
- deepmatch version [e.g. 0.2.0,]
18+
- python version [e.g. 3.6, 3.7, 3.8]
19+
- tensorflow version [e.g. 1.4.0, 1.14.0, 2.5.0]
20+
- deepmatch version [e.g. 0.2.1,]

.github/workflows/ci.yml

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,46 @@ jobs:
1717
timeout-minutes: 120
1818
strategy:
1919
matrix:
20-
python-version: [3.5,3.6,3.7]
21-
tf-version: [1.4.0,1.14.0,2.1.0,2.2.0,2.3.0]
20+
python-version: [3.6,3.7,3.8]
21+
tf-version: [1.4.0,1.14.0,2.5.0]
2222

2323
exclude:
2424
- python-version: 3.7
2525
tf-version: 1.4.0
26+
- python-version: 3.7
27+
tf-version: 1.15.0
28+
- python-version: 3.8
29+
tf-version: 1.4.0
30+
- python-version: 3.8
31+
tf-version: 1.14.0
32+
- python-version: 3.8
33+
tf-version: 1.15.0
34+
- python-version: 3.6
35+
tf-version: 2.7.0
36+
- python-version: 3.6
37+
tf-version: 2.8.0
38+
- python-version: 3.6
39+
tf-version: 2.9.0
40+
- python-version: 3.9
41+
tf-version: 1.4.0
42+
- python-version: 3.9
43+
tf-version: 1.15.0
44+
- python-version: 3.9
45+
tf-version: 2.2.0
2646

2747
steps:
2848

29-
- uses: actions/checkout@v1
49+
- uses: actions/checkout@v3
3050

3151
- name: Setup python environment
32-
uses: actions/setup-python@v1
52+
uses: actions/setup-python@v4
3353
with:
3454
python-version: ${{ matrix.python-version }}
3555

3656
- name: Install dependencies
3757
run: |
3858
pip3 install -q tensorflow==${{ matrix.tf-version }}
59+
pip install -q protobuf==3.19.0
3960
pip install -q requests
4061
pip install -e .
4162
- name: Test with pytest
@@ -46,7 +67,7 @@ jobs:
4667
pip install -q python-coveralls
4768
pytest --cov=deepmatch --cov-report=xml
4869
- name: Upload coverage to Codecov
49-
uses: codecov/codecov-action@v1.0.2
70+
uses: codecov/codecov-action@v3.1.0
5071
with:
5172
token: ${{secrets.CODECOV_TOKEN}}
5273
file: ./coverage.xml

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ Let's [**Get Started!**](https://deepmatch.readthedocs.io/en/latest/Quick-Start.
4747
<a href="https://github.com/wangzhegeek">Wang Zhe</a> ​
4848
<p>Baidu Inc. </p>​
4949
</td>
50+
<td>
51+
​ <a href="https://github.com/clhchtcjj"><img width="70" height="70" src="https://github.com/clhchtcjj.png?s=40" alt="pic"></a><br>
52+
​ <a href="https://github.com/clhchtcjj">Chen Leihui</a> ​
53+
<p>
54+
Alibaba Group </p>​
55+
</td>
5056
<td>
5157
​ <a href="https://github.com/LeoCai"><img width="70" height="70" src="https://github.com/LeoCai.png?s=40" alt="pic"></a><br>
5258
<a href="https://github.com/LeoCai">LeoCai</a>
@@ -57,6 +63,11 @@ Let's [**Get Started!**](https://deepmatch.readthedocs.io/en/latest/Quick-Start.
5763
​ <a href="https://github.com/yangjieyu">Yang Jieyu</a>
5864
<p> Ant Group </p>​
5965
</td>
66+
<td>
67+
​ <a href="https://github.com/zzszmyf"><img width="70" height="70" src="https://github.com/zzszmyf.png?s=40" alt="pic"></a><br>
68+
​ <a href="https://github.com/zzszmyf">Meng Yifan</a>
69+
<p> DeepCTR </p>​
70+
</td>
6071
</tr>
6172
</tbody>
6273
</table>

deepmatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .utils import check_version
22

3-
__version__ = '0.2.0'
3+
__version__ = '0.2.1'
44
check_version(__version__)

deepmatch/layers/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from deepctr.layers import custom_objects
22
from deepctr.layers.utils import reduce_sum
33

4-
from .core import PoolingLayer, Similarity, LabelAwareAttention, CapsuleLayer, SampledSoftmaxLayer, EmbeddingIndex
4+
from .core import PoolingLayer, Similarity, LabelAwareAttention, CapsuleLayer, SampledSoftmaxLayer, EmbeddingIndex, \
5+
MaskUserEmbedding
56
from .interaction import DotAttention, ConcatAttention, SoftmaxWeightedSum, AttentionSequencePoolingLayer, \
67
SelfAttention, \
78
SelfMultiHeadAttention, UserAttention
@@ -23,7 +24,8 @@
2324
'SelfAttention': SelfAttention,
2425
'SelfMultiHeadAttention': SelfMultiHeadAttention,
2526
'UserAttention': UserAttention,
26-
'DynamicMultiRNN': DynamicMultiRNN
27+
'DynamicMultiRNN': DynamicMultiRNN,
28+
'MaskUserEmbedding': MaskUserEmbedding
2729
}
2830

2931
custom_objects = dict(custom_objects, **_custom_objects)

deepmatch/layers/core.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tensorflow as tf
22
from deepctr.layers.activation import activation_layer
33
from deepctr.layers.utils import reduce_max, reduce_mean, reduce_sum, concat_func, div, softmax
4-
from tensorflow.python.keras.initializers import RandomNormal, Zeros, glorot_normal
4+
from tensorflow.python.keras.initializers import RandomNormal, Zeros, TruncatedNormal
55
from tensorflow.python.keras.layers import Layer
66
from tensorflow.python.keras.regularizers import l2
77

@@ -103,19 +103,19 @@ def call(self, inputs, training=None, **kwargs):
103103
weight = tf.pow(weight, self.pow_p) # [x,k_max,1]
104104

105105
if len(inputs) == 3:
106-
k_user = tf.cast(tf.maximum(
107-
1.,
108-
tf.minimum(
109-
tf.cast(self.k_max, dtype="float32"), # k_max
110-
tf.log1p(tf.cast(inputs[2], dtype="float32")) / tf.log(2.) # hist_len
111-
)
112-
), dtype="int64")
106+
k_user = inputs[2]
113107
seq_mask = tf.transpose(tf.sequence_mask(k_user, self.k_max), [0, 2, 1])
114108
padding = tf.ones_like(seq_mask, dtype=tf.float32) * (-2 ** 32 + 1) # [x,k_max,1]
115109
weight = tf.where(seq_mask, weight, padding)
116110

117-
weight = softmax(weight, dim=1, name="weight")
118-
output = reduce_sum(keys * weight, axis=1)
111+
if self.pow_p >= 100:
112+
idx = tf.stack(
113+
[tf.range(tf.shape(keys)[0]), tf.squeeze(tf.argmax(weight, axis=1, output_type=tf.int32), axis=1)],
114+
axis=1)
115+
output = tf.gather_nd(keys, idx)
116+
else:
117+
weight = softmax(weight, dim=1, name="weight")
118+
output = tf.reduce_sum(keys * weight, axis=1)
119119

120120
return output
121121

@@ -159,6 +159,7 @@ def get_config(self, ):
159159
base_config = super(Similarity, self).get_config()
160160
return dict(list(base_config.items()) + list(config.items()))
161161

162+
162163
class CapsuleLayer(Layer):
163164
def __init__(self, input_units, out_units, max_len, k_max, iteration_times=3,
164165
init_std=1.0, **kwargs):
@@ -171,32 +172,28 @@ def __init__(self, input_units, out_units, max_len, k_max, iteration_times=3,
171172
super(CapsuleLayer, self).__init__(**kwargs)
172173

173174
def build(self, input_shape):
174-
self.routing_logits = self.add_weight(shape=[self.max_len, self.k_max, 1],
175-
initializer=TruncatedNormal(stddev=self.init_std),
176-
trainable=False, name="B", dtype=tf.float32)
177-
# N,T,k_max,1
178175
self.bilinear_mapping_matrix = self.add_weight(shape=[self.input_units, self.out_units],
179176
name="S", dtype=tf.float32)
180177
super(CapsuleLayer, self).build(input_shape)
181178

182179
def call(self, inputs, **kwargs):
183180

184-
behavior_embddings = inputs[0]
181+
behavior_embedding = inputs[0]
185182
seq_len = inputs[1]
186-
batch_size = tf.shape(behavior_embddings)[0]
183+
batch_size = tf.shape(behavior_embedding)[0]
187184

188185
mask = tf.reshape(tf.sequence_mask(seq_len, self.max_len, tf.float32), [-1, self.max_len, 1, 1])
189186

190-
behavior_embdding_mapping = tf.matmul(behavior_embddings, self.bilinear_mapping_matrix)
191-
behavior_embdding_mapping = tf.expand_dims(behavior_embdding_mapping, axis=2)
187+
behavior_embedding_mapping = tf.tensordot(behavior_embedding, self.bilinear_mapping_matrix, axes=1)
188+
behavior_embedding_mapping = tf.expand_dims(behavior_embedding_mapping, axis=2)
192189

193-
behavior_embdding_mapping_ = tf.stop_gradient(behavior_embdding_mapping) # N,max_len,1,E
194-
print(behavior_embdding_mapping_)
190+
behavior_embdding_mapping_ = tf.stop_gradient(behavior_embedding_mapping) # N,max_len,1,E
195191
try:
196192
routing_logits = tf.truncated_normal([batch_size, self.max_len, self.k_max, 1], stddev=self.init_std)
197193
except AttributeError:
198194
routing_logits = tf.compat.v1.truncated_normal([batch_size, self.max_len, self.k_max, 1],
199195
stddev=self.init_std)
196+
routing_logits = tf.stop_gradient(routing_logits)
200197

201198
k_user = None
202199
if len(inputs) == 3:
@@ -208,32 +205,42 @@ def call(self, inputs, **kwargs):
208205
interest_padding = tf.ones_like(interest_mask) * -2 ** 31
209206
interest_mask = tf.cast(interest_mask, tf.bool)
210207

211-
routing_logits = tf.stop_gradient(routing_logits)
212-
self.routing_logits = routing_logits # N,max_len,k_max,1
213-
print(self.routing_logits)
214208
for i in range(self.iteration_times):
215209
if k_user is not None:
216-
self.routing_logits = tf.where(interest_mask, self.routing_logits, interest_padding)
217-
weight = tf.nn.softmax(self.routing_logits, 2) * mask # N,max_len,k_max,1
210+
routing_logits = tf.where(interest_mask, routing_logits, interest_padding)
211+
try:
212+
weight = softmax(routing_logits, 2) * mask
213+
except TypeError:
214+
weight = tf.transpose(softmax(tf.transpose(routing_logits, [0, 1, 3, 2])),
215+
[0, 1, 3, 2]) * mask # N,max_len,k_max,1
218216
if i < self.iteration_times - 1:
219217
Z = reduce_sum(tf.matmul(weight, behavior_embdding_mapping_), axis=1, keep_dims=True) # N,1,k_max,E
220218
interest_capsules = squash(Z)
221219
delta_routing_logits = reduce_sum(
222220
interest_capsules * behavior_embdding_mapping_,
223221
axis=-1, keep_dims=True
224222
)
225-
self.routing_logits += delta_routing_logits
223+
routing_logits += delta_routing_logits
226224
else:
227-
Z = reduce_sum(tf.matmul(weight, behavior_embdding_mapping), axis=1, keep_dims=True)
225+
Z = reduce_sum(tf.matmul(weight, behavior_embedding_mapping), axis=1, keep_dims=True)
228226
interest_capsules = squash(Z)
229227

230228
interest_capsules = tf.reshape(interest_capsules, [-1, self.k_max, self.out_units])
231229
return interest_capsules
232230

231+
def compute_output_shape(self, input_shape):
232+
return (None, self.k_max, self.out_units)
233+
234+
def get_config(self, ):
235+
config = {'input_units': self.input_units, 'out_units': self.out_units, 'max_len': self.max_len,
236+
'k_max': self.k_max, 'iteration_times': self.iteration_times, "init_std": self.init_std}
237+
base_config = super(CapsuleLayer, self).get_config()
238+
return dict(list(base_config.items()) + list(config.items()))
239+
233240

234241
def squash(inputs):
235242
vec_squared_norm = reduce_sum(tf.square(inputs), axis=-1, keep_dims=True)
236-
scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + 1e-8)
243+
scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + 1e-9)
237244
vec_squashed = scalar_factor * inputs
238245
return vec_squashed
239246

@@ -255,3 +262,27 @@ def get_config(self, ):
255262
config = {'index': self.index, }
256263
base_config = super(EmbeddingIndex, self).get_config()
257264
return dict(list(base_config.items()) + list(config.items()))
265+
266+
267+
class MaskUserEmbedding(Layer):
268+
269+
def __init__(self, k_max, **kwargs):
270+
self.k_max = k_max
271+
super(MaskUserEmbedding, self).__init__(**kwargs)
272+
273+
def build(self, input_shape):
274+
super(MaskUserEmbedding, self).build(
275+
input_shape) # Be sure to call this somewhere!
276+
277+
def call(self, x, training=None, **kwargs):
278+
user_embedding, interest_num = x
279+
if not training:
280+
interest_mask = tf.sequence_mask(interest_num, self.k_max, tf.float32)
281+
interest_mask = tf.reshape(interest_mask, [-1, self.k_max, 1])
282+
user_embedding *= interest_mask
283+
return user_embedding
284+
285+
def get_config(self, ):
286+
config = {'k_max': self.k_max, }
287+
base_config = super(MaskUserEmbedding, self).get_config()
288+
return dict(list(base_config.items()) + list(config.items()))

deepmatch/models/dssm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
def DSSM(user_feature_columns, item_feature_columns, user_dnn_hidden_units=(64, 32),
1717
item_dnn_hidden_units=(64, 32),
18-
dnn_activation='tanh', dnn_use_bn=False,
19-
l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, seed=1024, metric='cos'):
18+
dnn_activation='relu', dnn_use_bn=False,
19+
l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, gamma=10, seed=1024, metric='cos'):
2020
"""Instantiates the Deep Structured Semantic Model architecture.
2121
2222
:param user_feature_columns: An iterable containing user's features used by the model.
@@ -28,6 +28,7 @@ def DSSM(user_feature_columns, item_feature_columns, user_dnn_hidden_units=(64,
2828
:param l2_reg_dnn: float. L2 regularizer strength applied to DNN
2929
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
3030
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
31+
:param gamma: float. Scaling factor.
3132
:param seed: integer ,to use as random seed.
3233
:param metric: str, ``"cos"`` for cosine or ``"ip"`` for inner product
3334
:return: A Keras model instance.
@@ -55,12 +56,12 @@ def DSSM(user_feature_columns, item_feature_columns, user_dnn_hidden_units=(64,
5556
item_dnn_input = combined_dnn_input(item_sparse_embedding_list, item_dense_value_list)
5657

5758
user_dnn_out = DNN(user_dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout,
58-
dnn_use_bn, seed=seed)(user_dnn_input)
59+
dnn_use_bn, output_activation='linear', seed=seed)(user_dnn_input)
5960

6061
item_dnn_out = DNN(item_dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout,
61-
dnn_use_bn, seed=seed)(item_dnn_input)
62+
dnn_use_bn, output_activation='linear', seed=seed)(item_dnn_input)
6263

63-
score = Similarity(type=metric, gamma = 10)([user_dnn_out, item_dnn_out])
64+
score = Similarity(type=metric, gamma=gamma)([user_dnn_out, item_dnn_out])
6465

6566
output = PredictionLayer("binary", False)(score)
6667

deepmatch/models/fm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from ..layers.core import Similarity
99

1010

11-
def FM(user_feature_columns, item_feature_columns, l2_reg_embedding=1e-6, seed=1024, metric='cos'):
11+
def FM(user_feature_columns, item_feature_columns, l2_reg_embedding=1e-6, gamma=10, seed=1024, metric='cos'):
1212
"""Instantiates the FM architecture.
1313
1414
:param user_feature_columns: An iterable containing user's features used by the model.
1515
:param item_feature_columns: An iterable containing item's features used by the model.
1616
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
17+
:param gamma: float. Scaling factor.
1718
:param seed: integer ,to use as random seed.
1819
:param metric: str, ``"cos"`` for cosine or ``"ip"`` for inner product
1920
:return: A Keras model instance.
@@ -46,7 +47,7 @@ def FM(user_feature_columns, item_feature_columns, l2_reg_embedding=1e-6, seed=1
4647
item_dnn_input = concat_func(item_sparse_embedding_list, axis=1)
4748
item_vector_sum = Lambda(lambda x: reduce_sum(x, axis=1, keep_dims=False))(item_dnn_input)
4849

49-
score = Similarity(type=metric)([user_vector_sum, item_vector_sum])
50+
score = Similarity(type=metric, gamma=gamma)([user_vector_sum, item_vector_sum])
5051

5152
output = PredictionLayer("binary", False)(score)
5253

0 commit comments

Comments
 (0)