-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdraw_original_reconstruction.py
executable file
·67 lines (57 loc) · 2.09 KB
/
draw_original_reconstruction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams["figure.figsize"] = (10, 8)
def draw_orig_reconstr(data, decoded_imgs, title, dir_res_model, dataset, temporal=False):
# os.makedirs(model_name, exist_ok=True)
# filename = os.path.join(model_name, "input-decoder.png")
vmax = max(data.max(), decoded_imgs.max())
vmin = min(data.min(), decoded_imgs.min())
print('range of data:', vmin, vmax)
if (dataset == "flow"):
vmax = 70
if (dataset == "droplet"):
vmin = -1.5
vmax = 1.5
print('range of colormap:', vmin, vmax)
cmap = 'viridis'
if (dataset == "droplet"):
cmap='gray'
#figsize=(60, 40)
shift = 20 # just to show nice examples
# draw (original) data and (reconstructed) decoded_imgs
fig=plt.figure()
#fig.set_size_inches(8, 6)
columns = 12
rows = 2
for i in range(1, columns +1):
if (temporal == True):
img = data[i+shift,1,:,:,0]
else:
img = data[i+shift,:,:,0]
#img.reshape(84,444)
#print(img.shape)
fig.add_subplot(rows, columns, i)
plt.axis('off')
# plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
# plt.margins(0,0)
plt.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax)
if (temporal == True):
img = decoded_imgs[i+shift,1,:,:,0]
else:
img = decoded_imgs[i+shift,:,:,0]
#img.reshape(84,444)
#print(img.shape)
fig.add_subplot(rows, columns, i+columns)
plt.axis('off')
# plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
# plt.margins(0,0)
plt.imshow(img, cmap=cmap, vmin=vmin, vmax=vmax)
title += 'original and reconstructed frames'
plt.suptitle(title, fontsize=15)
plt.show()
plt.tight_layout()
fig.set_size_inches(12, 9)
fig.savefig('{}/orig_reconstr.png'.format(dir_res_model), bbox_inches='tight')
fig.savefig('{}/orig_reconstr.png'.format(dir_res_model), dpi=300)
plt.close(fig)