From 7c3404311f7a114361ca0877906ebc3379fa86d7 Mon Sep 17 00:00:00 2001 From: angelayi Date: Thu, 11 Apr 2024 23:20:50 -0700 Subject: [PATCH] add packaging to aoti --- .gitignore | 1 + _package_aoti.py | 192 ++++++++++++++++++++++++++++++++++++++ _pt2_archive_constants.py | 36 +++++++ export_aoti.py | 16 ++-- generate.py | 5 +- 5 files changed, 241 insertions(+), 9 deletions(-) create mode 100644 .gitignore create mode 100644 _package_aoti.py create mode 100644 _pt2_archive_constants.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..c18dd8d83 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/_package_aoti.py b/_package_aoti.py new file mode 100644 index 000000000..4689aa67a --- /dev/null +++ b/_package_aoti.py @@ -0,0 +1,192 @@ +import glob +import os +import pathlib +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch._inductor +import torch.utils._pytree as pytree +from torch.export._tree_utils import reorder_kwargs +from torch.export import ExportedProgram +from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact + + +from _pt2_archive_constants import ( + AOTINDUCTOR_DIR, + ARCHIVE_ROOT_NAME, + CONSTANTS_DIR, + MODELS_FILENAME_FORMAT, + SAMPLE_INPUTS_DIR, + WEIGHTS_DIR, +) + + +ARCHIVE_VERSION = 0 + +class PT2ArchiveWriter: + def __init__(self, archive_path: str): + self.archive_file = torch._C.PyTorchFileWriter(archive_path) + self.archive_file.set_min_version(ARCHIVE_VERSION) + self.write_string("archive_format", "pt2") + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def write_bytes(self, name: str, data: bytes) -> None: + assert isinstance(data, bytes), f"Expected bytes but got {type(data)}" + self.archive_file.write_record(name, data, len(data)) + + def write_string(self, name: str, data: str) -> None: + assert isinstance(data, str), f"Expected string but got {type(data)}" + data_bytes = data.encode() + self.write_bytes(name, data_bytes) + + def write_file(self, name: str, file_path: str) -> None: + """ + Copy a file into the archive. + name: The destination file inside the archive. + file_path: The source file on disk. + """ + assert os.path.isfile(file_path), f"{file_path} is not a valid file path" + + with open(file_path, "rb") as f: + file_bytes = f.read() + self.write_bytes(name, file_bytes) + + def close(self) -> None: + self.archive_file.write_end_of_file() + + +class PT2ArchiveReader: + def __init__(self, archive_path: str): + self.archive_file = torch._C.PyTorchFileReader(archive_path) + assert self.read_string("archive_format") == "pt2", "Invalid archive format" + + def __enter__(self): + return self + + def __exit__(self, *args): + # torch._C.PyTorchFileReader doesn't have a close method + pass + + def read_bytes(self, name: str) -> bytes: + return self.archive_file.get_record(name) + + def read_string(self, name: str) -> str: + data = self.read_bytes(name) + return data.decode() + + def get_file_names(self) -> List[str]: + return self.archive_file.get_all_records() + + +def _package_exported_program( + archive_writer: PT2ArchiveWriter, exported_program: ExportedProgram +) -> None: + exported_artifact: SerializedArtifact = serialize(exported_program) + archive_writer.write_bytes(MODELS_FILENAME_FORMAT.format("model"), exported_artifact.exported_program) + archive_writer.write_bytes(os.path.join(WEIGHTS_DIR, "weights.pt"), exported_artifact.state_dict) + archive_writer.write_bytes(os.path.join(CONSTANTS_DIR, "constants.pt"), exported_artifact.constants) + archive_writer.write_bytes(os.path.join(SAMPLE_INPUTS_DIR, "example_inputs.pt"), exported_artifact.example_inputs) + + +def _package_aoti_files(archive_writer: PT2ArchiveWriter, so_path: str): + cpp_file_path = so_path[:-3] + ".cpp" + extern_nodes_file_path = so_path[:-3] + ".json" + work_dir = pathlib.Path(so_path).parent + cubin_file_paths = glob.glob(f"{work_dir}/*.cubin") + + package_files = [so_path, cpp_file_path] + package_files.extend(cubin_file_paths) + + if os.path.isfile(extern_nodes_file_path): + package_files.append(extern_nodes_file_path) + + for path in package_files: + filename = os.path.basename(path) + archive_writer.write_file(f"{AOTINDUCTOR_DIR}{filename}", path) + + +def _extract_exported_program(archive_reader: PT2ArchiveReader) -> ExportedProgram: + exported_program_bytes = archive_reader.read_bytes(MODELS_FILENAME_FORMAT.format("model")) + state_dict_bytes = archive_reader.read_bytes(os.path.join(WEIGHTS_DIR, "weights.pt")) + constants_bytes = archive_reader.read_bytes(os.path.join(CONSTANTS_DIR, "constants.pt")) + example_inputs_bytes = archive_reader.read_bytes(os.path.join(SAMPLE_INPUTS_DIR, "example_inputs.pt")) + + artifact: SerializedArtifact = SerializedArtifact( + exported_program_bytes, + state_dict_bytes, + constants_bytes, + example_inputs_bytes, + ) + + deserialized_exported_program = deserialize(artifact) + return deserialized_exported_program + + +def _extract_so(archive_reader: PT2ArchiveReader, device: str) -> Callable: + tmp_output_dir = pathlib.Path("/tmp/aotinductor_loaded_model") + tmp_output_dir.mkdir(exist_ok=True) + + file_names = archive_reader.get_file_names() + aoti_files = [file for file in file_names if file.startswith(AOTINDUCTOR_DIR)] + + so_path = None + for file in aoti_files: + filename = os.path.basename(file) + with open(tmp_output_dir / filename, 'wb') as f: + f.write(archive_reader.read_bytes(file)) + if file.endswith('.so'): + assert so_path is None + so_path = tmp_output_dir / filename + assert so_path is not None + so_path = str(so_path) + + if device == "cpu": + runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg] + elif device == "cuda" or device.startswith("cuda:"): + runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg] + else: + raise RuntimeError("Unsupported device " + device) + + def optimized(*args, **kwargs): + call_spec = runner.get_call_spec() # type: ignore[attr-defined] + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined] + return pytree.tree_unflatten(flat_outputs, out_spec) + + return optimized + + +def aoti_compile( + exported_program: ExportedProgram, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + *, + options: Optional[Dict[str, Any]] = None, +): + archive_path = options["aot_inductor.output_path"] + options["aot_inductor.output_path"] = "" + + so_path = torch._inductor.aot_compile( + exported_program.module(), args, kwargs, options=options + ) + + with PT2ArchiveWriter(archive_path) as archive_writer: + # _package_exported_program(archive_writer, exported_program) + _package_aoti_files(archive_writer, so_path) + + return archive_path + + +def aoti_load(path: str, device: str): + with PT2ArchiveReader(path) as archive_reader: + # exported_program = _extract_exported_program(archive_reader) + optimized = _extract_so(archive_reader, device) + + return optimized diff --git a/_pt2_archive_constants.py b/_pt2_archive_constants.py new file mode 100644 index 000000000..1c9a7743f --- /dev/null +++ b/_pt2_archive_constants.py @@ -0,0 +1,36 @@ +# This file codify PT2 Inference Archive Spec +# https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit?usp=sharing + +# Naming convention +# *_DIR: path to a folder, e.g. "data/aotinductor/" +# *_PATH: absolute path to a file, e.g. "models/merge.json" +# *_FORMAT: naming format of a file, e.g. "models/{}.json" + +ARCHIVE_ROOT_NAME: str = "package" + +# Archive format +ARCHIVE_FORMAT_PATH: str = "archive_format" + +# Model definitions +MODELS_DIR: str = "models/" +MODELS_FILENAME_FORMAT: str = "models/{}.json"; # {model_name} + +# AOTInductor artifacts +AOTINDUCTOR_DIR: str = "data/aotinductor/" + +# weights, including parameters and buffers +WEIGHTS_DIR: str = "data/weights/" +WEIGHT_FILENAME_PREFIX: str = "weight_" + +# constants, including tensor_constants, non-persistent buffers and script objects +CONSTANTS_DIR: str = "data/constants/" +TENSOR_CONSTANT_FILENAME_PREFIX: str = "tensor_" +CUSTOM_OBJ_FILENAME_PREFIX: str = "custom_obj_" + +# sample inputs +SAMPLE_INPUTS_DIR: str = "data/sample_inputs/" +SAMPLE_INPUTS_FILENAME_FORMAT: str = "data/sample_inputs/{}.pt"; # {model_name} + +# extra folder +EXTRA_DIR: str = "extra/" +MODULE_INFO_PATH: str = "extra/module_info.json" diff --git a/export_aoti.py b/export_aoti.py index 7a5306b5b..fb83b5484 100644 --- a/export_aoti.py +++ b/export_aoti.py @@ -19,6 +19,8 @@ from model import Transformer +from _package_aoti import aoti_compile + default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' @@ -47,11 +49,11 @@ def export_model(model: nn.Module, device, output_path, args=None): # Specify that the first dimension of each input is that batch size dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}} - so = torch._export.aot_compile( - model, - args=input, - options={"aot_inductor.output_path": output_path}, - dynamic_shapes=dynamic_shapes, + ep = torch.export.export( + model, args=input, dynamic_shapes=dynamic_shapes, + ) + package_path = aoti_compile( + ep, input, options={"aot_inductor.output_path": output_path} ) - print(f"The generated DSO model can be found at: {so}") - return so + print(f"The generated PT2 model can be found at: {package_path}") + return package_path diff --git a/generate.py b/generate.py index 14b47bb6e..5178217a7 100644 --- a/generate.py +++ b/generate.py @@ -362,7 +362,8 @@ def main( # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. - model.forward = torch._export.aot_load(str(dso_path.absolute()), device) + from _package_aoti import aoti_load + model.forward = aoti_load(str(dso_path.absolute()), device) except: raise RuntimeError(f"Failed to load AOTI compiled {dso_path}") elif pte_path: @@ -387,7 +388,7 @@ def main( # dtype: if model_dtype: model.to(dtype=model_dtype) - + if is_speculative: draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp) else: