@@ -215,7 +215,7 @@ def _test_tracin_regression(self, features: int, mode: int) -> None:
215
215
for i in range (len (idx )):
216
216
self .assertTrue (isSorted (idx [i ]))
217
217
218
- if mode == "check_autograd_hacks " :
218
+ if mode == "sample_wise_trick " :
219
219
220
220
criterion = nn .MSELoss (reduction = "none" )
221
221
@@ -228,39 +228,47 @@ def _test_tracin_regression(self, features: int, mode: int) -> None:
228
228
False ,
229
229
)
230
230
231
- # With autograd hacks
231
+ # With sample-wise trick
232
232
criterion = nn .MSELoss (reduction = "sum" )
233
- tracin_hack = self .tracin_constructor (
233
+ tracin_sample_wise_trick = self .tracin_constructor (
234
234
net , dataset , tmpdir , batch_size , criterion , True
235
235
)
236
236
237
237
train_scores = tracin .influence (train_inputs , train_labels )
238
- train_scores_hack = tracin_hack .influence (train_inputs , train_labels )
239
- assertTensorAlmostEqual (self , train_scores , train_scores_hack )
238
+ train_scores_sample_wise_trick = tracin_sample_wise_trick .influence (
239
+ train_inputs , train_labels
240
+ )
241
+ assertTensorAlmostEqual (
242
+ self , train_scores , train_scores_sample_wise_trick
243
+ )
240
244
241
245
test_scores = tracin .influence (test_inputs , test_labels )
242
- test_scores_hack = tracin_hack .influence (test_inputs , test_labels )
243
- assertTensorAlmostEqual (self , test_scores , test_scores_hack )
246
+ test_scores_sample_wise_trick = tracin_sample_wise_trick .influence (
247
+ test_inputs , test_labels
248
+ )
249
+ assertTensorAlmostEqual (
250
+ self , test_scores , test_scores_sample_wise_trick
251
+ )
244
252
245
253
246
254
class _TestTracInRegression1DCheckIdx (_TestTracInRegression ):
247
255
def test_tracin_regression_1D_check_idx (self ):
248
256
self ._test_tracin_regression (1 , "check_idx" )
249
257
250
258
251
- class _TestTracInRegression1DCheckAutogradHacks (_TestTracInRegression ):
252
- def test_tracin_regression_1D_check_autograd_hacks (self ):
253
- self ._test_tracin_regression (1 , "check_autograd_hacks " )
259
+ class _TestTracInRegression1DCheckSampleWiseTrick (_TestTracInRegression ):
260
+ def test_tracin_regression_1D_check_sample_wise_trick (self ):
261
+ self ._test_tracin_regression (1 , "sample_wise_trick " )
254
262
255
263
256
264
class _TestTracInRegression20DCheckIdx (_TestTracInRegression ):
257
265
def test_tracin_regression_20D_check_idx (self ):
258
266
self ._test_tracin_regression (20 , "check_idx" )
259
267
260
268
261
- class _TestTracInRegression20DCheckAutogradHacks (_TestTracInRegression ):
262
- def test_tracin_regression_20D_check_autograd_hacks (self ):
263
- self ._test_tracin_regression (20 , "check_autograd_hacks " )
269
+ class _TestTracInRegression20DCheckSampleWiseTrick (_TestTracInRegression ):
270
+ def test_tracin_regression_20D_check_sample_wise_trick (self ):
271
+ self ._test_tracin_regression (20 , "sample_wise_tricksample_wise_trick " )
264
272
265
273
266
274
class _TestTracInXOR :
@@ -434,7 +442,7 @@ def _test_tracin_xor(self, mode) -> None:
434
442
influence_labels = dataset .labels [idx [i ][0 :5 ], 0 ]
435
443
self .assertTrue (torch .all (testlabels [i , 0 ] == influence_labels ))
436
444
437
- if mode == "check_autograd_hacks " :
445
+ if mode == "sample_wise_trick " :
438
446
439
447
criterion = nn .MSELoss (reduction = "none" )
440
448
@@ -447,9 +455,9 @@ def _test_tracin_xor(self, mode) -> None:
447
455
False ,
448
456
)
449
457
450
- # With autograd hacks
458
+ # With sample-wise trick
451
459
criterion = nn .MSELoss (reduction = "sum" )
452
- tracin_hack = self .tracin_constructor (
460
+ tracin_sample_wise_trick = self .tracin_constructor (
453
461
net ,
454
462
dataset ,
455
463
tmpdir ,
@@ -459,18 +467,22 @@ def _test_tracin_xor(self, mode) -> None:
459
467
)
460
468
461
469
test_scores = tracin .influence (testset , testlabels )
462
- test_scores_hack = tracin_hack .influence (testset , testlabels )
463
- assertTensorAlmostEqual (self , test_scores , test_scores_hack )
470
+ test_scores_sample_wise_trick = tracin_sample_wise_trick .influence (
471
+ testset , testlabels
472
+ )
473
+ assertTensorAlmostEqual (
474
+ self , test_scores , test_scores_sample_wise_trick
475
+ )
464
476
465
477
466
478
class _TestTracInXORCheckIdx (_TestTracInXOR ):
467
479
def test_tracin_xor_check_idx (self ):
468
480
self ._test_tracin_xor ("check_idx" )
469
481
470
482
471
- class _TestTracInXORCheckAutogradHacks (_TestTracInXOR ):
472
- def test_tracin_xor_check_autograd_hacks (self ):
473
- self ._test_tracin_xor ("check_autograd_hacks " )
483
+ class _TestTracInXORCheckSampleWiseTrick (_TestTracInXOR ):
484
+ def test_tracin_xor_check_sample_wise_trick (self ):
485
+ self ._test_tracin_xor ("sample_wise_trick " )
474
486
475
487
476
488
class _TestTracInIdentityRegression :
@@ -537,7 +549,7 @@ def _test_tracin_identity_regression(self, mode) -> None:
537
549
for i in range (len (idx )):
538
550
self .assertEqual (idx [i ][0 ], i )
539
551
540
- if mode == "check_autograd_hacks " :
552
+ if mode == "sample_wise_trick " :
541
553
542
554
criterion = nn .MSELoss (reduction = "none" )
543
555
@@ -550,9 +562,9 @@ def _test_tracin_identity_regression(self, mode) -> None:
550
562
False ,
551
563
)
552
564
553
- # With autograd hacks
565
+ # With sample-wise trick
554
566
criterion = nn .MSELoss (reduction = "sum" )
555
- tracin_hack = self .tracin_constructor (
567
+ tracin_sample_wise_trick = self .tracin_constructor (
556
568
net ,
557
569
dataset ,
558
570
tmpdir ,
@@ -562,18 +574,22 @@ def _test_tracin_identity_regression(self, mode) -> None:
562
574
)
563
575
564
576
train_scores = tracin .influence (train_inputs , train_labels )
565
- train_scores_hack = tracin_hack .influence (train_inputs , train_labels )
566
- assertTensorAlmostEqual (self , train_scores , train_scores_hack )
577
+ train_scores_tracin_sample_wise_trick = (
578
+ tracin_sample_wise_trick .influence (train_inputs , train_labels )
579
+ )
580
+ assertTensorAlmostEqual (
581
+ self , train_scores , train_scores_tracin_sample_wise_trick
582
+ )
567
583
568
584
569
585
class _TestTracInIdentityRegressionCheckIdx (_TestTracInIdentityRegression ):
570
586
def test_tracin_identity_regression_check_idx (self ):
571
587
self ._test_tracin_identity_regression ("check_idx" )
572
588
573
589
574
- class _TestTracInIdentityRegressionCheckAutogradHacks (_TestTracInIdentityRegression ):
575
- def test_tracin_identity_regression_check_autograd_hacks (self ):
576
- self ._test_tracin_identity_regression ("check_autograd_hacks " )
590
+ class _TestTracInIdentityRegressionCheckSampleWiseTrick (_TestTracInIdentityRegression ):
591
+ def test_tracin_identity_regression_check_sample_wise_trick (self ):
592
+ self ._test_tracin_identity_regression ("sample_wise_trick " )
577
593
578
594
579
595
class _TestTracInRandomProjectionRegression :
0 commit comments