From f457b596a801f801ce8992f4665cf85fe9dec74f Mon Sep 17 00:00:00 2001 From: sailesh Date: Wed, 29 Nov 2023 10:20:17 +0000 Subject: [PATCH 1/4] Download Custom HuggingFace Models and Progress Bar --- llm/download.py | 114 ++++++++++++++++++------------- llm/tests/test_download.py | 73 ++++++++++++++++++-- llm/tests/test_torchserve_run.py | 24 ++++++- llm/utils/generate_data_model.py | 42 ++++++++++++ llm/utils/inference_utils.py | 21 ++++-- llm/utils/marsgen.py | 82 ++++++++++++++++++++-- llm/utils/tsutils.py | 6 +- 7 files changed, 291 insertions(+), 71 deletions(-) diff --git a/llm/download.py b/llm/download.py index 2472804..e77c676 100644 --- a/llm/download.py +++ b/llm/download.py @@ -15,7 +15,6 @@ import uuid from typing import List import huggingface_hub as hfh -from huggingface_hub.utils import HfHubHTTPError from utils.marsgen import get_mar_name, generate_mars from utils.system_utils import ( check_if_path_exists, @@ -176,26 +175,6 @@ def read_config_for_download(gen_model: GenerateDataModel) -> GenerateDataModel: "repo_version" ] - # Make sure there is HF hub token for LLAMA(2) - if ( - gen_model.repo_info.repo_id.startswith("meta-llama") - and gen_model.repo_info.hf_token is None - ): - print( - "## Error: HuggingFace Hub token is required for llama download." - " Please specify it using --hf_token=. " - "Refer https://huggingface.co/docs/hub/security-tokens" - ) - sys.exit(1) - - # Validate downloaded files - hf_api = hfh.HfApi() - hf_api.list_repo_commits( - repo_id=gen_model.repo_info.repo_id, - revision=gen_model.repo_info.repo_version, - token=gen_model.repo_info.hf_token, - ) - # Read handler file name if not gen_model.mar_utils.handler_path: gen_model.mar_utils.handler_path = os.path.join( @@ -203,38 +182,64 @@ def read_config_for_download(gen_model: GenerateDataModel) -> GenerateDataModel: models[gen_model.model_name]["handler"], ) - except (KeyError, HfHubHTTPError): + # Validate hf_token + gen_model.validate_hf_token() + + # Validate repository info + gen_model.get_latest_commit_id() + + except KeyError: print( - "## Error: Please check either repo_id, repo_version" - " or HuggingFace ID is not correct\n" + "## There seems to be an error in the model_config.json file. " + "Please check the same." ) sys.exit(1) + else: # Custom model case - if not gen_model.skip_download: - print( - "## Please check your model name," - " it should be one of the following : " - ) - print(list(models.keys())) - print( - "\n## If you want to use custom model files," - " use the '--no_download' argument" - ) - sys.exit(1) + gen_model.is_custom_model = True + if gen_model.skip_download: + if check_if_folder_empty(gen_model.mar_utils.model_path): + print("## Error: The given model path folder is empty\n") + sys.exit(1) - if check_if_folder_empty(gen_model.mar_utils.model_path): - print("## Error: The given model path folder is empty\n") - sys.exit(1) + if not gen_model.repo_info.repo_version: + gen_model.repo_info.repo_version = "1.0" + + else: + if not gen_model.repo_info.repo_id: + print( + "## If you want to create a model archive file with the supported models, " + "make sure you're model name is present in the below : " + ) + print(list(models.keys())) + print( + "\nIf you want to create a model archive file for" + " a custom model,there are two methods:\n" + "1. If you have already downloaded the custom model" + " files, please include" + " the --no_download flag and provide the model_path " + "directory which contains the model files.\n" + "2. If you need to download the model files, provide " + "the HuggingFace repository ID using 'repo_id'" + " along with an empty model_path driectory where the " + "model files will be downloaded.\n" + ) + sys.exit(1) + + # Validate hf_token + gen_model.validate_hf_token() + + # Validate repository info + lastest_commit_id = gen_model.get_latest_commit_id() + + if not gen_model.repo_info.repo_version: + gen_model.repo_info.repo_version = lastest_commit_id if not gen_model.mar_utils.handler_path: gen_model.mar_utils.handler_path = os.path.join( os.path.dirname(__file__), "handler.py" ) - if not gen_model.repo_info.repo_version: - gen_model.repo_info.repo_version = "1.0" - - gen_model.is_custom_model = True print( f"\n## Generating MAR file for " f"custom model files: {gen_model.model_name}" @@ -260,7 +265,9 @@ def run_download(gen_model: GenerateDataModel) -> GenerateDataModel: GenerateDataModel: An instance of the GenerateDataModel class. """ if not check_if_folder_empty(gen_model.mar_utils.model_path): - print("## Make sure the path provided to download model files is empty\n") + print( + "## Make sure the model_path provided to download model files through is empty\n" + ) sys.exit(1) print( @@ -290,7 +297,9 @@ def create_mar(gen_model: GenerateDataModel) -> None: Args: gen_model (GenerateDataModel): An instance of the GenerateDataModel dataclass """ - if not gen_model.is_custom_model and not check_if_model_files_exist(gen_model): + if not ( + gen_model.is_custom_model and gen_model.skip_download + ) and not check_if_model_files_exist(gen_model): print("## Model files do not match HuggingFace repository files") sys.exit(1) @@ -347,13 +356,20 @@ def run_script(params: argparse.Namespace) -> bool: type=str, default="", required=True, - metavar="mn", + metavar="n", help="Name of model", ) + parser.add_argument( + "--repo_id", + type=str, + default=None, + metavar="ri", + help="HuggingFace repository ID (In case of custom model download)", + ) parser.add_argument( "--repo_version", type=str, - default="", + default=None, metavar="rv", help="Commit ID of models repo from HuggingFace repository", ) @@ -367,7 +383,7 @@ def run_script(params: argparse.Namespace) -> bool: type=str, default="", required=True, - metavar="mp", + metavar="p", help="Absolute path of model files (should be empty if downloading)", ) parser.add_argument( @@ -375,7 +391,7 @@ def run_script(params: argparse.Namespace) -> bool: type=str, default="", required=True, - metavar="mx", + metavar="a", help="Absolute path of exported MAR file (.mar)", ) parser.add_argument( @@ -389,7 +405,7 @@ def run_script(params: argparse.Namespace) -> bool: "--hf_token", type=str, default=None, - metavar="hft", + metavar="ht", help="HuggingFace Hub token to download LLAMA(2) models", ) parser.add_argument("--debug", action="store_true", help="flag to debug") diff --git a/llm/tests/test_download.py b/llm/tests/test_download.py index 69f99bc..50089e1 100644 --- a/llm/tests/test_download.py +++ b/llm/tests/test_download.py @@ -60,7 +60,7 @@ def cleanup_folders() -> None: def set_generate_args( model_name: str = MODEL_NAME, - repo_version: str = "", + repo_version: str = None, model_path: str = MODEL_PATH, mar_output: str = MODEL_STORE, handler_path: str = "", @@ -82,6 +82,7 @@ def set_generate_args( args.model_path = model_path args.mar_output = mar_output args.no_download = False + args.repo_id = None args.repo_version = repo_version args.handler_path = handler_path args.debug = False @@ -250,16 +251,20 @@ def test_skip_download_success() -> None: assert result is True -def custom_model_setup() -> None: +def custom_model_setup(download: bool = True) -> None: """ This function is used to setup custom model case. It runs download.py to download model files and deletes the contents of 'model_config.json' after making a backup. + + Args: + download (bool): Set to download model files (defaults to True) """ download_setup() - args = set_generate_args() - download.run_script(args) + if download: + args = set_generate_args() + download.run_script(args) # creating a backup of original model_config.json copy_file(MODEL_CONFIG_PATH, MODEL_TEMP_CONFIG_PATH) @@ -277,9 +282,9 @@ def custom_model_restore() -> None: cleanup_folders() -def test_custom_model_success() -> None: +def test_custom_model_skip_download_success() -> None: """ - This function tests the custom model case. + This function tests the no download custom model case. This is done by clearing the 'model_config.json' and generating the 'GPT2' MAR file. Expected result: Success. @@ -296,6 +301,62 @@ def test_custom_model_success() -> None: custom_model_restore() +def test_custom_model_download_success() -> None: + """ + This function tests the download custom model case. + This is done by clearing the 'model_config.json' and + generating the 'GPT2' MAR file. + Expected result: Success. + """ + custom_model_setup(download = False) + args = set_generate_args() + args.repo_id = "gpt2" + try: + result = download.run_script(args) + except SystemExit: + assert False + else: + assert result is True + custom_model_restore() + + +def test_custom_model_download_wrong_repo_id_throw_error() -> None: + """ + This function tests the download custom model case and + passes a wrong repo_id. + Expected result: Failure. + """ + custom_model_setup(download = False) + args = set_generate_args() + args.repo_id = "wrong_repo_id" + try: + download.run_script(args) + except SystemExit as e: + assert e.code == 1 + else: + assert False + custom_model_restore() + + +def test_custom_model_download_wrong_repo_version_throw_error() -> None: + """ + This function tests the download custom model case and + passes a correct repo_id but wrong repo_version. + Expected result: Failure. + """ + custom_model_setup(download = False) + args = set_generate_args() + args.repo_id = "gpt2" + args.repo_version = "wrong_repo_version" + try: + download.run_script(args) + except SystemExit as e: + assert e.code == 1 + else: + assert False + custom_model_restore() + + # Run the tests if __name__ == "__main__": pytest.main(["-v", __file__]) diff --git a/llm/tests/test_torchserve_run.py b/llm/tests/test_torchserve_run.py index e044210..dd4d7a7 100644 --- a/llm/tests/test_torchserve_run.py +++ b/llm/tests/test_torchserve_run.py @@ -201,9 +201,9 @@ def test_inference_json_file_success() -> None: assert False -def test_custom_model_success() -> None: +def test_custom_model_skip_download_success() -> None: """ - This function tests custom model with input folder. + This function tests custom model skipping download with input folder. Expected result: Success. """ custom_model_setup() @@ -221,6 +221,26 @@ def test_custom_model_success() -> None: process = subprocess.run(["python3", "cleanup.py"], check=False) +def test_custom_model_download_success() -> None: + """ + This function tests download custom model input folder. + Expected result: Success. + """ + custom_model_setup(download=False) + args = set_generate_args() + args.repo_id = "gpt2" + try: + download.run_script(args) + except SystemExit: + assert False + + process = subprocess.run(get_run_cmd(input_path=INPUT_PATH), check=False) + assert process.returncode == 0 + + custom_model_restore() + process = subprocess.run(["python3", "cleanup.py"], check=False) + + # Run the tests if __name__ == "__main__": pytest.main(["-v", __file__]) diff --git a/llm/utils/generate_data_model.py b/llm/utils/generate_data_model.py index 3a951d7..7b7db88 100644 --- a/llm/utils/generate_data_model.py +++ b/llm/utils/generate_data_model.py @@ -7,6 +7,8 @@ import os import dataclasses import sys +import huggingface_hub as hfh +from huggingface_hub.utils import HfHubHTTPError, HFValidationError @dataclasses.dataclass @@ -86,6 +88,7 @@ def set_values(self, params: argparse.Namespace) -> None: self.debug = params.debug self.repo_info.hf_token = params.hf_token + self.repo_info.repo_id = params.repo_id self.repo_info.repo_version = params.repo_version self.mar_utils.handler_path = params.handler_path @@ -107,3 +110,42 @@ def check_if_mar_exists(self) -> None: f"{self.repo_info.repo_version}\n" ) sys.exit(1) + + def validate_hf_token(self) -> None: + """ + This method makes sure there is HuggingFace token is valid + for the Meta models (Llama models) + """ + # Make sure there is HF hub token for LLAMA(2) + if ( + self.repo_info.repo_id.startswith("meta-llama") + and self.repo_info.hf_token is None + ): + print( + "## Error: HuggingFace Hub token is required for llama download." + " Please specify it using --hf_token=. " + "Refer https://huggingface.co/docs/hub/security-tokens" + ) + sys.exit(1) + + def get_latest_commit_id(self) -> None: + """ + This method validates the HuggingFace repository information and + gets the latest commit ID of the model. + """ + # Validate downloaded files + try: + hf_api = hfh.HfApi() + commit_info = hf_api.list_repo_commits( + repo_id=self.repo_info.repo_id, + revision=self.repo_info.repo_version, + token=self.repo_info.hf_token, + ) + self.repo_info.repo_version = commit_info[0].commit_id + + except (HfHubHTTPError, HFValidationError): + print( + "## Error: Please check either repo_id, repo_version" + " or HuggingFace ID is not correct\n" + ) + sys.exit(1) diff --git a/llm/utils/inference_utils.py b/llm/utils/inference_utils.py index e7934cd..a61fd61 100644 --- a/llm/utils/inference_utils.py +++ b/llm/utils/inference_utils.py @@ -7,6 +7,7 @@ import traceback from typing import List, Dict import json +import tqdm import requests import utils.tsutils as ts import utils.system_utils as su @@ -41,9 +42,10 @@ def start_ts_server(ts_data: TorchserveStartData, debug: bool) -> None: sys.exit(1) -def ts_health_check(model_name: str, model_timeout: int = 1200) -> None: +def ts_health_check(model_name: str, model_timeout: int = 1500) -> None: """ - This function checks if the model is registered or not. + This function checks if the model is registered or not. Also displays a + progress bar for the same. Args: model_name (str): The name of the model that is being registered. deploy_name (str): The name of the server where the model is registered. @@ -55,6 +57,13 @@ def ts_health_check(model_name: str, model_timeout: int = 1200) -> None: retry_count = 0 sleep_time = 15 success = False + total_tries = int(model_timeout / sleep_time) + progress_bar = tqdm.tqdm( + total=total_tries, + unit="check", + desc="Waiting for Model to be ready", + bar_format="{desc}: |{bar}| {n_fmt}/{total_fmt} checks", + ) while not success and retry_count * sleep_time < model_timeout: try: success = ts.run_health_check(model_name) @@ -63,11 +72,15 @@ def ts_health_check(model_name: str, model_timeout: int = 1200) -> None: if not success: time.sleep(sleep_time) retry_count += 1 + progress_bar.update(1) if success: - print("## Health check passed. Model registered.\n") + progress_bar.update(total_tries - retry_count) + progress_bar.close() + print("\n## Health check passed. Model registered.\n") else: + progress_bar.close() print( - f"## Failed health check after multiple retries for model - {model_name} \n" + f"\n## Failed health check after multiple retries for model - {model_name} \n" ) sys.exit(1) diff --git a/llm/utils/marsgen.py b/llm/utils/marsgen.py index d62376b..e0c76cc 100644 --- a/llm/utils/marsgen.py +++ b/llm/utils/marsgen.py @@ -6,8 +6,11 @@ """ import os import sys +import time +import threading import subprocess -from typing import Dict +from typing import List, Dict +import tqdm from utils.system_utils import check_if_path_exists, get_all_files_in_directory from utils.generate_data_model import GenerateDataModel @@ -15,6 +18,61 @@ MAR_NAME_LEN = 7 +def monitor_marfile_size( + file_path: str, approx_marfile_size: float, stop_monitoring: threading.Event +) -> None: + """ + Monitor the generation of a Model Archive File and display progress. + + Args: + file_path (str): The path to the Model Archive File. + approx_marfile_size (int): The approximate size of the Model Archive File in bytes. + stop_monitoring (threading.Event): Threading Event to stop progress bar. + """ + print("Model Archive File is Generating...\n") + previous_file_size = 0 + progress_bar = tqdm.tqdm( + total=approx_marfile_size, + unit="B", + unit_scale=True, + desc="Creating Model Archive", + ) + while not stop_monitoring.is_set(): + try: + current_file_size = os.path.getsize(file_path) + except FileNotFoundError: + current_file_size = 0 + size_change = current_file_size - previous_file_size + previous_file_size = current_file_size + progress_bar.update(size_change) + time.sleep(2) + progress_bar.update(approx_marfile_size - current_file_size) + progress_bar.close() + print( + f"\nModel Archive file size: {os.path.getsize(file_path) / (1024 ** 3):.2f} GB\n" + ) + + +def get_files_sizes(file_paths: List) -> float: + """ + Calculate the total size of the specified files. + + Args: + file_paths (list): A list of file paths for which the sizes should be calculated. + + Returns: + total_size (float): The sum of sizes (in bytes) of all the specified files. + """ + total_size = 0 + for file_path in file_paths: + try: + size = os.path.getsize(file_path) + total_size += size + except FileNotFoundError: + print(f"File not found: {file_path}") + return total_size + + def get_mar_name( model_name: str, repo_version: str, is_custom_model: str = False ) -> str: @@ -45,7 +103,8 @@ def generate_mars( ) -> None: """ This function runs Torch Model Archiver command to generate MAR file. It calls the - model_archiver_command_builder function to generate the command which it then runs + model_archiver_command_builder function to generate the command which it then runs. + It also starts a thread for the progress bar of Model Archive file generation. Args: gen_model (GenerateDataModel): Dataclass that contains data required to generate MAR file. @@ -90,12 +149,21 @@ def generate_mars( print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n") try: + stop_monitoring = threading.Event() + approx_marfile_size = get_files_sizes(extra_files_list) / 1.15 + mar_progress_thread = threading.Thread( + target=monitor_marfile_size, + args=( + os.path.join(model_store_dir, f"{gen_model.model_name}.mar"), + approx_marfile_size, + stop_monitoring, + ), + ) + mar_progress_thread.start() subprocess.check_call(cmd, shell=True) - if debug: - print( - f"## Model {gen_model.model_name} with version " - f"{gen_model.repo_info.repo_version} is generated.\n" - ) + stop_monitoring.set() + mar_progress_thread.join() + print(f"## {gen_model.model_name}.mar is generated.\n") except subprocess.CalledProcessError as exc: print("## Creation failed !\n") if debug: diff --git a/llm/utils/tsutils.py b/llm/utils/tsutils.py index 0dd6407..7880f81 100644 --- a/llm/utils/tsutils.py +++ b/llm/utils/tsutils.py @@ -52,11 +52,11 @@ def generate_ts_start_cmd(ts_data: TorchserveStartData, ncs: bool, debug: bool) if ts_data.ts_log_config: cmd += f" --log-config {ts_data.ts_log_config}" if ts_data.ts_log_file: - print(f"## Console logs redirected to file: {ts_data.ts_log_file} \n") - dirpath = os.path.dirname(ts_data.ts_log_file) - cmd += f" >> {os.path.join(dirpath,ts_data.ts_log_file)}" + cmd += f" >> {ts_data.ts_log_file}" if debug: print(f"## In directory: {os.getcwd()} | Executing command: {cmd} \n") + log_file = os.path.join(os.path.dirname(ts_data.ts_log_file), "ts_log.log") + print(f"## TorchServe Inference Server logs can be found at: {log_file} \n") return cmd From 6dea92ff3299b637221ebad09c04fedb186bfe9b Mon Sep 17 00:00:00 2001 From: sailesh Date: Wed, 29 Nov 2023 10:24:33 +0000 Subject: [PATCH 2/4] Minor linting fix --- llm/tests/test_download.py | 10 +++++----- llm/tests/test_torchserve_run.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llm/tests/test_download.py b/llm/tests/test_download.py index 50089e1..e8f3538 100644 --- a/llm/tests/test_download.py +++ b/llm/tests/test_download.py @@ -251,7 +251,7 @@ def test_skip_download_success() -> None: assert result is True -def custom_model_setup(download: bool = True) -> None: +def custom_model_setup(download_model: bool = True) -> None: """ This function is used to setup custom model case. It runs download.py to download model files and @@ -262,7 +262,7 @@ def custom_model_setup(download: bool = True) -> None: download (bool): Set to download model files (defaults to True) """ download_setup() - if download: + if download_model: args = set_generate_args() download.run_script(args) @@ -308,7 +308,7 @@ def test_custom_model_download_success() -> None: generating the 'GPT2' MAR file. Expected result: Success. """ - custom_model_setup(download = False) + custom_model_setup(download_model=False) args = set_generate_args() args.repo_id = "gpt2" try: @@ -326,7 +326,7 @@ def test_custom_model_download_wrong_repo_id_throw_error() -> None: passes a wrong repo_id. Expected result: Failure. """ - custom_model_setup(download = False) + custom_model_setup(download_model=False) args = set_generate_args() args.repo_id = "wrong_repo_id" try: @@ -344,7 +344,7 @@ def test_custom_model_download_wrong_repo_version_throw_error() -> None: passes a correct repo_id but wrong repo_version. Expected result: Failure. """ - custom_model_setup(download = False) + custom_model_setup(download_model=False) args = set_generate_args() args.repo_id = "gpt2" args.repo_version = "wrong_repo_version" diff --git a/llm/tests/test_torchserve_run.py b/llm/tests/test_torchserve_run.py index dd4d7a7..3a4aec1 100644 --- a/llm/tests/test_torchserve_run.py +++ b/llm/tests/test_torchserve_run.py @@ -226,7 +226,7 @@ def test_custom_model_download_success() -> None: This function tests download custom model input folder. Expected result: Success. """ - custom_model_setup(download=False) + custom_model_setup(download_model=False) args = set_generate_args() args.repo_id = "gpt2" try: From 923bbc437a4b19a5fe0ebf9c7f11ebcce5e9233b Mon Sep 17 00:00:00 2001 From: sailesh Date: Mon, 4 Dec 2023 05:49:15 +0000 Subject: [PATCH 3/4] minor repo_version change --- llm/download.py | 9 +++------ llm/utils/generate_data_model.py | 10 ++++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/llm/download.py b/llm/download.py index e77c676..9ccb8ca 100644 --- a/llm/download.py +++ b/llm/download.py @@ -186,9 +186,9 @@ def read_config_for_download(gen_model: GenerateDataModel) -> GenerateDataModel: gen_model.validate_hf_token() # Validate repository info - gen_model.get_latest_commit_id() + gen_model.validate_commit_info() - except KeyError: + except (KeyError, ValueError): print( "## There seems to be an error in the model_config.json file. " "Please check the same." @@ -230,10 +230,7 @@ def read_config_for_download(gen_model: GenerateDataModel) -> GenerateDataModel: gen_model.validate_hf_token() # Validate repository info - lastest_commit_id = gen_model.get_latest_commit_id() - - if not gen_model.repo_info.repo_version: - gen_model.repo_info.repo_version = lastest_commit_id + gen_model.validate_commit_info() if not gen_model.mar_utils.handler_path: gen_model.mar_utils.handler_path = os.path.join( diff --git a/llm/utils/generate_data_model.py b/llm/utils/generate_data_model.py index 7b7db88..19e1273 100644 --- a/llm/utils/generate_data_model.py +++ b/llm/utils/generate_data_model.py @@ -128,12 +128,11 @@ def validate_hf_token(self) -> None: ) sys.exit(1) - def get_latest_commit_id(self) -> None: + def validate_commit_info(self) -> str: """ This method validates the HuggingFace repository information and - gets the latest commit ID of the model. + sets the latest commit ID of the model if repo_version is None. """ - # Validate downloaded files try: hf_api = hfh.HfApi() commit_info = hf_api.list_repo_commits( @@ -141,7 +140,10 @@ def get_latest_commit_id(self) -> None: revision=self.repo_info.repo_version, token=self.repo_info.hf_token, ) - self.repo_info.repo_version = commit_info[0].commit_id + + # Set repo_version to latest commit ID if it is None + if not self.repo_info.repo_version: + self.repo_info.repo_version = commit_info[0].commit_id except (HfHubHTTPError, HFValidationError): print( From cbb87ad619dc86bd7b5df135959d56c80ad662f6 Mon Sep 17 00:00:00 2001 From: sailesh Date: Mon, 4 Dec 2023 07:16:49 +0000 Subject: [PATCH 4/4] added comments and moved get_files_sizes --- llm/utils/inference_utils.py | 2 ++ llm/utils/marsgen.py | 33 +++++++++++---------------------- llm/utils/system_utils.py | 21 +++++++++++++++++++++ 3 files changed, 34 insertions(+), 22 deletions(-) diff --git a/llm/utils/inference_utils.py b/llm/utils/inference_utils.py index a61fd61..eb35910 100644 --- a/llm/utils/inference_utils.py +++ b/llm/utils/inference_utils.py @@ -58,6 +58,8 @@ def ts_health_check(model_name: str, model_timeout: int = 1500) -> None: sleep_time = 15 success = False total_tries = int(model_timeout / sleep_time) + + # health check progress bar progress_bar = tqdm.tqdm( total=total_tries, unit="check", diff --git a/llm/utils/marsgen.py b/llm/utils/marsgen.py index e0c76cc..a859d02 100644 --- a/llm/utils/marsgen.py +++ b/llm/utils/marsgen.py @@ -9,9 +9,13 @@ import time import threading import subprocess -from typing import List, Dict +from typing import Dict import tqdm -from utils.system_utils import check_if_path_exists, get_all_files_in_directory +from utils.system_utils import ( + check_if_path_exists, + get_all_files_in_directory, + get_files_sizes, +) from utils.generate_data_model import GenerateDataModel # MAR_NAME_LEN - Number of characters to include from repo_version in MAR name @@ -53,26 +57,6 @@ def monitor_marfile_size( ) -def get_files_sizes(file_paths: List) -> float: - """ - Calculate the total size of the specified files. - - Args: - file_paths (list): A list of file paths for which the sizes should be calculated. - - Returns: - total_size (float): The sum of sizes (in bytes) of all the specified files. - """ - total_size = 0 - for file_path in file_paths: - try: - size = os.path.getsize(file_path) - total_size += size - except FileNotFoundError: - print(f"File not found: {file_path}") - return total_size - - def get_mar_name( model_name: str, repo_version: str, is_custom_model: str = False ) -> str: @@ -149,8 +133,13 @@ def generate_mars( print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n") try: + # Event to stop the thread from monitoring output file size. stop_monitoring = threading.Event() + + # Approximate size of output Model Archive file. approx_marfile_size = get_files_sizes(extra_files_list) / 1.15 + + # Creating a thread to monitor MAR file size while generation and show progress bar mar_progress_thread = threading.Thread( target=monitor_marfile_size, args=( diff --git a/llm/utils/system_utils.py b/llm/utils/system_utils.py index e2c9270..ab9411d 100644 --- a/llm/utils/system_utils.py +++ b/llm/utils/system_utils.py @@ -6,6 +6,7 @@ """ import os import sys +from typing import List from pathlib import Path nvidia_smi_cmd = { @@ -89,3 +90,23 @@ def get_all_files_in_directory(directory): if file.is_file() ] return output + + +def get_files_sizes(file_paths: List) -> float: + """ + Calculate the total size of the specified files. + + Args: + file_paths (list): A list of file paths for which the sizes should be calculated. + + Returns: + total_size (float): The sum of sizes (in bytes) of all the specified files. + """ + total_size = 0 + for file_path in file_paths: + try: + size = os.path.getsize(file_path) + total_size += size + except FileNotFoundError: + print(f"File not found: {file_path}") + return total_size