Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flux] Add advanced training script + support textual inversion inference #9434

Merged
Merged
Show file tree
Hide file tree
Changes from 121 commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
90686c2
add ostris trainer to README & add cache latents of vae
linoytsaban Aug 12, 2024
7b12ed2
add ostris trainer to README & add cache latents of vae
linoytsaban Aug 12, 2024
17dca18
style
linoytsaban Aug 12, 2024
de24a4f
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 13, 2024
8b314e9
readme
linoytsaban Aug 13, 2024
a59b063
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 14, 2024
df54cd8
add test for latent caching
linoytsaban Aug 14, 2024
e0e0319
add ostris noise scheduler
linoytsaban Aug 14, 2024
18aa369
style
linoytsaban Aug 14, 2024
f97d53d
fix import
linoytsaban Aug 14, 2024
0156bec
style
linoytsaban Aug 14, 2024
c4c2c48
fix tests
linoytsaban Aug 14, 2024
d514c7b
style
linoytsaban Aug 14, 2024
7ee6041
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 14, 2024
d5c2a36
--change upcasting of transformer?
linoytsaban Aug 16, 2024
e760cda
Merge branch 'main' into dreambooth-lora
linoytsaban Aug 21, 2024
f78ba77
Merge branch 'main' into dreambooth-lora
sayakpaul Aug 22, 2024
1b19593
Merge branch 'main' into dreambooth-lora
sayakpaul Aug 22, 2024
fbacbb5
update readme according to main
linoytsaban Sep 11, 2024
23f0636
Merge branch 'main' into dreambooth-lora
linoytsaban Sep 11, 2024
44c534e
add pivotal tuning for CLIP
linoytsaban Sep 11, 2024
087b982
fix imports, encode_prompt call,add TextualInversionLoaderMixin to Fl…
linoytsaban Sep 11, 2024
d9c3e45
TextualInversionLoaderMixin support for FluxPipeline for inference
linoytsaban Sep 11, 2024
5f9b74f
Merge branch 'huggingface:main' into dreambooth-lora-flux-exploration
linoytsaban Sep 12, 2024
f14617b
Merge branch 'huggingface:main' into dreambooth-lora-flux-exploration
linoytsaban Sep 13, 2024
b4328f8
move changes to advanced flux script, revert canonical
linoytsaban Sep 13, 2024
6254d04
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Sep 13, 2024
7b7a671
add latent caching to canonical script
linoytsaban Sep 13, 2024
2bb4ce1
revert changes to canonical script to keep it separate from https://g…
linoytsaban Sep 13, 2024
dc9be5b
revert changes to canonical script to keep it separate from https://g…
linoytsaban Sep 13, 2024
7098295
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Sep 13, 2024
a25cb90
style
linoytsaban Sep 13, 2024
fd75b67
remove redundant line and change code block placement to align with l…
linoytsaban Sep 13, 2024
238ed70
add initializer_token arg
linoytsaban Sep 16, 2024
4bf3a13
add transformer frac for range support from pure textual inversion to…
linoytsaban Sep 16, 2024
62b8ab8
support pure textual inversion - wip
linoytsaban Sep 17, 2024
30cc651
adjustments to support pure textual inversion and transformer optimiz…
linoytsaban Sep 17, 2024
35ac0f7
fix logic when using initializer token
linoytsaban Sep 17, 2024
7e91489
fix pure_textual_inversion_condition
linoytsaban Sep 17, 2024
8775bea
fix ti/pivotal loading of last validation run
linoytsaban Sep 17, 2024
1cfcfc1
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Sep 17, 2024
00bbb58
remove embeddings loading for ti in final training run (to avoid addi…
linoytsaban Sep 18, 2024
67e1bf7
support pivotal for t5
linoytsaban Sep 18, 2024
e00b30f
adapt pivotal for T5 encoder
linoytsaban Sep 18, 2024
dd327f3
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Sep 18, 2024
f4d6e9a
adapt pivotal for T5 encoder and support in flux pipeline
linoytsaban Sep 18, 2024
6f5460b
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Sep 18, 2024
a01e566
t5 pivotal support + support fo pivotal for clip only or both
linoytsaban Sep 19, 2024
887cc9d
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Sep 19, 2024
f74c4be
fix param chaining
linoytsaban Sep 20, 2024
c597bd8
fix param chaining
linoytsaban Sep 20, 2024
250025f
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Sep 20, 2024
81fe407
README first draft
linoytsaban Sep 21, 2024
c086a14
readme
linoytsaban Sep 23, 2024
549d3d0
readme
linoytsaban Sep 23, 2024
f356a43
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Sep 23, 2024
983bab8
readme
linoytsaban Sep 23, 2024
e253478
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Sep 23, 2024
c969419
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Sep 30, 2024
d42379d
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Sep 30, 2024
aefa48a
style
linoytsaban Sep 30, 2024
d966f05
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Sep 30, 2024
ae10674
fix import
linoytsaban Sep 30, 2024
dcd0e71
style
linoytsaban Sep 30, 2024
99b7521
add fix from https://github.com/huggingface/diffusers/pull/9419
linoytsaban Sep 30, 2024
46943ab
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Sep 30, 2024
57fb65b
add to readme, change function names
linoytsaban Sep 30, 2024
4faa6cf
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 1, 2024
571e49c
te lr changes
linoytsaban Oct 1, 2024
94dbe85
readme
linoytsaban Oct 1, 2024
5e751d4
change concept tokens logic
linoytsaban Oct 2, 2024
d9ed2b1
fix indices
linoytsaban Oct 2, 2024
9e5c04f
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 2, 2024
b686d04
change arg name
linoytsaban Oct 2, 2024
3881dbb
style
linoytsaban Oct 2, 2024
b3e5caa
dummy test
linoytsaban Oct 2, 2024
7db1798
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Oct 2, 2024
ca668b6
revert dummy test
linoytsaban Oct 2, 2024
7fcdc0d
reorder pivoting
linoytsaban Oct 2, 2024
8247860
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 3, 2024
83bcec0
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 4, 2024
8cb7971
add warning in case the token abstraction is not the instance prompt
linoytsaban Oct 4, 2024
b75b3e6
experimental - wip - specific block training
linoytsaban Oct 4, 2024
9a83f27
Merge branch 'main' into dreambooth-lora-flux-exploration
sayakpaul Oct 7, 2024
f110e4e
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 7, 2024
cc6f7a0
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Oct 8, 2024
03a6b5b
fix documentation and token abstraction processing
linoytsaban Oct 8, 2024
749e857
remove transformer block specification feature (for now)
linoytsaban Oct 8, 2024
de3e2a5
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 8, 2024
b791e13
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 9, 2024
c8ddd83
style
linoytsaban Oct 10, 2024
defac21
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 10, 2024
43c2cd5
fix copies
linoytsaban Oct 10, 2024
2ac6898
fix indexing issue when --initializer_concept has different amounts
linoytsaban Oct 10, 2024
d2cd0bf
add if TextualInversionLoaderMixin to all flux pipelines
linoytsaban Oct 10, 2024
f1879bf
style
linoytsaban Oct 10, 2024
20762bc
fix import
linoytsaban Oct 10, 2024
7399abd
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Oct 10, 2024
7366127
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 10, 2024
0e6d31e
fix imports
linoytsaban Oct 10, 2024
cd32792
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Oct 10, 2024
08aafc0
address review comments - remove necessary prints & comments, use pin…
linoytsaban Oct 10, 2024
c5b2422
style
linoytsaban Oct 10, 2024
4b11719
logger info fix
linoytsaban Oct 10, 2024
9373c0a
Merge remote-tracking branch 'origin/dreambooth-lora-flux-exploration…
linoytsaban Oct 10, 2024
6e2cb75
make lora target modules configurable and change the default
linoytsaban Oct 11, 2024
717b5ad
make lora target modules configurable and change the default
linoytsaban Oct 11, 2024
0fde49a
style
linoytsaban Oct 11, 2024
fbbe9c4
Merge branch 'main' into dreambooth-lora-flux-exploration
sayakpaul Oct 14, 2024
366a35e
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 14, 2024
e4fe609
make lora target modules configurable and change the default, add not…
linoytsaban Oct 14, 2024
3d0955b
style
linoytsaban Oct 14, 2024
452cef4
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 14, 2024
f62af61
Merge branch 'main' into dreambooth-lora-flux-exploration
sayakpaul Oct 15, 2024
a4429e0
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 15, 2024
c41dfff
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 15, 2024
cb265ad
add tests
linoytsaban Oct 15, 2024
61426f0
style
linoytsaban Oct 15, 2024
03f19f6
fix repo id
linoytsaban Oct 15, 2024
b1b2128
Merge branch 'main' into dreambooth-lora-flux-exploration
yiyixuxu Oct 16, 2024
bfb0741
Merge branch 'main' into dreambooth-lora-flux-exploration
sayakpaul Oct 16, 2024
bd2be32
add updated requirements for advanced flux
linoytsaban Oct 16, 2024
69d28b5
fix indices of t5 pivotal tuning embeddings
linoytsaban Oct 16, 2024
450f072
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 16, 2024
31de752
fix path in test
linoytsaban Oct 16, 2024
5dfd685
remove `pin_memory`
linoytsaban Oct 16, 2024
9bdb6a1
Merge branch 'main' into dreambooth-lora-flux-exploration
linoytsaban Oct 16, 2024
9093a4b
fix filename of embedding
linoytsaban Oct 16, 2024
f1b08cb
fix filename of embedding
linoytsaban Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 353 additions & 0 deletions examples/advanced_diffusion_training/README_flux.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import sys
import tempfile

import safetensors


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402


logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)


class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py"

def test_dreambooth_lora_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_text_encoder_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

starts_with_expected_prefix = all(
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_expected_prefix)

def test_dreambooth_lora_pivotal_tuning_flux_clip(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder_ti
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure embeddings were also saved
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# make sure the state_dict has the correct naming in the parameters.
textual_inversion_state_dict = safetensors.torch.load_file(
os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")
)
is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys())
self.assertTrue(is_clip)

# when performing pivotal tuning, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder_ti
--enable_t5_ti
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure embeddings were also saved
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# make sure the state_dict has the correct naming in the parameters.
textual_inversion_state_dict = safetensors.torch.load_file(
os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")
)
is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys())
self.assertTrue(is_te)

# when performing pivotal tuning, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)

# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)

def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)

def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()

run_command(self._launch_args + test_args)

self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})

resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + resume_run_args)

self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
Loading
Loading