27
27
import pytest
28
28
29
29
@pytest .mark .parametrize ('f' , [nlp .model .NCEDense , nlp .model .SparseNCEDense ])
30
- def test_nce_loss (f ):
30
+ @pytest .mark .parametrize ('cls_dtype' , ['float32' , 'int32' ])
31
+ @pytest .mark .parametrize ('count_dtype' , ['float32' , 'int32' ])
32
+ def test_nce_loss (f , cls_dtype , count_dtype ):
31
33
ctx = mx .cpu ()
32
34
batch_size = 2
33
35
num_sampled = 3
@@ -40,9 +42,9 @@ def test_nce_loss(f):
40
42
trainer = mx .gluon .Trainer (model .collect_params (), 'sgd' )
41
43
x = mx .nd .ones ((batch_size , num_hidden ))
42
44
y = mx .nd .ones ((batch_size ,))
43
- sampled_cls = mx .nd .ones ((num_sampled ,))
44
- sampled_cls_cnt = mx .nd .ones ((num_sampled ,))
45
- true_cls_cnt = mx .nd .ones ((batch_size ,))
45
+ sampled_cls = mx .nd .ones ((num_sampled ,), dtype = cls_dtype )
46
+ sampled_cls_cnt = mx .nd .ones ((num_sampled ,), dtype = count_dtype )
47
+ true_cls_cnt = mx .nd .ones ((batch_size ,), dtype = count_dtype )
46
48
samples = (sampled_cls , sampled_cls_cnt , true_cls_cnt )
47
49
with mx .autograd .record ():
48
50
pred , new_y = model (x , samples , y )
@@ -53,7 +55,9 @@ def test_nce_loss(f):
53
55
mx .nd .waitall ()
54
56
55
57
@pytest .mark .parametrize ('f' , [nlp .model .ISDense , nlp .model .SparseISDense ])
56
- def test_is_softmax_loss (f ):
58
+ @pytest .mark .parametrize ('cls_dtype' , ['float32' , 'int32' ])
59
+ @pytest .mark .parametrize ('count_dtype' , ['float32' , 'int32' ])
60
+ def test_is_softmax_loss (f , cls_dtype , count_dtype ):
57
61
ctx = mx .cpu ()
58
62
batch_size = 2
59
63
num_sampled = 3
@@ -66,9 +70,9 @@ def test_is_softmax_loss(f):
66
70
trainer = mx .gluon .Trainer (model .collect_params (), 'sgd' )
67
71
x = mx .nd .ones ((batch_size , num_hidden ))
68
72
y = mx .nd .ones ((batch_size ,))
69
- sampled_cls = mx .nd .ones ((num_sampled ,))
70
- sampled_cls_cnt = mx .nd .ones ((num_sampled ,))
71
- true_cls_cnt = mx .nd .ones ((batch_size ,))
73
+ sampled_cls = mx .nd .ones ((num_sampled ,), dtype = cls_dtype )
74
+ sampled_cls_cnt = mx .nd .ones ((num_sampled ,), dtype = count_dtype )
75
+ true_cls_cnt = mx .nd .ones ((batch_size ,), dtype = count_dtype )
72
76
samples = (sampled_cls , sampled_cls_cnt , true_cls_cnt )
73
77
with mx .autograd .record ():
74
78
pred , new_y = model (x , samples , y )
0 commit comments