Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed. #7816

Merged
merged 6 commits into from
May 3, 2024

Conversation

HelloWorldBeginner
Copy link
Contributor

@HelloWorldBeginner HelloWorldBeginner commented Apr 29, 2024

What does this PR do?

Added support for SDXL finetune on AscendNPU and fixed the bug causing the hang out when saving models using the deepspeed distributed framework. DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.

Fixes # (issue)

Before submitting

I fine-tuned SDXL on AscendNPU, and the results are good. I hope diffusers can support more devices.

20240429-163357

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul sayakpaul requested a review from yiyixuxu April 29, 2024 11:01
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@HelloWorldBeginner
Copy link
Contributor Author

I found some errors in checks, how can I fix it?

examples/controlnet/train_controlnet_sdxl.py:16:1: I001 [*] Import block is un-sorted or un-formatted
examples/text_to_image/train_text_to_image_lora_sdxl.py:18:1: I001 [*] Import block is un-sorted or un-formatted
src/diffusers/models/activations.py:16:1: I001 [*] Import block is un-sorted or un-formatted
src/diffusers/models/attention_processor.py:14:1: I001 [*] Import block is un-sorted or un-formatted

It's strange because I didn't modify the code here.

@sayakpaul
Copy link
Member

You can do the following:

  • Create a fresh Python environment.
  • Run pip install -e ".[quality]" from the root of diffusers.
  • Run make style && make quality.

@HelloWorldBeginner
Copy link
Contributor Author

I've already fixed the code formatting issues in the checks.

@yiyixuxu
Copy link
Collaborator

@sayakpaul
I'm ok with this PR if you think it is needed :)

@sayakpaul
Copy link
Member

Thanks, Yiyi.

I am alright with the PR because the number of changes is extremely minimal.

Comment on lines 1280 to 1291
if is_torch_npu_available() and query.dtype in (torch.float16, torch.bfloat16):
hidden_states = torch_npu.npu_fusion_attention(
query, key, value, attn.heads, input_layout="BNSD",
pse=None,
atten_mask=attention_mask,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.,
sync=False,
inner_precise=0,
)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so when Torch NPU is available it will default to using torch_npu.npu_fusion_attention right? But our current library-wide behaviour is that when PyTorch 1.x is available we rely on AttnProcessor and when PyTorch 2.x is available we rely on AttnProcessor2_0 which uses F.scaled_dot_product_attention(). These are two default attention processors we use based on the available PyTorch version.

So, with that in mind, I find this to be slightly problematic as we are moving away from a conceptual understanding. Folks that use the library already know that F.scaled_dot_product_attention() is being used if they're using PyTorch 2.x unless stated otherwise. Therefore, I think it might be better to have an AttnProcessorNPU class and use that instead. In that class, we will be able to do proper error handling too such as if query.dtype is not torch.float16 or torch.blfoat16, we error out.

I would like to ask @yiyixuxu and @DN6 for their opinions here too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I understand your point. It would be a good idea to separate the torch_npu flash attention module.

@HelloWorldBeginner
Copy link
Contributor Author

HelloWorldBeginner commented Apr 30, 2024

I've separated the NPU flash attention into a module and implemented a switch control using parameters.
I've tested it and it works.
20240430-154351
@sayakpaul

return hidden_states * self.gelu(gate)
if is_torch_npu_available():
hidden_states = self.proj(hidden_states, *args)
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not use the existing self.gelu() when using NPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared to self.gelu(), using torch_npu.npu_geglu can run faster and save memory on NPU.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this.

For me, the following would be nice to add before we merge:

@yiyixuxu could you review the changes introduced the core modules of the library and comment?

Comment on lines 105 to 110
if is_torch_npu_available():
hidden_states = self.proj(hidden_states, *args)
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
else:
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_torch_npu_available():
hidden_states = self.proj(hidden_states, *args)
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
else:
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
hidden_states = self.proj(hidden_states)
if if is_torch_npu_available():
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
else:
hidden_states, gate = hidden_states.chunk(2, dim=-1)
return hidden_states * self.gelu(gate)

@HelloWorldBeginner
Copy link
Contributor Author

Thanks for working on this.

For me, the following would be nice to add before we merge:

@yiyixuxu could you review the changes introduced the core modules of the library and comment?

Sure, I'll add unit tests and documentation later.

@HelloWorldBeginner
Copy link
Contributor Author

I've updated the code. @sayakpaul

Comment on lines +308 to +312
@unittest.skipIf(
torch_device != "npu" or not is_torch_npu_available(),
reason="torch npu flash attention is only available with NPU and `torch_npu` installed",
)
def test_set_torch_npu_flash_attn_processor_determinism(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test seems quite nice to me. Thanks for working on it!

@HelloWorldBeginner
Copy link
Contributor Author

Hi @sayakpaul.
I noticed the PR is still open. Does the code still need review from others?

@sayakpaul sayakpaul requested a review from yiyixuxu May 3, 2024 15:54
@sayakpaul
Copy link
Member

Yes, it needs reviews from our core maintainer @yiyixuxu/

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu yiyixuxu merged commit 5823736 into huggingface:main May 3, 2024
15 checks passed
XSE42 added a commit to XSE42/diffusers3d that referenced this pull request May 13, 2024
diffusers commit 5823736
    Add Ascend NPU support for SDXL fine-tuning and fix the model saving
    bug when using DeepSpeed.
    (huggingface/diffusers#7816)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants