diff --git a/Comparisons.md b/Comparisons.md
new file mode 100644
index 00000000..1542d4c0
--- /dev/null
+++ b/Comparisons.md
@@ -0,0 +1,24 @@
+# Comparisons
+
+## Comparisons among different model versions
+
+Note that V1.3 is not always better than V1.2. You may need to try different models based on your purpose and inputs.
+
+| Version | Strengths | Weaknesses |
+| :---: | :---: | :---: |
+|V1.3 | ✓ natural outputs
✓better results on very low-quality inputs
✓ work on relatively high-quality inputs
✓ can have repeated (twice) restorations | ✗ not very sharp
✗ have a slight change on identity |
+|V1.2 | ✓ sharper output
✓ with beauty makeup | ✗ some outputs are unnatural|
+
+For the following images, you may need to **zoom in** for comparing details, or **click the image** to see in the full size.
+
+| Input | V1 | V1.2 | V1.3
+| :---: | :---: | :---: | :---: |
+|![019_Anne_Hathaway_01_00](https://user-images.githubusercontent.com/17445847/153762146-96b25999-4ddd-42a5-a3fe-bb90565f4c4f.png)| ![](https://user-images.githubusercontent.com/17445847/153762256-ef41e749-5a27-495c-8a9c-d8403be55869.png) | ![](https://user-images.githubusercontent.com/17445847/153762297-d41582fc-6253-4e7e-a1ce-4dc237ae3bf3.png) | ![](https://user-images.githubusercontent.com/17445847/153762215-e0535e94-b5ba-426e-97b5-35c00873604d.png) |
+| ![106_Harry_Styles_00_00](https://user-images.githubusercontent.com/17445847/153789040-632c0eda-c15a-43e9-a63c-9ead64f92d4a.png) | ![](https://user-images.githubusercontent.com/17445847/153789172-93cd4980-5318-4633-a07e-1c8f8064ff89.png) | ![](https://user-images.githubusercontent.com/17445847/153789185-f7b268a7-d1db-47b0-ae4a-335e5d657a18.png) | ![](https://user-images.githubusercontent.com/17445847/153789198-7c7f3bca-0ef0-4494-92f0-20aa6f7d7464.png)|
+| ![076_Paris_Hilton_00_00](https://user-images.githubusercontent.com/17445847/153789607-86387770-9db8-441f-b08a-c9679b121b85.png) | ![](https://user-images.githubusercontent.com/17445847/153789619-e56b438a-78a0-425d-8f44-ec4692a43dda.png) | ![](https://user-images.githubusercontent.com/17445847/153789633-5b28f778-3b7f-4e08-8a1d-740ca6e82d8a.png) | ![](https://user-images.githubusercontent.com/17445847/153789645-bc623f21-b32d-4fc3-bfe9-61203407a180.png)|
+| ![008_George_Clooney_00_00](https://user-images.githubusercontent.com/17445847/153790017-0c3ca94d-1c9d-4a0e-b539-ab12d4da98ff.png) | ![](https://user-images.githubusercontent.com/17445847/153790028-fb0d38ab-399d-4a30-8154-2dcd72ca90e8.png) | ![](https://user-images.githubusercontent.com/17445847/153790044-1ef68e34-6120-4439-a5d9-0b6cdbe9c3d0.png) | ![](https://user-images.githubusercontent.com/17445847/153790059-a8d3cece-8989-4e9a-9ffe-903e1690cfd6.png)|
+| ![057_Madonna_01_00](https://user-images.githubusercontent.com/17445847/153790624-2d0751d0-8fb4-4806-be9d-71b833c2c226.png) | ![](https://user-images.githubusercontent.com/17445847/153790639-7eb870e5-26b2-41dc-b139-b698bb40e6e6.png) | ![](https://user-images.githubusercontent.com/17445847/153790651-86899b7a-a1b6-4242-9e8a-77b462004998.png) | ![](https://user-images.githubusercontent.com/17445847/153790655-c8f6c25b-9b4e-4633-b16f-c43da86cff8f.png)|
+| ![044_Amy_Schumer_01_00](https://user-images.githubusercontent.com/17445847/153790811-3fb4fc46-5b4f-45fe-8fcb-a128de2bfa60.png) | ![](https://user-images.githubusercontent.com/17445847/153790817-d45aa4ff-bfc4-4163-b462-75eef9426fab.png) | ![](https://user-images.githubusercontent.com/17445847/153790824-5f93c3a0-fe5a-42f6-8b4b-5a5de8cd0ac3.png) | ![](https://user-images.githubusercontent.com/17445847/153790835-0edf9944-05c7-41c4-8581-4dc5ffc56c9d.png)|
+| ![012_Jackie_Chan_01_00](https://user-images.githubusercontent.com/17445847/153791176-737b016a-e94f-4898-8db7-43e7762141c9.png) | ![](https://user-images.githubusercontent.com/17445847/153791183-2f25a723-56bf-4cd5-aafe-a35513a6d1c5.png) | ![](https://user-images.githubusercontent.com/17445847/153791194-93416cf9-2b58-4e70-b806-27e14c58d4fd.png) | ![](https://user-images.githubusercontent.com/17445847/153791202-aa98659c-b702-4bce-9c47-a2fa5eccc5ae.png)|
+
+
diff --git a/FAQ.md b/FAQ.md
new file mode 100644
index 00000000..e4d5a49c
--- /dev/null
+++ b/FAQ.md
@@ -0,0 +1,7 @@
+# FAQ
+
+1. **How to finetune the GFPGANCleanv1-NoCE-C2 (v1.2) model**
+
+**A:** 1) The GFPGANCleanv1-NoCE-C2 (v1.2) model uses the *clean* architecture, which is more friendly for deploying.
+2) This model is not directly trained. Instead, it is converted from another *bilinear* model.
+3) If you want to finetune the GFPGANCleanv1-NoCE-C2 (v1.2), you need to finetune its original *bilinear* model, and then do the conversion.
diff --git a/PaperModel.md b/PaperModel.md
index aec81d31..e9c8bdc4 100644
--- a/PaperModel.md
+++ b/PaperModel.md
@@ -60,17 +60,17 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth
- Option 1: Load extensions just-in-time(JIT)
```bash
- BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
+ BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1
# for aligned images
- BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
+ BASICSR_JIT=True python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned
```
- Option 2: Have successfully compiled extensions during installation
```bash
- python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
+ python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1
# for aligned images
- python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
+ python inference_gfpgan.py --input inputs/whole_imgs --output results --version 1 --aligned
```
diff --git a/README.md b/README.md
index 09d1ccce..1aa5c58c 100644
--- a/README.md
+++ b/README.md
@@ -19,7 +19,11 @@
GFPGAN aims at developing a **Practical Algorithm for Real-world Face Restoration**.
It leverages rich and diverse priors encapsulated in a pretrained face GAN (*e.g.*, StyleGAN2) for blind face restoration.
+:question: Frequently Asked Questions can be found in [FAQ.md](FAQ.md).
+
:triangular_flag_on_post: **Updates**
+
+- :fire::fire::white_check_mark: Add **[V1.3 model](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)**, which produces **more natural** restoration results, and better results on *very low-quality* / *high-quality* inputs. See more in [Model zoo](#european_castle-model-zoo), [Comparisons.md](Comparisons.md)
- :white_check_mark: Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/GFPGAN).
- :white_check_mark: Support enhancing non-face regions (background) with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN).
- :white_check_mark: We provide a *clean* version of GFPGAN, which does not require CUDA extensions.
@@ -88,24 +92,54 @@ If you want to use the original model in our paper, please see [PaperModel.md](P
## :zap: Quick Inference
-Download pre-trained models: [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth)
+We take the v1.3 version for an example. More models can be found [here](#european_castle-model-zoo).
+
+Download pre-trained models: [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth)
```bash
-wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth -P experiments/pretrained_models
+wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P experiments/pretrained_models
```
**Inference!**
```bash
-python inference_gfpgan.py --upscale 2 --test_path inputs/whole_imgs --save_root results
+python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2
+```
+
+```console
+Usage: python inference_gfpgan.py -i inputs/whole_imgs -o results -v 1.3 -s 2 [options]...
+
+ -h show this help
+ -i input Input image or folder. Default: inputs/whole_imgs
+ -o output Output folder. Default: results
+ -v version GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3
+ -s upscale The final upsampling scale of the image. Default: 2
+ -bg_upsampler background upsampler. Default: realesrgan
+ -bg_tile Tile size for background sampler, 0 for no tile during testing. Default: 400
+ -suffix Suffix of the restored faces
+ -only_center_face Only restore the center face
+ -aligned Input are aligned faces
+ -ext Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto
```
If you want to use the original model in our paper, please see [PaperModel.md](PaperModel.md) for installation and inference.
## :european_castle: Model Zoo
-- [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth): No colorization; no CUDA extensions are required. It is still in training. Trained with more data with pre-processing.
-- [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth): The paper model, with colorization.
+| Version | Model Name | Description |
+| :---: | :---: | :---: |
+| V1.3 | [GFPGANv1.3.pth](https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth) | Based on V1.2; **more natural** restoration results; better results on very low-quality / high-quality inputs. |
+| V1.2 | [GFPGANCleanv1-NoCE-C2.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth) | No colorization; no CUDA extensions are required. Trained with more data with pre-processing. |
+| V1 | [GFPGANv1.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth) | The paper model, with colorization. |
+
+The comparisons are in [Comparisons.md](Comparisons.md).
+
+Note that V1.3 is not always better than V1.2. You may need to select different models based on your purpose and inputs.
+
+| Version | Strengths | Weaknesses |
+| :---: | :---: | :---: |
+|V1.3 | ✓ natural outputs
✓better results on very low-quality inputs
✓ work on relatively high-quality inputs
✓ can have repeated (twice) restorations | ✗ not very sharp
✗ have a slight change on identity |
+|V1.2 | ✓ sharper output
✓ with beauty makeup | ✗ some outputs are unnatural |
You can find **more models (such as the discriminators)** here: [[Google Drive](https://drive.google.com/drive/folders/17rLiFzcUMoQuhLnptDsKolegHWwJOnHu?usp=sharing)], OR [[Tencent Cloud 腾讯微云](https://share.weiyun.com/ShYoCCoc)]
diff --git a/gfpgan/archs/gfpgan_bilinear_arch.py b/gfpgan/archs/gfpgan_bilinear_arch.py
new file mode 100644
index 00000000..d0537b14
--- /dev/null
+++ b/gfpgan/archs/gfpgan_bilinear_arch.py
@@ -0,0 +1,312 @@
+import math
+import random
+import torch
+from basicsr.archs.stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
+ StyleGAN2GeneratorBilinear)
+from basicsr.utils.registry import ARCH_REGISTRY
+from torch import nn
+
+from .gfpganv1_arch import ResUpBlock
+
+
+class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
+
+ It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
+ deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
+
+ Args:
+ out_size (int): The spatial size of outputs.
+ num_style_feat (int): Channel number of style features. Default: 512.
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
+ narrow (float): The narrow ratio for channels. Default: 1.
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
+ """
+
+ def __init__(self,
+ out_size,
+ num_style_feat=512,
+ num_mlp=8,
+ channel_multiplier=2,
+ lr_mlp=0.01,
+ narrow=1,
+ sft_half=False):
+ super(StyleGAN2GeneratorBilinearSFT, self).__init__(
+ out_size,
+ num_style_feat=num_style_feat,
+ num_mlp=num_mlp,
+ channel_multiplier=channel_multiplier,
+ lr_mlp=lr_mlp,
+ narrow=narrow)
+ self.sft_half = sft_half
+
+ def forward(self,
+ styles,
+ conditions,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ truncation=1,
+ truncation_latent=None,
+ inject_index=None,
+ return_latents=False):
+ """Forward function for StyleGAN2GeneratorBilinearSFT.
+
+ Args:
+ styles (list[Tensor]): Sample codes of styles.
+ conditions (list[Tensor]): SFT conditions to generators.
+ input_is_latent (bool): Whether input is latent style. Default: False.
+ noise (Tensor | None): Input noise or None. Default: None.
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
+ truncation (float): The truncation ratio. Default: 1.
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
+ inject_index (int | None): The injection index for mixing noise. Default: None.
+ return_latents (bool): Whether to return style latents. Default: False.
+ """
+ # style codes -> latents with Style MLP layer
+ if not input_is_latent:
+ styles = [self.style_mlp(s) for s in styles]
+ # noises
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers # for each style conv layer
+ else: # use the stored noise
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
+ # style truncation
+ if truncation < 1:
+ style_truncation = []
+ for style in styles:
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
+ styles = style_truncation
+ # get style latents with injection
+ if len(styles) == 1:
+ inject_index = self.num_latent
+
+ if styles[0].ndim < 3:
+ # repeat latent code for all the layers
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else: # used for encoder with different latent code for each layer
+ latent = styles[0]
+ elif len(styles) == 2: # mixing noises
+ if inject_index is None:
+ inject_index = random.randint(1, self.num_latent - 1)
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
+ latent = torch.cat([latent1, latent2], 1)
+
+ # main generation
+ out = self.constant_input(latent.shape[0])
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
+ noise[2::2], self.to_rgbs):
+ out = conv1(out, latent[:, i], noise=noise1)
+
+ # the conditions may have fewer levels
+ if i < len(conditions):
+ # SFT part to combine the conditions
+ if self.sft_half: # only apply SFT to half of the channels
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
+ out = torch.cat([out_same, out_sft], dim=1)
+ else: # apply SFT to all the channels
+ out = out * conditions[i - 1] + conditions[i]
+
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+ else:
+ return image, None
+
+
+@ARCH_REGISTRY.register()
+class GFPGANBilinear(nn.Module):
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
+
+ It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
+ deployment. It can be easily converted to the clean version: GFPGANv1Clean.
+
+
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
+
+ Args:
+ out_size (int): The spatial size of outputs.
+ num_style_feat (int): Channel number of style features. Default: 512.
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
+
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
+ input_is_latent (bool): Whether input is latent style. Default: False.
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
+ narrow (float): The narrow ratio for channels. Default: 1.
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
+ """
+
+ def __init__(
+ self,
+ out_size,
+ num_style_feat=512,
+ channel_multiplier=1,
+ decoder_load_path=None,
+ fix_decoder=True,
+ # for stylegan decoder
+ num_mlp=8,
+ lr_mlp=0.01,
+ input_is_latent=False,
+ different_w=False,
+ narrow=1,
+ sft_half=False):
+
+ super(GFPGANBilinear, self).__init__()
+ self.input_is_latent = input_is_latent
+ self.different_w = different_w
+ self.num_style_feat = num_style_feat
+
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
+ channels = {
+ '4': int(512 * unet_narrow),
+ '8': int(512 * unet_narrow),
+ '16': int(512 * unet_narrow),
+ '32': int(512 * unet_narrow),
+ '64': int(256 * channel_multiplier * unet_narrow),
+ '128': int(128 * channel_multiplier * unet_narrow),
+ '256': int(64 * channel_multiplier * unet_narrow),
+ '512': int(32 * channel_multiplier * unet_narrow),
+ '1024': int(16 * channel_multiplier * unet_narrow)
+ }
+
+ self.log_size = int(math.log(out_size, 2))
+ first_out_size = 2**(int(math.log(out_size, 2)))
+
+ self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True)
+
+ # downsample
+ in_channels = channels[f'{first_out_size}']
+ self.conv_body_down = nn.ModuleList()
+ for i in range(self.log_size, 2, -1):
+ out_channels = channels[f'{2**(i - 1)}']
+ self.conv_body_down.append(ResBlock(in_channels, out_channels))
+ in_channels = out_channels
+
+ self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True)
+
+ # upsample
+ in_channels = channels['4']
+ self.conv_body_up = nn.ModuleList()
+ for i in range(3, self.log_size + 1):
+ out_channels = channels[f'{2**i}']
+ self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
+ in_channels = out_channels
+
+ # to RGB
+ self.toRGB = nn.ModuleList()
+ for i in range(3, self.log_size + 1):
+ self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0))
+
+ if different_w:
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
+ else:
+ linear_out_channel = num_style_feat
+
+ self.final_linear = EqualLinear(
+ channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None)
+
+ # the decoder: stylegan2 generator with SFT modulations
+ self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
+ out_size=out_size,
+ num_style_feat=num_style_feat,
+ num_mlp=num_mlp,
+ channel_multiplier=channel_multiplier,
+ lr_mlp=lr_mlp,
+ narrow=narrow,
+ sft_half=sft_half)
+
+ # load pre-trained stylegan2 model if necessary
+ if decoder_load_path:
+ self.stylegan_decoder.load_state_dict(
+ torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema'])
+ # fix decoder without updating params
+ if fix_decoder:
+ for _, param in self.stylegan_decoder.named_parameters():
+ param.requires_grad = False
+
+ # for SFT modulations (scale and shift)
+ self.condition_scale = nn.ModuleList()
+ self.condition_shift = nn.ModuleList()
+ for i in range(3, self.log_size + 1):
+ out_channels = channels[f'{2**i}']
+ if sft_half:
+ sft_out_channels = out_channels
+ else:
+ sft_out_channels = out_channels * 2
+ self.condition_scale.append(
+ nn.Sequential(
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
+ ScaledLeakyReLU(0.2),
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1)))
+ self.condition_shift.append(
+ nn.Sequential(
+ EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0),
+ ScaledLeakyReLU(0.2),
+ EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0)))
+
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
+ """Forward function for GFPGANBilinear.
+
+ Args:
+ x (Tensor): Input images.
+ return_latents (bool): Whether to return style latents. Default: False.
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
+ """
+ conditions = []
+ unet_skips = []
+ out_rgbs = []
+
+ # encoder
+ feat = self.conv_body_first(x)
+ for i in range(self.log_size - 2):
+ feat = self.conv_body_down[i](feat)
+ unet_skips.insert(0, feat)
+
+ feat = self.final_conv(feat)
+
+ # style code
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
+ if self.different_w:
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
+
+ # decode
+ for i in range(self.log_size - 2):
+ # add unet skip
+ feat = feat + unet_skips[i]
+ # ResUpLayer
+ feat = self.conv_body_up[i](feat)
+ # generate scale and shift for SFT layers
+ scale = self.condition_scale[i](feat)
+ conditions.append(scale.clone())
+ shift = self.condition_shift[i](feat)
+ conditions.append(shift.clone())
+ # generate rgb images
+ if return_rgb:
+ out_rgbs.append(self.toRGB[i](feat))
+
+ # decoder
+ image, _ = self.stylegan_decoder([style_code],
+ conditions,
+ return_latents=return_latents,
+ input_is_latent=self.input_is_latent,
+ randomize_noise=randomize_noise)
+
+ return image, out_rgbs
diff --git a/gfpgan/models/gfpgan_model.py b/gfpgan/models/gfpgan_model.py
index c3d51b0b..684fc601 100644
--- a/gfpgan/models/gfpgan_model.py
+++ b/gfpgan/models/gfpgan_model.py
@@ -209,18 +209,18 @@ def feed_data(self, data):
self.loc_right_eyes = data['loc_right_eye']
self.loc_mouths = data['loc_mouth']
- # uncomment to check data
- # import torchvision
- # if self.opt['rank'] == 0:
- # import os
- # os.makedirs('tmp/gt', exist_ok=True)
- # os.makedirs('tmp/lq', exist_ok=True)
- # print(self.idx)
- # torchvision.utils.save_image(
- # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
- # torchvision.utils.save_image(
- # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
- # self.idx = self.idx + 1
+ # uncomment to check data
+ # import torchvision
+ # if self.opt['rank'] == 0:
+ # import os
+ # os.makedirs('tmp/gt', exist_ok=True)
+ # os.makedirs('tmp/lq', exist_ok=True)
+ # print(self.idx)
+ # torchvision.utils.save_image(
+ # self.gt, f'tmp/gt/gt_{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
+ # torchvision.utils.save_image(
+ # self.lq, f'tmp/lq/lq{self.idx}.png', nrow=4, padding=2, normalize=True, range=(-1, 1))
+ # self.idx = self.idx + 1
def construct_img_pyramid(self):
"""Construct image pyramid for intermediate restoration loss"""
@@ -300,10 +300,9 @@ def optimize_parameters(self, current_iter):
p.requires_grad = False
# image pyramid loss weight
- if current_iter < self.opt['train'].get('remove_pyramid_loss', float('inf')):
- pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 1)
- else:
- pyramid_loss_weight = 1e-12 # very small loss
+ pyramid_loss_weight = self.opt['train'].get('pyramid_loss_weight', 0)
+ if pyramid_loss_weight > 0 and current_iter > self.opt['train'].get('remove_pyramid_loss', float('inf')):
+ pyramid_loss_weight = 1e-12 # very small weight to avoid unused param error
if pyramid_loss_weight > 0:
self.output, out_rgbs = self.net_g(self.lq, return_rgb=True)
pyramid_gt = self.construct_img_pyramid()
diff --git a/gfpgan/utils.py b/gfpgan/utils.py
index f3e163e9..1cc104d8 100644
--- a/gfpgan/utils.py
+++ b/gfpgan/utils.py
@@ -6,6 +6,7 @@
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
+from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
@@ -47,7 +48,19 @@ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg
different_w=True,
narrow=1,
sft_half=True)
- else:
+ elif arch == 'bilinear':
+ self.gfpgan = GFPGANBilinear(
+ out_size=512,
+ num_style_feat=512,
+ channel_multiplier=channel_multiplier,
+ decoder_load_path=None,
+ fix_decoder=False,
+ num_mlp=8,
+ input_is_latent=True,
+ different_w=True,
+ narrow=1,
+ sft_half=True)
+ elif arch == 'original':
self.gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
diff --git a/inference_gfpgan.py b/inference_gfpgan.py
index a426cfc7..49889fc4 100644
--- a/inference_gfpgan.py
+++ b/inference_gfpgan.py
@@ -10,39 +10,56 @@
def main():
- """Inference demo for GFPGAN.
+ """Inference demo for GFPGAN (for users).
"""
parser = argparse.ArgumentParser()
- parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image')
- parser.add_argument('--arch', type=str, default='clean', help='The GFPGAN architecture. Option: clean | original')
- parser.add_argument('--channel', type=int, default=2, help='Channel multiplier for large networks of StyleGAN2')
- parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
- parser.add_argument('--bg_upsampler', type=str, default='realesrgan', help='background upsampler')
parser.add_argument(
- '--bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing')
- parser.add_argument('--test_path', type=str, default='inputs/whole_imgs', help='Input folder')
+ '-i',
+ '--input',
+ type=str,
+ default='inputs/whole_imgs',
+ help='Input image or folder. Default: inputs/whole_imgs')
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder. Default: results')
+ # we use version to select models, which is more user-friendly
+ parser.add_argument(
+ '-v', '--version', type=str, default='1.3', help='GFPGAN model version. Option: 1 | 1.2 | 1.3. Default: 1.3')
+ parser.add_argument(
+ '-s', '--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
+
+ parser.add_argument(
+ '--bg_upsampler', type=str, default='realesrgan', help='background upsampler. Default: realesrgan')
+ parser.add_argument(
+ '--bg_tile',
+ type=int,
+ default=400,
+ help='Tile size for background sampler, 0 for no tile during testing. Default: 400')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
parser.add_argument('--aligned', action='store_true', help='Input are aligned faces')
- parser.add_argument('--paste_back', action='store_false', help='Paste the restored faces back to images')
- parser.add_argument('--save_root', type=str, default='results', help='Path to save root')
parser.add_argument(
'--ext',
type=str,
default='auto',
- help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs. Default: auto')
args = parser.parse_args()
args = parser.parse_args()
- if args.test_path.endswith('/'):
- args.test_path = args.test_path[:-1]
- os.makedirs(args.save_root, exist_ok=True)
- # background upsampler
+ # ------------------------ input & output ------------------------
+ if args.input.endswith('/'):
+ args.input = args.input[:-1]
+ if os.path.isfile(args.input):
+ img_list = [args.input]
+ else:
+ img_list = sorted(glob.glob(os.path.join(args.input, '*')))
+
+ os.makedirs(args.output, exist_ok=True)
+
+ # ------------------------ set up background upsampler ------------------------
if args.bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
- warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
@@ -59,15 +76,38 @@ def main():
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
- # set up GFPGAN restorer
+
+ # ------------------------ set up GFPGAN restorer ------------------------
+ if args.version == '1':
+ arch = 'original'
+ channel_multiplier = 1
+ model_name = 'GFPGANv1'
+ elif args.version == '1.2':
+ arch = 'clean'
+ channel_multiplier = 2
+ model_name = 'GFPGANCleanv1-NoCE-C2'
+ elif args.version == '1.3':
+ arch = 'clean'
+ channel_multiplier = 2
+ model_name = 'GFPGANv1.3'
+ else:
+ raise ValueError(f'Wrong model version {args.version}.')
+
+ # determine model paths
+ model_path = os.path.join('experiments/pretrained_models', model_name + '.pth')
+ if not os.path.isfile(model_path):
+ model_path = os.path.join('realesrgan/weights', model_name + '.pth')
+ if not os.path.isfile(model_path):
+ raise ValueError(f'Model {model_name} does not exist.')
+
restorer = GFPGANer(
- model_path=args.model_path,
+ model_path=model_path,
upscale=args.upscale,
- arch=args.arch,
- channel_multiplier=args.channel,
+ arch=arch,
+ channel_multiplier=channel_multiplier,
bg_upsampler=bg_upsampler)
- img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
+ # ------------------------ restore ------------------------
for img_path in img_list:
# read image
img_name = os.path.basename(img_path)
@@ -77,23 +117,23 @@ def main():
# restore faces and background if necessary
cropped_faces, restored_faces, restored_img = restorer.enhance(
- input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
+ input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=True)
# save faces
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
# save cropped face
- save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
+ save_crop_path = os.path.join(args.output, 'cropped_faces', f'{basename}_{idx:02d}.png')
imwrite(cropped_face, save_crop_path)
# save restored face
if args.suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
- save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
+ save_restore_path = os.path.join(args.output, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save comparison image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
- imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
+ imwrite(cmp_img, os.path.join(args.output, 'cmp', f'{basename}_{idx:02d}.png'))
# save restored img
if restored_img is not None:
@@ -103,13 +143,12 @@ def main():
extension = args.ext
if args.suffix is not None:
- save_restore_path = os.path.join(args.save_root, 'restored_imgs',
- f'{basename}_{args.suffix}.{extension}')
+ save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}_{args.suffix}.{extension}')
else:
- save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}.{extension}')
+ save_restore_path = os.path.join(args.output, 'restored_imgs', f'{basename}.{extension}')
imwrite(restored_img, save_restore_path)
- print(f'Results are in the [{args.save_root}] folder.')
+ print(f'Results are in the [{args.output}] folder.')
if __name__ == '__main__':
diff --git a/scripts/convert_gfpganv_to_clean.py b/scripts/convert_gfpganv_to_clean.py
new file mode 100644
index 00000000..8fdccb61
--- /dev/null
+++ b/scripts/convert_gfpganv_to_clean.py
@@ -0,0 +1,164 @@
+import argparse
+import math
+import torch
+
+from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
+
+
+def modify_checkpoint(checkpoint_bilinear, checkpoint_clean):
+ for ori_k, ori_v in checkpoint_bilinear.items():
+ if 'stylegan_decoder' in ori_k:
+ if 'style_mlp' in ori_k: # style_mlp_layers
+ lr_mul = 0.01
+ prefix, name, idx, var = ori_k.split('.')
+ idx = (int(idx) * 2) - 1
+ crt_k = f'{prefix}.{name}.{idx}.{var}'
+ if var == 'weight':
+ _, c_in = ori_v.size()
+ scale = (1 / math.sqrt(c_in)) * lr_mul
+ crt_v = ori_v * scale * 2**0.5
+ else:
+ crt_v = ori_v * lr_mul * 2**0.5
+ checkpoint_clean[crt_k] = crt_v
+ elif 'modulation' in ori_k: # modulation in StyleConv
+ lr_mul = 1
+ crt_k = ori_k
+ var = ori_k.split('.')[-1]
+ if var == 'weight':
+ _, c_in = ori_v.size()
+ scale = (1 / math.sqrt(c_in)) * lr_mul
+ crt_v = ori_v * scale
+ else:
+ crt_v = ori_v * lr_mul
+ checkpoint_clean[crt_k] = crt_v
+ elif 'style_conv' in ori_k:
+ # StyleConv in style_conv1 and style_convs
+ if 'activate' in ori_k: # FusedLeakyReLU
+ # eg. style_conv1.activate.bias
+ # eg. style_convs.13.activate.bias
+ split_rlt = ori_k.split('.')
+ if len(split_rlt) == 4:
+ prefix, name, _, var = split_rlt
+ crt_k = f'{prefix}.{name}.{var}'
+ elif len(split_rlt) == 5:
+ prefix, name, idx, _, var = split_rlt
+ crt_k = f'{prefix}.{name}.{idx}.{var}'
+ crt_v = ori_v * 2**0.5 # 2**0.5 used in FusedLeakyReLU
+ c = crt_v.size(0)
+ checkpoint_clean[crt_k] = crt_v.view(1, c, 1, 1)
+ elif 'modulated_conv' in ori_k:
+ # eg. style_conv1.modulated_conv.weight
+ # eg. style_convs.13.modulated_conv.weight
+ _, c_out, c_in, k1, k2 = ori_v.size()
+ scale = 1 / math.sqrt(c_in * k1 * k2)
+ crt_k = ori_k
+ checkpoint_clean[crt_k] = ori_v * scale
+ elif 'weight' in ori_k:
+ crt_k = ori_k
+ checkpoint_clean[crt_k] = ori_v * 2**0.5
+ elif 'to_rgb' in ori_k: # StyleConv in to_rgb1 and to_rgbs
+ if 'modulated_conv' in ori_k:
+ # eg. to_rgb1.modulated_conv.weight
+ # eg. to_rgbs.5.modulated_conv.weight
+ _, c_out, c_in, k1, k2 = ori_v.size()
+ scale = 1 / math.sqrt(c_in * k1 * k2)
+ crt_k = ori_k
+ checkpoint_clean[crt_k] = ori_v * scale
+ else:
+ crt_k = ori_k
+ checkpoint_clean[crt_k] = ori_v
+ else:
+ crt_k = ori_k
+ checkpoint_clean[crt_k] = ori_v
+ # end of 'stylegan_decoder'
+ elif 'conv_body_first' in ori_k or 'final_conv' in ori_k:
+ # key name
+ name, _, var = ori_k.split('.')
+ crt_k = f'{name}.{var}'
+ # weight and bias
+ if var == 'weight':
+ c_out, c_in, k1, k2 = ori_v.size()
+ scale = 1 / math.sqrt(c_in * k1 * k2)
+ checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
+ else:
+ checkpoint_clean[crt_k] = ori_v * 2**0.5
+ elif 'conv_body' in ori_k:
+ if 'conv_body_up' in ori_k:
+ ori_k = ori_k.replace('conv2.weight', 'conv2.1.weight')
+ ori_k = ori_k.replace('skip.weight', 'skip.1.weight')
+ name1, idx1, name2, _, var = ori_k.split('.')
+ crt_k = f'{name1}.{idx1}.{name2}.{var}'
+ if name2 == 'skip':
+ c_out, c_in, k1, k2 = ori_v.size()
+ scale = 1 / math.sqrt(c_in * k1 * k2)
+ checkpoint_clean[crt_k] = ori_v * scale / 2**0.5
+ else:
+ if var == 'weight':
+ c_out, c_in, k1, k2 = ori_v.size()
+ scale = 1 / math.sqrt(c_in * k1 * k2)
+ checkpoint_clean[crt_k] = ori_v * scale
+ else:
+ checkpoint_clean[crt_k] = ori_v
+ if 'conv1' in ori_k:
+ checkpoint_clean[crt_k] *= 2**0.5
+ elif 'toRGB' in ori_k:
+ crt_k = ori_k
+ if 'weight' in ori_k:
+ c_out, c_in, k1, k2 = ori_v.size()
+ scale = 1 / math.sqrt(c_in * k1 * k2)
+ checkpoint_clean[crt_k] = ori_v * scale
+ else:
+ checkpoint_clean[crt_k] = ori_v
+ elif 'final_linear' in ori_k:
+ crt_k = ori_k
+ if 'weight' in ori_k:
+ _, c_in = ori_v.size()
+ scale = 1 / math.sqrt(c_in)
+ checkpoint_clean[crt_k] = ori_v * scale
+ else:
+ checkpoint_clean[crt_k] = ori_v
+ elif 'condition' in ori_k:
+ crt_k = ori_k
+ if '0.weight' in ori_k:
+ c_out, c_in, k1, k2 = ori_v.size()
+ scale = 1 / math.sqrt(c_in * k1 * k2)
+ checkpoint_clean[crt_k] = ori_v * scale * 2**0.5
+ elif '0.bias' in ori_k:
+ checkpoint_clean[crt_k] = ori_v * 2**0.5
+ elif '2.weight' in ori_k:
+ c_out, c_in, k1, k2 = ori_v.size()
+ scale = 1 / math.sqrt(c_in * k1 * k2)
+ checkpoint_clean[crt_k] = ori_v * scale
+ elif '2.bias' in ori_k:
+ checkpoint_clean[crt_k] = ori_v
+
+ return checkpoint_clean
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--ori_path', type=str, help='Path to the original model')
+ parser.add_argument('--narrow', type=float, default=1)
+ parser.add_argument('--channel_multiplier', type=float, default=2)
+ parser.add_argument('--save_path', type=str)
+ args = parser.parse_args()
+
+ ori_ckpt = torch.load(args.ori_path)['params_ema']
+
+ net = GFPGANv1Clean(
+ 512,
+ num_style_feat=512,
+ channel_multiplier=args.channel_multiplier,
+ decoder_load_path=None,
+ fix_decoder=False,
+ # for stylegan decoder
+ num_mlp=8,
+ input_is_latent=True,
+ different_w=True,
+ narrow=args.narrow,
+ sft_half=True)
+ crt_ckpt = net.state_dict()
+
+ crt_ckpt = modify_checkpoint(ori_ckpt, crt_ckpt)
+ print(f'Save to {args.save_path}.')
+ torch.save(dict(params_ema=crt_ckpt), args.save_path, _use_new_zipfile_serialization=False)