Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Official callbacks #7761

Merged
merged 28 commits into from
May 13, 2024
Merged

Official callbacks #7761

merged 28 commits into from
May 13, 2024

Conversation

asomoza
Copy link
Member

@asomoza asomoza commented Apr 24, 2024

What does this PR do?

Initial draft to support for official callbacks.

This is the most basic implementation I could think of without the need of modifying the pipelines.

After this, we need to discuss if we're going modify the pipelines to support additional functionalities:

On step begin

For this issue for example, the propossal is to start the CFG after a certain step and to stop it after another step. For the CFG on begin we would need to add an additional callback on_step_begin if we want to do it on the callbacks instead of manually doing it with the embeds and pass them to the pipelines. The same will be needed for differential diffusion.

Automatic callback_on_step_end_tensor_inputs

With the current implementation the user needs to know what to add to the callback_on_step_end_tensor_inputs list, for example for the SDXL implementation of the CFG cutout we need to add prompt_embeds, add_text_embeds, add_time_ids or it won't work. If we want to do this automatically I'll need to modify the pipelines, if not, I can add a error message indicating what values are missing.

The user already needs to know the args for each callback so maybe this is better to just document in a README for all the callbacks.

Chain callbacks

Should we add the functionality to chain callbacks? for example to use a list of callbacks, so we can use the CFG and IP cutouts at the same time? The alternative is to create another callback that does both of them.

  • Draft
  • Automatic callback_on_step_end_tensor_inputs
  • Chain callbacks
  • On step begin?
  • clean code
  • review

Fixes #7736

Example usage:

# for SD 1.5
from diffusers.callbacks import SDCFGCutoffCallback

callback = SDCFGCutoutCallback(cutoff_step_ratio=0.4)
# can also be used with cutoff_step_index
# callback = SDCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    callback_on_step_end=callback,
).images[0]
0.2 0.4 0.8 1.0
20240424003537_4009094394 20240424003545_4009094394 20240424003633_4009094394 20240424003643_4009094394
# for SDXL
from diffusers.callbacks import SDXLCFGCutoffCallback

callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    callback_on_step_end=callback,
).images[0]
0.2 0.4 0.8 1.0
20240424003846_4009094394 20240424003859_4009094394 20240424003914_4009094394 20240424003929_4009094394
# IP Adapter cutout
from diffusers.callbacks import IPAdapterScaleCutoffCallback

callback = IPAdapterScaleCutoffCallback(cutoff_step_ratio=0.4)

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    ip_adapter_image=ip_image,
    callback_on_step_end=callback,
).images[0]
IP Image 0.3 0.5 1.0
ip_source 20240424004314_2010138750 20240424004328_2010138750 20240424004343_2010138750
# Callback list
from diffusers.callbacks import IPAdapterScaleCutoffCallback, MultiPipelineCallbacks, SDXLCFGCutoffCallback

ip_callback = IPAdapterScaleCutoffCallback(cutoff_step_ratio=0.5)
cfg_callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)

callbacks = MultiPipelineCallbacks([ip_callback, cfg_callback])

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    ip_adapter_image=ip_image,
    callback_on_step_end=callbacks,
).images[0]
IP:1.0 - CFG 1.0 IP:1.0 - CFG 0.4 IP 0.5 - CFG 0.4
20240429200914_2010138750 20240429200753_2010138750 20240429201007_2010138750

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

cc @a-r-r-o-w here too

@yiyixuxu
Copy link
Collaborator

I like this - very nice and simple

for the next steps, let's see a proposal for this?

Automatic callback_on_step_end_tensor_input

after that, we can play around with on_the_step_begin; I think we don't have to consider chain callbacks for now, but open to it if we see use many cases for it in the future

@yiyixuxu
Copy link
Collaborator

actually would be nice to support list of callbacks since now we provide official ones that user can mix and match

@AmericanPresidentJimmyCarter
Copy link
Contributor

actually would be nice to support list of callbacks since now we provide official ones that user can mix and match

Yeah, I think this is the right way to do it. In fact I would say to not even use "callbacks" but rather just a pure function for doing each sampling step called default_sampling_function, which is a pure function into which we pass all things possibly required for a single sampling step.

Basically we have the sampling loop (SD pipeline as an example)

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if self.do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)

In my opinion, we need should replace everything underneath for i, t in enumerate(timesteps): with a default_sampling_function function defined elsewhere in the pipeline. The default_sampling_function should do the normal sampling loop and take a SamplingInput dataclass consisting of the things required to perform the sampling step, then a SamplingOutput which consists of the things that are normally returned from the sampling step. The SamplingOutput can be re-fed in as a SamplingInput with the next timestep into the default_sampling_function after.

We can add an argument sampling_functions: list[Callable]=[default_sampling_function] into the __call__ as a new, backwards compatible kwarg.

In this way we can finally get complete control over the sampling loop and chain multiple functions together the process the output of the sampling loop in the order of the sampling functions.

@bghira
Copy link
Contributor

bghira commented Apr 26, 2024

the specifics of that i'll have to wrap my head around but the initial idea of decoupling the logic inside __call__ so it can be more effectively monkeypatched downstream sounds ideal. both things can be done, really

@bghira
Copy link
Contributor

bghira commented Apr 26, 2024

there's the concept library on the hf hub from back in the day. for the uninitiated, it is/was a collection of dreambooths others had done, to make it easier to find eg. a backpack checkpoint or some other oddly specific item you reliably needed to work.

i know it's a security nightmare, but the idea of hub-supported callbacks "calls to me" as something worth bringing up.

on the other hand, having community callbacks in this repo is time-consuming but that allows thorough review of any callbacks that are included. unlike dreambooths, callbacks seem like they'd be rarely created, whereas there a billion potential concepts for a dreambooth.

listing the available callbacks is quite trivial in either case, where a diffusers:callbacks tag or something can be used to differentiate them. scanning these for safety issues with an LLM would possibly help sanitise any obvious issues?

@AmericanPresidentJimmyCarter
Copy link
Contributor

AmericanPresidentJimmyCarter commented Apr 26, 2024

It's probably easier if I write it out in some pseudocode. Writing it down, I think SamplingOutput is probably redundant, so maybe we could come up with a better name for that dataclass.

class SamplingInput:
    def __init__(self, img, text_embedding, unet, timestep=None, **kwargs):
        self.img = img
        self.text_embedding = text_embedding
        self.unet = unet
        self.timestep = timestep

# ... lots of other code ...

        inp = SamplingInput(img, text_embedding, unet)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                inp.timestep = t
                for sampling_function in self.sampling_functions:
                    inp = sampling_function(inp)
                    
        output_img = inp.img

Then we can do something like sampling_functions=[default_sampling_function, report_image_on_step], where the second function is just a sampling function that does nothing with the input, ships the in-progress image to an API somewhere to do realtime updates of inference, then returns the original input.

@Beinsezii
Copy link
Contributor

I really like the array of function pointers. It makes composition easy and clearly signals that the methods are designed to be changeable.

@yiyixuxu
Copy link
Collaborator

ohh thanks!
for this PR we will keep it simple and support official callbacks with minimum change to our pipelines
what you proposed will introduce a pretty drastic change to our design and I think it is outside the scope of this PR so maybe it is better to open a new discussion https://github.com/huggingface/diffusers/discussions instead?

@vladmandic
Copy link
Contributor

a bit late to the party here, but adding one use-case: modifying or skipping steps.
right now, loop is fixed and no matter what happens in callbacks, they cannot influence it:

            for i, t in enumerate(timesteps):

big use case is for callback to actually modify timesteps in some sense - perhaps we want to skip a step? perhaps force an early end since callback function determined it got what it needed and there is no point of running all the remaining steps to completion?

@AmericanPresidentJimmyCarter
Copy link
Contributor

AmericanPresidentJimmyCarter commented Apr 28, 2024

a bit late to the party here, but adding one use-case: modifying or skipping steps. right now, loop is fixed and no matter what happens in callbacks, they cannot influence it:

            for i, t in enumerate(timesteps):

big use case is for callback to actually modify timesteps in some sense - perhaps we want to skip a step? perhaps force an early end since callback function determined it got what it needed and there is no point of running all the remaining steps to completion?

The scheduler modifies which timesteps are in the timestep list, so determination of timesteps to run lives there. You can very simply just write your own scheduler to exclude some timesteps.

@bghira
Copy link
Contributor

bghira commented Apr 28, 2024

or i guess a scheduler wrapper that takes in its own callbacks, in teh case of SD.Next

@AmericanPresidentJimmyCarter
Copy link
Contributor

AmericanPresidentJimmyCarter commented Apr 28, 2024

ohh thanks! for this PR we will keep it simple and support official callbacks with minimum change to our pipelines what you proposed will introduce a pretty drastic change to our design and I think it is outside the scope of this PR so maybe it is better to open a new discussion https://github.com/huggingface/diffusers/discussions instead?

It's a continuation of #7736 but engineering a proper solution rather than a half baked one will save longer in the long run.

For example right now for determining timesteps we have schedulers -- the scheduler is a effectively a function you can pass into the pipeline that is relatively pure and just gets which timesteps you are supposed to perform, for the most part. Ideally we extend such functional designs to the sampling loop as well, and in this case, extend the ability to run multiple sampling functions in sequence.

I believe this solves every current and previous deficit that hacks like

        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],

have failed to fully address.

In my opinion this is poorly engineered, and now that it exists in the codebase it will need to be supported with backwards compatibility for the rest of time whereas I believe my proposed solution is (1) clean (2) consistent with the engineering of schedulers and (3) will not result in technical debt, but will incur a large one time cost to support by various pipelines. For now, just a few of the most used pipelines could be done and the rest stubbed with NotImplementedError.

@vladmandic
Copy link
Contributor

The scheduler modifies which timesteps are in the timestep list, so determination of timesteps to run lives there. You can very simply just write your own scheduler to exclude some timesteps.

@AmericanPresidentJimmyCarter i get that, but i don't want to monkey-patch all schedulers existing in diffusers.

example use case - there are some experimental sd15 models popping-up that are only finetuned on high-noise or low-noise - with idea behind them very similar to sdxl-refiner, but stabilityai never did refiner for sd15 and there is no pipeline for it.
so use case would be to allow initial run to "stop early" (e.g. at 80% of timesteps) so another model can continue from 80%+ of its timesteps (meaning we need to be able to set initial timestep).

or i guess a scheduler wrapper that takes in its own callbacks, in teh case of SD.Next

@bghira i might as well need to do that, i though since we're talking about callbacks design here this would be a place to address future needs.

@bghira
Copy link
Contributor

bghira commented Apr 28, 2024

that's a good point vlad. i was just thinking a preliminary attempt at a scheduler wrapper might result in some lessons being discovered that might help make a better upstream (diffusers) design. but maybe you already have a concrete idea? :P

also #4355 for your SD 1.x refiner needs.

@AmericanPresidentJimmyCarter
Copy link
Contributor

Yeah, this would require even more re-engineering. You would need sampling functions to be a part of the scheduler, and all of them would need to be passed to the scheduler instead of the pipeline. The net effect is more or less the same.

So for every pipeline, we would have a default sampling function which we pass to the default scheduler, and we could also pass multiple of these as I proposed. Then the only difference is in the sampling loop we self.scheduler.sample(...).

@vladmandic
Copy link
Contributor

or a very simple hack using existing callback concept:

  • allow modification of timesteps array from a callback
  • add something like this:
for i, t in enumerate(timesteps):
  if t <= 0:
    continue

@AmericanPresidentJimmyCarter
Copy link
Contributor

Why wouldn't you just subclass the scheduler and then overwrite the get timestep method? That seems trivial?

@yiyixuxu
Copy link
Collaborator

@yiyixuxu
Copy link
Collaborator

@vladmandic I think it's what you proposed here, already implemented

add something like this:
for i, t in enumerate(timesteps):
if t <= 0:
continue

@yiyixuxu
Copy link
Collaborator

@AmericanPresidentJimmyCarter
like I said, this PR's scope is

for this PR we will keep it simple and support official callbacks with minimum change to our pipelines

feel free to open another issue or discussion

@AmericanPresidentJimmyCarter
Copy link
Contributor

Opened as #7808

@asomoza
Copy link
Member Author

asomoza commented May 9, 2024

thanks @bghira, I really appreciate your comment.

@asomoza asomoza requested a review from stevhliu May 10, 2024 08:24
Comment on lines +63 to +65
- `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`.
- `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`.
- `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes were made by make quality. Also asked @stevhliu for a review of the documentation.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the test fails because you spelled Dict as dict, can we fix them so the tests pass?

src/diffusers/callbacks.py Outdated Show resolved Hide resolved
src/diffusers/callbacks.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! let's get this merged soon :)

docs/source/en/using-diffusers/callback.md Outdated Show resolved Hide resolved
@yiyixuxu yiyixuxu requested a review from sayakpaul May 10, 2024 19:10
@yiyixuxu
Copy link
Collaborator

cc @sayakpaul for a final review too

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design looks very clean to me. I love it!

My main comments are mostly on the documentation side. Additionally, I think having some tests (fast) would be nice to have.

docs/source/en/using-diffusers/callback.md Outdated Show resolved Hide resolved
docs/source/en/using-diffusers/callback.md Show resolved Hide resolved
src/diffusers/callbacks.py Show resolved Hide resolved
docs/source/en/using-diffusers/callback.md Outdated Show resolved Hide resolved
src/diffusers/callbacks.py Outdated Show resolved Hide resolved
docs/source/en/using-diffusers/callback.md Show resolved Hide resolved
@sayakpaul
Copy link
Member

@asomoza I think we can merge this PR and introduce a test suite in a future PR. Up to you how you want to tackle it.

@asomoza
Copy link
Member Author

asomoza commented May 12, 2024

yeah, I prefer to write the tests in a different PR after merging this one

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some final feedback on the docs, I think we are ready to merge:)

docs/source/en/using-diffusers/callback.md Outdated Show resolved Hide resolved
docs/source/en/using-diffusers/callback.md Outdated Show resolved Hide resolved
docs/source/en/using-diffusers/callback.md Outdated Show resolved Hide resolved
docs/source/en/using-diffusers/callback.md Outdated Show resolved Hide resolved
@yiyixuxu yiyixuxu merged commit fdb05f5 into huggingface:main May 13, 2024
15 checks passed
@asomoza asomoza deleted the official-callbacks branch May 13, 2024 03:25
@asomoza asomoza mentioned this pull request May 13, 2024
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Diffusers supported callbacks
9 participants