-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed. #7816
Conversation
…bug when using DeepSpeed.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I found some errors in checks, how can I fix it?
It's strange because I didn't modify the code here. |
You can do the following:
|
d2bf131
to
6536fc8
Compare
I've already fixed the code formatting issues in the checks. |
@sayakpaul |
Thanks, Yiyi. I am alright with the PR because the number of changes is extremely minimal. |
if is_torch_npu_available() and query.dtype in (torch.float16, torch.bfloat16): | ||
hidden_states = torch_npu.npu_fusion_attention( | ||
query, key, value, attn.heads, input_layout="BNSD", | ||
pse=None, | ||
atten_mask=attention_mask, | ||
scale=1.0 / math.sqrt(query.shape[-1]), | ||
pre_tockens=65536, | ||
next_tockens=65536, | ||
keep_prob=1., | ||
sync=False, | ||
inner_precise=0, | ||
)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so when Torch NPU is available it will default to using torch_npu.npu_fusion_attention
right? But our current library-wide behaviour is that when PyTorch 1.x is available we rely on AttnProcessor
and when PyTorch 2.x is available we rely on AttnProcessor2_0
which uses F.scaled_dot_product_attention()
. These are two default attention processors we use based on the available PyTorch version.
So, with that in mind, I find this to be slightly problematic as we are moving away from a conceptual understanding. Folks that use the library already know that F.scaled_dot_product_attention()
is being used if they're using PyTorch 2.x unless stated otherwise. Therefore, I think it might be better to have an AttnProcessorNPU
class and use that instead. In that class, we will be able to do proper error handling too such as if query.dtype
is not torch.float16
or torch.blfoat16
, we error out.
I would like to ask @yiyixuxu and @DN6 for their opinions here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I understand your point. It would be a good idea to separate the torch_npu flash attention module.
e4a39ae
to
63c6045
Compare
I've separated the NPU flash attention into a module and implemented a switch control using parameters. |
return hidden_states * self.gelu(gate) | ||
if is_torch_npu_available(): | ||
hidden_states = self.proj(hidden_states, *args) | ||
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not use the existing self.gelu()
when using NPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compared to self.gelu(), using torch_npu.npu_geglu can run faster and save memory on NPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this.
For me, the following would be nice to add before we merge:
- Documentation -- add an entry about the NPU processor to https://huggingface.co/docs/diffusers/main/en/api/attnprocessor
- Test: Similar to
def test_set_xformers_attn_processor_for_determinism(self):
@yiyixuxu could you review the changes introduced the core modules of the library and comment?
src/diffusers/models/activations.py
Outdated
if is_torch_npu_available(): | ||
hidden_states = self.proj(hidden_states, *args) | ||
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] | ||
else: | ||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) | ||
return hidden_states * self.gelu(gate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_torch_npu_available(): | |
hidden_states = self.proj(hidden_states, *args) | |
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] | |
else: | |
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) | |
return hidden_states * self.gelu(gate) | |
hidden_states = self.proj(hidden_states) | |
if if is_torch_npu_available(): | |
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] | |
else: | |
hidden_states, gate = hidden_states.chunk(2, dim=-1) | |
return hidden_states * self.gelu(gate) |
Sure, I'll add unit tests and documentation later. |
I've updated the code. @sayakpaul |
@unittest.skipIf( | ||
torch_device != "npu" or not is_torch_npu_available(), | ||
reason="torch npu flash attention is only available with NPU and `torch_npu` installed", | ||
) | ||
def test_set_torch_npu_flash_attn_processor_determinism(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test seems quite nice to me. Thanks for working on it!
Hi @sayakpaul. |
Yes, it needs reviews from our core maintainer @yiyixuxu/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
diffusers commit 5823736 Add Ascend NPU support for SDXL fine-tuning and fix the model saving bug when using DeepSpeed. (huggingface/diffusers#7816)
What does this PR do?
Added support for SDXL finetune on AscendNPU and fixed the bug causing the hang out when saving models using the deepspeed distributed framework. DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
I fine-tuned SDXL on AscendNPU, and the results are good. I hope diffusers can support more devices.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.