diff --git a/Comparisons.md b/Comparisons.md new file mode 100644 index 00000000..1542d4c0 --- /dev/null +++ b/Comparisons.md @@ -0,0 +1,24 @@ +# Comparisons + +## Comparisons among different model versions + +Note that V1.3 is not always better than V1.2. You may need to try different models based on your purpose and inputs. + +| Version | Strengths | Weaknesses | +| :---: | :---: | :---: | +|V1.3 | ✓ natural outputs
✓better results on very low-quality inputs
✓ work on relatively high-quality inputs
✓ can have repeated (twice) restorations | ✗ not very sharp
✗ have a slight change on identity | +|V1.2 | ✓ sharper output
✓ with beauty makeup | ✗ some outputs are unnatural| + +For the following images, you may need to **zoom in** for comparing details, or **click the image** to see in the full size. + +| Input | V1 | V1.2 | V1.3 +| :---: | :---: | :---: | :---: | +|![019_Anne_Hathaway_01_00](https://user-images.githubusercontent.com/17445847/153762146-96b25999-4ddd-42a5-a3fe-bb90565f4c4f.png)| ![](https://user-images.githubusercontent.com/17445847/153762256-ef41e749-5a27-495c-8a9c-d8403be55869.png) | ![](https://user-images.githubusercontent.com/17445847/153762297-d41582fc-6253-4e7e-a1ce-4dc237ae3bf3.png) | ![](https://user-images.githubusercontent.com/17445847/153762215-e0535e94-b5ba-426e-97b5-35c00873604d.png) | +| ![106_Harry_Styles_00_00](https://user-images.githubusercontent.com/17445847/153789040-632c0eda-c15a-43e9-a63c-9ead64f92d4a.png) | ![](https://user-images.githubusercontent.com/17445847/153789172-93cd4980-5318-4633-a07e-1c8f8064ff89.png) | ![](https://user-images.githubusercontent.com/17445847/153789185-f7b268a7-d1db-47b0-ae4a-335e5d657a18.png) | ![](https://user-images.githubusercontent.com/17445847/153789198-7c7f3bca-0ef0-4494-92f0-20aa6f7d7464.png)| +| ![076_Paris_Hilton_00_00](https://user-images.githubusercontent.com/17445847/153789607-86387770-9db8-441f-b08a-c9679b121b85.png) | ![](https://user-images.githubusercontent.com/17445847/153789619-e56b438a-78a0-425d-8f44-ec4692a43dda.png) | ![](https://user-images.githubusercontent.com/17445847/153789633-5b28f778-3b7f-4e08-8a1d-740ca6e82d8a.png) | ![](https://user-images.githubusercontent.com/17445847/153789645-bc623f21-b32d-4fc3-bfe9-61203407a180.png)| +| ![008_George_Clooney_00_00](https://user-images.githubusercontent.com/17445847/153790017-0c3ca94d-1c9d-4a0e-b539-ab12d4da98ff.png) | ![](https://user-images.githubusercontent.com/17445847/153790028-fb0d38ab-399d-4a30-8154-2dcd72ca90e8.png) | ![](https://user-images.githubusercontent.com/17445847/153790044-1ef68e34-6120-4439-a5d9-0b6cdbe9c3d0.png) | ![](https://user-images.githubusercontent.com/17445847/153790059-a8d3cece-8989-4e9a-9ffe-903e1690cfd6.png)| +| ![057_Madonna_01_00](https://user-images.githubusercontent.com/17445847/153790624-2d0751d0-8fb4-4806-be9d-71b833c2c226.png) | ![](https://user-images.githubusercontent.com/17445847/153790639-7eb870e5-26b2-41dc-b139-b698bb40e6e6.png) | ![](https://user-images.githubusercontent.com/17445847/153790651-86899b7a-a1b6-4242-9e8a-77b462004998.png) | ![](https://user-images.githubusercontent.com/17445847/153790655-c8f6c25b-9b4e-4633-b16f-c43da86cff8f.png)| +| ![044_Amy_Schumer_01_00](https://user-images.githubusercontent.com/17445847/153790811-3fb4fc46-5b4f-45fe-8fcb-a128de2bfa60.png) | ![](https://user-images.githubusercontent.com/17445847/153790817-d45aa4ff-bfc4-4163-b462-75eef9426fab.png) | ![](https://user-images.githubusercontent.com/17445847/153790824-5f93c3a0-fe5a-42f6-8b4b-5a5de8cd0ac3.png) | ![](https://user-images.githubusercontent.com/17445847/153790835-0edf9944-05c7-41c4-8581-4dc5ffc56c9d.png)| +| ![012_Jackie_Chan_01_00](https://user-images.githubusercontent.com/17445847/153791176-737b016a-e94f-4898-8db7-43e7762141c9.png) | ![](https://user-images.githubusercontent.com/17445847/153791183-2f25a723-56bf-4cd5-aafe-a35513a6d1c5.png) | ![](https://user-images.githubusercontent.com/17445847/153791194-93416cf9-2b58-4e70-b806-27e14c58d4fd.png) | ![](https://user-images.githubusercontent.com/17445847/153791202-aa98659c-b702-4bce-9c47-a2fa5eccc5ae.png)| + + diff --git a/FAQ.md b/FAQ.md new file mode 100644 index 00000000..e4d5a49c --- /dev/null +++ b/FAQ.md @@ -0,0 +1,7 @@ +# FAQ + +1. **How to finetune the GFPGANCleanv1-NoCE-C2 (v1.2) model** + +**A:** 1) The GFPGANCleanv1-NoCE-C2 (v1.2) model uses the *clean* architecture, which is more friendly for deploying. +2) This model is not directly trained. Instead, it is converted from another *bilinear* model. +3) If you want to finetune the GFPGANCleanv1-NoCE-C2 (v1.2), you need to finetune its original *bilinear* model, and then do the conversion. diff --git a/PaperModel.md b/PaperModel.md index aec81d31..e9c8bdc4 100644 --- a/PaperModel.md +++ b/PaperModel.md @@ -60,17 +60,17 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth - Option 1: Load extensions just-in-time(JIT) ```bash - BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1 + BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 # for aligned images - BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned + BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned ``` - Option 2: Have successfully compiled extensions during installation ```bash - python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1 + python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 # for aligned images - python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned + python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned ``` diff --git a/README.md b/README.md index 09d1ccce..1aa5c58c 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,11 @@ GFPGAN aims at developing a **Practical Algorithm for Real-world Face Restoration**.
It leverages rich and diverse priors encapsulated in a pretrained face GAN (*e.g.*, StyleGAN2) for blind face restoration. +:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md). + :triangular_flag_on_post: **Updates** + +- :fire::fire::white_check_mark: Add **[V1.3 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)**, which produces **more natural** restoration results, and better results on *very low-quality* / *high-quality* inputs. See more in [Model zoo](#european_castle-model-zoo), [Comparisons.md](Comparisons.md) - :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/GFPGAN). - :white_check_mark: Support enhancing non-face regions (background) with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN). - :white_check_mark: We provide a *clean* version of GFPGAN, which does not require CUDA extensions. @@ -88,24 +92,54 @@ If you want to use the original model in our paper, please see [PaperModel.md](P ## :zap: Quick Inference -Download pre-trained models: [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth) +We take the v1.3 version for an example. More models can be found [here](#european_castle-model-zoo). + +Download pre-trained models: [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth) ```bash -wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth -P experiments/pretrained_models +wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P experiments/pretrained_models ``` **Inference!** ```bash -python inference_gfpgan.py --upscale 2 --test_path inputs/whole_imgs --save_root results +python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2 +``` + +```console +Usage: python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2 [options]... + + -h show this help + -i input Input image or folder. Default: inputs/whole_imgs + -o output Output folder. Default: results + -v version GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3 + -s upscale The final upsampling scale of the image. Default: 2 + -bg_upsampler background upsampler. Default: realesrgan + -bg_tile Tile size for background sampler, 0 for no tile during testing. Default: 400 + -suffix Suffix of the restored faces + -only_center_face Only restore the center face + -aligned Input are aligned faces + -ext Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto ``` If you want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation and inference. ## :european_castle: Model Zoo -- [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth): No colorization; no CUDA extensions are required. It is still in training. Trained with more data with pre-processing. -- [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth): The paper model, with colorization. +| Version | Model Name | Description | +| :---: | :---: | :---: | +| V1.3 | [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth) | Based on V1.2; **more natural** restoration results; better results on very low-quality / high-quality inputs. | +| V1.2 | [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth) | No colorization; no CUDA extensions are required. Trained with more data with pre-processing. | +| V1 | [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth) | The paper model, with colorization. | + +The comparisons are in [Comparisons.md](Comparisons.md). + +Note that V1.3 is not always better than V1.2. You may need to select different models based on your purpose and inputs. + +| Version | Strengths | Weaknesses | +| :---: | :---: | :---: | +|V1.3 | ✓ natural outputs
✓better results on very low-quality inputs
✓ work on relatively high-quality inputs
✓ can have repeated (twice) restorations | ✗ not very sharp
✗ have a slight change on identity | +|V1.2 | ✓ sharper output
✓ with beauty makeup | ✗ some outputs are unnatural | You can find **more models (such as the discriminators)** here: [[Google Drive](https://drive.google.com/drive/folders/17rLiFzcUMoQuhLnptDsKolegHWwJOnHu?usp=sharing)], OR [[Tencent Cloud 腾讯微云](https://share.weiyun.com/ShYoCCoc)] diff --git a/gfpgan/archs/gfpgan_bilinear_arch.py b/gfpgan/archs/gfpgan_bilinear_arch.py new file mode 100644 index 00000000..d0537b14 --- /dev/null +++ b/gfpgan/archs/gfpgan_bilinear_arch.py @@ -0,0 +1,312 @@ +import math +import random +import torch +from basicsr.archs.stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, + StyleGAN2GeneratorBilinear) +from basicsr.utils.registry import ARCH_REGISTRY +from torch import nn + +from .gfpganv1_arch import ResUpBlock + + +class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear): + """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). + + It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for + deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + num_mlp (int): Layer number of MLP style layers. Default: 8. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__(self, + out_size, + num_style_feat=512, + num_mlp=8, + channel_multiplier=2, + lr_mlp=0.01, + narrow=1, + sft_half=False): + super(StyleGAN2GeneratorBilinearSFT, self).__init__( + out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + lr_mlp=lr_mlp, + narrow=narrow) + self.sft_half = sft_half + + def forward(self, + styles, + conditions, + input_is_latent=False, + noise=None, + randomize_noise=True, + truncation=1, + truncation_latent=None, + inject_index=None, + return_latents=False): + """Forward function for StyleGAN2GeneratorBilinearSFT. + + Args: + styles (list[Tensor]): Sample codes of styles. + conditions (list[Tensor]): SFT conditions to generators. + input_is_latent (bool): Whether input is latent style. Default: False. + noise (Tensor | None): Input noise or None. Default: None. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + truncation (float): The truncation ratio. Default: 1. + truncation_latent (Tensor | None): The truncation latent tensor. Default: None. + inject_index (int | None): The injection index for mixing noise. Default: None. + return_latents (bool): Whether to return style latents. Default: False. + """ + # style codes -> latents with Style MLP layer + if not input_is_latent: + styles = [self.style_mlp(s) for s in styles] + # noises + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers # for each style conv layer + else: # use the stored noise + noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] + # style truncation + if truncation < 1: + style_truncation = [] + for style in styles: + style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) + styles = style_truncation + # get style latents with injection + if len(styles) == 1: + inject_index = self.num_latent + + if styles[0].ndim < 3: + # repeat latent code for all the layers + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + else: # used for encoder with different latent code for each layer + latent = styles[0] + elif len(styles) == 2: # mixing noises + if inject_index is None: + inject_index = random.randint(1, self.num_latent - 1) + latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) + latent = torch.cat([latent1, latent2], 1) + + # main generation + out = self.constant_input(latent.shape[0]) + out = self.style_conv1(out, latent[:, 0], noise=noise[0]) + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], + noise[2::2], self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + + # the conditions may have fewer levels + if i < len(conditions): + # SFT part to combine the conditions + if self.sft_half: # only apply SFT to half of the channels + out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) + out_sft = out_sft * conditions[i - 1] + conditions[i] + out = torch.cat([out_same, out_sft], dim=1) + else: # apply SFT to all the channels + out = out * conditions[i - 1] + conditions[i] + + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space + i += 2 + + image = skip + + if return_latents: + return image, latent + else: + return image, None + + +@ARCH_REGISTRY.register() +class GFPGANBilinear(nn.Module): + """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. + + It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for + deployment. It can be easily converted to the clean version: GFPGANv1Clean. + + + Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. + + Args: + out_size (int): The spatial size of outputs. + num_style_feat (int): Channel number of style features. Default: 512. + channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. + decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. + fix_decoder (bool): Whether to fix the decoder. Default: True. + + num_mlp (int): Layer number of MLP style layers. Default: 8. + lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. + input_is_latent (bool): Whether input is latent style. Default: False. + different_w (bool): Whether to use different latent w for different layers. Default: False. + narrow (float): The narrow ratio for channels. Default: 1. + sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. + """ + + def __init__( + self, + out_size, + num_style_feat=512, + channel_multiplier=1, + decoder_load_path=None, + fix_decoder=True, + # for stylegan decoder + num_mlp=8, + lr_mlp=0.01, + input_is_latent=False, + different_w=False, + narrow=1, + sft_half=False): + + super(GFPGANBilinear, self).__init__() + self.input_is_latent = input_is_latent + self.different_w = different_w + self.num_style_feat = num_style_feat + + unet_narrow = narrow * 0.5 # by default, use a half of input channels + channels = { + '4': int(512 * unet_narrow), + '8': int(512 * unet_narrow), + '16': int(512 * unet_narrow), + '32': int(512 * unet_narrow), + '64': int(256 * channel_multiplier * unet_narrow), + '128': int(128 * channel_multiplier * unet_narrow), + '256': int(64 * channel_multiplier * unet_narrow), + '512': int(32 * channel_multiplier * unet_narrow), + '1024': int(16 * channel_multiplier * unet_narrow) + } + + self.log_size = int(math.log(out_size, 2)) + first_out_size = 2**(int(math.log(out_size, 2))) + + self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(self.log_size, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels)) + in_channels = out_channels + + self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) + + # upsample + in_channels = channels['4'] + self.conv_body_up = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) + in_channels = out_channels + + # to RGB + self.toRGB = nn.ModuleList() + for i in range(3, self.log_size + 1): + self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) + + if different_w: + linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat + else: + linear_out_channel = num_style_feat + + self.final_linear = EqualLinear( + channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) + + # the decoder: stylegan2 generator with SFT modulations + self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT( + out_size=out_size, + num_style_feat=num_style_feat, + num_mlp=num_mlp, + channel_multiplier=channel_multiplier, + lr_mlp=lr_mlp, + narrow=narrow, + sft_half=sft_half) + + # load pre-trained stylegan2 model if necessary + if decoder_load_path: + self.stylegan_decoder.load_state_dict( + torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) + # fix decoder without updating params + if fix_decoder: + for _, param in self.stylegan_decoder.named_parameters(): + param.requires_grad = False + + # for SFT modulations (scale and shift) + self.condition_scale = nn.ModuleList() + self.condition_shift = nn.ModuleList() + for i in range(3, self.log_size + 1): + out_channels = channels[f'{2**i}'] + if sft_half: + sft_out_channels = out_channels + else: + sft_out_channels = out_channels * 2 + self.condition_scale.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) + self.condition_shift.append( + nn.Sequential( + EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), + ScaledLeakyReLU(0.2), + EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) + + def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True): + """Forward function for GFPGANBilinear. + + Args: + x (Tensor): Input images. + return_latents (bool): Whether to return style latents. Default: False. + return_rgb (bool): Whether return intermediate rgb images. Default: True. + randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. + """ + conditions = [] + unet_skips = [] + out_rgbs = [] + + # encoder + feat = self.conv_body_first(x) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + unet_skips.insert(0, feat) + + feat = self.final_conv(feat) + + # style code + style_code = self.final_linear(feat.view(feat.size(0), -1)) + if self.different_w: + style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) + + # decode + for i in range(self.log_size - 2): + # add unet skip + feat = feat + unet_skips[i] + # ResUpLayer + feat = self.conv_body_up[i](feat) + # generate scale and shift for SFT layers + scale = self.condition_scale[i](feat) + conditions.append(scale.clone()) + shift = self.condition_shift[i](feat) + conditions.append(shift.clone()) + # generate rgb images + if return_rgb: + out_rgbs.append(self.toRGB[i](feat)) + + # decoder + image, _ = self.stylegan_decoder([style_code], + conditions, + return_latents=return_latents, + input_is_latent=self.input_is_latent, + randomize_noise=randomize_noise) + + return image, out_rgbs diff --git a/gfpgan/models/gfpgan_model.py b/gfpgan/models/gfpgan_model.py index c3d51b0b..684fc601 100644 --- a/gfpgan/models/gfpgan_model.py +++ b/gfpgan/models/gfpgan_model.py @@ -209,18 +209,18 @@ def feed_data(self, data): self.loc_right_eyes = data['loc_right_eye'] self.loc_mouths = data['loc_mouth'] - # uncomment to check data - # import torchvision - # if self.opt['rank'] == 0: - # import os - # os.makedirs('tmp/gt', exist_ok=True) - # os.makedirs('tmp/lq', exist_ok=True) - # print(self.idx) - # torchvision.utils.save_image( - # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) - # torchvision.utils.save_image( - # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) - # self.idx = self.idx + 1 + # uncomment to check data + # import torchvision + # if self.opt['rank'] == 0: + # import os + # os.makedirs('tmp/gt', exist_ok=True) + # os.makedirs('tmp/lq', exist_ok=True) + # print(self.idx) + # torchvision.utils.save_image( + # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + # torchvision.utils.save_image( + # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1)) + # self.idx = self.idx + 1 def construct_img_pyramid(self): """Construct image pyramid for intermediate restoration loss""" @@ -300,10 +300,9 @@ def optimize_parameters(self, current_iter): p.requires_grad = False # image pyramid loss weight - if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')): - pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1) - else: - pyramid_loss_weight = 1e-12 # very small loss + pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0) + if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')): + pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error if pyramid_loss_weight > 0: self.output, out_rgbs = self.net_g(self.lq, return_rgb=True) pyramid_gt = self.construct_img_pyramid() diff --git a/gfpgan/utils.py b/gfpgan/utils.py index f3e163e9..1cc104d8 100644 --- a/gfpgan/utils.py +++ b/gfpgan/utils.py @@ -6,6 +6,7 @@ from facexlib.utils.face_restoration_helper import FaceRestoreHelper from torchvision.transforms.functional import normalize +from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear from gfpgan.archs.gfpganv1_arch import GFPGANv1 from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean @@ -47,7 +48,19 @@ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg different_w=True, narrow=1, sft_half=True) - else: + elif arch == 'bilinear': + self.gfpgan = GFPGANBilinear( + out_size=512, + num_style_feat=512, + channel_multiplier=channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=1, + sft_half=True) + elif arch == 'original': self.gfpgan = GFPGANv1( out_size=512, num_style_feat=512, diff --git a/inference_gfpgan.py b/inference_gfpgan.py index a426cfc7..49889fc4 100644 --- a/inference_gfpgan.py +++ b/inference_gfpgan.py @@ -10,39 +10,56 @@ def main(): - """Inference demo for GFPGAN. + """Inference demo for GFPGAN (for users). """ parser = argparse.ArgumentParser() - parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image') - parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original') - parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2') - parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth') - parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler') parser.add_argument( - '--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing') - parser.add_argument('--test_path', type=str, default='inputs/whole_imgs', help='Input folder') + '-i', + '--input', + type=str, + default='inputs/whole_imgs', + help='Input image or folder. Default: inputs/whole_imgs') + parser.add_argument('-o', '--output', type=str, default='results', help='Output folder. Default: results') + # we use version to select models, which is more user-friendly + parser.add_argument( + '-v', '--version', type=str, default='1.3', help='GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3') + parser.add_argument( + '-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2') + + parser.add_argument( + '--bg_upsampler', type=str, default='realesrgan', help='background upsampler. Default: realesrgan') + parser.add_argument( + '--bg_tile', + type=int, + default=400, + help='Tile size for background sampler, 0 for no tile during testing. Default: 400') parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces') parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face') parser.add_argument('--aligned', action='store_true', help='Input are aligned faces') - parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images') - parser.add_argument('--save_root', type=str, default='results', help='Path to save root') parser.add_argument( '--ext', type=str, default='auto', - help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs') + help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto') args = parser.parse_args() args = parser.parse_args() - if args.test_path.endswith('/'): - args.test_path = args.test_path[:-1] - os.makedirs(args.save_root, exist_ok=True) - # background upsampler + # ------------------------ input & output ------------------------ + if args.input.endswith('/'): + args.input = args.input[:-1] + if os.path.isfile(args.input): + img_list = [args.input] + else: + img_list = sorted(glob.glob(os.path.join(args.input, '*'))) + + os.makedirs(args.output, exist_ok=True) + + # ------------------------ set up background upsampler ------------------------ if args.bg_upsampler == 'realesrgan': if not torch.cuda.is_available(): # CPU import warnings - warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. ' + warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. ' 'If you really want to use it, please modify the corresponding codes.') bg_upsampler = None else: @@ -59,15 +76,38 @@ def main(): half=True) # need to set False in CPU mode else: bg_upsampler = None - # set up GFPGAN restorer + + # ------------------------ set up GFPGAN restorer ------------------------ + if args.version == '1': + arch = 'original' + channel_multiplier = 1 + model_name = 'GFPGANv1' + elif args.version == '1.2': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANCleanv1-NoCE-C2' + elif args.version == '1.3': + arch = 'clean' + channel_multiplier = 2 + model_name = 'GFPGANv1.3' + else: + raise ValueError(f'Wrong model version {args.version}.') + + # determine model paths + model_path = os.path.join('experiments/pretrained_models', model_name + '.pth') + if not os.path.isfile(model_path): + model_path = os.path.join('realesrgan/weights', model_name + '.pth') + if not os.path.isfile(model_path): + raise ValueError(f'Model {model_name} does not exist.') + restorer = GFPGANer( - model_path=args.model_path, + model_path=model_path, upscale=args.upscale, - arch=args.arch, - channel_multiplier=args.channel, + arch=arch, + channel_multiplier=channel_multiplier, bg_upsampler=bg_upsampler) - img_list = sorted(glob.glob(os.path.join(args.test_path, '*'))) + # ------------------------ restore ------------------------ for img_path in img_list: # read image img_name = os.path.basename(img_path) @@ -77,23 +117,23 @@ def main(): # restore faces and background if necessary cropped_faces, restored_faces, restored_img = restorer.enhance( - input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back) + input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=True) # save faces for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)): # save cropped face - save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png') + save_crop_path = os.path.join(args.output, 'cropped_faces', f'{basename}_{idx:02d}.png') imwrite(cropped_face, save_crop_path) # save restored face if args.suffix is not None: save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png' else: save_face_name = f'{basename}_{idx:02d}.png' - save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name) + save_restore_path = os.path.join(args.output, 'restored_faces', save_face_name) imwrite(restored_face, save_restore_path) # save comparison image cmp_img = np.concatenate((cropped_face, restored_face), axis=1) - imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png')) + imwrite(cmp_img, os.path.join(args.output, 'cmp', f'{basename}_{idx:02d}.png')) # save restored img if restored_img is not None: @@ -103,13 +143,12 @@ def main(): extension = args.ext if args.suffix is not None: - save_restore_path = os.path.join(args.save_root, 'restored_imgs', - f'{basename}_{args.suffix}.{extension}') + save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}_{args.suffix}.{extension}') else: - save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}') + save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}.{extension}') imwrite(restored_img, save_restore_path) - print(f'Results are in the [{args.save_root}] folder.') + print(f'Results are in the [{args.output}] folder.') if __name__ == '__main__': diff --git a/scripts/convert_gfpganv_to_clean.py b/scripts/convert_gfpganv_to_clean.py new file mode 100644 index 00000000..8fdccb61 --- /dev/null +++ b/scripts/convert_gfpganv_to_clean.py @@ -0,0 +1,164 @@ +import argparse +import math +import torch + +from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean + + +def modify_checkpoint(checkpoint_bilinear, checkpoint_clean): + for ori_k, ori_v in checkpoint_bilinear.items(): + if 'stylegan_decoder' in ori_k: + if 'style_mlp' in ori_k: # style_mlp_layers + lr_mul = 0.01 + prefix, name, idx, var = ori_k.split('.') + idx = (int(idx) * 2) - 1 + crt_k = f'{prefix}.{name}.{idx}.{var}' + if var == 'weight': + _, c_in = ori_v.size() + scale = (1 / math.sqrt(c_in)) * lr_mul + crt_v = ori_v * scale * 2**0.5 + else: + crt_v = ori_v * lr_mul * 2**0.5 + checkpoint_clean[crt_k] = crt_v + elif 'modulation' in ori_k: # modulation in StyleConv + lr_mul = 1 + crt_k = ori_k + var = ori_k.split('.')[-1] + if var == 'weight': + _, c_in = ori_v.size() + scale = (1 / math.sqrt(c_in)) * lr_mul + crt_v = ori_v * scale + else: + crt_v = ori_v * lr_mul + checkpoint_clean[crt_k] = crt_v + elif 'style_conv' in ori_k: + # StyleConv in style_conv1 and style_convs + if 'activate' in ori_k: # FusedLeakyReLU + # eg. style_conv1.activate.bias + # eg. style_convs.13.activate.bias + split_rlt = ori_k.split('.') + if len(split_rlt) == 4: + prefix, name, _, var = split_rlt + crt_k = f'{prefix}.{name}.{var}' + elif len(split_rlt) == 5: + prefix, name, idx, _, var = split_rlt + crt_k = f'{prefix}.{name}.{idx}.{var}' + crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU + c = crt_v.size(0) + checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1) + elif 'modulated_conv' in ori_k: + # eg. style_conv1.modulated_conv.weight + # eg. style_convs.13.modulated_conv.weight + _, c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v * scale + elif 'weight' in ori_k: + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v * 2**0.5 + elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs + if 'modulated_conv' in ori_k: + # eg. to_rgb1.modulated_conv.weight + # eg. to_rgbs.5.modulated_conv.weight + _, c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v * scale + else: + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v + else: + crt_k = ori_k + checkpoint_clean[crt_k] = ori_v + # end of 'stylegan_decoder' + elif 'conv_body_first' in ori_k or 'final_conv' in ori_k: + # key name + name, _, var = ori_k.split('.') + crt_k = f'{name}.{var}' + # weight and bias + if var == 'weight': + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale * 2**0.5 + else: + checkpoint_clean[crt_k] = ori_v * 2**0.5 + elif 'conv_body' in ori_k: + if 'conv_body_up' in ori_k: + ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight') + ori_k = ori_k.replace('skip.weight', 'skip.1.weight') + name1, idx1, name2, _, var = ori_k.split('.') + crt_k = f'{name1}.{idx1}.{name2}.{var}' + if name2 == 'skip': + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale / 2**0.5 + else: + if var == 'weight': + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale + else: + checkpoint_clean[crt_k] = ori_v + if 'conv1' in ori_k: + checkpoint_clean[crt_k] *= 2**0.5 + elif 'toRGB' in ori_k: + crt_k = ori_k + if 'weight' in ori_k: + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale + else: + checkpoint_clean[crt_k] = ori_v + elif 'final_linear' in ori_k: + crt_k = ori_k + if 'weight' in ori_k: + _, c_in = ori_v.size() + scale = 1 / math.sqrt(c_in) + checkpoint_clean[crt_k] = ori_v * scale + else: + checkpoint_clean[crt_k] = ori_v + elif 'condition' in ori_k: + crt_k = ori_k + if '0.weight' in ori_k: + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale * 2**0.5 + elif '0.bias' in ori_k: + checkpoint_clean[crt_k] = ori_v * 2**0.5 + elif '2.weight' in ori_k: + c_out, c_in, k1, k2 = ori_v.size() + scale = 1 / math.sqrt(c_in * k1 * k2) + checkpoint_clean[crt_k] = ori_v * scale + elif '2.bias' in ori_k: + checkpoint_clean[crt_k] = ori_v + + return checkpoint_clean + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--ori_path', type=str, help='Path to the original model') + parser.add_argument('--narrow', type=float, default=1) + parser.add_argument('--channel_multiplier', type=float, default=2) + parser.add_argument('--save_path', type=str) + args = parser.parse_args() + + ori_ckpt = torch.load(args.ori_path)['params_ema'] + + net = GFPGANv1Clean( + 512, + num_style_feat=512, + channel_multiplier=args.channel_multiplier, + decoder_load_path=None, + fix_decoder=False, + # for stylegan decoder + num_mlp=8, + input_is_latent=True, + different_w=True, + narrow=args.narrow, + sft_half=True) + crt_ckpt = net.state_dict() + + crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt) + print(f'Save to {args.save_path}.') + torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)