Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Autoencoder #2098

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Conversation

calvinpelletier
Copy link
Contributor

@calvinpelletier calvinpelletier commented Dec 2, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

  • Implement the Flux autoencoder model
  • Function to convert weights from HF's format to ours
  • Unit tests

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Parity: 3.36e-5

Speed: ~10% faster

Mnimal testing code:

import numpy as np
import torch
from PIL import Image
from safetensors.torch import load_file

from torchtune.models.flux import flux_autoencoder
from torchtune.models.flux._convert_weights import flux_ae_hf_to_tune

# First, download the AE weights: https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py#L373
WEIGHTS_PATH = "/home/cpelletier/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-dev/snapshots/0ef5fff789c832c5c7f4e127f94c8b54bbcced44/ae.safetensors"


def to_img(x):
    x = x.clamp(-1, 1).float()
    x = x[0].permute(1, 2, 0)
    return Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())


def to_tensor(img):
    img = img.resize((256, 256), Image.Resampling.LANCZOS)
    x = np.array(img, dtype=np.float32)
    x = (x / 127.5) - 1.0
    x = torch.from_numpy(x)
    x = x.permute(2, 0, 1).unsqueeze(0)
    return x.cuda().to(dtype=torch.bfloat16)


ae = flux_autoencoder()
ae.load_state_dict(flux_ae_hf_to_tune(load_file(WEIGHTS_PATH)))
ae = ae.to(dtype=torch.bfloat16, device="cuda").eval().requires_grad_(False)

x = to_tensor(Image.open("/home/cpelletier/tmp.jpg"))
y = ae(x)
to_img(y).save("/home/cpelletier/out.png")

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Dec 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2098

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit aa8b751 with merge base 32e265d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 2, 2024
from torchtune.models.flux._autoencoder import FluxAutoencoder


def flux_autoencoder() -> FluxAutoencoder:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a model builder or component builder?

Also, should it be named flux_v1_autoencoder, anticipating future flux versions? Or just flux_autoencoder, anticipating that the autoencoder part probably won't change in the next version?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think the name flux_1_autoencoder would be good. If doesn't make sense to break Autoencoder into smaller components, then there's no need for a component builder.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good. A few minor things to followup on and I'll approve

from torchtune.models.flux._autoencoder import FluxAutoencoder


def flux_autoencoder() -> FluxAutoencoder:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think the name flux_1_autoencoder would be good. If doesn't make sense to break Autoencoder into smaller components, then there's no need for a component builder.

"""
z = z / self.scale_factor + self.shift_factor
h = self.conv_in(z)
h = self.mid(h)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small thing, we follow the pattern of looping through the list of layers here instead of putting them into Sequential. This makes the forward pass a bit more readable and easier to add breakpoints.

return self.scale_factor * (z - self.shift_factor)


class Decoder(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we name these FluxDecoder/FluxEncoder?

# ch = number of channels (size of the channel dimension)


class FluxAutoencoder(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of overkill now, but I'd have the FluxAutoencoder just take in an Enocder and Decoder as inputs and then define a component builder function to define Encoder + Decoder and pass them into Autoencoder. This does nothing for us now, but if we want to add lora support for the autoencoder in the future it'll make it much easier.

assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.4286, 0.4276, 0.4054])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are from the small dummy model, did you run these same tests on the full model with weights against the flux codebase?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants