diff --git a/archs/DualDn_arch.py b/archs/DualDn_arch.py index 98cc975..f52b989 100644 --- a/archs/DualDn_arch.py +++ b/archs/DualDn_arch.py @@ -37,6 +37,7 @@ def __init__(self, out_c = 4, c = 64, backbone_type = 'Restormer', + bgu_ratio = 8, bias = False, LayerNorm_type = 'BiasFree' ): @@ -44,6 +45,7 @@ def __init__(self, super(Raw_Dn, self).__init__() self.with_noise_map= with_noise_map + self.bgu_ratio = bgu_ratio if self.with_noise_map: self.num = 2 else: @@ -83,7 +85,7 @@ def forward(self, in_raw, colormask, wb_matrix, rgb_xyz_matrix, ref='D65', gamma if ref_sRGB != None and not (ref_sRGB == 0).all(): # Only for inference, to keep consistent with ref_sRGB's color in_srgb = torch.clamp(in_srgb, min=1e-6, max=1) # For BGU estimation, there may be overflow, set bound. - bgu_gamma = bguFit(in_srgb, ref_sRGB) + bgu_gamma = bguFit(in_srgb, ref_sRGB, self.bgu_ratio) bgu_srgb = bguSlice(bgu_gamma, in_srgb) in_srgb = torch.from_numpy(bgu_srgb).float().cuda().permute(2,0,1).unsqueeze() @@ -100,6 +102,7 @@ def __init__(self, out_c = 3, c = 64, backbone_type = 'Restormer', + bgu_ratio = 8, bias = False, LayerNorm_type = 'BiasFree' ): @@ -109,6 +112,7 @@ def __init__(self, self.down = nn.PixelUnshuffle(downscale_factor=2) self.conv_out = nn.Conv2d(c, out_c, kernel_size=(1, 1), stride=(1, 1), bias=bias) self.conv_in = nn.Conv2d(in_c, c, kernel_size=(1, 1), stride=(1, 1), bias=bias) + self.bgu_ratio = bgu_ratio if backbone_type == 'Restormer': self.backbone = Restormer(dim=c, num_blocks = [4,6,6,8], num_refinement_blocks = 4, heads = [1,2,4,8], ffn_expansion_factor = 2.66, bias = bias, LayerNorm_type = LayerNorm_type, dual_pixel_task = False) elif backbone_type == 'SwinIR': @@ -125,7 +129,7 @@ def forward(self, in_raw, colormask, wb_matrix, rgb_xyz_matrix, ref='D65', gamma if ref_sRGB != None and not (ref_sRGB == 0).all(): # Only for inference, to keep consistent with ref_sRGB's color in_srgb = torch.clamp(in_srgb, min=1e-6, max=1) # For BGU estimation, there may be overflow, set bound. - bgu_gamma = bguFit(in_srgb, ref_sRGB) + bgu_gamma = bguFit(in_srgb, ref_sRGB, self.bgu_ratio) bgu_srgb = bguSlice(bgu_gamma, in_srgb) in_srgb = torch.from_numpy(bgu_srgb).float().cuda().permute(2,0,1).unsqueeze() @@ -144,6 +148,7 @@ def __init__(self, out_c = 3, c = 64, backbone_type = 'Restormer', + bgu_ratio = 8, bias = False, LayerNorm_type = 'BiasFree' ): @@ -151,6 +156,7 @@ def __init__(self, super(DualDn, self).__init__() self.with_noise_map= with_noise_map + self.bgu_ratio = bgu_ratio if self.with_noise_map: self.num = 2 else: @@ -200,7 +206,7 @@ def forward(self, raw, colormask, wb_matrix, rgb_xyz_matrix, ref='D65', gamma_ty if ref_sRGB != None and not (ref_sRGB == 0).all(): # Only for inference, to keep consistent with ref_sRGB's color skip_srgb = torch.clamp(skip_srgb, min=1e-6, max=1) # For BGU estimation, there may be overflow, set bound. - bgu_gamma = bguFit(skip_srgb, ref_sRGB) + bgu_gamma = bguFit(skip_srgb, ref_sRGB, self.bgu_ratio) bgu_srgb = bguSlice(bgu_gamma, skip_srgb) rgb_noise_map = bguSlice(bgu_gamma, rgb_noise_map) skip_srgb = torch.from_numpy(bgu_srgb).float().cuda().permute(2,0,1).unsqueeze(0) @@ -213,7 +219,7 @@ def forward(self, raw, colormask, wb_matrix, rgb_xyz_matrix, ref='D65', gamma_ty if ref_sRGB != None and not (ref_sRGB == 0).all(): # Only for inference, to keep consistent with ref_sRGB's color skip_srgb = torch.clamp(skip_srgb, min=1e-6, max=1) # For BGU estimation, there may be overflow, set bound. - bgu_gamma = bguFit(skip_srgb, ref_sRGB) + bgu_gamma = bguFit(skip_srgb, ref_sRGB, self.bgu_ratio) bgu_srgb = bguSlice(bgu_gamma, skip_srgb) skip_srgb = torch.from_numpy(bgu_srgb).float().cuda().permute(2,0,1).unsqueeze()