1313# limitations under the License.
1414
1515import logging
16+
1617from absl .testing import absltest
18+ from absl .testing import parameterized
1719from absl .testing .absltest import mock
1820import numpy as np
19-
2021from tensorflow_privacy .privacy .privacy_tests .membership_inference_attack .data_structures import AttackInputData
2122from tensorflow_privacy .privacy .privacy_tests .membership_inference_attack .data_structures import SingleSliceSpec
2223from tensorflow_privacy .privacy .privacy_tests .membership_inference_attack .data_structures import SlicingFeature
@@ -38,7 +39,7 @@ def _are_lists_equal(lhs, rhs) -> bool:
3839 return True
3940
4041
41- class SingleSliceSpecsTest (absltest .TestCase ):
42+ class SingleSliceSpecsTest (parameterized .TestCase ):
4243 """Tests for get_single_slice_specs."""
4344
4445 ENTIRE_DATASET_SLICE = SingleSliceSpec ()
@@ -95,8 +96,81 @@ def test_slicing_by_multiple_features(self):
9596 output = get_single_slice_specs (input_data , n_classes )
9697 self .assertLen (output , expected_slices )
9798
99+ @parameterized .parameters (
100+ (np .array ([1 , 2 , 1 , 2 ]), np .array ([2 , 2 , 1 , 2 ]), [1 , 2 ]),
101+ (np .array ([0 , - 1 , 2 , - 1 , 2 ]), np .array ([2 , 2 , - 1 , 2 ]), [- 1 , 2 ]),
102+ (np .array ([1 , 2 , 1 , 2 ] + list (range (5000 ))), np .array ([2 , 2 , 1 ]), [1 , 2 ]),
103+ (np .array ([1 , 2 , 1 , 2 ]), np .array ([3 , 4 ]), []),
104+ )
105+ def test_slicing_by_custom_indices_one_pair (self , custom_train_indices ,
106+ custom_test_indices ,
107+ expected_groups ):
108+ input_data = SlicingSpec (
109+ all_custom_train_indices = [custom_train_indices ],
110+ all_custom_test_indices = [custom_test_indices ])
111+ expected = [self .ENTIRE_DATASET_SLICE ] + [
112+ SingleSliceSpec (SlicingFeature .CUSTOM ,
113+ (custom_train_indices , custom_test_indices , g ))
114+ for g in expected_groups
115+ ]
116+ output = get_single_slice_specs (input_data )
117+ self .assertTrue (_are_lists_equal (output , expected ))
98118
99- class GetSliceTest (absltest .TestCase ):
119+ def test_slicing_by_custom_indices_multi_pairs (self ):
120+ all_custom_train_indices = [
121+ np .array ([1 , 2 , 1 , 2 ]),
122+ np .array ([0 , - 1 , 2 , - 1 , 2 ]),
123+ np .array ([1 , 2 , 1 , 2 ] + list (range (5000 ))),
124+ np .array ([1 , 2 , 1 , 2 ])
125+ ]
126+ all_custom_test_indices = [
127+ np .array ([2 , 2 , 1 , 2 ]),
128+ np .array ([2 , 2 , - 1 , 2 ]),
129+ np .array ([2 , 2 , 1 ]),
130+ np .array ([3 , 4 ])
131+ ]
132+ expected_group_values = [[1 , 2 ], [- 1 , 2 ], [1 , 2 ], []]
133+
134+ input_data = SlicingSpec (
135+ all_custom_train_indices = all_custom_train_indices ,
136+ all_custom_test_indices = all_custom_test_indices )
137+ expected = [self .ENTIRE_DATASET_SLICE ]
138+ for custom_train_indices , custom_test_indices , eg in zip (
139+ all_custom_train_indices , all_custom_test_indices ,
140+ expected_group_values ):
141+ expected .extend ([
142+ SingleSliceSpec (SlicingFeature .CUSTOM ,
143+ (custom_train_indices , custom_test_indices , g ))
144+ for g in eg
145+ ])
146+ output = get_single_slice_specs (input_data )
147+ self .assertTrue (_are_lists_equal (output , expected ))
148+
149+ @parameterized .parameters (
150+ ([np .array ([1 , 2 ])], None ),
151+ (None , [np .array ([1 , 2 ])]),
152+ ([], [np .array ([1 , 2 ])]),
153+ ([np .array ([1 , 2 ])], [np .array ([1 , 2 ]),
154+ np .array ([1 , 2 ])]),
155+ )
156+ def test_slicing_by_custom_indices_wrong_indices (self ,
157+ all_custom_train_indices ,
158+ all_custom_test_indices ):
159+ self .assertRaises (
160+ ValueError ,
161+ SlicingSpec ,
162+ all_custom_train_indices = all_custom_train_indices ,
163+ all_custom_test_indices = all_custom_test_indices )
164+
165+ def test_slicing_by_custom_indices_too_many_groups (self ):
166+ input_data = SlicingSpec (
167+ all_custom_train_indices = [np .arange (1001 ),
168+ np .arange (3 )],
169+ all_custom_test_indices = [np .arange (1001 ), np .arange (3 )])
170+ self .assertRaises (ValueError , get_single_slice_specs , input_data )
171+
172+
173+ class GetSliceTest (parameterized .TestCase ):
100174
101175 def __init__ (self , methodname ):
102176 """Initialize the test class."""
@@ -210,6 +284,40 @@ def test_slice_by_correctness(self):
210284 self .assertTrue ((output .labels_train == [0 , 2 ]).all ())
211285 self .assertTrue ((output .labels_test == [1 , 2 , 0 ]).all ())
212286
287+ def test_slice_by_custom_indices (self ):
288+ custom_train_indices = np .array ([2 , 2 , 100 , 4 ])
289+ custom_test_indices = np .array ([100 , 2 , 2 , 2 ])
290+ custom_slice = SingleSliceSpec (
291+ SlicingFeature .CUSTOM , (custom_train_indices , custom_test_indices , 2 ))
292+ output = get_slice (self .input_data , custom_slice )
293+ np .testing .assert_array_equal (output .logits_train ,
294+ np .array ([[0 , 1 , 0 ], [2 , 0 , 3 ]]))
295+ np .testing .assert_array_equal (
296+ output .logits_test , np .array ([[12 , 13 , 0 ], [14 , 15 , 0 ], [0 , 16 , 17 ]]))
297+ np .testing .assert_array_equal (output .probs_train ,
298+ np .array ([[0 , 1 , 0 ], [0.1 , 0 , 0.7 ]]))
299+ np .testing .assert_array_equal (
300+ output .probs_test , np .array ([[0.1 , 0.9 , 0 ], [0.15 , 0.85 , 0 ], [0 , 0 ,
301+ 1 ]]))
302+ np .testing .assert_array_equal (output .labels_train , np .array ([1 , 0 ]))
303+ np .testing .assert_array_equal (output .labels_test , np .array ([2 , 0 , 2 ]))
304+ np .testing .assert_array_equal (output .loss_train , np .array ([2 , 0.25 ]))
305+ np .testing .assert_array_equal (output .loss_test , np .array ([3.5 , 7 , 4.5 ]))
306+ np .testing .assert_array_equal (output .entropy_train , np .array ([0.4 , 8 ]))
307+ np .testing .assert_array_equal (output .entropy_test ,
308+ np .array ([10.5 , 4.5 , 0.3 ]))
309+
310+ @parameterized .parameters (
311+ (np .array ([2 , 2 , 100 ]), np .array ([100 , 2 , 2 ])),
312+ (np .array ([2 , 2 , 100 , 4 ]), np .array ([100 , 2 , 2 ])),
313+ (np .array ([2 , 100 , 4 ]), np .array ([100 , 2 , 2 , 2 ])),
314+ )
315+ def test_slice_by_custom_indices_wrong_size (self , custom_train_indices ,
316+ custom_test_indices ):
317+ custom_slice = SingleSliceSpec (
318+ SlicingFeature .CUSTOM , (custom_train_indices , custom_test_indices , 2 ))
319+ self .assertRaises (ValueError , get_slice , self .input_data , custom_slice )
320+
213321
214322class GetSliceTestForMultilabelData (absltest .TestCase ):
215323
@@ -288,6 +396,26 @@ def test_slice_by_correctness_fails(self):
288396 False )
289397 self .assertRaises (ValueError , get_slice , self .input_data , percentile_slice )
290398
399+ def test_slice_by_custom_indices (self ):
400+ custom_train_indices = np .array ([2 , 2 , 100 , 4 ])
401+ custom_test_indices = np .array ([100 , 2 , 2 , 2 ])
402+ custom_slice = SingleSliceSpec (
403+ SlicingFeature .CUSTOM , (custom_train_indices , custom_test_indices , 2 ))
404+ output = get_slice (self .input_data , custom_slice )
405+ # Check logits.
406+ with self .subTest (msg = 'Check logits' ):
407+ np .testing .assert_array_equal (output .logits_train ,
408+ np .array ([[0 , 1 , 0 ], [2 , 0 , 3 ]]))
409+ np .testing .assert_array_equal (
410+ output .logits_test , np .array ([[12 , 13 , 0 ], [14 , 15 , 0 ], [0 , 16 , 17 ]]))
411+
412+ # Check labels.
413+ with self .subTest (msg = 'Check labels' ):
414+ np .testing .assert_array_equal (output .labels_train ,
415+ np .array ([[0 , 1 , 1 ], [1 , 0 , 1 ]]))
416+ np .testing .assert_array_equal (output .labels_test ,
417+ np .array ([[0 , 1 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]]))
418+
291419
292420if __name__ == '__main__' :
293421 absltest .main ()
0 commit comments