From 2b62194a37be6162ae62e9e72547a6af403b1128 Mon Sep 17 00:00:00 2001 From: Ruikang Li <54311152+Lyricccco@users.noreply.github.com> Date: Mon, 11 Nov 2024 14:22:04 +0800 Subject: [PATCH] Update DualDn_model.py --- models/DualDn_model.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/models/DualDn_model.py b/models/DualDn_model.py index d23dce6..4b77eba 100644 --- a/models/DualDn_model.py +++ b/models/DualDn_model.py @@ -343,8 +343,6 @@ def dnd_test(self, current_iter, data_path): self.net_g.eval() out_dict = OrderedDict() - out_dict['lq_sRGB'] = self.lq_sRGB.detach().cpu() - out_dict['ref_sRGB'] = self.ref_sRGB.detach().cpu() if self.is_train: save_path = osp.join(self.opt['path']['visualization'], str(current_iter)) else: @@ -357,7 +355,7 @@ def dnd_test(self, current_iter, data_path): loc_w_l = int(loc_w_l) - 1 top, left = 0, 0 self.out_Raw = torch.zeros(1,1,512,512).to(self.lq_Raw.device) - self.out_sRGB = torch.zeros(1,3,512,512).to(self.lq_sRGB.device) + self.out_sRGB = torch.zeros(1,3,512,512).to(self.lq_Raw.device) exit_flag = False for top in range(loc_h_l-8, loc_h_l): for left in range(loc_w_l-8, loc_w_l): @@ -372,7 +370,7 @@ def dnd_test(self, current_iter, data_path): out_Rawpatch, out_sRGBpatch = self.net_g(self.lq_Raw[:,:,top:top+528,left:left+528], self.color_mask[:,:,top:top+528,left:left+528], self.wb_matrix, self.rgb_xyz_matrix, self.ref, self.gamma_type, self.demosaic_type, alpha = self.alpha, final_stage = self.final_stage) self.out_Raw, self.out_sRGB = out_Rawpatch[:,:,loc_h_l-top:loc_h_l-top+512,loc_w_l-left:loc_w_l-left+512], out_sRGBpatch[:,:,loc_h_l-top:loc_h_l-top+512,loc_w_l-left:loc_w_l-left+512] - ##* intermediate output + ##* intermediate sRGB output # self.out_sRGB = run_pipeline(self.out_Raw, {'color_mask':color_mask, 'wb_matrix':self.wb_matrix, 'color_desc':'RGBG', 'rgb_xyz_matrix':self.rgb_xyz_matrix, 'ref': self.ref, 'alpha':self.alpha}, False, 'normal', self.final_stage) out_dict['out_Raw'] = self.out_Raw.detach().cpu() @@ -391,19 +389,9 @@ def dnd_test(self, current_iter, data_path): save_out_sRGB_path = osp.join(save_path,'visuals', '{}_{:0=2}_ours.png'.format(img_ind, i+1)) imwrite(out_sRGB, save_out_sRGB_path) - self.net_g.train() - lq_sRGB = tensor2img([out_dict['lq_sRGB']], rgb2bgr=True) - save_lq_sRGB_path = osp.join(save_path,'visuals', '{}_lq.png'.format(img_ind)) - imwrite(lq_sRGB, save_lq_sRGB_path) - ref_sRGB = tensor2img([out_dict['ref_sRGB']], rgb2bgr=False) - save_ref_sRGB_path = osp.join(save_path,'visuals', '{}_ref.png'.format(img_ind)) - imwrite(ref_sRGB, save_ref_sRGB_path) - del self.lq_Raw - del self.lq_sRGB del self.out_Raw del self.out_sRGB - del self.ref_sRGB torch.cuda.empty_cache() def dist_validation(self, dataloader, current_iter, tb_logger):