|
1 | 1 | import keras
|
| 2 | +from absl.testing import parameterized |
2 | 3 | from keras import ops
|
3 | 4 | from keras.layers import deserialize
|
4 | 5 | from keras.layers import serialize
|
|
7 | 8 | from keras_rs.src.layers.retrieval import sampling_probability_correction
|
8 | 9 |
|
9 | 10 |
|
10 |
| -class SamplingProbabilityCorrectionTest(testing.TestCase): |
11 |
| - def setUp(self): |
12 |
| - shape = (10, 20) # (num_queries, num_candidates) |
| 11 | +class SamplingProbabilityCorrectionTest( |
| 12 | + testing.TestCase, parameterized.TestCase |
| 13 | +): |
| 14 | + def create_inputs(self, logits_rank=2, probs_rank=1): |
| 15 | + shape_3d = (15, 20, 10) |
| 16 | + logits_shape = shape_3d[-logits_rank:] |
| 17 | + probs_shape = shape_3d[-probs_rank:] |
| 18 | + |
13 | 19 | rng = keras.random.SeedGenerator(42)
|
14 |
| - self.logits = keras.random.uniform(shape, seed=rng) |
15 |
| - self.probs_1d = keras.random.uniform(shape[1:], seed=rng) |
16 |
| - self.probs_2d = keras.random.uniform(shape, seed=rng) |
| 20 | + logits = keras.random.uniform(logits_shape, seed=rng) |
| 21 | + probs = keras.random.uniform(probs_shape, seed=rng) |
| 22 | + return logits, probs |
| 23 | + |
| 24 | + @parameterized.named_parameters( |
| 25 | + [ |
| 26 | + { |
| 27 | + "testcase_name": "logits_rank_1_probs_rank_1", |
| 28 | + "logits_rank": 1, |
| 29 | + "probs_rank": 1, |
| 30 | + }, |
| 31 | + { |
| 32 | + "testcase_name": "logits_rank_2_probs_rank_1", |
| 33 | + "logits_rank": 2, |
| 34 | + "probs_rank": 1, |
| 35 | + }, |
| 36 | + { |
| 37 | + "testcase_name": "logits_rank_2_probs_rank_2", |
| 38 | + "logits_rank": 2, |
| 39 | + "probs_rank": 2, |
| 40 | + }, |
| 41 | + { |
| 42 | + "testcase_name": "logits_rank_3_probs_rank_1", |
| 43 | + "logits_rank": 3, |
| 44 | + "probs_rank": 1, |
| 45 | + }, |
| 46 | + { |
| 47 | + "testcase_name": "logits_rank_3_probs_rank_2", |
| 48 | + "logits_rank": 3, |
| 49 | + "probs_rank": 2, |
| 50 | + }, |
| 51 | + { |
| 52 | + "testcase_name": "logits_rank_3_probs_rank_3", |
| 53 | + "logits_rank": 3, |
| 54 | + "probs_rank": 3, |
| 55 | + }, |
| 56 | + ] |
| 57 | + ) |
| 58 | + def test_call(self, logits_rank, probs_rank): |
| 59 | + logits, probs = self.create_inputs( |
| 60 | + logits_rank=logits_rank, probs_rank=probs_rank |
| 61 | + ) |
17 | 62 |
|
18 |
| - def test_call(self): |
19 | 63 | # Verifies logits are always less than corrected logits.
|
20 | 64 | layer = sampling_probability_correction.SamplingProbabilityCorrection()
|
21 |
| - corrected_logits = layer(self.logits, self.probs_1d) |
| 65 | + corrected_logits = layer(logits, probs) |
22 | 66 | self.assertAllClose(
|
23 |
| - ops.less(self.logits, corrected_logits), ops.ones(self.logits.shape) |
| 67 | + ops.less(logits, corrected_logits), ops.ones(logits.shape) |
24 | 68 | )
|
25 | 69 |
|
26 | 70 | # Set some of the probabilities to 0.
|
27 | 71 | probs_with_zeros = ops.multiply(
|
28 |
| - self.probs_1d, |
| 72 | + probs, |
29 | 73 | ops.cast(
|
30 |
| - ops.greater_equal( |
31 |
| - keras.random.uniform(self.probs_1d.shape), 0.5 |
32 |
| - ), |
| 74 | + ops.greater_equal(keras.random.uniform(probs.shape), 0.5), |
33 | 75 | dtype="float32",
|
34 | 76 | ),
|
35 | 77 | )
|
36 | 78 |
|
37 | 79 | # Verifies logits are always less than corrected logits.
|
38 |
| - corrected_logits_with_zeros = layer(self.logits, probs_with_zeros) |
| 80 | + corrected_logits_with_zeros = layer(logits, probs_with_zeros) |
39 | 81 | self.assertAllClose(
|
40 |
| - ops.less(self.logits, corrected_logits_with_zeros), |
41 |
| - ops.ones(self.logits.shape), |
| 82 | + ops.less(logits, corrected_logits_with_zeros), |
| 83 | + ops.ones(logits.shape), |
42 | 84 | )
|
43 | 85 |
|
44 | 86 | def test_predict(self):
|
45 | 87 | # Note: for predict, we test with probabilities that have a batch dim.
|
| 88 | + logits, probs = self.create_inputs(probs_rank=2) |
| 89 | + |
46 | 90 | layer = sampling_probability_correction.SamplingProbabilityCorrection()
|
47 |
| - in_logits = keras.layers.Input(self.logits.shape[1:]) |
48 |
| - in_probs = keras.layers.Input(self.probs_2d.shape[1:]) |
| 91 | + in_logits = keras.layers.Input(logits.shape[1:]) |
| 92 | + in_probs = keras.layers.Input(probs.shape[1:]) |
49 | 93 | out_logits = layer(in_logits, in_probs)
|
50 | 94 | model = keras.Model([in_logits, in_probs], out_logits)
|
51 | 95 |
|
52 |
| - model.predict([self.logits, self.probs_2d], batch_size=4) |
| 96 | + model.predict([logits, probs], batch_size=4) |
53 | 97 |
|
54 | 98 | def test_serialization(self):
|
55 | 99 | layer = sampling_probability_correction.SamplingProbabilityCorrection()
|
56 | 100 | restored = deserialize(serialize(layer))
|
57 | 101 | self.assertDictEqual(layer.get_config(), restored.get_config())
|
58 | 102 |
|
59 | 103 | def test_model_saving(self):
|
| 104 | + logits, probs = self.create_inputs() |
| 105 | + |
60 | 106 | layer = sampling_probability_correction.SamplingProbabilityCorrection()
|
61 |
| - in_logits = keras.layers.Input(shape=self.logits.shape[1:]) |
62 |
| - in_probs = keras.layers.Input(batch_shape=self.probs_1d.shape) |
| 107 | + in_logits = keras.layers.Input(shape=logits.shape[1:]) |
| 108 | + in_probs = keras.layers.Input(batch_shape=probs.shape) |
63 | 109 | out_logits = layer(in_logits, in_probs)
|
64 | 110 | model = keras.Model([in_logits, in_probs], out_logits)
|
65 | 111 |
|
66 |
| - self.run_model_saving_test( |
67 |
| - model=model, input_data=[self.logits, self.probs_1d] |
68 |
| - ) |
| 112 | + self.run_model_saving_test(model=model, input_data=[logits, probs]) |
0 commit comments