Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise error if disk is full before downloading weights #1903

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions litgpt/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
import importlib.util
from pathlib import Path
import shutil
from typing import List, Optional, Tuple

import torch
Expand Down Expand Up @@ -62,7 +63,38 @@ def download_from_hub(

download_files = ["tokenizer*", "generation_config.json", "config.json"]
if not tokenizer_only:
bins, safetensors = find_weight_files(repo_id, access_token)
bins, safetensors, info = find_weight_files(repo_id, access_token)

total_weight_size_bytes = 0
if bins:
total_weight_size_bytes = sum(
(file.size or 0)
for file in info.siblings
if file.rfilename.endswith(".bin") or file.rfilename.endswith(".bin.index.json")
)
elif safetensors:
total_weight_size_bytes = sum(
(file.size or 0)
for file in info.siblings
if file.rfilename.endswith(".safetensors")
)
else:
raise ValueError(f"Couldn't find weight files for {repo_id}")

weight_size_gb = total_weight_size_bytes / (1024**3)
free_space_bytes = shutil.disk_usage(str(checkpoint_dir)).free
free_space_gb = free_space_bytes / (1024**3)

if weight_size_gb > free_space_gb:
if os.getenv("LIGHTNING_CLUSTER_ID") is not None:
rasbt marked this conversation as resolved.
Show resolved Hide resolved
studio_text = " Please switch to a larger Studio with more disk space."
else:
studio_text = ""
raise RuntimeError(
f"Not enough disk space to download {repo_id} weights. "
f"Needed: ~{weight_size_gb:.2f} GB, free: ~{free_space_gb:.2f} GB.{studio_text}"
)

if bins:
# covers `.bin` files and `.bin.index.json`
download_files.append("*.bin*")
Expand Down Expand Up @@ -104,11 +136,11 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s
from huggingface_hub.utils import filter_repo_objects

with gated_repo_catcher(repo_id, access_token):
info = repo_info(repo_id, token=access_token)
info = repo_info(repo_id, token=access_token, files_metadata=True)
filenames = [f.rfilename for f in info.siblings]
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"]))
safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"]))
return bins, safetensors
return bins, safetensors, info


@contextmanager
Expand Down
Loading