Skip to content

Commit

Permalink
flux controlnet control_guidance_start and control_guidance_end imple…
Browse files Browse the repository at this point in the history
…ment (#9571)

* flux controlnet control_guidance_start and control_guidance_end implement

* minor fix - added docstrings, consistent controlnet scale flux and SD3
  • Loading branch information
ighoshsubho authored Oct 10, 2024
1 parent e16fd93 commit 38a3e4d
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 6 deletions.
42 changes: 39 additions & 3 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@
>>> image = pipe(
... prompt,
... control_image=control_image,
... controlnet_conditioning_scale=0.6,
... control_guidance_start=0.2,
... control_guidance_end=0.8,
... controlnet_conditioning_scale=1.0,
... num_inference_steps=28,
... guidance_scale=3.5,
... ).images[0]
Expand Down Expand Up @@ -572,6 +574,8 @@ def __call__(
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_image: PipelineImageInput = None,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
Expand Down Expand Up @@ -614,6 +618,10 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
Expand Down Expand Up @@ -674,6 +682,17 @@ def __call__(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor

if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
Expand Down Expand Up @@ -839,7 +858,16 @@ def __call__(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

# 6. Denoising loop
# 6. Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)

# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
Expand All @@ -856,12 +884,20 @@ def __call__(
guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

# controlnet
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale,
conditioning_scale=cond_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@
... prompt,
... image=init_image,
... control_image=control_image,
... controlnet_conditioning_scale=0.6,
... control_guidance_start=0.2,
... control_guidance_end=0.8,
... controlnet_conditioning_scale=1.0,
... strength=0.7,
... num_inference_steps=2,
... guidance_scale=3.5,
Expand Down Expand Up @@ -631,6 +633,8 @@ def __call__(
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -710,6 +714,17 @@ def __call__(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor

if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)

self.check_inputs(
prompt,
prompt_2,
Expand Down Expand Up @@ -862,6 +877,14 @@ def __call__(
latents,
)

controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)

num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

Expand All @@ -877,11 +900,19 @@ def __call__(
)
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale,
conditioning_scale=cond_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
... image=init_image,
... mask_image=mask_image,
... control_image=control_image,
... control_guidance_start=0.2,
... control_guidance_end=0.8,
... controlnet_conditioning_scale=0.7,
... strength=0.7,
... num_inference_steps=28,
Expand Down Expand Up @@ -737,6 +739,8 @@ def __call__(
timesteps: List[int] = None,
num_inference_steps: int = 28,
guidance_scale: float = 7.0,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
Expand Down Expand Up @@ -783,6 +787,10 @@ def __call__(
Custom timesteps to use for the denoising process.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the ControlNet stops applying.
control_mode (`int` or `List[int]`, *optional*):
The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
Expand Down Expand Up @@ -826,6 +834,17 @@ def __call__(
global_height = height
global_width = width

if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)

# 1. Check inputs
self.check_inputs(
prompt,
Expand Down Expand Up @@ -1031,6 +1050,14 @@ def __call__(
generator,
)

controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)

# 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
Expand All @@ -1049,11 +1076,19 @@ def __call__(
else:
guidance = None

if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]

controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale,
conditioning_scale=cond_scale,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
Expand Down

0 comments on commit 38a3e4d

Please sign in to comment.