Skip to content

Commit

Permalink
Support simple diffusers Flux loras.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 5, 2024
1 parent 7afa985 commit 78e133d
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
8 changes: 8 additions & 0 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format

if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to

return key_map
53 changes: 53 additions & 0 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,59 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""):

return key_map

def flux_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("depth", 0)
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
hidden_size = mmdit_config.get("hidden_size", 0)

key_map = {}
for index in range(n_double_layers):
prefix_from = "transformer_blocks.{}".format(index)
prefix_to = "{}double_blocks.{}".format(output_prefix, index)

for end in ("weight", "bias"):
k = "{}.attn.".format(prefix_from)
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))

block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
}

for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])

for index in range(n_single_layers):
prefix_from = "single_transformer_blocks.{}".format(index)
prefix_to = "{}single_blocks.{}".format(output_prefix, index)

for end in ("weight", "bias"):
k = "{}.attn.".format(prefix_from)
qkv = "{}.linear1.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))

block_map = {#TODO
}

for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])

MAP_BASIC = { #TODO
}

for k in MAP_BASIC:
if len(k) > 2:
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
else:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])

return key_map

def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)
Expand Down

0 comments on commit 78e133d

Please sign in to comment.