From 5956b68a6927126daffc2c5a6d1a9a189defe288 Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Mon, 14 Oct 2024 10:09:33 -0600 Subject: [PATCH] Improve the performance and suitable for NPU computing (#9642) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve the performance and suitable for NPU * Improve the performance and suitable for NPU computing * Improve the performance and suitable for NPU * Improve the performance and suitable for NPU * Improve the performance and suitable for NPU * Improve the performance and suitable for NPU --------- Co-authored-by: 蒋硕 Co-authored-by: Sayak Paul --- .../text_to_image/train_text_to_image_sdxl.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 2ca511c857ae..bcf0fa9eb0ac 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -59,6 +59,8 @@ logger = get_logger(__name__) if is_torch_npu_available(): + import torch_npu + torch.npu.config.allow_internal_format = False DATASET_NAME_MAPPING = { @@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae): with torch.no_grad(): model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor + + # There might have slightly performance improvement + # by changing model_input.cpu() to accelerator.gather(model_input) return {"model_input": model_input.cpu()} @@ -935,7 +940,10 @@ def preprocess_train(examples): del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two del text_encoders, tokenizers, vae gc.collect() - torch.cuda.empty_cache() + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + elif torch.cuda.is_available(): + torch.cuda.empty_cache() def collate_fn(examples): model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) @@ -1091,8 +1099,7 @@ def compute_time_ids(original_size, crops_coords_top_left): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids target_size = (args.resolution, args.resolution) add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_time_ids = torch.tensor([add_time_ids]) - add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype) return add_time_ids add_time_ids = torch.cat( @@ -1261,7 +1268,10 @@ def compute_time_ids(original_size, crops_coords_top_left): ) del pipeline - torch.cuda.empty_cache() + if is_torch_npu_available(): + torch_npu.npu.empty_cache() + elif torch.cuda.is_available(): + torch.cuda.empty_cache() if args.use_ema: # Switch back to the original UNet parameters.