Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Official callbacks #7761

Merged
merged 28 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7b321cf
initial draft
asomoza Apr 24, 2024
91e4896
Merge branch 'huggingface:main' into official-callbacks
asomoza Apr 24, 2024
9bc2640
Merge branch 'main' into official-callbacks
asomoza Apr 26, 2024
e78eed1
automatic inputs from official callbacks
asomoza Apr 29, 2024
e9d77cf
more decriptive name for callback class
asomoza Apr 29, 2024
7e490e7
added list of callbacks support
asomoza Apr 30, 2024
f988880
Merge branch 'main' into official-callbacks
asomoza May 2, 2024
20d759f
new design
asomoza May 2, 2024
540fb87
fix callback_on_step_end type hint
asomoza May 2, 2024
d3d19b0
Merge branch 'main' into official-callbacks
asomoza May 6, 2024
7be58c1
suggestions
asomoza May 6, 2024
def6340
Merge branch 'main' into official-callbacks
asomoza May 7, 2024
6273ed2
added cutoff_step_index
asomoza May 7, 2024
46045ce
propagate changes to stable difussion pipelines
asomoza May 7, 2024
19ca5d7
added docstrings
asomoza May 9, 2024
55c6e75
Merge branch 'main' into official-callbacks
asomoza May 9, 2024
15bd9ea
added documentation
asomoza May 10, 2024
8ceeed8
Merge branch 'main' into official-callbacks
asomoza May 10, 2024
9bae8fc
quality
asomoza May 10, 2024
e23edfe
Merge branch 'main' into official-callbacks
asomoza May 11, 2024
2ab2ba4
apply suggestions
asomoza May 11, 2024
7b30390
missed some dicts
asomoza May 11, 2024
7b28343
comment for conditional embeddings
asomoza May 11, 2024
9884828
added controlnet note
asomoza May 11, 2024
b4597a2
fixed docstring
asomoza May 11, 2024
f9c70d4
changed to SDXL example and added images
asomoza May 11, 2024
4a71be6
Merge branch 'main' into official-callbacks
asomoza May 13, 2024
6ce63d7
doc suggestions applied
asomoza May 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions docs/source/en/using-diffusers/callback.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,50 @@ The denoising loop of a pipeline can be modified with custom defined functions u

This guide will demonstrate how callbacks work by a few features you can implement with them.

## Official callbacks

We provide a list of callbacks you can plug into an existing pipeline and modify the denoising loop. This is the current list of official callbacks:

- `SDCFGCutoffCallback`: Disables the CFG after a certain number of steps for SD 1.5 pipelines.
asomoza marked this conversation as resolved.
Show resolved Hide resolved
- `SDXLCFGCutoffCallback`: Disables the CFG after a certain number of steps for SDXL pipelines.
asomoza marked this conversation as resolved.
Show resolved Hide resolved
- `IPAdapterScaleCutoffCallback`: Disables the IP Adapter after a certain number of steps.
asomoza marked this conversation as resolved.
Show resolved Hide resolved

To set up a callback, you need to specify the number of steps after which the callback comes into effect. You can do so by using either one of these two arguments
asomoza marked this conversation as resolved.
Show resolved Hide resolved

- `cutoff_step_ratio`: Float number with the ratio of the steps.
- `cutoff_step_index`: Integer number with the exact number of the step.

asomoza marked this conversation as resolved.
Show resolved Hide resolved
```python
import torch
from diffusers import StableDiffusionPipeline
from diffusers.callbacks import SDCFGCutoffCallback
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

callback = SDCFGCutoutCallback(cutoff_step_ratio=0.4)
# can also be used with cutoff_step_index
# callback = SDCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"

generator = torch.Generator(device="cuda").manual_seed(1)
out = pipeline(
prompt,
generator=generator,
callback_on_step_end=callback,
)

out.images[0].save("official_callback.png")
asomoza marked this conversation as resolved.
Show resolved Hide resolved
```

## Dynamic classifier-free guidance

Dynamic classifier-free guidance (CFG) is a feature that allows you to disable CFG after a certain number of inference steps which can help you save compute with minimal cost to performance. The callback function for this should have the following arguments:

* `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`.
* `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`.
* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
- `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`.
- `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`.
- `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
Comment on lines +87 to +89
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these changes were made by make quality. Also asked @stevhliu for a review of the documentation.


Your callback function should look something like this:

Expand Down
156 changes: 156 additions & 0 deletions src/diffusers/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Any, Dict, List

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME


class PipelineCallback(ConfigMixin):
asomoza marked this conversation as resolved.
Show resolved Hide resolved
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
"""
Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
custom callbacks and ensures that all callbacks have a consistent interface.

Please implement the following:
`tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
include
variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
`callback_fn`: This method defines the core functionality of your callback.
"""

config_name = CONFIG_NAME

@register_to_config
def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
super().__init__()

if (cutoff_step_ratio is None and cutoff_step_index is None) or (
cutoff_step_ratio is not None and cutoff_step_index is not None
):
raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")

if cutoff_step_ratio is not None and (
not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
):
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")

@property
def tensor_inputs(self) -> List[str]:
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")

def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")

def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)


class MultiPipelineCallbacks:
asomoza marked this conversation as resolved.
Show resolved Hide resolved
"""
This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
provides a unified interface for calling all of them.
"""

def __init__(self, callbacks: List[PipelineCallback]):
self.callbacks = callbacks

@property
def tensor_inputs(self) -> List[str]:
return [input for callback in self.callbacks for input in callback.tensor_inputs]

def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
"""
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
"""
for callback in self.callbacks:
callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)

return callback_kwargs


class SDCFGCutoffCallback(PipelineCallback):
asomoza marked this conversation as resolved.
Show resolved Hide resolved
"""
Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.

Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""

tensor_inputs = ["prompt_embeds"]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:]
asomoza marked this conversation as resolved.
Show resolved Hide resolved

pipeline._guidance_scale = 0.0

callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
return callback_kwargs


class SDXLCFGCutoffCallback(PipelineCallback):
"""
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.

Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""

tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:]

add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
add_text_embeds = add_text_embeds[-1:]

add_time_ids = callback_kwargs[self.tensor_inputs[2]]
add_time_ids = add_time_ids[-1:]

pipeline._guidance_scale = 0.0

callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
return callback_kwargs


class IPAdapterScaleCutoffCallback(PipelineCallback):
"""
Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
`cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.

Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
"""

tensor_inputs = []

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == cutoff_step:
pipeline.set_ip_adapter_scale(0.0)
return callback_kwargs
8 changes: 7 additions & 1 deletion src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
Expand Down Expand Up @@ -926,7 +927,9 @@ def __call__(
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
asomoza marked this conversation as resolved.
Show resolved Hide resolved
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
Expand Down Expand Up @@ -1055,6 +1058,9 @@ def __call__(
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)

if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet

# align format for control guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
Expand Down Expand Up @@ -917,7 +918,9 @@ def __call__(
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
Expand Down Expand Up @@ -1040,6 +1043,9 @@ def __call__(
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)

if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet

# align format for control guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
Expand Down Expand Up @@ -1134,7 +1135,9 @@ def __call__(
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
Expand Down Expand Up @@ -1275,6 +1278,9 @@ def __call__(
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)

if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet

# align format for control guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
CLIPVisionModelWithProjection,
)

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import (
FromSingleFileMixin,
Expand Down Expand Up @@ -1178,7 +1179,9 @@ def __call__(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
Expand Down Expand Up @@ -1351,6 +1354,9 @@ def __call__(
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)

if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet

# align format for control guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from diffusers.utils.import_utils import is_invisible_watermark_available

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import (
FromSingleFileMixin,
Expand Down Expand Up @@ -969,7 +970,9 @@ def __call__(
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
Expand Down Expand Up @@ -1133,6 +1136,9 @@ def __call__(
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)

if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet

# align format for control guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from diffusers.utils.import_utils import is_invisible_watermark_available

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import (
FromSingleFileMixin,
Expand Down Expand Up @@ -1105,7 +1106,9 @@ def __call__(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
Expand Down Expand Up @@ -1288,6 +1291,9 @@ def __call__(
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)

if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet

# align format for control guidance
Expand Down
Loading
Loading