Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding VQGAN Training script #5483

Merged
merged 50 commits into from
May 15, 2024
Merged

Adding VQGAN Training script #5483

merged 50 commits into from
May 15, 2024

Conversation

isamu-isozaki
Copy link
Contributor

@isamu-isozaki isamu-isozaki commented Oct 23, 2023

What does this PR do?

This is a vqgan training script ported from taming-transformers and from lucidrian's muse-maskgit repo here and open-muse. I'm planning to test this on the cifar10 dataset to confirm it works

Some steps missing/need confirmation are

  • Confirm einops and timm can be external dependencies. If not convert these ops to native pytorch
  • Test on cifar10
  • Add in test to test_models_vq and test_models_vae for the slight modification

Fixes #4702

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@isamu-isozaki isamu-isozaki marked this pull request as draft October 23, 2023 01:02
@isamu-isozaki
Copy link
Contributor Author

Once confirmed it works with cifar10 will remove the draft part

@isamu-isozaki
Copy link
Contributor Author

isamu-isozaki commented Oct 30, 2023

I was able to start training this script. And I removed the einops dependencies. The only additional dependency so far is timm. I plan to run this overnight on cifar with 128 image resolution and then remove the draft from this pr. Also let me know if anyone knows a good VQModel config that's easy to train/fast

@isamu-isozaki
Copy link
Contributor Author

isamu-isozaki commented Oct 31, 2023

Ok! Training seems to work. Here's a wandb run on cifar 10. In 6gb vram, command to run this is

accelerate launch train_vqgan.py --dataset_name=cifar10 --image_column=img --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg images/horse.jpg images/ship.jpg --resolution=128 --train_batch_size=2 --gradient_accumulation_steps=8 --report_to=wandb

For the validation images, they will be shown like so for each validation image provided. The left is the input image and the right is the generated image
original vs generated

The remaining parts that I can think of are

  • Make log_validation support trackers other than wandb
  • Make tqdm updates similar to other examples

I did find a bug where global step doesn't seem to go above 3000 but once that is fixed I'll open for review

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@isamu-isozaki
Copy link
Contributor Author

The main logic is done so I think it's ready for review. For the 3000 step bug I'm currently running training to see if it happens again after the fixes.

@isamu-isozaki isamu-isozaki marked this pull request as ready for review October 31, 2023 15:03
@isamu-isozaki
Copy link
Contributor Author

isamu-isozaki commented Nov 1, 2023

Ok! Seems like it was a hardware issue(I think). Got steps 3100. Script should be ready for review.

@isamu-isozaki isamu-isozaki changed the title WIP: Adding VQGAN Training script Adding VQGAN Training script Nov 1, 2023
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 26, 2023
@github-actions github-actions bot closed this Dec 26, 2023
@yqy2001
Copy link

yqy2001 commented Feb 24, 2024

Hi there, what is the current status of this PR? It seems that everything works well. Will this be merged?

@isamu-isozaki
Copy link
Contributor Author

@sayakpaul I tried fixing by following the @require_torch format by making a @require_timm. Let me know what you think

return (dec,)

return DecoderOutput(sample=dec)
if return_loss:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we don't need this special return_loss flag
I don't think it would break, no?
it is a tuple, we would usually use out[0], if it is a DecoderOutput, usually we do out.sample; I think just adding the loss to the output should be fine cc @DN6 to confirm if it's non-breaking

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, wouldn't it be possible for some people to do out[-1] for tuple? I think that's the only time it'll break

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if they do that, I think the error message would be fairly easy to digest but I don't think it will be breaking. WDYT? I like the idea of not introducing return_loss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll remove return_loss

@sayakpaul
Copy link
Member

@isamu-isozaki I pushed a few things and I hope you don't mind.

  • Moved is_timm_available and require_timm to proper modules.
  • Added timm as a dependency in our workflows.
  • Decorated the trainer test class with require_timm instead of doing it per test methods.

@isamu-isozaki
Copy link
Contributor Author

@sayakpaul No worries thanks a bunch for doing that. I did forget the proper way to add modules for the tests 😅

@sayakpaul
Copy link
Member

Ah all tests passing. Sight to the sore eyes, eh!

@isamu-isozaki
Copy link
Contributor Author

isamu-isozaki commented Apr 30, 2024

@sayakpaul awesome! I just removed return_loss(and hopefully tests still pass). I did do the tests on my end

@sayakpaul
Copy link
Member

@isamu-isozaki could you resolve the conflicts so that it's ready for merging? We would like to include in our upcoming release. Sorry for the delay on my end.

@yiyixuxu could give the changes introduced to the library components a bit?

@isamu-isozaki
Copy link
Contributor Author

@sayakpaul tnx I think I resolved the conflicts but let me do tests to make sure

@sayakpaul
Copy link
Member

Okay the code quality issues should be easy to fix I think. But LMK if you find difficulties. What I would do:

  • Create a fresh Python env.
  • From diffusers root, run pip install -e .[quality].
  • Run make style && make quality.
  • Push the changes.

@isamu-isozaki
Copy link
Contributor Author

isamu-isozaki commented May 15, 2024

@sayakpaul tnx a bunch. I think I fixed the ruff format error but one question I have is when I try locally the doc-builder always fails even in a fresh environment with the above steps. But when I just fix all the tests before that, the checks in the ci usually passes. It might be a bug on my part but is that a common issue?
The error is

ruff check examples scripts src tests utils benchmarks setup.py
ruff format --check examples scripts src tests utils benchmarks setup.py
918 files left unchanged
doc-builder style src/diffusers docs/source --max_len 119 --check_only
Traceback (most recent call last):
  File "/home/isamu/miniconda3/envs/diffusers/bin/doc-builder", line 8, in <module>
    sys.exit(main())
  File "/home/isamu/miniconda3/envs/diffusers/lib/python3.10/site-packages/doc_builder/commands/doc_builder_cli.py", line 47, in main
    args.func(args)
  File "/home/isamu/miniconda3/envs/diffusers/lib/python3.10/site-packages/doc_builder/commands/style.py", line 28, in style_command
    raise ValueError(f"{len(changed)} files should be restyled!")
ValueError: 284 files should be restyled!
Makefile:43: recipe for target 'quality' failed
make: *** [quality] Error 1

locally but I think it'll pass here

@sayakpaul
Copy link
Member

That's weird. Could be a setup related problem :/

@sayakpaul
Copy link
Member

Alright merging this now!

@sayakpaul sayakpaul merged commit d27e996 into huggingface:main May 15, 2024
15 checks passed
@sayakpaul
Copy link
Member

Thanks a lot for shipping this super cool script, @isamu-isozaki. Really appreciate your hard work and patience!

@isamu-isozaki
Copy link
Contributor Author

@sayakpaul np! No worries at all and thanks for the support!

XSE42 added a commit to XSE42/diffusers3d that referenced this pull request Jun 23, 2024
diffusers commit d27e996
    Adding VQGAN Training script huggingface/diffusers#5483
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Example script to train a VQ-VAE
7 participants