Skip to content

Commit

Permalink
Improve the performance and suitable for NPU computing (#9642)
Browse files Browse the repository at this point in the history
* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU computing

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

---------

Co-authored-by: 蒋硕 <jiangshuo9@h-partners.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
3 people authored Oct 14, 2024
1 parent 8d81564 commit 5956b68
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@

logger = get_logger(__name__)
if is_torch_npu_available():
import torch_npu

torch.npu.config.allow_internal_format = False

DATASET_NAME_MAPPING = {
Expand Down Expand Up @@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae):
with torch.no_grad():
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor

# There might have slightly performance improvement
# by changing model_input.cpu() to accelerator.gather(model_input)
return {"model_input": model_input.cpu()}


Expand Down Expand Up @@ -935,7 +940,10 @@ def preprocess_train(examples):
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
del text_encoders, tokenizers, vae
gc.collect()
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()

def collate_fn(examples):
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
Expand Down Expand Up @@ -1091,8 +1099,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
return add_time_ids

add_time_ids = torch.cat(
Expand Down Expand Up @@ -1261,7 +1268,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
)

del pipeline
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()

if args.use_ema:
# Switch back to the original UNet parameters.
Expand Down

0 comments on commit 5956b68

Please sign in to comment.