Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 75b3c12

Browse files
[bugfix] Backport int32 fix for sampled block and bump version (#1109)
* [BugFix] support int32 for sampled blocks (#1106) * support int32 for sampled blocks * Fix lint * Update version
1 parent 434187e commit 75b3c12

File tree

3 files changed

+21
-17
lines changed

3 files changed

+21
-17
lines changed

src/gluonnlp/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from . import initializer
3232
from .vocab import Vocab
3333

34-
__version__ = '0.8.2'
34+
__version__ = '0.8.3'
3535

3636
__all__ = ['data',
3737
'model',

src/gluonnlp/model/sampled_block.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ def hybrid_forward(self, F, x, sampled_values, label, w_all, b_all):
7070

7171
# remove accidental hits
7272
if self._remove_accidental_hits:
73-
label_vec = F.reshape(label, (-1, 1))
74-
sample_vec = F.reshape(sampled_candidates, (1, -1))
75-
mask = F.broadcast_equal(label_vec, sample_vec) * -1e37
73+
label_vec = F.reshape(label, (-1, 1)).astype('int32')
74+
sample_vec = F.reshape(sampled_candidates, (1, -1)).astype('int32')
75+
mask = F.broadcast_equal(label_vec, sample_vec).astype('float32') * -1e37
7676
pred_sampled = pred_sampled + mask
7777

7878
# subtract log(q)
79-
expected_count_sampled = F.reshape(expected_count_sampled,
80-
shape=(1, self._num_sampled))
81-
expected_count_true = expected_count_true.reshape((-1,))
79+
expected_count_sampled = expected_count_sampled.astype('float32')
80+
expected_count_sampled = expected_count_sampled.reshape(shape=(1, self._num_sampled))
81+
expected_count_true = expected_count_true.astype('float32').reshape((-1,))
8282
pred_true = pred_true - F.log(expected_count_true)
8383
pred_true = pred_true.reshape((-1, 1))
8484
pred_sampled = F.broadcast_sub(pred_sampled, F.log(expected_count_sampled))
@@ -174,7 +174,7 @@ def hybrid_forward(self, F, x, sampled_values, label, weight, bias):
174174
# (batch_size,)
175175
label = F.reshape(label, shape=(-1,))
176176
# (num_sampled+batch_size,)
177-
ids = F.concat(sampled_candidates, label, dim=0)
177+
ids = F.concat(sampled_candidates.astype('int32'), label.astype('int32'), dim=0)
178178
# lookup weights and biases
179179
# (num_sampled+batch_size, dim)
180180
w_all = F.Embedding(data=ids, weight=weight,
@@ -477,7 +477,7 @@ def forward(self, x, sampled_values, label): # pylint: disable=arguments-differ
477477
# (batch_size,)
478478
label = label.reshape(shape=(-1,))
479479
# (num_sampled+batch_size,)
480-
ids = nd.concat(sampled_candidates, label, dim=0)
480+
ids = nd.concat(sampled_candidates.astype('int32'), label.astype('int32'), dim=0)
481481
# lookup weights and biases
482482
weight = self.weight.row_sparse_data(ids)
483483
bias = self.bias.data(ids.context)

tests/unittest/test_sampled_logits.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import pytest
2828

2929
@pytest.mark.parametrize('f', [nlp.model.NCEDense, nlp.model.SparseNCEDense])
30-
def test_nce_loss(f):
30+
@pytest.mark.parametrize('cls_dtype', ['float32', 'int32'])
31+
@pytest.mark.parametrize('count_dtype', ['float32', 'int32'])
32+
def test_nce_loss(f, cls_dtype, count_dtype):
3133
ctx = mx.cpu()
3234
batch_size = 2
3335
num_sampled = 3
@@ -40,9 +42,9 @@ def test_nce_loss(f):
4042
trainer = mx.gluon.Trainer(model.collect_params(), 'sgd')
4143
x = mx.nd.ones((batch_size, num_hidden))
4244
y = mx.nd.ones((batch_size,))
43-
sampled_cls = mx.nd.ones((num_sampled,))
44-
sampled_cls_cnt = mx.nd.ones((num_sampled,))
45-
true_cls_cnt = mx.nd.ones((batch_size,))
45+
sampled_cls = mx.nd.ones((num_sampled,), dtype=cls_dtype)
46+
sampled_cls_cnt = mx.nd.ones((num_sampled,), dtype=count_dtype)
47+
true_cls_cnt = mx.nd.ones((batch_size,), dtype=count_dtype)
4648
samples = (sampled_cls, sampled_cls_cnt, true_cls_cnt)
4749
with mx.autograd.record():
4850
pred, new_y = model(x, samples, y)
@@ -53,7 +55,9 @@ def test_nce_loss(f):
5355
mx.nd.waitall()
5456

5557
@pytest.mark.parametrize('f', [nlp.model.ISDense, nlp.model.SparseISDense])
56-
def test_is_softmax_loss(f):
58+
@pytest.mark.parametrize('cls_dtype', ['float32', 'int32'])
59+
@pytest.mark.parametrize('count_dtype', ['float32', 'int32'])
60+
def test_is_softmax_loss(f, cls_dtype, count_dtype):
5761
ctx = mx.cpu()
5862
batch_size = 2
5963
num_sampled = 3
@@ -66,9 +70,9 @@ def test_is_softmax_loss(f):
6670
trainer = mx.gluon.Trainer(model.collect_params(), 'sgd')
6771
x = mx.nd.ones((batch_size, num_hidden))
6872
y = mx.nd.ones((batch_size,))
69-
sampled_cls = mx.nd.ones((num_sampled,))
70-
sampled_cls_cnt = mx.nd.ones((num_sampled,))
71-
true_cls_cnt = mx.nd.ones((batch_size,))
73+
sampled_cls = mx.nd.ones((num_sampled,), dtype=cls_dtype)
74+
sampled_cls_cnt = mx.nd.ones((num_sampled,), dtype=count_dtype)
75+
true_cls_cnt = mx.nd.ones((batch_size,), dtype=count_dtype)
7276
samples = (sampled_cls, sampled_cls_cnt, true_cls_cnt)
7377
with mx.autograd.record():
7478
pred, new_y = model(x, samples, y)

0 commit comments

Comments
 (0)