Skip to content

Commit

Permalink
Update DualDn_arch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyricccco authored Nov 9, 2024
1 parent f7b6629 commit b9702bb
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions archs/DualDn_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ def __init__(self,
out_c = 4,
c = 64,
backbone_type = 'Restormer',
bgu_ratio = 8,
bias = False,
LayerNorm_type = 'BiasFree'
):

super(Raw_Dn, self).__init__()

self.with_noise_map= with_noise_map
self.bgu_ratio = bgu_ratio
if self.with_noise_map:
self.num = 2
else:
Expand Down Expand Up @@ -83,7 +85,7 @@ def forward(self, in_raw, colormask, wb_matrix, rgb_xyz_matrix, ref='D65', gamma

if ref_sRGB != None and not (ref_sRGB == 0).all(): # Only for inference, to keep consistent with ref_sRGB's color
in_srgb = torch.clamp(in_srgb, min=1e-6, max=1) # For BGU estimation, there may be overflow, set bound.
bgu_gamma = bguFit(in_srgb, ref_sRGB)
bgu_gamma = bguFit(in_srgb, ref_sRGB, self.bgu_ratio)
bgu_srgb = bguSlice(bgu_gamma, in_srgb)
in_srgb = torch.from_numpy(bgu_srgb).float().cuda().permute(2,0,1).unsqueeze()

Expand All @@ -100,6 +102,7 @@ def __init__(self,
out_c = 3,
c = 64,
backbone_type = 'Restormer',
bgu_ratio = 8,
bias = False,
LayerNorm_type = 'BiasFree'
):
Expand All @@ -109,6 +112,7 @@ def __init__(self,
self.down = nn.PixelUnshuffle(downscale_factor=2)
self.conv_out = nn.Conv2d(c, out_c, kernel_size=(1, 1), stride=(1, 1), bias=bias)
self.conv_in = nn.Conv2d(in_c, c, kernel_size=(1, 1), stride=(1, 1), bias=bias)
self.bgu_ratio = bgu_ratio
if backbone_type == 'Restormer':
self.backbone = Restormer(dim=c, num_blocks = [4,6,6,8], num_refinement_blocks = 4, heads = [1,2,4,8], ffn_expansion_factor = 2.66, bias = bias, LayerNorm_type = LayerNorm_type, dual_pixel_task = False)
elif backbone_type == 'SwinIR':
Expand All @@ -125,7 +129,7 @@ def forward(self, in_raw, colormask, wb_matrix, rgb_xyz_matrix, ref='D65', gamma

if ref_sRGB != None and not (ref_sRGB == 0).all(): # Only for inference, to keep consistent with ref_sRGB's color
in_srgb = torch.clamp(in_srgb, min=1e-6, max=1) # For BGU estimation, there may be overflow, set bound.
bgu_gamma = bguFit(in_srgb, ref_sRGB)
bgu_gamma = bguFit(in_srgb, ref_sRGB, self.bgu_ratio)
bgu_srgb = bguSlice(bgu_gamma, in_srgb)
in_srgb = torch.from_numpy(bgu_srgb).float().cuda().permute(2,0,1).unsqueeze()

Expand All @@ -144,13 +148,15 @@ def __init__(self,
out_c = 3,
c = 64,
backbone_type = 'Restormer',
bgu_ratio = 8,
bias = False,
LayerNorm_type = 'BiasFree'
):

super(DualDn, self).__init__()

self.with_noise_map= with_noise_map
self.bgu_ratio = bgu_ratio
if self.with_noise_map:
self.num = 2
else:
Expand Down Expand Up @@ -200,7 +206,7 @@ def forward(self, raw, colormask, wb_matrix, rgb_xyz_matrix, ref='D65', gamma_ty

if ref_sRGB != None and not (ref_sRGB == 0).all(): # Only for inference, to keep consistent with ref_sRGB's color
skip_srgb = torch.clamp(skip_srgb, min=1e-6, max=1) # For BGU estimation, there may be overflow, set bound.
bgu_gamma = bguFit(skip_srgb, ref_sRGB)
bgu_gamma = bguFit(skip_srgb, ref_sRGB, self.bgu_ratio)
bgu_srgb = bguSlice(bgu_gamma, skip_srgb)
rgb_noise_map = bguSlice(bgu_gamma, rgb_noise_map)
skip_srgb = torch.from_numpy(bgu_srgb).float().cuda().permute(2,0,1).unsqueeze(0)
Expand All @@ -213,7 +219,7 @@ def forward(self, raw, colormask, wb_matrix, rgb_xyz_matrix, ref='D65', gamma_ty

if ref_sRGB != None and not (ref_sRGB == 0).all(): # Only for inference, to keep consistent with ref_sRGB's color
skip_srgb = torch.clamp(skip_srgb, min=1e-6, max=1) # For BGU estimation, there may be overflow, set bound.
bgu_gamma = bguFit(skip_srgb, ref_sRGB)
bgu_gamma = bguFit(skip_srgb, ref_sRGB, self.bgu_ratio)
bgu_srgb = bguSlice(bgu_gamma, skip_srgb)
skip_srgb = torch.from_numpy(bgu_srgb).float().cuda().permute(2,0,1).unsqueeze()

Expand Down

0 comments on commit b9702bb

Please sign in to comment.