diff --git a/Makefile b/Makefile index 4383e66f8f..180fd2c55b 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ .PHONY: build docs test BUILDDIR := $(PWD) +BUILD_ARGS := # set nightly to build nightly release CHECKDIRS := examples tests src utils notebooks setup.py PYCHECKGLOBS := 'examples/**/*.py' 'scripts/**/*.py' 'src/**/*.py' 'tests/**/*.py' 'utils/**/*.py' setup.py DOCDIR := docs @@ -43,7 +44,7 @@ docs: # creates wheel file build: - python3 setup.py sdist bdist_wheel + python3 setup.py sdist bdist_wheel $(BUILD_ARGS) # clean package clean: diff --git a/README.md b/README.md index f841f4827c..a63c937031 100644 --- a/README.md +++ b/README.md @@ -16,14 +16,13 @@ limitations under the License. # ![icon for DeepSparse](https://raw.githubusercontent.com/neuralmagic/deepsparse/main/docs/source/icon-deepsparse.png) DeepSparse Engine -### CPU inference engine that delivers unprecedented performance for sparse models +### Neural network inference engine that delivers GPU-class performance for sparsified models on CPUs -

GitHub - GitHub + GitHub Documentation @@ -47,20 +46,29 @@ limitations under the License. ## Overview -The DeepSparse Engine is a CPU runtime that delivers unprecedented performance by taking advantage of natural sparsity within neural networks to reduce compute required as well as accelerate memory bound workloads. It is focused on model deployment and scaling machine learning pipelines, fitting seamlessly into your existing deployments as an inference backend. +The DeepSparse Engine is a CPU runtime that delivers GPU-class performance by taking advantage of sparsity within neural networks to reduce compute required as well as accelerate memory bound workloads. +It is focused on model deployment and scaling machine learning pipelines, fitting seamlessly into your existing deployments as an inference backend. -This repository includes package APIs along with examples to quickly get started learning about and actually running sparse models. +This repository includes package APIs along with examples to quickly get started benchmarking and inferencing sparse models. -### Related Products +## Sparsification -- [SparseZoo](https://github.com/neuralmagic/sparsezoo): - Neural network model repository for highly sparse models and optimization recipes -- [SparseML](https://github.com/neuralmagic/sparseml): - Libraries for state-of-the-art deep neural network optimization algorithms, - enabling simple pipelines integration with a few lines of code -- [Sparsify](https://github.com/neuralmagic/sparsify): - Easy-to-use autoML interface to optimize deep neural networks for - better inference performance and a smaller footprint +Sparsification is the process of taking a trained deep learning model and removing redundant information from the overprecise and over-parameterized network resulting in a faster and smaller model. +Techniques for sparsification are all encompassing including everything from inducing sparsity using [pruning](https://neuralmagic.com/blog/pruning-overview/) and [quantization](https://arxiv.org/abs/1609.07061) to enabling naturally occurring sparsity using [activation sparsity](http://proceedings.mlr.press/v119/kurtz20a.html) or [winograd/FFT](https://arxiv.org/abs/1509.09308). +When implemented correctly, these techniques result in significantly more performant and smaller models with limited to no effect on the baseline metrics. +For example, pruning plus quantization can give noticeable improvements in performance while recovering to nearly the same baseline accuracy. + +The Deep Sparse product suite builds on top of sparsification enabling you to easily apply the techniques to your datasets and models using recipe-driven approaches. +Recipes encode the directions for how to sparsify a model into a simple, easily editable format. +- Download a sparsification recipe and sparsified model from the [SparseZoo](https://github.com/neuralmagic/sparsezoo). +- Alternatively, create a recipe for your model using [Sparsify](https://github.com/neuralmagic/sparsify). +- Apply your recipe with only a few lines of code using [SparseML](https://github.com/neuralmagic/sparseml). +- Finally, for GPU-level performance on CPUs, deploy your sparse-quantized model with the [DeepSparse Engine](https://github.com/neuralmagic/deepsparse). + + +**Full Deep Sparse product flow:** + + ## Compatibility @@ -68,21 +76,22 @@ The DeepSparse Engine ingests models in the [ONNX](https://onnx.ai/) format, all ## Quick Tour -To expedite inference and benchmarking on real models, we include the `sparsezoo` package. [SparseZoo](https://github.com/neuralmagic/sparsezoo) hosts inference optimized models, trained on repeatable optimization recipes using state-of-the-art techniques from [SparseML](https://github.com/neuralmagic/sparseml). +To expedite inference and benchmarking on real models, we include the `sparsezoo` package. [SparseZoo](https://github.com/neuralmagic/sparsezoo) hosts inference-optimized models, trained on repeatable sparsification recipes using state-of-the-art techniques from [SparseML](https://github.com/neuralmagic/sparseml). ### Quickstart with SparseZoo ONNX Models -**MobileNetV1 Dense** +**ResNet-50 Dense** -Here is how to quickly perform inference with DeepSparse Engine on a pre-trained dense MobileNetV1 from SparseZoo. +Here is how to quickly perform inference with DeepSparse Engine on a pre-trained dense ResNet-50 from SparseZoo. ```python from deepsparse import compile_model from sparsezoo.models import classification + batch_size = 64 # Download model and compile as optimized executable for your machine -model = classification.mobilenet_v1() +model = classification.resnet_50() engine = compile_model(model, batch_size=batch_size) # Fetch sample input and predict output using engine @@ -90,44 +99,68 @@ inputs = model.data_inputs.sample_batch(batch_size=batch_size) outputs, inference_time = engine.timed_run(inputs) ``` -**MobileNetV1 Optimized** +**ResNet-50 Sparsified** When exploring available optimized models, you can use the `Zoo.search_optimized_models` utility to find models that share a base. -Let us try this on the dense MobileNetV1 to see what is available. +Try this on the dense ResNet-50 to see what is available: ```python from sparsezoo import Zoo from sparsezoo.models import classification -print(Zoo.search_optimized_models(classification.mobilenet_v1())) + +model = classification.resnet_50() +print(Zoo.search_optimized_models(model)) ``` Output: ```shell -[Model(stub=cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/base-none), - Model(stub=cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/pruned-conservative), - Model(stub=cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/pruned-moderate), - Model(stub=cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/pruned_quant-moderate)] +[ + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/base-none), + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned-conservative), + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned-moderate), + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate), + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet-augmented/pruned_quant-aggressive) +] ``` -Great. We can see there are two pruned versions targeting FP32, `conservative` at 100% and `moderate` at >= 99% of baseline accuracy. There is also a `pruned_quant` variant targetting INT8. +We can see there are two pruned versions targeting FP32 and two pruned, quantized versions targeting INT8. +The `conservative`, `moderate`, and `aggressive` tags recover to 100%, >=99%, and <99% of baseline accuracy respectively. -Let's say you want to evaluate best performance on FP32 and are okay with a small drop in accuracy, so we can choose `pruned-moderate` over `pruned-conservative`. +For a version of ResNet-50 that recovers close to the baseline and is very performant, choose the pruned_quant-moderate model. +This model will run [nearly 7x faster](https://neuralmagic.com/blog/benchmark-resnet50-with-deepsparse) than the baseline model on a compatible CPU (with the VNNI instruction set enabled). +For hardware compatibility, see the Hardware Support section. ```python from deepsparse import compile_model -from sparsezoo.models import classification -batch_size = 64 - -model = classification.mobilenet_v1(optim_name="pruned", optim_category="moderate") -engine = compile_model(model, batch_size=batch_size) +import numpy -inputs = model.data_inputs.sample_batch(batch_size=batch_size) -outputs, inference_time = engine.timed_run(inputs) +batch_size = 64 +sample_inputs = [numpy.random.randn(batch_size, 3, 224, 224).astype(numpy.float32)] + +# run baseline benchmarking +engine_base = compile_model( + model="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/base-none", + batch_size=batch_size, +) +benchmarks_base = engine_base.benchmark(sample_inputs) +print(benchmarks_base) + +# run sparse benchmarking +engine_sparse = compile_model( + model="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate", + batch_size=batch_size, +) +if not engine_sparse.cpu_vnni: + print("WARNING: VNNI instructions not detected, quantization speedup not well supported") +benchmarks_sparse = engine_sparse.benchmark(sample_inputs) +print(benchmarks_sparse) + +print(f"Speedup: {benchmarks_sparse.items_per_second / benchmarks_base.items_per_second:.2f}x") ``` -### Quickstart with custom ONNX models +### Quickstart with Custom ONNX Models We accept ONNX files for custom models, too. Simply plug in your model to compare performance with other solutions. diff --git a/docs/source/conf.py b/docs/source/conf.py index 7ae82eb6cc..ad5a4bb2d6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -84,7 +84,7 @@ # a list of builtin themes. # html_theme = "sphinx_rtd_theme" -html_logo = "icon-engine.png" +html_logo = "icon-deepsparse.png" html_theme_options = { 'analytics_id': 'UA-128364174-1', # Provided by Google in your dashboard diff --git a/docs/source/icon-engine.png b/docs/source/icon-deepsparse.png similarity index 100% rename from docs/source/icon-engine.png rename to docs/source/icon-deepsparse.png diff --git a/docs/source/index.rst b/docs/source/index.rst index d140b6f814..ef78b48a2a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,13 +17,16 @@ DeepSparse |version| ==================== -CPU inference engine that delivers unprecedented performance for sparse models. +Neural network inference engine that delivers GPU-class performance for sparsified models on CPUs .. raw:: html

- - GitHub + + GitHub + + + GitHub Documentation @@ -48,54 +51,59 @@ CPU inference engine that delivers unprecedented performance for sparse models. Overview ======== -The DeepSparse Engine is a CPU runtime that delivers unprecedented performance by taking advantage of -natural sparsity within neural networks to reduce compute required as well as accelerate memory bound workloads. -It is focused on model deployment and scaling machine learning pipelines, -fitting seamlessly into your existing deployments as an inference backend. +The DeepSparse Engine is a CPU runtime that delivers GPU-class performance by taking advantage of sparsity within neural networks to reduce compute required as well as accelerate memory bound workloads. +It is focused on model deployment and scaling machine learning pipelines, fitting seamlessly into your existing deployments as an inference backend. + +`This repository `_ includes package APIs along with examples to quickly get started benchmarking and inferencing sparse models. + +Sparsification +============== + +Sparsification is the process of taking a trained deep learning model and removing redundant information from the overprecise and over-parameterized network resulting in a faster and smaller model. +Techniques for sparsification are all encompassing including everything from inducing sparsity using `pruning `_ and `quantization `_ to enabling naturally occurring sparsity using `activation sparsity `_ or `winograd/FFT `_. +When implemented correctly, these techniques result in significantly more performant and smaller models with limited to no effect on the baseline metrics. +For example, pruning plus quantization can give noticeable improvements in performance while recovering to nearly the same baseline accuracy. + +The Deep Sparse product suite builds on top of sparsification enabling you to easily apply the techniques to your datasets and models using recipe-driven approaches. +Recipes encode the directions for how to sparsify a model into a simple, easily editable format. +- Download a sparsification recipe and sparsified model from the `SparseZoo `_. +- Alternatively, create a recipe for your model using `Sparsify `_. +- Apply your recipe with only a few lines of code using `SparseML `_. +- Finally, for GPU-level performance on CPUs, deploy your sparse-quantized model with the `DeepSparse Engine `_. + -`This GitHub repository `_ includes package APIs along with examples to quickly get started learning about and -actually running sparse models. +**Full Deep Sparse product flow:** + + Compatibility ============= -The DeepSparse Engine ingests models in the `ONNX `_ format, -allowing for compatibility with `PyTorch `_, -`TensorFlow `_, `Keras `_, -and `many other frameworks `_ that support it. +The DeepSparse Engine ingests models in the `ONNX `_ format, +allowing for compatibility with `PyTorch `_, +`TensorFlow `_, `Keras `_, +and `many other frameworks `_ that support it. This reduces the extra work of preparing your trained model for inference to just one step of exporting. -Related Products -================ - -- `SparseZoo `_: - Neural network model repository for highly sparse models and optimization recipes -- `SparseML `_: - Libraries for state-of-the-art deep neural network optimization algorithms, - enabling simple pipelines integration with a few lines of code -- `Sparsify `_: - Easy-to-use autoML interface to optimize deep neural networks for - better inference performance and a smaller footprint - Resources and Learning More =========================== -- `SparseZoo Documentation `_ -- `SparseML Documentation `_ -- `Sparsify Documentation `_ -- `Neural Magic Blog `_, - `Resources `_, - `Website `_ +- `SparseZoo Documentation `_ +- `SparseML Documentation `_ +- `Sparsify Documentation `_ +- `Neural Magic Blog `_, + `Resources `_, + `Website `_ Release History =============== Official builds are hosted on PyPi -- stable: `deepsparse `_ -- nightly (dev): `deepsparse-nightly `_ +- stable: `deepsparse `_ +- nightly (dev): `deepsparse-nightly `_ Additionally, more information can be found via -`GitHub Releases `_. +`GitHub Releases `_. .. toctree:: :maxdepth: 3 @@ -118,8 +126,9 @@ Additionally, more information can be found via api/deepsparse .. toctree:: - :maxdepth: 2 - :caption: Help and Support + :maxdepth: 3 + :caption: Help Bugs, Feature Requests - Support, General Q&A \ No newline at end of file + Support, General Q&A + Neural Magic Docs diff --git a/docs/source/quicktour.md b/docs/source/quicktour.md index 432c267f70..7bb3c94a1a 100644 --- a/docs/source/quicktour.md +++ b/docs/source/quicktour.md @@ -16,24 +16,22 @@ limitations under the License. ## Quick Tour -To expedite inference and benchmarking on real models, we include the `sparsezoo` package. -[SparseZoo](https://github.com/neuralmagic/sparsezoo) hosts inference optimized models, -trained on repeatable optimization recipes using state-of-the-art techniques from -[SparseML](https://github.com/neuralmagic/sparseml). +To expedite inference and benchmarking on real models, we include the `sparsezoo` package. [SparseZoo](https://github.com/neuralmagic/sparsezoo) hosts inference-optimized models, trained on repeatable sparsification recipes using state-of-the-art techniques from [SparseML](https://github.com/neuralmagic/sparseml). ### Quickstart with SparseZoo ONNX Models -**MobileNetV1 Dense** +**ResNet-50 Dense** -Here is how to quickly perform inference with DeepSparse Engine on a pre-trained dense MobileNetV1 from SparseZoo. +Here is how to quickly perform inference with DeepSparse Engine on a pre-trained dense ResNet-50 from SparseZoo. ```python from deepsparse import compile_model from sparsezoo.models import classification + batch_size = 64 # Download model and compile as optimized executable for your machine -model = classification.mobilenet_v1() +model = classification.resnet_50() engine = compile_model(model, batch_size=batch_size) # Fetch sample input and predict output using engine @@ -41,46 +39,68 @@ inputs = model.data_inputs.sample_batch(batch_size=batch_size) outputs, inference_time = engine.timed_run(inputs) ``` -**MobileNetV1 Optimized** +**ResNet-50 Sparsified** -When exploring available optimized models, you can use the `Zoo.search_optimized_models` -utility to find models that share a base. +When exploring available optimized models, you can use the `Zoo.search_optimized_models` utility to find models that share a base. -Let us try this on the dense MobileNetV1 to see what is available. +Try this on the dense ResNet-50 to see what is available: ```python from sparsezoo import Zoo from sparsezoo.models import classification -print(Zoo.search_optimized_models(classification.mobilenet_v1())) + +model = classification.resnet_50() +print(Zoo.search_optimized_models(model)) ``` + Output: -``` -[Model(stub=cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/base-none), - Model(stub=cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/pruned-conservative), - Model(stub=cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/pruned-moderate), - Model(stub=cv/classification/mobilenet_v1-1.0/pytorch/sparseml/imagenet/pruned_quant-moderate)] + +```shell +[ + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/base-none), + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned-conservative), + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned-moderate), + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate), + Model(stub=cv/classification/resnet_v1-50/pytorch/sparseml/imagenet-augmented/pruned_quant-aggressive) +] ``` -Great. We can see there are two pruned versions targeting FP32, -`conservative` at 100% and `moderate` at >= 99% of baseline accuracy. -There is also a `pruned_quant` variant targeting INT8. +We can see there are two pruned versions targeting FP32 and two pruned, quantized versions targeting INT8. +The `conservative`, `moderate`, and `aggressive` tags recover to 100%, >=99%, and <99% of baseline accuracy respectively. -Let's say you want to evaluate best performance on FP32 and are okay with a small drop in accuracy, -so we can choose `pruned-moderate` over `pruned-conservative`. +For a version of ResNet-50 that recovers close to the baseline and is very performant, choose the pruned_quant-moderate model. +This model will run [nearly 7x faster](https://neuralmagic.com/blog/benchmark-resnet50-with-deepsparse) than the baseline model on a compatible CPU (with the VNNI instruction set enabled). +For hardware compatibility, see the Hardware Support section. ```python from deepsparse import compile_model -from sparsezoo.models import classification -batch_size = 64 - -model = classification.mobilenet_v1(optim_name="pruned", optim_category="moderate") -engine = compile_model(model, batch_size=batch_size) +import numpy -inputs = model.data_inputs.sample_batch(batch_size=batch_size) -outputs, inference_time = engine.timed_run(inputs) +batch_size = 64 +sample_inputs = [numpy.random.randn(batch_size, 3, 224, 224).astype(numpy.float32)] + +# run baseline benchmarking +engine_base = compile_model( + model="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/base-none", + batch_size=batch_size, +) +benchmarks_base = engine_base.benchmark(sample_inputs) +print(benchmarks_base) + +# run sparse benchmarking +engine_sparse = compile_model( + model="zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate", + batch_size=batch_size, +) +if not engine_sparse.cpu_vnni: + print("WARNING: VNNI instructions not detected, quantization speedup not well supported") +benchmarks_sparse = engine_sparse.benchmark(sample_inputs) +print(benchmarks_sparse) + +print(f"Speedup: {benchmarks_sparse.items_per_second / benchmarks_base.items_per_second:.2f}x") ``` -### Quickstart with custom ONNX models +### Quickstart with Custom ONNX Models We accept ONNX files for custom models, too. Simply plug in your model to compare performance with other solutions. diff --git a/examples/benchmark/README.md b/examples/benchmark/README.md index ddb56c0e9a..de9f8715d7 100644 --- a/examples/benchmark/README.md +++ b/examples/benchmark/README.md @@ -14,9 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. --> -# Benchmarking and Correctness Examples +# Benchmarking Examples -This directory holds examples for comparing inference on an ONNX model, both for performance and correctness. +This directory holds examples for comparing inference on ONNX models, both for performance and correctness. ## Installation @@ -24,7 +24,16 @@ Install DeepSparse with `pip install deepsparse` and the additional external req ## Execution -### Benchmark +### ResNet-50 Benchmark + +`resnet50_benchmark.py` is a script for benchmarking all of the sparsified ResNet50 V1 models hosted on SparseZoo, on the DeepSparse engine. + +Example command for ResNet50 benchmarks with batch size 128 and 16 cores used: +```bash +python resnet50_benchmark.py --batch_size 128 --num_cores 16 +``` + +### ONNX Benchmark `run_benchmark.py` is a script for benchmarking an ONNX model over random inputs and using both the DeepSparse Engine and ONNXRuntime, comparing results. diff --git a/examples/benchmark/resnet50_benchmark.py b/examples/benchmark/resnet50_benchmark.py new file mode 100644 index 0000000000..35cd6f05d3 --- /dev/null +++ b/examples/benchmark/resnet50_benchmark.py @@ -0,0 +1,161 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script for benchmarking sparsified ResNet50 models from +the SparseZoo on the DeepSparse Engine. + +########## +Command help: +usage: resnet50_benchmark.py [-h] [-s BATCH_SIZE] [-j NUM_CORES] [-b NUM_ITERATIONS] [-w NUM_WARMUP_ITERATIONS] + +Benchmark sparsified ResNet50 models from the SparseZoo + +optional arguments: + -h, --help show this help message and exit + -s BATCH_SIZE, --batch_size BATCH_SIZE + The batch size to run the analysis for + -j NUM_CORES, --num_cores NUM_CORES + The number of physical cores to run the analysis on, defaults to all physical cores available on the system + -b NUM_ITERATIONS, --num_iterations NUM_ITERATIONS + The number of times the benchmark will be run + -w NUM_WARMUP_ITERATIONS, --num_warmup_iterations NUM_WARMUP_ITERATIONS + The number of warmup runs that will be executed before the actual benchmarking + +########## +Example command for ResNet50 benchmarks with batch size 128 and 16 cores used: +python resnet50_benchmark.py \ + --batch_size 128 \ + --num_cores 16 +""" + +import argparse + +import numpy + +from deepsparse import benchmark_model, cpu + + +CORES_PER_SOCKET, AVX_TYPE, _ = cpu.cpu_details() + + +def parse_args(): + parser = argparse.ArgumentParser( + description=("Benchmark sparsified ResNet50 models from the SparseZoo") + ) + + parser.add_argument( + "-s", + "--batch_size", + type=int, + default=64, + help="The batch size to run the analysis for", + ) + parser.add_argument( + "-j", + "--num_cores", + type=int, + default=CORES_PER_SOCKET, + help=( + "The number of physical cores to run the analysis on, " + "defaults to all physical cores available on the system" + ), + ) + parser.add_argument( + "-b", + "--num_iterations", + help="The number of times the benchmark will be run", + type=int, + default=50, + ) + parser.add_argument( + "-w", + "--num_warmup_iterations", + help=( + "The number of warmup runs that will be executed before the actual" + " benchmarking" + ), + type=int, + default=10, + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + batch_size = args.batch_size + num_cores = args.num_cores + num_iterations = args.num_iterations + num_warmup_iterations = args.num_warmup_iterations + + sample_inputs = [numpy.random.randn(batch_size, 3, 224, 224).astype(numpy.float32)] + + print( + f"Starting DeepSparse benchmarks using batch size {batch_size} and {num_cores} cores" + ) + + results = benchmark_model( + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/base-none", + sample_inputs, + batch_size=batch_size, + num_cores=num_cores, + num_iterations=num_iterations, + num_warmup_iterations=num_warmup_iterations, + ) + print(f"ResNet-50 v1 Dense FP32 {results}") + + results = benchmark_model( + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned-conservative", + sample_inputs, + batch_size=batch_size, + num_cores=num_cores, + num_iterations=num_iterations, + num_warmup_iterations=num_warmup_iterations, + ) + print(f"ResNet-50 v1 Pruned Conservative FP32 {results}") + + results = benchmark_model( + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned-moderate", + sample_inputs, + batch_size=batch_size, + num_cores=num_cores, + num_iterations=num_iterations, + num_warmup_iterations=num_warmup_iterations, + ) + print(f"ResNet-50 v1 Pruned Moderate FP32 {results}") + + results = benchmark_model( + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate", + sample_inputs, + batch_size=batch_size, + num_cores=num_cores, + num_iterations=num_iterations, + num_warmup_iterations=num_warmup_iterations, + ) + print(f"ResNet-50 v1 Pruned Moderate INT8 {results}") + + results = benchmark_model( + "zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet-augmented/pruned_quant-aggressive", + sample_inputs, + batch_size=batch_size, + num_cores=num_cores, + num_iterations=num_iterations, + num_warmup_iterations=num_warmup_iterations, + ) + print(f"ResNet-50 v1 Pruned Aggressive INT8 {results}") + + +if __name__ == "__main__": + main() diff --git a/notebooks/classification.ipynb b/notebooks/classification.ipynb index 7b090b4926..fce1d9324e 100644 --- a/notebooks/classification.ipynb +++ b/notebooks/classification.ipynb @@ -63,11 +63,11 @@ "source": [ "## Gathering the Model and Data\n", "\n", - "By default, you will download a MobileNetV1 model trained on the ImageNet dataset.\n", + "By default, you will download a sparsified ResNet-50 model trained on the ImageNet dataset.\n", "The model's pretrained weights and exported ONNX file are downloaded from the SparseZoo model repo.\n", "The sample batch of data is downloaded from SparseZoo as well.\n", "\n", - "If you want to try different architectures replace `mobilenet_v1()` with your choice, for example: `resnet50()` or `efficientnet_b0()`.\n", + "If you want to try different architectures replace `resnet50()` with your choice, for example: `mobilenet_v1()` or `efficientnet_b0()`.\n", "\n", "You may also want to try different batch sizes to evaluate accuracy and performance for your task." ] @@ -95,7 +95,7 @@ "# Define your model below\n", "# =====================================================\n", "print(\"Downloading model...\")\n", - "model = classification.mobilenet_v1()\n", + "model = classification.resnet_50(optim_name=\"pruned_quant\", optim_category=\"moderate\")\n", "\n", "# Gather sample batch of data for inference and visualization\n", "batch = model.sample_batch(batch_size=batch_size)\n", @@ -276,9 +276,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 80621586b9..0dbc6fe79a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,7 +5,7 @@ ensure_newline_before_comments = True force_grid_wrap = 0 include_trailing_comma = True known_first_party = deepsparse,sparsezoo -known_third_party = numpy,onnx,requests,onnxruntime,flask,flask_cors +known_third_party = numpy,onnx,requests,onnxruntime,flask,flask_cors,tqdm sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER line_length = 88 diff --git a/setup.py b/setup.py index 20f6bc56bc..a8a1c88849 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import sys +from datetime import date from distutils import log from fnmatch import fnmatch from typing import Dict, List, Tuple @@ -21,11 +23,28 @@ from setuptools.command.install import install +_PACKAGE_NAME = "deepsparse" +_VERSION = "0.1.1" +_VERSION_MAJOR, _VERSION_MINOR, _VERSION_BUG = _VERSION.split(".") +_VERSION_MAJOR_MINOR = f"{_VERSION_MAJOR}.{_VERSION_MINOR}" +_NIGHTLY = "nightly" in sys.argv + +if _NIGHTLY: + _PACKAGE_NAME += "-nightly" + _VERSION += "." + date.today().strftime("%Y%m%d") + # remove nightly param so it does not break bdist_wheel + sys.argv.remove("nightly") + + # File regexes for binaries to include in package_data binary_regexes = ["*/*.so", "*/*.so.*", "*.bin", "*/*.bin"] -_deps = ["numpy>=1.16.3", "onnx>=1.5.0,<1.8.0", "requests>=2.0.0", "sparsezoo>=0.1.0"] +_deps = ["numpy>=1.16.3", "onnx>=1.5.0,<1.8.0", "requests>=2.0.0", "tqdm>=4.0.0"] + +_nm_deps = [ + f"{'sparsezoo-nightly' if _NIGHTLY else 'sparsezoo'}~={_VERSION_MAJOR_MINOR}" +] _dev_deps = [ "black>=20.8b1", @@ -75,7 +94,7 @@ def _setup_package_data() -> Dict: def _setup_install_requires() -> List: - return _deps + return _nm_deps + _deps def _setup_extras() -> Dict: @@ -91,11 +110,14 @@ def _setup_long_description() -> Tuple[str, str]: setup( - name="deepsparse", - version="0.1.0", + name=_PACKAGE_NAME, + version=_VERSION, author="Neuralmagic, Inc.", author_email="support@neuralmagic.com", - description="CPU runtime that delivers unprecedented performance for sparse models", + description=( + "Neural network inference engine that delivers GPU-class performance " + "for sparsified models on CPUs" + ), long_description=_setup_long_description()[0], long_description_content_type=_setup_long_description()[1], keywords=( diff --git a/src/deepsparse/engine.py b/src/deepsparse/engine.py index a794b2ecd9..1335299fb7 100644 --- a/src/deepsparse/engine.py +++ b/src/deepsparse/engine.py @@ -21,19 +21,25 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union import numpy +from tqdm.auto import tqdm from deepsparse.benchmark import BenchmarkResults try: + from sparsezoo import Zoo from sparsezoo.objects import File, Model -except Exception: + + sparsezoo_import_error = None +except Exception as sparsezoo_err: + Zoo = None Model = object File = object + sparsezoo_import_error = sparsezoo_err try: # flake8: noqa - from deepsparse.cpu import cpu_details + from deepsparse.cpu import cpu_architecture from deepsparse.lib import init_deepsparse_lib from deepsparse.version import * except ImportError: @@ -46,7 +52,11 @@ __all__ = ["Engine", "compile_model", "benchmark_model", "analyze_model"] -CORES_PER_SOCKET, AVX_TYPE, VNNI = cpu_details() +ARCH = cpu_architecture() +CORES_PER_SOCKET = ARCH.available_cores_per_socket +NUM_SOCKETS = ARCH.available_sockets +AVX_TYPE = ARCH.isa +VNNI = ARCH.vnni LIB = init_deepsparse_lib() @@ -55,9 +65,13 @@ def _model_to_path(model: Union[str, Model, File]) -> str: if not model: raise ValueError("model must be a path, sparsezoo.Model, or sparsezoo.File") - if isinstance(model, str): - pass - elif Model is not object and isinstance(model, Model): + if isinstance(model, str) and model.startswith("zoo:"): + # load SparseZoo Model from stub + if sparsezoo_import_error is not None: + raise sparsezoo_import_error + model = Zoo.load_model_from_stub(model) + + if Model is not object and isinstance(model, Model): # default to the main onnx file for the model model = model.onnx_file.downloaded_path() elif File is not object and isinstance(model, File): @@ -90,6 +104,16 @@ def _validate_num_cores(num_cores: Union[None, int]) -> int: return num_cores +def _validate_num_sockets(num_sockets: Union[None, int]) -> int: + if not num_sockets: + num_sockets = NUM_SOCKETS + + if num_sockets < 1: + raise ValueError("num_sockets must be greater than 0") + + return num_sockets + + class Engine(object): """ Create a new DeepSparse Engine that compiles the given onnx file @@ -105,19 +129,29 @@ class Engine(object): | # create an engine for batch size 1 on all available cores | engine = Engine("path/to/onnx", batch_size=1, num_cores=None) - :param model: Either a path to the model's onnx file, a sparsezoo Model object, - or a sparsezoo ONNX File object that defines the neural network + :param model: Either a path to the model's onnx file, a SparseZoo model stub + prefixed by 'zoo:', a SparseZoo Model object, or a SparseZoo ONNX File + object that defines the neural network :param batch_size: The batch size of the inputs to be used with the engine :param num_cores: The number of physical cores to run the model on. Pass None or 0 to run on the max number of cores in one socket for the current machine, default None + :param num_sockets: The number of physical sockets to run the model on. + Pass None or 0 to run on the max number of sockets for the + current machine, default None """ - def __init__(self, model: Union[str, Model, File], batch_size: int, num_cores: int): + def __init__( + self, + model: Union[str, Model, File], + batch_size: int, + num_cores: int, + num_sockets: int = None, + ): self._model_path = _model_to_path(model) self._batch_size = _validate_batch_size(batch_size) self._num_cores = _validate_num_cores(num_cores) - self._num_sockets = 1 # only single socket is supported currently + self._num_sockets = _validate_num_sockets(num_sockets) self._cpu_avx_type = AVX_TYPE self._cpu_vnni = VNNI self._eng_net = LIB.deepsparse_engine( @@ -324,6 +358,7 @@ def benchmark( num_warmup_iterations: int = 5, include_inputs: bool = False, include_outputs: bool = False, + show_progress: bool = False, ) -> BenchmarkResults: """ A convenience function for quickly benchmarking the instantiated model @@ -342,6 +377,7 @@ def benchmark( will be added to the results. Default is False :param include_outputs: If True, outputs from forward passes during benchmarking will be added to the results. Default is False + :param show_progress: If True, will display a progress bar. Default is False :return: the results of benchmarking """ # define data loader @@ -355,6 +391,7 @@ def _infinite_loader(): num_warmup_iterations=num_warmup_iterations, include_inputs=include_inputs, include_outputs=include_outputs, + show_progress=show_progress, ) def benchmark_loader( @@ -364,6 +401,7 @@ def benchmark_loader( num_warmup_iterations: int = 5, include_inputs: bool = False, include_outputs: bool = False, + show_progress: bool = False, ) -> BenchmarkResults: """ A convenience function for quickly benchmarking the instantiated model @@ -382,6 +420,7 @@ def benchmark_loader( will be added to the results. Default is False :param include_outputs: If True, outputs from forward passes during benchmarking will be added to the results. Default is False + :param show_progress: If True, will display a progress bar. Default is False :return: the results of benchmarking """ assert num_iterations >= 1 and num_warmup_iterations >= 0, ( @@ -391,13 +430,15 @@ def benchmark_loader( completed_iterations = 0 results = BenchmarkResults() + if show_progress: + progress_bar = tqdm(total=num_iterations) + while completed_iterations < num_warmup_iterations + num_iterations: for batch in loader: # run benchmark start = time.time() out = self.run(batch) end = time.time() - completed_iterations += 1 if completed_iterations >= num_warmup_iterations: # update results if warmup iterations are completed @@ -408,10 +449,17 @@ def benchmark_loader( inputs=batch if include_inputs else None, outputs=out if include_outputs else None, ) + if show_progress: + progress_bar.update(1) + + completed_iterations += 1 if completed_iterations >= num_warmup_iterations + num_iterations: break + if show_progress: + progress_bar.close() + return results def _validate_inputs(self, inp: List[numpy.ndarray]): @@ -445,7 +493,10 @@ def _properties_dict(self) -> Dict: def compile_model( - model: Union[str, Model, File], batch_size: int = 1, num_cores: int = None + model: Union[str, Model, File], + batch_size: int = 1, + num_cores: int = None, + num_sockets: int = None, ) -> Engine: """ Convenience function to compile a model in the DeepSparse Engine @@ -453,15 +504,19 @@ def compile_model( Gives defaults of batch_size == 1 and num_cores == None (will use all physical cores available on a single socket). - :param model: Either a path to the model's onnx file, a sparsezoo Model object, - or a sparsezoo ONNX File object that defines the neural network + :param model: Either a path to the model's onnx file, a SparseZoo model stub + prefixed by 'zoo:', a SparseZoo Model object, or a SparseZoo ONNX File + object that defines the neural network :param batch_size: The batch size of the inputs to be used with the model :param num_cores: The number of physical cores to run the model on. Pass None or 0 to run on the max number of cores in one socket for the current machine, default None + :param num_sockets: The number of physical sockets to run the model on. + Pass None or 0 to run on the max number of sockets for the + current machine, default None :return: The created Engine after compiling the model """ - return Engine(model, batch_size, num_cores) + return Engine(model, batch_size, num_cores, num_sockets) def benchmark_model( @@ -473,6 +528,8 @@ def benchmark_model( num_warmup_iterations: int = 5, include_inputs: bool = False, include_outputs: bool = False, + show_progress: bool = False, + num_sockets: int = None, ) -> BenchmarkResults: """ Convenience function to benchmark a model in the DeepSparse Engine @@ -480,8 +537,9 @@ def benchmark_model( Gives defaults of batch_size == 1 and num_cores == None (will use all physical cores available on a single socket). - :param model: Either a path to the model's onnx file, a sparsezoo Model object, - or a sparsezoo ONNX File object that defines the neural network + :param model: Either a path to the model's onnx file, a SparseZoo model stub + prefixed by 'zoo:', a SparseZoo Model object, or a SparseZoo ONNX File + object that defines the neural network :param batch_size: The batch size of the inputs to be used with the model :param num_cores: The number of physical cores to run the model on. Pass None or 0 to run on the max number of cores @@ -498,12 +556,21 @@ def benchmark_model( will be added to the results. Default is False :param include_outputs: If True, outputs from forward passes during benchmarking will be added to the results. Default is False + :param show_progress: If True, will display a progress bar. Default is False + :param num_sockets: The number of physical sockets to run the model on. + Pass None or 0 to run on the max number of sockets for the + current machine, default None :return: the results of benchmarking """ - model = compile_model(model, batch_size, num_cores) + model = compile_model(model, batch_size, num_cores, num_sockets) return model.benchmark( - inp, num_iterations, num_warmup_iterations, include_inputs, include_outputs + inp, + num_iterations, + num_warmup_iterations, + include_inputs, + include_outputs, + show_progress, ) @@ -517,6 +584,7 @@ def analyze_model( optimization_level: int = 1, imposed_as: Optional[float] = None, imposed_ks: Optional[float] = None, + num_sockets: int = None, ) -> dict: """ Function to analyze a model's performance in the DeepSparse Engine. @@ -524,9 +592,9 @@ def analyze_model( Gives defaults of batch_size == 1 and num_cores == None (will use all physical cores available on a single socket). - :param model: Either a path to the model's onnx file, a sparsezoo Model object, - or a sparsezoo ONNX File object that defines the neural network - graph definition to analyze + :param model: Either a path to the model's onnx file, a SparseZoo model stub + prefixed by 'zoo:', a SparseZoo Model object, or a SparseZoo ONNX File + object that defines the neural network graph definition to analyze :param inp: The list of inputs to pass to the engine for analyzing inference. The expected order is the inputs order as defined in the ONNX graph. :param batch_size: The batch size of the inputs to be used with the model @@ -547,12 +615,15 @@ def analyze_model( Will force all prunable layers in the graph to have weights with this desired sparsity level (percentage of 0's in the tensor). Beneficial for seeing how pruning affects the performance of the model. + :param num_sockets: The number of physical sockets to run the model on. + Pass None or 0 to run on the max number of sockets for the + current machine, default None :return: the analysis structure containing the performance details of each layer """ model = _model_to_path(model) num_cores = _validate_num_cores(num_cores) batch_size = _validate_batch_size(batch_size) - num_sockets = 1 + num_sockets = _validate_num_sockets(num_sockets) eng_net = LIB.deepsparse_engine(model, batch_size, num_cores, num_sockets) return eng_net.benchmark(