Skip to content

Commit

Permalink
PixArt-Sigma Implementation (#7654)
Browse files Browse the repository at this point in the history
* support PixArt-DMD

---------

Co-authored-by: jschen <chenjunsong4@h-partners.com>
Co-authored-by: badayvedat <badayvedat@gmail.com>
Co-authored-by: Vedat Baday <54285744+badayvedat@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
7 people authored Apr 24, 2024
1 parent 9ef43f3 commit 39215aa
Show file tree
Hide file tree
Showing 12 changed files with 1,640 additions and 43 deletions.
223 changes: 223 additions & 0 deletions scripts/convert_pixart_sigma_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import argparse
import os

import torch
from transformers import T5EncoderModel, T5Tokenizer

from diffusers import AutoencoderKL, DPMSolverMultistepScheduler, PixArtSigmaPipeline, Transformer2DModel


ckpt_id = "PixArt-alpha"
# https://github.com/PixArt-alpha/PixArt-sigma/blob/dd087141864e30ec44f12cb7448dd654be065e88/scripts/inference.py#L158
interpolation_scale = {256: 0.5, 512: 1, 1024: 2, 2048: 4}


def main(args):
all_state_dict = torch.load(args.orig_ckpt_path)
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}

# Patch embeddings.
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")

# Caption projection.
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")

# AdaLN-single LN
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")

if args.micro_condition:
# Resolution.
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.weight"] = state_dict.pop(
"csize_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_1.bias"] = state_dict.pop(
"csize_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.weight"] = state_dict.pop(
"csize_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.resolution_embedder.linear_2.bias"] = state_dict.pop(
"csize_embedder.mlp.2.bias"
)
# Aspect ratio.
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.weight"] = state_dict.pop(
"ar_embedder.mlp.0.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_1.bias"] = state_dict.pop(
"ar_embedder.mlp.0.bias"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.weight"] = state_dict.pop(
"ar_embedder.mlp.2.weight"
)
converted_state_dict["adaln_single.emb.aspect_ratio_embedder.linear_2.bias"] = state_dict.pop(
"ar_embedder.mlp.2.bias"
)
# Shared norm.
converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias")

for depth in range(28):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)
# Attention is all you need 🤘

# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.bias"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn.proj.bias"
)
if args.qk_norm:
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.weight"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.q_norm.bias"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.weight"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.k_norm.bias"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.bias"
)

# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc1.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.fc2.bias"
)

# Cross-attention.
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)

converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias

converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)

# Final block.
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")

# PixArt XL/2
transformer = Transformer2DModel(
sample_size=args.image_size // 8,
num_layers=28,
attention_head_dim=72,
in_channels=4,
out_channels=8,
patch_size=2,
attention_bias=True,
num_attention_heads=16,
cross_attention_dim=1152,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_single",
norm_elementwise_affine=False,
norm_eps=1e-6,
caption_channels=4096,
interpolation_scale=interpolation_scale[args.image_size],
use_additional_conditions=args.micro_condition,
)
transformer.load_state_dict(converted_state_dict, strict=True)

assert transformer.pos_embed.pos_embed is not None
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
except Exception as e:
print(f"Skipping {str(e)}")
pass
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"

num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")

if args.only_transformer:
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
else:
# pixart-Sigma vae link: https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/tree/main/vae
vae = AutoencoderKL.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="vae")

scheduler = DPMSolverMultistepScheduler()

tokenizer = T5Tokenizer.from_pretrained(f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="tokenizer")
text_encoder = T5EncoderModel.from_pretrained(
f"{ckpt_id}/pixart_sigma_sdxlvae_T5_diffusers", subfolder="text_encoder"
)

pipeline = PixArtSigmaPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
)

pipeline.save_pretrained(args.dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--micro_condition", action="store_true", help="If use Micro-condition in PixArtMS structure during training."
)
parser.add_argument("--qk_norm", action="store_true", help="If use qk norm during training.")
parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=1024,
type=int,
choices=[256, 512, 1024, 2048],
required=False,
help="Image size of pretrained model, 256, 512, 1024, or 2048.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--only_transformer", default=True, type=bool, required=True)

args = parser.parse_args()
main(args)
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
Expand Down Expand Up @@ -637,6 +638,7 @@
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
Expand Down
74 changes: 74 additions & 0 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,77 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value
)

return mask_downsample


class PixArtImageProcessor(VaeImageProcessor):
"""
Image processor for PixArt image resize and crop.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `False`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to grayscale format.
"""

@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_grayscale: bool = False,
):
super().__init__(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
resample=resample,
do_normalize=do_normalize,
do_binarize=do_binarize,
do_convert_grayscale=do_convert_grayscale,
)

@staticmethod
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
"""Returns binned height and width."""
ar = float(height / width)
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
default_hw = ratios[closest_ratio]
return int(default_hw[0]), int(default_hw[1])

@staticmethod
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
orig_height, orig_width = samples.shape[2], samples.shape[3]

# Check if resizing is needed
if orig_height != new_height or orig_width != new_width:
ratio = max(new_height / orig_height, new_width / orig_width)
resized_width = int(orig_width * ratio)
resized_height = int(orig_height * ratio)

# Resize
samples = F.interpolate(
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)

# Center Crop
start_x = (resized_width - new_width) // 2
end_x = start_x + new_width
start_y = (resized_height - new_height) // 2
end_y = start_y + new_height
samples = samples[:, :, start_y:end_y, start_x:end_x]

return samples
9 changes: 7 additions & 2 deletions src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
attention_type: str = "default",
caption_channels: int = None,
interpolation_scale: float = None,
use_additional_conditions: Optional[bool] = None,
):
super().__init__()

Expand All @@ -124,6 +125,12 @@ def __init__(
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.gradient_checkpointing = False
if use_additional_conditions is None:
if norm_type == "ada_norm_single" and sample_size == 128:
use_additional_conditions = True
else:
use_additional_conditions = False
self.use_additional_conditions = use_additional_conditions

# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
Expand Down Expand Up @@ -305,9 +312,7 @@ def _init_patched_inputs(self, norm_type):

# PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
if self.config.norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_cascade"] = [
Expand Down Expand Up @@ -450,7 +450,7 @@
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_cascade import (
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/pipelines/pixart_alpha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -32,7 +33,13 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_pixart_alpha import PixArtAlphaPipeline
from .pipeline_pixart_alpha import (
ASPECT_RATIO_256_BIN,
ASPECT_RATIO_512_BIN,
ASPECT_RATIO_1024_BIN,
PixArtAlphaPipeline,
)
from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline

else:
import sys
Expand Down
Loading

0 comments on commit 39215aa

Please sign in to comment.