Skip to content

Commit

Permalink
Multi-image masking for single IP Adapter (#7499)
Browse files Browse the repository at this point in the history
* Support multiimage masking

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
  • Loading branch information
3 people authored Apr 9, 2024
1 parent a341b53 commit a0cf607
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 50 deletions.
178 changes: 128 additions & 50 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from importlib import import_module
from typing import Callable, Optional, Union
from typing import Callable, List, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -2195,42 +2195,78 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)

if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if len(ip_adapter_masks) != len(self.scale):
raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
if mask is not None:
if not isinstance(scale, list):
scale = [scale]

current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)

mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

hidden_states = hidden_states + scale * current_ip_hidden_states
hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down Expand Up @@ -2369,49 +2405,91 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)

if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
raise ValueError(
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if len(ip_adapter_masks) != len(self.scale):
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
if mask is not None:
if not isinstance(scale, list):
scale = [scale]

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
_current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

if mask is not None:
mask_downsample = IPAdapterMaskProcessor.downsample(
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
)
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

current_ip_hidden_states = current_ip_hidden_states * mask_downsample
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)

hidden_states = hidden_states + scale * current_ip_hidden_states
hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down
30 changes: 30 additions & 0 deletions tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,33 @@ def test_ip_adapter_multiple_masks(self):

max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4

def test_ip_adapter_multiple_masks_one_adapter(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
torch_dtype=self.dtype,
)
pipeline.enable_model_cpu_offload()
pipeline.load_ip_adapter(
"h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]
)
pipeline.set_ip_adapter_scale([[0.7, 0.7]])

inputs = self.get_dummy_inputs(for_masks=True)
masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"]
processor = IPAdapterMaskProcessor()
masks = processor.preprocess(masks)
masks = masks.reshape(1, masks.shape[0], masks.shape[2], masks.shape[3])
inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks]
ip_images = inputs["ip_adapter_image"]
inputs["ip_adapter_image"] = [[image[0] for image in ip_images]]
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array(
[0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424]
)

max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
50 changes: 50 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,11 @@ def test_pipeline_signature(self):
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, cross_attention_dim), device=torch_device)

def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
return _masks

def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
if "image" in parameters.keys() and "strength" in parameters.keys():
Expand Down Expand Up @@ -365,6 +370,51 @@ def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4):

assert out_cfg.shape == out_no_cfg.shape

def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
sample_size = pipe.unet.config.get("sample_size", 32)
block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512])
input_size = sample_size * (2 ** (len(block_out_channels) - 1))

# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()

adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)

# forward pass with single ip adapter and masks, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()

# forward pass with single ip adapter and masks, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()

max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()

self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)


class PipelineLatentTesterMixin:
"""
Expand Down

0 comments on commit a0cf607

Please sign in to comment.