-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
117 lines (89 loc) · 3.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
import gc
import argparse
from scipy.stats import binom
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import functional as F
from inr import INR
from transformer import Transformer
from database import Database
from hashes.dinohash import dinohash, preprocess
class ImageDataset(Dataset):
def __init__(self, image_files, transform=None):
self.image_files = image_files
self.transform = transform
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image = Image.open(os.path.join(dataset_folder, self.image_files[idx])).convert("RGB").transpose(Image.Transpose.FLIP_LEFT_RIGHT)
if self.transform:
image = self.transform(image)
return preprocess(image)
def combined_transform(image):
transformations = []
for transform in transformations:
image = t.transform(image, method=transform)
return image
def generate_roc(matches, bits):
matches = matches * bits
taus = np.arange(bits+1)
tpr = [(matches>=tau).mean() for tau in taus]
fpr = 1 - binom.cdf(taus-1, bits, 0.5)
df = pd.DataFrame({
"tpr": tpr,
"fpr": fpr,
"tau": taus
})
df.to_csv(f"./results/{hasher.__name__}_{transformation}.csv")
hasher = dinohash
dataset_folder = './adversarial_data/train/adv'
image_files = [f for f in os.listdir(dataset_folder)]
image_files.sort()
image_files = image_files[:1_000]
BATCH_SIZE = 128
N_IMAGE_RETRIEVAL = 1
parser = argparse.ArgumentParser(description ='Perform retrieval benchmarking.')
parser.add_argument('-r', '--refresh', action='store_true')
parser.add_argument('--defense', dest='defense', type=str, default=None,
help='path to defense model')
parser.add_argument('--transform')
args = parser.parse_args()
transformation = args.transform
t = Transformer()
dataset = ImageDataset(image_files, transform=combined_transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
if args.defense:
defense = INR(device='cuda', pretrain_inr_path=args.defense)
os.makedirs("databases", exist_ok=True)
if hasher.__name__ + ".npy" not in os.listdir("databases") or args.refresh:
print("Creating database for", hasher.__name__)
original_hashes = []
image_file_batches = (image_files[i:i+BATCH_SIZE] for i in range(0, len(image_files), BATCH_SIZE))
for image_batch in tqdm(dataloader):
original_hashes.extend(hasher(image_batch).cpu())
gc.collect()
db = Database(original_hashes, storedir=f"databases/{hasher.__name__}")
else:
db = Database(None, storedir=f"databases/{hasher.__name__}")
print(f"Computing bit accuracy for {transformation} + {hasher.__name__}...")
modified_hashes = []
for transformed_images in tqdm(dataloader):
transformed_images = transformed_images.cuda()
if args.defense:
transformed_images = defense.forward(transformed_images)
modified_hashes_batch = hasher(transformed_images).tolist()
modified_hashes.extend(modified_hashes_batch)
modified_hashes = np.array(modified_hashes)
bits = modified_hashes.shape[-1]
matches = db.similarity_score(modified_hashes)
inv_matches = db.similarity_score(modified_hashes[::-1])
print(matches.mean(), matches.std())
print(inv_matches.mean(), inv_matches.std())
with open(f"./results/{hasher.__name__}_{transformation}.txt", "w") as f:
f.write(f"Bit accuracy: {matches.mean()} / {matches.std()}\n")
f.write(f"Random accuracy: {inv_matches.mean()} / {inv_matches.std()}\n")
generate_roc(matches, bits=bits)