-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_utils.py
116 lines (86 loc) · 3.17 KB
/
image_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import logging
import os
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tqdm import tqdm
def check_is_dir(path):
if not os.path.isdir(path):
raise ValueError(f"Provided path: {path} is not a directory")
return True
def filter_images(list_of_files):
valid_extensions = {".jpg", ".png", ".jpeg", ".webp"}
return [
file
for file in list_of_files
if any(file.endswith(ext) for ext in valid_extensions)
]
def load_image(image_path):
image = Image.open(image_path).convert("RGB")
return image
def read_images_from_dir(dir_path):
check_is_dir(dir_path)
files = os.listdir(dir_path)
image_files = filter_images(files)
image_paths = [os.path.join(dir_path, file) for file in image_files]
images = [load_image(image_path) for image_path in tqdm(image_paths)]
logging.info(f"Loaded {len(images)} images from {dir_path}")
return images
def get_images_from_dir(dir_path):
check_is_dir(dir_path)
files = os.listdir(dir_path)
image_files = filter_images(files)
image_paths = [os.path.join(dir_path, file) for file in image_files]
return image_paths
def max_resolution_rescale(image, max_width, max_height):
width, height = image.size
if width > max_width or height > max_height:
ratio = min(max_width / width, max_height / height)
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.LANCZOS)
return image
def min_resolution_filter(image, min_width, min_height):
width, height = image.size
return width >= min_width and height >= min_height
def plot_image(image):
plt.imshow(image)
plt.axis("off")
plt.show()
def center_crop(image, new_width, new_height):
width, height = image.size
left = (width - new_width) / 2
top = (height - new_height) / 2
right = (width + new_width) / 2
bottom = (height + new_height) / 2
cropped_image = image.crop((left, top, right, bottom))
logging.info(f"Center cropped image to {new_width}x{new_height}")
return cropped_image
def save_image(image, save_path):
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if not isinstance(image, Image.Image):
raise ValueError("Input image must be a numpy array or PIL Image")
if image.mode != "RGB":
image = image.convert("RGB")
image.save(save_path)
logging.info(f"Saved image to {save_path}")
def create_directory(dir_path):
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
print(f"Directory created: {dir_path}")
else:
print(f"Directory already exists: {dir_path}")
def save_images_to_dir(images, dir_path):
create_directory(dir_path)
check_is_dir(dir_path)
for i, image in tqdm(enumerate(images, 1)):
save_path = os.path.join(dir_path, f"image_{i}.jpg")
save_image(image, save_path)
return True
def get_images_from_dir(dir_path):
check_is_dir(dir_path)
files = os.listdir(dir_path)
image_files = filter_images(files)
image_paths = [os.path.join(dir_path, file) for file in image_files]
return image_paths