Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download Support for Custom HuggingFace Models #27

Merged
merged 4 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 62 additions & 49 deletions llm/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import uuid
from typing import List
import huggingface_hub as hfh
from huggingface_hub.utils import HfHubHTTPError
from utils.marsgen import get_mar_name, generate_mars
from utils.system_utils import (
check_if_path_exists,
Expand Down Expand Up @@ -176,65 +175,68 @@ def read_config_for_download(gen_model: GenerateDataModel) -> GenerateDataModel:
"repo_version"
]

# Make sure there is HF hub token for LLAMA(2)
if (
gen_model.repo_info.repo_id.startswith("meta-llama")
and gen_model.repo_info.hf_token is None
):
print(
"## Error: HuggingFace Hub token is required for llama download."
" Please specify it using --hf_token=<your token>. "
"Refer https://huggingface.co/docs/hub/security-tokens"
)
sys.exit(1)

# Validate downloaded files
hf_api = hfh.HfApi()
hf_api.list_repo_commits(
repo_id=gen_model.repo_info.repo_id,
revision=gen_model.repo_info.repo_version,
token=gen_model.repo_info.hf_token,
)

# Read handler file name
if not gen_model.mar_utils.handler_path:
gen_model.mar_utils.handler_path = os.path.join(
os.path.dirname(__file__),
models[gen_model.model_name]["handler"],
)

except (KeyError, HfHubHTTPError):
# Validate hf_token
gen_model.validate_hf_token()

# Validate repository info
gen_model.validate_commit_info()

except (KeyError, ValueError):
print(
"## Error: Please check either repo_id, repo_version"
" or HuggingFace ID is not correct\n"
"## There seems to be an error in the model_config.json file. "
"Please check the same."
)
sys.exit(1)

else: # Custom model case
if not gen_model.skip_download:
print(
"## Please check your model name,"
" it should be one of the following : "
)
print(list(models.keys()))
print(
"\n## If you want to use custom model files,"
" use the '--no_download' argument"
)
sys.exit(1)
gen_model.is_custom_model = True
if gen_model.skip_download:
if check_if_folder_empty(gen_model.mar_utils.model_path):
print("## Error: The given model path folder is empty\n")
sys.exit(1)

if check_if_folder_empty(gen_model.mar_utils.model_path):
print("## Error: The given model path folder is empty\n")
sys.exit(1)
if not gen_model.repo_info.repo_version:
gen_model.repo_info.repo_version = "1.0"
gavrissh marked this conversation as resolved.
Show resolved Hide resolved

else:
if not gen_model.repo_info.repo_id:
print(
"## If you want to create a model archive file with the supported models, "
"make sure you're model name is present in the below : "
)
print(list(models.keys()))
print(
"\nIf you want to create a model archive file for"
" a custom model,there are two methods:\n"
"1. If you have already downloaded the custom model"
" files, please include"
" the --no_download flag and provide the model_path "
"directory which contains the model files.\n"
"2. If you need to download the model files, provide "
"the HuggingFace repository ID using 'repo_id'"
" along with an empty model_path driectory where the "
"model files will be downloaded.\n"
)
sys.exit(1)

# Validate hf_token
gen_model.validate_hf_token()
gavrissh marked this conversation as resolved.
Show resolved Hide resolved

# Validate repository info
gen_model.validate_commit_info()

if not gen_model.mar_utils.handler_path:
gen_model.mar_utils.handler_path = os.path.join(
os.path.dirname(__file__), "handler.py"
)

if not gen_model.repo_info.repo_version:
gen_model.repo_info.repo_version = "1.0"

gen_model.is_custom_model = True
print(
f"\n## Generating MAR file for "
f"custom model files: {gen_model.model_name}"
Expand All @@ -260,7 +262,9 @@ def run_download(gen_model: GenerateDataModel) -> GenerateDataModel:
GenerateDataModel: An instance of the GenerateDataModel class.
"""
if not check_if_folder_empty(gen_model.mar_utils.model_path):
print("## Make sure the path provided to download model files is empty\n")
print(
"## Make sure the model_path provided to download model files through is empty\n"
)
sys.exit(1)

print(
Expand Down Expand Up @@ -290,7 +294,9 @@ def create_mar(gen_model: GenerateDataModel) -> None:
Args:
gen_model (GenerateDataModel): An instance of the GenerateDataModel dataclass
"""
if not gen_model.is_custom_model and not check_if_model_files_exist(gen_model):
if not (
gen_model.is_custom_model and gen_model.skip_download
) and not check_if_model_files_exist(gen_model):
print("## Model files do not match HuggingFace repository files")
sys.exit(1)

Expand Down Expand Up @@ -347,13 +353,20 @@ def run_script(params: argparse.Namespace) -> bool:
type=str,
default="",
required=True,
metavar="mn",
metavar="n",
help="Name of model",
)
parser.add_argument(
"--repo_id",
type=str,
default=None,
metavar="ri",
help="HuggingFace repository ID (In case of custom model download)",
)
parser.add_argument(
"--repo_version",
type=str,
default="",
default=None,
metavar="rv",
help="Commit ID of models repo from HuggingFace repository",
)
Expand All @@ -367,15 +380,15 @@ def run_script(params: argparse.Namespace) -> bool:
type=str,
default="",
required=True,
metavar="mp",
metavar="p",
help="Absolute path of model files (should be empty if downloading)",
)
parser.add_argument(
"--mar_output",
type=str,
default="",
required=True,
metavar="mx",
metavar="a",
help="Absolute path of exported MAR file (.mar)",
)
parser.add_argument(
Expand All @@ -389,7 +402,7 @@ def run_script(params: argparse.Namespace) -> bool:
"--hf_token",
type=str,
default=None,
metavar="hft",
metavar="ht",
help="HuggingFace Hub token to download LLAMA(2) models",
)
parser.add_argument("--debug", action="store_true", help="flag to debug")
Expand Down
73 changes: 67 additions & 6 deletions llm/tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def cleanup_folders() -> None:

def set_generate_args(
model_name: str = MODEL_NAME,
repo_version: str = "",
repo_version: str = None,
model_path: str = MODEL_PATH,
mar_output: str = MODEL_STORE,
handler_path: str = "",
Expand All @@ -82,6 +82,7 @@ def set_generate_args(
args.model_path = model_path
args.mar_output = mar_output
args.no_download = False
args.repo_id = None
args.repo_version = repo_version
args.handler_path = handler_path
args.debug = False
Expand Down Expand Up @@ -250,16 +251,20 @@ def test_skip_download_success() -> None:
assert result is True


def custom_model_setup() -> None:
def custom_model_setup(download_model: bool = True) -> None:
"""
This function is used to setup custom model case.
It runs download.py to download model files and
deletes the contents of 'model_config.json' after
making a backup.

Args:
download (bool): Set to download model files (defaults to True)
"""
download_setup()
args = set_generate_args()
download.run_script(args)
if download_model:
args = set_generate_args()
download.run_script(args)

# creating a backup of original model_config.json
copy_file(MODEL_CONFIG_PATH, MODEL_TEMP_CONFIG_PATH)
Expand All @@ -277,9 +282,9 @@ def custom_model_restore() -> None:
cleanup_folders()


def test_custom_model_success() -> None:
def test_custom_model_skip_download_success() -> None:
"""
This function tests the custom model case.
This function tests the no download custom model case.
This is done by clearing the 'model_config.json' and
generating the 'GPT2' MAR file.
Expected result: Success.
Expand All @@ -296,6 +301,62 @@ def test_custom_model_success() -> None:
custom_model_restore()


def test_custom_model_download_success() -> None:
"""
This function tests the download custom model case.
This is done by clearing the 'model_config.json' and
generating the 'GPT2' MAR file.
Expected result: Success.
"""
custom_model_setup(download_model=False)
args = set_generate_args()
args.repo_id = "gpt2"
try:
result = download.run_script(args)
except SystemExit:
assert False
else:
assert result is True
custom_model_restore()


def test_custom_model_download_wrong_repo_id_throw_error() -> None:
"""
This function tests the download custom model case and
passes a wrong repo_id.
Expected result: Failure.
"""
custom_model_setup(download_model=False)
args = set_generate_args()
args.repo_id = "wrong_repo_id"
try:
download.run_script(args)
except SystemExit as e:
assert e.code == 1
else:
assert False
custom_model_restore()


def test_custom_model_download_wrong_repo_version_throw_error() -> None:
"""
This function tests the download custom model case and
passes a correct repo_id but wrong repo_version.
Expected result: Failure.
"""
custom_model_setup(download_model=False)
args = set_generate_args()
args.repo_id = "gpt2"
args.repo_version = "wrong_repo_version"
try:
download.run_script(args)
except SystemExit as e:
assert e.code == 1
else:
assert False
custom_model_restore()


# Run the tests
if __name__ == "__main__":
pytest.main(["-v", __file__])
24 changes: 22 additions & 2 deletions llm/tests/test_torchserve_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def test_inference_json_file_success() -> None:
assert False


def test_custom_model_success() -> None:
def test_custom_model_skip_download_success() -> None:
"""
This function tests custom model with input folder.
This function tests custom model skipping download with input folder.
Expected result: Success.
"""
custom_model_setup()
Expand All @@ -221,6 +221,26 @@ def test_custom_model_success() -> None:
process = subprocess.run(["python3", "cleanup.py"], check=False)


def test_custom_model_download_success() -> None:
"""
This function tests download custom model input folder.
Expected result: Success.
"""
custom_model_setup(download_model=False)
args = set_generate_args()
args.repo_id = "gpt2"
try:
download.run_script(args)
except SystemExit:
assert False

process = subprocess.run(get_run_cmd(input_path=INPUT_PATH), check=False)
assert process.returncode == 0

custom_model_restore()
process = subprocess.run(["python3", "cleanup.py"], check=False)


# Run the tests
if __name__ == "__main__":
pytest.main(["-v", __file__])
Loading