|
10 | 10 |
|
11 | 11 | class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase): |
12 | 12 | def create_inputs(self, logits_rank=2, candidate_ids_rank=1): |
13 | | - num_candidates = 10 |
14 | | - shape_3d = (15, 20, num_candidates) |
| 13 | + shape_3d = (15, 20, 10) |
15 | 14 | shape = shape_3d[-logits_rank:] |
16 | 15 | candidate_ids_shape = shape_3d[-candidate_ids_rank:] |
17 | | - rng = keras.random.SeedGenerator(42) |
| 16 | + num_candidates = shape[-1] |
18 | 17 |
|
| 18 | + rng = keras.random.SeedGenerator(42) |
| 19 | + logits = keras.random.uniform(shape, seed=rng) |
19 | 20 | labels = keras.ops.one_hot( |
20 | 21 | keras.random.randint( |
21 | 22 | shape[:-1], minval=0, maxval=num_candidates, seed=rng |
22 | 23 | ), |
23 | 24 | num_candidates, |
24 | 25 | ) |
25 | | - logits = keras.random.uniform(shape, seed=rng) |
26 | 26 | candidate_ids = keras.random.randint( |
27 | 27 | candidate_ids_shape, minval=0, maxval=num_candidates, seed=rng |
28 | 28 | ) |
29 | | - return labels, logits, candidate_ids |
| 29 | + |
| 30 | + return logits, labels, candidate_ids |
30 | 31 |
|
31 | 32 | @parameterized.named_parameters( |
32 | 33 | [ |
@@ -63,12 +64,12 @@ def create_inputs(self, logits_rank=2, candidate_ids_rank=1): |
63 | 64 | ] |
64 | 65 | ) |
65 | 66 | def test_call(self, logits_rank, candidate_ids_rank): |
66 | | - labels, logits, candidate_ids = self.create_inputs( |
| 67 | + logits, labels, candidate_ids = self.create_inputs( |
67 | 68 | logits_rank=logits_rank, candidate_ids_rank=candidate_ids_rank |
68 | 69 | ) |
69 | 70 |
|
70 | 71 | out_logits = remove_accidental_hits.RemoveAccidentalHits()( |
71 | | - labels, logits, candidate_ids |
| 72 | + logits, labels, candidate_ids |
72 | 73 | ) |
73 | 74 |
|
74 | 75 | # Logits of labels are unchanged. |
@@ -148,36 +149,36 @@ def test_mismatched_labels_candidates_shapes(self): |
148 | 149 |
|
149 | 150 | def test_predict(self): |
150 | 151 | # Note: for predict, we test with probabilities that have a batch dim. |
151 | | - labels, logits, candidate_ids = self.create_inputs(candidate_ids_rank=2) |
| 152 | + logits, labels, candidate_ids = self.create_inputs(candidate_ids_rank=2) |
152 | 153 |
|
153 | 154 | layer = remove_accidental_hits.RemoveAccidentalHits() |
154 | | - in_labels = keras.layers.Input(labels.shape[1:]) |
155 | 155 | in_logits = keras.layers.Input(logits.shape[1:]) |
| 156 | + in_labels = keras.layers.Input(labels.shape[1:]) |
156 | 157 | in_candidate_ids = keras.layers.Input(labels.shape[1:]) |
157 | | - out_logits = layer(in_labels, in_logits, in_candidate_ids) |
| 158 | + out_logits = layer(in_logits, in_labels, in_candidate_ids) |
158 | 159 | model = keras.Model( |
159 | | - [in_labels, in_logits, in_candidate_ids], out_logits |
| 160 | + [in_logits, in_labels, in_candidate_ids], out_logits |
160 | 161 | ) |
161 | 162 |
|
162 | | - model.predict([labels, logits, candidate_ids], batch_size=8) |
| 163 | + model.predict([logits, labels, candidate_ids], batch_size=8) |
163 | 164 |
|
164 | 165 | def test_serialization(self): |
165 | 166 | layer = remove_accidental_hits.RemoveAccidentalHits() |
166 | 167 | restored = deserialize(serialize(layer)) |
167 | 168 | self.assertDictEqual(layer.get_config(), restored.get_config()) |
168 | 169 |
|
169 | 170 | def test_model_saving(self): |
170 | | - labels, logits, candidate_ids = self.create_inputs() |
| 171 | + logits, labels, candidate_ids = self.create_inputs() |
171 | 172 |
|
172 | 173 | layer = remove_accidental_hits.RemoveAccidentalHits() |
173 | | - in_labels = keras.layers.Input(labels.shape[1:]) |
174 | 174 | in_logits = keras.layers.Input(logits.shape[1:]) |
| 175 | + in_labels = keras.layers.Input(labels.shape[1:]) |
175 | 176 | in_candidate_ids = keras.layers.Input(batch_shape=candidate_ids.shape) |
176 | | - out_logits = layer(in_labels, in_logits, in_candidate_ids) |
| 177 | + out_logits = layer(in_logits, in_labels, in_candidate_ids) |
177 | 178 | model = keras.Model( |
178 | | - [in_labels, in_logits, in_candidate_ids], out_logits |
| 179 | + [in_logits, in_labels, in_candidate_ids], out_logits |
179 | 180 | ) |
180 | 181 |
|
181 | 182 | self.run_model_saving_test( |
182 | | - model=model, input_data=[labels, logits, candidate_ids] |
| 183 | + model=model, input_data=[logits, labels, candidate_ids] |
183 | 184 | ) |
0 commit comments