-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathDefense_ShrinkPad.py
57 lines (45 loc) · 1.48 KB
/
Defense_ShrinkPad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
'''
This is the example code of defending the BadNets attack.
Dataset is CIFAR-10.
Defense method is ShrinkPad.
'''
import os
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision.models as models
import torchvision
from torchvision.transforms import Compose, ToTensor, PILToTensor, RandomHorizontalFlip, Lambda
import core
shrinkpad = core.ShrinkPad(size_map=32, pad=4)
dataset = torchvision.datasets.DatasetFolder
pattern = torch.zeros((1, 32, 32), dtype=torch.uint8)
pattern[0, -3:, -3:] = 255
# The targeting models that have been poisoned and trained by the BadNet attack, the specific watermark is a small 3x3 square, which you can see in example.py.
transform_test = Compose([
ToTensor() ,
Lambda(lambda img: img + pattern) ,
Lambda(lambda img: shrinkpad.preprocess(img))
])
testset = dataset(
root='./data/cifar10/test',
loader=cv2.imread,
extensions=('png',),
transform=transform_test,
target_transform=None,
is_valid_file=None)
schedule = {
'test_model': './experiments/train_poisoned_DatasetFolder-CIFAR10_2025-02-24_18:20:13/ckpt_epoch_50.pth',
'save_dir': './experiments',
'CUDA_VISIBLE_DEVICES': '0',
'GPU_num': 1,
'experiment_name': 'CIFAR10_test',
'device': 'GPU',
'metric': 'ASR_NoTarget',
'y_target': 0,
'batch_size': 64,
'num_workers': 4,
}
model=core.models.ResNet(18)
predictions = shrinkpad.test(model, testset, schedule=schedule, size_map=32, pad=4)