Skip to content

Commit

Permalink
[Core] introduce _no_split_modules to ModelMixin (#6396)
Browse files Browse the repository at this point in the history
* introduce _no_split_modules.

* unnecessary spaces.

* remove unnecessary kwargs and style

* fix: accelerate imports.

* change to _determine_device_map

* add the blocks that have residual connections.

* add: CrossAttnUpBlock2D

* add: testin

* style

* line-spaces

* quality

* add disk offload test without safetensors.

* checking disk offloading percentages.

* change model split

* add: utility for checking multi-gpu requirement.

* model parallelism test

* splits.

* splits.

* splits

* splits.

* splits.

* splits.

* offload folder to test_disk_offload_with_safetensors

* add _no_split_modules

* fix-copies
  • Loading branch information
sayakpaul authored Apr 30, 2024
1 parent b02e211 commit 3fd31ee
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/diffusers/models/autoencoders/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]

@register_to_config
def __init__(
Expand Down
92 changes: 87 additions & 5 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@

if is_accelerate_available():
import accelerate
from accelerate.utils import set_module_tensor_to_device
from accelerate import infer_auto_device_map
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
from accelerate.utils.versions import is_torch_version


Expand Down Expand Up @@ -99,6 +100,29 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].dtype


# Adapted from `transformers` (see modeling_utils.py)
def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype):
if isinstance(device_map, str):
no_split_modules = model._get_no_split_modules(device_map)
device_map_kwargs = {"no_split_module_classes": no_split_modules}

if device_map != "sequential":
max_memory = get_balanced_memory(
model,
dtype=torch_dtype,
low_zero=(device_map == "balanced_low_0"),
max_memory=max_memory,
**device_map_kwargs,
)
else:
max_memory = get_max_memory(max_memory)

device_map_kwargs["max_memory"] = max_memory
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)

return device_map


def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
Expand Down Expand Up @@ -201,6 +225,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
_no_split_modules = None

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -560,6 +585,36 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)

# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
try:
device_map = {"": torch.device(device_map)}
except RuntimeError:
raise ValueError(
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
)
elif isinstance(device_map, int):
if device_map < 0:
raise ValueError(
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
)
else:
device_map = {"": device_map}

if device_map is not None:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")

if low_cpu_mem_usage:
if device_map is not None and not is_torch_version(">=", "1.10"):
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")

# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path

Expand All @@ -582,10 +637,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
token=token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
user_agent=user_agent,
**kwargs,
)
Expand Down Expand Up @@ -690,6 +741,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
try:
accelerate.load_checkpoint_and_dispatch(
model,
Expand Down Expand Up @@ -881,6 +933,36 @@ def _find_mismatched_keys(

return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs

# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.
Args:
device_map (`str`):
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
Returns:
`List[str]`: List of modules that should not be split
"""
_no_split_modules = set()
modules_to_check = [self]
while len(modules_to_check) > 0:
module = modules_to_check.pop(-1)
# if the module does not appear in _no_split_modules, we also check the children
if module.__class__.__name__ not in _no_split_modules:
if isinstance(module, ModelMixin):
if module._no_split_modules is None:
raise ValueError(
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
else:
_no_split_modules = _no_split_modules | set(module._no_split_modules)
modules_to_check += list(module.children())
return list(_no_split_modules)

@property
def device(self) -> torch.device:
"""
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock"]

@register_to_config
def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class conditioning with `class_embed_type` equal to `None`.
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ class conditioning with `class_embed_type` equal to `None`.
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]

@register_to_config
def __init__(
Expand Down
128 changes: 128 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import numpy as np
import requests_mock
import torch
from accelerate.utils import compute_module_sizes
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from requests.exceptions import HTTPError
Expand All @@ -39,6 +40,7 @@
require_torch_2,
require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_gpu,
run_test_in_subprocess,
torch_device,
)
Expand Down Expand Up @@ -200,6 +202,21 @@ class ModelTesterMixin:
main_input_name = None # overwrite in model specific tester class
base_precision = 1e-3
forward_requires_fresh_args = False
model_split_percents = [0.5, 0.7, 0.9]

def check_device_map_is_respected(self, model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
while len(param_name) > 0 and param_name not in device_map:
param_name = ".".join(param_name.split(".")[:-1])
if param_name not in device_map:
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")

param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
self.assertEqual(param.device, torch.device("meta"))
else:
self.assertEqual(param.device, torch.device(param_device))

def test_from_save_pretrained(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
Expand Down Expand Up @@ -670,6 +687,117 @@ def test_deprecated_kwargs(self):
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)

@require_torch_gpu
def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})

self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_gpu
def test_disk_offload_without_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)

with self.assertRaises(ValueError):
max_size = int(self.model_split_percents[1] * model_size)
max_memory = {0: max_size, "cpu": max_size}
# This errors out because it's missing an offload folder
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)

max_size = int(self.model_split_percents[1] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)

self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_gpu
def test_disk_offload_with_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

max_size = int(self.model_split_percents[1] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
)

self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_torch_multi_gpu
def test_model_parallelism(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)

torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

torch.manual_seed(0)
new_output = new_model(**inputs_dict)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))


@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
model_split_percents = [0.5, 0.3, 0.4]

@property
def dummy_input(self):
Expand Down

0 comments on commit 3fd31ee

Please sign in to comment.