-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCompute_DFM_CIFAR.py
126 lines (95 loc) · 4.81 KB
/
Compute_DFM_CIFAR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import sys
sys.path.insert(0,'/home/wangs1/dfmX-augmentation/')
from dataset.CIFAR import CIFAR
import torch
import numpy as np
import pickle
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy.fft as fft
import argparse
from torchmetrics import ConfusionMatrix
import backbone.resnet as resnet
from blocks.resnet.Blocks import BasicBlock,Bottleneck
from train import Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
def main(args):
model_path = args.model_path
model = Model.load_from_checkpoint(model_path)
model.to(device)
model.eval()
model.freeze()
encoder = model.backbone_model
confmat = ConfusionMatrix(num_classes=10)
mean = [0.491400, 0.482158, 0.446531]
std = [0.247032, 0.243485, 0.261588]
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])
Matrix_org = torch.zeros((10,10))
data_test = data_test = CIFAR('./dataset',train=False,transform=transform)
test_loader = torch.utils.data.DataLoader(data_test, batch_size= 1000, shuffle=False,num_workers=2)
for x, y in test_loader:
x, y = x.to(device), y.to(device)
y_hat = encoder(x)
Matrix_org += confmat(y_hat.cpu(), y.cpu())
print(Matrix_org)
batchsize = 64
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean, std)])
testset = CIFAR('./dataset',train=False,transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batchsize, shuffle=False)
result_prediction = {}
for test_class in range(10):
cur_pre = Matrix_org[test_class,test_class]
t = args.t # try with different t values to limite the performance degradation within 30%
with open('./DFMs/removal_order.pkl', 'rb') as f:
importance = pickle.load(f)
re_importance = np.copy(importance)
count = 0
while np.sum(importance) != 0 :
count += 1
correct = 0
mask = np.copy(re_importance)
max_importance = np.max(importance)
mask[mask == max_importance] = 0
mask[mask != 0] = 1
# remove frequency and reconstruct
for x,y in test_loader:
x1=x
sizex = x1.size()
reference_class = torch.ones(sizex[0])*test_class
if (y.to(device) == reference_class.to(device)).int().sum()>0:
F_x1 = torch.zeros(sizex,dtype=torch.complex128)
F_x1 = fft.fftshift(fft.fft2(x1))
for num_s in range(sizex[0]):
for channel in range(3):
F_x1[num_s,channel,:,:] = F_x1[num_s,channel,:,:] * mask
x1 = fft.ifft2(fft.ifftshift(F_x1))
x1 = torch.Tensor(x1).to(device)
x1 = torch.real(x1)
y_hat = encoder(x1)
_, predicted = torch.max(y_hat.data,1)
correct_predictions = (predicted == y.to(device)).int()
tested_classes = (y.to(device) == reference_class.to(device))
tested_classes = tested_classes.int()
correct += (tested_classes*correct_predictions).sum().item()
if correct >= cur_pre-t:
cur_pre = correct
re_importance[re_importance == max_importance] = 0
importance[importance == max_importance] = 0
re_importance[re_importance >0] = 1
result_prediction.update({test_class:re_importance})
with open('./DFMs/'+args.backbone_model+'_'+str(args.t)+'.pkl', 'wb') as f:
pickle.dump(result_prediction, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--backbone_model', type=str, default='resnet18',
help='model ')
parser.add_argument('--model_path', type=str, default='None',
help='path of the model')
parser.add_argument('--t', type=int, default=6,
help='flexible threshold')
args = parser.parse_args()
if not os.path.exists('./DFMs'):
os.makedirs( './DFMs')
main(args)