diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 0a744264b7a6..c16bd8ac2069 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -838,3 +838,108 @@ def get_connected_passed_kwargs(prefix): ) return init_kwargs + + +def _get_custom_components_and_folders( + pretrained_model_name: str, + config_dict: Dict[str, Any], + filenames: Optional[List[str]] = None, + variant_filenames: Optional[List[str]] = None, + variant: Optional[str] = None, +): + config_dict = config_dict.copy() + + # retrieve all folder_names that contain relevant files + folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] + + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + # optionally create a custom component <> custom file mapping + custom_components = {} + for component in folder_names: + module_candidate = config_dict[component][0] + + if module_candidate is None or not isinstance(module_candidate, str): + continue + + # We compute candidate file path on the Hub. Do not use `os.path.join`. + candidate_file = f"{component}/{module_candidate}.py" + + if candidate_file in filenames: + custom_components[component] = module_candidate + elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): + raise ValueError( + f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." + ) + + if len(variant_filenames) == 0 and variant is not None: + error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + raise ValueError(error_message) + + return custom_components, folder_names + + +def _get_ignore_patterns( + passed_components, + model_folder_names: List[str], + model_filenames: List[str], + variant_filenames: List[str], + use_safetensors: bool, + from_flax: bool, + allow_pickle: bool, + use_onnx: bool, + is_onnx: bool, + variant: Optional[str] = None, +) -> List[str]: + if ( + use_safetensors + and not allow_pickle + and not is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ) + ): + raise EnvironmentError( + f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" + ) + + if from_flax: + ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] + + elif use_safetensors and is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ): + ignore_patterns = ["*.bin", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} + safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} + if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " + f"expected, please check your folder structure." + ) + + else: + ignore_patterns = ["*.safetensors", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} + bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} + if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " + f"your folder structure." + ) + + return ignore_patterns diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 857a13147cfe..2be0c5e7310c 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -71,15 +71,16 @@ CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, _fetch_class_library_tuple, + _get_custom_components_and_folders, _get_custom_pipeline_class, _get_final_device_map, + _get_ignore_patterns, _get_pipeline_class, _identify_model_variants, _maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls, _unwrap_model, _update_init_kwargs_with_connected_pipeline, - is_safetensors_compatible, load_sub_model, maybe_raise_or_warn, variant_compatible_siblings, @@ -1298,44 +1299,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: config_dict = cls._dict_from_json_file(config_file) ignore_filenames = config_dict.pop("_ignore_files", []) - # retrieve all folder_names that contain relevant files - folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] - - diffusers_module = importlib.import_module(__name__.split(".")[0]) - pipelines = getattr(diffusers_module, "pipelines") - - # optionally create a custom component <> custom file mapping - custom_components = {} - for component in folder_names: - module_candidate = config_dict[component][0] - - if module_candidate is None or not isinstance(module_candidate, str): - continue - - # We compute candidate file path on the Hub. Do not use `os.path.join`. - candidate_file = f"{component}/{module_candidate}.py" - - if candidate_file in filenames: - custom_components[component] = module_candidate - elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): - raise ValueError( - f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." - ) - - if len(variant_filenames) == 0 and variant is not None: - error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." - raise ValueError(error_message) - # remove ignored filenames model_filenames = set(model_filenames) - set(ignore_filenames) variant_filenames = set(variant_filenames) - set(ignore_filenames) - # if the whole pipeline is cached we don't have to ping the Hub if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version ) >= version.parse("0.22.0"): warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames) + custom_components, folder_names = _get_custom_components_and_folders( + pretrained_model_name, config_dict, filenames, variant_filenames, variant + ) model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} custom_class_name = None @@ -1395,49 +1370,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] - if ( - use_safetensors - and not allow_pickle - and not is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names - ) - ): - raise EnvironmentError( - f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" - ) - if from_flax: - ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif use_safetensors and is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names - ): - ignore_patterns = ["*.bin", "*.msgpack"] - - use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx - if not use_onnx: - ignore_patterns += ["*.onnx", "*.pb"] - - safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} - safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} - if ( - len(safetensors_variant_filenames) > 0 - and safetensors_model_filenames != safetensors_variant_filenames - ): - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) - else: - ignore_patterns = ["*.safetensors", "*.msgpack"] - - use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx - if not use_onnx: - ignore_patterns += ["*.onnx", "*.pb"] - - bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} - bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} - if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) + # retrieve all patterns that should not be downloaded and error out when needed + ignore_patterns = _get_ignore_patterns( + passed_components, + model_folder_names, + model_filenames, + variant_filenames, + use_safetensors, + from_flax, + allow_pickle, + use_onnx, + pipeline_class._is_onnx, + variant, + ) # Don't download any objects that are passed allow_patterns = [ diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 697244dcb105..5eedd393c8f8 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -18,7 +18,7 @@ StableDiffusionPipeline, UNet2DConditionModel, ) -from diffusers.pipelines.pipeline_utils import is_safetensors_compatible +from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible from diffusers.utils.testing_utils import torch_device