Skip to content

Commit

Permalink
[deepsparse.benchmark] enable internal kv cache by default (#1335)
Browse files Browse the repository at this point in the history
* [deepsparse.benchmark] enable internal kv cache by default

* remove requirement for sequence length to be set to run in kv cache mode

* add option to disable all kv cache overrides

* add store_true

* argparse fix
  • Loading branch information
bfineran authored Oct 19, 2023
1 parent 869af57 commit eb114ba
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/deepsparse/benchmark/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
from deepsparse.utils import (
generate_random_inputs,
has_model_kv_cache,
infer_sequence_length,
model_to_path,
override_onnx_input_shapes,
overwrite_onnx_model_inputs_for_kv_cache_models,
Expand Down Expand Up @@ -268,12 +269,13 @@ def parse_args():
),
)
parser.add_argument(
"--internal-kv-cache",
"--internal_kv_cache",
"--no-internal-kv-cache",
"--no_internal_kv_cache",
help=(
"DeepSparse engine only - If True, and a model with KV cache, "
"DeepSparse engine only - If not present, and model has KV cache, "
"KV Cache state will be managed within the compiled deepsparse "
"model. This is preferred when applicable for best performance"
"model. This is preferred when applicable for best performance. Set "
"flag to disable"
),
action="store_true",
default=False,
Expand All @@ -292,6 +294,16 @@ def parse_args():
type=str,
default=None,
)
parser.add_argument(
"--disable-kv-cache-overrides",
"--disable_kv_cache_overrides",
help=(
"If set, deepsparse.benchmark will not alter the model "
"with kv cache overrides"
),
action="store_true",
default=False,
)

return parser.parse_args()

Expand Down Expand Up @@ -328,6 +340,7 @@ def benchmark_model(
internal_kv_cache: bool = False,
quiet: bool = False,
export_path: Optional[str] = None,
disable_kv_cache_overrides: bool = False,
) -> Dict:
if quiet:
set_logging_level(logging.WARN)
Expand All @@ -345,7 +358,9 @@ def benchmark_model(
model_path = model_to_path(model_path)

cached_outputs = None
if sequence_length and input_ids_length and has_model_kv_cache(model_path):
if not disable_kv_cache_overrides and has_model_kv_cache(model_path):
if not sequence_length:
sequence_length = infer_sequence_length(model_path)
if input_ids_length > sequence_length:
raise ValueError(
f"input_ids_length: {input_ids_length} "
Expand Down Expand Up @@ -474,9 +489,10 @@ def main():
input_ids_length=args.input_ids_length,
thread_pinning=args.thread_pinning,
engine=args.engine,
internal_kv_cache=args.internal_kv_cache,
internal_kv_cache=not args.no_internal_kv_cache,
quiet=args.quiet,
export_path=args.export_path,
disable_kv_cache_overrides=args.disable_kv_cache_overrides,
)

# Results summary
Expand Down
22 changes: 22 additions & 0 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"truncate_onnx_embedding_model",
"overwrite_onnx_model_inputs_for_kv_cache_models",
"default_cached_outputs",
"infer_sequence_length",
"has_model_kv_cache",
"CACHE_INPUT_PREFIX",
"CACHE_OUTPUT_PREFIX",
Expand Down Expand Up @@ -592,3 +593,24 @@ def has_model_kv_cache(model: Union[str, ModelProto]) -> bool:
:return True if the model has a KV cache support, False otherwise.
"""
return bool(any(default_cached_outputs(model)))


def infer_sequence_length(model: Union[str, ModelProto]) -> int:
"""
:param model: model
:return: inferred sequence length of the model
"""
if not isinstance(model, ModelProto):
model = onnx.load(model, load_external_data=False)

# try to find attention mask dim, default to 0
target_input_idx = 0
for idx, inp in enumerate(model.graph.input):
if inp.name == "attention_mask":
target_input_idx = idx
try:
# return shape of second dim if possible
target_input = model.graph.input[target_input_idx]
return target_input.type.tensor_type.shape.dim[1].dim_value
except Exception:
return 0 # unable to infer seq len

0 comments on commit eb114ba

Please sign in to comment.