Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[deepsparse.benchmark] enable internal kv cache by default #1335

Merged
merged 5 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions src/deepsparse/benchmark/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
from deepsparse.utils import (
generate_random_inputs,
has_model_kv_cache,
infer_sequence_length,
model_to_path,
override_onnx_input_shapes,
overwrite_onnx_model_inputs_for_kv_cache_models,
Expand Down Expand Up @@ -268,12 +269,13 @@ def parse_args():
),
)
parser.add_argument(
"--internal-kv-cache",
"--internal_kv_cache",
"--no-internal-kv-cache",
"--no_internal_kv_cache",
help=(
"DeepSparse engine only - If True, and a model with KV cache, "
"DeepSparse engine only - If not present, and model has KV cache, "
"KV Cache state will be managed within the compiled deepsparse "
"model. This is preferred when applicable for best performance"
"model. This is preferred when applicable for best performance. Set "
"flag to disable"
),
action="store_true",
default=False,
Expand All @@ -292,6 +294,16 @@ def parse_args():
type=str,
default=None,
)
parser.add_argument(
"--disable-kv-cache-overrides",
"--disable_kv_cache_overrides",
help=(
"If set, deepsparse.benchmark will not alter the model "
"with kv cache overrides"
),
action="store_true",
default=False,
)

return parser.parse_args()

Expand Down Expand Up @@ -328,6 +340,7 @@ def benchmark_model(
internal_kv_cache: bool = False,
quiet: bool = False,
export_path: Optional[str] = None,
disable_kv_cache_overrides: bool = False,
) -> Dict:
if quiet:
set_logging_level(logging.WARN)
Expand All @@ -345,7 +358,9 @@ def benchmark_model(
model_path = model_to_path(model_path)

cached_outputs = None
if sequence_length and input_ids_length and has_model_kv_cache(model_path):
if not disable_kv_cache_overrides and has_model_kv_cache(model_path):
if not sequence_length:
sequence_length = infer_sequence_length(model_path)
if input_ids_length > sequence_length:
raise ValueError(
f"input_ids_length: {input_ids_length} "
Expand Down Expand Up @@ -474,9 +489,10 @@ def main():
input_ids_length=args.input_ids_length,
thread_pinning=args.thread_pinning,
engine=args.engine,
internal_kv_cache=args.internal_kv_cache,
internal_kv_cache=not args.no_internal_kv_cache,
quiet=args.quiet,
export_path=args.export_path,
disable_kv_cache_overrides=args.disable_kv_cache_overrides,
)

# Results summary
Expand Down
22 changes: 22 additions & 0 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"truncate_onnx_embedding_model",
"overwrite_onnx_model_inputs_for_kv_cache_models",
"default_cached_outputs",
"infer_sequence_length",
"has_model_kv_cache",
"CACHE_INPUT_PREFIX",
"CACHE_OUTPUT_PREFIX",
Expand Down Expand Up @@ -592,3 +593,24 @@ def has_model_kv_cache(model: Union[str, ModelProto]) -> bool:
:return True if the model has a KV cache support, False otherwise.
"""
return bool(any(default_cached_outputs(model)))


def infer_sequence_length(model: Union[str, ModelProto]) -> int:
"""
:param model: model
:return: inferred sequence length of the model
"""
if not isinstance(model, ModelProto):
model = onnx.load(model, load_external_data=False)

# try to find attention mask dim, default to 0
target_input_idx = 0
for idx, inp in enumerate(model.graph.input):
if inp.name == "attention_mask":
target_input_idx = idx
try:
# return shape of second dim if possible
target_input = model.graph.input[target_input_idx]
return target_input.type.tensor_type.shape.dim[1].dim_value
except Exception:
return 0 # unable to infer seq len