Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

O(n^m) to O(n) for finding no target names #2372

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

AllenHW
Copy link

@AllenHW AllenHW commented Feb 10, 2025

The code snippet finds modules that are not targeted by the LoRA adaptor.

Previous implementation is a double for-loop along the modules in the model and lora targets, and has a O(n*m) runtime, where n can be up to a 1000 and m can be up to 500 depending on the LoRA. The logic is meant to find model modules that don't contain a suffix (starting with a '.' or the beginning of the word) found in LoRA targets.

Instead of a double for loop, we could split module names by '.' to find all potential suffixes, and check if any of them are contained in the LoRA targets, which have been turned into a lookup table. Module names are split into less than 10 suffixes, so it is effectively an O(n) operation

This change reduces the latency of load_lora_weights() by around 0.6 seconds on an Azure A100 machine, for a 300MB Flux adaptor (kishlaykumar1995/blinky-flux-lora-32). When the lora state_dict is loaded on the GPU already, load_lora_weights() used to take around 1.1 secs, so it achieves a 50% reduction in latency of applying LoRA

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for suggesting this optimization. Could you please share the snippet you used to measure the time improvement?

While reviewing your code, I noticed that it's very similar to the code that is already in _find_minimal_target_modules:

def generate_suffixes(s):
parts = s.split(".")
return [".".join(parts[i:]) for i in range(len(parts))][::-1]
# Create a reverse lookup for other_module_names to quickly check suffix matches
other_module_suffixes = {suffix for item in other_module_names for suffix in generate_suffixes(item)}

Therefore, I wonder if we could not eliminate the generation of names_no_target altogether since what we really need are the suffixes, not the full module names. For this, we would need to pass the key_list to _find_minimal_target_modules instead of the no longer required names_no_target and then derive the suffixes directly.

In principle, changing the signature like this is fine, since it's a fully private function, but it would mean rewriting the unit tests. I think the rewrite would not be so hard, it would come down to:

    def test_find_minimal_target_modules(self, target_modules, other_module_names, expected):
        # check all possible combinations of list and set
-       result = find_minimal_target_modules(target_modules, other_module_names)
+       all_module_names = target_modules + other_module_names
+       result = find_minimal_target_modules(target_modules, all_module_names)
        assert result == expected

(and then also adjusting the set vs list tests)

What do you think about this potential further optimization?

And also, please run make style on your changes to satisfy the linter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants