Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix VAE loading issue in train_dreambooth #7632

Merged

Conversation

bssrdf
Copy link
Contributor

@bssrdf bssrdf commented Apr 10, 2024

What does this PR do?

This PR is similar to #3462 but fixes the same issue in train_dreambooth.py.

Fixes #7619

Before submitting

Who can review?

@sayakpaul @standardAI

@yiyixuxu yiyixuxu requested a review from sayakpaul April 10, 2024 21:37
@sayakpaul
Copy link
Member

We need to make sure the code quality tests pass.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@tolgacangoz
Copy link
Contributor

@bssrdf could you run make style && make quality?

@bssrdf
Copy link
Contributor Author

bssrdf commented Apr 11, 2024

@bssrdf could you run make style && make quality?

Done!

@sayakpaul
Copy link
Member

It still doesn't pass :/

@bssrdf
Copy link
Contributor Author

bssrdf commented Apr 12, 2024

It still doesn't pass :/

I don't know what's going on. Any way, I rerun make style and make quality and here are the outputs. Nothing stands out.

$ make style
ruff check examples scripts src tests utils benchmarks setup.py --fix
warning: The top-level linter settings are deprecated in favour of their counterparts in the `lint` section. Please update the following options in `pyproject.toml`:
  - 'ignore' -> 'lint.ignore'
  - 'select' -> 'lint.select'
  - 'isort' -> 'lint.isort'
  - 'per-file-ignores' -> 'lint.per-file-ignores'
All checks passed!
ruff format examples scripts src tests utils benchmarks setup.py
warning: The top-level linter settings are deprecated in favour of their counterparts in the `lint` section. Please update the following options in `pyproject.toml`:
  - 'ignore' -> 'lint.ignore'
  - 'select' -> 'lint.select'
  - 'isort' -> 'lint.isort'
  - 'per-file-ignores' -> 'lint.per-file-ignores'
876 files left unchanged
doc-builder style src/diffusers docs/source --max_len 119
make autogenerate_code
make[1]: Entering directory '/home/isodden/temp/diffusers'
running deps_table_update
updating src/diffusers/dependency_versions_table.py
make[1]: Leaving directory '/home/isodden/temp/diffusers'
make extra_style_checks
make[1]: Entering directory '/home/isodden/temp/diffusers'
python utils/custom_init_isort.py
python utils/check_doc_toc.py --fix_and_overwrite
make[1]: Leaving directory '/home/isodden/temp/diffusers'
$make quality
ruff check examples scripts src tests utils benchmarks setup.py
warning: The top-level linter settings are deprecated in favour of their counterparts in the `lint` section. Please update the following options in `pyproject.toml`:
  - 'ignore' -> 'lint.ignore'
  - 'select' -> 'lint.select'
  - 'isort' -> 'lint.isort'
  - 'per-file-ignores' -> 'lint.per-file-ignores'
All checks passed!
ruff format --check examples scripts src tests utils benchmarks setup.py
warning: The top-level linter settings are deprecated in favour of their counterparts in the `lint` section. Please update the following options in `pyproject.toml`:
  - 'ignore' -> 'lint.ignore'
  - 'select' -> 'lint.select'
  - 'isort' -> 'lint.isort'
  - 'per-file-ignores' -> 'lint.per-file-ignores'
876 files already formatted
doc-builder style src/diffusers docs/source --max_len 119 --check_only
python utils/check_doc_toc.py
$git log
commit ee9b19e53e4ff1b14a001fadc140f19c47959bb1 (HEAD -> fix-train-dreambooth-vae-not-loading, origin/fix-train-dreambooth-vae-not-loading)
Merge: ac68fe51 279de3c3
Author: Sayak Paul <spsayakpaul@gmail.com>
Date:   Fri Apr 12 08:33:05 2024 +0530

    Merge branch 'main' into fix-train-dreambooth-vae-not-loading

commit 279de3c3ffedcb1394518a8f1c950fa30f272390
Author: Sai-Suraj-27 <sai.suraj.27.729@gmail.com>
Date:   Fri Apr 12 01:13:01 2024 +0530

    fix: Replaced deprecated `logger.warn` with `logger.warning` (#7643)

    Fixed deprecated logger.warn with logger.warning.

commit 8e14535708f6af0794148150f5c073c4723dbbae
Author: Yiqin Zhao <yiqinzhao@outlook.com>
Date:   Thu Apr 11 15:08:42 2024 -0400

    Fixed YAML loading. (#7579)

commit 0bee4d336b925b6064eee156f5a155e3ca3b30ab
Author: dg845 <58458699+dg845@users.noreply.github.com>
Date:   Thu Apr 11 10:52:12 2024 -0700

    LCM Distill Scripts Fix Bug when Initializing Target U-Net (#6848)

    * Initialize target_unet from unet rather than teacher_unet so that we correctly add time_embedding.cond_proj if necessary.

    * Use UNet2DConditionModel.from_config to initialize target_unet from unet's config.

    ---------

    Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

commit ac68fe5151a7e1bba83167bd07b65aded716174b
Author: bssrdf <bssrdf@gmail.com>
Date:   Thu Apr 11 09:22:17 2024 -0400

    reformat code with make style && make quality

@tolgacangoz
Copy link
Contributor

tolgacangoz commented Apr 12, 2024

Could you reverse the last changes make style && make quality done, and downgrade your ruff, then run the command again? Actually, install diffusers from the source:

python -m uv pip install -e ".[dev]"

Maybe ruff in the diffusers should be upgraded, Idk 🤔.

@bssrdf bssrdf force-pushed the fix-train-dreambooth-vae-not-loading branch from ee9b19e to f04998c Compare April 12, 2024 12:49
@bssrdf
Copy link
Contributor Author

bssrdf commented Apr 12, 2024

I downgraded ruff and reran the "make style && make quality". Let's give it another push. Yeah, ruff is now on 0.3.5. diffusers uses 0.1.5 which is pretty old.

Copy link
Contributor

@tolgacangoz tolgacangoz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems OK now. Thanks for fixing this!

Comment on lines 926 to 933
try:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
else:
except OSError:
# IF does not have a VAE so let's just set it to None
# We don't have to error out here
vae = None
Copy link
Contributor

@bghira bghira Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't blindly just set None for the VAE if it fails to load. assuming it's DeepFloyd isn't great, the reason that check was appropriate before is because model_has_vae would check for an actual VAE config in the repo.

Copy link
Contributor Author

@bssrdf bssrdf Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that check model_has_vae is broken, hence this PR and #3462. I am not the first one who encountered this problem. This PR will make train_dreambooth work for users until the check is fixed.

Copy link
Contributor

@bghira bghira Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the OSError version of this check isn't going to work. the model loader throws OSError for other reasons:

                    raise OSError(
                        "You seem to have cloned a repository without having git-lfs installed. Please install "
                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
                        "you cloned."
                    )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figured out the problem, which is specific only on Windows OS. Please review the updated PR with model_has_vae check augmented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for being patient with me and looking into that deeper. that is a good fix to find.

@bghira
Copy link
Contributor

bghira commented Apr 25, 2024

i understand that the other script does this, but that change is from May, 2023. if anything we should fix the method used by this script, and enhance the other scripts to make use of the correct method.

def model_has_vae(args):
config_file_name = os.path.join("vae", AutoencoderKL.config_name)
if platform.system() == "Windows":
config_file_name = config_file_name.replace('\\', '/')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhhhhhh!!! i never use Windows, so this explains a lot.

the os.path.join is actually something we can remove then! it can be f"vae/{AutoencoderKL.config_name}" on line 761 and then the Windows check can go away.

I guess the Git info response will have Linux-like paths in it always.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed os.path.join for config_file_name and used '/' for concatenation. It is OS independent now. Thanks for suggestions.

Comment on lines 764 to 766
if os.path.isdir(args.pretrained_model_name_or_path):
config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
return os.path.isfile(config_file_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should remain the same using os.*

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These stay the same

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding this. Left a single comment.

def model_has_vae(args):
config_file_name = os.path.join("vae", AutoencoderKL.config_name)
config_file_name = f"vae/{AutoencoderKL.config_name}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make use of config_file_name = Path("vae", AutoencoderKL.config_name).as_posix() here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul, this will also work. I can change to using Path and as_posix if we all agree it is proper to do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that will be better to promote generality.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i didn't even know that exists. learning things!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Thanks for the tips. I too learned a thing.

@yiyixuxu
Copy link
Collaborator

@sayakpaul can we merge this now?

@sayakpaul
Copy link
Member

Sorry for the delay. Once this CI is green, will merge.

@sayakpaul sayakpaul merged commit cdda94f into huggingface:main May 14, 2024
8 checks passed
@bssrdf bssrdf deleted the fix-train-dreambooth-vae-not-loading branch May 14, 2024 12:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Dreambooth] number of channels error in train_dreambooth.py
7 participants