Skip to content

Commit 2cd84f3

Browse files
authored
fix some bugs
- Fix model save error when using acc metric model save issue - Fix model error when all feature columns are dense
2 parents b4d8181 + 300f115 commit 2cd84f3

File tree

14 files changed

+114
-96
lines changed

14 files changed

+114
-96
lines changed

.github/workflows/ci.yml

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,25 @@ jobs:
1717
timeout-minutes: 120
1818
strategy:
1919
matrix:
20-
python-version: [3.6,3.7]
21-
torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.0,1.8.1]
20+
python-version: [3.6,3.7,3.8]
21+
torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.1,1.8.1,1.9.0,1.10.2,1.11.0]
2222

23-
# exclude:
24-
# - python-version: 3.5
25-
# tf-version: 1.1.0
23+
exclude:
24+
- python-version: 3.6
25+
torch-version: 1.11.0
26+
- python-version: 3.8
27+
torch-version: 1.1.0
28+
- python-version: 3.8
29+
torch-version: 1.2.0
30+
- python-version: 3.8
31+
torch-version: 1.3.0
2632

2733
steps:
2834

29-
- uses: actions/checkout@v1
35+
- uses: actions/checkout@v3
3036

3137
- name: Setup python environment
32-
uses: actions/setup-python@v1
38+
uses: actions/setup-python@v4
3339
with:
3440
python-version: ${{ matrix.python-version }}
3541

@@ -47,7 +53,7 @@ jobs:
4753
pip install -q sklearn
4854
pytest --cov=deepctr_torch --cov-report=xml
4955
- name: Upload coverage to Codecov
50-
uses: codecov/codecov-action@v1.0.2
56+
uses: codecov/codecov-action@v3.1.0
5157
with:
5258
token: ${{secrets.CODECOV_TOKEN}}
5359
file: ./coverage.xml

README.md

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -47,34 +47,19 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
4747

4848
## DisscussionGroup & Related Projects
4949

50-
<html>
51-
<table style="margin-left: 20px; margin-right: auto;">
52-
<tr>
53-
<td>
54-
公众号:<b>浅梦学习笔记</b><br><br>
55-
<a href="https://github.com/shenweichen/deepctr-torch">
56-
<img align="center" src="./docs/pics/code.png" />
57-
</a>
58-
</td>
59-
<td>
60-
微信:<b>deepctrbot</b><br><br>
61-
<a href="https://github.com/shenweichen/deepctr-torch">
62-
<img align="center" src="./docs/pics/deepctrbot.png" />
63-
</a>
64-
</td>
65-
<td>
66-
<ul>
67-
<li><a href="https://github.com/shenweichen/AlgoNotes">AlgoNotes</a></li>
68-
<li><a href="https://github.com/shenweichen/DeepCTR">DeepCTR</a></li>
69-
<li><a href="https://github.com/shenweichen/DeepMatch">DeepMatch</a></li>
70-
<li><a href="https://github.com/shenweichen/GraphEmbedding">GraphEmbedding</a></li>
71-
</ul>
72-
</td>
73-
</tr>
74-
</table>
75-
</html>
50+
- [Github Discussions](https://github.com/shenweichen/DeepCTR/discussions)
51+
- Wechat Discussions
7652

53+
|公众号:浅梦学习笔记|微信:deepctrbot|学习小组 [加入](https://t.zsxq.com/026UJEuzv) [主题集合](https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MjM5MzY4NzE3MA==&action=getalbum&album_id=1361647041096843265&scene=126#wechat_redirect)|
54+
|:--:|:--:|:--:|
55+
| [![公众号](./docs/pics/code.png)](https://github.com/shenweichen/AlgoNotes)| [![微信](./docs/pics/deepctrbot.png)](https://github.com/shenweichen/AlgoNotes)|[![学习小组](./docs/pics/planet_github.png)](https://t.zsxq.com/026UJEuzv)|
7756

57+
- Related Projects
58+
59+
- [AlgoNotes](https://github.com/shenweichen/AlgoNotes)
60+
- [DeepCTR](https://github.com/shenweichen/DeepCTR)
61+
- [DeepMatch](https://github.com/shenweichen/DeepMatch)
62+
- [GraphEmbedding](https://github.com/shenweichen/GraphEmbedding)
7863

7964
## Main Contributors([welcome to join us!](./CONTRIBUTING.md))
8065

@@ -84,59 +69,58 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
8469
<td>
8570
​ <a href="https://github.com/shenweichen"><img width="70" height="70" src="https://github.com/shenweichen.png?s=40" alt="pic"></a><br>
8671
​ <a href="https://github.com/shenweichen">Shen Weichen</a> ​
87-
<p>Core Dev<br> Zhejiang Unversity <br> <br> </p>​
72+
<p> Alibaba Group </p>​
8873
</td>
8974
<td>
9075
​ <a href="https://github.com/zanshuxun"><img width="70" height="70" src="https://github.com/zanshuxun.png?s=40" alt="pic"></a><br>
9176
​ <a href="https://github.com/zanshuxun">Zan Shuxun</a>
92-
<p>Core Dev<br> Beijing University <br> of Posts and <br> Telecommunications</p>​
77+
<p> Alibaba Group </p>​
9378
</td>
9479
<td>
9580
<a href="https://github.com/weberrr"><img width="70" height="70" src="https://github.com/weberrr.png?s=40" alt="pic"></a><br>
9681
<a href="https://github.com/weberrr">Wang Ze</a> ​
97-
<p>Core Dev<br> Beihang University <br> <br> </p>​
82+
<p> Meituan </p>​
9883
</td>
9984
<td>
10085
​ <a href="https://github.com/wutongzhang"><img width="70" height="70" src="https://github.com/wutongzhang.png?s=40" alt="pic"></a><br>
10186
<a href="https://github.com/wutongzhang">Zhang Wutong</a>
102-
<p>Core Dev<br> Beijing University <br> of Posts and <br> Telecommunications</p>​
87+
<p> Tencent </p>​
10388
</td>
10489
<td>
10590
​ <a href="https://github.com/ZhangYuef"><img width="70" height="70" src="https://github.com/ZhangYuef.png?s=40" alt="pic"></a><br>
10691
​ <a href="https://github.com/ZhangYuef">Zhang Yuefeng</a>
107-
<p>Core Dev<br>
108-
Peking University <br> <br> </p>​
92+
<p> Peking University </p>​
10993
</td>
11094
</tr>
11195
<tr align="center">
11296
<td>
11397
​ <a href="https://github.com/JyiHUO"><img width="70" height="70" src="https://github.com/JyiHUO.png?s=40" alt="pic"></a><br>
11498
​ <a href="https://github.com/JyiHUO">Huo Junyi</a>
115-
<p>Core Dev<br>
99+
<p>
116100
University of Southampton <br> <br> </p>​
117101
</td>
118102
<td>
119103
​ <a href="https://github.com/Zengai"><img width="70" height="70" src="https://github.com/Zengai.png?s=40" alt="pic"></a><br>
120104
​ <a href="https://github.com/Zengai">Zeng Kai</a> ​
121-
<p>Dev<br>
105+
<p>
122106
SenseTime <br> <br> </p>​
123107
</td>
124108
<td>
125109
​ <a href="https://github.com/chenkkkk"><img width="70" height="70" src="https://github.com/chenkkkk.png?s=40" alt="pic"></a><br>
126110
​ <a href="https://github.com/chenkkkk">Chen K</a> ​
127-
<p>Dev<br>
111+
<p>
128112
NetEase <br> <br> </p>​
129113
</td>
130114
<td>
131115
​ <a href="https://github.com/WeiyuCheng"><img width="70" height="70" src="https://github.com/WeiyuCheng.png?s=40" alt="pic"></a><br>
132116
​ <a href="https://github.com/WeiyuCheng">Cheng Weiyu</a> ​
133-
<p>Dev<br>
117+
<p>
134118
Shanghai Jiao Tong University</p>​
135119
</td>
136120
<td>
137121
​ <a href="https://github.com/tangaqi"><img width="70" height="70" src="https://github.com/tangaqi.png?s=40" alt="pic"></a><br>
138122
​ <a href="https://github.com/tangaqi">Tang</a>
139-
<p>Test<br>
123+
<p>
140124
Tongji University <br> <br> </p>​
141125
</td>
142126
</tr>

deepctr_torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
from . import models
33
from .utils import check_version
44

5-
__version__ = '0.2.7'
5+
__version__ = '0.2.8'
66
check_version(__version__)

deepctr_torch/models/basemodel.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
Author:
55
Weichen Shen,[email protected]
6+
zanshuxun, [email protected]
67
78
"""
89
from __future__ import print_function
@@ -75,7 +76,7 @@ def forward(self, X, sparse_feat_refine_weight=None):
7576

7677
sparse_embedding_list += varlen_embedding_list
7778

78-
linear_logit = torch.zeros([X.shape[0], 1]).to(sparse_embedding_list[0].device)
79+
linear_logit = torch.zeros([X.shape[0], 1]).to(self.device)
7980
if len(sparse_embedding_list) > 0:
8081
sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1)
8182
if sparse_feat_refine_weight is not None:
@@ -476,6 +477,10 @@ def _log_loss(self, y_true, y_pred, eps=1e-7, normalize=True, sample_weight=None
476477
sample_weight,
477478
labels)
478479

480+
@staticmethod
481+
def _accuracy_score(y_true, y_pred):
482+
return accuracy_score(y_true, np.where(y_pred > 0.5, 1, 0))
483+
479484
def _get_metrics(self, metrics, set_eps=False):
480485
metrics_ = {}
481486
if metrics:
@@ -490,8 +495,7 @@ def _get_metrics(self, metrics, set_eps=False):
490495
if metric == "mse":
491496
metrics_[metric] = mean_squared_error
492497
if metric == "accuracy" or metric == "acc":
493-
metrics_[metric] = lambda y_true, y_pred: accuracy_score(
494-
y_true, np.where(y_pred > 0.5, 1, 0))
498+
metrics_[metric] = self._accuracy_score
495499
self.metrics_names.append(metric)
496500
return metrics_
497501

docs/pics/code2.jpg

52.2 KB
Loading

docs/pics/planet_github.png

8.11 KB
Loading

docs/requirements.readthedocs.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
Cython>=0.28.5
2-
tensorflow==1.15.4
2+
tensorflow==2.7.2
3+
scikit-learn==1.0

docs/source/FAQ.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ To save/load weights:
66

77
```python
88
import torch
9-
model = DeepFM()
9+
model = DeepFM(...)
1010
torch.save(model.state_dict(), 'DeepFM_weights.h5')
1111
model.load_state_dict(torch.load('DeepFM_weights.h5'))
1212
```
@@ -15,7 +15,7 @@ To save/load models:
1515

1616
```python
1717
import torch
18-
model = DeepFM()
18+
model = DeepFM(...)
1919
torch.save(model, 'DeepFM.h5')
2020
model = torch.load('DeepFM.h5')
2121
```

docs/source/History.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# History
2-
- 06/14/2021 : [v0.2.7](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6) released.Add [AFN](./Features.html#afn-adaptive-factorization-network-learning-adaptive-order-feature-interactions) and fix some bugs.
2+
- 06/19/2022 : [v0.2.8](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.8) released.Fix some bugs.
3+
- 06/14/2021 : [v0.2.7](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.7) released.Add [AFN](./Features.html#afn-adaptive-factorization-network-learning-adaptive-order-feature-interactions) and fix some bugs.
34
- 04/04/2021 : [v0.2.6](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6) released.Add [IFM](./Features.html#ifm-input-aware-factorization-machine) and [DIFM](./Features.html#difm-dual-input-aware-factorization-machine);Support multi-gpus running([example](./FAQ.html#how-to-run-the-demo-with-multiple-gpus)).
45
- 02/12/2021 : [v0.2.5](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.5) released.Fix bug in DCN-M.
56
- 12/05/2020 : [v0.2.4](https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.4) released.Imporve compatibility & fix issues.Add History callback.([example](https://deepctr-torch.readthedocs.io/en/latest/FAQ.html#set-learning-rate-and-use-earlystopping)).

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# The short X.Y version
2727
version = ''
2828
# The full version, including alpha/beta/rc tags
29-
release = '0.2.7'
29+
release = '0.2.8'
3030

3131

3232
# -- General configuration ---------------------------------------------------

docs/source/index.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,21 @@ You can read the latest code at https://github.com/shenweichen/DeepCTR-Torch and
3434

3535
News
3636
-----
37+
06/19/2022 : Fix some bugs. `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.8>`_
38+
3739
06/14/2021 : Add `AFN <./Features.html#afn-adaptive-factorization-network-learning-adaptive-order-feature-interactions>`_ and fix some bugs. `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.7>`_
3840

3941
04/04/2021 : Add `IFM <./Features.html#ifm-input-aware-factorization-machine>`_ and `DIFM <./Features.html#difm-dual-input-aware-factorization-machine>`_ . Support multi-gpus running(`example <./FAQ.html#how-to-run-the-demo-with-multiple-gpus>`_). `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.6>`_
4042

41-
02/12/2021 : Fix bug in DCN-M. `Changelog <https://github.com/shenweichen/DeepCTR-Torch/releases/tag/v0.2.4>`_
4243

4344
DisscussionGroup
4445
-----------------------
4546

46-
公众号:**浅梦学习笔记** wechat ID: **deepctrbot**
47+
公众号:**浅梦学习笔记** wechat ID: **deepctrbot**
48+
49+
`Discussions <https://github.com/shenweichen/DeepCTR/discussions>`_ `学习小组主题集合 <https://mp.weixin.qq.com/mp/appmsgalbum?__biz=MjM5MzY4NzE3MA==&action=getalbum&album_id=1361647041096843265&scene=126#wechat_redirect>`_
4750

48-
.. image:: ../pics/code.png
51+
.. image:: ../pics/code2.jpg
4952

5053
.. toctree::
5154
:maxdepth: 2

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
long_description = fh.read()
55

66
REQUIRED_PACKAGES = [
7-
'torch>=1.1.0', 'tqdm', 'sklearn', 'tensorflow'
7+
'torch>=1.1.0', 'tqdm', 'scikit-learn', 'tensorflow'
88
]
99

1010
setuptools.setup(
1111
name="deepctr-torch",
12-
version="0.2.7",
12+
version="0.2.8",
1313
author="Weichen Shen",
1414
author_email="[email protected]",
1515
description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with PyTorch",
@@ -37,6 +37,7 @@
3737
'Programming Language :: Python :: 3.5',
3838
'Programming Language :: Python :: 3.6',
3939
'Programming Language :: Python :: 3.7',
40+
'Programming Language :: Python :: 3.8',
4041
'Topic :: Scientific/Engineering',
4142
'Topic :: Scientific/Engineering :: Artificial Intelligence',
4243
'Topic :: Software Development',

tests/models/DeepFM_test.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,33 @@
66

77

88
@pytest.mark.parametrize(
9-
'use_fm,hidden_size,sparse_feature_num',
10-
[(True, (32,), 3),
11-
(False, (32,), 3),
12-
(False, (32,), 2), (False, (32,), 1), (True, (), 1), (False, (), 2)
9+
'use_fm,hidden_size,sparse_feature_num,dense_feature_num',
10+
[(True, (32,), 3, 3),
11+
(False, (32,), 3, 3),
12+
(False, (32,), 2, 2),
13+
(False, (32,), 1, 1),
14+
(True, (), 1, 1),
15+
(False, (), 2, 2),
16+
(True, (32,), 0, 3),
17+
(True, (32,), 3, 0),
18+
(False, (32,), 0, 3),
19+
(False, (32,), 3, 0),
1320
]
1421
)
15-
def test_DeepFM(use_fm, hidden_size, sparse_feature_num):
22+
def test_DeepFM(use_fm, hidden_size, sparse_feature_num, dense_feature_num):
1623
model_name = "DeepFM"
1724
sample_size = SAMPLE_SIZE
1825
x, y, feature_columns = get_test_data(
19-
sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=sparse_feature_num)
26+
sample_size, sparse_feature_num=sparse_feature_num, dense_feature_num=dense_feature_num)
2027

2128
model = DeepFM(feature_columns, feature_columns, use_fm=use_fm,
2229
dnn_hidden_units=hidden_size, dnn_dropout=0.5, device=get_device())
2330
check_model(model, model_name, x, y)
2431

32+
# no linear part
33+
model = DeepFM([], feature_columns, use_fm=use_fm,
34+
dnn_hidden_units=hidden_size, dnn_dropout=0.5, device=get_device())
35+
check_model(model, model_name + '_no_linear', x, y)
36+
2537
if __name__ == "__main__":
2638
pass

0 commit comments

Comments
 (0)