Skip to content

Commit

Permalink
Add prompt scheduling callback to community scripts (#9718)
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky authored Oct 19, 2024
1 parent 5d3e7bd commit 89565e9
Showing 1 changed file with 84 additions and 1 deletion.
85 changes: 84 additions & 1 deletion examples/community/README_community_scripts.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ If a community script doesn't work as expected, please open an issue and ping th
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
| Using IP-Adapter with negative noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) | | [Álvaro Somoza](https://github.com/asomoza)|
| asymmetric tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#asymmetric-tiling ) | | [alexisrolland](https://github.com/alexisrolland)|
| Prompt scheduling callback |Allows changing prompts during a generation | [Prompt Scheduling](#prompt-scheduling ) | | [hlky](https://github.com/hlky)|


## Example usages
Expand Down Expand Up @@ -229,4 +230,86 @@ seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)

torch.cuda.empty_cache()
image.save('image.png')
```
```

### Prompt Scheduling callback

Prompt scheduling callback allows changing prompts during a generation, like [prompt editing in A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing)

```python
from diffusers import StableDiffusionPipeline
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
from diffusers.configuration_utils import register_to_config
import torch
from typing import Any, Dict, Optional


pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
pipeline.safety_checker = None
pipeline.requires_safety_checker = False


class SDPromptScheduleCallback(PipelineCallback):
@register_to_config
def __init__(
self,
prompt: str,
negative_prompt: Optional[str] = None,
num_images_per_prompt: int = 1,
cutoff_step_ratio=1.0,
cutoff_step_index=None,
):
super().__init__(
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
)

tensor_inputs = ["prompt_embeds"]

def callback_fn(
self, pipeline, step_index, timestep, callback_kwargs
) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index
if cutoff_step_index is not None
else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == cutoff_step:
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
prompt=self.config.prompt,
negative_prompt=self.config.negative_prompt,
device=pipeline._execution_device,
num_images_per_prompt=self.config.num_images_per_prompt,
do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
)
if pipeline.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
return callback_kwargs

callback = MultiPipelineCallbacks(
[
SDPromptScheduleCallback(
prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
negative_prompt="Deformed, ugly, bad anatomy",
cutoff_step_ratio=0.25,
)
]
)

image = pipeline(
prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
negative_prompt="Deformed, ugly, bad anatomy",
callback_on_step_end=callback,
callback_on_step_end_tensor_inputs=["prompt_embeds"],
).images[0]
```

0 comments on commit 89565e9

Please sign in to comment.