-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix VAE loading issue in train_dreambooth #7632
fix VAE loading issue in train_dreambooth #7632
Conversation
We need to make sure the code quality tests pass. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@bssrdf could you run |
Done! |
It still doesn't pass :/ |
I don't know what's going on. Any way, I rerun
|
Could you reverse the last changes python -m uv pip install -e ".[dev]" Maybe |
ee9b19e
to
f04998c
Compare
I downgraded ruff and reran the "make style && make quality". Let's give it another push. Yeah, ruff is now on 0.3.5. diffusers uses 0.1.5 which is pretty old. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems OK now. Thanks for fixing this!
try: | ||
vae = AutoencoderKL.from_pretrained( | ||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant | ||
) | ||
else: | ||
except OSError: | ||
# IF does not have a VAE so let's just set it to None | ||
# We don't have to error out here | ||
vae = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can't blindly just set None for the VAE if it fails to load. assuming it's DeepFloyd isn't great, the reason that check was appropriate before is because model_has_vae would check for an actual VAE config in the repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that check model_has_vae
is broken, hence this PR and #3462. I am not the first one who encountered this problem. This PR will make train_dreambooth work for users until the check is fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the OSError version of this check isn't going to work. the model loader throws OSError for other reasons:
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figured out the problem, which is specific only on Windows OS. Please review the updated PR with model_has_vae
check augmented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for being patient with me and looking into that deeper. that is a good fix to find.
i understand that the other script does this, but that change is from May, 2023. if anything we should fix the method used by this script, and enhance the other scripts to make use of the correct method. |
… windows os to make match work
def model_has_vae(args): | ||
config_file_name = os.path.join("vae", AutoencoderKL.config_name) | ||
if platform.system() == "Windows": | ||
config_file_name = config_file_name.replace('\\', '/') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohhhhhhh!!! i never use Windows, so this explains a lot.
the os.path.join is actually something we can remove then! it can be f"vae/{AutoencoderKL.config_name}"
on line 761 and then the Windows check can go away.
I guess the Git info response will have Linux-like paths in it always.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed os.path.join for config_file_name and used '/' for concatenation. It is OS independent now. Thanks for suggestions.
if os.path.isdir(args.pretrained_model_name_or_path): | ||
config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) | ||
return os.path.isfile(config_file_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these should remain the same using os.*
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These stay the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for finding this. Left a single comment.
def model_has_vae(args): | ||
config_file_name = os.path.join("vae", AutoencoderKL.config_name) | ||
config_file_name = f"vae/{AutoencoderKL.config_name}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make use of config_file_name = Path("vae", AutoencoderKL.config_name).as_posix()
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul, this will also work. I can change to using Path and as_posix if we all agree it is proper to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that will be better to promote generality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i didn't even know that exists. learning things!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! Thanks for the tips. I too learned a thing.
@sayakpaul can we merge this now? |
Sorry for the delay. Once this CI is green, will merge. |
What does this PR do?
This PR is similar to #3462 but fixes the same issue in
train_dreambooth.py
.Fixes #7619
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @standardAI