Skip to content

Commit

Permalink
Modularize instruct_pix2pix SD inferencing during and after training …
Browse files Browse the repository at this point in the history
…in examples (#7603)

* Modularize instruct_pix2pix code

* quality check

* quality check

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
satani99 and sayakpaul authored Apr 10, 2024
1 parent a402431 commit 37e9d69
Showing 1 changed file with 63 additions and 70 deletions.
133 changes: 63 additions & 70 deletions examples/instruct_pix2pix/train_instruct_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from diffusers.utils.torch_utils import is_compiled_module


if is_wandb_available():
import wandb

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")

Expand All @@ -64,6 +67,48 @@
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]


def log_validation(
pipeline,
args,
accelerator,
generator,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)

for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
tracker.log({"validation": wandb_table})


def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
parser.add_argument(
Expand Down Expand Up @@ -411,11 +456,6 @@ def main():

generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)

if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -517,7 +557,8 @@ def save_model_hook(models, weights, output_dir):
model.save_pretrained(os.path.join(output_dir, "unet"))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

def load_model_hook(models, input_dir):
if args.use_ema:
Expand Down Expand Up @@ -923,11 +964,6 @@ def collate_fn(examples):
and (args.validation_prompt is not None)
and (epoch % args.validation_epochs == 0)
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
Expand All @@ -942,38 +978,14 @@ def collate_fn(examples):
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

# run inference
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)

with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)

for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"validation": wandb_table})

log_validation(
pipeline,
args,
accelerator,
generator,
)

if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
Expand All @@ -984,15 +996,14 @@ def collate_fn(examples):
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())

pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
unet=unet,
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
)
Expand All @@ -1006,31 +1017,13 @@ def collate_fn(examples):
ignore_patterns=["step_*", "epoch_*"],
)

if args.validation_prompt is not None:
edited_images = []
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)

for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"test": wandb_table})

if (args.val_image_url is not None) and (args.validation_prompt is not None):
log_validation(
pipeline,
args,
accelerator,
generator,
)
accelerator.end_training()


Expand Down

0 comments on commit 37e9d69

Please sign in to comment.