-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy patheval_super_resolution.py
144 lines (123 loc) · 4.62 KB
/
eval_super_resolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Evaluation script for the super-resolution experiment with the Stanford bunny.
Loads the training mesh from the data_generation folder and performs mesh
subdivision using PyMesh.
Automatically aligns the eigenvectors using the KL divergence of the histograms.
Arguments:
- checkpoint: path to a Pytorch Lightning checkpoint file
Note: requires the --dataset_dir flag to be specified as well.
"""
from argparse import ArgumentParser
import matplotlib.pyplot as plt
import numpy as np
import pymesh
import pytorch_lightning as pl
import torch
from scipy.spatial.transform import Rotation as R
from sklearn.metrics import r2_score
from src.data.graph_dataset import GraphDataset
from src.models.graph_inr import GraphINR
from src.plotting.figures import draw_mesh, draw_pc
from src.utils.data_generation import get_fourier, load_mesh, mesh_to_graph
from src.utils.eigenvectors import align_eigenvectors_kl
from src.utils.get_predictions import get_batched_predictions
# Read arguments
parser = ArgumentParser()
parser.add_argument("checkpoint", type=str)
parser = pl.Trainer.add_argparse_args(parser)
parser = GraphINR.add_model_specific_args(parser)
parser = GraphDataset.add_dataset_specific_args(parser)
args = parser.parse_args()
# Data
dataset = GraphDataset(**vars(args))
mesh_train = load_mesh("data_generation/bunny/reconstruction/bun_zipper.ply")
u_train = dataset.get_inputs(0).numpy()
y_train = dataset.get_target(0).numpy()
# Plot training signal
rot = R.from_euler("xyz", [90, 00, 145], degrees=True).as_matrix()
fig = draw_mesh(mesh_train, intensity=y_train[:, 0], colorscale="Reds", rot=rot)
fig.update_layout(scene_camera=dict(eye=dict(x=1.1, y=1.1, z=0.2)))
fig.show()
# Model
model = GraphINR.load_from_checkpoint(args.checkpoint)
# Plot training predictions
inputs = torch.from_numpy(u_train).float()
_, pred = get_batched_predictions(model, inputs, 0)
fig = draw_mesh(mesh_train, intensity=pred, colorscale="Reds", rot=rot)
fig.update_layout(scene_camera=dict(eye=dict(x=1.1, y=1.1, z=0.2)))
fig.show()
# Get test data and align eigenvectors to training ones
mesh_test = pymesh.subdivide(mesh_train, order=1)
_, adj_test = mesh_to_graph(mesh_test)
u_test = get_fourier(adj_test, k=args.n_fourier)
u_test = align_eigenvectors_kl(u_train, u_test)
# Predict signal
inputs = torch.from_numpy(u_test).float()
_, pred = get_batched_predictions(model, inputs, 0)
# Plot test signal
fig = draw_mesh(
mesh_test,
intensity=pred,
colorscale="Reds",
rot=rot,
cmin=y_train.min(),
cmax=y_train.max(),
)
fig.update_layout(scene_camera=dict(eye=dict(x=1.1, y=1.1, z=0.2)))
fig.show()
# Plot zoomed-in point clouds (take screenshots here!)
zoom = 0.6 # Lower is more zoomed
inputs = torch.from_numpy(u_train).float()
_, pred = get_batched_predictions(model, inputs, 0)
fig = draw_mesh(mesh_train, rot=rot, color="black")
fig.update_layout(scene_camera=dict(eye=dict(x=zoom, y=zoom, z=0.2)))
pc_trace = draw_pc(
mesh_train.vertices * 1.001,
color=pred[:, 0],
colorscale="Reds",
rot=rot,
marker_size=1.5,
).data[0]
fig.add_trace(pc_trace)
fig.write_html("super_resolution_bunny_original.html")
fig.show()
inputs = torch.from_numpy(u_test).float()
_, pred = get_batched_predictions(model, inputs, 0)
fig = draw_mesh(mesh_test, rot=rot, color="black")
fig.update_layout(scene_camera=dict(eye=dict(x=zoom, y=zoom, z=0.2)))
pc_trace = draw_pc(
mesh_test.vertices * 1.001,
color=pred[:, 0],
colorscale="Reds",
rot=rot,
marker_size=1.5,
).data[0]
fig.add_trace(pc_trace)
fig.write_html("super_resolution_bunny_superresolved.html")
fig.show()
# Compute squared error per node
mse = (y_train - pred[: u_train.shape[0]]) ** 2
# Plot error per node
fig = draw_mesh(mesh_train, intensity=mse, colorscale="Reds", rot=rot)
fig.update_layout(scene_camera=dict(eye=dict(x=1.1, y=1.1, z=0.2)))
fig.write_html("super_resolution_error.html")
fig.show()
# Compute r2
score = r2_score(y_train, pred[: u_train.shape[0]])
print(f"R2 score: {score}")
# Compute r2 without 90th percentile outliers
mask = mse < np.percentile(mse, 90)
r2_score_adjusted = r2_score(y_train[mask], pred[: u_train.shape[0]][mask])
print(f"R2 score adjusted (90th percentile): {r2_score_adjusted}")
# Compute r2 without 95th percentile outliers
mask = mse < np.percentile(mse, 95)
r2_score_adjusted = r2_score(y_train[mask], pred[: u_train.shape[0]][mask])
print(f"R2 score adjusted (95th percentile): {r2_score_adjusted}")
# Plot distribution of squared error
plt.figure(figsize=(2.2, 2.2))
plt.hist(mse, bins=10, density=True)
plt.yscale("log")
plt.xlabel("Squared error")
plt.ylabel("Density")
plt.tight_layout()
plt.savefig("super_resolution_error_density.pdf", bbox_inches="tight")