-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnominal-mnist.py
59 lines (47 loc) · 1.88 KB
/
nominal-mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from Architectures.CNN4_4 import *
from Model import *
from torch.utils.data.dataloader import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import copy
import torchattacks
import os
# Settings
config = {
'batch_size': 50,
'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
'criterion': nn.CrossEntropyLoss(),
'optimizer': torch.optim.Adam,
'optimkwargs': {},
'scheduler': None,
'schedulerkwargs': {},
'epochs': 20,
'in_channels': 1,
'num_classes': 10,
'n_steps': 10,
'kernel_size': 3,
'experiment_name': "nominal-mnist",
'model_log': None,
'attack': None,
'thres': 0,
'attackkwargs': {},
'comments': None
}
config['experiment_path'] = str(os.getcwd()) + "/Experiment_logs/" + config['experiment_name'] + "/"
if not os.path.isdir(config['experiment_path']):
os.mkdir(config['experiment_path'])
# Load dataset
train_data = torchvision.datasets.MNIST("../data", train=True, download=True, transform=ToTensor())
test_data = torchvision.datasets.MNIST("../data", train=False, download=True, transform=ToTensor())
test_data, val_data = torch.utils.data.random_split(test_data, [int(len(test_data)*0.5), len(test_data) - int(len(test_data)*0.5)])
train_loader = DataLoader(train_data, batch_size=config['batch_size'], shuffle=True, num_workers=0)
val_loader = DataLoader(val_data, batch_size=config['batch_size'], shuffle=True, num_workers=0)
test_loader = DataLoader(test_data, batch_size=config['batch_size'], shuffle=True, num_workers=0)
for x, y in train_loader:
baselines = torch.zeros_like(x).to(config['device'])
break
cnn = CNN4_4(in_channels=config['in_channels'], num_classes=config['num_classes'])
defense = Model(cnn, config, config['experiment_name'])
defense.train(val_loader, test_loader)
list_of_attacks = [None]
defense.validate(val_loader, list_of_attacks=list_of_attacks)