Skip to content

Commit

Permalink
added cutoff_step_index
Browse files Browse the repository at this point in the history
  • Loading branch information
asomoza committed May 7, 2024
1 parent def6340 commit 6273ed2
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions src/diffusers/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@ class PipelineCallback(ConfigMixin):
config_name = CONFIG_NAME

@register_to_config
def __init__(self, cutoff_step_ratio=1.0):
def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
super().__init__()

if not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0):
if (cutoff_step_ratio is None and cutoff_step_index is None) or (
cutoff_step_ratio is not None and cutoff_step_index is not None
):
raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")

if cutoff_step_ratio is not None and (
not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
):
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")

@property
Expand Down Expand Up @@ -45,8 +52,14 @@ class SDCFGCutoffCallback(PipelineCallback):

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs):
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == int(pipeline.num_timesteps * cutoff_step_ratio):
if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:]

Expand All @@ -61,8 +74,14 @@ class SDXLCFGCutoffCallback(PipelineCallback):

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs):
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

if step_index == int(pipeline.num_timesteps * cutoff_step_ratio):
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:]

Expand All @@ -85,7 +104,13 @@ class IPAdapterScaleCutoffCallback(PipelineCallback):

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs):
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == int(pipeline.num_timesteps * cutoff_step_ratio):
if step_index == cutoff_step:
pipeline.set_ip_adapter_scale(0.0)
return callback_kwargs

0 comments on commit 6273ed2

Please sign in to comment.