Skip to content

Commit 6e5c713

Browse files
committed
v1.1
1 parent c80b650 commit 6e5c713

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+5961
-2
lines changed

.DS_Store

8 KB
Binary file not shown.

Figs/overview.jpg

118 KB
Loading

README.md

+68-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,69 @@
1-
# Backdoor-LTH
1+
# Quarantine: Sparsity Can Uncover the Trojan Attack Trigger for Free
2+
3+
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
4+
5+
Codes for this paper **Quarantine: Sparsity Can Uncover the Trojan Attack Trigger for Free** [CVPR2022]
6+
7+
Tianlong Chen\*, Zhenyu Zhang\*, Yihua Zhang\*, Shiyu Chang, Sijia Liu, Zhangyang Wang
8+
9+
10+
11+
## Overview:
12+
13+
Trojan attacks threaten deep neural networks (DNNs) by poisoning them to behave normally on most samples, yet to produce manipulated results for inputs attached with a particular trigger. Several works attempt to detect whether a given DNN has been injected with a specific trigger during the training. In a parallel line of research, the lottery ticket hypothesis reveals the existence of sparse subnetworks which are capable of reaching competitive performance as the dense network after independent training. Connecting these two dots, we investigate the problem of Trojan DNN detection from the brand new lens of sparsity, even when no clean training data is available. Our crucial observation is that the Trojan features are significantly more stable to network pruning than benign features. Leveraging that, we propose a novel Trojan network detection regime: first locating a ``winning Trojan lottery ticket" which preserves nearly full Trojan information yet only chance-level performance on clean inputs; then recovering the trigger embedded in this already isolated subnetwork.
14+
15+
<img src = "Figs/overview.jpg" align = "center" width="60%" hight="60%">
16+
17+
18+
19+
## Prerequisites
20+
21+
```
22+
pytorch >= 1.4
23+
torchvision
24+
advertorch
25+
```
26+
27+
28+
29+
## Usage
30+
31+
1. Iterative magnitude pruning on CIFAR-10 with ResNet-20, RGB trigger.
32+
33+
```
34+
bash script/imp_cifar10_resnet20_color_trigger.sh [data-path]
35+
```
36+
37+
2. Calculate trojan score:
38+
39+
```
40+
bash script/linear_mode_cifar10_res20_color_trigger.sh [data-path] [model-path]
41+
```
42+
43+
3. Recover trigger and detection
44+
45+
```
46+
bash script/reverse_trigger_cifar10_resnet20.sh [data-path] [model-file]
47+
```
48+
49+
50+
51+
## Pretrained models
52+
53+
**CIFAR-10, ResNet-20, RGB trigger:** Pretrained_model/cifar10_res20_rgb_trigger
54+
55+
More models will coming soon...
56+
57+
58+
59+
## Citation
60+
61+
```
62+
@article{chen2022quarantine,
63+
title={Quarantine: Sparsity Can Uncover the Trojan Attack Trigger for Free},
64+
author={Chen, Tianlong and Zhang, Zhenyu and Zhang, Yihua and Chang, Shiyu and Liu, Sijia and Wang, Zhangyang},
65+
journal={arXiv preprint arXiv:2205.11819},
66+
year={2022}
67+
}
68+
```
269

3-
Codes are coming soon!

dataset/.DS_Store

6 KB
Binary file not shown.

dataset/__init__.py

Whitespace-only changes.

dataset/clean_label_cifar10.py

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from PIL import Image
2+
from torch.utils import data
3+
from torchvision import transforms
4+
from torchvision.datasets import CIFAR10
5+
import numpy as np
6+
import torch
7+
import random
8+
from dataset.pgd_attack import PgdAttack
9+
10+
11+
class CleanLabelPoisonedCIFAR10(data.Dataset):
12+
13+
def __init__(self, root,
14+
transform=None,
15+
poison_ratio=0.1,
16+
target=0,
17+
patch_size=5,
18+
random_loc=False,
19+
upper_right=True,
20+
bottom_left=False,
21+
augmentation=True,
22+
black_trigger=False,
23+
pgd_alpha: float = 2 / 255,
24+
pgd_eps: float = 8 / 255,
25+
pgd_iter=7,
26+
robust_model=None):
27+
28+
self.root = root
29+
self.poison_ratio = poison_ratio
30+
self.target_label = target
31+
self.patch_size = patch_size
32+
self.random_loc = random_loc
33+
self.upper_right = upper_right
34+
self.bottom_left = bottom_left
35+
self.pgd_alpha = pgd_alpha
36+
self.pgd_eps = pgd_eps
37+
self.pgd_iter = pgd_iter
38+
self.model = robust_model
39+
self.attacker = PgdAttack(self.model, self.pgd_eps, self.pgd_iter, self.pgd_alpha)
40+
41+
if random_loc:
42+
print('Using random location')
43+
if upper_right:
44+
print('Using fixed location of Upper Right')
45+
if bottom_left:
46+
print('Using fixed location of Bottom Left')
47+
48+
# init trigger
49+
trans_trigger = transforms.Compose(
50+
[transforms.Resize((patch_size, patch_size)), transforms.ToTensor()]
51+
)
52+
trigger = Image.open("dataset/triggers/htbd.png").convert("RGB")
53+
if black_trigger:
54+
print('Using black trigger')
55+
trigger = Image.open("dataset/triggers/clbd.png").convert("RGB")
56+
self.trigger = trans_trigger(trigger)
57+
58+
normalize = transforms.Normalize(mean = (0.4914, 0.4822, 0.4465), std = (0.2470, 0.2435, 0.2616))
59+
60+
if pgd_alpha is None:
61+
pgd_alpha = 1.5 * pgd_eps / pgd_iter
62+
self.pgd_alpha: float = pgd_alpha
63+
self.pgd_eps: float = pgd_eps
64+
self.pgd_iter: int = pgd_iter
65+
66+
if augmentation:
67+
self.transform = transforms.Compose([
68+
transforms.ToPILImage(),
69+
transforms.RandomCrop(32, padding=4),
70+
transforms.RandomHorizontalFlip(),
71+
transforms.ToTensor(),
72+
normalize
73+
])
74+
else:
75+
self.transform = transforms.Compose([
76+
transforms.ToPILImage(),
77+
transforms.ToTensor(),
78+
normalize
79+
])
80+
81+
dataset = CIFAR10(root, train=True, download=True)
82+
83+
self.imgs = dataset.data
84+
self.labels = dataset.targets
85+
self.image_size = self.imgs.shape[1]
86+
87+
if self.poison_ratio != 0.0:
88+
self.imgs = torch.tensor(np.transpose(self.imgs, (0, 3, 1, 2)), dtype=torch.float32) / 255.
89+
target_index, other_index = self.separate_img()
90+
self.poison_num = int(len(target_index) * self.poison_ratio)
91+
target_imgs = self.imgs[target_index[:self.poison_num]]
92+
target_imgs = self.attacker(target_imgs, self.target_label * torch.ones(len(target_imgs), dtype=torch.long)) # (N,3,32,32)
93+
target_imgs = self.add_trigger(target_imgs)
94+
self.imgs[target_index[:self.poison_num]] = target_imgs
95+
print('poison images = {}'.format(self.poison_num))
96+
else:
97+
print("Point ratio is zero!")
98+
99+
def __getitem__(self, index):
100+
img = self.transform(self.imgs[index])
101+
return img, self.labels[index]
102+
103+
def __len__(self):
104+
return len(self.imgs)
105+
106+
def separate_img(self):
107+
"""
108+
Collect all the images, which belong to the target class
109+
"""
110+
dataset = CIFAR10(self.root, train=True, download=True)
111+
target_img_index = []
112+
other_img_index = []
113+
all_data = dataset.data
114+
all_label = dataset.targets
115+
for i in range(len(all_data)):
116+
if self.target_label == all_label[i]:
117+
target_img_index.append(i)
118+
else:
119+
other_img_index.append(i)
120+
return torch.tensor(target_img_index), torch.tensor(other_img_index)
121+
122+
def add_trigger(self, img):
123+
124+
if self.random_loc:
125+
start_x = random.randint(0, self.image_size - self.patch_size)
126+
start_y = random.randint(0, self.image_size - self.patch_size)
127+
elif self.upper_right:
128+
start_x = self.image_size - self.patch_size - 3
129+
start_y = self.image_size - self.patch_size - 3
130+
elif self.bottom_left:
131+
start_x = 3
132+
start_y = 3
133+
else:
134+
assert False
135+
136+
img[:, :, start_x: start_x + self.patch_size, start_y: start_y + self.patch_size] = self.trigger
137+
return img
138+

0 commit comments

Comments
 (0)