From eb114ba9d9a63e2e33bbb6b8ba68a42b491f0a91 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Thu, 19 Oct 2023 12:45:31 -0400 Subject: [PATCH] [deepsparse.benchmark] enable internal kv cache by default (#1335) * [deepsparse.benchmark] enable internal kv cache by default * remove requirement for sequence length to be set to run in kv cache mode * add option to disable all kv cache overrides * add store_true * argparse fix --- src/deepsparse/benchmark/benchmark_model.py | 28 ++++++++++++++++----- src/deepsparse/utils/onnx.py | 22 ++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/deepsparse/benchmark/benchmark_model.py b/src/deepsparse/benchmark/benchmark_model.py index 8636d0831e..16d921f6f3 100644 --- a/src/deepsparse/benchmark/benchmark_model.py +++ b/src/deepsparse/benchmark/benchmark_model.py @@ -132,6 +132,7 @@ from deepsparse.utils import ( generate_random_inputs, has_model_kv_cache, + infer_sequence_length, model_to_path, override_onnx_input_shapes, overwrite_onnx_model_inputs_for_kv_cache_models, @@ -268,12 +269,13 @@ def parse_args(): ), ) parser.add_argument( - "--internal-kv-cache", - "--internal_kv_cache", + "--no-internal-kv-cache", + "--no_internal_kv_cache", help=( - "DeepSparse engine only - If True, and a model with KV cache, " + "DeepSparse engine only - If not present, and model has KV cache, " "KV Cache state will be managed within the compiled deepsparse " - "model. This is preferred when applicable for best performance" + "model. This is preferred when applicable for best performance. Set " + "flag to disable" ), action="store_true", default=False, @@ -292,6 +294,16 @@ def parse_args(): type=str, default=None, ) + parser.add_argument( + "--disable-kv-cache-overrides", + "--disable_kv_cache_overrides", + help=( + "If set, deepsparse.benchmark will not alter the model " + "with kv cache overrides" + ), + action="store_true", + default=False, + ) return parser.parse_args() @@ -328,6 +340,7 @@ def benchmark_model( internal_kv_cache: bool = False, quiet: bool = False, export_path: Optional[str] = None, + disable_kv_cache_overrides: bool = False, ) -> Dict: if quiet: set_logging_level(logging.WARN) @@ -345,7 +358,9 @@ def benchmark_model( model_path = model_to_path(model_path) cached_outputs = None - if sequence_length and input_ids_length and has_model_kv_cache(model_path): + if not disable_kv_cache_overrides and has_model_kv_cache(model_path): + if not sequence_length: + sequence_length = infer_sequence_length(model_path) if input_ids_length > sequence_length: raise ValueError( f"input_ids_length: {input_ids_length} " @@ -474,9 +489,10 @@ def main(): input_ids_length=args.input_ids_length, thread_pinning=args.thread_pinning, engine=args.engine, - internal_kv_cache=args.internal_kv_cache, + internal_kv_cache=not args.no_internal_kv_cache, quiet=args.quiet, export_path=args.export_path, + disable_kv_cache_overrides=args.disable_kv_cache_overrides, ) # Results summary diff --git a/src/deepsparse/utils/onnx.py b/src/deepsparse/utils/onnx.py index ca43cc112d..f5af4c0718 100644 --- a/src/deepsparse/utils/onnx.py +++ b/src/deepsparse/utils/onnx.py @@ -52,6 +52,7 @@ "truncate_onnx_embedding_model", "overwrite_onnx_model_inputs_for_kv_cache_models", "default_cached_outputs", + "infer_sequence_length", "has_model_kv_cache", "CACHE_INPUT_PREFIX", "CACHE_OUTPUT_PREFIX", @@ -592,3 +593,24 @@ def has_model_kv_cache(model: Union[str, ModelProto]) -> bool: :return True if the model has a KV cache support, False otherwise. """ return bool(any(default_cached_outputs(model))) + + +def infer_sequence_length(model: Union[str, ModelProto]) -> int: + """ + :param model: model + :return: inferred sequence length of the model + """ + if not isinstance(model, ModelProto): + model = onnx.load(model, load_external_data=False) + + # try to find attention mask dim, default to 0 + target_input_idx = 0 + for idx, inp in enumerate(model.graph.input): + if inp.name == "attention_mask": + target_input_idx = idx + try: + # return shape of second dim if possible + target_input = model.graph.input[target_input_idx] + return target_input.type.tensor_type.shape.dim[1].dim_value + except Exception: + return 0 # unable to infer seq len