diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py index 3050f565b3cc..8b32c03edb79 100644 --- a/src/diffusers/callbacks.py +++ b/src/diffusers/callbacks.py @@ -8,10 +8,17 @@ class PipelineCallback(ConfigMixin): config_name = CONFIG_NAME @register_to_config - def __init__(self, cutoff_step_ratio=1.0): + def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None): super().__init__() - if not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0): + if (cutoff_step_ratio is None and cutoff_step_index is None) or ( + cutoff_step_ratio is not None and cutoff_step_index is not None + ): + raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.") + + if cutoff_step_ratio is not None and ( + not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0) + ): raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.") @property @@ -45,8 +52,14 @@ class SDCFGCutoffCallback(PipelineCallback): def callback_fn(self, pipeline, step_index, timestep, callback_kwargs): cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) - if step_index == int(pipeline.num_timesteps * cutoff_step_ratio): + if step_index == cutoff_step: prompt_embeds = callback_kwargs[self.tensor_inputs[0]] prompt_embeds = prompt_embeds[-1:] @@ -61,8 +74,14 @@ class SDXLCFGCutoffCallback(PipelineCallback): def callback_fn(self, pipeline, step_index, timestep, callback_kwargs): cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index - if step_index == int(pipeline.num_timesteps * cutoff_step_ratio): + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: prompt_embeds = callback_kwargs[self.tensor_inputs[0]] prompt_embeds = prompt_embeds[-1:] @@ -85,7 +104,13 @@ class IPAdapterScaleCutoffCallback(PipelineCallback): def callback_fn(self, pipeline, step_index, timestep, callback_kwargs): cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) - if step_index == int(pipeline.num_timesteps * cutoff_step_ratio): + if step_index == cutoff_step: pipeline.set_ip_adapter_scale(0.0) return callback_kwargs