Skip to content

Commit

Permalink
[refactor] DiffusionPipeline.download (#9557)
Browse files Browse the repository at this point in the history
* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
  • Loading branch information
3 people authored Oct 17, 2024
1 parent 9a7f824 commit 5704376
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 74 deletions.
105 changes: 105 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,3 +838,108 @@ def get_connected_passed_kwargs(prefix):
)

return init_kwargs


def _get_custom_components_and_folders(
pretrained_model_name: str,
config_dict: Dict[str, Any],
filenames: Optional[List[str]] = None,
variant_filenames: Optional[List[str]] = None,
variant: Optional[str] = None,
):
config_dict = config_dict.copy()

# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

# optionally create a custom component <> custom file mapping
custom_components = {}
for component in folder_names:
module_candidate = config_dict[component][0]

if module_candidate is None or not isinstance(module_candidate, str):
continue

# We compute candidate file path on the Hub. Do not use `os.path.join`.
candidate_file = f"{component}/{module_candidate}.py"

if candidate_file in filenames:
custom_components[component] = module_candidate
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
raise ValueError(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

return custom_components, folder_names


def _get_ignore_patterns(
passed_components,
model_folder_names: List[str],
model_filenames: List[str],
variant_filenames: List[str],
use_safetensors: bool,
from_flax: bool,
allow_pickle: bool,
use_onnx: bool,
is_onnx: bool,
variant: Optional[str] = None,
) -> List[str]:
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)

if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]

elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
f"expected, please check your folder structure."
)

else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
f"your folder structure."
)

return ignore_patterns
91 changes: 18 additions & 73 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,16 @@
CUSTOM_PIPELINE_FILE_NAME,
LOADABLE_CLASSES,
_fetch_class_library_tuple,
_get_custom_components_and_folders,
_get_custom_pipeline_class,
_get_final_device_map,
_get_ignore_patterns,
_get_pipeline_class,
_identify_model_variants,
_maybe_raise_warning_for_inpainting,
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
is_safetensors_compatible,
load_sub_model,
maybe_raise_or_warn,
variant_compatible_siblings,
Expand Down Expand Up @@ -1298,44 +1299,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])

# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

# optionally create a custom component <> custom file mapping
custom_components = {}
for component in folder_names:
module_candidate = config_dict[component][0]

if module_candidate is None or not isinstance(module_candidate, str):
continue

# We compute candidate file path on the Hub. Do not use `os.path.join`.
candidate_file = f"{component}/{module_candidate}.py"

if candidate_file in filenames:
custom_components[component] = module_candidate
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
raise ValueError(
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
)

if len(variant_filenames) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)

# if the whole pipeline is cached we don't have to ping the Hub
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)

custom_components, folder_names = _get_custom_components_and_folders(
pretrained_model_name, config_dict, filenames, variant_filenames, variant
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}

custom_class_name = None
Expand Down Expand Up @@ -1395,49 +1370,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]

if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
if (
len(safetensors_variant_filenames) > 0
and safetensors_model_filenames != safetensors_variant_filenames
):
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]

bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
logger.warning(
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
)
# retrieve all patterns that should not be downloaded and error out when needed
ignore_patterns = _get_ignore_patterns(
passed_components,
model_folder_names,
model_filenames,
variant_filenames,
use_safetensors,
from_flax,
allow_pickle,
use_onnx,
pipeline_class._is_onnx,
variant,
)

# Don't download any objects that are passed
allow_patterns = [
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
from diffusers.utils.testing_utils import torch_device


Expand Down

0 comments on commit 5704376

Please sign in to comment.