-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
83 lines (66 loc) · 2.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from torchvision import transforms
from PIL import Image
import os
import logging
from tqdm import tqdm
from src.denoising_network import DenoisingNetwork
import cProfile
import pstats
from pstats import SortKey
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
IMAGE_SIZE = (256, 256) # Size of the images to be loaded
TEST_IMAGES_PATH = 'data/test_images/'
MODEL_PATH = 'denoising_network.pth'
def load_image(path):
"""Load and preprocess a single image."""
try:
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
])
with Image.open(path) as img:
return transform(img.convert('RGB'))
except Exception as e:
logging.error(f"Error processing {path}: {str(e)}")
return None
def load_images_from_directory(directory):
"""Load all images from a directory."""
images = []
for filename in os.listdir(directory):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(directory, filename)
img = load_image(img_path)
if img is not None:
label = 1 if filename.startswith('real') else 0
images.append((img, filename, label))
return images
def process_output(output):
"""Process the model output to determine if the image is real or generated."""
mean_output = output.mean().item()
return 1 if mean_output > 0.5 else 0, mean_output
def main():
logging.info("Starting main function")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
model = DenoisingNetwork().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
logging.info("Model loaded successfully")
test_images = load_images_from_directory(TEST_IMAGES_PATH)
logging.info(f"Number of test images loaded: {len(test_images)}")
with torch.no_grad():
for img, filename, true_label in tqdm(test_images, desc="Processing images"):
img_tensor = img.unsqueeze(0).to(device)
output = model(img_tensor)
prediction, score = process_output(output)
logging.info(f"Image: {filename}, Prediction: {'real' if prediction == 1 else 'generated'}, Score: {score:.4f}, True Label: {'real' if true_label == 1 else 'generated'}")
logging.info("Main function completed")
def profile_main():
with cProfile.Profile() as pr:
main() # Your existing main function
stats = pstats.Stats(pr)
stats.sort_stats(SortKey.TIME)
stats.dump_stats('depixel_profile.prof')
if __name__ == "__main__":
profile_main()