-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathFGSM Attack.py
67 lines (57 loc) · 2.38 KB
/
FGSM Attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import PreActResNet18
import torch
import torch.nn as nn
from torchvision import datasets, transforms
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=True)
epsilons = [0, 0.01, 0.02, 0.031, 0.04, 0.05, 0.06]
net = PreActResNet18.PreActResNet18()
check_point = torch.load("CIFAR10_PreActResNet18.checkpoint")
device = torch.device("cuda")
model = net.to(device)
model.load_state_dict(check_point['state_dict'])
model.eval()
criterion = nn.SoftMarginLoss()
def fgsm_attack(image, epsilon, data_grad):
sign_data_grad = data_grad.sign()
perturbed_image = image + epsilon * sign_data_grad
perturbed_image = torch.clamp(perturbed_image, 0, 1)
return perturbed_image
def test(model, device, test_loader, epsilon):
correct = 0
adv_examples = []
for data, target in test_loader:
data, target = data.to(device), target.to(device)
data.requires_grad = True
output = model(data)
init_pred = output.max(1, keepdim=True)[1]
if init_pred.item() != target.item():
continue
prediction = torch.zeros(1, 10)
prediction[0][target.item()] = 1
prediction = prediction.to(device)
loss = criterion(output, prediction.float())
model.zero_grad()
loss.backward()
data_grad = data.grad.data
perturbed_data = fgsm_attack(data, epsilon, data_grad)
output = model(perturbed_data)
final_pred = output.max(1, keepdim=True)[1]
if final_pred.item() == target.item():
correct += 1
if (epsilon == 0) and (len(adv_examples) < 5):
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))
else:
if len(adv_examples) < 5:
adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
adv_examples.append((init_pred.item(), final_pred.item(), adv_ex))
final_acc = correct / float(len(test_loader))
print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))
return final_acc, adv_examples
accuracies = []
examples = []
for eps in epsilons:
acc, ex = test(model, device, test_loader, eps)
accuracies.append(acc)
examples.append(ex)