-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_ebb.py
99 lines (75 loc) · 3.45 KB
/
test_ebb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import subprocess
import lightning.pytorch as pl
import numpy as np
import torch
import torch.nn as nn
from pytorch_msssim import ssim
from torch.utils.data import DataLoader
from tqdm import tqdm
from net.model import bokeh
from utils.dataset_utils import BokehDataset_ebb as BokehDataset
from utils.image_io import save_image_tensor
from utils.val_utils import AverageMeter, compute_psnr_ssim
def ssim_loss(recoverd, clean):
assert recoverd.shape == clean.shape
return 1 - ssim(recoverd, clean, data_range=1, size_average=True)
class BokehModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = bokeh()
self.loss_fn = nn.L1Loss()
def forward(self, x, x1, x2):
return self.net(x, x1, x2)
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
([clean_name], degrad_patch, clean_patch, depth_patch, mask_patch) = batch
restored = self.net(degrad_patch, depth_patch, mask_patch)
loss = self.loss_fn(restored, clean_patch)
ssim_l = ssim_loss(restored, clean_patch)
loss = loss + ssim_l
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def lr_scheduler_step(self, scheduler, metric):
scheduler.step(self.current_epoch)
lr = scheduler.get_lr()
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=2e-4)
scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer, warmup_epochs=5, max_epochs=100)
return [optimizer], [scheduler]
def test_Bokeh(net, dataset, task="bokeh"):
output_path = testopt.output_path + task + '/'
subprocess.check_output(['mkdir', '-p', output_path])
dataset.set_dataset(task)
testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=1)
psnr = AverageMeter()
ssim = AverageMeter()
with torch.no_grad():
for ([degraded_name], degrad_patch, clean_patch, depth_patch, mask_patch) in tqdm(testloader):
degrad_patch, clean_patch, depth_patch, mask_patch = degrad_patch.cuda(), clean_patch.cuda(), depth_patch.cuda(), mask_patch.cuda()
restored = net(degrad_patch, depth_patch, mask_patch)
temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch)
psnr.update(temp_psnr, N)
ssim.update(temp_ssim, N)
save_image_tensor(restored, output_path + degraded_name[0] + '.png')
print("PSNR: %.2f, SSIM: %.4f" % (psnr.avg, ssim.avg))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Input Parameters
parser.add_argument('--cuda', type=int, default=0)
parser.add_argument('--bokeh_path', type=str, default='./data/EBB/test/', help='save path of test bokeh images')
parser.add_argument('--output_path', type=str, default="output/", help='output save path')
parser.add_argument('--ckpt_name', type=str, default="", help='checkpoint save path')
testopt = parser.parse_args()
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(testopt.cuda)
ckpt_path = testopt.ckpt_name
print("CKPT name : {}".format(ckpt_path))
net = BokehModel.load_from_checkpoint(ckpt_path).cuda()
net.eval()
print('Start testing rain streak removal...')
bokeh_set = BokehDataset(testopt, train=False)
test_Bokeh(net, bokeh_set, task="bokeh")