-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain2.py
107 lines (81 loc) · 3.27 KB
/
main2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
import gc
import argparse
from scipy.stats import binom
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import functional as F
from transformer import Transformer
from database import Database
from hashes.neuralhash import neuralhash, preprocess
class ImageDataset(Dataset):
def __init__(self, image_files, transform=None):
self.image_files = image_files
self.transform = transform
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image = Image.open(os.path.join(dataset_folder, self.image_files[idx])).convert("RGB")
if self.transform:
image = self.transform(image)
return preprocess(image)
def combined_transform(image):
transformations = ["screenshot", transformation, "erase", "text"]
for transform in transformations:
image = t.transform(image, method=transform)
return image
def generate_roc(matches, bits):
matches = matches * bits
taus = np.arange(bits+1)
tpr = [(matches>=tau).mean() for tau in taus]
fpr = 1 - binom.cdf(taus-1, bits, 0.5)
df = pd.DataFrame({
"tpr": tpr,
"fpr": fpr,
"tau": taus
})
df.to_csv(f"./results/{hasher.__name__}_{transformation}.csv")
hasher = neuralhash
dataset_folder = './diffusion_data'
image_files = [f for f in os.listdir(dataset_folder)]
image_files.sort()
image_files = image_files[:1_000_000]
BATCH_SIZE = 64
N_IMAGE_RETRIEVAL = 1
parser = argparse.ArgumentParser(description ='Perform retrieval benchmarking.')
parser.add_argument('-r', '--refresh', action='store_true')
parser.add_argument('--transform')
args = parser.parse_args()
transformation = args.transform
t = Transformer()
dataset = ImageDataset(image_files, transform=combined_transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
os.makedirs("databases", exist_ok=True)
if hasher.__name__ + ".npy" not in os.listdir("databases") or args.refresh:
print("Creating database for", hasher.__name__)
original_hashes = []
image_file_batches = (image_files[i:i+BATCH_SIZE] for i in range(0, len(image_files), BATCH_SIZE))
for image_batch in tqdm(dataloader):
original_hashes.extend(hasher(image_batch).cpu())
gc.collect()
db = Database(original_hashes, storedir=f"databases/{hasher.__name__}")
else:
db = Database(None, storedir=f"databases/{hasher.__name__}")
print(f"Computing bit accuracy for {transformation} + {hasher.__name__}...")
modified_hashes = []
for transformed_images in tqdm(dataloader):
modified_hashes_batch = hasher(transformed_images).tolist()
modified_hashes.extend(modified_hashes_batch)
modified_hashes = np.array(modified_hashes)
bits = modified_hashes.shape[-1]
matches = db.similarity_score(modified_hashes)
inv_matches = db.similarity_score(modified_hashes[::-1])
print(matches.mean(), matches.std())
print(inv_matches.mean(), inv_matches.std())
with open(f"./results/{hasher.__name__}_{transformation}.txt", "w") as f:
f.write(f"Bit accuracy: {matches.mean()} / {matches.std()}\n")
f.write(f"Random accuracy: {inv_matches.mean()} / {inv_matches.std()}\n")
generate_roc(matches, bits=bits)