Skip to content

Commit

Permalink
Merge branch 'main' into v2/logger
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm authored Feb 15, 2024
2 parents 067a38a + c4a7b68 commit 590251f
Show file tree
Hide file tree
Showing 32 changed files with 1,125 additions and 596 deletions.
3 changes: 2 additions & 1 deletion examples/benchmark/resnet50_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@

import numpy

from deepsparse import benchmark_model, cpu
from deepsparse import cpu
from deepsparse.engine import benchmark_model


CORES_PER_SOCKET, AVX_TYPE, VNNI = cpu.cpu_details()
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _parse_requirements_file(file_path):
"datasets<2.16",
"accelerate<0.26",
"seqeval",
"evaluate",
]
_sentence_transformers_integration_deps = ["optimum-deepsparse"] + _torch_deps

Expand Down Expand Up @@ -308,7 +309,7 @@ def _setup_entry_points() -> Dict:
f"deepsparse.image_classification.eval={ic_eval}",
"deepsparse.license=deepsparse.license:main",
"deepsparse.validate_license=deepsparse.license:validate_license_cli",
"deepsparse.eval=deepsparse.evaluation.cli:main",
"deepsparse.evaluate=deepsparse.evaluation.cli:main",
]
}

Expand Down
4 changes: 4 additions & 0 deletions src/deepsparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,9 @@
from .version import __version__, is_release
from .analytics import deepsparse_analytics as _analytics
from .subgraph_execute import *
from .analyze import analyze
from .evaluation.evaluator import evaluate
from .benchmark.benchmark_model import benchmark_model
from .benchmark.benchmark_pipeline import benchmark_pipeline

_analytics.send_event("python__init")
56 changes: 37 additions & 19 deletions src/deepsparse/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
ModelAnalysis,
NodeInferenceResult,
)
from sparsezoo.analyze.cli import analyze_options, analyze_performance_options
from sparsezoo.analyze.cli import (
DEEPSPARSE_ENGINE,
analyze_options,
analyze_performance_options,
)


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,21 +78,11 @@ def main(
)

_LOGGER.info("Starting Analysis ...")
analysis = ModelAnalysis.create(model_path)
_LOGGER.info("Analysis complete, collating results...")
scenario = BenchmarkScenario(
batch_size=batch_size_throughput,
num_cores=None,
engine=benchmark_engine,
)
performance_summary = run_benchmark_and_analysis(
onnx_model=model_to_path(model_path),
scenario=scenario,
)
analysis = analyze(model_path, batch_size_throughput, benchmark_engine)

by_types: bool = convert_to_bool(by_types)
by_layers: bool = convert_to_bool(by_layers)

analysis.benchmark_results = [performance_summary]
summary = analysis.summary(
by_types=by_types,
by_layers=by_layers,
Expand All @@ -103,13 +97,9 @@ def main(

print("Comparison Analysis:")
for model_to_compare in compare:
compare_model_analysis = ModelAnalysis.create(model_to_compare)
_LOGGER.info(f"Running Performance Analysis on {model_to_compare}")
performance_summary = run_benchmark_and_analysis(
onnx_model=model_to_path(model_to_compare),
scenario=scenario,
compare_model_analysis = analyze(
model_to_compare, batch_size_throughput, benchmark_engine
)
compare_model_analysis.benchmark_results = [performance_summary]
summary_comparison_model = compare_model_analysis.summary(
by_types=by_types,
by_layers=by_layers,
Expand All @@ -124,6 +114,34 @@ def main(
analysis.yaml(file_path=save)


def analyze(
model_path,
batch_size_throughput: int = 1,
benchmark_engine: str = DEEPSPARSE_ENGINE,
) -> ModelAnalysis:
"""
:param model_path: Local filepath to an ONNX model, or a SparseZoo stub
:param batch_size_throughput: Batch size for throughput benchmark
:param benchmark_engine: Benchmark engine to use, can be 'deepsparse' or
'onnxruntime', defaults to 'deepsparse'
:return: A `ModelAnalysis` object encapsulating the results of the analysis
"""
analysis = ModelAnalysis.create(model_path)
_LOGGER.info("Analysis complete, collating results...")
scenario = BenchmarkScenario(
batch_size=batch_size_throughput,
num_cores=None,
engine=benchmark_engine,
)
performance_summary = run_benchmark_and_analysis(
onnx_model=model_to_path(model_path),
scenario=scenario,
)

analysis.benchmark_results = [performance_summary]
return analysis


def run_benchmark_and_analysis(
onnx_model: str,
scenario: BenchmarkScenario,
Expand Down
5 changes: 5 additions & 0 deletions src/deepsparse/benchmark/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ def benchmark_model(
if not disable_kv_cache_overrides:
if not sequence_length:
sequence_length = infer_sequence_length(model_path)
if not sequence_length:
raise ValueError(
"Unable to infer sequence length from model. "
"Specify it manually through `sequence_length` argument."
)
if input_ids_length > sequence_length:
raise ValueError(
f"input_ids_length: {input_ids_length} "
Expand Down
38 changes: 14 additions & 24 deletions src/deepsparse/evaluation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
Module for evaluating models on the various evaluation integrations
OPTIONS:
--target TARGET A path to a remote or local directory containing ONNX/torch model
MODEL_PATH
A path to an ONNX model, local directory containing ONNX model
(including all the auxiliary files) or a SparseZoo stub
-d DATASET, --dataset DATASET
The dataset to evaluate on. The user may pass multiple datasets
Expand All @@ -30,9 +31,7 @@
integration name that is registered in the evaluation registry
-e ENGINE_TYPE, --engine_type ENGINE_TYPE
Inference engine to use for the evaluation. The default
is the DeepSparse engine. If the evaluation should be run
without initializing a pipeline (e.g. for the evaluation
of a torch model), the engine type should be set to None
is the DeepSparse engine.
-s SAVE_PATH, --save_path SAVE_PATH
The path to save the evaluation results.
By default the results will be saved in the
Expand Down Expand Up @@ -73,7 +72,7 @@

from deepsparse.evaluation.evaluator import evaluate
from deepsparse.evaluation.results import Result, save_result
from deepsparse.evaluation.utils import args_to_dict, get_save_path
from deepsparse.evaluation.utils import get_save_path, parse_kwarg_tuples
from deepsparse.operators.engine_operator import (
DEEPSPARSE_ENGINE,
ORT_ENGINE,
Expand All @@ -89,12 +88,10 @@
ignore_unknown_options=True,
)
)
@click.option(
"--target",
@click.argument(
"model_path",
type=click.Path(dir_okay=True, file_okay=True),
required=True,
help="A path to a remote or local directory containing ONNX/torch model "
"(including all the auxiliary files) or a SparseZoo stub",
)
@click.option(
"-d",
Expand All @@ -118,9 +115,7 @@
type=click.Choice([DEEPSPARSE_ENGINE, ORT_ENGINE, TORCHSCRIPT_ENGINE]),
default=DEEPSPARSE_ENGINE,
help="The engine to use for the evaluation. The default is the "
"DeepSparse engine. If the evaluation should be run without "
"initializing a pipeline (e.g. for the evaluation of a torch "
"model), the engine type should be set to None",
"DeepSparse engine. ",
)
@click.option(
"-s",
Expand Down Expand Up @@ -167,7 +162,7 @@
)
@click.argument("integration_args", nargs=-1, type=click.UNPROCESSED)
def main(
target,
model_path,
dataset,
integration,
engine_type,
Expand All @@ -181,16 +176,11 @@ def main(
# join datasets to a list if multiple datasets are passed
datasets = list(dataset) if not isinstance(dataset, str) else dataset
# format kwargs to a dict
integration_args = args_to_dict(integration_args)
integration_args = parse_kwarg_tuples(integration_args)

_LOGGER.info(f"Target to evaluate: {target}")
if engine_type:
_LOGGER.info(f"A pipeline with the engine type: {engine_type} will be created")
else:
_LOGGER.info(
"No engine type specified. The target "
"will be evaluated using the native framework"
)
_LOGGER.info(
f"Creating {engine_type} pipeline to evaluate from model path: {model_path}"
)

_LOGGER.info(
f"Datasets to evaluate on: {datasets}\n"
Expand All @@ -201,7 +191,7 @@ def main(
)

result: Result = evaluate(
target=target,
model=model_path,
datasets=datasets,
integration=integration,
engine_type=engine_type,
Expand All @@ -211,7 +201,7 @@ def main(
**integration_args,
)

_LOGGER.info(f"Evaluation done. Results:\n{result}")
_LOGGER.info(f"Evaluation done. Results:\n{result.formatted}")

save_path = get_save_path(
save_path=save_path,
Expand Down
35 changes: 22 additions & 13 deletions src/deepsparse/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, List, Optional, Union
from pathlib import Path
from typing import List, Optional, Union

from deepsparse import Pipeline
from deepsparse.evaluation.registry import EvaluationRegistry
from deepsparse.evaluation.results import Result
from deepsparse.evaluation.utils import create_model_from_target
from deepsparse.evaluation.utils import create_pipeline
from deepsparse.operators.engine_operator import (
DEEPSPARSE_ENGINE,
ORT_ENGINE,
Expand All @@ -30,32 +32,39 @@


def evaluate(
target: Any,
model: Union[Pipeline, Path, str],
datasets: Union[str, List[str]],
integration: Optional[str] = None,
engine_type: Union[
DEEPSPARSE_ENGINE, ORT_ENGINE, TORCHSCRIPT_ENGINE, None
DEEPSPARSE_ENGINE, ORT_ENGINE, TORCHSCRIPT_ENGINE
] = DEEPSPARSE_ENGINE,
batch_size: int = 1,
splits: Union[List[str], str, None] = None,
metrics: Union[List[str], str, None] = None,
**kwargs,
) -> Result:

# if target is a string, turn it into an appropriate model/pipeline
# otherwise assume it is a model/pipeline
model = (
create_model_from_target(target, engine_type)
if isinstance(target, str)
else target
if isinstance(model, Pipeline):
_LOGGER.info(
"Passed a Pipeline object into evaluate function. This will "
"override the following arguments:"
)
batch_size = model.batch_size
_LOGGER.info(f"batch_size: {batch_size}")
engine_type = engine_type
_LOGGER.info(f"engine_type: {engine_type}")

# if target is a string, turn it into an appropriate pipeline
# otherwise assume it is a pipeline
pipeline = (
create_pipeline(model, engine_type) if isinstance(model, (Path, str)) else model
)

eval_integration = EvaluationRegistry.resolve(model, datasets, integration)
eval_integration = EvaluationRegistry.resolve(pipeline, datasets, integration)

return eval_integration(
model=model,
pipeline=pipeline,
datasets=datasets,
engine_type=engine_type,
batch_size=batch_size,
splits=splits,
metrics=metrics,
Expand Down
6 changes: 3 additions & 3 deletions src/deepsparse/evaluation/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# flake8: noqa: F401


def try_import_lm_evaluation_harness(raise_error=False):
def try_import_lm_evaluation_harness(raise_error=True):
try:
import lm_eval

Expand All @@ -24,11 +24,11 @@ def try_import_lm_evaluation_harness(raise_error=False):
if raise_error:
raise ImportError(
"Unable to import lm_eval. "
"To install run 'pip install "
"git+https://github.com/EleutherAI/lm-evaluation-harness@b018a7d51'"
"To install run 'pip install lm-eval==0.4.0'"
)
return False


if try_import_lm_evaluation_harness(raise_error=False):
from .lm_evaluation_harness import *
from .perplexity import *
Loading

0 comments on commit 590251f

Please sign in to comment.