From df1ba5d1e6aa9a0a1744b7ae3ff37ca114cec7bb Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 21 Jan 2025 08:25:39 +0800 Subject: [PATCH] Fix bundle download error from ngc source (#8307) Fixes #8306 This previous api has been deprecated, update based on: https://docs.ngc.nvidia.com/api/?urls.primaryName=Private%20Artifacts%20(Models)%20API#/artifact-file-controller/downloadAllArtifactFiles ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 131c78008b..5089f0c045 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -174,7 +174,7 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam def _get_ngc_bundle_url(model_name: str, version: str) -> str: - return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip" + return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files" def _get_ngc_private_base_url(repo: str) -> str: @@ -218,6 +218,21 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str: return name +def _get_all_download_files(request_url: str, headers: dict | None = None) -> list[str]: + if not has_requests: + raise ValueError("requests package is required, please install it.") + headers = {} if headers is None else headers + response = requests_get(request_url, headers=headers) + response.raise_for_status() + model_info = json.loads(response.text) + + if not isinstance(model_info, dict) or "modelFiles" not in model_info: + raise ValueError("The data is not a dictionary or it does not have the key 'modelFiles'.") + + model_files = model_info["modelFiles"] + return [f["path"] for f in model_files] + + def _download_from_ngc( download_path: Path, filename: str, @@ -229,12 +244,12 @@ def _download_from_ngc( # ensure prefix is contained filename = _add_ngc_prefix(filename, prefix=prefix) url = _get_ngc_bundle_url(model_name=filename, version=version) - filepath = download_path / f"{filename}_v{version}.zip" if remove_prefix: filename = _remove_ngc_prefix(filename, prefix=remove_prefix) - extract_path = download_path / f"{filename}" - download_url(url=url, filepath=filepath, hash_val=None, progress=progress) - extractall(filepath=filepath, output_dir=extract_path, has_base=True) + filepath = download_path / filename + filepath.mkdir(parents=True, exist_ok=True) + for file in _get_all_download_files(url): + download_url(url=f"{url}/{file}", filepath=f"{filepath}/{file}", hash_val=None, progress=progress) def _download_from_ngc_private(