-
Notifications
You must be signed in to change notification settings - Fork 0
/
wasserstein_distance.py
41 lines (29 loc) · 1.36 KB
/
wasserstein_distance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python3
from PIL import Image
import numpy as np
from math import sqrt
import argparse
def wasserstein_distance(original: np.array, compressed: np.array) -> float:
if original.shape != compressed.shape:
raise("Original and compressed array have different shapes")
flattened_original = original.flatten()
flattened_compressed = compressed.flatten()
sorted_original = np.sort(flattened_original)
sorted_compressed = np.sort(flattened_compressed)
# sorted_original = flattened_original
# sorted_compressed = flattened_compressed
commulative_difference = np.sum(np.abs(sorted_original - sorted_compressed))
return commulative_difference / flattened_original.size
def main():
parser = argparse.ArgumentParser(prog='psnr')
parser.add_argument("-o", "--original", help="original image")
parser.add_argument("-c", "--compressed", help="compressed image")
args = parser.parse_args()
print("Original image = [{}]".format(args.original))
print("Compressed image = [{}]".format(args.compressed))
original_image = np.array(Image.open(args.original).convert("RGB"))
compressed_image = np.array(Image.open(args.compressed).convert("RGB"))
wasserstein_value = wasserstein_distance(original=original_image, compressed=compressed_image)
print("Wasserstein distance = [{:.2f}]".format(wasserstein_value))
if __name__ == "__main__":
main()