Skip to content

Commit

Permalink
[docs] AutoPipeline (#7714)
Browse files Browse the repository at this point in the history
* autopipeline

* edits

* feedback

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
stevhliu and sayakpaul authored Apr 22, 2024
1 parent a9dd860 commit 33b363e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 132 deletions.
38 changes: 3 additions & 35 deletions docs/source/en/api/pipelines/auto_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,10 @@ specific language governing permissions and limitations under the License.

# AutoPipeline

`AutoPipeline` is designed to:

1. make it easy for you to load a checkpoint for a task without knowing the specific pipeline class to use
2. use multiple pipelines in your workflow

Based on the task, the `AutoPipeline` class automatically retrieves the relevant pipeline given the name or path to the pretrained weights with the `from_pretrained()` method.

To seamlessly switch between tasks with the same checkpoint without reallocating additional memory, use the `from_pipe()` method to transfer the components from the original pipeline to the new one.

```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

image = pipeline(prompt, num_inference_steps=25).images[0]
```

<Tip>

Check out the [AutoPipeline](../../tutorials/autopipeline) tutorial to learn how to use this API!

</Tip>

`AutoPipeline` supports text-to-image, image-to-image, and inpainting for the following diffusion models:

- [Stable Diffusion](./stable_diffusion/overview)
- [ControlNet](./controlnet)
- [Stable Diffusion XL (SDXL)](./stable_diffusion/stable_diffusion_xl)
- [DeepFloyd IF](./deepfloyd_if)
- [Kandinsky 2.1](./kandinsky)
- [Kandinsky 2.2](./kandinsky_v22)
The `AutoPipeline` is designed to make it easy to load a checkpoint for a task without needing to know the specific pipeline class. Based on the task, the `AutoPipeline` automatically retrieves the correct pipeline class from the checkpoint `model_index.json` file.

> [!TIP]
> Check out the [AutoPipeline](../../tutorials/autopipeline) tutorial to learn how to use this API!
## AutoPipelineForText2Image

Expand Down
145 changes: 49 additions & 96 deletions docs/source/en/tutorials/autopipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,75 +12,74 @@ specific language governing permissions and limitations under the License.

# AutoPipeline

🤗 Diffusers is able to complete many different tasks, and you can often reuse the same pretrained weights for multiple tasks such as text-to-image, image-to-image, and inpainting. If you're new to the library and diffusion models though, it may be difficult to know which pipeline to use for a task. For example, if you're using the [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint for text-to-image, you might not know that you could also use it for image-to-image and inpainting by loading the checkpoint with the [`StableDiffusionImg2ImgPipeline`] and [`StableDiffusionInpaintPipeline`] classes respectively.
Diffusers provides many pipelines for basic tasks like generating images, videos, audio, and inpainting. On top of these, there are specialized pipelines for adapters and features like upscaling, super-resolution, and more. Different pipeline classes can even use the same checkpoint because they share the same pretrained model! With so many different pipelines, it can be overwhelming to know which pipeline class to use.

The `AutoPipeline` class is designed to simplify the variety of pipelines in 🤗 Diffusers. It is a generic, *task-first* pipeline that lets you focus on the task. The `AutoPipeline` automatically detects the correct pipeline class to use, which makes it easier to load a checkpoint for a task without knowing the specific pipeline class name.
The [AutoPipeline](../api/pipelines/auto_pipeline) class is designed to simplify the variety of pipelines in Diffusers. It is a generic *task-first* pipeline that lets you focus on a task ([`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]) without needing to know the specific pipeline class. The [AutoPipeline](../api/pipelines/auto_pipeline) automatically detects the correct pipeline class to use.

<Tip>
For example, let's use the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint.

Take a look at the [AutoPipeline](../api/pipelines/auto_pipeline) reference to see which tasks are supported. Currently, it supports text-to-image, image-to-image, and inpainting.
Under the hood, [AutoPipeline](../api/pipelines/auto_pipeline):

</Tip>
1. Detects a `"stable-diffusion"` class from the [model_index.json](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0/blob/main/model_index.json) file.
2. Depending on the task you're interested in, it loads the [`StableDiffusionPipeline`], [`StableDiffusionImg2ImgPipeline`], or [`StableDiffusionInpaintPipeline`]. Any parameter (`strength`, `num_inference_steps`, etc.) you would pass to these specific pipelines can also be passed to the [AutoPipeline](../api/pipelines/auto_pipeline).

This tutorial shows you how to use an `AutoPipeline` to automatically infer the pipeline class to load for a specific task, given the pretrained weights.

## Choose an AutoPipeline for your task

Start by picking a checkpoint. For example, if you're interested in text-to-image with the [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) checkpoint, use [`AutoPipelineForText2Image`]:
<hfoptions id="autopipeline">
<hfoption id="text-to-image">

```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
"dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
prompt = "peasant and dragon combat, wood cutting style, viking era, bevel with rune"

image = pipeline(prompt, num_inference_steps=25).images[0]
prompt = "cinematic photo of Godzilla eating sushi with a cat in a izakaya, 35mm photograph, film, professional, 4k, highly detailed"
generator = torch.Generator(device="cpu").manual_seed(37)
image = pipe_txt2img(prompt, generator=generator).images[0]
image
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png" alt="generated image of peasant fighting dragon in wood cutting style"/>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png"/>
</div>

Under the hood, [`AutoPipelineForText2Image`]:

1. automatically detects a `"stable-diffusion"` class from the [`model_index.json`](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json) file
2. loads the corresponding text-to-image [`StableDiffusionPipeline`] based on the `"stable-diffusion"` class name

Likewise, for image-to-image, [`AutoPipelineForImage2Image`] detects a `"stable-diffusion"` checkpoint from the `model_index.json` file and it'll load the corresponding [`StableDiffusionImg2ImgPipeline`] behind the scenes. You can also pass any additional arguments specific to the pipeline class such as `strength`, which determines the amount of noise or variation added to an input image:
</hfoption>
<hfoption id="image-to-image">

```py
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
import torch
import requests
from PIL import Image
from io import BytesIO

pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
"dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
prompt = "a portrait of a dog wearing a pearl earring"

url = "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0f/1665_Girl_with_a_Pearl_Earring.jpg/800px-1665_Girl_with_a_Pearl_Earring.jpg"
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png")

response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
image.thumbnail((768, 768))
prompt = "cinematic photo of Godzilla eating burgers with a cat in a fast food restaurant, 35mm photograph, film, professional, 4k, highly detailed"
generator = torch.Generator(device="cpu").manual_seed(53)
image = pipe_img2img(prompt, image=init_image, generator=generator).images[0]
image
```

image = pipeline(prompt, image, num_inference_steps=200, strength=0.75, guidance_scale=10.5).images[0]
Notice how the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint is used for both text-to-image and image-to-image tasks? To save memory and avoid loading the checkpoint twice, use the [`~DiffusionPipeline.from_pipe`] method.

```py
pipe_img2img = AutoPipelineForImage2Image.from_pipe(pipe_txt2img).to("cuda")
image = pipeline(prompt, image=init_image, generator=generator).images[0]
image
```

You can learn more about the [`~DiffusionPipeline.from_pipe`] method in the [Reuse a pipeline](../using-diffusers/loading#reuse-a-pipeline) guide.

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png" alt="generated image of a vermeer portrait of a dog wearing a pearl earring"/>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png"/>
</div>

And if you want to do inpainting, then [`AutoPipelineForInpainting`] loads the underlying [`StableDiffusionInpaintPipeline`] class in the same way:
</hfoption>
<hfoption id="inpainting">

```py
from diffusers import AutoPipelineForInpainting
Expand All @@ -91,22 +90,27 @@ pipeline = AutoPipelineForInpainting.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = load_image(img_url).convert("RGB")
mask_image = load_image(mask_url).convert("RGB")
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png")
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-mask.png")

prompt = "A majestic tiger sitting on a bench"
image = pipeline(prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]
prompt = "cinematic photo of a owl, 35mm photograph, film, professional, 4k, highly detailed"
generator = torch.Generator(device="cpu").manual_seed(38)
image = pipeline(prompt, image=init_image, mask_image=mask_image, generator=generator, strength=0.4).images[0]
image
```

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png" alt="generated image of a tiger sitting on a bench"/>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png"/>
</div>

If you try to load an unsupported checkpoint, it'll throw an error:
</hfoption>
</hfoptions>

## Unsupported checkpoints

The [AutoPipeline](../api/pipelines/auto_pipeline) supports [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [Stable Diffusion XL](../api/pipelines/stable_diffusion/stable_diffusion_xl), [ControlNet](../api/pipelines/controlnet), [Kandinsky 2.1](../api/pipelines/kandinsky.md), [Kandinsky 2.2](../api/pipelines/kandinsky_v22), and [DeepFloyd IF](../api/pipelines/deepfloyd_if) checkpoints.

If you try to load an unsupported checkpoint, you'll get an error.

```py
from diffusers import AutoPipelineForImage2Image
Expand All @@ -117,54 +121,3 @@ pipeline = AutoPipelineForImage2Image.from_pretrained(
)
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
```

## Use multiple pipelines

For some workflows or if you're loading many pipelines, it is more memory-efficient to reuse the same components from a checkpoint instead of reloading them which would unnecessarily consume additional memory. For example, if you're using a checkpoint for text-to-image and you want to use it again for image-to-image, use the [`~AutoPipelineForImage2Image.from_pipe`] method. This method creates a new pipeline from the components of a previously loaded pipeline at no additional memory cost.

The [`~AutoPipelineForImage2Image.from_pipe`] method detects the original pipeline class and maps it to the new pipeline class corresponding to the task you want to do. For example, if you load a `"stable-diffusion"` class pipeline for text-to-image:

```py
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
import torch

pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
)
print(type(pipeline_text2img))
"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'>"
```

Then [`~AutoPipelineForImage2Image.from_pipe`] maps the original `"stable-diffusion"` pipeline class to [`StableDiffusionImg2ImgPipeline`]:

```py
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
print(type(pipeline_img2img))
"<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline'>"
```

If you passed an optional argument - like disabling the safety checker - to the original pipeline, this argument is also passed on to the new pipeline:

```py
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
import torch

pipeline_text2img = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
requires_safety_checker=False,
).to("cuda")

pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img)
print(pipeline_img2img.config.requires_safety_checker)
"False"
```

You can overwrite any of the arguments and even configuration from the original pipeline if you want to change the behavior of the new pipeline. For example, to turn the safety checker back on and add the `strength` argument:

```py
pipeline_img2img = AutoPipelineForImage2Image.from_pipe(pipeline_text2img, requires_safety_checker=True, strength=0.3)
print(pipeline_img2img.config.requires_safety_checker)
"True"
```
5 changes: 4 additions & 1 deletion docs/source/en/using-diffusers/loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ When you load multiple pipelines that share the same model components, it makes
1. You generated an image with the [`StableDiffusionPipeline`] but you want to improve its quality with the [`StableDiffusionSAGPipeline`]. Both of these pipelines share the same pretrained model, so it'd be a waste of memory to load the same model twice.
2. You want to add a model component, like a [`MotionAdapter`](../api/pipelines/animatediff#animatediffpipeline), to [`AnimateDiffPipeline`] which was instantiated from an existing [`StableDiffusionPipeline`]. Again, both pipelines share the same pretrained model, so it'd be a waste of memory to load an entirely new pipeline again.

With the [`DiffusionPipeline.from_pipe`] API, you can switch between multiple pipelines to take advantage of their different features without increasing memory-usage. It is similar to turning on and off a feature in your pipeline. To switch between tasks, use the [`~DiffusionPipeline.from_pipe`] method with the [`AutoPipeline`](../api/pipelines/auto_pipeline) class, which automatically identifies the pipeline class based on the task (learn more in the [AutoPipeline](../tutorials/autopipeline) tutorial).
With the [`DiffusionPipeline.from_pipe`] API, you can switch between multiple pipelines to take advantage of their different features without increasing memory-usage. It is similar to turning on and off a feature in your pipeline.

> [!TIP]
> To switch between tasks (rather than features), use the [`~DiffusionPipeline.from_pipe`] method with the [AutoPipeline](../api/pipelines/auto_pipeline) class, which automatically identifies the pipeline class based on the task (learn more in the [AutoPipeline](../tutorials/autopipeline) tutorial).
Let's start with a [`StableDiffusionPipeline`] and then reuse the loaded model components to create a [`StableDiffusionSAGPipeline`] to increase generation quality. You'll use the [`StableDiffusionPipeline`] with an [IP-Adapter](./ip_adapter) to generate a bear eating pizza.

Expand Down

0 comments on commit 33b363e

Please sign in to comment.