Skip to content

Commit

Permalink
Update DualDn_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyricccco authored Nov 11, 2024
1 parent ecc0812 commit 2b62194
Showing 1 changed file with 2 additions and 14 deletions.
16 changes: 2 additions & 14 deletions models/DualDn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,6 @@ def dnd_test(self, current_iter, data_path):

self.net_g.eval()
out_dict = OrderedDict()
out_dict['lq_sRGB'] = self.lq_sRGB.detach().cpu()
out_dict['ref_sRGB'] = self.ref_sRGB.detach().cpu()
if self.is_train:
save_path = osp.join(self.opt['path']['visualization'], str(current_iter))
else:
Expand All @@ -357,7 +355,7 @@ def dnd_test(self, current_iter, data_path):
loc_w_l = int(loc_w_l) - 1
top, left = 0, 0
self.out_Raw = torch.zeros(1,1,512,512).to(self.lq_Raw.device)
self.out_sRGB = torch.zeros(1,3,512,512).to(self.lq_sRGB.device)
self.out_sRGB = torch.zeros(1,3,512,512).to(self.lq_Raw.device)
exit_flag = False
for top in range(loc_h_l-8, loc_h_l):
for left in range(loc_w_l-8, loc_w_l):
Expand All @@ -372,7 +370,7 @@ def dnd_test(self, current_iter, data_path):
out_Rawpatch, out_sRGBpatch = self.net_g(self.lq_Raw[:,:,top:top+528,left:left+528], self.color_mask[:,:,top:top+528,left:left+528], self.wb_matrix, self.rgb_xyz_matrix, self.ref, self.gamma_type, self.demosaic_type, alpha = self.alpha, final_stage = self.final_stage)
self.out_Raw, self.out_sRGB = out_Rawpatch[:,:,loc_h_l-top:loc_h_l-top+512,loc_w_l-left:loc_w_l-left+512], out_sRGBpatch[:,:,loc_h_l-top:loc_h_l-top+512,loc_w_l-left:loc_w_l-left+512]

##* intermediate output
##* intermediate sRGB output
# self.out_sRGB = run_pipeline(self.out_Raw, {'color_mask':color_mask, 'wb_matrix':self.wb_matrix, 'color_desc':'RGBG', 'rgb_xyz_matrix':self.rgb_xyz_matrix, 'ref': self.ref, 'alpha':self.alpha}, False, 'normal', self.final_stage)

out_dict['out_Raw'] = self.out_Raw.detach().cpu()
Expand All @@ -391,19 +389,9 @@ def dnd_test(self, current_iter, data_path):
save_out_sRGB_path = osp.join(save_path,'visuals', '{}_{:0=2}_ours.png'.format(img_ind, i+1))
imwrite(out_sRGB, save_out_sRGB_path)

self.net_g.train()
lq_sRGB = tensor2img([out_dict['lq_sRGB']], rgb2bgr=True)
save_lq_sRGB_path = osp.join(save_path,'visuals', '{}_lq.png'.format(img_ind))
imwrite(lq_sRGB, save_lq_sRGB_path)
ref_sRGB = tensor2img([out_dict['ref_sRGB']], rgb2bgr=False)
save_ref_sRGB_path = osp.join(save_path,'visuals', '{}_ref.png'.format(img_ind))
imwrite(ref_sRGB, save_ref_sRGB_path)

del self.lq_Raw
del self.lq_sRGB
del self.out_Raw
del self.out_sRGB
del self.ref_sRGB
torch.cuda.empty_cache()

def dist_validation(self, dataloader, current_iter, tb_logger):
Expand Down

0 comments on commit 2b62194

Please sign in to comment.