Skip to content

Commit f84ce67

Browse files
authored
Test SamplingProbabilityCorrection layer with different dimensions. (#43)
The `SamplingProbabilityCorrection` already supports different dimensions. This adds testing for it. See #39
1 parent 6081ac3 commit f84ce67

File tree

1 file changed

+68
-24
lines changed

1 file changed

+68
-24
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import keras
2+
from absl.testing import parameterized
23
from keras import ops
34
from keras.layers import deserialize
45
from keras.layers import serialize
@@ -7,62 +8,105 @@
78
from keras_rs.src.layers.retrieval import sampling_probability_correction
89

910

10-
class SamplingProbabilityCorrectionTest(testing.TestCase):
11-
def setUp(self):
12-
shape = (10, 20) # (num_queries, num_candidates)
11+
class SamplingProbabilityCorrectionTest(
12+
testing.TestCase, parameterized.TestCase
13+
):
14+
def create_inputs(self, logits_rank=2, probs_rank=1):
15+
shape_3d = (15, 20, 10)
16+
logits_shape = shape_3d[-logits_rank:]
17+
probs_shape = shape_3d[-probs_rank:]
18+
1319
rng = keras.random.SeedGenerator(42)
14-
self.logits = keras.random.uniform(shape, seed=rng)
15-
self.probs_1d = keras.random.uniform(shape[1:], seed=rng)
16-
self.probs_2d = keras.random.uniform(shape, seed=rng)
20+
logits = keras.random.uniform(logits_shape, seed=rng)
21+
probs = keras.random.uniform(probs_shape, seed=rng)
22+
return logits, probs
23+
24+
@parameterized.named_parameters(
25+
[
26+
{
27+
"testcase_name": "logits_rank_1_probs_rank_1",
28+
"logits_rank": 1,
29+
"probs_rank": 1,
30+
},
31+
{
32+
"testcase_name": "logits_rank_2_probs_rank_1",
33+
"logits_rank": 2,
34+
"probs_rank": 1,
35+
},
36+
{
37+
"testcase_name": "logits_rank_2_probs_rank_2",
38+
"logits_rank": 2,
39+
"probs_rank": 2,
40+
},
41+
{
42+
"testcase_name": "logits_rank_3_probs_rank_1",
43+
"logits_rank": 3,
44+
"probs_rank": 1,
45+
},
46+
{
47+
"testcase_name": "logits_rank_3_probs_rank_2",
48+
"logits_rank": 3,
49+
"probs_rank": 2,
50+
},
51+
{
52+
"testcase_name": "logits_rank_3_probs_rank_3",
53+
"logits_rank": 3,
54+
"probs_rank": 3,
55+
},
56+
]
57+
)
58+
def test_call(self, logits_rank, probs_rank):
59+
logits, probs = self.create_inputs(
60+
logits_rank=logits_rank, probs_rank=probs_rank
61+
)
1762

18-
def test_call(self):
1963
# Verifies logits are always less than corrected logits.
2064
layer = sampling_probability_correction.SamplingProbabilityCorrection()
21-
corrected_logits = layer(self.logits, self.probs_1d)
65+
corrected_logits = layer(logits, probs)
2266
self.assertAllClose(
23-
ops.less(self.logits, corrected_logits), ops.ones(self.logits.shape)
67+
ops.less(logits, corrected_logits), ops.ones(logits.shape)
2468
)
2569

2670
# Set some of the probabilities to 0.
2771
probs_with_zeros = ops.multiply(
28-
self.probs_1d,
72+
probs,
2973
ops.cast(
30-
ops.greater_equal(
31-
keras.random.uniform(self.probs_1d.shape), 0.5
32-
),
74+
ops.greater_equal(keras.random.uniform(probs.shape), 0.5),
3375
dtype="float32",
3476
),
3577
)
3678

3779
# Verifies logits are always less than corrected logits.
38-
corrected_logits_with_zeros = layer(self.logits, probs_with_zeros)
80+
corrected_logits_with_zeros = layer(logits, probs_with_zeros)
3981
self.assertAllClose(
40-
ops.less(self.logits, corrected_logits_with_zeros),
41-
ops.ones(self.logits.shape),
82+
ops.less(logits, corrected_logits_with_zeros),
83+
ops.ones(logits.shape),
4284
)
4385

4486
def test_predict(self):
4587
# Note: for predict, we test with probabilities that have a batch dim.
88+
logits, probs = self.create_inputs(probs_rank=2)
89+
4690
layer = sampling_probability_correction.SamplingProbabilityCorrection()
47-
in_logits = keras.layers.Input(self.logits.shape[1:])
48-
in_probs = keras.layers.Input(self.probs_2d.shape[1:])
91+
in_logits = keras.layers.Input(logits.shape[1:])
92+
in_probs = keras.layers.Input(probs.shape[1:])
4993
out_logits = layer(in_logits, in_probs)
5094
model = keras.Model([in_logits, in_probs], out_logits)
5195

52-
model.predict([self.logits, self.probs_2d], batch_size=4)
96+
model.predict([logits, probs], batch_size=4)
5397

5498
def test_serialization(self):
5599
layer = sampling_probability_correction.SamplingProbabilityCorrection()
56100
restored = deserialize(serialize(layer))
57101
self.assertDictEqual(layer.get_config(), restored.get_config())
58102

59103
def test_model_saving(self):
104+
logits, probs = self.create_inputs()
105+
60106
layer = sampling_probability_correction.SamplingProbabilityCorrection()
61-
in_logits = keras.layers.Input(shape=self.logits.shape[1:])
62-
in_probs = keras.layers.Input(batch_shape=self.probs_1d.shape)
107+
in_logits = keras.layers.Input(shape=logits.shape[1:])
108+
in_probs = keras.layers.Input(batch_shape=probs.shape)
63109
out_logits = layer(in_logits, in_probs)
64110
model = keras.Model([in_logits, in_probs], out_logits)
65111

66-
self.run_model_saving_test(
67-
model=model, input_data=[self.logits, self.probs_1d]
68-
)
112+
self.run_model_saving_test(model=model, input_data=[logits, probs])

0 commit comments

Comments
 (0)