diff --git a/.bazelrc b/.bazelrc index 97915ab3..014e1911 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,16 +1,28 @@ -# By default, the linux x86 platform is targeted. +# By default, the host platform is targeted with the default c toolchain, +# likely using gcc/libstdc++ on Linux. +# # To select the android ARM64 platform, build with `--config=android_arm64` # -# Both platforms require an external toolchain (NDK or clang/libc++) that +# The Android platform and the clang toolchain requires an external +# toolchain (NDK with clang/libc++) that # needs setup by the user. See README.md for instructions. +# +# If clang/libc++ is installed to /usr/local/ it can be used with linux builds +# instead of the default (gcc) by building with `--config=clang_toolchain`. +# Since this will provide no specific advantages for most users, and the process +# of installing a specific clang toolchain is a bit involved, the documentation +# for clang toolchain setup is in toolchain/ -build --crosstool_top=//toolchain:clang_suite -build --cpu=k8 - +build --cxxopt=-std=gnu++17 +build --linkopt=-lm +build --cxxopt=-Wno-sign-compare # Use the default C++ toolchain to build the tools used during the # build. build --host_crosstool_top=@bazel_tools//tools/cpp:toolchain +build:clang_toolchain --crosstool_top=//toolchain:clang_suite +build:clang_toolchain --cpu=k8 + # Android build:android_arm64 --cpu=arm64-v8a build:android_arm64 --fat_apk_cpu=arm64-v8a diff --git a/BUILD b/BUILD index 81ddde36..977c36ae 100644 --- a/BUILD +++ b/BUILD @@ -1,7 +1,6 @@ # [internal] load cc_fuzz_target.bzl # [internal] load cc_proto_library.bzl # [internal] load android_cc_test:def.bzl -# [internal] load open_source_rules.bzl package(default_visibility = [":__subpackages__"]) @@ -48,7 +47,7 @@ cc_library( name = "layer_wrapper_interface", hdrs = ["layer_wrapper_interface.h"], deps = [ - ":sparse_inference_matrixvector", + "//sparse_matmul", ], ) @@ -58,7 +57,7 @@ cc_library( deps = [ ":dsp_util", ":layer_wrapper_interface", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_glog//:glog", ], ) @@ -68,7 +67,7 @@ cc_library( hdrs = ["conv1d_layer_wrapper.h"], deps = [ ":layer_wrapper", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/memory", "@com_google_glog//:glog", ], @@ -79,7 +78,7 @@ cc_library( hdrs = ["dilated_convolutional_layer_wrapper.h"], deps = [ ":layer_wrapper", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/memory", "@com_google_glog//:glog", ], @@ -90,7 +89,7 @@ cc_library( hdrs = ["transpose_convolutional_layer_wrapper.h"], deps = [ ":layer_wrapper", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/memory", "@com_google_glog//:glog", ], @@ -114,7 +113,7 @@ cc_library( ":dsp_util", ":layer_wrappers_lib", ":lyra_types", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -128,14 +127,12 @@ cc_library( hdrs = ["benchmark_decode_lib.h"], deps = [ ":architecture_utils", + ":dsp_util", ":generative_model_interface", ":log_mel_spectrogram_extractor_impl", ":lyra_config", ":wavegru_model_impl", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/flags:usage", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -246,7 +243,7 @@ cc_library( ":generative_model_interface", ":lyra_types", ":lyra_wavegru", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/time", @@ -275,7 +272,7 @@ cc_library( ":generative_model_interface", ":lyra_types", ":lyra_wavegru", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/time", @@ -656,7 +653,7 @@ cc_library( copts = ["-O3"], deps = [ ":layer_wrapper", - ":sparse_inference_matrixvector", + "//sparse_matmul", ], ) @@ -690,8 +687,8 @@ cc_library( ], data = glob(["wavegru/**"]), deps = [ - ":sparse_inference_matrixvector", ":vector_quantizer_interface", + "//sparse_matmul", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", @@ -733,7 +730,7 @@ cc_library( ":layer_wrappers_lib", ":lyra_types", ":project_and_sample", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/memory", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", @@ -750,7 +747,7 @@ cc_library( copts = ["-O3"], deps = [ ":lyra_types", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -827,36 +824,6 @@ cc_test( ], ) -cc_test( - name = "sparse_inference_matrixvector_test", - size = "small", - timeout = "short", - srcs = ["sparse_inference_matrixvector_test.cc"], - deps = [ - ":sparse_inference_matrixvector", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "sparse_inference_matrixvector", - srcs = select({ - ":android_config": ["lib/android_arm64/libsparse_inference.so"], - "//conditions:default": ["lib/linux_x86_64/libsparse_inference.so"], - }), - hdrs = ["sparse_inference_matrixvector.h"], - defines = [ - "ACCURATE_TRANSCENDENTAL_APPROX", - "FAST_SAMPLING", - "FAST_TRANSCENDENTALS", - "SIGMOID_AS_TANH", - ], - deps = [ - "@com_google_absl//absl/status", - "@com_google_glog//:glog", - ], -) - cc_binary( name = "encoder_main", srcs = [ @@ -926,7 +893,7 @@ cc_test( ":exported_layers_test", ":lyra_config", ":lyra_wavegru", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", "@gulrak_filesystem//:filesystem", @@ -945,7 +912,7 @@ cc_test( deps = [ ":lyra_config", ":lyra_wavegru", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", "@gulrak_filesystem//:filesystem", @@ -964,7 +931,7 @@ cc_test( deps = [ ":lyra_config", ":lyra_wavegru", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", "@gulrak_filesystem//:filesystem", @@ -998,7 +965,7 @@ cc_test( ":exported_layers_test", ":lyra_types", ":project_and_sample", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", "@gulrak_filesystem//:filesystem", @@ -1095,7 +1062,7 @@ cc_library( deps = [ ":layer_wrappers_lib", ":lyra_types", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/random", "@com_google_googletest//:gtest", "@gulrak_filesystem//:filesystem", @@ -1341,7 +1308,7 @@ cc_test( ":exported_layers_test", ":lyra_config", ":lyra_types", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@gulrak_filesystem//:filesystem", @@ -1357,7 +1324,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":layer_wrappers_lib", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], @@ -1381,7 +1348,7 @@ cc_test( ":conv1d_layer_wrapper", ":layer_wrapper", ":layer_wrapper_test_common", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_googletest//:gtest_main", "@gulrak_filesystem//:filesystem", ], @@ -1405,7 +1372,7 @@ cc_test( ":dilated_convolutional_layer_wrapper", ":layer_wrapper", ":layer_wrapper_test_common", - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_googletest//:gtest_main", "@gulrak_filesystem//:filesystem", ], @@ -1428,8 +1395,8 @@ cc_test( deps = [ ":layer_wrapper", ":layer_wrapper_test_common", - ":sparse_inference_matrixvector", ":transpose_convolutional_layer_wrapper", + "//sparse_matmul", "@com_google_googletest//:gtest_main", "@gulrak_filesystem//:filesystem", ], @@ -1483,7 +1450,7 @@ cc_library( ], hdrs = ["dsp_util.h"], deps = [ - ":sparse_inference_matrixvector", + "//sparse_matmul", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_audio_dsp//audio/dsp:signal_vector_util", diff --git a/README.md b/README.md index 06e87043..2ada9af9 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ parallel multiple signals in different frequency ranges that it later combines into a single output signal at the desired sample rate. This trick, plus 64-bit ARM optimizations, enables Lyra to not only run on cloud servers, but also on-device on mid-range phones, such as Pixel phones, in real time (with a -processing latency of 90ms). This generative model is then trained on thousands +processing latency of 100ms). This generative model is then trained on thousands of hours of speech data with speakers in over 70 languages and optimized to accurately recreate the input audio. @@ -51,44 +51,6 @@ Lyra can be built from linux using bazel for an arm android target, or a linux target. The android target is optimized for realtime performance. The linux target is typically used for development and debugging. -You will also need to install some tools (which may already be on your system). -You can install them with: - -```shell -sudo apt update -sudo apt install ninja-build git cmake clang python -``` - -### Linux requirements - -The instructions below are for Ubuntu and have been verified on 20.04. - -You will need to install a certain version of clang to ensure ABI compatibility. - -```shell -git clone https://github.com/llvm/llvm-project.git -cd llvm-project -git checkout 96ef4f307df2 - -mkdir build_clang -cd build_clang -cmake -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DLLVM_ENABLE_PROJECTS="clang" -DCMAKE_BUILD_TYPE=release ../llvm -ninja -sudo $(which ninja) install - -cd .. -mkdir build_libcxx -cd build_libcxx -cmake -G Ninja -DCMAKE_C_COMPILER=/usr/local/bin/clang -DCMAKE_CXX_COMPILER=/usr/local/bin/clang++ -DLLVM_ENABLE_PROJECTS="libcxx;libcxxabi" -DCMAKE_BUILD_TYPE=release ../llvm -ninja -sudo $(which ninja) install - -sudo ldconfig -``` - -Note: the above will install a particular version of libc++ to /usr/local/lib, -and clang to /usr/local/bin, which the toolchain depends on. - ### Android requirements Building on android requires downloading a specific version of the android NDK @@ -158,6 +120,11 @@ bazel build -c opt :decoder_main bazel-bin/decoder_main --model_path=wavegru --output_dir=$HOME/temp/ --encoded_path=$HOME/temp/16khz_sample_000001.lyra ``` +Note: the default Bazel toolchain is automatically configured and likely uses +gcc/libstdc++ on Linux. This should be satisfactory for most users, but will +differ from the NDK toolchain, which uses clang/libc++. To use a custom clang +toolchain on Linux, see toolchain/README.md and .bazelrc. + ### Building for Android #### Android App @@ -184,8 +151,8 @@ Press "Record from microphone", say a few words (be sure to have your microphone near your mouth), and then press "Encode and decode to speaker". You should hear your voice being played back after being coded with Lyra. -If you press 'Benchmark', you should you should see something like the following -in logcat on a Pixel 4 when running the benchmark: +If you press 'Benchmark', you should see something like the following in logcat +on a Pixel 4 when running the benchmark: ```shell I Starting benchmarkDecode() @@ -209,7 +176,8 @@ with `--copt=-DUSE_FIXED16`, although there may be some loss of quality. To build your own android app, you can either use the cc_library target outputs to create a .so that you can use in your own build system. Or you can use it -with an [`android_binary`](https://docs.bazel.build/versions/master/be/android.html) +with an +[`android_binary`](https://docs.bazel.build/versions/master/be/android.html) rule within bazel to create an .apk file as in this example. There is a tutorial on building for android with Bazel in the @@ -232,13 +200,11 @@ This builds an executable binary that can be run on android 64-bit arm devices a binary through the shell. ```shell -# Push the binary and the data it needs, including the model, .wav, and .so files: +# Push the binary and the data it needs, including the model and .wav files: adb push bazel-bin/encoder_main /data/local/tmp/ adb push bazel-bin/decoder_main /data/local/tmp/ adb push wavegru/ /data/local/tmp/ adb push testdata/ /data/local/tmp/ -adb shell mkdir -p /data/local/tmp/_U_S_S_Csparse_Uinference_Umatrixvector___Ulib_Sandroid_Uarm64 -adb push bazel-bin/_solib_arm64-v8a/_U_S_S_Csparse_Uinference_Umatrixvector___Ulib_Sandroid_Uarm64/libsparse_inference.so /data/local/tmp/_U_S_S_Csparse_Uinference_Umatrixvector___Ulib_Sandroid_Uarm64 adb shell cd /data/local/tmp @@ -325,10 +291,10 @@ class LyraDecoder : public LyraDecoderInterface { Once again, the static `Create` method instantiates a `LyraDecoder` with the desired sample rate in Hertz, number of channels and bitrate, as long as those -parameters are supported. Else it returns a `nullptr`. These parameters don't need -to be the same as the ones in `LyraEncoder`. And once again, the `Create` method -also needs to know where the model weights are stored. It also checks that these -weights exist and are compatible with the current Lyra version. +parameters are supported. Else it returns a `nullptr`. These parameters don't +need to be the same as the ones in `LyraEncoder`. And once again, the `Create` +method also needs to know where the model weights are stored. It also checks +that these weights exist and are compatible with the current Lyra version. Given a `LyraDecoder`, any packet can be decoded by first feeding it into `SetEncodedPacket`, which returns true if the provided span of bytes is a valid @@ -350,19 +316,24 @@ The rest of the `LyraDecoder` methods are just getters for the different predetermined parameters. For an example on how to use `LyraEncoder` and `LyraDecoder` to encode and -decode a stream of audio, please refer to the [integration -test](lyra_integration_test.cc). +decode a stream of audio, please refer to the +[integration test](lyra_integration_test.cc). + +## Sparse Matrix Multiplication Library +Lyra uses a library in the `sparse_matmul` directory that enables fast execution +of sparse Matrix-Vector multiplication ops on mobile and desktop CPU platforms +(ARM and AVX2) to allow for real-time operation on phones. This library was +created by DeepMind for their implementation of WaveRNN with sparsity [[4]](#4), +which gave a huge improvement in complexity over WaveNet. + +A generic kernel is also provided, which enables debugging on non-optimized +platforms. Contributions for other platforms are welcome. ## License Use of this source code is governed by a Apache v2.0 license that can be found in the LICENSE file. -Please note that there is a closed-source kernel used for math operations that -is linked via a shared object called libsparse_inference.so. We provide the -libsparse_inference.so library to be linked, but are unable to provide source -for it. This is the reason that a specific toolchain/compiler is required. - ## Papers 1. Kleijn, W. B., Lim, F. S., Luebs, A., Skoglund, J., Stimberg, F., Wang, Q., & @@ -376,3 +347,7 @@ for it. This is the reason that a specific toolchain/compiler is required. & Yeh, H. (2021). [Generative Speech Coding with Predictive Variance Regularization](https://arxiv.org/pdf/2102.09660). arXiv preprint arXiv:2102.09660. +4. Kalchbrenner, N., Elsen, E., Simonyan, K., Noury, S., + Casagrande, N., Lockhart, E., ... & Kavukcuoglu, K. (2018, July). + [Efficient neural audio synthesis](https://arxiv.org/abs/1802.08435). + In International Conference on Machine Learning (pp. 2410-2419). PMLR. diff --git a/architecture_utils.h b/architecture_utils.h index 355690e8..bde4f88a 100644 --- a/architecture_utils.h +++ b/architecture_utils.h @@ -17,7 +17,7 @@ #ifndef LYRA_CODEC_ARCHITECTURE_UTILS_H_ #define LYRA_CODEC_ARCHITECTURE_UTILS_H_ -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "include/ghc/filesystem.hpp" namespace chromemedia { diff --git a/benchmark_decode_lib.cc b/benchmark_decode_lib.cc index 27533293..0648bc78 100644 --- a/benchmark_decode_lib.cc +++ b/benchmark_decode_lib.cc @@ -34,6 +34,7 @@ #include "absl/types/span.h" #include "architecture_utils.h" #include "audio/dsp/signal_vector_util.h" +#include "dsp_util.h" #include "generative_model_interface.h" #include "glog/logging.h" #include "include/ghc/filesystem.hpp" @@ -113,7 +114,8 @@ int benchmark_decode(const int num_cond_vectors, chromemedia::codec::GetNumSamplesPerHop( chromemedia::codec::kInternalSampleRateHz), chromemedia::codec::kNumFeatures, - chromemedia::codec::kNumFramesPerPacket, model_path); + chromemedia::codec::kNumFramesPerPacket, + LogMelSpectrogramExtractorImpl::GetSilenceValue(), model_path); const int num_samples_per_hop = chromemedia::codec::GetNumSamplesPerHop( chromemedia::codec::kInternalSampleRateHz); @@ -124,14 +126,18 @@ int benchmark_decode(const int num_cond_vectors, chromemedia::codec::LogMelSpectrogramExtractorImpl::Create( chromemedia::codec::kInternalSampleRateHz, kNumFeatures, num_samples_per_hop, num_samples_per_frame); - std::uniform_real_distribution distribution( - std::numeric_limits::min(), std::numeric_limits::max()); + // Generate a random signal. + // The characteristics of the signal are not so important, since this is + // testing benchmarking. But it should have some variance since silent + // signals could potentially be handled differently. + std::uniform_real_distribution distribution(-1.0, 1.0); std::default_random_engine generator; std::vector random_audio(num_samples_per_hop); for (int i = 0; i < num_cond_vectors; ++i) { - std::generate(random_audio.begin(), random_audio.end(), - [&]() { return distribution(generator); }); + std::generate(random_audio.begin(), random_audio.end(), [&]() { + return UnitFloatToInt16Scalar(distribution(generator)); + }); auto features_or = feature_extractor->Extract(absl::MakeConstSpan(random_audio)); if (!features_or.has_value()) { diff --git a/causal_convolutional_conditioning.h b/causal_convolutional_conditioning.h index 1d596971..5a7d84bd 100644 --- a/causal_convolutional_conditioning.h +++ b/causal_convolutional_conditioning.h @@ -27,7 +27,7 @@ #include "glog/logging.h" #include "layer_wrappers_lib.h" #include "lyra_types.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { @@ -96,7 +96,7 @@ class CausalConvolutionalConditioning { CausalConvolutionalConditioning(int feature_depth, int num_cond_hiddens, int num_hiddens, int num_samples_per_hop, int num_frames_per_packet, int num_threads, - const std::string& path, + float silence_value, const std::string& path, const std::string& prefix) : feature_depth_(feature_depth), num_hiddens_(num_hiddens), @@ -106,8 +106,7 @@ class CausalConvolutionalConditioning { num_threads_(num_threads), path_(path), prefix_(prefix), - num_precomputed_frames_(0), - spin_barrier_(num_threads_) { + num_precomputed_frames_(0) { // Crash ok. CHECK_LE(num_threads_, num_cond_hiddens) << "Number of threads must be <= the number of hidden layers " @@ -121,6 +120,7 @@ class CausalConvolutionalConditioning { CreateLayers(); PrepareOutput(); + WarmUp(silence_value); } ~CausalConvolutionalConditioning() {} @@ -147,7 +147,7 @@ class CausalConvolutionalConditioning { auto f = [this](csrblocksparse::SpinBarrier* barrier, int tid) { ComputeFunction(barrier, tid); }; - LaunchOnThreadsWithBarrier(num_threads_, f); + csrblocksparse::LaunchOnThreadsWithBarrier(num_threads_, f); } int num_samples() const { @@ -349,6 +349,21 @@ class CausalConvolutionalConditioning { conditioning_.FillZero(); } + void WarmUp(float silence_value) { + csrblocksparse::CacheAlignedVector silence_vector( + feature_depth_ * kCondInputNumTimesteps); + silence_vector.FillWith(silence_value); + const csrblocksparse::FatCacheAlignedVector silence_input( + silence_vector, feature_depth_); + const int kNumPaddingFrames = (kConv1DKernel - 1) / 2; + csrblocksparse::SpinBarrier spin_barrier(0); + for (int i = 0; i < kNumPaddingFrames; ++i) { + InsertNewInput(silence_input); + conv1d_layer_->Run(0, &spin_barrier, + dilated_conv_layer_0_->InputViewToUpdate()); + } + } + void InsertNewInput( const csrblocksparse::FatCacheAlignedVector& input) { // This conversion might not always be necessary, will @@ -435,7 +450,6 @@ class CausalConvolutionalConditioning { const std::string prefix_; int num_precomputed_frames_; - csrblocksparse::SpinBarrier spin_barrier_; std::unique_ptr conv1d_layer_; std::unique_ptr dilated_conv_layer_0_; diff --git a/causal_convolutional_conditioning_test.cc b/causal_convolutional_conditioning_test.cc index bce456b4..9f1b1bbf 100644 --- a/causal_convolutional_conditioning_test.cc +++ b/causal_convolutional_conditioning_test.cc @@ -19,7 +19,7 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "absl/types/span.h" #include "exported_layers_test.h" #include "gmock/gmock.h" @@ -27,7 +27,7 @@ #include "include/ghc/filesystem.hpp" #include "lyra_config.h" #include "lyra_types.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { @@ -86,7 +86,7 @@ class CausalConvolutionalConditioningPeer { const std::string& prefix) : conditioning_stack_(feature_depth, num_cond_hiddens, num_hiddens, num_samples_per_hop, num_frames_per_packet, - num_threads, path, prefix) {} + num_threads, 0.0f, path, prefix) {} void Precompute(const csrblocksparse::FatCacheAlignedVector& input, int num_threads) { @@ -241,10 +241,10 @@ TYPED_TEST(CausalConvolutionalConditioningTest, ConditioningType no_multithreading( kFeatures.at(0).size(), kNumCondHiddens, kNumHiddens, kNumSamplesPerHop, - kNumFramesPerPacket, 1, this->testdata_dir_.string(), "lyra"); + kNumFramesPerPacket, 1, 0.0f, this->testdata_dir_.string(), "lyra"); ConditioningType two_threads( kFeatures.at(0).size(), kNumCondHiddens, kNumHiddens, kNumSamplesPerHop, - kNumFramesPerPacket, 2, this->testdata_dir_.string(), "lyra"); + kNumFramesPerPacket, 2, 0.0f, this->testdata_dir_.string(), "lyra"); csrblocksparse::FatCacheAlignedVector input(kFeatures.at(0).size(), 1); for (int i = 0; i < kFeatures.size(); ++i) { std::copy(kFeatures.at(i).begin(), kFeatures.at(i).end(), input.data()); @@ -278,12 +278,12 @@ TYPED_TEST(CausalConvolutionalConditioningTest, {0.0f, 0.0f, 0.0f}}; ConditioningType one_frame_conditioning( kFeatures.at(0).size(), kNumCondHiddens, kNumHiddens, kNumSamplesPerHop, - /*num_frames_per_packet=*/1, kNumThreads, this->testdata_dir_.string(), - "lyra"); + /*num_frames_per_packet=*/1, kNumThreads, 0.0f, + this->testdata_dir_.string(), "lyra"); ConditioningType two_frames_conditioning( kFeatures.at(0).size(), kNumCondHiddens, kNumHiddens, kNumSamplesPerHop, - /*num_frames_per_packet=*/2, kNumThreads, this->testdata_dir_.string(), - "lyra"); + /*num_frames_per_packet=*/2, kNumThreads, 0.0f, + this->testdata_dir_.string(), "lyra"); csrblocksparse::FatCacheAlignedVector input(kFeatures.at(0).size(), 1); std::vector one_frame_output_to_compare; diff --git a/conv1d_layer_wrapper.h b/conv1d_layer_wrapper.h index 429e52c3..d2a2cc65 100644 --- a/conv1d_layer_wrapper.h +++ b/conv1d_layer_wrapper.h @@ -25,7 +25,7 @@ #include "absl/memory/memory.h" #include "glog/logging.h" #include "layer_wrapper.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/conv1d_layer_wrapper_test.cc b/conv1d_layer_wrapper_test.cc index 5e86461d..75538dce 100644 --- a/conv1d_layer_wrapper_test.cc +++ b/conv1d_layer_wrapper_test.cc @@ -17,13 +17,13 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "gmock/gmock.h" #include "gtest/gtest.h" #include "include/ghc/filesystem.hpp" #include "layer_wrapper.h" #include "layer_wrapper_test_common.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/decoder_main_lib_test.cc b/decoder_main_lib_test.cc index bbdaecee..54485f4e 100644 --- a/decoder_main_lib_test.cc +++ b/decoder_main_lib_test.cc @@ -20,9 +20,9 @@ #include // NOLINT(build/c++11) #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "gmock/gmock.h" -// placeholder for testing header. +// Placeholder for testing header. #include "absl/flags/flag.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" diff --git a/dilated_convolutional_layer_wrapper.h b/dilated_convolutional_layer_wrapper.h index 44a661bc..6583f2af 100644 --- a/dilated_convolutional_layer_wrapper.h +++ b/dilated_convolutional_layer_wrapper.h @@ -25,7 +25,7 @@ #include "absl/memory/memory.h" #include "glog/logging.h" #include "layer_wrapper.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/dilated_convolutional_layer_wrapper_test.cc b/dilated_convolutional_layer_wrapper_test.cc index 40bca63f..59a3d853 100644 --- a/dilated_convolutional_layer_wrapper_test.cc +++ b/dilated_convolutional_layer_wrapper_test.cc @@ -18,13 +18,13 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "gmock/gmock.h" #include "gtest/gtest.h" #include "include/ghc/filesystem.hpp" #include "layer_wrapper.h" #include "layer_wrapper_test_common.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/dsp_util.cc b/dsp_util.cc index a572571e..74a32d7b 100644 --- a/dsp_util.cc +++ b/dsp_util.cc @@ -50,5 +50,31 @@ int16_t ClipToInt16(float value) { static_cast(std::numeric_limits::max())); } +int16_t UnitFloatToInt16Scalar(float unit_float) { + // First, scale unit_float linearly to int16 ranges. + // The unary negation is used here to scale by the negative min int16_t value, + // which has a greater absolute value than the max. + float int16_range_float = + unit_float * (-std::numeric_limits().min()); + // If unit_float was outside the [-1, 1), clip to the min/max value. + return ClipToInt16(int16_range_float); +} + +std::vector UnitFloatToInt16(absl::Span input) { + std::vector output; + output.reserve(input.size()); + std::transform(input.begin(), input.end(), std::back_inserter(output), + UnitFloatToInt16Scalar); + return output; +} + +std::vector Int16ToUnitFloat(absl::Span input) { + std::vector output(input.size()); + std::transform(input.begin(), input.end(), output.begin(), [](int16_t x) { + return -static_cast(x) / std::numeric_limits().min(); + }); + return output; +} + } // namespace codec } // namespace chromemedia diff --git a/dsp_util.h b/dsp_util.h index 10c42674..c3dd11c3 100644 --- a/dsp_util.h +++ b/dsp_util.h @@ -21,7 +21,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { @@ -33,8 +33,22 @@ absl::optional LogSpectralDistance( const absl::Span second_log_spectrum); // Clip values above max value or below min value for int16_t. +// The quantization scheme uses native c rounding (non-centered, decimal +// truncation) int16_t ClipToInt16(float value); +// Converts from a unit-float to a 16-bit integer. +// If |unit_float| is in the [-1, 1) interval it will scale linearly to the +// int16_t limits. Values outside the interval are clipped to the limits. +// The clipping, rounding, and quantization follows ClipToInt16(). +int16_t UnitFloatToInt16Scalar(float unit_float); + +// Converts from a Span of unit-floats to a vector of 16-bit integers. +std::vector UnitFloatToInt16(absl::Span input); + +// Converts from a Span of 16-bit integers to a vector of unit-floats. +std::vector Int16ToUnitFloat(absl::Span input); + #if defined __aarch64__ // We do not provide fixed16 to fixed32 casting as there is no use case so far. diff --git a/dsp_util_test.cc b/dsp_util_test.cc index 2e0dc026..0748ebaa 100644 --- a/dsp_util_test.cc +++ b/dsp_util_test.cc @@ -40,13 +40,60 @@ TEST(DspUtilTest, LogSpectralDistanceTest) { EXPECT_NEAR(log_spectral_distance_or.value(), 10.0f, 0.0001); } -TEST(DspUtilTest, ClipTest) { - const float kMax = ClipToInt16(10000000); - EXPECT_EQ(kMax, std::numeric_limits::max()); - const float kZero = ClipToInt16(0); +TEST(ClipToInt16Test, ClipsExtremeValues) { + const int16_t kMaxExceeded = ClipToInt16(10000000); + EXPECT_EQ(kMaxExceeded, std::numeric_limits::max()); + const int16_t kMinExceeded = ClipToInt16(-10000000); + EXPECT_EQ(kMinExceeded, std::numeric_limits::min()); +} + +TEST(ClipToInt16Test, TruncatesDecimal) { + const int16_t kJustAboveZero = ClipToInt16(.0001); + EXPECT_EQ(kJustAboveZero, 0); + + const int16_t kJustBelowOne = ClipToInt16(.999); + EXPECT_EQ(kJustBelowOne, 0); + + const int16_t kShouldTruncateNegativeDecimalToZero = ClipToInt16(-.0001); + EXPECT_EQ(kShouldTruncateNegativeDecimalToZero, 0); +} + +TEST(ClipToInt16Test, BoundaryIdentity) { + const int16_t kZero = ClipToInt16(0); EXPECT_EQ(kZero, 0); - const float kMin = ClipToInt16(-10000000); - EXPECT_EQ(kMin, std::numeric_limits::min()); + + const int16_t kMaxBoundary = ClipToInt16(std::numeric_limits::max()); + EXPECT_EQ(kMaxBoundary, std::numeric_limits::max()); + + const int16_t kMinBoundary = ClipToInt16(std::numeric_limits::min()); + EXPECT_EQ(kMinBoundary, std::numeric_limits::min()); +} + +TEST(UnitFloatToInt16ScalarTest, ExtremeValues) { + const int16_t kMaxExceeded = UnitFloatToInt16Scalar(100000.0); + EXPECT_EQ(kMaxExceeded, std::numeric_limits::max()); + + const int16_t kMinExceeded = UnitFloatToInt16Scalar(-100000.0); + EXPECT_EQ(kMinExceeded, std::numeric_limits::min()); +} + +TEST(UnitFloatToInt16ScalarTest, RoundsTowardsZero) { + const int16_t kShouldRoundDownToZero = UnitFloatToInt16Scalar(1e-10); + EXPECT_EQ(kShouldRoundDownToZero, 0); + + const int16_t kShouldRoundNegativeUpToZero = UnitFloatToInt16Scalar(-1e-10); + EXPECT_EQ(kShouldRoundNegativeUpToZero, 0); +} + +TEST(UnitFloatToInt16ScalarTest, BoundariesMapToLimits) { + const int16_t kZero = UnitFloatToInt16Scalar(0.0); + EXPECT_EQ(kZero, 0); + + const int16_t kMaxBoundary = UnitFloatToInt16Scalar(1.0); + EXPECT_EQ(kMaxBoundary, std::numeric_limits::max()); + + const int16_t kMinBoundary = UnitFloatToInt16Scalar(-1.0); + EXPECT_EQ(kMinBoundary, std::numeric_limits::min()); } // Pair of input and output types to be tested for casting and their diff --git a/encoder_main_lib_test.cc b/encoder_main_lib_test.cc index 9d34b505..ef797196 100644 --- a/encoder_main_lib_test.cc +++ b/encoder_main_lib_test.cc @@ -17,9 +17,9 @@ #include #include // NOLINT(build/c++11) -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "gmock/gmock.h" -// placeholder for testing header. +// Placeholder for testing header. #include "absl/flags/flag.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" diff --git a/exported_layers_test.h b/exported_layers_test.h index 863aaa61..398b9671 100644 --- a/exported_layers_test.h +++ b/exported_layers_test.h @@ -21,14 +21,14 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "absl/random/random.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "include/ghc/filesystem.hpp" #include "layer_wrappers_lib.h" #include "lyra_types.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/layer_wrapper.h b/layer_wrapper.h index 54be74d7..94c9d858 100644 --- a/layer_wrapper.h +++ b/layer_wrapper.h @@ -18,6 +18,7 @@ #define LYRA_CODEC_LAYER_WRAPPER_H_ #include +#include #include #include #include @@ -26,7 +27,7 @@ #include "dsp_util.h" #include "glog/logging.h" #include "layer_wrapper_interface.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { @@ -204,6 +205,21 @@ class LayerWrapper : public LayerWrapperInterface +std::ostream& operator<<( + std::ostream& out_stream, + const std::unique_ptr>& layer_ptr) { + return out_stream << layer_ptr.get(); +} + } // namespace codec } // namespace chromemedia diff --git a/layer_wrapper_interface.h b/layer_wrapper_interface.h index 5b89fc53..e8226ef7 100644 --- a/layer_wrapper_interface.h +++ b/layer_wrapper_interface.h @@ -20,7 +20,7 @@ #include #include -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/layer_wrapper_test_common.h b/layer_wrapper_test_common.h index 8eb93f73..2c18fc51 100644 --- a/layer_wrapper_test_common.h +++ b/layer_wrapper_test_common.h @@ -26,7 +26,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "layer_wrappers_lib.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { @@ -140,7 +140,7 @@ void VerifyMultipleThreadsYeldSameResults( auto f = [&](csrblocksparse::SpinBarrier* spin_barrier, int tid) { layer->Run(tid, spin_barrier, output_view); }; - LaunchOnThreadsWithBarrier(params.num_threads, f); + csrblocksparse::LaunchOnThreadsWithBarrier(params.num_threads, f); std::vector saved_output(output_view.data(), output_view.data() + output_view.rows()); diff --git a/lib/android_arm64/libsparse_inference.so b/lib/android_arm64/libsparse_inference.so deleted file mode 100755 index 496089ad..00000000 Binary files a/lib/android_arm64/libsparse_inference.so and /dev/null differ diff --git a/lib/linux_x86_64/libsparse_inference.so b/lib/linux_x86_64/libsparse_inference.so deleted file mode 100755 index 7e3ba9b8..00000000 Binary files a/lib/linux_x86_64/libsparse_inference.so and /dev/null differ diff --git a/lyra_components.cc b/lyra_components.cc index 759cf873..77aebfb9 100644 --- a/lyra_components.cc +++ b/lyra_components.cc @@ -61,8 +61,9 @@ std::unique_ptr CreateQuantizer( std::unique_ptr CreateGenerativeModel( int num_samples_per_hop, int num_output_features, int num_frames_per_packet, const ghc::filesystem::path& model_path) { - return WavegruModelImpl::Create(num_samples_per_hop, num_output_features, - num_frames_per_packet, model_path); + return WavegruModelImpl::Create( + num_samples_per_hop, num_output_features, num_frames_per_packet, + LogMelSpectrogramExtractorImpl::GetSilenceValue(), model_path); } std::unique_ptr CreateFeatureExtractor( diff --git a/lyra_config.cc b/lyra_config.cc index 550660b2..3bb8a4f1 100644 --- a/lyra_config.cc +++ b/lyra_config.cc @@ -31,7 +31,7 @@ const int kVersionMajor = 0; // |identifier| field needs to be set in lyra_config.textproto to match this. const int kVersionMinor = 0; // The micro version is for other things like a release of bugfixes. -const int kVersionMicro = 1; +const int kVersionMicro = 2; const int kNumFeatures = 160; const int kNumExpectedOutputFeatures = 160; @@ -50,6 +50,120 @@ const int kPacketSize = 15; const int kBitrate = kPacketSize * CHAR_BIT * kFrameRate * kNumChannels; +absl::Status AreParamsSupported(int sample_rate_hz, int num_channels, + int bitrate, + const ghc::filesystem::path& model_path) { + constexpr absl::string_view kAssets[] = { + "lyra_16khz_ar_to_gates_bias.raw.gz", + "lyra_16khz_ar_to_gates_mask.raw.gz", + "lyra_16khz_ar_to_gates_weights.raw.gz", + "lyra_16khz_conditioning_stack_0_bias.raw.gz", + "lyra_16khz_conditioning_stack_0_mask.raw.gz", + "lyra_16khz_conditioning_stack_0_weights.raw.gz", + "lyra_16khz_conditioning_stack_1_bias.raw.gz", + "lyra_16khz_conditioning_stack_1_mask.raw.gz", + "lyra_16khz_conditioning_stack_1_weights.raw.gz", + "lyra_16khz_conditioning_stack_2_bias.raw.gz", + "lyra_16khz_conditioning_stack_2_mask.raw.gz", + "lyra_16khz_conditioning_stack_2_weights.raw.gz", + "lyra_16khz_conv1d_bias.raw.gz", + "lyra_16khz_conv1d_mask.raw.gz", + "lyra_16khz_conv1d_weights.raw.gz", + "lyra_16khz_conv_cond_bias.raw.gz", + "lyra_16khz_conv_cond_mask.raw.gz", + "lyra_16khz_conv_cond_weights.raw.gz", + "lyra_16khz_conv_to_gates_bias.raw.gz", + "lyra_16khz_conv_to_gates_mask.raw.gz", + "lyra_16khz_conv_to_gates_weights.raw.gz", + "lyra_16khz_gru_layer_bias.raw.gz", + "lyra_16khz_gru_layer_mask.raw.gz", + "lyra_16khz_gru_layer_weights.raw.gz", + "lyra_16khz_means_bias.raw.gz", + "lyra_16khz_means_mask.raw.gz", + "lyra_16khz_means_weights.raw.gz", + "lyra_16khz_mix_bias.raw.gz", + "lyra_16khz_mix_mask.raw.gz", + "lyra_16khz_mix_weights.raw.gz", + "lyra_16khz_proj_bias.raw.gz", + "lyra_16khz_proj_mask.raw.gz", + "lyra_16khz_proj_weights.raw.gz", + "lyra_16khz_quant_codebook_dimensions.gz", + "lyra_16khz_quant_code_vectors.gz", + "lyra_16khz_quant_mean_vectors.gz", + "lyra_16khz_quant_transmat.gz", + "lyra_16khz_scales_bias.raw.gz", + "lyra_16khz_scales_mask.raw.gz", + "lyra_16khz_scales_weights.raw.gz", + "lyra_16khz_transpose_0_bias.raw.gz", + "lyra_16khz_transpose_0_mask.raw.gz", + "lyra_16khz_transpose_0_weights.raw.gz", + "lyra_16khz_transpose_1_bias.raw.gz", + "lyra_16khz_transpose_1_mask.raw.gz", + "lyra_16khz_transpose_1_weights.raw.gz", + "lyra_16khz_transpose_2_bias.raw.gz", + "lyra_16khz_transpose_2_mask.raw.gz", + "lyra_16khz_transpose_2_weights.raw.gz"}; + + if (!IsSampleRateSupported(sample_rate_hz)) { + return absl::InvalidArgumentError(absl::StrFormat( + "Sample rate %d Hz is not supported by codec.", sample_rate_hz)); + } + if (num_channels != kNumChannels) { + return absl::InvalidArgumentError(absl::StrFormat( + "Number of channels %d is not supported by codec. It needs to be %d.", + num_channels, kNumChannels)); + } + if (bitrate != kBitrate) { + return absl::InvalidArgumentError(absl::StrFormat( + "Bitrate %d bps is not supported by codec. It needs to be %d bps.", + bitrate, kBitrate)); + } + for (auto asset : kAssets) { + std::error_code error; + const bool exists = + ghc::filesystem::exists(model_path / std::string(asset), error); + if (error) { + return absl::UnknownError( + absl::StrFormat("Error when probing for asset %s in %s: %s", asset, + model_path, error.message())); + } + if (!exists) { + return absl::InvalidArgumentError( + absl::StrFormat("Asset %s does not exist in %s.", asset, model_path)); + } + } + const ghc::filesystem::path lyra_config_proto = + model_path / "lyra_config.textproto"; + std::error_code error; + const bool exists = ghc::filesystem::exists(lyra_config_proto, error); + if (error) { + return absl::UnknownError( + absl::StrFormat("Error when probing for asset %s: %s", + lyra_config_proto, error.message())); + } + third_party::lyra_codec::LyraConfig lyra_config; + if (exists) { + std::ifstream lyra_config_stream(lyra_config_proto.string()); + const std::string lyra_config_string{ + std::istreambuf_iterator(lyra_config_stream), + std::istreambuf_iterator()}; + // Even though LyraConfig is a subclass of Message, the reinterpreting is + // necessary for the mobile proto library. + if (!google::protobuf::TextFormat::ParseFromString( + lyra_config_string, + reinterpret_cast(&lyra_config))) { + return absl::UnknownError(absl::StrFormat( + "Error when parsing %s: %s", lyra_config_proto, error.message())); + } + } + if (lyra_config.identifier() != kVersionMinor) { + return absl::InvalidArgumentError(absl::StrFormat( + "Weights identifier (%d) is not compatible with code identifier (%d).", + lyra_config.identifier(), kVersionMinor)); + } + return absl::OkStatus(); +} + const std::string& GetVersionString() { static const std::string kVersionString = [] { return absl::StrCat(kVersionMajor, ".", kVersionMinor, ".", kVersionMicro); diff --git a/lyra_config.h b/lyra_config.h index 20644277..1f770ac9 100644 --- a/lyra_config.h +++ b/lyra_config.h @@ -56,126 +56,15 @@ inline constexpr int kSupportedSampleRates[] = {8000, 16000, 32000, 48000}; inline constexpr int kInternalSampleRateHz = 16000; inline constexpr int kNumQuantizationBits = 120; -inline constexpr absl::string_view kAssets[] = { - "lyra_16khz_ar_to_gates_bias.raw.gz", - "lyra_16khz_ar_to_gates_mask.raw.gz", - "lyra_16khz_ar_to_gates_weights.raw.gz", - "lyra_16khz_conditioning_stack_0_bias.raw.gz", - "lyra_16khz_conditioning_stack_0_mask.raw.gz", - "lyra_16khz_conditioning_stack_0_weights.raw.gz", - "lyra_16khz_conditioning_stack_1_bias.raw.gz", - "lyra_16khz_conditioning_stack_1_mask.raw.gz", - "lyra_16khz_conditioning_stack_1_weights.raw.gz", - "lyra_16khz_conditioning_stack_2_bias.raw.gz", - "lyra_16khz_conditioning_stack_2_mask.raw.gz", - "lyra_16khz_conditioning_stack_2_weights.raw.gz", - "lyra_16khz_conv1d_bias.raw.gz", - "lyra_16khz_conv1d_mask.raw.gz", - "lyra_16khz_conv1d_weights.raw.gz", - "lyra_16khz_conv_cond_bias.raw.gz", - "lyra_16khz_conv_cond_mask.raw.gz", - "lyra_16khz_conv_cond_weights.raw.gz", - "lyra_16khz_conv_to_gates_bias.raw.gz", - "lyra_16khz_conv_to_gates_mask.raw.gz", - "lyra_16khz_conv_to_gates_weights.raw.gz", - "lyra_16khz_gru_layer_bias.raw.gz", - "lyra_16khz_gru_layer_mask.raw.gz", - "lyra_16khz_gru_layer_weights.raw.gz", - "lyra_16khz_means_bias.raw.gz", - "lyra_16khz_means_mask.raw.gz", - "lyra_16khz_means_weights.raw.gz", - "lyra_16khz_mix_bias.raw.gz", - "lyra_16khz_mix_mask.raw.gz", - "lyra_16khz_mix_weights.raw.gz", - "lyra_16khz_proj_bias.raw.gz", - "lyra_16khz_proj_mask.raw.gz", - "lyra_16khz_proj_weights.raw.gz", - "lyra_16khz_quant_codebook_dimensions.gz", - "lyra_16khz_quant_code_vectors.gz", - "lyra_16khz_quant_mean_vectors.gz", - "lyra_16khz_quant_transmat.gz", - "lyra_16khz_scales_bias.raw.gz", - "lyra_16khz_scales_mask.raw.gz", - "lyra_16khz_scales_weights.raw.gz", - "lyra_16khz_transpose_0_bias.raw.gz", - "lyra_16khz_transpose_0_mask.raw.gz", - "lyra_16khz_transpose_0_weights.raw.gz", - "lyra_16khz_transpose_1_bias.raw.gz", - "lyra_16khz_transpose_1_mask.raw.gz", - "lyra_16khz_transpose_1_weights.raw.gz", - "lyra_16khz_transpose_2_bias.raw.gz", - "lyra_16khz_transpose_2_mask.raw.gz", - "lyra_16khz_transpose_2_weights.raw.gz"}; -inline constexpr absl::string_view kLyraConfigProto = "lyra_config.textproto"; - inline bool IsSampleRateSupported(int sample_rate_hz) { return std::find(std::begin(kSupportedSampleRates), std::end(kSupportedSampleRates), sample_rate_hz) != std::end(kSupportedSampleRates); } -inline absl::Status AreParamsSupported( - int sample_rate_hz, int num_channels, int bitrate, - const ghc::filesystem::path& model_path) { - if (!IsSampleRateSupported(sample_rate_hz)) { - return absl::InvalidArgumentError(absl::StrFormat( - "Sample rate %d Hz is not supported by codec.", sample_rate_hz)); - } - if (num_channels != kNumChannels) { - return absl::InvalidArgumentError(absl::StrFormat( - "Number of channels %d is not supported by codec. It needs to be %d.", - num_channels, kNumChannels)); - } - if (bitrate != kBitrate) { - return absl::InvalidArgumentError(absl::StrFormat( - "Bitrate %d bps is not supported by codec. It needs to be %d bps.", - bitrate, kBitrate)); - } - for (auto asset : kAssets) { - std::error_code error; - const bool exists = - ghc::filesystem::exists(model_path / std::string(asset), error); - if (error) { - return absl::UnknownError( - absl::StrFormat("Error when probing for asset %s in %s: %s", asset, - model_path, error.message())); - } - if (!exists) { - return absl::InvalidArgumentError( - absl::StrFormat("Asset %s does not exist in %s.", asset, model_path)); - } - } - const ghc::filesystem::path lyra_config_proto = - model_path / std::string(kLyraConfigProto); - std::error_code error; - const bool exists = ghc::filesystem::exists(lyra_config_proto, error); - if (error) { - return absl::UnknownError( - absl::StrFormat("Error when probing for asset %s: %s", - lyra_config_proto, error.message())); - } - third_party::lyra_codec::LyraConfig lyra_config; - if (exists) { - std::ifstream lyra_config_stream(lyra_config_proto.string()); - const std::string lyra_config_string{ - std::istreambuf_iterator(lyra_config_stream), - std::istreambuf_iterator()}; - // Even though LyraConfig is a subclass of Message, the reinterpreting is - // necessary for the mobile proto library. - if (!google::protobuf::TextFormat::ParseFromString( - lyra_config_string, - reinterpret_cast(&lyra_config))) { - return absl::UnknownError(absl::StrFormat( - "Error when parsing %s: %s", lyra_config_proto, error.message())); - } - } - if (lyra_config.identifier() != kVersionMinor) { - return absl::InvalidArgumentError(absl::StrFormat( - "Weights identifier (%d) is not compatible with code identifier (%d).", - lyra_config.identifier(), kVersionMinor)); - } - return absl::OkStatus(); -} +absl::Status AreParamsSupported(int sample_rate_hz, int num_channels, + int bitrate, + const ghc::filesystem::path& model_path); // Returns a string of form "|kVersionMajor|.|kVersionMinor|.|kVersionMicro|". const std::string& GetVersionString(); diff --git a/lyra_decoder_test.cc b/lyra_decoder_test.cc index e5660e07..3e0b4143 100644 --- a/lyra_decoder_test.cc +++ b/lyra_decoder_test.cc @@ -26,7 +26,7 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "absl/memory/memory.h" #include "absl/random/random.h" #include "absl/strings/string_view.h" diff --git a/lyra_encoder_test.cc b/lyra_encoder_test.cc index 750696b2..00336454 100644 --- a/lyra_encoder_test.cc +++ b/lyra_encoder_test.cc @@ -26,7 +26,7 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "absl/memory/memory.h" // IWYU pragma: keep #include "absl/types/optional.h" // IWYU pragma: keep #include "absl/types/span.h" diff --git a/lyra_integration_test.cc b/lyra_integration_test.cc index c088b35b..276d439b 100644 --- a/lyra_integration_test.cc +++ b/lyra_integration_test.cc @@ -21,7 +21,7 @@ #include #include "glog/logging.h" -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" diff --git a/lyra_types.h b/lyra_types.h index 0fe15d7b..7cb5fd74 100644 --- a/lyra_types.h +++ b/lyra_types.h @@ -20,7 +20,7 @@ #include #include "layer_wrapper.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/lyra_wavegru.h b/lyra_wavegru.h index 36dd526b..d2b9c4ff 100644 --- a/lyra_wavegru.h +++ b/lyra_wavegru.h @@ -38,7 +38,7 @@ #include "layer_wrappers_lib.h" #include "lyra_types.h" #include "project_and_sample.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/lyra_wavegru_test.cc b/lyra_wavegru_test.cc index 629a2823..a6041a9e 100644 --- a/lyra_wavegru_test.cc +++ b/lyra_wavegru_test.cc @@ -22,12 +22,12 @@ #include "exported_layers_test.h" #endif // !defined(USE_FIXED16) && !defined(USE_BFLOAT16) -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "absl/strings/str_format.h" #include "gtest/gtest.h" #include "include/ghc/filesystem.hpp" #include "lyra_config.h" -#include "sparse_inference_matrixvector.h" // IWYU pragma: keep +#include "sparse_matmul/sparse_matmul.h" // IWYU pragma: keep namespace chromemedia { namespace codec { diff --git a/project_and_sample.h b/project_and_sample.h index 32a1bb12..5dac3e06 100644 --- a/project_and_sample.h +++ b/project_and_sample.h @@ -32,7 +32,7 @@ #include "absl/time/time.h" #include "glog/logging.h" #include "lyra_types.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/project_and_sample_test.cc b/project_and_sample_test.cc index bf8e26ce..367e8e24 100644 --- a/project_and_sample_test.cc +++ b/project_and_sample_test.cc @@ -20,14 +20,14 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "absl/strings/str_format.h" #include "exported_layers_test.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "include/ghc/filesystem.hpp" #include "lyra_types.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { @@ -143,7 +143,7 @@ TYPED_TEST(ProjectAndSampleTest, GetSamplesReturnGoldenValues) { &this->scratch_space_, kNumSplitBands, actual_samples.data()); barrier->barrier(); }; - LaunchOnThreadsWithBarrier(num_threads, f); + csrblocksparse::LaunchOnThreadsWithBarrier(num_threads, f); EXPECT_THAT(actual_samples, testing::Pointwise(IsRelativelyClose(TestFixture::kTolerance), expected_samples)); diff --git a/sparse_inference_matrixvector.h b/sparse_inference_matrixvector.h deleted file mode 100644 index c51e6a19..00000000 --- a/sparse_inference_matrixvector.h +++ /dev/null @@ -1,985 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LYRA_CODEC_SPARSE_INFERENCE_MATRIXVECTOR_H_ -#define LYRA_CODEC_SPARSE_INFERENCE_MATRIXVECTOR_H_ - -#include "absl/status/status.h" -#include "glog/logging.h" - -// [internal] Start of sparse_inference_matrixvector declarations. - -#if defined __aarch64__ -#include -#endif -#include -#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32) -#include -#endif -#include -#include -#include // NOLINT -#include - -namespace csrblocksparse { - -enum class ARInputsMode { - k0ARInputs, - k2ARInputs, - k3ARInputs, -}; - -class SpinBarrier { - public: - explicit SpinBarrier(int num_threads) - : num_threads_(num_threads), threads_at_barrier_(0), barrier_step_(0) {} - - void barrier(); - - private: - const int num_threads_; - std::atomic threads_at_barrier_; - std::atomic barrier_step_; -}; - -class ProducerConsumer { - public: - ProducerConsumer(int num_producers, int num_consumers); - - inline void produce(); - - inline void consume(); - int num_producers() const; - int num_consumers() const; - - private: - const int num_producers_; - const int num_consumers_; - std::atomic producers_ready_; - std::atomic consumers_passed_; -}; - -using Thread = std::thread; - -class fixed16_type {}; -class fixed32_type {}; - -template -class fixed16 : fixed16_type { - public: - static constexpr int kExponentBits = ExponentBits; - static constexpr int kMantissaBits = 16 - ExponentBits - 1; - - fixed16() = default; - explicit fixed16(float x); - explicit fixed16(int16_t x); - - explicit operator float() const; - - int raw_val() const; - - private: - inline float fixed16_to_float(int16_t x) const; - - inline int16_t float_to_fixed16(float x) const; - - int16_t val_; -}; - -template -class fixed32 : fixed32_type { - public: - static constexpr int kExponentBits = ExponentBits; - static constexpr int kMantissaBits = 32 - ExponentBits - 1; - - fixed32() = default; - explicit fixed32(float x); - explicit fixed32(int32_t x); - - explicit operator float() const; - - int raw_val() const; - - private: - inline float fixed32_to_float(int32_t x) const; - - inline int32_t float_to_fixed32(float x) const; - - int32_t val_; -}; - -class bfloat16 { - public: - bfloat16() = default; - explicit bfloat16(float x); - explicit bfloat16(uint16_t x); - static constexpr int kMantissaBits = 8; - - explicit operator float() const; - - private: - inline uint16_t float_to_bfloat16(float x) const; - - inline float bfloat16_to_float(uint32_t as_int) const; - - uint16_t val_; -}; - -template -struct IsCustomFloatType - : std::integral_constant::value> {}; - -template -struct IsAnyFloatType - : std::integral_constant::value || - IsCustomFloatType::value> {}; - -template -struct IsFixed16Type - : std::integral_constant::value> {}; - -template -struct IsFixed32Type - : std::integral_constant::value> {}; - -template -struct IsFixedType : std::integral_constant::value || - IsFixed32Type::value> { -}; - -template -struct TypeOfProduct {}; - -template -struct TypeOfProduct< - LhsType, RhsType, - typename std::enable_if::value && - IsAnyFloatType::value>::type> { - using type = float; -}; - -template -struct TypeOfProduct< - LhsType, RhsType, - typename std::enable_if::value && - IsFixed16Type::value>::type> { - static_assert(LhsType::kMantissaBits + RhsType::kMantissaBits < 31, - "Sum of mantissa bits must not exceed 31."); - using type = fixed32<31 - LhsType::kMantissaBits - RhsType::kMantissaBits>; -}; - -template -struct MantissaBitsOf { - static constexpr int value = 1; -}; - -namespace detail { - -#if defined __AVX__ - -#if defined __AVX2__ - -template -struct IsAddableFixedTypes - : std::integral_constant::value || - IsFixed16Type::value> {}; -template -struct ShouldEnableGenericAdd - : std::integral_constant::value> {}; - -#else // No AVX2. - -template -struct ShouldEnableGenericAdd : std::true_type {}; - -#endif // __AVX2__ - -template -typename std::enable_if::value>::type SumVectors( - int start, int end, const Type* add1, const Type* add2, Type* result); - -template -typename std::enable_if::value>::type SumVectors( - int start, int end, const Type* add1, const Type* add2, Type* result); - -#elif defined __aarch64__ - -template -struct IsAddableFixedTypes - : std::integral_constant::value || - IsFixed16Type::value> {}; -template -struct ShouldEnableGenericAdd - : std::integral_constant::value> {}; - -template -typename std::enable_if::value>::type SumVectors( - int start, int end, const Type* add1, const Type* add2, Type* result); - -template -typename std::enable_if::value>::type SumVectors( - int start, int end, const Type* add1, const Type* add2, Type* result); - -#else // defined __aarch64__ - -template -struct ShouldEnableGenericAdd : std::true_type {}; - -#endif // defined __AVX__ - -template -typename std::enable_if::value>::type SumVectors( - int start, int end, const Type* add1, const Type* add2, Type* result); - -} // namespace detail - -template -class MutableVectorView; -template -class VectorView; - -template -class CacheAlignedVector { - public: - using value_type = DataType; - - explicit CacheAlignedVector(std::size_t size); - - explicit CacheAlignedVector(const std::vector& input); - - template - explicit CacheAlignedVector(const std::vector& input); - - CacheAlignedVector(const DataType* input, int size); - - template - explicit CacheAlignedVector(const InputType* input, int size); - - CacheAlignedVector(); - - ~CacheAlignedVector(); - - CacheAlignedVector(CacheAlignedVector const& other); - CacheAlignedVector(CacheAlignedVector const& other, int start, int end); - - void operator=(CacheAlignedVector const& other); - - CacheAlignedVector(CacheAlignedVector&& other); - - CacheAlignedVector& operator=(CacheAlignedVector&& other); - - VectorView AsView() const; - - MutableVectorView AsMutableView(); - - void PrepareForThreads(const std::vector& split_points, - int block_height); - - void FillRandom(float min = -10.f, float max = 10.f); - - void FillZero(); - - void FillOnes(); - - void FillWith(const DataType& value); - - template - typename std::enable_if::value, int>::type Sample( - float temperature = 1.f); - -#if defined __aarch64__ - template - typename std::enable_if::value, int>::type Sample( - float temperature, std::minstd_rand* gen, - CacheAlignedVector* scratch) const; - - template - static inline int32x4_t vmul_temp_fixed(int32x4_t x, int32x2_t inv_temp); - - template - static inline int float_to_fixed(float x); - - template - static inline float fixed_to_float(int x); - - template - typename std::enable_if::value, int>::type Sample( - float temperature, std::minstd_rand* gen, - CacheAlignedVector* scratch) const; -#endif // defined __aarch64__ - - template -#if defined __aarch64__ - typename std::enable_if< - !std::is_same::value && !IsFixed32Type::value, int>::type -#else - int -#endif - Sample(float temperature, std::minstd_rand* gen, - CacheAlignedVector* scratch, int tid = 0, - SpinBarrier* barrier = nullptr) const; - - int ScalarSample(float temperature, std::minstd_rand* gen, - CacheAlignedVector* scratch, int tid = 0, - const int mindex = 0, const int maxdex = -1, - SpinBarrier* barrier = nullptr) const; - -#if defined __AVX2__ - inline int ThreadMax(int t_start, int t_end) const; - - template - inline float ApplyExpAndSum(int max_value, float* scratch_ptr); - - inline void FindSamplePoint(const float* scratch_ptr, float* random_target, - int* start, int* end); -#endif // __AVX2__ code - - template - typename std::enable_if::value, int>::type ThreadMax( - int tid) const; - - template - typename std::enable_if::value, int>::type ReducingSample( - std::minstd_rand* gen, CacheAlignedVector* scratch, int tid = 0, - float temperature = 1.0f, SpinBarrier* barrier = nullptr); - - template - typename std::enable_if::value, int>::type ReducingSample( - std::minstd_rand* gen, CacheAlignedVector* scratch, int tid = 0, - float temperature = 1.0f, SpinBarrier* barrier = nullptr); - - template - typename std::enable_if::value, void>::type Exp(); - - template - typename std::enable_if::value, void>::type Sigmoid(); - - template - typename std::enable_if< - IsFixed32Type::value && IsFixed32Type::value, void>::type - Sigmoid(const int32_t* sigmoid_table, CacheAlignedVector* result); - - template - typename std::enable_if::value, void>::type Tanh(); - - template - typename std::enable_if< - IsFixed32Type::value && IsFixed32Type::value, void>::type - Tanh(const int32_t* tanh_table, CacheAlignedVector* result); - - template - typename std::enable_if::value, const int32_t*>::type - cast_data() const; - template - typename std::enable_if::value, const int16_t*>::type - cast_data() const; - template - typename std::enable_if::value || IsFixed16Type::value), - const Q*>::type - cast_data() const; - const DataType* begin() const; - const DataType* end() const; - const DataType* data() const; - DataType* data(); - - const DataType& operator[](int pos) const; - DataType& operator[](int pos); - - std::size_t size() const; - bool empty() const; - std::size_t bytes() const; - - int rows() const; - int cols() const; - - int col_stride() const; - - void Print() const; - - float maximum() const; - - private: - void resize(std::size_t size); - - std::size_t size_; - DataType* data_; - std::vector maxes_; - std::vector thread_starts_; -#if defined __AVX__ || defined __AVX2__ - static constexpr int kCacheLineSize = 64; - static constexpr int kSIMDWidth = 8; -#else - static constexpr int kCacheLineSize = 128; - static constexpr int kSIMDWidth = 4; -#endif // __AVX__ - std::unique_ptr gen_; -}; - -template -class FatCacheAlignedVector { - public: - FatCacheAlignedVector(); - FatCacheAlignedVector(int rows, int cols); - FatCacheAlignedVector(const CacheAlignedVector& vector, int rows); - template - explicit FatCacheAlignedVector(const FatCacheAlignedVector& vector); - FatCacheAlignedVector(CacheAlignedVector&& vector, int rows); - - VectorView slice(const int col) const; - MutableVectorView slice(const int col); - - const T* data() const; - T* data(); - template - typename std::enable_if::value, const int16_t*>::type - cast_data() const; - template - typename std::enable_if::value, const int32_t*>::type - cast_data() const; - template - typename std::enable_if::value || IsFixed32Type::value), - const Q*>::type - cast_data() const; - - int rows() const; - int cols() const; - int size() const; - bool empty() const; - std::size_t bytes() const; - - void reshape(int rows, int cols); - - float maximum() const; - - int col_stride() const; - - void FillOnes(); - void FillZero(); - void FillRandom(float min = -10.f, float max = 10.f); - - const T& operator[](int pos) const; - T& operator[](int pos); - - private: - CacheAlignedVector vector_; - int rows_; - int cols_; -}; - -template -class MutableVectorView { - public: - using value_type = T; - - explicit MutableVectorView(T* data = nullptr, int rows = 0, int cols = 0, - int col_stride = 0); - - explicit MutableVectorView(CacheAlignedVector* vector); - - explicit MutableVectorView(CacheAlignedVector* vector, int pos = 0, - int rows = 0); - - explicit MutableVectorView(FatCacheAlignedVector* vector); - - MutableVectorView(FatCacheAlignedVector* vector, int pos, int rows); - - T* data(); - const T* data() const; - - template - typename std::enable_if::value, const int32_t*>::type - cast_data() const; - template - typename std::enable_if::value, const int16_t*>::type - cast_data() const; - template - typename std::enable_if::value || IsFixed16Type::value), - const Q*>::type - cast_data() const; - - int cols() const; - - int rows() const; - - bool empty() const; - - int col_stride() const; - - std::size_t bytes() const; - - void reshape(int rows, int cols); - - const T& operator[](int pos) const; - T& operator[](int pos); - - protected: - T* data_; - int rows_; - int cols_; - int col_stride_; -}; - -template -class VectorView : public MutableVectorView { - public: - using value_type = T; - - explicit VectorView(const MutableVectorView& other); - - explicit VectorView(const T* data = nullptr, int rows = 0, int cols = 0, - int col_stride = 0); - - explicit VectorView(const CacheAlignedVector& vector); - - explicit VectorView(const CacheAlignedVector& vector, int pos = 0, - int rows = 0); - - explicit VectorView(const FatCacheAlignedVector& vector); - - VectorView(const FatCacheAlignedVector& vector, int pos, int rows); - - VectorView& operator=(const MutableVectorView& other); -}; - -template -class MaskedSparseMatrix; - -class ThreadBounds { - public: - ThreadBounds(); - - void PrepareForThreads(int block_width, int block_height, int num_threads, - int reduced_rows_per_cache_row, int reduced_rows, - const int* nnz_per_row); - - template - const WeightType* OffsetWeights(const WeightType* weights, int tid) const; - template - const RhsIndType* OffsetRhsIndices(const RhsIndType* rhs_indices, - int tid) const; - template - const BiasType* OffsetBias(const BiasType* bias, int tid) const; - template - OutType* OffsetOutput(OutType* output, int tid) const; - int StartRow(int tid) const; - const std::vector& row_starts() const; - - private: - void ComputeThreadSplitPoints(int num_threads, int reduced_rows_per_cache_row, - int reduced_rows, const int* nnz_per_row); - - int block_width_; - int block_height_; - std::vector row_starts_; - std::vector weight_starts_; - std::vector rhs_indices_starts_; - std::vector bias_starts_; -}; - -class MatmulBase { - public: - MatmulBase() { -#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32) - unsigned int eax, ebx, ecx, edx; - if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) { - using_avx_ = (ecx & bit_AVX) != 0; - if (using_avx_) { - __get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx); - using_avx2_ = (ebx & bit_AVX2) != 0; - using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) && - (ebx & bit_AVX512BW) != 0; - VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_; - } else { - LOG(ERROR) << "AVX not found at all!"; - } - } -#else - using_aarch64_ = true; -#endif - } - - protected: - bool using_avx512_ = false; - bool using_avx2_ = false; - bool using_avx_ = false; - bool using_aarch64_ = false; -}; - -constexpr int kGenericSIMDWidth = 4; - -template -class GruGates : public MatmulBase { - public: - using SampleWeightType = float; - static constexpr int kSIMDWidth = kGenericSIMDWidth; - - template - void GruWithARInput(int start, int end, int state_size, - const InputType* gru_recurrent_ptr, - const InputType* input_ptr, GruStateType* gru_state_ptr, - const SampleType* ar_sample0 = nullptr, - const SampleType* ar_sample1 = nullptr, - const SampleWeightType* ar_01_weights = nullptr, - int num_replicas = 1, int replica_stride = 0, - const SampleType* ar_sample2 = nullptr, - const SampleWeightType* ar_2_weights = nullptr, - const InputType* gru_recurrent_other_ptr = nullptr); - - void PlainGru(int start, int end, int state_size, - const InputType* gru_recurrent_ptr, const InputType* input_ptr, - GruStateType* gru_state_ptr); -}; - -#if defined __ARM_NEON || defined __aarch64__ -static constexpr int kNeonSIMDWidth = 4; - -template <> -class GruGates : public MatmulBase { - public: - static constexpr int kSIMDWidth = kNeonSIMDWidth; - - template - void GruWithARInput(int start, int end, int state_size, - const float* gru_recurrent_data, const float* input_data, - float* gru_state_data, const float* ar_sample0 = nullptr, - const float* ar_sample1 = nullptr, - const float* ar_01_weights = nullptr, - int num_replicas = 1, int replica_stride = 0, - const float* ar_sample2 = nullptr, - const float* ar_2_weights = nullptr, - const float* gru_recurrent_other_data = nullptr); -}; -#endif // defined __ARM_NEON || defined __aarch64__ - -template -class GruGates, fixed32, - fixed16> : public MatmulBase { - public: -#if defined __ARM_NEON || defined __aarch64__ - static constexpr int kSIMDWidth = kNeonSIMDWidth; -#elif defined __AVX2__ - static constexpr int kSIMDWidth = kAVX2SIMDWidth * 2; -#else // Generic case. - static constexpr int kSIMDWidth = kGenericSIMDWidth; -#endif // __ARM_NEON || defined __aarch64__ / __AVX2__ - - using GruStateType = fixed16; - using InputType = fixed32; - using SampleType = fixed16; - using SampleWeightType = float; - static constexpr int kInputMantissaBits = InputType::kMantissaBits; - static constexpr int kSampleMantissaBits = SampleType::kMantissaBits; - static constexpr int kStateMantissaBits = GruStateType::kMantissaBits; - template - void GruWithARInput(int start, int end, int state_size, - const InputType* gru_recurrent_data, - const InputType* input_data, GruStateType* gru_state_data, - const SampleType* ar_sample0 = nullptr, - const SampleType* ar_sample1 = nullptr, - const SampleWeightType* ar_01_weights = nullptr, - int num_replicas = 1, int replica_stride = 0, - const SampleType* ar_sample2 = nullptr, - const SampleWeightType* ar_2_weights = nullptr, - const InputType* gru_recurrent_other_data = nullptr); -}; - -template -class Matmul : public MatmulBase { - public: - template - void MatVec4x4(const WeightType* weights, const RhsType* rhs, - const typename TypeOfProduct::type* bias, - const int32_t* nnz_per_row, const int16_t* rhs_indices, - int start_row, int end_row, bool relu, int replicas, - int stride, OutType* output); - template - void MatVec8x4(const WeightType* weights, const RhsType* rhs, - const typename TypeOfProduct::type* bias, - const int32_t* nnz_per_row, const int16_t* rhs_indices, - int start_row, int end_row, bool relu, int replicas, - int stride, OutType* output); -}; - -template <> -class Matmul : public MatmulBase { - public: - void MatVec4x4(const float* weights, const float* rhs, const float* bias, - const int32_t* nnz_per_row, const int16_t* rhs_indices, - int start_row, int end_row, bool relu, int replicas, - int stride, float* output); - void MatVec8x4(const float* weights, const float* rhs, const float* bias, - const int32_t* nnz_per_row, const int16_t* rhs_indices, - int start_row, int end_row, bool relu, int replicas, - int stride, float* output); -}; - -template -class Matmul, fixed16> : public MatmulBase { - public: - using WeightType = fixed16; - using RhsType = fixed16; - - template - void MatVec4x4(const int16_t* weights, const int16_t* rhs, - const int32_t* bias, const int32_t* nnz_per_row, - const int16_t* rhs_indices, int start_row, int end_row, - bool relu, int replicas, int stride, OutType* output); - - template - void MatVec8x4(const int16_t* weights, const int16_t* rhs, - const int32_t* bias, const int32_t* nnz_per_row, - const int16_t* rhs_indices, int start_row, int end_row, - bool relu, int replicas, int stride, OutType* output); -}; - -template -class CsrBlockSparseMatrix { - public: - CsrBlockSparseMatrix(); - - CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len); - - template - CsrBlockSparseMatrix(const MaskedSparseMatrix& masked_matrix); - - CsrBlockSparseMatrix( - const CsrBlockSparseMatrix& src_matrix, - const std::vector& new_weights, - const std::vector& new_deltas, const std::vector& new_nnz, - int cols); - - CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col, - bool keep_rhs_size = false) const; - - CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const; - - void DoubleBlockHeight(); - - std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer); - - void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len); - - template - void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out, - bool relu = false, int tid = 0, - SpinBarrier* barrier = nullptr) const; - template - void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid, - int replicas, int output_stride, OutType* output); - - int rows() const; - int cols() const; - int block_height() const; - int block_width() const; - float sparsity() const; - int num_threads() const; - const ThreadBounds& thread_bounds() const; - const CacheAlignedVector& rhs_indices() const; - const std::string& name() const; - void set_name(const std::string& name); - const std::vector& split_points() const; - - std::size_t bytes() const; - - template - typename std::enable_if::value, int>::type - SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, - float temperature, int tid, SpinBarrier* barrier, - std::minstd_rand* gen, - CacheAlignedVector* scratch) const; - template - typename std::enable_if::value, int>::type - SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, - float temperature, int tid, SpinBarrier* barrier, - std::minstd_rand* gen, - CacheAlignedVector* scratch) const; - - void Print() const; - - template - int PrepareForThreads(int num_threads, int cache_line_size = -1); - - void ComputeRHSIndices(); - - void ComputeColDeltas(); - - std::vector CumulativeColDeltas() const; - - private: - constexpr std::size_t FixedParameterSize() const; - template - void DetermineBlockSize(const MaskedSparseMatrix& masked_matrix); - - template - void MakeColumnsMultiple(const std::vector& row_offsets, - std::vector* reduced_mask, - std::vector* weights); - - template - void MaskAndWeightsToCsr(const std::vector& mask, - const std::vector& weights, - std::vector* nnz_per_row, - std::vector* col_indices, - std::vector* weights_csr); - - template - int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const; - - int col_multiple_; - int rows_; - int cols_; - int reduced_rows_; - int reduced_cols_; - float sparsity_; - int block_width_; - int block_height_; - int num_threads_; - std::string name_; - - CacheAlignedVector weights_; - CacheAlignedVector col_deltas_; - CacheAlignedVector nnz_per_row_; - CacheAlignedVector rhs_indices_; - Matmul matmul_; - ThreadBounds thread_bounds_; - static constexpr int kCacheLineSize = 64; -}; - -template ::type, - typename DeltaType = int16_t> -class SparseLinearLayer { - public: - SparseLinearLayer(); - - SparseLinearLayer(CsrBlockSparseMatrix&& sparse_matrix, - CacheAlignedVector&& bias); - SparseLinearLayer( - const SparseLinearLayer& src); - SparseLinearLayer& operator=( - const SparseLinearLayer& src); - - template - void SpMM_bias(const RhsClassType& rhs, OutType* out, bool relu = false, - int tid = 0, SpinBarrier* barrier = nullptr) const; - template - int SpMM_bias_Sample(const RhsClassType& rhs, OutType* out, float temperature, - int tid, SpinBarrier* barrier, std::minstd_rand* gen, - CacheAlignedVector* scratch) const; - template - void MatVec(const RhsClassType& rhs, bool relu, int tid, int replicas, - int output_stride, OutType* output, - SpinBarrier* barrier = nullptr); - - int rows() const; - int cols() const; - float sparsity() const; - int block_width() const; - int block_height() const; - int num_threads() const; - const CacheAlignedVector& bias() const; - const std::vector& split_points() const; - bool IsSplit() const; - - std::size_t bytes() const; - void Print() const; - - void DoubleBlockHeight(); - - int PrepareForThreads(int num_threads, int cache_line_size = -1); - - void SliceForThreads(const std::vector& split_points); - - void SplitInputs( - SparseLinearLayer* part1, - SparseLinearLayer* part2); - - void SplitOutputs( - SparseLinearLayer* part1, - SparseLinearLayer* part2); - - private: - struct PartLinearLayer { - PartLinearLayer(const CsrBlockSparseMatrix& matrix, - const CacheAlignedVector& bias, - const CacheAlignedVector& bias_4, int tid, - int start_col, int end_col); - CsrBlockSparseMatrix self_matrix; - CacheAlignedVector full_bias; - CacheAlignedVector quarter_bias; - CsrBlockSparseMatrix other_matrix; - }; - CsrBlockSparseMatrix sparse_matrix_; - CacheAlignedVector bias_; - CacheAlignedVector full_bias_; - CacheAlignedVector mid_output_; - std::vector thread_layers_; - std::unique_ptr split_pc_; - int num_threads_ = 0; -}; - -template -SparseLinearLayer CreateConstantLayer( - int rows, int cols, float sparsity, float constant = 1.f); - -template -absl::Status LoadLogitLayer( - const std::string& prefix, bool zipped, const std::string& path, - SparseLinearLayer* sparse_linear_layer); - -template -absl::Status LoadSparseLayer( - const std::string& prefix, bool zipped, - SparseLinearLayer* sparse_linear_layer, - const std::string& path); - -template -typename std::enable_if::value, - absl::Status>::type -ReadArrayFromFile(const std::string& file_name, std::vector* array, - const std::string& path = "/data/local/tmp/"); -template -typename std::enable_if::value && - csrblocksparse::IsFixed16Type::value, - absl::Status>::type -ReadArrayFromFile(const std::string& file_name, std::vector* array, - const std::string& path = "/data/local/tmp/"); - -} // namespace csrblocksparse - -// [internal] End of sparse_inference_matrixvector declarations. - -namespace chromemedia { -namespace codec { - -typedef std::function Function; -void LaunchOnThreadsWithBarrier(int num_threads, Function&& func); - -} // namespace codec -} // namespace chromemedia - -#endif // LYRA_CODEC_SPARSE_INFERENCE_MATRIXVECTOR_H_ diff --git a/sparse_inference_matrixvector_test.cc b/sparse_inference_matrixvector_test.cc deleted file mode 100644 index 99a6eeda..00000000 --- a/sparse_inference_matrixvector_test.cc +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2021 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tests for sparse_inference_matrixvector library. -// This can be useful for checking runtime ABI compatibility on a unit. -#include "sparse_inference_matrixvector.h" - -#include "gtest/gtest.h" - -namespace chromemedia { -namespace codec { -namespace { - -TEST(SparseInferenceTest, ScalarSample) { - // The vector size must be a multiple of 8. - constexpr int kOutputBins = 96; - // The scratch size must be at least as big as the vector size. - csrblocksparse::CacheAlignedVector scratch_space(kOutputBins); - - const std::minstd_rand::result_type kSeed = 42; - std::minstd_rand gen{std::minstd_rand(kSeed)}; - - csrblocksparse::CacheAlignedVector mat = - csrblocksparse::CacheAlignedVector(kOutputBins); - mat.ScalarSample(1.0, &gen, &scratch_space); - // If we've reached this point, there hasn't been anything so mismatched - // as to cause a segfault. - // This is useful for testing our library ABI. - SUCCEED(); -} - -} // namespace -} // namespace codec -} // namespace chromemedia diff --git a/sparse_matmul/BUILD b/sparse_matmul/BUILD new file mode 100644 index 00000000..03c147cf --- /dev/null +++ b/sparse_matmul/BUILD @@ -0,0 +1,21 @@ +# [internal] load placeholder + +licenses(["notice"]) + +cc_library( + name = "sparse_matmul", + hdrs = [ + "sparse_matmul.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//sparse_matmul/compute:gru_gates", + "//sparse_matmul/layers:layer", + "//sparse_matmul/layers:matrix", + "//sparse_matmul/layers:utils", + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:types", + "//sparse_matmul/os:coop_threads", + "//sparse_matmul/vector:cache_aligned_vector", + ], # internal :sparse_matmul deps placeholder +) diff --git a/sparse_matmul/compute/BUILD b/sparse_matmul/compute/BUILD new file mode 100644 index 00000000..41f5b3f6 --- /dev/null +++ b/sparse_matmul/compute/BUILD @@ -0,0 +1,88 @@ +# Low-level computation code, including generic and architecture-specific +# variants. + +licenses(["notice"]) + +cc_library( + name = "gru_gates", + srcs = [ + "ar_inputs.h", + "gru_gates_arm.h", + "gru_gates_avx_fixed.h", + "gru_gates_generic.h", + ], + hdrs = ["gru_gates.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":matmul", + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:types", + "//sparse_matmul/vector:cache_aligned_vector", + ], +) + +cc_library( + name = "kernels", + srcs = [ + "kernels_arm.h", + "kernels_avx.h", + ], + hdrs = [ + "kernels_generic.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:types", + ], +) + +cc_library( + name = "matmul", + srcs = [ + "matmul_fixed_avx2.cc", + "matmul_fixed_avx2.h", + "matmul_generic.cc", + "matmul_generic.h", + ], + hdrs = [ + "matmul.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "//sparse_matmul/numerics:types", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "thread_bounds", + srcs = ["thread_bounds.cc"], + hdrs = ["thread_bounds.h"], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "@com_google_glog//:glog", + ], +) + +cc_test( + name = "gru_gates_test", + size = "small", + srcs = [ + "gru_gates_test.cc", + ], + deps = [ + ":gru_gates", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/sparse_matmul/compute/ar_inputs.h b/sparse_matmul/compute/ar_inputs.h new file mode 100644 index 00000000..d10e2d96 --- /dev/null +++ b/sparse_matmul/compute/ar_inputs.h @@ -0,0 +1,37 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_ + +namespace csrblocksparse { + +// Possible numbers of Autoregressive inputs. +// TODO(b/188702959): Generalize to any non-negative integer value? +enum class ARInputsMode { + // There are no autoregressive inputs. Inputs to the GRU gates are strictly + // from the gate-recurrent matmul and other unrelated inputs. + k0ARInputs, + // Two autoregressive inputs, such as coarse and fine for WaveRNN. + k2ARInputs, + // Three autoregressive inputs, such as prev coarse and fine plus current + // coarse for WaveRNN. + k3ARInputs, +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_ diff --git a/sparse_matmul/compute/gru_gates.h b/sparse_matmul/compute/gru_gates.h new file mode 100644 index 00000000..7b8cd489 --- /dev/null +++ b/sparse_matmul/compute/gru_gates.h @@ -0,0 +1,214 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_ + +#include +#include + +// IWYU pragma: begin_exports +#include "sparse_matmul/compute/ar_inputs.h" +#include "sparse_matmul/compute/gru_gates_arm.h" +#include "sparse_matmul/compute/gru_gates_avx_fixed.h" +#include "sparse_matmul/compute/gru_gates_generic.h" +#include "sparse_matmul/compute/matmul.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" +// IWYU pragma: end_exports + +namespace csrblocksparse { + +// The master template is really a catch-all for the unimplemented cases to +// run the generics. +template +class GruGates : public MatmulBase { + public: + using SampleWeightType = float; + static constexpr int kSIMDWidth = kGenericSIMDWidth; + + // Generic GRU function covers all uses for WaveRNN-like architectures and + // conditioning. + // Controlled by template parameters thus: + // - |kInputsMode| == |k0ARInputs|: There are no autoregressive inputs so + // |ar_sample0|, |ar_sample1|, |ar_sample2|, |ar_01_weights|, + // |ar_2_weights| are ignored. + // - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied + // by |ar_01_weights| and added to the (conditioning) input. + // - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by + // |ar_2_weights| and added to the other two |ar_inputs| (and added to the + // conditioning input). + // - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary + // recurrent input that must be added to |*gru_recurrent_ptr|. + // - |num_replicas| determines the number of duplicates of the output to be + // written, separated by |replica_stride|. + // - |start|, |end| are |rows| in [0, |state_size|] to be processed by this + // thread. + // + // Previous state is read from |*gru_state_ptr| and the new state is written + // to *(|gru_state_ptr| + i * |replica_stride| for i in [0, |num_replicas|)). + template + void GruWithARInput(int start, int end, int state_size, + const InputType* gru_recurrent_ptr, + const InputType* input_ptr, GRUStateType* gru_state_ptr, + const SampleType* ar_sample0 = nullptr, + const SampleType* ar_sample1 = nullptr, + const SampleWeightType* ar_01_weights = nullptr, + int num_replicas = 1, int replica_stride = 0, + const SampleType* ar_sample2 = nullptr, + const SampleWeightType* ar_2_weights = nullptr, + const InputType* gru_recurrent_other_ptr = nullptr) { + CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica"; + GoThroughGates( + start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr, + input_ptr, gru_state_ptr, ar_2_weights, state_size, ar_sample0, + ar_sample1, ar_sample2); + } + + // No AR inputs, no split gates, no batching, no replicated outputs. + // TODO(b/188702959): Redirect conditioning GRU here, removing code from + // gru_layer.h. + // Copy to specializations. + void PlainGru(int start, int end, int state_size, + const InputType* gru_recurrent_ptr, const InputType* input_ptr, + GRUStateType* gru_state_ptr) { + GruWithARInput( + start, end, state_size, gru_recurrent_ptr, input_ptr, gru_state_ptr); + } +}; + +#if defined __ARM_NEON || defined __aarch64__ +// Partial specialization for float. +template <> +class GruGates : public MatmulBase { + public: + static constexpr int kSIMDWidth = kNeonSIMDWidth; + + // Generic GRU function covers all uses for WaveRNN-like architectures and + // conditioning. + template + void GruWithARInput(int start, int end, int state_size, + const float* gru_recurrent_data, const float* input_data, + float* gru_state_data, const float* ar_sample0 = nullptr, + const float* ar_sample1 = nullptr, + const float* ar_01_weights = nullptr, + int num_replicas = 1, int replica_stride = 0, + const float* ar_sample2 = nullptr, + const float* ar_2_weights = nullptr, + const float* gru_recurrent_other_data = nullptr) { + DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica"; + GoThroughGatesFloat( + start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data, + input_data, gru_state_data, ar_2_weights, state_size, ar_sample0, + ar_sample1, ar_sample2); + } +}; +#endif // defined __ARM_NEON || defined __aarch64__ + +// Partial specialization for fixed types. The sample weights are always float +// whatever the fixed type of the other weights. +template +class GruGates, fixed32, + fixed16> : public MatmulBase { + public: +#if defined __ARM_NEON || defined __aarch64__ + static constexpr int kSIMDWidth = kNeonSIMDWidth; +#elif defined __AVX2__ + static constexpr int kSIMDWidth = kAVX2SIMDWidth * 2; +#else // Generic case. + static constexpr int kSIMDWidth = kGenericSIMDWidth; +#endif // __ARM_NEON || defined __aarch64__ / __AVX2__ + + using GRUStateType = fixed16; + using InputType = fixed32; + using SampleType = fixed16; + using SampleWeightType = float; + static constexpr int kInputMantissaBits = InputType::kMantissaBits; + static constexpr int kSampleMantissaBits = SampleType::kMantissaBits; + static constexpr int kStateMantissaBits = GRUStateType::kMantissaBits; + // Generic GRU function covers all uses for WaveRNN-like architectures and + // conditioning. + template + void GruWithARInput(int start, int end, int state_size, + const InputType* gru_recurrent_data, + const InputType* input_data, GRUStateType* gru_state_data, + const SampleType* ar_sample0 = nullptr, + const SampleType* ar_sample1 = nullptr, + const SampleWeightType* ar_01_weights = nullptr, + int num_replicas = 1, int replica_stride = 0, + const SampleType* ar_sample2 = nullptr, + const SampleWeightType* ar_2_weights = nullptr, + const InputType* gru_recurrent_other_data = nullptr) { +#if defined __ARM_NEON || defined __aarch64__ || defined __AVX2__ + const int32_t* gru_recurrent_ptr = + reinterpret_cast(gru_recurrent_data); + const int32_t* gru_recurrent_other_ptr = + reinterpret_cast(gru_recurrent_other_data); + const int32_t* input_ptr = reinterpret_cast(input_data); + int16_t* gru_state_ptr = reinterpret_cast(gru_state_data); +#if defined __AVX2__ + // The samples are fixed16, but we scale them up here and convert to float + // so that the product with the QR weights is always on the same scale as + // InputType, so we don't have to do any more scaling inside. + const float sample_factor = static_cast(1 << kInputMantissaBits); +#else + const float sample_factor = 1.0f; +#endif + // AR sample 0 and 1 are packed into a pair because the QR weights are + // formatted with the weights interleaved for sample 0 and 1. + std::pair ar_sample01; + float ar_sample2_float = 0.0f; + if (kInputsMode == ARInputsMode::k2ARInputs || + kInputsMode == ARInputsMode::k3ARInputs) { + ar_sample01 = {static_cast(*ar_sample0) * sample_factor, + static_cast(*ar_sample1) * sample_factor}; + if (kInputsMode == ARInputsMode::k3ARInputs) { + ar_sample2_float = static_cast(*ar_sample2) * sample_factor; + } + } +#if defined __AVX2__ + CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; + GruGatesAVXFixed( + start, end, state_size, gru_recurrent_ptr, input_ptr, &ar_sample01, + ar_01_weights, num_replicas, replica_stride, &ar_sample2_float, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); +#else // ARM. + DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica"; + GoThroughGatesFixed( + start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr, + input_ptr, gru_state_ptr, ar_2_weights, state_size, &ar_sample01, + &ar_sample2_float); +#endif // __AVX2__ / ARM. +#else // Generic case. + CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica"; + GoThroughGates( + start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data, + input_data, gru_state_data, ar_2_weights, state_size, ar_sample0, + ar_sample1, ar_sample2); +#endif // __ARM_NEON || defined __aarch64__ / __AVX2__ + } +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_ diff --git a/sparse_matmul/compute/gru_gates_arm.h b/sparse_matmul/compute/gru_gates_arm.h new file mode 100644 index 00000000..d95805da --- /dev/null +++ b/sparse_matmul/compute/gru_gates_arm.h @@ -0,0 +1,288 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ + +#if defined __ARM_NEON || defined __aarch64__ +#include +#endif +#include + +#include "sparse_matmul/compute/ar_inputs.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" + +namespace csrblocksparse { + +static constexpr int kNeonSIMDWidth = 4; + +// ------ Scalar calculation -------- +// See "Efficient Neural Audio Synthesis" for a description of the calculation. +// https://arxiv.org/abs/1802.08435 +// +// NOTE: +// |sample| = (|coarse_at_sminus1|, |fine_at_sminus1|, +// |coarse_at_sminus1|, |fine_at_sminus1|) +// |w_sample| = (|coarse_at_s|, |coarse_at_s|, |coarse_at_s|, |coarse_at_s|) +// +// CHEATSHEET: +// vld1q_f32 = load 4 32-bit floats +// vmulq_f32(a, b) : return a * b; +// vaddq_f32(a, b) : return a + b; +// vmlaq_f32(c, a, b) : return c + a * b; +// vpaddq_f32(a, b) : return (a0 + a1, a2 + a3, b0 + b1, b2 + b3) +// vsubq_f32(a, b) : return a - b; +// vst1q_f32 = store 4 32-bit floats +#if defined __ARM_NEON || defined __aarch64__ + +#if !defined __aarch64__ +// Backport of vpaddq_f32 to ARM32. +inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) { + float32x2_t a10 = vget_low_f32(a); + float32x2_t a32 = vget_high_f32(a); + float32x2_t b10 = vget_low_f32(b); + float32x2_t b32 = vget_high_f32(b); + return vcombine_f32(vpadd_f32(a10, a32), vpadd_f32(b10, b32)); +} +#endif + +template +void GoThroughGatesFloat(int start, int end, const float* qr_ptr, + const float* gru_gates_ptr, + const float* gru_gates_other_ptr, + const float* conditioning_ptr, float* gru_h_ptr, + const float* w_hat, int proj_size, + const float* coarse_at_sminus1, + const float* fine_at_sminus1, + const float* coarse_at_s) { + // Increment all the pointers to save on pointer arithmetic in the loop. + conditioning_ptr += start; + gru_h_ptr += start; + gru_gates_ptr += start; + if (SplitGates) { + DCHECK_NE(gru_gates_other_ptr, nullptr); + gru_gates_other_ptr += start; + } + if (kInputsMode != ARInputsMode::k0ARInputs) { + DCHECK_NE(qr_ptr, nullptr); + qr_ptr += 2 * start; + DCHECK_NE(coarse_at_sminus1, nullptr); + DCHECK_NE(fine_at_sminus1, nullptr); + if (kInputsMode == ARInputsMode::k3ARInputs) { + DCHECK_NE(w_hat, nullptr); + DCHECK_NE(coarse_at_s, nullptr); + w_hat += start; + } + } + for (int i = start; i < end; i += kNeonSIMDWidth) { + float32x4_t reset = vld1q_f32(gru_gates_ptr); + float32x4_t update = vld1q_f32(gru_gates_ptr + proj_size); + float32x4_t cell = vld1q_f32(gru_gates_ptr + 2 * proj_size); + float32x4_t qr_cell; + if (SplitGates) { + reset = vaddq_f32(reset, vld1q_f32(gru_gates_other_ptr)); + update = vaddq_f32(update, vld1q_f32(gru_gates_other_ptr + proj_size)); + cell = vaddq_f32(cell, vld1q_f32(gru_gates_other_ptr + 2 * proj_size)); + } + if (kInputsMode != ARInputsMode::k0ARInputs) { + // Setup the sample vector. + float32x4_t sample = vdupq_n_f32(*coarse_at_sminus1); + sample = vsetq_lane_f32(*fine_at_sminus1, sample, 1); + sample = vsetq_lane_f32(*fine_at_sminus1, sample, 3); + + // All auto types are float32x4_t, auto used to fit statements on one line + // for readability. Do two rows of QR at once. + auto qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample); + auto qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample); + auto qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); + + auto qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample); + auto qr_update_1 = + vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample); + auto qr_update = vpaddq_f32(qr_update_0, qr_update_1); + + auto qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample); + auto qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample); + qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); + + if (kInputsMode == ARInputsMode::k3ARInputs) { + float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); + qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); + qr_update = + vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); + qr_cell = + vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); + } + reset = vaddq_f32(reset, qr_reset); + update = vaddq_f32(update, qr_update); + } + auto reset_conditioning = vld1q_f32(conditioning_ptr); + auto update_conditioning = vld1q_f32(conditioning_ptr + proj_size); + auto cell_conditioning = vld1q_f32(conditioning_ptr + 2 * proj_size); + + reset = fast_sigmoid(vaddq_f32(reset, reset_conditioning)); + update = fast_sigmoid(vaddq_f32(update, update_conditioning)); + if (kInputsMode == ARInputsMode::k0ARInputs) { + cell = vmulq_f32(reset, cell); + } else { + cell = vmlaq_f32(qr_cell, reset, cell); + } + auto hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); + + auto prev_h = vld1q_f32(gru_h_ptr); + auto diff = vsubq_f32(prev_h, hbar); + auto new_h = vmlaq_f32(hbar, diff, update); + + vst1q_f32(gru_h_ptr, new_h); + // Increment all the pointers. + conditioning_ptr += kNeonSIMDWidth; + gru_h_ptr += kNeonSIMDWidth; + gru_gates_ptr += kNeonSIMDWidth; + if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; + if (kInputsMode != ARInputsMode::k0ARInputs) { + qr_ptr += 2 * kNeonSIMDWidth; + if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; + } + } +} + +// This version should only be used if all of the 32-bit fixed point +// representations have the same number of mantissa bits. +// |ar_at_sminus1| packs sample 0 and 1 into a pair because the QR weights are +// formatted with the weights interleaved for sample 0 and 1. The two samples +// represent coarse and fine for WaveRNN. +template +void GoThroughGatesFixed(int start, int end, const float* qr_ptr, + const int32_t* gru_gates_ptr, + const int32_t* gru_gates_other_ptr, + const int32_t* conditioning_ptr, int16_t* gru_h_ptr, + const float* w_hat, int proj_size, + const std::pair* ar_at_sminus1, + const float* coarse_at_s) { + // Increment all the pointers to save on pointer arithmetic in the loop. + conditioning_ptr += start; + gru_h_ptr += start; + gru_gates_ptr += start; + if (SplitGates) { + DCHECK_NE(gru_gates_other_ptr, nullptr); + gru_gates_other_ptr += start; + } + float32x4_t sample01; + float32x4_t w_sample; + if (kInputsMode != ARInputsMode::k0ARInputs) { + DCHECK_NE(qr_ptr, nullptr); + qr_ptr += 2 * start; + DCHECK_NE(ar_at_sminus1, nullptr); + sample01 = vdupq_n_f32(ar_at_sminus1->first); + sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 1); + sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 3); + if (kInputsMode == ARInputsMode::k3ARInputs) { + DCHECK_NE(w_hat, nullptr); + DCHECK_NE(coarse_at_s, nullptr); + w_hat += start; + w_sample = vdupq_n_f32(*coarse_at_s); + } + } + for (int i = start; i < end; i += kNeonSIMDWidth) { + auto reset = vld1q_s32(gru_gates_ptr); + auto update = vld1q_s32(gru_gates_ptr + proj_size); + // vcvtq_n_f32_s32 = convert 32-bit fixed point to fp32 + auto cell_int = vld1q_s32(gru_gates_ptr + 2 * proj_size); + if (SplitGates) { + reset = vaddq_s32(reset, vld1q_s32(gru_gates_other_ptr)); + update = vaddq_s32(update, vld1q_s32(gru_gates_other_ptr + proj_size)); + cell_int = + vaddq_s32(cell_int, vld1q_s32(gru_gates_other_ptr + 2 * proj_size)); + } + float32x4_t cell = + vcvtq_n_f32_s32(cell_int, GRUMatMulOutType::kMantissaBits); + float32x4_t qr_cell; + if (kInputsMode != ARInputsMode::k0ARInputs) { + // Do two rows of QR at once. + float32x4_t qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample01); + float32x4_t qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample01); + float32x4_t qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); + + float32x4_t qr_update_0 = + vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample01); + float32x4_t qr_update_1 = + vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample01); + float32x4_t qr_update = vpaddq_f32(qr_update_0, qr_update_1); + + float32x4_t qr_cell_0 = + vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample01); + float32x4_t qr_cell_1 = + vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample01); + qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); + if (kInputsMode == ARInputsMode::k3ARInputs) { + float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); + qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); + qr_update = + vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); + qr_cell = + vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); + } + reset = vaddq_s32( + reset, vcvtq_n_s32_f32(qr_reset, GRUMatMulOutType::kMantissaBits)); + update = vaddq_s32( + update, vcvtq_n_s32_f32(qr_update, GRUMatMulOutType::kMantissaBits)); + } + + auto reset_conditioning = vld1q_s32(conditioning_ptr); + auto update_conditioning = vld1q_s32(conditioning_ptr + proj_size); + float32x4_t cell_conditioning = + vcvtq_n_f32_s32(vld1q_s32(conditioning_ptr + 2 * proj_size), + GRUMatMulOutType::kMantissaBits); + + float32x4_t reset_f32 = fast_sigmoid( + vaddq_s32(reset, reset_conditioning)); + float32x4_t update_f32 = fast_sigmoid( + vaddq_s32(update, update_conditioning)); + if (kInputsMode == ARInputsMode::k0ARInputs) { + cell = vmulq_f32(reset_f32, cell); + } else { + cell = vmlaq_f32(qr_cell, reset_f32, cell); + } + float32x4_t hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); + + float32x4_t prev_h = vcvtq_n_f32_s32(vmovl_s16(vld1_s16(gru_h_ptr)), + GRUStateType::kMantissaBits); + float32x4_t diff = vsubq_f32(prev_h, hbar); + float32x4_t new_h = vmlaq_f32(hbar, diff, update_f32); + + // vcvtq_n_s32_f32 = convert fp32 to signed 32-bit fixed point + // vqrshrn_n_s32 = saturating, rounding, narrowing right shift - used to + // convert a 32-bit fixed point value to a 16-bit fixed point value + vst1_s16(gru_h_ptr, + vqrshrn_n_s32( + vcvtq_n_s32_f32(new_h, GRUStateType::kMantissaBits + 16), 16)); + // Increment all the pointers. + conditioning_ptr += kNeonSIMDWidth; + gru_h_ptr += kNeonSIMDWidth; + gru_gates_ptr += kNeonSIMDWidth; + if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; + if (kInputsMode != ARInputsMode::k0ARInputs) { + qr_ptr += 2 * kNeonSIMDWidth; + if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; + } + } +} +#endif // defined __ARM_NEON || defined __aarch64__ + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ diff --git a/sparse_matmul/compute/gru_gates_avx_fixed.h b/sparse_matmul/compute/gru_gates_avx_fixed.h new file mode 100644 index 00000000..0703020c --- /dev/null +++ b/sparse_matmul/compute/gru_gates_avx_fixed.h @@ -0,0 +1,348 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_ + +#include +#if defined __AVX2__ +#include +#endif +#include + +#include "sparse_matmul/compute/ar_inputs.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" + +namespace csrblocksparse { + +#if defined __AVX2__ + +constexpr int kAVX2SIMDWidth = 8; + +// Loads 8x fixed32 from |ptr0| and adds to |input|. +// If |kTwoInputs|, also loads from |ptr1| and adds that as well. +// Returns the 2 or 3-way sum. +template +inline __m256i LoadAndAddFixed32(const int32_t* ptr0, const int32_t* ptr1, + const __m256i& input) { + __m256i data0 = _mm256_load_si256(reinterpret_cast(ptr0)); + if (kTwoInputs) { + __m256i data1 = _mm256_load_si256(reinterpret_cast(ptr1)); + data0 = _mm256_add_epi32(data0, data1); + } + return _mm256_add_epi32(data0, input); +} + +// Loads 8x fixed32 from ptr0. +// If |kTwoInputs|, also loads from |ptr1| and adds. +// Multiplies the loaded values by the factor and adds to |input|, which also +// is converted to float. +// Returns the sum. +template +inline __m256 LoadMultiplyAddToFloat(const int32_t* ptr0, const int32_t* ptr1, + const __m256& float_factor, + const __m256& input) { + __m256i data0 = _mm256_load_si256(reinterpret_cast(ptr0)); + if (kTwoInputs) { + __m256i data1 = _mm256_load_si256(reinterpret_cast(ptr1)); + data0 = _mm256_add_epi32(data0, data1); + } + __m256 float_result = _mm256_cvtepi32_ps(data0); + float_result = _mm256_mul_ps(float_result, float_factor); + return _mm256_add_ps(float_result, input); +} + +// Loads 16x float in 2x 8x registers from |ptr0_1| and multiplies by +// |input_pairs|, likewise formatted as 8x floats, alternating between the two +// AR inputs and sums each pair of results, making 8x float results. +// If |kThreeInputs|, also loads 8x float from |ptr2| and multiplies by +// |third_input|, which must be formatted as 8x float. The second product is +// added to the previous result. +// Returns the sum added to |accumulator|. +template +inline __m256 MultiplyAddFloat(const __m256& input_pairs, + const __m256& third_input, const float* ptr0_1, + const float* ptr2, const __m256& accumulator) { + __m256 data_pair0 = _mm256_load_ps(ptr0_1); + __m256 data_pair1 = _mm256_load_ps(ptr0_1 + 8); + data_pair0 = _mm256_mul_ps(data_pair0, input_pairs); + data_pair1 = _mm256_mul_ps(data_pair1, input_pairs); + data_pair0 = _mm256_hadd_ps(data_pair0, data_pair1); + // Swap the middle 2 64 bit pairs to correct the hadd result. + data_pair0 = _mm256_permute4x64_pd(data_pair0, 0xd8); + if (kThreeInputs) { + // Load 256 bits (8 x float) of data, then multiply-accumulate. + data_pair1 = _mm256_load_ps(ptr2); + data_pair1 = _mm256_mul_ps(data_pair1, third_input); + data_pair0 = _mm256_add_ps(data_pair0, data_pair1); + } + // Add conditioning. + return _mm256_add_ps(data_pair0, accumulator); +} + +// Processes the tanh and the final combination, returns the new GRU state. +template +inline __m256i GRUComputeState(const __m256& cell0, const __m256& cell1, + const __m256& reset0, const __m256& reset1, + const __m256& update0, const __m256& update1, + const int32_t* gate_ptr, + const int32_t* gate_other_ptr, + const void* gru_h_ptr) { + // Multiply the cell gru output and the reset. + __m256 float_gru0 = LoadMultiplyAddToFloat( + gate_ptr, gate_other_ptr, reset0, cell0); + __m256 float_gru1 = LoadMultiplyAddToFloat( + gate_ptr + kAVX2SIMDWidth, gate_other_ptr + kAVX2SIMDWidth, reset1, + cell1); + // Compute tanh on the result. + __m256 hbar0, hbar1; + float_tanh_float(float_gru0, float_gru1, + hbar0, hbar1); + // Load the 16-bit previous gru state and update. + __m256i gru = _mm256_load_si256(reinterpret_cast<__m256i const*>(gru_h_ptr)); + __m256 state_factor = + _mm256_set1_ps(1.0f / (static_cast(1 << kStateMantissaBits))); + float_gru0 = + _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(gru))); + float_gru1 = _mm256_cvtepi32_ps( + _mm256_cvtepi16_epi32(_mm256_extractf128_si256(gru, 1))); + float_gru0 = _mm256_mul_ps(float_gru0, state_factor); + float_gru1 = _mm256_mul_ps(float_gru1, state_factor); + float_gru0 = _mm256_sub_ps(float_gru0, hbar0); + float_gru1 = _mm256_sub_ps(float_gru1, hbar1); + float_gru0 = _mm256_mul_ps(float_gru0, update0); + float_gru1 = _mm256_mul_ps(float_gru1, update1); + state_factor = _mm256_set1_ps(static_cast(1 << kStateMantissaBits)); + float_gru0 = _mm256_add_ps(float_gru0, hbar0); + float_gru1 = _mm256_add_ps(float_gru1, hbar1); + float_gru0 = _mm256_mul_ps(float_gru0, state_factor); + float_gru1 = _mm256_mul_ps(float_gru1, state_factor); + return PackFloatsToFixed16(float_gru0, float_gru1); +} + +// According to |kInputsMode|, processes 0, 2 or 3 autoregressive inputs and +// combines with |input| and |gates*|. +// With 2 AR inputs, loads 8x pairs of float from |pair_weights| and multiplies +// by |paired_ar|, likewise formatted as 8x float, but scaled such that the +// product with pair_weights is on the same scale as |*input| and |*gates0|, +// and sums each pair result, making 8x float results. +// If 3 AR inputs, also loads 8x float from |third_weights| and multiplies by +// |third_ar|, which must be formatted as 8x scaled floats. The second product +// is added to the previous result. +// Inputs, 8x fixed32 are loaded from |input|, and added to the total. +// Finally 8x fixed32 from |gates0| (and |gates1| if |kTwoGates|) are added as +// well. +// Returns the total sum as a float, but on the scale of |*input|. +template +inline __m256i GruInput32ToFloat(const __m256& paired_ar, + const __m256& third_ar, + const float* pair_weights, + const float* third_weights, + const int32_t* gates0, const int32_t* gates1, + const int32_t* input) { + __m256i data32 = _mm256_load_si256(reinterpret_cast<__m256i const*>(input)); + data32 = LoadAndAddFixed32(gates0, gates1, data32); + __m256 float_data = _mm256_cvtepi32_ps(data32); + if (kInputsMode != ARInputsMode::k0ARInputs) { + float_data = MultiplyAddFloat( + paired_ar, third_ar, pair_weights, third_weights, float_data); + } + return float_data; +} + +// Generic GRU gates function controlled by template parameters thus: +// - |kInputBits|: the mantissa bits in |*input_ptr|, |*gru_recurrent_ptr|. +// - |kStateBits|: the mantissa_bits in |*gru_state_ptr|. +// - |kInputsMode == |k0ARInputs|: There are no autoregressive inputs so +// |ar_sample, |ar_sample1|, |ar_sample2|, |ar_01_weights|, |ar_2_weights| are +// ignored. +// - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied by +// |ar_01_weights| and added to the (conditioning) input. +// - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by |ar_2_weights| +// and added to the other two AR inputs (and added to the conditioning input). +// - |kReplicas| determines the number of duplicates of the output to be +// written, separated by |replica_stride|. If zero, then the number of +// replicas is variable and taken from the |replicas| argument. +// - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary +// recurrent input that must be added to |*gru_recurrent_ptr|. +// - |start|, |end| are |rows| in [0, |state_size|] to be processed by this +// thread. +// +// Previous state is read from |*gru_state_ptr| and the new state is written to +// *(|gru_state_ptr| + i * |replica_stride| for i in [0, |kReplicas|]). +template +inline void GruGatesTemplate( + int start, int end, int state_size, int replicas, int replica_stride, + const int32_t* gru_recurrent_ptr, const int32_t* input_ptr, + const std::pair* ar_sample01, const float* ar_01_weights, + const float* ar_sample2, const float* ar_2_weights, + const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) { + constexpr int kQRIncrement = kAVX2SIMDWidth; + // Increment all the pointers to save on pointer arithmetic in the loop. + input_ptr += start; + gru_state_ptr += start; + gru_recurrent_ptr += start; + if (kSplitGates) gru_recurrent_other_ptr += start; + __m256 ar_2_inputs, ar_3rd_input; + if (kInputsMode != ARInputsMode::k0ARInputs) { + ar_01_weights += 2 * start; + ar_2_inputs = _mm256_castsi256_ps( + _mm256_set1_epi64x(*reinterpret_cast(ar_sample01))); + if (kInputsMode == ARInputsMode::k3ARInputs) { + ar_2_weights += start; + ar_3rd_input = _mm256_set1_ps(*ar_sample2); + } else { + ar_3rd_input = {}; + } + } else { + ar_2_inputs = {}; + ar_3rd_input = {}; + } + // The transcendentals handle 2x registers of data at once, so we have to do + // everything in duplicate. + for (int i = start; i < end; i += kQRIncrement * 2) { + // Load 8 pairs of fixed16s for each of reset, update and cell. + __m256 reset0 = GruInput32ToFloat( + ar_2_inputs, ar_3rd_input, ar_01_weights, ar_2_weights, + gru_recurrent_ptr, gru_recurrent_other_ptr, input_ptr); + __m256 reset1 = GruInput32ToFloat( + ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * kQRIncrement, + ar_2_weights + kQRIncrement, gru_recurrent_ptr + kAVX2SIMDWidth, + gru_recurrent_other_ptr + kAVX2SIMDWidth, input_ptr + kAVX2SIMDWidth); + float_sigmoid_float(reset0, reset1); + __m256 update0 = GruInput32ToFloat( + ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * state_size, + ar_2_weights + state_size, gru_recurrent_ptr + state_size, + gru_recurrent_other_ptr + state_size, input_ptr + state_size); + __m256 update1 = GruInput32ToFloat( + ar_2_inputs, ar_3rd_input, + ar_01_weights + 2 * state_size + 2 * kQRIncrement, + ar_2_weights + state_size + kQRIncrement, + gru_recurrent_ptr + state_size + kAVX2SIMDWidth, + gru_recurrent_other_ptr + state_size + kAVX2SIMDWidth, + input_ptr + state_size + kAVX2SIMDWidth); + float_sigmoid_float(update0, update1); + __m256 cell0 = _mm256_cvtepi32_ps(_mm256_load_si256( + reinterpret_cast<__m256i const*>(input_ptr + 2 * state_size))); + __m256 cell1 = + _mm256_cvtepi32_ps(_mm256_load_si256(reinterpret_cast<__m256i const*>( + input_ptr + 2 * state_size + kAVX2SIMDWidth))); + if (kInputsMode != ARInputsMode::k0ARInputs) { + cell0 = MultiplyAddFloat( + ar_2_inputs, ar_3rd_input, ar_01_weights + 4 * state_size, + ar_2_weights + 2 * state_size, cell0); + cell1 = MultiplyAddFloat( + ar_2_inputs, ar_3rd_input, + ar_01_weights + 4 * state_size + 2 * kQRIncrement, + ar_2_weights + 2 * state_size + kQRIncrement, cell1); + } + __m256i gru_state = GRUComputeState( + cell0, cell1, reset0, reset1, update0, update1, + gru_recurrent_ptr + 2 * state_size, + gru_recurrent_other_ptr + 2 * state_size, gru_state_ptr); + if (kReplicas > 0) { + // With |kReplicas| a template parameter, the compiler will unroll the + // loop. + for (int j = 0; j < kReplicas; ++j) { + _mm256_store_si256( + reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride), + gru_state); + } + } else { + // This loop will not unroll as replicas is variable. + for (int j = 0; j < replicas; ++j) { + _mm256_store_si256( + reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride), + gru_state); + } + } + // Increment all the pointers. + input_ptr += 2 * kAVX2SIMDWidth; + gru_state_ptr += 2 * kAVX2SIMDWidth; + gru_recurrent_ptr += 2 * kAVX2SIMDWidth; + if (kSplitGates) gru_recurrent_other_ptr += 2 * kAVX2SIMDWidth; + if (kInputsMode != ARInputsMode::k0ARInputs) { + ar_01_weights += 4 * kQRIncrement; + if (kInputsMode == ARInputsMode::k3ARInputs) + ar_2_weights += 2 * kQRIncrement; + } + } +} + +// Dispatches calls to the GruGatesTemplate function above converting the +// replicas variable argument to a template parameter to allow the compiler to +// unroll the write loop. +// |ar_sample01| packs sample 0 and 1 into a pair because the QR weights are +// formatted with the weights interleaved for sample 0 and 1. The two samples +// represent coarse and fine for WaveRNN. +template +inline void GruGatesAVXFixed( + int start, int end, int state_size, const int32_t* gru_recurrent_ptr, + const int32_t* input_ptr, const std::pair* ar_sample01, + const float* ar_01_weights, int num_replicas, int replica_stride, + const float* ar_sample2, const float* ar_2_weights, + const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) { + // Convert the number of replicas from a variable to a template parameter + // with a switch. This enables the compiler to unroll the loop for + // the write, making it faster for common numbers of threads. + switch (num_replicas) { + case 1: + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + break; + case 2: + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + break; + case 4: + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + break; + case 6: + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + break; + default: + // Zero |kReplicas| tells the function to use the |num_replicas| variable. + GruGatesTemplate( + start, end, state_size, num_replicas, replica_stride, + gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2, + ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); + } +} + +#endif // __AVX2__ + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_ diff --git a/sparse_matmul/compute/gru_gates_generic.h b/sparse_matmul/compute/gru_gates_generic.h new file mode 100644 index 00000000..691efb1f --- /dev/null +++ b/sparse_matmul/compute/gru_gates_generic.h @@ -0,0 +1,97 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_ + +#include "sparse_matmul/compute/ar_inputs.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" + +namespace csrblocksparse { + +constexpr int kGenericSIMDWidth = 4; + +// TODO(b/188702959): Rename arguments to match gru_gates.h. +template +void GoThroughGates(int start, int end, const QR_W_Type* qr_ptr, + const GRUMatMulOutType* gru_gates_ptr, + const GRUMatMulOutType* gru_gates_other_ptr, + const GRUMatMulOutType* conditioning_ptr, + GRUStateType* gru_h_ptr, const QR_W_Type* w_hat, + int proj_size, const SampleType* coarse_at_sminus1, + const SampleType* fine_at_sminus1, + const SampleType* coarse_at_s = nullptr) { + float qr_cell = 0.0f, reset, update, cell; + for (int i = start; i < end; ++i) { + if (kInputsMode == ARInputsMode::k0ARInputs) { + reset = static_cast(gru_gates_ptr[i]); + update = static_cast(gru_gates_ptr[proj_size + i]); + } else { + float qr_c_reset = static_cast(qr_ptr[2 * i + 0]); + float qr_f_reset = static_cast(qr_ptr[2 * i + 1]); + float qr_c_update = static_cast(qr_ptr[2 * proj_size + 2 * i + 0]); + float qr_f_update = static_cast(qr_ptr[2 * proj_size + 2 * i + 1]); + float qr_c_cell = static_cast(qr_ptr[4 * proj_size + 2 * i + 0]); + float qr_f_cell = static_cast(qr_ptr[4 * proj_size + 2 * i + 1]); + float w_hat_i_reset = 0.0f; + float w_hat_i_update = 0.0f; + float w_hat_i_cell = 0.0f; + if (kInputsMode == ARInputsMode::k3ARInputs) { + w_hat_i_reset = static_cast(w_hat[i]); + w_hat_i_update = static_cast(w_hat[proj_size + i]); + w_hat_i_cell = static_cast(w_hat[2 * proj_size + i]); + } + float coarse = static_cast(coarse_at_sminus1[0]); + float fine = static_cast(fine_at_sminus1[0]); + reset = qr_c_reset * coarse + qr_f_reset * fine; + update = qr_c_update * coarse + qr_f_update * fine; + qr_cell = qr_c_cell * coarse + qr_f_cell * fine; + if (kInputsMode == ARInputsMode::k3ARInputs) { + float coarse = static_cast(coarse_at_s[0]); + reset += w_hat_i_reset * coarse; + update += w_hat_i_update * coarse; + qr_cell += w_hat_i_cell * coarse; + } + reset += static_cast(gru_gates_ptr[i]); + update += static_cast(gru_gates_ptr[proj_size + i]); + } + cell = static_cast(gru_gates_ptr[2 * proj_size + i]); + if (SplitGates) { + reset += static_cast(gru_gates_other_ptr[i]); + update += static_cast(gru_gates_other_ptr[proj_size + i]); + cell += static_cast(gru_gates_other_ptr[2 * proj_size + i]); + } + float reset_conditioning = static_cast(conditioning_ptr[i]); + float update_conditioning = + static_cast(conditioning_ptr[proj_size + i]); + float cell_conditioning = + static_cast(conditioning_ptr[2 * proj_size + i]); + reset = fast_sigmoid(reset + reset_conditioning); + update = fast_sigmoid(update + update_conditioning); + float hbar = fast_tanh(qr_cell + reset * cell + cell_conditioning); + int h_index = i; + float prev_h = static_cast(gru_h_ptr[h_index]); + float diff = prev_h - hbar; + float new_h = hbar + diff * update; + gru_h_ptr[h_index] = static_cast(new_h); + } +} + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_ diff --git a/sparse_matmul/compute/gru_gates_test.cc b/sparse_matmul/compute/gru_gates_test.cc new file mode 100644 index 00000000..4f626c98 --- /dev/null +++ b/sparse_matmul/compute/gru_gates_test.cc @@ -0,0 +1,164 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/compute/gru_gates.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/types/span.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace { + +using csrblocksparse::ARInputsMode; + +template +csrblocksparse::CacheAlignedVector TestGruGates() { + using SampleWeightType = float; + constexpr int kStateSize = 16; + csrblocksparse::CacheAlignedVector qr(6 * kStateSize); + csrblocksparse::CacheAlignedVector w(3 * kStateSize); + csrblocksparse::CacheAlignedVector gru_gates(3 * kStateSize); + csrblocksparse::CacheAlignedVector gru_other_gates(3 * kStateSize); + csrblocksparse::CacheAlignedVector conditioning(3 * kStateSize); + csrblocksparse::CacheAlignedVector gru_h(kStateSize); + csrblocksparse::GruGates gru_gates_impl; + const SampleType kCoarseAtSMinus1(0.03f); + const SampleType kFineAtSMinus1(0.07f); + const SampleType kCoarseAtS(-0.02f); + + qr.FillOnes(); + w.FillOnes(); + gru_gates.FillRandom(); + gru_other_gates.FillRandom(); + conditioning.FillRandom(); + gru_h.FillZero(); + + gru_gates_impl.template GruWithARInput( + /*start=*/0, /*end=*/kStateSize, kStateSize, gru_gates.data(), + conditioning.data(), gru_h.data(), &kCoarseAtSMinus1, &kFineAtSMinus1, + qr.data(), + /*num_replicas=*/1, /*replica_stride=*/0, &kCoarseAtS, w.data(), + gru_other_gates.data()); + return gru_h; +} + +TEST(GruGates, FloatWaveRNNCoarseMatchesGolden) { + // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers + // will also need to change. + const std::vector kGoldenValues = { + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.746f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.993f}; + csrblocksparse::CacheAlignedVector gru_h = + TestGruGates(); + + ASSERT_EQ(kGoldenValues.size(), gru_h.size()); + for (int i = 0; i < gru_h.size(); ++i) { + EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i; + } +} + +TEST(GruGates, FloatWaveRNNFineMatchesGolden) { + // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers + // will also need to change. + const std::vector kGoldenValues = { + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.737f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.969f, 0.0f, 0.0f, 1.0f, 0.0f, -0.994f}; + csrblocksparse::CacheAlignedVector gru_h = + TestGruGates(); + + ASSERT_EQ(kGoldenValues.size(), gru_h.size()); + for (int i = 0; i < gru_h.size(); ++i) { + EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i; + } +} + +TEST(GruGates, FloatTwoArInputsNonSplitGateMatchesGolden) { + // If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers + // will also need to change. + const std::vector kGoldenValues = { + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.714f, 0.0f, -0.002f, + 0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.965f}; + csrblocksparse::CacheAlignedVector gru_h = + TestGruGates(); + + ASSERT_EQ(kGoldenValues.size(), gru_h.size()); + for (int i = 0; i < gru_h.size(); ++i) { + EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i; + } +} + +TEST(GruGates, FixedWaveRNNCoarseMatchesFloat) { + using GRUMatMulOutType = csrblocksparse::fixed32<11>; + using GRUStateType = csrblocksparse::fixed16<2>; + using SampleType = csrblocksparse::fixed16<0>; + csrblocksparse::CacheAlignedVector float_gru_h = + TestGruGates(); + csrblocksparse::CacheAlignedVector fixed_gru_h = + TestGruGates(); + + ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size()); + for (int i = 0; i < fixed_gru_h.size(); ++i) { + EXPECT_NEAR(float_gru_h[i], static_cast(fixed_gru_h[i]), 1e-3) + << "i=" << i; + } +} + +TEST(GruGates, FixedWaveRNNFineMatchesFloat) { + using GRUMatMulOutType = csrblocksparse::fixed32<11>; + using GRUStateType = csrblocksparse::fixed16<2>; + using SampleType = csrblocksparse::fixed16<0>; + csrblocksparse::CacheAlignedVector float_gru_h = + TestGruGates(); + csrblocksparse::CacheAlignedVector fixed_gru_h = + TestGruGates(); + + ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size()); + for (int i = 0; i < fixed_gru_h.size(); ++i) { + EXPECT_NEAR(float_gru_h[i], static_cast(fixed_gru_h[i]), 1e-3) + << "i=" << i; + } +} + +TEST(GruGates, FixedTwoArInputsNonSplitGateMatchesFloat) { + using GRUMatMulOutType = csrblocksparse::fixed32<11>; + using GRUStateType = csrblocksparse::fixed16<2>; + using SampleType = csrblocksparse::fixed16<0>; + csrblocksparse::CacheAlignedVector float_gru_h = + TestGruGates(); + csrblocksparse::CacheAlignedVector fixed_gru_h = + TestGruGates(); + + ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size()); + for (int i = 0; i < fixed_gru_h.size(); ++i) { + EXPECT_NEAR(float_gru_h[i], static_cast(fixed_gru_h[i]), 1e-3) + << "i=" << i; + } +} + +} // namespace diff --git a/sparse_matmul/compute/kernels_arm.h b/sparse_matmul/compute/kernels_arm.h new file mode 100644 index 00000000..494430fe --- /dev/null +++ b/sparse_matmul/compute/kernels_arm.h @@ -0,0 +1,2886 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_ARM_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_ARM_H_ + +#if defined __aarch64__ + +#include + +#include + +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" + +#define LABEL_COL_LOOP "1" +#define LABEL_ROW_LOOP "2" +#define LABEL_SKIP_COL_LOOP "3" +#define LABEL_TOP_LOOP "4" + +namespace csrblocksparse { +namespace detail { + +template +struct IsFloatOrBfloat + : std::integral_constant::value || + std::is_same::value> {}; + +template +struct IsAllowableFloatTypes + : std::integral_constant::value && + std::is_same::value && + std::is_same::value> {}; + +// 16-bit inputs, 32-bit output exponent matches sum of input exponents +// OR +// 16-bit inputs, 16-bit output - will shift to match exponent +template +struct IsAllowableFixedTypes + : std::integral_constant::value && + IsFixed16Type::value) && + (IsFixed32Type::value || + IsFixed16Type::value)> {}; + +template +struct ShouldEnableGenericKernel + : std::integral_constant< + bool, + !IsAllowableFloatTypes::value && + !IsAllowableFixedTypes::value> {}; + +template +struct ShouldEnableGenericSpMV_4x4 + : ShouldEnableGenericKernel {}; +template +struct ShouldEnableGenericSpMM5_4x4 + : ShouldEnableGenericKernel {}; +template +struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; +template +struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; +template +struct IsAddableFixedTypes + : std::integral_constant::value || + IsFixed16Type::value> {}; +template +struct ShouldEnableGenericAdd + : std::integral_constant::value> {}; + +// The computational routines do NO error checking for speed. It is assumed +// that this has been handled by CSRBlockSparseMatrix. + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a vector and b is vector. Weights are stored for this +// routine by making each 4x4 block contiguous. Blocks are ordered in standard +// row-major format. column indices are converted to deltas and then multiplied +// by 2 to convert to bytes, so that the value can be used directly to offset +// the pointer into the rhs vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMV_4x4(const bfloat16* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const float* rhs_ptr, + const float* bias_ptr, float* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + /* This instrinsic version exists for reference, note that in the + intrinsic version col_deltas_bytes should NOT actually be in bytes, + but rather elements. Intrinsics are 25-35% slower than the + assembly version. + + for (int r = 0; r < rows; r += 4) { + int reduced_col_count = nnz_per_row[r / 4]; + float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); + for (int c = 0; c < reduced_col_count; ++c) { + int32_t offset = *col_deltas_bytes; col_deltas_bytes++; + rhs_ptr += offset; + float32x4_t rhs = vld1q_f32(rhs_ptr); + + uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; + + float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); + float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); + float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); + float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); + + accum0 = vmlaq_f32(accum0, lhs0, rhs); + accum1 = vmlaq_f32(accum1, lhs1, rhs); + accum2 = vmlaq_f32(accum2, lhs2, rhs); + accum3 = vmlaq_f32(accum3, lhs3, rhs); + } + + float32x4_t reduce0 = vpaddq_f32(accum0, accum1); + float32x4_t reduce1 = vpaddq_f32(accum2, accum3); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out_ptr + r, reduce2); + } */ + + // If the relu is handled in the routine with a comparison and vbit (insert + // if true), or by branching, then it is slightly, but noticeably slower + // ~5%, the outer branch avoids that penalty. + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Convert bfloat16 -> float32. + "shll v4.4s, v2.4h, #16\n" + "shll2 v5.4s, v2.8h, #16\n" + "shll v6.4s, v3.4h, #16\n" + "shll2 v7.4s, v3.8h, #16\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "fmla v31.4s, v7.4s, v0.4s\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result + "faddp v28.4s, v28.4s, v29.4s\n" + "faddp v30.4s, v30.4s, v31.4s\n" + "faddp v28.4s, v28.4s, v30.4s\n" + + // Do relu if requested. + "fmax v28.4s, v28.4s, v25.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Convert bfloat16 -> float32. + "shll v4.4s, v2.4h, #16\n" + "shll2 v5.4s, v2.8h, #16\n" + "shll v6.4s, v3.4h, #16\n" + "shll2 v7.4s, v3.8h, #16\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "fmla v31.4s, v7.4s, v0.4s\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" + "faddp v30.4s, v30.4s, v31.4s\n" + "faddp v28.4s, v28.4s, v30.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMM5_4x4(const bfloat16* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const float* rhs_ptr, + const float* bias_ptr, float* out_ptr, int64_t assigned_rows, + int64_t rows, int64_t cols, int relu) { + /* This instrinsic version exists for reference, note that in the + intrinsic version col_deltas_bytes should NOT actually be in bytes, + but rather elements. Intrinsics are 25-35% slower than the + assembly version. + + for (int r = 0; r < rows; r += 4) { + int reduced_col_count = nnz_per_row[r / 4]; + float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); + float32x4_t accum4 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum5 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum6 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum7 = vdupq_n_f32(bias_ptr + r + 3); + ... + for (int c = 0; c < reduced_col_count; ++c) { + int32_t offset = *col_deltas_bytes; col_deltas_bytes++; + rhs_ptr += offset; + float32x4_t rhs = vld1q_f32(rhs_ptr); + float32x4_t rhs2 = vld1q_f32(rhs2_ptr); + float32x4_t rhs3 = vld1q_f32(rhs3_ptr); + float32x4_t rhs4 = vld1q_f32(rhs4_ptr); + float32x4_t rhs5 = vld1q_f32(rhs5_ptr); + + uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; + + float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); + float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); + float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); + float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); + + accum0 = vmlaq_f32(accum0, lhs0, rhs); + accum1 = vmlaq_f32(accum1, lhs1, rhs); + accum2 = vmlaq_f32(accum2, lhs2, rhs); + accum3 = vmlaq_f32(accum3, lhs3, rhs); + accum4 = vmlaq_f32(accum0, lhs0, rhs2); + accum5 = vmlaq_f32(accum1, lhs1, rhs2); + accum6 = vmlaq_f32(accum2, lhs2, rhs2); + accum7 = vmlaq_f32(accum3, lhs3, rhs2); + ... + } + + float32x4_t reduce0 = vpaddq_f32(accum0, accum1); + float32x4_t reduce1 = vpaddq_f32(accum2, accum3); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out_ptr + r, reduce2); + + float32x4_t reduce0 = vpaddq_f32(accum4, accum5); + float32x4_t reduce1 = vpaddq_f32(accum6, accum7); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out2_ptr + r, reduce2); + + ... + } */ + + // If the relu is handled in the routine with a comparison and vbit (insert + // if true), or by branching, then it is slightly, but noticeably slower + // ~5%, the outer branch avoids that penalty. + // + // Pointers to the columns. + const float* rhs2_ptr = rhs_ptr + cols; + float* out2_ptr = out_ptr + rows; + const float* rhs3_ptr = rhs_ptr + 2 * cols; + float* out3_ptr = out_ptr + 2 * rows; + const float* rhs4_ptr = rhs_ptr + 3 * cols; + float* out4_ptr = out_ptr + 3 * rows; + const float* rhs5_ptr = rhs_ptr + 4 * cols; + float* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" + "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" + "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" + "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Convert bfloat16 -> float32. + "shll v4.4s, v2.4h, #16\n" + "shll2 v5.4s, v2.8h, #16\n" + "shll v6.4s, v3.4h, #16\n" + "shll2 v7.4s, v3.8h, #16\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column + "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column + "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column + "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column + "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column + "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column + "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column + "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column + "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column + "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column + "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column + "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column + "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column + "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column + "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column + "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column + "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column + "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column + "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column + "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "faddp v28.4s, v28.4s, v29.4s\n" // 1st column + "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v16.4s\n" // 4th column + "faddp v11.4s, v11.4s, v12.4s\n" // 5th column + + "faddp v30.4s, v30.4s, v31.4s\n" // 1st column + "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column + "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column + "faddp v17.4s, v17.4s, v18.4s\n" // 4th column + "faddp v13.4s, v13.4s, v14.4s\n" // 5th column + + "faddp v28.4s, v28.4s, v30.4s\n" // 1st column + "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v17.4s\n" // 4th column + "faddp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), + [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), + [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" + "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" + "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" + "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Convert bfloat16 -> float32. + "shll v4.4s, v2.4h, #16\n" + "shll2 v5.4s, v2.8h, #16\n" + "shll v6.4s, v3.4h, #16\n" + "shll2 v7.4s, v3.8h, #16\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column + "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column + "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column + "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column + "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column + "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column + "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column + "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column + "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column + "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column + "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column + "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column + "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column + "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column + "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column + "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column + "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column + "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column + "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column + "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" // 1st column + "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v16.4s\n" // 4th column + "faddp v11.4s, v11.4s, v12.4s\n" // 5th column + + "faddp v30.4s, v30.4s, v31.4s\n" // 1st column + "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column + "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column + "faddp v17.4s, v17.4s, v18.4s\n" // 4th column + "faddp v13.4s, v13.4s, v14.4s\n" // 5th column + + "faddp v28.4s, v28.4s, v30.4s\n" // 1st column + "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v17.4s\n" // 4th column + "faddp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), + [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), + [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// float implementations below the line. + +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMV_4x4(const float* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const float* rhs_ptr, + const float* bias_ptr, float* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + /* This instrinsic version exists for reference, note that in the + intrinsic version col_deltas_bytes should NOT actually be in bytes, + but rather elements. Intrinsics are 25-35% slower than the + assembly version. + + for (int r = 0; r < rows; r += 4) { + int reduced_col_count = nnz_per_row[r / 4]; + float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); + for (int c = 0; c < reduced_col_count; ++c) { + int32_t offset = *col_deltas_bytes; col_deltas_bytes++; + rhs_ptr += offset; + float32x4_t rhs = vld1q_f32(rhs_ptr); + + uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; + + float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); + float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); + float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); + float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); + + accum0 = vmlaq_f32(accum0, lhs0, rhs); + accum1 = vmlaq_f32(accum1, lhs1, rhs); + accum2 = vmlaq_f32(accum2, lhs2, rhs); + accum3 = vmlaq_f32(accum3, lhs3, rhs); + } + + float32x4_t reduce0 = vpaddq_f32(accum0, accum1); + float32x4_t reduce1 = vpaddq_f32(accum2, accum3); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out_ptr + r, reduce2); + } */ + + // If the relu is handled in the routine with a comparison and vbit (insert + // if true), or by branching, then it is slightly, but noticeably slower + // ~5%, the outer branch avoids that penalty. + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "fmla v31.4s, v7.4s, v0.4s\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" + "faddp v30.4s, v30.4s, v31.4s\n" + "faddp v28.4s, v28.4s, v30.4s\n" + + // Do relu as requested. + "fmax v28.4s, v28.4s, v25.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" + "fmla v29.4s, v5.4s, v0.4s\n" + "fmla v30.4s, v6.4s, v0.4s\n" + "fmla v31.4s, v7.4s, v0.4s\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" + "faddp v30.4s, v30.4s, v31.4s\n" + "faddp v28.4s, v28.4s, v30.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in sparse_linear_layer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMM5_4x4(const float* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const float* rhs_ptr, + const float* bias_ptr, float* out_ptr, int64_t assigned_rows, + int64_t rows, int64_t cols, int relu) { + /* This instrinsic version exists for reference, note that in the + intrinsic version col_deltas_bytes should NOT actually be in bytes, + but rather elements. Intrinsics are 25-35% slower than the + assembly version. + + for (int r = 0; r < rows; r += 4) { + int reduced_col_count = nnz_per_row[r / 4]; + float32x4_t accum0 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum1 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum2 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum3 = vdupq_n_f32(bias_ptr + r + 3); + float32x4_t accum4 = vdupq_n_f32(bias_ptr + r); + float32x4_t accum5 = vdupq_n_f32(bias_ptr + r + 1); + float32x4_t accum6 = vdupq_n_f32(bias_ptr + r + 2); + float32x4_t accum7 = vdupq_n_f32(bias_ptr + r + 3); + ... + for (int c = 0; c < reduced_col_count; ++c) { + int32_t offset = *col_deltas_bytes; col_deltas_bytes++; + rhs_ptr += offset; + float32x4_t rhs = vld1q_f32(rhs_ptr); + float32x4_t rhs2 = vld1q_f32(rhs2_ptr); + float32x4_t rhs3 = vld1q_f32(rhs3_ptr); + float32x4_t rhs4 = vld1q_f32(rhs4_ptr); + float32x4_t rhs5 = vld1q_f32(rhs5_ptr); + + uint16x4_t lhs0_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs1_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs2_int = vld1_u16(weights_ptr); weights_ptr += 4; + uint16x4_t lhs3_int = vld1_u16(weights_ptr); weights_ptr += 4; + + float32x4_t lhs0 = vreinterpretq_f32_u32(vshll_n_u16(lhs0_int, 16)); + float32x4_t lhs1 = vreinterpretq_f32_u32(vshll_n_u16(lhs1_int, 16)); + float32x4_t lhs2 = vreinterpretq_f32_u32(vshll_n_u16(lhs2_int, 16)); + float32x4_t lhs3 = vreinterpretq_f32_u32(vshll_n_u16(lhs3_int, 16)); + + accum0 = vmlaq_f32(accum0, lhs0, rhs); + accum1 = vmlaq_f32(accum1, lhs1, rhs); + accum2 = vmlaq_f32(accum2, lhs2, rhs); + accum3 = vmlaq_f32(accum3, lhs3, rhs); + accum4 = vmlaq_f32(accum0, lhs0, rhs2); + accum5 = vmlaq_f32(accum1, lhs1, rhs2); + accum6 = vmlaq_f32(accum2, lhs2, rhs2); + accum7 = vmlaq_f32(accum3, lhs3, rhs2); + ... + } + + float32x4_t reduce0 = vpaddq_f32(accum0, accum1); + float32x4_t reduce1 = vpaddq_f32(accum2, accum3); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out_ptr + r, reduce2); + + float32x4_t reduce0 = vpaddq_f32(accum4, accum5); + float32x4_t reduce1 = vpaddq_f32(accum6, accum7); + float32x4_t reduce2 = vpaddq_f32(reduce0, reduce1); + vst1q_f32(out2_ptr + r, reduce2); + + ... + } */ + + // If the relu is handled in the routine with a comparison and vbit (insert + // if true), or by branching, then it is slightly, but noticeably slower + // ~5%, the outer branch avoids that penalty. + // + // Pointers to the columns. + const float* rhs2_ptr = rhs_ptr + cols; + float* out2_ptr = out_ptr + rows; + const float* rhs3_ptr = rhs_ptr + 2 * cols; + float* out3_ptr = out_ptr + 2 * rows; + const float* rhs4_ptr = rhs_ptr + 3 * cols; + float* out4_ptr = out_ptr + 3 * rows; + const float* rhs5_ptr = rhs_ptr + 4 * cols; + float* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" + "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" + "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" + "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column + "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column + "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column + "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column + "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column + "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column + "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column + "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column + "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column + "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column + "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column + "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column + "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column + "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column + "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column + "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column + "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column + "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column + "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column + "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "faddp v28.4s, v28.4s, v29.4s\n" // 1st column + "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v16.4s\n" // 4th column + "faddp v11.4s, v11.4s, v12.4s\n" // 5th column + + "faddp v30.4s, v30.4s, v31.4s\n" // 1st column + "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column + "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column + "faddp v17.4s, v17.4s, v18.4s\n" // 4th column + "faddp v13.4s, v13.4s, v14.4s\n" // 5th column + + "faddp v28.4s, v28.4s, v30.4s\n" // 1st column + "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v17.4s\n" // 4th column + "faddp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), + [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), + [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4s}, [%[rhs_ptr]], x8\n" + "ld1 {v1.4s}, [%[rhs2_ptr]], x8\n" + "ld1 {v8.4s}, [%[rhs3_ptr]], x8\n" + "ld1 {v9.4s}, [%[rhs4_ptr]], x8\n" + "ld1 {v10.4s}, [%[rhs5_ptr]], x8\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[weights_ptr]], #64\n" + + // Multiply-accumulate. + "fmla v28.4s, v4.4s, v0.4s\n" // for 1st column + "fmla v29.4s, v5.4s, v0.4s\n" // for 1st column + "fmla v30.4s, v6.4s, v0.4s\n" // for 1st column + "fmla v31.4s, v7.4s, v0.4s\n" // for 1st column + "fmla v23.4s, v4.4s, v1.4s\n" // for 2nd column + "fmla v24.4s, v5.4s, v1.4s\n" // for 2nd column + "fmla v25.4s, v6.4s, v1.4s\n" // for 2nd column + "fmla v26.4s, v7.4s, v1.4s\n" // for 2nd column + "fmla v19.4s, v4.4s, v8.4s\n" // for 3rd column + "fmla v20.4s, v5.4s, v8.4s\n" // for 3rd column + "fmla v21.4s, v6.4s, v8.4s\n" // for 3rd column + "fmla v22.4s, v7.4s, v8.4s\n" // for 3rd column + "fmla v15.4s, v4.4s, v9.4s\n" // for 4th column + "fmla v16.4s, v5.4s, v9.4s\n" // for 4th column + "fmla v17.4s, v6.4s, v9.4s\n" // for 4th column + "fmla v18.4s, v7.4s, v9.4s\n" // for 4th column + "fmla v11.4s, v4.4s, v10.4s\n" // for 5th column + "fmla v12.4s, v5.4s, v10.4s\n" // for 5th column + "fmla v13.4s, v6.4s, v10.4s\n" // for 5th column + "fmla v14.4s, v7.4s, v10.4s\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "faddp v28.4s, v28.4s, v29.4s\n" // 1st column + "faddp v23.4s, v23.4s, v24.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v20.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v16.4s\n" // 4th column + "faddp v11.4s, v11.4s, v12.4s\n" // 5th column + + "faddp v30.4s, v30.4s, v31.4s\n" // 1st column + "faddp v25.4s, v25.4s, v26.4s\n" // 2nd column + "faddp v21.4s, v21.4s, v22.4s\n" // 3rd column + "faddp v17.4s, v17.4s, v18.4s\n" // 4th column + "faddp v13.4s, v13.4s, v14.4s\n" // 5th column + + "faddp v28.4s, v28.4s, v30.4s\n" // 1st column + "faddp v23.4s, v23.4s, v25.4s\n" // 2nd column + "faddp v19.4s, v19.4s, v21.4s\n" // 3rd column + "faddp v15.4s, v15.4s, v17.4s\n" // 4th column + "faddp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), + [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), + [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), + [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), + [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), + [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), + [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), + [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + std::is_same::type>::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Do relu if requested. + "smax v28.4s, v28.4s, v25.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + std::is_same::type>::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + // Pointers to the columns. + const RhsType* rhs2_ptr = rhs_ptr + cols; + OutType* out2_ptr = out_ptr + rows; + const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; + OutType* out3_ptr = out_ptr + 2 * rows; + const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; + OutType* out4_ptr = out_ptr + 3 * rows; + const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; + OutType* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "smax v28.4s, v28.4s, v0.4s\n" + "smax v23.4s, v23.4s, v0.4s\n" + "smax v19.4s, v19.4s, v0.4s\n" + "smax v15.4s, v15.4s, v0.4s\n" + "smax v11.4s, v11.4s, v0.4s\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the bias must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if::value && + IsFixed16Type::value && + IsFixed16Type::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + constexpr int kShiftAmount = 15 - WeightType::kExponentBits - + RhsType::kExponentBits + OutType::kExponentBits; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Do relu if requested. + "smax v28.4s, v28.4s, v25.4s\n" + "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v26.4h}, [%[out_ptr]], #8\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v26.4h}, [%[out_ptr]], #8\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if::value && + IsFixed16Type::value && + IsFixed16Type::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + constexpr int kShiftAmount = 15 - WeightType::kExponentBits - + RhsType::kExponentBits + OutType::kExponentBits; + // Pointers to the columns. + const RhsType* rhs2_ptr = rhs_ptr + cols; + OutType* out2_ptr = out_ptr + rows; + const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; + OutType* out3_ptr = out_ptr + 2 * rows; + const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; + OutType* out4_ptr = out_ptr + 3 * rows; + const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; + OutType* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "smax v28.4s, v28.4s, v0.4s\n" + "smax v23.4s, v23.4s, v0.4s\n" + "smax v19.4s, v19.4s, v0.4s\n" + "smax v15.4s, v15.4s, v0.4s\n" + "smax v11.4s, v11.4s, v0.4s\n" + "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" + "sqrshrn v22.4h, v23.4s, %[shift_amount]\n" + "sqrshrn v18.4h, v19.4s, %[shift_amount]\n" + "sqrshrn v14.4h, v15.4s, %[shift_amount]\n" + "sqrshrn v10.4h, v11.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v26.4h}, [%[out_ptr]], #8\n" + "st1 {v22.4h}, [%[out2_ptr]], #8\n" + "st1 {v18.4h}, [%[out3_ptr]], #8\n" + "st1 {v14.4h}, [%[out4_ptr]], #8\n" + "st1 {v10.4h}, [%[out5_ptr]], #8\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + "sqrshrn v26.4h, v28.4s, %[shift_amount]\n" + "sqrshrn v22.4h, v23.4s, %[shift_amount]\n" + "sqrshrn v18.4h, v19.4s, %[shift_amount]\n" + "sqrshrn v14.4h, v15.4s, %[shift_amount]\n" + "sqrshrn v10.4h, v11.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v26.4h}, [%[out_ptr]], #8\n" + "st1 {v22.4h}, [%[out2_ptr]], #8\n" + "st1 {v18.4h}, [%[out3_ptr]], #8\n" + "st1 {v14.4h}, [%[out4_ptr]], #8\n" + "st1 {v10.4h}, [%[out5_ptr]], #8\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + IsFixed32Type::value && + !std::is_same::type>::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount > 0, + "Result must have fewer mantissa bits than product"); + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + // Do relu if requested. + "smax v28.4s, v28.4s, v25.4s\n" + "srshr v28.4s, v28.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + + "movi v25.4s, #0\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // accum_0 = 0 + "dup v29.4s, v27.s[1]\n" // accum_1 = 0 + "dup v30.4s, v27.s[2]\n" // accum_2 = 0 + "dup v31.4s, v27.s[3]\n" // accum_3 = 0 + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + // Duplicate the lower half into the upper half. + "mov v0.d[1], v0.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" + "smlal2 v29.4s, v2.8h, v0.8h\n" + "smlal v30.4s, v3.4h, v0.4h\n" + "smlal2 v31.4s, v3.8h, v0.8h\n" + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + // Horizontally add accumulators and store result. + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + "srshr v28.4s, v28.4s, %[shift_amount]\n" + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +// Note that the number of exponent bits in the output must exactly match +// the sum of the input and rhs types. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + IsFixed32Type::value && + !std::is_same::type>::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount > 0, + "Result must have fewer mantissa bits than product"); + // Pointers to the columns. + const RhsType* rhs2_ptr = rhs_ptr + cols; + OutType* out2_ptr = out_ptr + rows; + const RhsType* rhs3_ptr = rhs_ptr + 2 * cols; + OutType* out3_ptr = out_ptr + 2 * rows; + const RhsType* rhs4_ptr = rhs_ptr + 3 * cols; + OutType* out4_ptr = out_ptr + 3 * rows; + const RhsType* rhs5_ptr = rhs_ptr + 4 * cols; + OutType* out5_ptr = out_ptr + 4 * rows; + if (relu) { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "movi v0.4s, #0\n" + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + // Do relu as requested. + "smax v28.4s, v28.4s, v0.4s\n" + "smax v23.4s, v23.4s, v0.4s\n" + "smax v19.4s, v19.4s, v0.4s\n" + "smax v15.4s, v15.4s, v0.4s\n" + "smax v11.4s, v11.4s, v0.4s\n" + + "srshr v28.4s, v28.4s, %[shift_amount]\n" + "srshr v23.4s, v23.4s, %[shift_amount]\n" + "srshr v19.4s, v19.4s, %[shift_amount]\n" + "srshr v15.4s, v15.4s, %[shift_amount]\n" + "srshr v11.4s, v11.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } else { + asm( + // Load the first two column deltas. + "ldrsh x7, [%[col_deltas_bytes]], #2\n" + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + // ld1 doesn't support pre-index, so we do the first addition here. + "add %[rhs_ptr], %[rhs_ptr], x7\n" + "add %[rhs2_ptr], %[rhs2_ptr], x7\n" + "add %[rhs3_ptr], %[rhs3_ptr], x7\n" + "add %[rhs4_ptr], %[rhs4_ptr], x7\n" + "add %[rhs5_ptr], %[rhs5_ptr], x7\n" + + LABEL_ROW_LOOP + ":\n" + + // Load the bias. + "ld1 {v27.4s}, [%[bias_ptr]], #16\n" + + // Zero out local accumulators. + "dup v28.4s, v27.s[0]\n" // for 1st column + "dup v29.4s, v27.s[1]\n" // for 1st column + "dup v30.4s, v27.s[2]\n" // for 1st column + "dup v31.4s, v27.s[3]\n" // for 1st column + "dup v23.4s, v27.s[0]\n" // for 2nd column + "dup v24.4s, v27.s[1]\n" // for 2nd column + "dup v25.4s, v27.s[2]\n" // for 2nd column + "dup v26.4s, v27.s[3]\n" // for 2nd column + "dup v19.4s, v27.s[0]\n" // for 3rd column + "dup v20.4s, v27.s[1]\n" // for 3rd column + "dup v21.4s, v27.s[2]\n" // for 3rd column + "dup v22.4s, v27.s[3]\n" // for 3rd column + "dup v15.4s, v27.s[0]\n" // for 4th column + "dup v16.4s, v27.s[1]\n" // for 4th column + "dup v17.4s, v27.s[2]\n" // for 4th column + "dup v18.4s, v27.s[3]\n" // for 4th column + "dup v11.4s, v27.s[0]\n" // for 5th column + "dup v12.4s, v27.s[1]\n" // for 5th column + "dup v13.4s, v27.s[2]\n" // for 5th column + "dup v14.4s, v27.s[3]\n" // for 5th column + + // Update the stopping condition for this set of rows. + "ldr w6, [%[nnz_per_row]], #4\n" + "cmp w6, #0\n" + // Skip the body if there isn't anything in this row. + "beq " LABEL_SKIP_COL_LOOP "f\n" + + LABEL_COL_LOOP + ":\n" + // Load 1 Rhs vectors of size 1x4 each and duplicate into upper half. + "ld1 {v0.4h}, [%[rhs_ptr]], x8\n" + "mov v0.d[1], v0.d[0]\n" + "ld1 {v1.4h}, [%[rhs2_ptr]], x8\n" + "mov v1.d[1], v1.d[0]\n" + "ld1 {v8.4h}, [%[rhs3_ptr]], x8\n" + "mov v8.d[1], v8.d[0]\n" + "ld1 {v9.4h}, [%[rhs4_ptr]], x8\n" + "mov v9.d[1], v9.d[0]\n" + "ld1 {v10.4h}, [%[rhs5_ptr]], x8\n" + "mov v10.d[1], v10.d[0]\n" + + // Start this load now, which we won't need until the end of the loop. + "ldrsh x8, [%[col_deltas_bytes]], #2\n" + + // Load 16 Lhs cells corresponding to a 4x4 block. + "ld1 {v2.8h, v3.8h}, [%[weights_ptr]], #32\n" + + // Multiply-accumulate. + "smlal v28.4s, v2.4h, v0.4h\n" // for 1st column + "smlal2 v29.4s, v2.8h, v0.8h\n" // for 1st column + "smlal v30.4s, v3.4h, v0.4h\n" // for 1st column + "smlal2 v31.4s, v3.8h, v0.8h\n" // for 1st columh + "smlal v23.4s, v2.4h, v1.4h\n" // for 2nd column + "smlal2 v24.4s, v2.8h, v1.8h\n" // for 2nd column + "smlal v25.4s, v3.4h, v1.4h\n" // for 2nd column + "smlal2 v26.4s, v3.8h, v1.8h\n" // for 2nd column + "smlal v19.4s, v2.4h, v8.4h\n" // for 3rd column + "smlal2 v20.4s, v2.8h, v8.8h\n" // for 3rd column + "smlal v21.4s, v3.4h, v8.4h\n" // for 3rd column + "smlal2 v22.4s, v3.8h, v8.8h\n" // for 3rd column + "smlal v15.4s, v2.4h, v9.4h\n" // for 4th column + "smlal2 v16.4s, v2.8h, v9.8h\n" // for 4th column + "smlal v17.4s, v3.4h, v9.4h\n" // for 4th column + "smlal2 v18.4s, v3.8h, v9.8h\n" // for 4th column + "smlal v11.4s, v2.4h, v10.4h\n" // for 5th column + "smlal2 v12.4s, v2.8h, v10.8h\n" // for 5th column + "smlal v13.4s, v3.4h, v10.4h\n" // for 5th column + "smlal2 v14.4s, v3.8h, v10.8h\n" // for 5th column + + // Loop. Decrement loop index. + "subs w6, w6, #1\n" // decrement (reduced) columns left + "bne " LABEL_COL_LOOP "b\n" + + LABEL_SKIP_COL_LOOP + ":\n" + + "addp v28.4s, v28.4s, v29.4s\n" // 1st column + "addp v23.4s, v23.4s, v24.4s\n" // 2nd column + "addp v19.4s, v19.4s, v20.4s\n" // 3rd column + "addp v15.4s, v15.4s, v16.4s\n" // 4th column + "addp v11.4s, v11.4s, v12.4s\n" // 5th column + + "addp v30.4s, v30.4s, v31.4s\n" // 1st column + "addp v25.4s, v25.4s, v26.4s\n" // 2nd column + "addp v21.4s, v21.4s, v22.4s\n" // 3rd column + "addp v17.4s, v17.4s, v18.4s\n" // 4th column + "addp v13.4s, v13.4s, v14.4s\n" // 5th column + + "addp v28.4s, v28.4s, v30.4s\n" // 1st column + "addp v23.4s, v23.4s, v25.4s\n" // 2nd column + "addp v19.4s, v19.4s, v21.4s\n" // 3rd column + "addp v15.4s, v15.4s, v17.4s\n" // 4th column + "addp v11.4s, v11.4s, v13.4s\n" // 5th column + + "srshr v28.4s, v28.4s, %[shift_amount]\n" + "srshr v23.4s, v23.4s, %[shift_amount]\n" + "srshr v19.4s, v19.4s, %[shift_amount]\n" + "srshr v15.4s, v15.4s, %[shift_amount]\n" + "srshr v11.4s, v11.4s, %[shift_amount]\n" + + // Store accumulators. + "st1 {v28.4s}, [%[out_ptr]], #16\n" + "st1 {v23.4s}, [%[out2_ptr]], #16\n" + "st1 {v19.4s}, [%[out3_ptr]], #16\n" + "st1 {v15.4s}, [%[out4_ptr]], #16\n" + "st1 {v11.4s}, [%[out5_ptr]], #16\n" + + // Decrement rows remaining. + "subs %[assigned_rows], %[assigned_rows], #1\n" + "bne " LABEL_ROW_LOOP "b\n" + + // clang-format off + : // outputs + [out_ptr] "+r"(out_ptr), [out2_ptr] "+r"(out2_ptr), + [out3_ptr] "+r"(out3_ptr), [out4_ptr] "+r"(out4_ptr), + [out5_ptr] "+r"(out5_ptr), [weights_ptr] "+r"(weights_ptr), + [col_deltas_bytes] "+r"(col_deltas_bytes), [bias_ptr] "+r"(bias_ptr), + [nnz_per_row] "+r"(nnz_per_row), [assigned_rows] "+r"(assigned_rows), + [rhs_ptr] "+r"(rhs_ptr), [rhs2_ptr] "+r"(rhs2_ptr), + [rhs3_ptr] "+r"(rhs3_ptr), [rhs4_ptr] "+r"(rhs4_ptr), + [rhs5_ptr] "+r"(rhs5_ptr) + : // inputs + [shift_amount] "I"(kShiftAmount) + : // clobbers + "cc", "memory", "x6", "x7", "x8", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on + } +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + constexpr int kSIMDWidth = 4; + for (int i = start; i < end; i += kSIMDWidth) { + int32x4_t add1_int = vld1q_s32(reinterpret_cast(add1 + i)); + int32x4_t add2_int = vld1q_s32(reinterpret_cast(add2 + i)); + int32x4_t result_int = vqaddq_s32(add1_int, add2_int); + vst1q_s32(reinterpret_cast(result + i), result_int); + } +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + constexpr int kSIMDWidth = 8; + for (int i = start; i < end; i += kSIMDWidth) { + int16x8_t add1_int = vld1q_s16(reinterpret_cast(add1 + i)); + int16x8_t add2_int = vld1q_s16(reinterpret_cast(add2 + i)); + int16x8_t result_int = vqaddq_s16(add1_int, add2_int); + vst1q_s16(reinterpret_cast(result + i), result_int); + } +} + +} // namespace detail +} // namespace csrblocksparse + +#undef LABEL_COL_LOOP +#undef LABEL_ROW_LOOP +#undef LABEL_SKIP_COL_LOOP +#undef LABEL_TOP_LOOP + +#endif // defined __aarch64__ +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_ARM_H_ diff --git a/sparse_matmul/compute/kernels_avx.h b/sparse_matmul/compute/kernels_avx.h new file mode 100644 index 00000000..a56fb9cd --- /dev/null +++ b/sparse_matmul/compute/kernels_avx.h @@ -0,0 +1,601 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_ + +#if defined __AVX__ +#include + +#include +#include +// TODO(b/188702959): Remove fast_transcendentals with GRU refactor. +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" + +namespace csrblocksparse { +namespace detail { + +template +struct IsAllowableFloatTypes + : std::integral_constant::value && + std::is_same::value && + std::is_same::value> {}; + +#if defined __AVX2__ +// 16-bit inputs, 32-bit output exponent matches sum of input exponents +// OR +// 16-bit inputs, 16-bit output - will shift to match exponent +template +struct IsAllowableFixedTypes + : std::integral_constant::value && + IsFixed16Type::value) && + (IsFixed32Type::value || + IsFixed16Type::value)> {}; + +template +struct ShouldEnableGenericKernel + : std::integral_constant< + bool, + !IsAllowableFloatTypes::value && + !IsAllowableFixedTypes::value> {}; + +template +struct IsAddableFixedTypes + : std::integral_constant::value || + IsFixed16Type::value> {}; +template +struct ShouldEnableGenericAdd + : std::integral_constant::value> {}; + +#else // No AVX2. + +template +struct ShouldEnableGenericKernel + : std::integral_constant< + bool, !IsAllowableFloatTypes::value> {}; + +template +struct ShouldEnableGenericAdd : std::true_type {}; +#endif // __AVX2__ + +template +struct ShouldEnableGenericSpMV_4x4 + : ShouldEnableGenericKernel {}; +template +struct ShouldEnableGenericSpMM5_4x4 + : ShouldEnableGenericKernel {}; +template +struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; +template +struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; + +// The computational routines do NO error checking for speed. It is assumed +// that this has been handled by CSRBlockSparseMatrix. + +// In-line function to extract results from a pair of registers and store in +// memory. Note that the non-const references are registers, and are modified +// by this function! +inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2, + float** out_ptr) { + // Horizontally add the results. We have 2 registers, |sum1| and |sum2| that + // each contain 2 sets of 4 values that need to be added. + sum1 = _mm256_hadd_ps(sum1, sum2); + sum1 = _mm256_hadd_ps(sum1, sum1); + // Now |sum1| contains [|res0|, |res2|, |res0|, |res2|, |res1|, |res3|, + // |res1|, |res3|] + if (relu) { + sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps()); + } + // It is really hard in AVX to cross the 128 bit 'lanes' and this is the + // *only* way to do it. + // Get the top half of |sum1| in to bottom of |sum2|. + sum2 = _mm256_permute2f128_ps(sum1, sum1, 1); + // Interleave the values between the two registers. + sum1 = _mm256_unpacklo_ps(sum1, sum2); + // Save the lower 128 bits (4 floats). + __m128 result = _mm256_extractf128_ps(sum1, 0); + _mm_store_ps(*out_ptr, result); + *out_ptr += 4; +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a vector and b is vector. Weights are stored for this +// routine by making each 4x4 block contiguous. Blocks are ordered in standard +// row-major format. column indices are converted to deltas and then multiplied +// by 2 to convert to bytes, so that the value can be used directly to offset +// the pointer into the rhs vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + // Broadcast the biases by 4 to undo the division by 4 in the input biases. + __m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), + _mm_broadcast_ss(bias_ptr)); + bias_ptr += 2; + __m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), + _mm_broadcast_ss(bias_ptr)); + bias_ptr += 2; + + int reduced_col_count = *nnz_per_row++; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + rhs_ptr += col_delta; + // Multiply this 4x4 block. + __m256 rhs = + _mm256_broadcast_ps(reinterpret_cast(rhs_ptr)); + __m256 weights1 = _mm256_load_ps(weights_ptr); + weights_ptr += 8; + sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs)); + __m256 weights2 = _mm256_load_ps(weights_ptr); + weights_ptr += 8; + sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs)); + } + Extract4Results(relu, sum1, sum2, &out_ptr); + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if::value && + std::is_same::value && + std::is_same::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + const RhsType* rhs_ptrs[5]; + for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; + + OutType* out_ptrs[5]; + for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; + + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + // We will acumulate the results in 10 registers, |sum1_0| to |sum2_4|. + // Broadcast the biases by 4 to undo the division by 4 in the input biases. + __m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), + _mm_broadcast_ss(bias_ptr)); + bias_ptr += 2; + __m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), + _mm_broadcast_ss(bias_ptr)); + bias_ptr += 2; + __m256 sum1_1 = sum1_0; + __m256 sum2_1 = sum2_0; + __m256 sum1_2 = sum1_0; + __m256 sum2_2 = sum2_0; + __m256 sum1_3 = sum1_0; + __m256 sum2_3 = sum2_0; + __m256 sum1_4 = sum1_0; + __m256 sum2_4 = sum2_0; + + int reduced_col_count = *nnz_per_row++; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; + + // Multiply this 4x4 block. + __m256 rhs = + _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[0])); + __m256 weights1 = _mm256_load_ps(weights_ptr); + weights_ptr += 8; + sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs)); + __m256 weights2 = _mm256_load_ps(weights_ptr); + weights_ptr += 8; + sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs)); + rhs = _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[1])); + sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs)); + sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs)); + rhs = _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[2])); + sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs)); + sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs)); + rhs = _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[3])); + sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs)); + sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs)); + rhs = _mm256_broadcast_ps(reinterpret_cast(rhs_ptrs[4])); + sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs)); + sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs)); + } + + Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]); + Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]); + Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]); + Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]); + Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]); + } +} + +#ifdef __AVX2__ + +// In-line function to finish the computation of the result as 4x int32 in +// |sum|. +inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) { + // Horizontally add the results. We have 1 register that contains results + // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not + // cross lanes, so we end up with [0 1 0 1 2 3 2 3] + sum = _mm256_hadd_epi32(sum, sum); + // Permutes the middle two pairs to get the answers together. + sum = _mm256_permute4x64_epi64(sum, 0xd8); + if (kShiftAmount > 0) { + // Shift right with rounding to get the right number of mantissa bits. + __m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1)); + sum = _mm256_add_epi32(sum, rounding); + sum = _mm256_srai_epi32(sum, kShiftAmount); + } + // Now |sum| contains [|res0|, |res1|, |res2|, |res3|, |res0|, |res1|, + // |res2|, |res3|] + if (relu) { + sum = _mm256_max_epi32(sum, _mm256_setzero_si256()); + } +} + +// In-line function to extract the 4x int32 results from |sum| to memory. +// Non-const reference for |sum| as it is a register. +inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum, + int32_t** out_ptr) { + Compute4Results(relu, kShiftAmount, sum); + // Save the lower 128 bits (4x int32). + __m128i result = _mm256_extractf128_si256(sum, 0); + _mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result); + *out_ptr += 4; +} + +// In-line function to extract the 4x int32 results from sum to 4x int16 in +// memory. +// Non-const reference for |sum| as it is a register. +inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum, + int16_t** out_ptr) { + Compute4Results(relu, kShiftAmount, sum); + // Clip to 16 bit range (with saturation) and pack in the bottom 64 bits. + // Converts the lower 4x int32 in bottom 128 bits to 4x int16 in bottom 64 + // bits, replicated in the next 64 bits. + sum = _mm256_packs_epi32(sum, sum); + // Save 4x int 16 from the bottom 64 bits. + *reinterpret_cast(*out_ptr) = _mm256_extract_epi64(sum, 0); + *out_ptr += 4; +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a vector and b is vector. Weights are stored for this +// routine by making each 4x4 block contiguous. Blocks are ordered in standard +// row-major format. column indices are converted to deltas and then multiplied +// by 2 to convert to bytes, so that the value can be used directly to offset +// the pointer into the rhs vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + (IsFixed32Type::value || IsFixed16Type::value)>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount >= 0, + "Result must have fewer mantissa bits than product"); + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3]. + __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); + __m256i biases = _mm256_set_m128i(bias, bias); + bias_ptr += 4; + // Swap the top two pairs: [0 1 2 3 2 3 0 1] + // TODO(b/188702959): consider |_mm256_permutevar8x32|, and set the index + // register outside the row loop. + biases = _mm256_permute4x64_epi64(biases, 0xb4); + // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3]. + biases = _mm256_unpacklo_epi32(biases, biases); + // Double the results to make up for the division by 4. + // TODO(b/188702959): consider moving this to where the biases are computed. + __m256i sum = _mm256_add_epi32(biases, biases); + + // TODO(b/188702959): People don't like the old-fashioned, close-to-the- + // metal notation of *|nnz_per_row|++, so measure the effect of putting the + // increment in the for loop. + int reduced_col_count = *nnz_per_row; + ++nnz_per_row; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + rhs_ptr += col_delta; + // Multiply this 4x4 block. + // Get the 4x int16 into the bottom of rhs_64. + __m128i rhs_64 = + _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr)); + // Load all 16 weights. + __m256i weights = + _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + // Broadcast the rhs, pretending that each is a 64-bit unit: + // [0123 0123 0123 0123]. + __m256i rhs = _mm256_broadcastq_epi64(rhs_64); + weights_ptr += 16; + // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally + // adds adjacent pairs to make 8x32 bit results. Add these to the sum. + sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs)); + } + static_assert( + IsFixed16Type::value || IsFixed32Type::value, + "AVX2 kernel only supports fixed16 and fixed32 types"); + // The only significant difference between fixed16 and fixed32 is the size + // of the storage unit. The registers have to be repacked accordingly. + if (IsFixed32Type::value) { + Extract4xint32(relu, kShiftAmount, sum, + reinterpret_cast(&out_ptr)); + } else { + Extract4xint16(relu, kShiftAmount, sum, + reinterpret_cast(&out_ptr)); + } + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + IsFixed16Type::value && IsFixed16Type::value && + (IsFixed32Type::value || IsFixed16Type::value)>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount >= 0, + "Result must have fewer mantissa bits than product"); + const RhsType* rhs_ptrs[5]; + for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; + + OutType* out_ptrs[5]; + for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; + + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + // We will acumulate the results in 5 registers, sum_0 to sum_4. + // Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3]. + __m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); + __m256i biases = _mm256_set_m128i(bias, bias); + bias_ptr += 4; + // Swap the top two pairs: [0 1 2 3 2 3 0 1] + biases = _mm256_permute4x64_epi64(biases, 0xb4); + // Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3]. + biases = _mm256_unpacklo_epi32(biases, biases); + // Double the results to make up for the division by 4. + __m256i sum_0 = _mm256_add_epi32(biases, biases); + __m256i sum_1 = sum_0; + __m256i sum_2 = sum_0; + __m256i sum_3 = sum_0; + __m256i sum_4 = sum_0; + + int reduced_col_count = *nnz_per_row; + ++nnz_per_row; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; + // Multiply this 4x4 block. + // Get the 4x int16 into the bottom of |rhs_64|. + __m128i rhs_64 = + _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0])); + // Load all 16 weights. + __m256i weights = + _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + // Broadcast the rhs, pretending that each is a 64-bit unit: + // [0123 0123 0123 0123]. + __m256i rhs = _mm256_broadcastq_epi64(rhs_64); + weights_ptr += 16; + // |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally + // adds adjacent pairs to make 8x32 bit results. Add these to the sum. + sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs)); + rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1])); + rhs = _mm256_broadcastq_epi64(rhs_64); + sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs)); + rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2])); + rhs = _mm256_broadcastq_epi64(rhs_64); + sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs)); + rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3])); + rhs = _mm256_broadcastq_epi64(rhs_64); + sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs)); + rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4])); + rhs = _mm256_broadcastq_epi64(rhs_64); + sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs)); + } + static_assert( + IsFixed16Type::value || IsFixed32Type::value, + "AVX2 kernel only supports fixed16 and fixed32 types"); + // The only significant difference between fixed16 and fixed32 is the size + // of the storage unit. The registers have to be repacked accordingly. + if (IsFixed32Type::value) { + Extract4xint32(relu, kShiftAmount, sum_0, + reinterpret_cast(&out_ptrs[0])); + Extract4xint32(relu, kShiftAmount, sum_1, + reinterpret_cast(&out_ptrs[1])); + Extract4xint32(relu, kShiftAmount, sum_2, + reinterpret_cast(&out_ptrs[2])); + Extract4xint32(relu, kShiftAmount, sum_3, + reinterpret_cast(&out_ptrs[3])); + Extract4xint32(relu, kShiftAmount, sum_4, + reinterpret_cast(&out_ptrs[4])); + } else { + Extract4xint16(relu, kShiftAmount, sum_0, + reinterpret_cast(&out_ptrs[0])); + Extract4xint16(relu, kShiftAmount, sum_1, + reinterpret_cast(&out_ptrs[1])); + Extract4xint16(relu, kShiftAmount, sum_2, + reinterpret_cast(&out_ptrs[2])); + Extract4xint16(relu, kShiftAmount, sum_3, + reinterpret_cast(&out_ptrs[3])); + Extract4xint16(relu, kShiftAmount, sum_4, + reinterpret_cast(&out_ptrs[4])); + } + } +} + +// Processes one GRU gate input with sigmoid. +template +inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr, + const __m256i& input, + const int32_t* sigmoid_table) { + __m256i gate = _mm256_loadu_si256(reinterpret_cast(gate_ptr)); + if (SplitGates) { + __m256i other = + _mm256_loadu_si256(reinterpret_cast(gate_other_ptr)); + gate = _mm256_add_epi32(gate, other); + } + gate = _mm256_add_epi32(gate, input); + // Compute sigmoids on reset and update. + return csrblocksparse::fixed32_sigmoid_fixed16( + sigmoid_table, gate); +} + +// Processes the tanh and the final combination, returning the new GRU state. +template +inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset, + const __m256i& update, + const __m256i& rounding_offset, + const void* gate_ptr, const void* gate_other_ptr, + const void* gru_h_ptr, const int32_t* tanh_table) { + // Multiply the cell GRU output and the reset. There is a slight danger of + // loss of precision here, so use 32x32=64 bit and shift back after. + __m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr)); + if (SplitGates) { + __m256i other_gru = + _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr)); + gru = _mm256_add_epi32(gru, other_gru); + } + // This only computes the products of the low-order 32 bits of each pair. + __m256i gru_lo = _mm256_mul_epi32(gru, reset); + // Swap odd and even 32-bit units and do it again to get the high products. + gru = _mm256_shuffle_epi32(gru, 0xb1); + __m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1)); + // Now shift right to compensate for the multiply and re-interleave the + // 32-bit results. + // NOTE: There is no shift right arithmetic for 64 bit values until AVX512! + // Fortunately it doesn't matter, as the results are being truncated to 32 + // bits and we aren't shifting right by more than 32 bits here. + gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits); + // The upper results are shifted LEFT, so we can use blend to recombine in + // a single instruction. + gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits); + // Recombine the 32 bit results from lo and hi, alternating. + gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa); + gru = _mm256_add_epi32(cell, gru); + // Compute tanh on the result. Although this instantly discards a bunch of + // bits, there were only 7 surplus bits for the multiply, which isn't enough + // to do it as 16x16=32. + __m256i hbar = + csrblocksparse::fixed32_tanh_fixed16(tanh_table, gru); + // Load the 16-bit previous GRU state and sign-extend to 32 bits. + gru = _mm256_cvtepi16_epi32( + _mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr))); + gru = _mm256_sub_epi32(gru, hbar); + // Since |gru| is 16 bit sign-extended to 32, and |update| is the output of + // sigmoid, it is always contained within 16 bits and never negative, we can + // use |madd_epi16| to do 16x16=32 multiply with horizontal adding as the + // addend will always be zero, and this is twice as fast as full blown + // 32x32=32. The only possible problem is if the subtract above caused + // overflow. + gru = _mm256_madd_epi16(gru, update); + // Renormalize to fixed16. This time rounding is critical, as this is the + // output GRU state. + gru = _mm256_add_epi32(gru, rounding_offset); + gru = _mm256_srai_epi32(gru, StateMantissaBits); + return _mm256_add_epi32(gru, hbar); +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + constexpr int kSIMDWidth = 8; + for (int i = start; i < end; i += kSIMDWidth) { + __m256i data1 = + _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); + __m256i data2 = + _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); + data1 = _mm256_add_epi32(data1, data2); + _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); + } +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + constexpr int kSIMDWidth = 16; + for (int i = start; i < end; i += kSIMDWidth) { + __m256i data1 = + _mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); + __m256i data2 = + _mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); + data1 = _mm256_add_epi16(data1, data2); + _mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); + } +} + +#endif // __AVX2__ + +} // namespace detail +} // namespace csrblocksparse + +#undef LABEL_COL_LOOP +#undef LABEL_ROW_LOOP +#undef LABEL_SKIP_COL_LOOP +#undef LABEL_TOP_LOOP + +#endif // __AVX__ + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_ diff --git a/sparse_matmul/compute/kernels_generic.h b/sparse_matmul/compute/kernels_generic.h new file mode 100644 index 00000000..2ff9c7ec --- /dev/null +++ b/sparse_matmul/compute/kernels_generic.h @@ -0,0 +1,273 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_ + +#include +#include + +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" + +// Separate out the assembly kernels for readability. Eventually this will +// become an ifdef switch on the architecture type. +#if defined __aarch64__ +#include "sparse_matmul/compute/kernels_arm.h" +#elif defined __AVX__ +#include "sparse_matmul/compute/kernels_avx.h" +#else // defined __AVX__ +// If there is no architecture-specific implementation, then always use generic. +template +struct ShouldEnableGenericSpMV_4x4 : std::true_type {}; +template +struct ShouldEnableGenericSpMM5_4x4 : std::true_type {}; +template +struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; +template +struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; +template +struct ShouldEnableGenericAdd : std::true_type {}; +#endif // defined __arch64__ + +namespace csrblocksparse { +namespace detail { + +// The computational routines do NO error checking for speed. It is assumed +// that this has been handled by CSRBlockSparseMatrix. + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a vector and b is vector. Weights are stored for this +// routine by making each 4x4 block contiguous. Blocks are ordered in standard +// row-major format. column indices are converted to deltas and then multiplied +// by 2 to convert to bytes, so that the value can be used directly to offset +// the pointer into the rhs vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + ShouldEnableGenericSpMV_4x4::value>::type +SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + float accumulators[4]; + // Undo the divion by the happens for the assembly version. + for (int i = 0; i < 4; ++i) + accumulators[i] = 4.f * static_cast(*bias_ptr++); + + int reduced_col_count = *nnz_per_row++; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + rhs_ptr += col_delta; + + // Multiply this 4x4 block. + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + accumulators[i] += static_cast(*weights_ptr++) * + static_cast(rhs_ptr[j]); + } + } + } + + for (int i = 0; i < 4; ++i) + *out_ptr++ = static_cast(relu ? std::max(accumulators[i], 0.f) + : accumulators[i]); + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 +// blocked pattern, x is a fat vector with 5 columns and b is vector. b is +// broadcast. Weights are stored for this routine by making each 4x4 block +// contiguous. Blocks are ordered in standard row-major format. column indices +// are converted to deltas and then multiplied by 2 to convert to bytes, so +// that the value can be used directly to offset the pointer into the rhs +// vector. +// +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + ShouldEnableGenericSpMM5_4x4::value>::type +SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + const RhsType* rhs_ptrs[5]; + for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; + + OutType* out_ptrs[5]; + for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; + + for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { + float accumulators[4][5]; + // Undo the divion by the happens for the assembly version. + for (int i = 0; i < 4; ++i) { + for (int k = 0; k < 5; ++k) { + accumulators[i][k] = 4.f * static_cast(*bias_ptr); + } + ++bias_ptr; + } + + int reduced_col_count = *nnz_per_row++; + for (int c = 0; c < reduced_col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; + + // multiply this 4x4 block + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 5; ++k) { + accumulators[i][k] += static_cast(*weights_ptr) * + static_cast(rhs_ptrs[k][j]); + } + weights_ptr++; + } + } + } + + for (int k = 0; k < 5; ++k) { + for (int i = 0; i < 4; ++i) { + out_ptrs[k][0] = static_cast( + relu ? std::max(accumulators[i][k], 0.f) : accumulators[i][k]); + out_ptrs[k]++; + } + } + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with +// a 1x1 blocked pattern (ie unstructured), x is a +// vector and b is vector. +// Weights are stored for this routine in standard CSR format. Each row must +// have a multiple of 8 columns. +// column indices are converted to deltas and then multiplied by 2 to convert +// to bytes, so that the value can be used directly to offset the pointer +// into the rhs vector. +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + ShouldEnableGenericSpMV_1x1::value>::type +SpMV_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, + int64_t rows /* only used in SpMM variants */, + int64_t cols /* only used in SpMM variants */, int relu) { + for (int row = 0; row < assigned_rows; ++row) { + // Undo the divion by the happens for the assembly version. + float accumulator = 4.f * static_cast(*bias_ptr++); + + int col_count = *nnz_per_row++; + for (int c = 0; c < col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + rhs_ptr += col_delta; + + accumulator += + static_cast(*weights_ptr++) * static_cast(*rhs_ptr); + } + + *out_ptr++ = + static_cast(relu ? std::max(accumulator, 0.f) : accumulator); + } +} + +// Performs the calculation y = A * x + b where A is a sparse matrix with +// a 1x1 blocked pattern (ie unstructured), x is a +// vector and b is vector. +// Weights are stored for this routine in standard CSR format. Each row must +// have a multiple of 8 columns. +// column indices are converted to deltas and then multiplied by 2 to convert +// to bytes, so that the value can be used directly to offset the pointer +// into the rhs vector. +// NOTE: The bias is expected to have be multiplied by .25f prior to calling +// this function. This is automatically taken care of in SparseLinearLayer. +// The bias is reconstructed through horizontal additions, leads to a small +// speedup by reducing latencies at the end of the loop. +template +typename std::enable_if< + ShouldEnableGenericSpMM5_1x1::value>::type +SpMM5_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, + const int32_t* nnz_per_row, const RhsType* rhs_ptr, + const typename TypeOfProduct::type* bias_ptr, + OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, + int relu) { + const RhsType* rhs_ptrs[5]; + for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; + + OutType* out_ptrs[5]; + for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; + + for (int row = 0; row < assigned_rows; ++row) { + // Undo the divion by the happens for the assembly version. + float accumulator[5]; + for (int i = 0; i < 5; ++i) + accumulator[i] = 4.f * static_cast(*bias_ptr); + + ++bias_ptr; + + int col_count = *nnz_per_row++; + for (int c = 0; c < col_count; ++c) { + int col_delta = *col_deltas_bytes++ / sizeof(RhsType); + for (int i = 0; i < 5; ++i) { + rhs_ptrs[i] += col_delta; + accumulator[i] += static_cast(*weights_ptr) * + static_cast(rhs_ptrs[i][0]); + } + weights_ptr++; + } + + for (int i = 0; i < 5; ++i) { + out_ptrs[i][0] = static_cast(relu ? std::max(accumulator[i], 0.f) + : accumulator[i]); + out_ptrs[i]++; + } + } +} + +template +typename std::enable_if::value>::type SumVectors( + int start, int end, const Type* add1, const Type* add2, Type* result) { + LOG_FIRST_N(WARNING, 1) << "SumVectors: using generic kernel!"; + for (int i = start; i < end; ++i) { + Type sum = static_cast(static_cast(add1[i]) + + static_cast(add2[i])); + result[i] = sum; + } +} + +} // namespace detail +} // namespace csrblocksparse + +#undef LABEL_COL_LOOP +#undef LABEL_ROW_LOOP +#undef LABEL_SKIP_COL_LOOP +#undef LABEL_TOP_LOOP + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_ diff --git a/sparse_matmul/compute/matmul.h b/sparse_matmul/compute/matmul.h new file mode 100644 index 00000000..442164de --- /dev/null +++ b/sparse_matmul/compute/matmul.h @@ -0,0 +1,199 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_ + +#include +#include + +#include "absl/time/time.h" +#include "sparse_matmul/compute/matmul_fixed_avx2.h" +#include "sparse_matmul/compute/matmul_generic.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32) +#include +#endif + +namespace csrblocksparse { + +// The number of elements in a block. +constexpr int kBlockSize = 4; + +// Base class for Matmul containing the members that are non type-specicfic. +class MatmulBase { + public: + // Constructor initializes the flags that determine which implementation to + // use at run-time, constrained by both compiler flags and cpuid. + MatmulBase() { +#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32) + // Code tested to work on Linux systems and multiple Android emulators. + unsigned int eax, ebx, ecx, edx; + if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) { + using_avx_ = (ecx & bit_AVX) != 0; + if (using_avx_) { + __get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx); + using_avx2_ = (ebx & bit_AVX2) != 0; + using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) && + (ebx & bit_AVX512BW) != 0; + VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_; + } else { + LOG(ERROR) << "AVX not found at all!"; + } + } +#else + using_aarch64_ = true; +#endif + } + + protected: + // Flags that define what (runtime) architectures are available. Flags that + // are set are limited by both the compiler flags and runtime environment. + bool using_avx512_ = false; + bool using_avx2_ = false; + bool using_avx_ = false; + bool using_aarch64_ = false; +}; + +// The master template is really a catch-all for the unimplmented cases to +// report an error. +template +class Matmul : public MatmulBase { + public: + // Sparse inputs, outputs replicated strided for each thread. + template + void MatVec4x4(const WeightType* weights, const RhsType* rhs, + const typename TypeOfProduct::type* bias, + const int32_t* nnz_per_row, const int16_t* rhs_indices, + int start_row, int end_row, bool relu, int replicas, + int stride, OutType* output) { + // The specializations should take care of every real case. + CHECK(false) << "Unsupported combination of types used!"; + } + template + void MatVec8x4(const WeightType* weights, const RhsType* rhs, + const typename TypeOfProduct::type* bias, + const int32_t* nnz_per_row, const int16_t* rhs_indices, + int start_row, int end_row, bool relu, int replicas, + int stride, OutType* output) { + // The specializations should take care of every real case. + CHECK(false) << "Unsupported combination of types used!"; + } +}; + +// Full specialization for float. +template <> +class Matmul : public MatmulBase { + public: + void MatVec4x4(const float* weights, const float* rhs, const float* bias, + const int32_t* nnz_per_row, const int16_t* rhs_indices, + int start_row, int end_row, bool relu, int replicas, + int stride, float* output) { + detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, /*block_height=*/4, + /*block_width=*/4, relu, replicas, stride, + output); + } + void MatVec8x4(const float* weights, const float* rhs, const float* bias, + const int32_t* nnz_per_row, const int16_t* rhs_indices, + int start_row, int end_row, bool relu, int replicas, + int stride, float* output) { + detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, /*block_height=*/8, + /*block_width=*/4, relu, replicas, stride, + output); + } +}; + +// Partial specialization for fixed types. Covers fixed16xfixed16 = OutType, +// where OutType should be fixed16 or fixed32. The mantissa bits don't have +// to match. +template +class Matmul, fixed16> : public MatmulBase { + public: + using WeightType = fixed16; + using RhsType = fixed16; + + template + void MatVec4x4(const int16_t* weights, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int replicas, int stride, OutType* output) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount >= 0, + "OutType must not have more mantissa bits than inputs"); +#if defined __AVX2__ + CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; + if (sizeof(*output) == 4) { + int32_t* out32 = reinterpret_cast(output); + detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, relu, kShiftAmount, + replicas, stride, out32); + } else { + int16_t* out16 = reinterpret_cast(output); + detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, relu, kShiftAmount, + replicas, stride, out16); + } +#elif defined __aarch64__ + if (using_aarch64_) { + LOG(FATAL) << "Fixed16 MatVec4x4 not yet implemented!"; + } + +#else + detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, /*block_height=*/4, + /*block_width=*/4, relu, sizeof(*output), + kShiftAmount, replicas, stride, output); +#endif // __AVX2__ + } + + template + void MatVec8x4(const int16_t* weights, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int replicas, int stride, OutType* output) { + constexpr int kShiftAmount = + TypeOfProduct::type::kMantissaBits - + OutType::kMantissaBits; + static_assert(kShiftAmount >= 0, + "OutType must not have more mantissa bits than inputs"); +#if defined __AVX2__ + CHECK(replicas == 1 && sizeof(*output) == 4) + << "Only replicas == 1 and fixed32 output are implemented for AVX2!"; + CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; + int32_t* out32 = reinterpret_cast(output); + detail::MatVec8x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, relu, kShiftAmount, out32); +#elif defined __aarch64__ + if (using_aarch64_) { + LOG(FATAL) << "Fixed16 MatVec8x4 not yet implemented!"; + } +#else + detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices, + start_row, end_row, /*block_height=*/8, + /*block_width=*/4, relu, sizeof(*output), + kShiftAmount, replicas, stride, output); +#endif // __AVX2__ + } +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_ diff --git a/sparse_matmul/compute/matmul_fixed_avx2.cc b/sparse_matmul/compute/matmul_fixed_avx2.cc new file mode 100644 index 00000000..f1e0905f --- /dev/null +++ b/sparse_matmul/compute/matmul_fixed_avx2.cc @@ -0,0 +1,232 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/compute/matmul_fixed_avx2.h" + +#include + +#if defined __AVX__ +#include +#endif + +#include "sparse_matmul/compute/matmul.h" + +namespace csrblocksparse { +namespace detail { + +#if defined __AVX2__ +// In-line function computes and returns the result of one row (of blocks) as +// 4x int32_t. |weights_ptr| is a non-const reference so it can easily be +// interpreted as belonging to the caller. +inline __m256i ComputeRowResults(const __m128i& bias128, const int16_t* rhs, + const int16_t* rhs_indices, int nnz, + int16_t const*& weights_ptr) { + // Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is + // Zero and 0-3 are the 4x32 bit bias values. + __m256i sum = _mm256_cvtepu32_epi64(bias128); + + for (int c = 0; c < nnz; ++c) { + int rhs_index = rhs_indices[c]; + // Load all 16 weights. + __m256i weights = + _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + // Get the 4x int16_t into the bottom of |rhs_64|. + __m128i rhs_64 = _mm_loadl_epi64( + reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize)); + // Broadcast the rhs, pretending that each is a 64-bit unit: + // [0123 0123 0123 0123]. + __m256i rhs_value = _mm256_broadcastq_epi64(rhs_64); + weights_ptr += 16; + sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs_value)); + } + // Horizontally add the results. We have 1 register that contains results + // [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not + // cross lanes, so we end up with [0 1 0 1 2 3 2 3] + sum = _mm256_hadd_epi32(sum, sum); + // Permutes the middle two pairs to get the answers together. + return _mm256_permute4x64_epi64(sum, 0xd8); +} + +// Template that allows any fixed combination of OutType and replicas, plus +// variable |relu|, |shift_out|. Note that |kReplicas| is a template arg as +// well as a function arg so we can hard-code a limited amount of unrolling. +template +void MatVec4x4FixedAVX2Template(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, + int end_row, bool relu, int shift_out, + int replicas, int stride, OutType* output) { + int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0; + __m256i rounding = _mm256_set1_epi32(rounding_addon); + __m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min); + for (int row_block = start_row; row_block < end_row; ++row_block) { + // Load 4 biases [0 1 2 3]. + __m128i bias128 = _mm_load_si128(reinterpret_cast<__m128i const*>(bias)); + bias += kBlockSize; + int nnz = nnz_per_row[row_block]; + __m256i sum = + ComputeRowResults(bias128, rhs, rhs_indices, nnz, weights_ptr); + rhs_indices += nnz; + // Shift right with rounding to get the right number of mantissa bits. + sum = _mm256_add_epi32(sum, rounding); + sum = _mm256_srai_epi32(sum, shift_out); + // Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3] + sum = _mm256_max_epi32(sum, zero); + if (sizeof(OutType) == 2) { + // Clip to 16 bit range (with saturation) and pack in the bottom 64 + // bits. The 64 bit result is replicated across the whole 256 bit + // register. [0123 0123 0123 0123] + sum = _mm256_packs_epi32(sum, sum); + int64_t result = _mm256_extract_epi64(sum, 0); + *reinterpret_cast(output) = result; + if (kReplicas > 1) { + *reinterpret_cast(output + stride) = result; + if (kReplicas > 2) { + for (int r = 2; r < replicas; ++r) { + *reinterpret_cast(output + r * stride) = result; + } + } + } + } else { + // Save the lower 128 bits (4x int32_t). + __m128i result = _mm256_extractf128_si256(sum, 0); + _mm_store_si128(reinterpret_cast<__m128i*>(output), result); + if (kReplicas > 1) { + _mm_store_si128(reinterpret_cast<__m128i*>(output + stride), result); + if (kReplicas > 2) { + for (int r = 2; r < replicas; ++r) { + _mm_store_si128(reinterpret_cast<__m128i*>(output + r * stride), + result); + } + } + } + } + output += kBlockSize; + } +} + +// Version that covers all possible combinations of the variable conditions: +// |relu|, |shift_out|, |replicas|, with int16_t |output|. +void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int replicas, int stride, + int16_t* output) { + if (replicas <= 1) { + MatVec4x4FixedAVX2Template(weights_ptr, rhs, bias, nnz_per_row, + rhs_indices, start_row, end_row, + relu, shift_out, 1, stride, output); + } else if (replicas == 2) { + MatVec4x4FixedAVX2Template(weights_ptr, rhs, bias, nnz_per_row, + rhs_indices, start_row, end_row, + relu, shift_out, 2, stride, output); + } else { + MatVec4x4FixedAVX2Template( + weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row, + relu, shift_out, replicas, stride, output); + } +} + +// Version that covers all possible combinations of the variable conditions: +// |relu|, |shift_out|, |replicas|, with int32_t |output|. +void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int replicas, int stride, + int32_t* output) { + if (replicas <= 1) { + MatVec4x4FixedAVX2Template(weights_ptr, rhs, bias, nnz_per_row, + rhs_indices, start_row, end_row, + relu, shift_out, 1, stride, output); + } else if (replicas == 2) { + MatVec4x4FixedAVX2Template(weights_ptr, rhs, bias, nnz_per_row, + rhs_indices, start_row, end_row, + relu, shift_out, 2, stride, output); + } else { + MatVec4x4FixedAVX2Template( + weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row, + relu, shift_out, replicas, stride, output); + } +} + +// In-line function computes and returns the result of one row (of blocks) as +// 8x int32_t. weights_ptr is a non-const reference so it can easily be +// interpreted as belonging to the caller. +inline __m256i Compute8RowResults(const __m256i& bias256, const int16_t* rhs, + const int16_t* rhs_indices, int nnz, + int16_t const*& weights_ptr) { + // Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is + // Zero and 0-3 are the 4x32 bit bias values from 128 bit half of the input. + __m256i sum1 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(bias256)); + // Plus 4 more in another sum register from the upper 128 bit half. + __m256i sum2 = _mm256_cvtepu32_epi64(_mm256_extractf128_si256(bias256, 1)); + + for (int c = 0; c < nnz; ++c) { + int rhs_index = rhs_indices[c]; + // Load all 16 weights. + __m256i weights = + _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + // Get the 4x int16_t into the bottom of |rhs_64|. + __m128i rhs_64 = _mm_loadl_epi64( + reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize)); + // Broadcast the rhs, pretending that each is a 64-bit unit: + // [0123 0123 0123 0123]. + __m256i rhs_value = _mm256_broadcastq_epi64(rhs_64); + weights_ptr += 16; + sum1 = _mm256_add_epi32(sum1, _mm256_madd_epi16(weights, rhs_value)); + // Same again for the other 4 results, re-using the same rhs value. + weights = _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); + weights_ptr += 16; + sum2 = _mm256_add_epi32(sum2, _mm256_madd_epi16(weights, rhs_value)); + } + // Horizontally add the results. We have 2 registers that contain results + // [0 0 1 1 2 2 3 3], and [4 4 5 5 6 6 7 7] but hadd (and almost no other AVX + // instruction) will not cross lanes, so we end up with [0 1 4 5 2 3 6 7] + sum1 = _mm256_hadd_epi32(sum1, sum2); + // Permutes the middle two pairs to get the answers in the right order. + return _mm256_permute4x64_epi64(sum1, 0xd8); +} + +// Version that covers the main conditions used with 8x4: +// |relu|, |shift_out|, with int32_t |output|. +void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int32_t* output) { + int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0; + __m256i rounding = _mm256_set1_epi32(rounding_addon); + __m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min); + for (int row_block = start_row; row_block < end_row; ++row_block) { + // Load 4 biases [0 1 2 3 4 5 6 7]. + __m256i bias256 = _mm256_load_si256(reinterpret_cast<__m256i const*>(bias)); + bias += kBlockSize * 2; + int nnz = nnz_per_row[row_block]; + __m256i sum = + Compute8RowResults(bias256, rhs, rhs_indices, nnz, weights_ptr); + rhs_indices += nnz; + // Shift right with rounding to get the right number of mantissa bits. + sum = _mm256_add_epi32(sum, rounding); + sum = _mm256_srai_epi32(sum, shift_out); + // Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3] + sum = _mm256_max_epi32(sum, zero); + // Save the all 256 bits (8x int32_t). + _mm256_store_si256(reinterpret_cast<__m256i*>(output), sum); + output += kBlockSize * 2; + } +} + +#endif + +} // namespace detail +} // namespace csrblocksparse diff --git a/sparse_matmul/compute/matmul_fixed_avx2.h b/sparse_matmul/compute/matmul_fixed_avx2.h new file mode 100644 index 00000000..59e7d0ea --- /dev/null +++ b/sparse_matmul/compute/matmul_fixed_avx2.h @@ -0,0 +1,49 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_ + +#include + +namespace csrblocksparse { +namespace detail { + +// Version that covers all possible combinations of the variable conditions: +// |relu|, |shift_out|, |replicas|, with int16 output. +void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int replicas, int stride, + int16_t* output); +// Version that covers all possible combinations of the variable conditions: +// |relu|, |shift_out|, |replicas|, with int32 output. +void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int replicas, int stride, + int32_t* output); +// Version that covers the main conditions used with 8x4: +// |relu|, |shift_out|, with int32 output. +void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + bool relu, int shift_out, int32_t* output); + +} // namespace detail +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_ diff --git a/sparse_matmul/compute/matmul_generic.cc b/sparse_matmul/compute/matmul_generic.cc new file mode 100644 index 00000000..1cf4fe53 --- /dev/null +++ b/sparse_matmul/compute/matmul_generic.cc @@ -0,0 +1,122 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/compute/matmul_generic.h" + +#include +#include + +#include "sparse_matmul/compute/matmul.h" + +namespace csrblocksparse { +namespace detail { + +void MatVecFloatGeneric(const float* weights, const float* rhs, + const float* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + int block_height, int block_width, bool relu, + int replicas, int stride, float* output) { + int weight_index = 0; + int bias_index = 0; + std::vector accumulators(block_height); + for (int row_block = start_row; row_block < end_row; + ++row_block, output += block_height) { + int nnz = nnz_per_row[row_block]; + // Biases are now stored and used directly without pre-division. + for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++]; + + for (int c = 0; c < nnz; ++c) { + int rhs_index = rhs_indices[c]; + const float* block_rhs = rhs + rhs_index * block_width; + // Multiply this |block_height| x |block_width| block. + for (int i = 0; i < block_height; ++i) { + for (int j = 0; j < block_width; ++j) { + accumulators[i] += weights[weight_index++] * block_rhs[j]; + } + } + } + rhs_indices += nnz; + // Apply relu if desired. + if (relu) { + for (int i = 0; i < block_height; ++i) { + if (accumulators[i] < 0) accumulators[i] = 0; + } + } + for (int r = 0; r < replicas; ++r) { + for (int i = 0; i < block_height; ++i) { + output[i + r * stride] = accumulators[i]; + } + } + } +} + +void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + int block_height, int block_width, bool relu, + int bytes_out, int shift_out, int replicas, int stride, + void* output) { + int weight_index = 0; + int bias_index = 0; + std::vector accumulators(block_height); + for (int row_block = start_row; row_block < end_row; ++row_block) { + int nnz = nnz_per_row[row_block]; + // Biases are now stored and used directly without pre-division. + for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++]; + + for (int c = 0; c < nnz; ++c) { + int rhs_index = rhs_indices[c]; + const int16_t* block_rhs = rhs + rhs_index * block_width; + // Multiply this |block_height| x |block_width| block. + for (int i = 0; i < block_height; ++i) { + for (int j = 0; j < block_width; ++j) { + accumulators[i] += weights[weight_index++] * block_rhs[j]; + } + } + } + rhs_indices += nnz; + // Apply relu if desired. + if (relu) { + for (int i = 0; i < block_height; ++i) { + if (accumulators[i] < 0) accumulators[i] = 0; + } + } + // Output shift. + if (shift_out > 0) { + for (int i = 0; i < block_height; ++i) { + accumulators[i] >>= shift_out; + } + } + if (bytes_out == 2) { + int16_t* out16 = reinterpret_cast(output); + output = out16 + block_height; + for (int r = 0; r < replicas; ++r, out16 += stride) { + for (int i = 0; i < block_height; ++i) { + out16[i] = accumulators[i]; + } + } + } else { + int32_t* out32 = reinterpret_cast(output); + output = out32 + block_height; + for (int r = 0; r < replicas; ++r, out32 += stride) { + for (int i = 0; i < block_height; ++i) { + out32[i] = accumulators[i]; + } + } + } + } +} + +} // namespace detail +} // namespace csrblocksparse diff --git a/sparse_matmul/compute/matmul_generic.h b/sparse_matmul/compute/matmul_generic.h new file mode 100644 index 00000000..415d71cd --- /dev/null +++ b/sparse_matmul/compute/matmul_generic.h @@ -0,0 +1,41 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_ + +#include + +namespace csrblocksparse { +namespace detail { + +// Generic version uses plain C++ code. +void MatVecFloatGeneric(const float* weights, const float* rhs, + const float* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + int block_height, int block_width, bool relu, + int replicas, int stride, float* output); +void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs, + const int32_t* bias, const int32_t* nnz_per_row, + const int16_t* rhs_indices, int start_row, int end_row, + int block_height, int block_width, bool relu, + int bytes_out, int shift_out, int replicas, int stride, + void* output); + +} // namespace detail +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_ diff --git a/sparse_matmul/compute/thread_bounds.cc b/sparse_matmul/compute/thread_bounds.cc new file mode 100644 index 00000000..e37a395e --- /dev/null +++ b/sparse_matmul/compute/thread_bounds.cc @@ -0,0 +1,106 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/compute/thread_bounds.h" + +#include + +#include "glog/logging.h" + +namespace csrblocksparse { + +void ThreadBounds::PrepareForThreads(int block_width, int block_height, + int num_threads, + int reduced_rows_per_cache_row, + int reduced_rows, const int* nnz_per_row) { + CHECK_GT(num_threads, 0); + block_width_ = block_width; + block_height_ = block_height; + ComputeThreadSplitPoints(num_threads, reduced_rows_per_cache_row, + reduced_rows, nnz_per_row); + weight_starts_.clear(); + rhs_indices_starts_.clear(); + bias_starts_.clear(); + weight_starts_.reserve(row_starts_.size()); + rhs_indices_starts_.reserve(row_starts_.size()); + bias_starts_.reserve(row_starts_.size()); + + // Compute the start indices of each of the types, given what we know about + // padding, and number of |nnz_per_row|. + int weight_index = 0; + int rhs_indices_index = 0; + int bias_index = 0; + int row = 0; + for (int start : row_starts_) { + while (row < start) { + weight_index += nnz_per_row[row] * block_width_ * block_height_; + rhs_indices_index += nnz_per_row[row]; + bias_index += block_height_; + ++row; + } + weight_starts_.push_back(weight_index); + rhs_indices_starts_.push_back(rhs_indices_index); + bias_starts_.push_back(bias_index); + } +} + +// Computes the block row (reduced) index of the start of each thread. +void ThreadBounds::ComputeThreadSplitPoints(int num_threads, + int reduced_rows_per_cache_row, + int reduced_rows, + const int* nnz_per_row) { + row_starts_.assign(/*n=*/1, /*val=*/0); + // Break the rule if the matrix is too small to allow one per thread, which + // occurs only during tests. + if (reduced_rows_per_cache_row * num_threads > reduced_rows) + reduced_rows_per_cache_row = std::max(reduced_rows / num_threads, 1); + int cache_rows = (reduced_rows + reduced_rows_per_cache_row - 1) / + reduced_rows_per_cache_row; + + // Compute exclusive prefix sum of the amount of work per row. + std::vector work_upto_row(cache_rows + 1, 0); + int extra_row_work = 2 * reduced_rows_per_cache_row; + for (int i = 0; i < cache_rows; ++i) { + int new_nnz = 0; + for (int j = 0; j < reduced_rows_per_cache_row; ++j) { + // if |reduced_rows_per_cache_row| isn't an exact multiple of the + // matrix size, then we need to be careful here. + int index = i * reduced_rows_per_cache_row + j; + if (index < reduced_rows) new_nnz += nnz_per_row[index]; + } + work_upto_row[i + 1] = new_nnz + extra_row_work + work_upto_row[i]; + } + int total_work = work_upto_row.back(); + // Find the split point point based on assigned approximately equal amount + // of work for each thread. + int prev_split = 0; + for (int i = 1; i <= num_threads; ++i) { + int split = std::distance( + work_upto_row.begin(), + std::lower_bound(work_upto_row.begin(), work_upto_row.end(), + i * total_work / num_threads)); + int split_row = split * reduced_rows_per_cache_row; + if (i == num_threads) { + split_row = reduced_rows; + } + + VLOG(2) << "tid=" << i - 1 << " num rows=" << split_row - row_starts_.back() + << " work=" << work_upto_row[split] - work_upto_row[prev_split]; + row_starts_.push_back(split_row); + prev_split = split; + } + VLOG(2) << "total rows=" << reduced_rows << " total work=" << total_work; +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/compute/thread_bounds.h b/sparse_matmul/compute/thread_bounds.h new file mode 100644 index 00000000..fd8a7d2b --- /dev/null +++ b/sparse_matmul/compute/thread_bounds.h @@ -0,0 +1,74 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_ + +#include + +namespace csrblocksparse { + +// Class to compute and store the bounds of each thread used in a computation, +// and to provide corresponding spans of vectors. +class ThreadBounds { + public: + ThreadBounds() : block_width_(0), block_height_(0) {} + + void PrepareForThreads(int block_width, int block_height, int num_threads, + int reduced_rows_per_cache_row, int reduced_rows, + const int* nnz_per_row); + + // Functions that offset the appropriate type to the start of the data + // needed by the given thread id (|tid|). + template + const WeightType* OffsetWeights(const WeightType* weights, int tid) const { + return weights + weight_starts_[tid]; + } + template + const RhsIndType* OffsetRhsIndices(const RhsIndType* rhs_indices, + int tid) const { + return rhs_indices + rhs_indices_starts_[tid]; + } + template + const BiasType* OffsetBias(const BiasType* bias, int tid) const { + return bias + bias_starts_[tid]; + } + template + OutType* OffsetOutput(OutType* output, int tid) const { + return output + block_height_ * row_starts_[tid]; + } + int StartRow(int tid) const { return row_starts_[tid]; } + const std::vector& row_starts() const { return row_starts_; } + + private: + // Computes the block row (reduced) index of the start of each thread. + void ComputeThreadSplitPoints(int num_threads, int reduced_rows_per_cache_row, + int reduced_rows, const int* nnz_per_row); + + // Sizes of a sparse block. + int block_width_; + int block_height_; + // Start indices of each data type by thread-id with an extra value at the + // end. + std::vector row_starts_; + std::vector weight_starts_; + std::vector rhs_indices_starts_; + std::vector bias_starts_; +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_ diff --git a/sparse_matmul/layers/BUILD b/sparse_matmul/layers/BUILD new file mode 100644 index 00000000..7c4ed36d --- /dev/null +++ b/sparse_matmul/layers/BUILD @@ -0,0 +1,146 @@ +# Sparse/Masked Matrix and Layer. + +# [internal] load android_library_selector +# [internal] load android_cc_test:def.bzl + +licenses(["notice"]) + +cc_library( + name = "layer", + hdrs = [ + "sparse_linear_layer.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + ":matrix", + "//sparse_matmul/numerics:types", + "//sparse_matmul/os:coop_threads", + "//sparse_matmul/vector:cache_aligned_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_glog//:glog", + ], +) + +cc_library( + name = "matrix", + hdrs = [ + "csr_blocksparse_matrix.h", + "masked_sparse_matrix.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "//sparse_matmul/compute:kernels", + "//sparse_matmul/compute:matmul", + "//sparse_matmul/compute:thread_bounds", + "//sparse_matmul/numerics:types", + "//sparse_matmul/os:coop_threads", + "//sparse_matmul/vector:cache_aligned_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + "@com_google_glog//:glog", + ], +) + +cc_library( + name = "utils", + srcs = [ + "utils.cc", + ], + hdrs = [ + "read_array_ifstream.h", + "utils.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + ":layer", + ":matrix", + ":status", + "//sparse_matmul/numerics:types", + "//sparse_matmul/vector:cache_aligned_vector", + "//sparse_matmul/zlib_wrapper", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@gulrak_filesystem//:filesystem", + ], +) + +cc_library( + name = "status", + srcs = [ + "errno_mapping.cc", + ], + hdrs = [ + "errno_mapping.h", + "status_macros.h", + ], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +cc_test( + name = "csrblocksparse_test", + size = "small", + srcs = [ + "csrblocksparse_test.cc", + ], + data = glob(["testdata/*"]), + linkopts = select({ + "@bazel_tools//platforms:android": ["-landroid"], + "//conditions:default": [], + }), + shard_count = 10, + deps = [ + ":status", + ":utils", + "//sparse_matmul/compute:matmul", + "//sparse_matmul/numerics:test_utils", + "//sparse_matmul/os:coop_threads", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@gulrak_filesystem//:filesystem", + ], +) + +cc_test( + name = "sparse_linear_layer_test", + srcs = [ + "sparse_linear_layer_test.cc", + ], + deps = [ + ":layer", + "//sparse_matmul/numerics:test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + deps = [ + ":layer", + ":matrix", + ":status", + ":utils", + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:test_utils", + "//sparse_matmul/numerics:types", + "//sparse_matmul/vector:cache_aligned_vector", + "@com_google_absl//absl/flags:flag", + "@com_google_googletest//:gtest_main", + "@gulrak_filesystem//:filesystem", + ], +) diff --git a/sparse_matmul/layers/csr_blocksparse_matrix.h b/sparse_matmul/layers/csr_blocksparse_matrix.h new file mode 100644 index 00000000..be515735 --- /dev/null +++ b/sparse_matmul/layers/csr_blocksparse_matrix.h @@ -0,0 +1,835 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_ + +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" +// IWYU pragma: begin_exports +#include "sparse_matmul/compute/kernels_generic.h" +#include "sparse_matmul/compute/matmul.h" +#include "sparse_matmul/compute/thread_bounds.h" +#include "sparse_matmul/layers/masked_sparse_matrix.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/os/coop_threads.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" +// IWYU pragma: end_exports +#include "absl/memory/memory.h" + +namespace csrblocksparse { +// CsrBlockSparseMatrix stores a modified block compressed sparse row +// representation of a sparse matrix. The ordering of the weights is modified +// in the 16x1 and 1x1 cases so that a certain number (4 and 8 respectively) +// of columns of weights are stored contiguously before moving on to the next +// row. The 4x4 case stores each block contiguously. +// +// Currently it is constructed from a MaskedSparseMatrix which usees a dense +// binary mask representation. The construction generates the compressed +// representation. Further iterations will support a direct serialization +// of the compressed representation. +// +// MaskedSparseMatrix masked_matrix(rows, cols, existing_mask, existing_values) +// CsrBlockSparseMatrix matrix(masked_matrix) +// +// matrix.SpMV_bias(rhs, bias, &out); +// +// This class is thread compatible. +template +class CsrBlockSparseMatrix { + public: + CsrBlockSparseMatrix() {} + + // Reference used to indicate that this is an input and not an output. + CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len) { + ReadFromFlatBuffer(buffer, len); + ComputeRHSIndices(); + } + + template + CsrBlockSparseMatrix(const MaskedSparseMatrix& masked_matrix) { + sparsity_ = masked_matrix.sparsity(); + rows_ = masked_matrix.rows(); + cols_ = masked_matrix.cols(); + + DetermineBlockSize(masked_matrix); + + if (block_width_ == 1 && block_height_ == 1) + col_multiple_ = 8; + else + col_multiple_ = 1; + + std::vector weights(masked_matrix.values().begin(), + masked_matrix.values().end()); + + reduced_rows_ = (rows_ + block_height_ - 1) / block_height_; + rows_ = reduced_rows_ * block_height_; + reduced_cols_ = cols_ / block_width_; + + // Calculate the reduced CSR representation of the matrix. + std::vector reduced_mask(reduced_rows_ * reduced_cols_); + std::vector row_offsets = {0}; + int nnz = 0; + const auto& mask = masked_matrix.mask(); + for (int r = 0; r < reduced_rows_; ++r) { + for (int c = 0; c < reduced_cols_; ++c) { + int mask_val = mask[r * block_height_ * cols_ + c * block_width_]; + reduced_mask[r * reduced_cols_ + c] = mask_val; + nnz += mask_val; + } + row_offsets.push_back(nnz); + } + + // Make sure the reduced representation has the correct number of columns. + MakeColumnsMultiple(row_offsets, &reduced_mask, &weights); + + std::vector col_indices; + std::vector weights_csr; + std::vector nnz_per_row; + MaskAndWeightsToCsr(reduced_mask, weights, &nnz_per_row, &col_indices, + &weights_csr); + + // Generate column deltas from |col_indices|. + std::vector col_deltas; + for (int i = 0; i < col_indices.size(); ++i) { + // |col_indices| are used to index the RHS vector which is always float. + int64_t diff = sizeof(RhsType); + if (i == 0) + diff *= block_width_ * (col_indices[i]); + else + diff *= block_width_ * (col_indices[i] - col_indices[i - 1]); + + CHECK(diff < std::numeric_limits::max()) + << "delta between column indices in bytes " << diff + << " exceeded the maximum size of the DeltaType " + << std::numeric_limits::max(); + col_deltas.push_back(static_cast(diff)); + } + + // Because of pre-fetching we need some extra values at the end. + col_deltas.insert(col_deltas.end(), std::max(2, col_multiple_ + 1), 0); + nnz_per_row.insert(nnz_per_row.end(), 2, nnz_per_row.back()); + + weights_ = CacheAlignedVector(weights_csr); + col_deltas_ = CacheAlignedVector(col_deltas); + nnz_per_row_ = CacheAlignedVector(nnz_per_row); + ComputeRHSIndices(); + + num_threads_ = 0; + PrepareForThreads(1); + } + + // Constructor makes a matrix from the given weights, deltas and nnz, taking + // the other parameters from |src_matrix|. |cols| is the number of raw columns + // (NOT blocks) of the new matrix. + CsrBlockSparseMatrix( + const CsrBlockSparseMatrix& src_matrix, + const std::vector& new_weights, + const std::vector& new_deltas, const std::vector& new_nnz, + int cols) { + num_threads_ = 0; + col_multiple_ = src_matrix.col_multiple_; + block_width_ = src_matrix.block_width_; + block_height_ = src_matrix.block_height_; + reduced_rows_ = new_nnz.size(); + rows_ = reduced_rows_ * block_height_; + cols_ = cols; + reduced_cols_ = cols_ / block_width_; + weights_ = CacheAlignedVector(new_weights); + col_deltas_ = CacheAlignedVector(new_deltas); + nnz_per_row_ = CacheAlignedVector(new_nnz); + sparsity_ = 1.0f - static_cast(new_weights.size()) / (rows_ * cols_); + ComputeRHSIndices(); + name_ = src_matrix.name_; + PrepareForThreads(1); + } + + // Factory method takes a column slice out of *this and returns a sparse + // matrix that takes as inputs [|start_col|, |end_col|) of *this, and + // returns the same number of outputs, but only a partial result. + // If |keep_rhs_size|, then the new matrix takes the same rhs as the current + // matrix, but uses a subset of it, instead of expecting just the reduced rhs. + // If |start_col| > |end_col|, then we slice out the complement of the defined + // interval, ie [0, |end_col|) + [|start_col|, current end). + // NOTE That |start_col| and |end_col| are in raw column coordinates, NOT + // block units. + CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col, + bool keep_rhs_size = false) const { + int weight_index = 0; + int delta_index = 0; + std::vector new_deltas; + std::vector new_weights; + std::vector new_nnz(reduced_rows_); + int col = 0; + int prev_col = keep_rhs_size ? 0 : start_col; + for (int r = 0; r < reduced_rows_; ++r) { + int reduced_col_count = nnz_per_row_[r]; + for (int c = 0; c < reduced_col_count; ++c, ++delta_index) { + col += col_deltas_[delta_index] / sizeof(RhsType); + if ((start_col < end_col && start_col <= col && col < end_col) || + (start_col > end_col && (col < end_col || col >= start_col))) { + ++new_nnz[r]; + new_deltas.push_back((col - prev_col) * sizeof(RhsType)); + prev_col = col; + for (int i = 0; i < block_width_ * block_height_; + ++i, ++weight_index) { + new_weights.push_back(weights_[weight_index]); + } + } else { + weight_index += block_width_ * block_height_; + } + } + } + int new_cols = keep_rhs_size ? cols_ : end_col - start_col; + return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, + new_cols); + } + + // Factory method takes a row slice out of *this and returns a sparse + // matrix that takes the sampe inputs as *this, and returns the outputs for + // the range [|start_row|, |end_row|). + // NOTE That |start_row| and |end_row| are in raw column coordinates, NOT + // block units. + CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const { + int start_reduced = start_row / block_height_; + int end_reduced = end_row / block_height_; + std::vector new_nnz(nnz_per_row_.data() + start_reduced, + nnz_per_row_.data() + end_reduced); + int weight_start = 0; + for (int r = 0; r < start_reduced; ++r) { + weight_start += nnz_per_row_[r]; + } + int weight_end = weight_start; + for (int r = start_reduced; r < end_reduced; ++r) { + weight_end += nnz_per_row_[r]; + } + int delta_start = 0; + for (int i = 0; i < weight_start; ++i) { + delta_start += col_deltas_[i]; + } + std::vector new_deltas(col_deltas_.data() + weight_start, + col_deltas_.data() + weight_end); + new_deltas[0] += delta_start; + int block_size = block_height_ * block_width_; + std::vector new_weights( + weights_.data() + weight_start * block_size, + weights_.data() + weight_end * block_size); + return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, cols_); + } + + // Combines adjacent row blocks, doubling the block height. + // This necessarily involves adding zero weights where the blocks don't align + // across adjacent pairs of rows, so use with caution, as the resulting matrix + // is most likely to run slower if very sparse to begin with. + // In the few cases where the blocks do mostly align, the resulting matmul + // could be much faster, as the number of reads of the rhs will be halved. + void DoubleBlockHeight() { + int new_rows = reduced_rows_ / 2; + std::vector new_nnz(new_rows); + std::vector new_rhs_indices; + std::vector new_weights; + int rhs_index1 = 0; + int rhs_index2 = 0; + int block_size = block_height_ * block_width_; + for (int r = 0; r < new_rows; ++r) { + int start_nnz = new_rhs_indices.size(); + rhs_index2 += nnz_per_row_[r * 2]; + int end1 = rhs_index1 + nnz_per_row_[r * 2]; + int end2 = rhs_index2 + nnz_per_row_[r * 2 + 1]; + // Run over a pair of rows with 2 iterators, combining blocks as we go, or + // padding with zeros where the block positions don't match. + while (rhs_index1 < end1 || rhs_index2 < end2) { + int col1 = rhs_index1 < end1 ? rhs_indices_[rhs_index1] : reduced_cols_; + int col2 = rhs_index2 < end2 ? rhs_indices_[rhs_index2] : reduced_cols_; + if (col1 < col2) { + // Need zero weights for row2 to pad out weights block. + new_rhs_indices.push_back(col1); + new_weights.insert(new_weights.end(), + weights_.data() + rhs_index1 * block_size, + weights_.data() + (rhs_index1 + 1) * block_size); + new_weights.insert(new_weights.end(), block_size, + static_cast(0.0f)); + ++rhs_index1; + } else if (col1 > col2) { + // Need zero weights for row1 to pad out weights block. + new_rhs_indices.push_back(col2); + new_weights.insert(new_weights.end(), block_size, + static_cast(0.0f)); + new_weights.insert(new_weights.end(), + weights_.data() + rhs_index2 * block_size, + weights_.data() + (rhs_index2 + 1) * block_size); + ++rhs_index2; + } else { + // Combine weights for both row1 and row2. + new_rhs_indices.push_back(col1); + new_weights.insert(new_weights.end(), + weights_.data() + rhs_index1 * block_size, + weights_.data() + (rhs_index1 + 1) * block_size); + new_weights.insert(new_weights.end(), + weights_.data() + rhs_index2 * block_size, + weights_.data() + (rhs_index2 + 1) * block_size); + ++rhs_index1; + ++rhs_index2; + } + } + rhs_index1 = rhs_index2; + new_nnz[r] = new_rhs_indices.size() - start_nnz; + } + block_height_ *= 2; + reduced_rows_ /= 2; + weights_ = CacheAlignedVector(new_weights); + rhs_indices_ = CacheAlignedVector(new_rhs_indices); + nnz_per_row_ = CacheAlignedVector(new_nnz); + sparsity_ = 1.0f - static_cast(new_weights.size()) / (rows_ * cols_); + ComputeColDeltas(); + if (num_threads_ > 0) { + int num_threads = num_threads_; + num_threads_ = 0; + PrepareForThreads(num_threads); + } + } + + // Allocates memory and fills buffer. + // Caller is responsible for the memory de-allocation. + // TODO(b/189958858): Both Read and Write need to eventually handle the + // different possible HalfType and DeltaType values, but punting for now as + // there is only one supported combination. + std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer) { + std::size_t bytes = 0; + bytes += FixedParameterSize(); + bytes += weights_.size() * sizeof(WeightType); + bytes += col_deltas_.size() * sizeof(DeltaType); + bytes += nnz_per_row_.size() * sizeof(int); + + uint8_t* bytes_ptr_ptr = + reinterpret_cast(CHECK_NOTNULL(malloc(bytes))); + + int* int_bytes_ptr = reinterpret_cast(bytes_ptr_ptr); + + *int_bytes_ptr++ = rows_; + *int_bytes_ptr++ = cols_; + *int_bytes_ptr++ = reduced_rows_; + *int_bytes_ptr++ = reduced_cols_; + *int_bytes_ptr++ = block_width_; + *int_bytes_ptr++ = block_height_; + *int_bytes_ptr++ = col_multiple_; + *int_bytes_ptr++ = num_threads_; + *int_bytes_ptr++ = weights_.size(); + *int_bytes_ptr++ = col_deltas_.size(); + *int_bytes_ptr++ = nnz_per_row_.size(); + + float* float_bytes_ptr = reinterpret_cast(int_bytes_ptr); + *float_bytes_ptr++ = sparsity_; + + uint8_t* bytes_ptr = reinterpret_cast(float_bytes_ptr); + + memcpy(bytes_ptr, weights_.data(), weights_.size() * sizeof(WeightType)); + bytes_ptr += weights_.size() * sizeof(WeightType); + + memcpy(bytes_ptr, col_deltas_.data(), + col_deltas_.size() * sizeof(DeltaType)); + bytes_ptr += col_deltas_.size() * sizeof(DeltaType); + + memcpy(bytes_ptr, nnz_per_row_.data(), nnz_per_row_.size() * sizeof(int)); + bytes_ptr += nnz_per_row_.size() * sizeof(int); + + csr_flatbuffer->resize(bytes); + csr_flatbuffer->assign(reinterpret_cast(bytes_ptr_ptr), bytes); + free(bytes_ptr_ptr); + + return bytes; + } + + void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len) { + CHECK_GE(len, FixedParameterSize()); + + const int* int_bytes_ptr = reinterpret_cast(bytes); + rows_ = *int_bytes_ptr++; + cols_ = *int_bytes_ptr++; + reduced_rows_ = *int_bytes_ptr++; + reduced_cols_ = *int_bytes_ptr++; + block_width_ = *int_bytes_ptr++; + block_height_ = *int_bytes_ptr++; + col_multiple_ = *int_bytes_ptr++; + int num_threads = *int_bytes_ptr++; + int32_t weights_size = *int_bytes_ptr++; + int32_t col_deltas_size = *int_bytes_ptr++; + int32_t nnz_per_row_size = *int_bytes_ptr++; + + // Make sure negative sizes don't mess things up. + weights_size = std::max(0, weights_size); + col_deltas_size = std::max(0, col_deltas_size); + nnz_per_row_size = std::max(0, nnz_per_row_size); + + const float* float_bytes_ptr = + reinterpret_cast(int_bytes_ptr); + sparsity_ = *float_bytes_ptr++; + + std::size_t total_bytes = + FixedParameterSize() + weights_size * sizeof(WeightType) + + col_deltas_size * sizeof(DeltaType) + nnz_per_row_size * sizeof(int); + + CHECK_EQ(total_bytes, len) + << "total bytes: " << total_bytes << ", actual len given: " << len; + + const uint8_t* bytes_ptr = + reinterpret_cast(float_bytes_ptr); + std::vector weights_raw(weights_size); + memcpy(weights_raw.data(), bytes_ptr, weights_size * sizeof(WeightType)); + weights_ = CacheAlignedVector(weights_raw); + bytes_ptr += weights_size * sizeof(WeightType); + + std::vector deltas_raw(col_deltas_size); + memcpy(deltas_raw.data(), bytes_ptr, col_deltas_size * sizeof(DeltaType)); + col_deltas_ = CacheAlignedVector(deltas_raw); + bytes_ptr += col_deltas_size * sizeof(DeltaType); + + std::vector nnz_raw(nnz_per_row_size); + memcpy(nnz_raw.data(), bytes_ptr, nnz_per_row_size * sizeof(int)); + nnz_per_row_ = CacheAlignedVector(nnz_raw); + num_threads_ = 0; + PrepareForThreads(num_threads); + } + + // Multiply a Sparse matrix by a possibly dense matrix. Often the matrix is + // a vector with a small number of columns, hence the term "fat vector". + // 1x1 and 4x4 have specializations for output columns (ie fatness) > 5, + // and often achieve twice as many GFlops when multiplying a right hand side + // that has 5 or more columns. (Best is a multiple of 5). + // 16x1 doesn't have enough registers and just loops over the width 1 kernel. + // + // |rhs| and |out| are COLUMN MAJOR. + + // Fast Tuples WeightType, BiasType, RhsType, OutType are: + // (float, float, float, float) + // (bfloat16, float, float, float) + // and only on ARM64. All other cases use a slow generic implementation. + template + void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out, + bool relu = false, int tid = 0, + SpinBarrier* barrier = nullptr) const { + static_assert(std::is_same::value, + "Rhs types must match"); + CHECK_LT(tid, num_threads_); + CHECK_EQ(rhs.cols(), out->cols()); + CHECK_EQ(rhs.rows(), cols_); + CHECK_GE(out->rows(), rows_); + int cols_to_go = out->cols(); + int rhs_index = *thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid); + const RhsType* rhs_ptr = rhs.data() + rhs_index * block_height_; + OutType* out_ptr = thread_bounds_.OffsetOutput(out->data(), tid); + const WeightType* weights_ptr = + thread_bounds_.OffsetWeights(weights_.data(), tid); + const DeltaType* delta_ptr = + thread_bounds_.OffsetRhsIndices(col_deltas_.data(), tid); + int offset = *delta_ptr / sizeof(RhsType); + rhs_ptr -= offset; + const int* nnz_ptr = nnz_per_row_.data() + thread_bounds_.StartRow(tid); + int assigned_rows = + thread_bounds_.StartRow(tid + 1) - thread_bounds_.StartRow(tid); + const BiasType* bias_ptr = thread_bounds_.OffsetBias(bias.data(), tid); + + while (cols_to_go > 0) { + if (block_width_ == 4 && block_height_ == 4) { + if (cols_to_go >= 5) { + detail::SpMM5_4x4( + weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, + assigned_rows, out->col_stride(), rhs.col_stride(), relu); + } else { + detail::SpMV_4x4( + weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, + assigned_rows, out->col_stride(), rhs.col_stride(), relu); + } + } else { + if (cols_to_go >= 5) { + detail::SpMM5_1x1( + weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, + assigned_rows, out->col_stride(), rhs.col_stride(), relu); + } else { + detail::SpMV_1x1( + weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, + assigned_rows, out->col_stride(), rhs.col_stride(), relu); + } + } + + if (cols_to_go >= 5) { + cols_to_go -= 5; + rhs_ptr += rhs.col_stride() * 5; + out_ptr += out->col_stride() * 5; + } else { + cols_to_go--; + rhs_ptr += rhs.col_stride(); + out_ptr += out->col_stride(); + } + if (barrier) barrier->barrier(); + } + } + template + void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid, + int replicas, int output_stride, OutType* output) { + CHECK_LT(tid, num_threads_); + CHECK_EQ(block_width_, 4) << "Block width must be 4!"; + if (block_height_ == 8) { + matmul_.MatVec8x4( + thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, + thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), + thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), + thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, + replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); + } else { + CHECK_EQ(block_height_, 4) << "Block height must be 4 or 8!"; + matmul_.MatVec4x4( + thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, + thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), + thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), + thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, + replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); + } + } + + int rows() const { return rows_; } + int cols() const { return cols_; } + int block_height() const { return block_height_; } + int block_width() const { return block_width_; } + float sparsity() const { return sparsity_; } + int num_threads() const { return num_threads_; } + const ThreadBounds& thread_bounds() const { return thread_bounds_; } + const CacheAlignedVector& rhs_indices() const { + return rhs_indices_; + } + const std::string& name() const { return name_; } + void set_name(const std::string& name) { name_ = name; } + const std::vector& split_points() const { + return thread_bounds_.row_starts(); + } + + std::size_t bytes() const { + return weights_.size() * sizeof(WeightType) + + col_deltas_.size() * sizeof(DeltaType) + + nnz_per_row_.size() * sizeof(int); + } + + // Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above, + // and then samples from the output (softmax distribution) layer. + template + typename std::enable_if::value, int>::type + SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, + float temperature, int tid, SpinBarrier* barrier, + std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + SpMM_bias(rhs, bias, out, /*relu=*/false, tid, barrier); + return out->Sample(temperature, gen, scratch); + } + // Fixed32 version. + template + typename std::enable_if::value, int>::type + SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, + float temperature, int tid, SpinBarrier* barrier, + std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + // We don't pass the barrier on, as we have more work to do. + SpMM_bias(rhs, bias, out, /*relu=*/false, tid); + return out->ReducingSample(gen, scratch, tid, temperature, barrier); + } + + void Print() const { + std::cout << "Weights\n"; + weights_.Print(); + std::cout << std::endl; + std::cout << "Deltas\n"; + col_deltas_.Print(); + std::cout << std::endl; + std::cout << "nnz\n"; + nnz_per_row_.Print(); + std::cout << std::endl; + } + + // Split the computation amongst threads by rows based on the number of + // non zeros, with the addition of a constant to account for the work of the + // bias and the horizontal add at the end, and also guarantees that each + // thread writes only whole cache lines, based on the size of OutType. + // The |cache_line_size| arg is used only for testing. Normally it is provided + // through the architecture #defines. + // Each thread gets a contiguous row range (|split_points|). + // Thread t does rows [ split_points[t], split_points[t + 1] ) + // Each thread also needs to know how many non zeros were before it to skip + // (|nnz_to_skip|). And finally it also needs to know what the offset into + // the rhs vector would have been at the split point (|rhs_to_skip|). + // + // Some tricky corner cases where the number of non-zeros doesn't split + // nicely amongst the number of requested threads are not handled and default + // to one thread; these cases are only going to happen in tests and not in + // the matrices that correspond in real models. + // + // Returns the maximum number of threads that can be used; <= |num_threads|. + template + int PrepareForThreads(int num_threads, int cache_line_size = -1) { + CHECK_GT(num_threads, 0); + // we've already prepared for this number of threads, nothing to do + if (num_threads == num_threads_) return num_threads_; + + num_threads_ = num_threads; + thread_bounds_.PrepareForThreads( + block_width_, block_height_, num_threads_, + ReducedRowsPerCacheLine(cache_line_size), reduced_rows_, + nnz_per_row_.data()); + return num_threads_; + } + + // Computes and stores the |rhs_indices_| from the |col_deltas_|. + void ComputeRHSIndices() { + std::vector cumulative_deltas = CumulativeColDeltas(); + std::vector rhs_indices(cumulative_deltas.size() + + reduced_rows_); + int total_indices = 0; + int delta_index = 0; + for (int r = 0; r < reduced_rows_; ++r) { + for (int n = 0; n < nnz_per_row_[r]; ++n, ++delta_index) { + rhs_indices[total_indices++] = + cumulative_deltas[delta_index] / block_width_; + } + } + rhs_indices_ = CacheAlignedVector(rhs_indices); + } + + // Computes and stores the |col_deltas_| from the |rhs_indices_|. + void ComputeColDeltas() { + std::vector col_deltas(rhs_indices_.size()); + int prev_index = 0; + for (int i = 0; i < rhs_indices_.size(); ++i) { + int offset = rhs_indices_[i] - prev_index; + prev_index = rhs_indices_[i]; + col_deltas[i] = offset * block_width_ * sizeof(RhsType); + } + col_deltas_ = CacheAlignedVector(col_deltas); + } + + // Computes and returns the inclusive prefix sum of the deltas, ie absolute + // positions. + std::vector CumulativeColDeltas() const { + std::vector cum_col_deltas(col_deltas_.size()); + for (int i = 0; i < col_deltas_.size(); ++i) { + cum_col_deltas[i] = col_deltas_[i] / sizeof(RhsType); + if (i > 0) cum_col_deltas[i] += cum_col_deltas[i - 1]; + } + return cum_col_deltas; + } + + private: + constexpr std::size_t FixedParameterSize() const { + return sizeof(int) // rows + + sizeof(int) // cols + + sizeof(int) // reduced_rows + + sizeof(int) // reduced_cols + + sizeof(int) // block_width + + sizeof(int) // block_height + + sizeof(float) // sparsity + + sizeof(int) // col_multiple + + sizeof(int) // num_threads_ + + sizeof(int) // weights_.size() + + sizeof(int) // col_deltas_.size() + + sizeof(int); // nnz_per_row_.size() + } + // Possible block sizes are only those that are supported by the computation + // default is 1x1, other options are 4x4 and 16x1. + template + void DetermineBlockSize(const MaskedSparseMatrix& masked_matrix) { + const std::vector> kPreferredOrder = {{4, 4}}; + int rows = masked_matrix.rows(); + int cols = masked_matrix.cols(); + + for (const auto& block_size : kPreferredOrder) { + int block_height, block_width; + std::tie(block_height, block_width) = block_size; + if (cols % block_width != 0) continue; + + int reduced_rows = (rows + block_height - 1) / block_height; + int reduced_cols = cols / block_width; + + // For each possible block, confirm that it is either all 0s or all 1s. + bool all_same = true; + const auto& mask = masked_matrix.mask(); + for (int r = 0; r < reduced_rows; ++r) { + for (int c = 0; c < reduced_cols; ++c) { + int val = mask[r * block_height * cols + c * block_width]; + for (int i = 0; i < block_height; ++i) { + for (int j = 0; j < block_width; ++j) { + int index = (r * block_height + i) * cols + c * block_width + j; + if (index < masked_matrix.mask().size()) { + all_same &= (masked_matrix.mask()[index] == val); + } + } + } + } + } + + // If this block configuration is possible, accept it. + if (all_same) { + block_height_ = block_height; + block_width_ = block_width; + return; + } + } + + // No large blocks were found, default to 1x1. + block_height_ = 1; + block_width_ = 1; + } + + // CSR descriptors are for the reduced matrix, weights is the full matrix. + template + void MakeColumnsMultiple(const std::vector& row_offsets, + std::vector* reduced_mask, + std::vector* weights) { + if (col_multiple_ > 0) { + // Make sure each row has a number of columns that is a multiple of + // |col_multiple|. + for (int r = 1; r < row_offsets.size(); ++r) { + int num_row = row_offsets[r] - row_offsets[r - 1]; + int num_needed = col_multiple_ - num_row % col_multiple_; + if (num_needed < col_multiple_) { + // Find gaps in the columns where we can insert a column of 0 weights. + int num_added = 0; + for (int c = 0; c < reduced_cols_; ++c) { + if ((*reduced_mask)[(r - 1) * reduced_cols_ + c] == 0) { + (*reduced_mask)[(r - 1) * reduced_cols_ + c] = 1; + + // Zero out the weights that correspond to this block. + for (int i = 0; i < block_height_; ++i) { + for (int j = 0; j < block_width_; ++j) { + (*weights)[((r - 1) * block_height_ + i) * cols_ + + block_width_ * c + j] = InputType(0.f); + } + } + num_added++; + } + + if (num_added == num_needed) break; + } + } + } + } + } + + // Given the final dense mask and weights, convert to the compressed + // block CSR representation. + template + void MaskAndWeightsToCsr(const std::vector& mask, + const std::vector& weights, + std::vector* nnz_per_row, + std::vector* col_indices, + std::vector* weights_csr) { + std::vector row_offsets = {0}; + int nnz = 0; + // Standard CSR format. + if (block_width_ == 1 && block_height_ == 1) { + for (int r = 0; r < rows_; ++r) { + for (int c = 0; c < cols_; ++c) { + if (mask[r * cols_ + c] == 1) { + nnz++; + col_indices->push_back(c); + weights_csr->push_back(WeightType(weights[r * cols_ + c])); + } + } + row_offsets.push_back(nnz); + } + } else if (block_width_ == 4 && block_height_ == 4) { + // Weights are stored contiguously for each block in this case. + for (int r = 0; r < reduced_rows_; ++r) { + for (int c = 0; c < reduced_cols_; ++c) { + if (mask[r * reduced_cols_ + c] == 1) { + col_indices->push_back(c); + nnz++; + for (int i = 0; i < block_height_; ++i) { + for (int j = 0; j < block_width_; ++j) { + int row_index = (block_height_ * r + i) * cols_; + int w_index = row_index + block_width_ * c + j; + WeightType weight = w_index < weights.size() + ? WeightType(weights[w_index]) + : WeightType(0.0f); + weights_csr->push_back(weight); + } + } + } + } + row_offsets.push_back(nnz); + } + } + for (int i = 1; i < row_offsets.size(); ++i) + nnz_per_row->push_back(row_offsets[i] - row_offsets[i - 1]); + } + + // Returns the number of block rows per cache line. This is the minimum unit + // into which the calculation is broken for threads. + template + int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const { + int line_size = kCacheLineSize; + if (override_cache_line_size >= 1) line_size = override_cache_line_size; + return std::max(line_size / (block_height_ * sizeof(OutType)), 1); + } + + int col_multiple_; + int rows_; + int cols_; + int reduced_rows_; + int reduced_cols_; + float sparsity_; + int block_width_; + int block_height_; + int num_threads_; + std::string name_; + + CacheAlignedVector weights_; + CacheAlignedVector col_deltas_; + CacheAlignedVector nnz_per_row_; + // |thread_bounds_| and |rhs_indices_| don't need to be serialized as they are + // always recalculated from serialized data. + CacheAlignedVector rhs_indices_; + Matmul matmul_; + ThreadBounds thread_bounds_; + static constexpr int kCacheLineSize = 64; +}; + +// Converts a sparse matrix represented with (|mask|, |weights|, |size|) into +// the CSR format, and returns that as a serialized string. +template +std::string ConvertDenseToSparseRepresentation_Int16Deltas( + const std::vector& mask, const std::vector& weights, + const int rows, const int cols) { + MaskedSparseMatrix masked_weights(rows, cols, mask.data(), + weights.data()); + CsrBlockSparseMatrix + sparse_masked_weights(masked_weights); + std::string buffer; + sparse_masked_weights.WriteToFlatBuffer(&buffer); + return buffer; +} + +} // namespace csrblocksparse +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_ diff --git a/sparse_matmul/layers/csrblocksparse_test.cc b/sparse_matmul/layers/csrblocksparse_test.cc new file mode 100644 index 00000000..08a42ca3 --- /dev/null +++ b/sparse_matmul/layers/csrblocksparse_test.cc @@ -0,0 +1,977 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +// Placeholder for get runfiles header. +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "include/ghc/filesystem.hpp" +#include "sparse_matmul/compute/matmul.h" +#include "sparse_matmul/layers/utils.h" +#include "sparse_matmul/numerics/test_utils.h" +#include "sparse_matmul/os/coop_threads.h" + +namespace csrblocksparse { +namespace { + +inline constexpr absl::string_view kTestdataPath = "layers/testdata"; + +TEST(CSRBlockSparseMatrix, FlatBufferSerialization) { + const int kRows = 8; + const int kCols = 8; + std::vector mask = {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}; + std::vector values(kRows * kCols, 1.f); + values[1] = 2.f; + values[3] = 3.f; + values[36] = -1.f; + values[45] = -2.f; + + csrblocksparse::CacheAlignedVector bias(kRows); + csrblocksparse::CacheAlignedVector rhs(kCols); + csrblocksparse::CacheAlignedVector out_ref(kRows); + csrblocksparse::CacheAlignedVector out_test(kRows); + + bias.FillZero(); + rhs.FillOnes(); + + csrblocksparse::MaskedSparseMatrix matrix(kRows, kCols, mask.data(), + values.data()); + + matrix.SpMM_bias(rhs, bias, &out_ref); + + csrblocksparse::CsrBlockSparseMatrix + block_sparse_matrix(matrix); + + std::string buffer; + std::size_t num_bytes = block_sparse_matrix.WriteToFlatBuffer(&buffer); + + csrblocksparse::CsrBlockSparseMatrix + new_block_sparse_matrix(reinterpret_cast(buffer.c_str()), + num_bytes); + + new_block_sparse_matrix.SpMM_bias(rhs, bias, &out_test); + + CheckResult(out_ref, out_test, kCols); +} + +template +void CorrectnessCheckBlockSpMM(int rows, int cols, int block_height, + int block_width, float sparsity, + bool use_relu = false, int num_threads = 1, + int fatness = 1, bool test_matmul = false) { + using BiasType = typename TypeOfProduct::type; + MaskedSparseMatrix matrix(rows, cols, sparsity, block_height, + block_width); + matrix.CastWeights(); + FatCacheAlignedVector rhs(cols, fatness); + CacheAlignedVector bias(rows); + FatCacheAlignedVector out(rows, fatness); + + bias.FillRandom(); + rhs.FillRandom(); + out.FillZero(); + FatCacheAlignedVector out_reference = out; + + matrix.SpMM_bias(rhs, bias, &out_reference, use_relu); + + CsrBlockSparseMatrix sparse_matrix(matrix); + + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias)); + num_threads = sparse_linear_layer.PrepareForThreads(num_threads); + + // Checks that the result of applying each thread's portion serially is + // correct. + for (int thread_id = 0; thread_id < num_threads; ++thread_id) { + sparse_linear_layer.SpMM_bias(rhs, &out, use_relu, thread_id); + } + + CheckResult(out_reference, out, sparse_linear_layer.cols()); + + if (test_matmul) { + for (int thread_id = 0; thread_id < num_threads; ++thread_id) { + sparse_linear_layer.MatVec(rhs, use_relu, thread_id, + /*replicas=*/1, /*output_stride=*/0, &out); + } + + CheckResult(out_reference, out, sparse_linear_layer.cols()); + } +} + +// Does: +// y = Ax + b; +// x = Ay + b; +// y = Ax + b; +// +// to make sure that dependent multiplies are correct. +template +void ThreadBody( + SpinBarrier* spin_barrier, int tid, + const SparseLinearLayer& sparse_linear_layer, + FatCacheAlignedVector* rhs, FatCacheAlignedVector* out, + bool use_relu) { + sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid); + spin_barrier->barrier(); + sparse_linear_layer.SpMM_bias(*out, rhs, use_relu, tid); + spin_barrier->barrier(); + sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid); +} + +template +void CorrectnessCheckBlockSpMM_MultiThread(int rows, int cols, int block_height, + int block_width, float sparsity, + bool use_relu = false, + int num_threads = 1, + int fatness = 1) { + typedef typename TypeOfProduct::type BiasType; + CHECK(rows == cols); + MaskedSparseMatrix matrix(rows, cols, sparsity, block_height, + block_width); + matrix.CastWeights(); + FatCacheAlignedVector rhs(cols, fatness); + FatCacheAlignedVector rhs_mt(cols, fatness); + CacheAlignedVector bias(rows); + FatCacheAlignedVector out(rows, fatness); + + bias.FillOnes(); + rhs.FillOnes(); + rhs_mt.FillOnes(); + out.FillZero(); + FatCacheAlignedVector out_reference = out; + + matrix.SpMM_bias(rhs, bias, &out_reference, use_relu); + matrix.SpMM_bias(out_reference, bias, &rhs, use_relu); + matrix.SpMM_bias(rhs, bias, &out_reference, use_relu); + + CsrBlockSparseMatrix sparse_matrix(matrix); + + num_threads = sparse_matrix.PrepareForThreads(num_threads, + /*cache_line_size=*/1); + + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias)); + + csrblocksparse::LaunchOnThreadsWithBarrier( + num_threads, ThreadBody, + sparse_linear_layer, &rhs_mt, &out, use_relu); + + CheckResult(out_reference, out, cols); +} + +} // namespace + +TEST(MaskedSparseCorrectness, HandCoded) { + const int kRows = 8; + const int kCols = 8; + // clang-format off + std::vector mask = {1, 1, 0, 0, 0, 1, 1, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 1, 0, 0, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 1, 0, 1}; + // clang-format on + std::vector values(kRows * kCols, 1.f); + + std::vector answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f}; + + MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); + CacheAlignedVector rhs(kCols); + CacheAlignedVector bias(kRows); + CacheAlignedVector out(kRows); + + bias.FillOnes(); + rhs.FillOnes(); + out.FillZero(); + + MaskedLinearLayer masked_linear_layer(std::move(matrix), + std::move(bias)); + + masked_linear_layer.SpMM_bias(rhs, &out); + + for (int i = 0; i < kRows; ++i) { + EXPECT_EQ(answer[i], out[i]); + } +} + +TEST(MaskedSparseCorrectness, HandCodedFatVector) { + const int kRows = 8; + const int kCols = 8; + // clang-format off + std::vector mask = {1, 1, 0, 0, 0, 1, 1, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 1, 0, 0, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 1, 0, 1}; + // clang-format on + + std::vector values(kRows * kCols, 1.f); + std::vector answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f}; + + MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); + const int kMaxWidth = 5; + for (int width = 5; width <= kMaxWidth; ++width) { + FatCacheAlignedVector rhs(kCols, width); + CacheAlignedVector bias(kRows); + FatCacheAlignedVector out(kRows, width); + + bias.FillOnes(); + rhs.FillOnes(); + out.FillZero(); + + MaskedLinearLayer masked_linear_layer(std::move(matrix), + std::move(bias)); + + masked_linear_layer.SpMM_bias(rhs, &out); + + for (int i = 0; i < kRows; ++i) { + for (int width = 0; width < kMaxWidth; ++width) { + EXPECT_EQ(answer[i], out[i + width * kRows]); + } + } + } +} + +TEST(CsrBlockSparseMatrix, HandCodedMultiThread) { + const int kRows = 8; + const int kCols = 8; + // clang-format off + std::vector mask = {1, 1, 0, 0, 0, 1, 1, 1, + 0, 1, 0, 1, 0, 1, 0, 1, + 1, 0, 0, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 0, 0, + 1, 1, 0, 0, 1, 1, 0, 0, + 1, 0, 0, 0, 0, 1, 0, 1}; + // clang-format on + std::vector values(kRows * kCols, 1.f); + + std::vector answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f}; + + MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); + CacheAlignedVector rhs(kCols); + CacheAlignedVector bias(kRows); + CacheAlignedVector out(kRows); + + bias.FillOnes(); + rhs.FillOnes(); + out.FillZero(); + + CacheAlignedVector bias_csr = bias; + + CsrBlockSparseMatrix sparse_matrix(matrix); + + MaskedLinearLayer masked_linear_layer(std::move(matrix), + std::move(bias)); + + masked_linear_layer.SpMM_bias(rhs, &out); + + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias_csr)); + sparse_linear_layer.PrepareForThreads(2, /*cache_line_size=*/1); + + CacheAlignedVector out_tmp(kRows); + const bool kUseRelu = false; + sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/0); + sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/1); + + for (int i = 0; i < kRows; ++i) { + EXPECT_EQ(answer[i], out_tmp[i]); + } +} + +TEST(TestCasts, TestBfloat16) { + const int kRows = 1000; + const int kCols = 100; + const float kSparsity = 0.f; + + MaskedSparseMatrix matrix(kRows, kCols, kSparsity); + MaskedSparseMatrix matrix_bfloat16(kRows, kCols, matrix.mask().data(), + matrix.values().data()); + + matrix_bfloat16.CastWeights(); + + CheckResult(matrix.values(), matrix_bfloat16.values(), kCols); +} + +TEST(TestCasts, TestFP16) { + const int kRows = 1000; + const int kCols = 100; + const float kSparsity = 0.f; + + MaskedSparseMatrix matrix(kRows, kCols, kSparsity); +#if !defined __arm__ && !defined __aarch64__ + // Conversion doesn't handle denormals, so flush denormals to zero first. + for (int i = 0; i < matrix.values().size(); ++i) { + if (matrix.data()[i] < 1. / static_cast(1 << 14)) + matrix.data()[i] = 0.f; + } +#endif + MaskedSparseMatrix matrix_fp16(kRows, kCols, matrix.mask().data(), + matrix.values().data()); + + matrix_fp16.CastWeights(); + + CheckResult(matrix.values(), matrix_fp16.values(), kCols); +} + +TEST(TestCasts, TestFixed16) { + const int kRows = 100000; + const int kCols = 1; + const float kSparsity = 0.f; + + MaskedSparseMatrix matrix(kRows, kCols, kSparsity); + + // Relative error for fixed point is high near 0. + for (int i = 0; i < matrix.values().size(); ++i) { + // 1.1e-3 is based on the max error of .013 and a grid spacing of 1 / 2**16 + // == 3e-5. 3e-5 / .013 / 2 = 1.1e-3. + if (std::abs(matrix.data()[i]) < 1.1e-3) { + matrix.data()[i] = 0.f; + } + } + + MaskedSparseMatrix matrix_fixed16 = matrix; + + matrix_fixed16.CastWeights>(); + + CheckResult(matrix.values(), matrix_fixed16.values(), kCols); +} + +TEST(TestCasts, TestFixed32) { + const int kRows = 100000; + const int kCols = 1; + const float kSparsity = 0.f; + + MaskedSparseMatrix matrix(kRows, kCols, kSparsity); + MaskedSparseMatrix matrix_fixed32 = matrix; + + matrix_fixed32.CastWeights>(); + + CheckResult(matrix.values(), matrix_fixed32.values(), kCols); +} + +template +void TestSpMM(int block_width, int block_height, int fatness, + bool test_matmul = false) { + std::array use_relu = {false, true}; + std::vector sparsity_levels = {.5, .8, .9, .95, .98}; + std::vector> sizes = {{8, 8}, {128, 128}, {128, 64}, + {256, 192}, {512, 512}, {1024, 512}, + {384, 384}, {512, 384}}; + for (int num_threads = 1; num_threads < 2 + test_matmul; ++num_threads) { + for (const auto& relu : use_relu) { + for (const auto& sparsity : sparsity_levels) { + for (const auto& size : sizes) { + int rows, cols; + std::tie(rows, cols) = size; + CorrectnessCheckBlockSpMM( + rows, cols, block_height, block_width, sparsity, relu, + num_threads, fatness, test_matmul); + } + } + } + } +} + +template +void TestSpMM_MultiThread(int block_width, int block_height, int fatness) { + std::array use_relu = {false, true}; + std::vector sparsity_levels = {.5, .8, .9, .95, .98}; + std::vector> sizes = { + {48, 48}, {128, 128}, {512, 512}, {384, 384}}; + for (int num_threads = 1; num_threads < 5; ++num_threads) { + for (const auto& relu : use_relu) { + for (const auto& sparsity : sparsity_levels) { + for (const auto& size : sizes) { + int rows, cols; + std::tie(rows, cols) = size; + CorrectnessCheckBlockSpMM_MultiThread( + rows, cols, block_height, block_width, sparsity, relu, + num_threads, fatness); + } + } + } + } +} + +template +void TestSumVectors(int start = 0, int end = -1, int size = 6) { + std::vector values; + std::vector answer; + + for (int i = 1; i < size + 1; ++i) { + const float x = static_cast(i); + values.push_back(static_cast(x)); + answer.push_back(static_cast(x * 2)); + } + + if (end == -1) { + end = values.size(); + } + + csrblocksparse::CacheAlignedVector result(values.size()); + csrblocksparse::CacheAlignedVector values_aligned(values); + detail::SumVectors(start, end, values_aligned.data(), values_aligned.data(), + result.data()); + for (int i = start; i < end; ++i) { + EXPECT_EQ(static_cast(answer[i]), static_cast(result[i])); + } +} + +TEST(CsrBlockSparseMatrix, SumVectors_Generic) { + TestSumVectors(); + TestSumVectors(1); + TestSumVectors(1, 4); +} + +TEST(CsrBlockSparseMatrix, SumVectors_Bfloat16) { + TestSumVectors(); + TestSumVectors(1); + TestSumVectors(1, 4); +} + +// For SIMD-optimized SumVectors, the memory of the vector should be at least +// |kSIMDWidth * sizeof(float)| long, and the start position has to be an +// aligned memory location. So setting |size| to be 100 to be safe and +// |start| to be 0 (|start| == 1 is not aligned). +TEST(CsrBlockSparseMatrix, SumVectors_Fixed16) { + TestSumVectors>(0, -1, 100); + TestSumVectors>(0, 4, 100); +} + +TEST(CsrBlockSparseMatrix, SumVectors_Fixed32) { + TestSumVectors>(0, -1, 100); + TestSumVectors>(0, 4, 100); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_Bfloat16) { + TestSpMM(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +// This actually uses multiple threads, and uses the output as the input for +// multiple steps to test that synchronization and memory visibility is +// working correctly.Requires square matrices. +TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_Bfloat16) { + TestSpMM_MultiThread( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_Bfloat16) { + TestSpMM_MultiThread( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_Bfloat16) { + TestSpMM(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_Bfloat16) { + TestSpMM(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +// This actually uses multiple threads, and uses the output as the input for +// multiple steps to test that synchronization and memory visibility is +// working correctly.Requires square matrices. +TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_Bfloat16) { + TestSpMM_MultiThread( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_Bfloat16) { + TestSpMM_MultiThread( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block4x4_float) { + TestSpMM(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1, + /*test_matmul=*/true); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_float) { + TestSpMM(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +// This actually uses multiple threads, and uses the output as the input for +// multiple steps to test that synchronization and memory visibility is +// working correctly.Requires square matrices. +TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_float) { + TestSpMM_MultiThread(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_float) { + TestSpMM_MultiThread(/*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_float) { + TestSpMM(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_float) { + TestSpMM(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +// This actually uses multiple threads, and uses the output as the input for +// multiple steps to test that synchronization and memory visibility is +// working correctly.Requires square matrices. +TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_float) { + TestSpMM_MultiThread(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_float) { + TestSpMM_MultiThread(/*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32) { + TestSpMM, csrblocksparse::fixed16<4>, + typename csrblocksparse::TypeOfProduct< + csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1, + /*test_matmul=*/true); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32) { + TestSpMM, csrblocksparse::fixed16<4>, + typename csrblocksparse::TypeOfProduct< + csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32) { + TestSpMM, csrblocksparse::fixed16<4>, + typename csrblocksparse::TypeOfProduct< + csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32) { + TestSpMM, csrblocksparse::fixed16<4>, + typename csrblocksparse::TypeOfProduct< + csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_16) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed16<8>>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1, + /*test_matmul=*/true); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_16) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed16<8>>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_16) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed16<8>>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_16) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed16<8>>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32_unmatched) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed32<13>>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/1, + /*test_matmul=*/true); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32_unmatched) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed32<13>>( + /*block_width=*/4, + /*block_height=*/4, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32_unmatched) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed32<13>>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/1); +} + +TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32_unmatched) { + TestSpMM, csrblocksparse::fixed16<5>, + csrblocksparse::fixed32<13>>( + /*block_width=*/1, + /*block_height=*/1, + /*fatness=*/7); +} + +TEST(CsrBlockSparseMatrix, RhsIndicesDeltasRoundTrip) { + MaskedSparseMatrix matrix(/*rows=*/256, /*cols=*/256, + /*sparsity=*/0.9, /*block_height=*/4, + /*block_width=*/4); + CsrBlockSparseMatrix sparse_matrix(matrix); + CacheAlignedVector copy_indices = sparse_matrix.rhs_indices(); + sparse_matrix.ComputeColDeltas(); + sparse_matrix.ComputeRHSIndices(); + // They get padded when created, so the newer one could be bigger. + EXPECT_LE(copy_indices.size(), sparse_matrix.rhs_indices().size()); + for (int i = 0; i < copy_indices.size(); ++i) { + EXPECT_EQ(copy_indices[i], sparse_matrix.rhs_indices()[i]) << "i=" << i; + } +} + +// Tests that a Layer that is split into 2 by columns (inputs) computes the same +// result as the original layer. +TEST(CsrBlockSparseMatrix, SplitByCol) { + int kRows = 1024; + int kCols = 1024; + MaskedSparseMatrix matrix(kRows, kCols, 0.95, /*block_height=*/4, + /*block_width=*/4); + FatCacheAlignedVector rhs(kCols, /*cols=*/1); + CacheAlignedVector bias(kRows); + FatCacheAlignedVector out1(kRows, /*cols=*/1); + FatCacheAlignedVector out2(kRows, /*cols=*/1); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + out2.FillZero(); + FatCacheAlignedVector out_reference = out1; + + CsrBlockSparseMatrix sparse_matrix(matrix); + + SparseLinearLayer sparse_linear_layer(std::move(sparse_matrix), + std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false, + /*tid=*/0); + // Split the layer into 2 parts. + SparseLinearLayer part1, part2; + sparse_linear_layer.SplitInputs(&part1, &part2); + part1.PrepareForThreads(1); + part2.PrepareForThreads(1); + EXPECT_EQ(kRows, part1.rows()); + EXPECT_EQ(kCols / 2, part1.cols()); + EXPECT_EQ(kRows, part2.rows()); + EXPECT_EQ(kCols / 2, part2.cols()); + MutableVectorView rhs1(&rhs, 0, kCols / 2); + MutableVectorView rhs2(&rhs, kCols / 2, kCols / 2); + for (int i = 0; i < kCols / 2; ++i) { + EXPECT_FLOAT_EQ(rhs[i], rhs1.data()[i]); + EXPECT_FLOAT_EQ(rhs[i + kCols / 2], rhs2.data()[i]); + } + part1.SpMM_bias(rhs1, &out1, /*relu=*/false, /*tid=*/0); + part2.SpMM_bias(rhs2, &out2, /*relu=*/false, /*tid=*/0); + // Check that out1 + out2 = out_reference. + for (int i = 0; i < kRows; ++i) { + EXPECT_NEAR(out_reference[i], out1[i] + out2[i], 2e-5) + << " i=" << i << " out1=" << out1[i] << " out2=" << out2[i]; + } +} +// Tests that a Layer that is split into 2 by rows (outputs) computes the same +// result as the original layer. +TEST(CsrBlockSparseMatrix, SplitByRow) { + int kRows = 1024; + int kCols = 1024; + MaskedSparseMatrix matrix(kRows, kCols, 0.95, /*block_height=*/4, + /*block_width=*/4); + FatCacheAlignedVector rhs(kCols, /*cols=*/1); + CacheAlignedVector bias(kRows); + FatCacheAlignedVector out1(kRows, /*cols=*/1); + FatCacheAlignedVector out2(kRows, /*cols=*/1); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + out2.FillZero(); + FatCacheAlignedVector out_reference = out1; + + CsrBlockSparseMatrix sparse_matrix(matrix); + + SparseLinearLayer sparse_linear_layer(std::move(sparse_matrix), + std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false, + /*tid=*/0); + // Split the layer into 2 parts. + SparseLinearLayer part1, part2; + sparse_linear_layer.SplitOutputs(&part1, &part2); + part1.PrepareForThreads(1); + part2.PrepareForThreads(1); + EXPECT_EQ(kRows / 2, part1.rows()); + EXPECT_EQ(kCols, part1.cols()); + EXPECT_EQ(kRows / 2, part2.rows()); + EXPECT_EQ(kCols, part2.cols()); + MutableVectorView out2a(&out2, 0, kRows / 2); + MutableVectorView out2b(&out2, kRows / 2, kRows / 2); + part1.SpMM_bias(rhs, &out2a, /*relu=*/false, /*tid=*/0); + part2.SpMM_bias(rhs, &out2b, /*relu=*/false, /*tid=*/0); + // Check that out2 = out_reference. + for (int i = 0; i < kRows; ++i) { + EXPECT_NEAR(out_reference[i], out2[i], 2e-5) + << " i=" << i << " out1=" << out_reference[i] << " out2=" << out2[i]; + } +} + +TEST(CsrBlockSparseMatrix, MutableVectorView) { + const int kRows = 1024; + const int kCols = 1024; + const int kFatness = 2; + + std::vector values(kRows * kCols, 1.f); + std::vector mask(kRows * kCols); + for (int i = 0; i < mask.size(); ++i) mask[i] = i % 2; + + auto masked_matrix = + MaskedSparseMatrix(kRows, kCols, mask.data(), values.data()); + auto sparse_matrix = CsrBlockSparseMatrix(masked_matrix); + FatCacheAlignedVector x(kCols, kFatness); + x.FillOnes(); + + CacheAlignedVector bias(kRows); + bias.FillZero(); + + // First check that we can use spans as output. Split a multiplication + // into upper and lower halves times the full vector: + // --------------- x t + // | | x t + // | | x t + // --------------- = + // | | x b + // | | x b + // --------------- x b + + FatCacheAlignedVector out(kRows, kFatness); + FatCacheAlignedVector out_view(kRows, kFatness); + + MutableVectorView out_view_top(&out_view, 0, kRows / 2); + MutableVectorView out_view_bottom(&out_view, kRows / 2, kRows / 2); + + sparse_matrix.SpMM_bias(x, bias, &out); + + auto masked_matrix_top = + MaskedSparseMatrix(kRows / 2, kCols, mask.data(), values.data()); + auto masked_matrix_bottom = MaskedSparseMatrix( + kRows / 2, kCols, mask.data() + kRows * kCols / 2, + values.data() + kRows * kCols / 2); + auto sparse_matrix_top = + CsrBlockSparseMatrix(masked_matrix_top); + auto sparse_matrix_bottom = + CsrBlockSparseMatrix(masked_matrix_bottom); + + sparse_matrix_top.SpMM_bias(x, bias, &out_view_top); + sparse_matrix_bottom.SpMM_bias(x, bias, &out_view_bottom); + + CheckResult(out, out_view, kCols); + + // Check that we can use a span as an input vector. Multiply upper left + // portion of the matrix by the top half of the vector. + // --------------- + // |oooooo | x q + // |oooooo | x q + // | | = + // | | + // --------------- + + auto masked_matrix_quarter = MaskedSparseMatrix( + kRows / 2, kCols / 2, mask.data(), values.data()); + auto sparse_matrix_quarter = + CsrBlockSparseMatrix(masked_matrix_quarter); + + MutableVectorView x_top(&x, 0, kCols / 2); + FatCacheAlignedVector out_correct(kRows / 2, /*cols=*/2); + + for (int i = 0; i < kFatness * (kRows / 2); ++i) out_correct[i] = 256.f; + + MutableVectorView bias_top(&bias, 0, kRows / 2); + FatCacheAlignedVector out_quarter(kRows / 2, kFatness); + + sparse_matrix_quarter.SpMM_bias(x_top, bias_top, &out_quarter); + + CheckResult(out_correct, out_quarter, kCols / 2); +} + +namespace { + +bool skip_test(const absl::Status& status, absl::string_view msg) { + if (!status.ok()) { + LOG(INFO) << "Couldn't load " << msg << ", skipping test " << status; + return true; + } + + return false; +} + +} // namespace + +TEST(CsrBlockSparseMatrix, ModelMatrices_Bfloat16) { + std::vector names = { + "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_", + "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_", + "768_512_95_4x4_finelogit_", "lyra_conv1d_"}; + const std::string kPath = +#if defined __arm__ || defined __aarch64__ + "/data/local/tmp/"; +#else + (ghc::filesystem::current_path() / kTestdataPath).string(); +#endif + for (auto& layer_name : names) { + SparseLinearLayer sparse_linear_layer; + auto status = LoadSparseLayer(layer_name, /*zipped=*/true, + &sparse_linear_layer, kPath); + // If the files don't exist on the device we're running on, just skip this + // test and log that it was skipped. + if (skip_test(status, layer_name)) return; + + int rows = sparse_linear_layer.rows(); + int cols = sparse_linear_layer.cols(); + + MaskedLinearLayer masked_linear_layer; + status = LoadMaskedLayer(layer_name, /*zipped=*/true, + &masked_linear_layer, kPath); + if (skip_test(status, layer_name)) return; + masked_linear_layer.CastWeights(); + + CacheAlignedVector rhs(cols); + CacheAlignedVector out_ref(rows); + CacheAlignedVector out_spmv(rows); + + rhs.FillRandom(); + out_ref.FillZero(); + out_spmv.FillZero(); + + std::array use_relus = {false, true}; + for (bool use_relu : use_relus) { + masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu); + sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu); + + CheckResult(out_ref, out_spmv, cols); + } + } +} + +TEST(CsrBlockSparseMatrix, ModelMatrices_float) { + std::vector names = { + "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_", + "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_", + "768_512_95_4x4_finelogit_", "lyra_conv1d_"}; + const std::string kPath = +#if defined __arm__ || defined __aarch64__ + "/data/local/tmp/"; +#else + (ghc::filesystem::current_path() / kTestdataPath).string(); +#endif + for (auto& layer_name : names) { + SparseLinearLayer sparse_linear_layer; + auto status = LoadSparseLayer(layer_name, /*zipped=*/true, + &sparse_linear_layer, kPath); + // If the files don't exist on the device we're running on, just skip this + // test and log that it was skipped. + if (skip_test(status, layer_name)) return; + + int rows = sparse_linear_layer.rows(); + int cols = sparse_linear_layer.cols(); + + MaskedLinearLayer masked_linear_layer; + status = LoadMaskedLayer(layer_name, /*zipped=*/true, + &masked_linear_layer, kPath); + if (skip_test(status, layer_name)) return; + + CacheAlignedVector rhs(cols); + CacheAlignedVector out_ref(rows); + CacheAlignedVector out_spmv(rows); + + rhs.FillRandom(); + out_ref.FillZero(); + out_spmv.FillZero(); + + std::array use_relus = {false, true}; + for (bool use_relu : use_relus) { + masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu); + sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu); + + CheckResult(out_ref, out_spmv, cols); + } + } +} + +#undef SKIP_TEST + +} // namespace csrblocksparse diff --git a/sparse_matmul/layers/errno_mapping.cc b/sparse_matmul/layers/errno_mapping.cc new file mode 100644 index 00000000..558abb33 --- /dev/null +++ b/sparse_matmul/layers/errno_mapping.cc @@ -0,0 +1,195 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/layers/errno_mapping.h" + +#include + +#include "absl/strings/str_cat.h" + +namespace csrblocksparse { + +namespace { + +absl::StatusCode ErrnoToCode(int error_number) { + switch (error_number) { + case 0: + return absl::StatusCode::kOk; + case EINVAL: // Invalid argument + case ENAMETOOLONG: // Filename too long + case E2BIG: // Argument list too long + case EDESTADDRREQ: // Destination address required + case EDOM: // Mathematics argument out of domain of function + case EFAULT: // Bad address + case EILSEQ: // Illegal byte sequence + case ENOPROTOOPT: // Protocol not available + case ENOSTR: // Not a STREAM + case ENOTSOCK: // Not a socket + case ENOTTY: // Inappropriate I/O control operation + case EPROTOTYPE: // Protocol wrong type for socket + case ESPIPE: // Invalid seek + return absl::StatusCode::kInvalidArgument; + case ETIMEDOUT: // Connection timed out + case ETIME: // Timer expired + return absl::StatusCode::kDeadlineExceeded; + case ENODEV: // No such device + case ENOENT: // No such file or directory +#ifdef ENOMEDIUM + case ENOMEDIUM: // No medium found +#endif + case ENXIO: // No such device or address + case ESRCH: // No such process + return absl::StatusCode::kNotFound; + case EEXIST: // File exists + case EADDRNOTAVAIL: // Address not available + case EALREADY: // Connection already in progress +#ifdef ENOTUNIQ + case ENOTUNIQ: // Name not unique on network +#endif + return absl::StatusCode::kAlreadyExists; + case EPERM: // Operation not permitted + case EACCES: // Permission denied +#ifdef ENOKEY + case ENOKEY: // Required key not available +#endif + case EROFS: // Read only file system + return absl::StatusCode::kPermissionDenied; + case ENOTEMPTY: // Directory not empty + case EISDIR: // Is a directory + case ENOTDIR: // Not a directory + case EADDRINUSE: // Address already in use + case EBADF: // Invalid file descriptor +#ifdef EBADFD + case EBADFD: // File descriptor in bad state +#endif + case EBUSY: // Device or resource busy + case ECHILD: // No child processes + case EISCONN: // Socket is connected +#ifdef EISNAM + case EISNAM: // Is a named type file +#endif +#ifdef ENOTBLK + case ENOTBLK: // Block device required +#endif + case ENOTCONN: // The socket is not connected + case EPIPE: // Broken pipe +#ifdef ESHUTDOWN + case ESHUTDOWN: // Cannot send after transport endpoint shutdown +#endif + case ETXTBSY: // Text file busy +#ifdef EUNATCH + case EUNATCH: // Protocol driver not attached +#endif + return absl::StatusCode::kFailedPrecondition; + case ENOSPC: // No space left on device +#ifdef EDQUOT + case EDQUOT: // Disk quota exceeded +#endif + case EMFILE: // Too many open files + case EMLINK: // Too many links + case ENFILE: // Too many open files in system + case ENOBUFS: // No buffer space available + case ENODATA: // No message is available on the STREAM read queue + case ENOMEM: // Not enough space + case ENOSR: // No STREAM resources +#ifdef EUSERS + case EUSERS: // Too many users +#endif + return absl::StatusCode::kResourceExhausted; +#ifdef ECHRNG + case ECHRNG: // Channel number out of range +#endif + case EFBIG: // File too large + case EOVERFLOW: // Value too large to be stored in data type + case ERANGE: // Result too large + return absl::StatusCode::kOutOfRange; +#ifdef ENOPKG + case ENOPKG: // Package not installed +#endif + case ENOSYS: // Function not implemented + case ENOTSUP: // Operation not supported + case EAFNOSUPPORT: // Address family not supported +#ifdef EPFNOSUPPORT + case EPFNOSUPPORT: // Protocol family not supported +#endif + case EPROTONOSUPPORT: // Protocol not supported +#ifdef ESOCKTNOSUPPORT + case ESOCKTNOSUPPORT: // Socket type not supported +#endif + case EXDEV: // Improper link + return absl::StatusCode::kUnimplemented; + case EAGAIN: // Resource temporarily unavailable +#ifdef ECOMM + case ECOMM: // Communication error on send +#endif + case ECONNREFUSED: // Connection refused + case ECONNABORTED: // Connection aborted + case ECONNRESET: // Connection reset + case EINTR: // Interrupted function call +#ifdef EHOSTDOWN + case EHOSTDOWN: // Host is down +#endif + case EHOSTUNREACH: // Host is unreachable + case ENETDOWN: // Network is down + case ENETRESET: // Connection aborted by network + case ENETUNREACH: // Network unreachable + case ENOLCK: // No locks available + case ENOLINK: // Link has been severed +#ifdef ENONET + case ENONET: // Machine is not on the network +#endif + return absl::StatusCode::kUnavailable; + case EDEADLK: // Resource deadlock avoided +#ifdef ESTALE + case ESTALE: // Stale file handle +#endif + return absl::StatusCode::kAborted; + case ECANCELED: // Operation cancelled + return absl::StatusCode::kCancelled; + default: + return absl::StatusCode::kUnknown; + } +} + +// POSIX `strerror_r()` returns `int`. +ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(int result, const char* buffer, + int error_code) { + if (ABSL_PREDICT_FALSE(result != 0)) { + return absl::StrCat("Unknown error ", error_code); + } + return buffer; +} + +// GNU `strerror_r()` returns `char*`. +ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(char* result, + const char* buffer, + int error_code) { + return result; +} + +std::string StrError(int error_code) { + char message[256]; + return StrErrorResult(strerror_r(error_code, message, sizeof(message)), + message, error_code); +} + +} // namespace + +absl::Status ErrnoToCanonicalStatus(int error_number, + absl::string_view message) { + return absl::Status(ErrnoToCode(error_number), + absl::StrCat(message, ": ", StrError(error_number))); +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/layers/errno_mapping.h b/sparse_matmul/layers/errno_mapping.h new file mode 100644 index 00000000..747d3b4d --- /dev/null +++ b/sparse_matmul/layers/errno_mapping.h @@ -0,0 +1,29 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_ +#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" + +namespace csrblocksparse { + +// Converts |error_number| value to absl::Status. +absl::Status ErrnoToCanonicalStatus(int error_number, + absl::string_view message); + +} // namespace csrblocksparse + +#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_ diff --git a/sparse_matmul/layers/masked_sparse_matrix.h b/sparse_matmul/layers/masked_sparse_matrix.h new file mode 100644 index 00000000..a905ba4b --- /dev/null +++ b/sparse_matmul/layers/masked_sparse_matrix.h @@ -0,0 +1,206 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_ + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" + +namespace csrblocksparse { + +// MaskedSparseMatrix serves two purposes: +// 1) It is useful as a reference implementation of SpMV for correctness +// checking the much more complicated implementations in CSRBlockSparseMatrix +// 2) This is the format that sparse matrices are represented after pruning +// in TF. This class provides a bridge to getting these parameters into +// a compressed form suitable for computation and serialization. +// +// MaskedSparseMatrix matrix(rows, cols, mask_from_tf, values_from_tf); +// CSRBlockSparseMatrix csr_matrix(matrix); +// csr_matrix.Multiply(rhs, bias, &out); +template +class MaskedSparseMatrix { + public: + MaskedSparseMatrix() {} + + // Construct a MaskedSparseMatrix of the given size, sparsity and block size. + // This is mainly useful for testing. + MaskedSparseMatrix(int rows, int cols, float sparsity, int block_height = 1, + int block_width = 1, float constant = 1.f, + bool random = true) + : rows_(rows), cols_(cols), sparsity_(sparsity) { + CHECK_EQ(rows % block_height, 0); + CHECK_EQ(cols % block_width, 0); + + init(sparsity, block_height, block_width, constant, random); + } + + // Construct from an existing mask and values (most likely from a TF model). + template + MaskedSparseMatrix(int rows, int cols, const MaskType* mask, const T* values) + : rows_(rows), cols_(cols) { + mask_.resize(rows * cols); + values_.resize(rows * cols); + std::copy_n(mask, rows * cols, mask_.begin()); + std::copy_n(values, rows * cols, values_.begin()); + sparsity_ = + 1.f - std::accumulate(mask_.begin(), mask_.end(), 0.f) / mask_.size(); + } + + const std::vector& mask() const { return mask_; } + const std::vector& values() const { return values_; } + T* data() { return values_.data(); } + const T* data() const { return values_.data(); } + + int rows() const { return rows_; } + int cols() const { return cols_; } + float sparsity() const { return sparsity_; } + + void Print() const { + absl::PrintF("-------Values---------\n"); + for (int r = 0; r < rows_; ++r) { + for (int c = 0; c < cols_; ++c) { + absl::PrintF("%+6.3f ", static_cast(values_[r * cols_ + c])); + } + absl::PrintF("\n"); + } + absl::PrintF("-------Mask---------\n"); + for (int r = 0; r < rows_; ++r) { + for (int c = 0; c < cols_; ++c) { + printf("%2d ", mask_[r * cols_ + c]); + } + absl::PrintF("\n"); + } + } + + // This routine is useful for rounding the possibly higher precision values + // stored in this class to a lower precision, so that correctness checks + // between this class and CSRBlockSparseMatrix can have a tighter tolerance. + template + void CastWeights() { + for (int i = 0; i < values_.size(); ++i) { + values_[i] = static_cast(U(values_[i])); + } + } + + // Only meant for correctness checking. + // RhsClassType is meant to be either CacheAlignedVector OR + // FatCacheAlignedVector. + // The weight matrix is ROW MAJOR and RhsClassType is COLUMN MAJOR. + // |bias| is broadcast if |rhs| has more than one column. + template + void SpMM_bias(const RhsClassType& rhs, + const CacheAlignedVector& bias, OutClassType* out, + bool relu = false) { + for (int r = 0; r < rows_; ++r) { + for (int n = 0; n < rhs.cols(); ++n) { + float sum = 0.f; + const RhsType* rhs_ptr = rhs.data() + n * rhs.rows(); + OutType* out_ptr = out->data() + n * out->rows(); + const int* mask_ptr = mask_.data() + r * cols_; + const T* value_ptr = values_.data() + r * cols_; + for (int c = 0; c < cols_; ++c) { + sum += mask_ptr[c] * static_cast(value_ptr[c]) * + static_cast(rhs_ptr[c]); + } + out_ptr[r] = static_cast( + relu ? std::max(sum + static_cast(bias[r]), 0.f) + : sum + static_cast(bias[r])); + } + } + } + + private: + // Generate a random matrix with the specified sparsity. + // Useful for testing. + void init(float sparsity, int block_height, int block_width, float constant, + bool random = true) { + int reduced_rows = rows_ / block_height; + int reduced_cols = cols_ / block_width; + mask_.resize(rows_ * cols_, 0); + + // Fill with non-zero value to make sure masking works. + values_.resize(rows_ * cols_, static_cast(2.f)); + + std::mt19937 generator(0); + std::uniform_real_distribution dist_sparsity; + std::uniform_real_distribution dist_value(-1.f, 1.f); + int nnz = 0; + while (nnz == 0) { + for (int r = 0; r < reduced_rows; ++r) { + for (int c = 0; c < reduced_cols; ++c) { + if (dist_sparsity(generator) > sparsity) { + nnz++; + for (int i = 0; i < block_height; ++i) { + for (int j = 0; j < block_width; ++j) { + mask_[(r * block_height + i) * cols_ + block_width * c + j] = 1; + values_[(r * block_height + i) * cols_ + block_width * c + j] = + static_cast(random ? dist_value(generator) : constant); + } + } + } + } + } + } + } + + std::vector mask_; + std::vector values_; + int rows_; + int cols_; + float sparsity_; +}; + +template +class MaskedLinearLayer { + public: + MaskedLinearLayer(MaskedSparseMatrix&& weights, + CacheAlignedVector&& bias) + : weights_(std::move(weights)), bias_(std::move(bias)) {} + + MaskedLinearLayer() {} + + template + void CastWeights() { + weights_.template CastWeights(); + } + + // Does Ax + b where A is a masked sparse ROW MAJOR matrix and + // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is + // broadcast is rhs has more than one column. + template + void SpMM_bias(const FatVector& rhs, FatVector* out, bool relu = false) { + static_assert(std::is_same::value, + "FatVector value_type must match masked_linear_layer type"); + weights_.SpMM_bias(rhs, bias_, out, relu); + } + + private: + MaskedSparseMatrix weights_; + CacheAlignedVector bias_; +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_ diff --git a/sparse_matmul/layers/read_array_ifstream.h b/sparse_matmul/layers/read_array_ifstream.h new file mode 100644 index 00000000..3ea2bd13 --- /dev/null +++ b/sparse_matmul/layers/read_array_ifstream.h @@ -0,0 +1,66 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Low-level array reading function using std::ifstream. + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/substitute.h" +#include "include/ghc/filesystem.hpp" + +namespace csrblocksparse { +namespace detail { + +template +absl::Status ReadArrayIfstream(const std::string& file_name, + const std::string& path, std::vector* array, + int64_t* length) { + ghc::filesystem::path complete_path(path); + complete_path /= file_name; + std::ifstream in_stream(complete_path.u8string(), std::ios::binary); + if (!in_stream.is_open()) { + return absl::UnknownError( + absl::Substitute("Error opening $0", complete_path.string())); + } + + std::stringstream buffer; + buffer << in_stream.rdbuf(); + if (buffer.str().empty()) { + LOG(ERROR) << "File " << complete_path << " was empty."; + return absl::UnknownError( + absl::Substitute("File $0 was empty", complete_path.string())); + } + std::string contents = buffer.str(); + *length = contents.length(); + int64_t elem = (*length + sizeof(T) - 1) / sizeof(T); + array->resize(elem); + std::move(contents.begin(), contents.end(), + reinterpret_cast(array->data())); + + return absl::OkStatus(); +} + +} // namespace detail +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_ diff --git a/sparse_matmul/layers/sparse_linear_layer.h b/sparse_matmul/layers/sparse_linear_layer.h new file mode 100644 index 00000000..9363f301 --- /dev/null +++ b/sparse_matmul/layers/sparse_linear_layer.h @@ -0,0 +1,365 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_ + +#include + +#include "absl/memory/memory.h" +#include "glog/logging.h" +#include "sparse_matmul/layers/csr_blocksparse_matrix.h" +#include "sparse_matmul/layers/masked_sparse_matrix.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/os/coop_threads.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" + +namespace csrblocksparse { + +template ::type, + typename DeltaType = int16_t> +class SparseLinearLayer { + public: + SparseLinearLayer() {} + + SparseLinearLayer(CsrBlockSparseMatrix&& sparse_matrix, + CacheAlignedVector&& bias) + : sparse_matrix_(std::move(sparse_matrix)), full_bias_(std::move(bias)) { + CHECK_EQ(sparse_matrix_.rows(), full_bias_.size()); + // Some kernels expect that the bias is divided by 4, so we store a second + // copy of a quarter of the bias. + // TODO(b/189958858): Remove the quartered bias if it can be done without + // loss of speed, and rename the |full_bias_| member back to |bias_|. + bias_ = full_bias_; + for (int i = 0; i < bias_.size(); ++i) { + bias_[i] = static_cast(.25f * static_cast(bias_[i])); + } + } + SparseLinearLayer( + const SparseLinearLayer& src) { + *this = src; + } + SparseLinearLayer& operator=( + const SparseLinearLayer& src) { + sparse_matrix_ = src.sparse_matrix_; + bias_ = src.bias_; + full_bias_ = src.full_bias_; + mid_output_ = src.mid_output_; + thread_layers_ = src.thread_layers_; + num_threads_ = src.num_threads_; + if (src.split_pc_) { + split_pc_ = absl::make_unique( + src.split_pc_->num_producers(), src.split_pc_->num_consumers()); + } + return *this; + } + + // Does Ax + b where A is a block sparse compressed sparse row matrix and + // x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is + // broadcast if rhs has more than one column. + template + void SpMM_bias(const RhsClassType& rhs, OutType* out, bool relu = false, + int tid = 0, SpinBarrier* barrier = nullptr) const { + static_assert( + std::is_same::value, ""); + sparse_matrix_.SpMM_bias(rhs, bias_, out, relu, tid, barrier); + } + // Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above, + // and then samples from the output (softmax distribution) layer. + template + int SpMM_bias_Sample(const RhsClassType& rhs, OutType* out, float temperature, + int tid, SpinBarrier* barrier, std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + static_assert( + std::is_same::value, ""); + return sparse_matrix_.SpMM_bias_Sample(rhs, bias_, out, temperature, tid, + barrier, gen, scratch); + } + template + void MatVec(const RhsClassType& rhs, bool relu, int tid, int replicas, + int output_stride, OutType* output, + SpinBarrier* barrier = nullptr) { + static_assert( + std::is_same::value, ""); +#ifdef __AVX2__ + if (block_width() == 4 && (block_height() == 4 || block_height() == 8) && + !IsCustomFloatType::value) { + if (!IsSplit()) { + sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, + tid, replicas, output_stride, output->data()); + if (barrier != nullptr) barrier->barrier(); + return; + } + // NOTE: Until the quartered bias is removed it is a bad idea to split + // for ARM in the same way, as we would have to quarter the output of + // the first part of the split before running the second part. + // Signal completion of the previous MatVec. + split_pc_->produce(); + PartLinearLayer& thread_part = thread_layers_[tid]; + auto offset_output = + sparse_matrix_.thread_bounds().OffsetOutput(output->data(), tid); + auto mid_output = + sparse_matrix_.thread_bounds().OffsetOutput(mid_output_.data(), tid); + auto offset_bias = sparse_matrix_.thread_bounds().OffsetOutput( + mid_output_.cast_data(), tid); + // We can continue to consume the data that this thread produced and + // compute just the |self_matrix| part. + // No |relu| or |replicas|, as this is only a partial matmul. + // |tid| is always zero because the matrix has been split by tid. + thread_part.self_matrix.MatVec( + rhs.cast_data(), thread_part.full_bias.cast_data(), /*relu=*/false, + /*tid=*/0, /*replicas=*/1, output_stride, mid_output); + // We have to wait for the other threads to finish working on the previous + // MatMul before consuming the rest of |rhs|. + split_pc_->consume(); + thread_part.other_matrix.MatVec(rhs.cast_data(), offset_bias, relu, + /*tid=*/0, replicas, output_stride, + offset_output); + return; + } +#endif + DCHECK_EQ(replicas, 1) << "Must have single replica for SpMM API"; + if (IsSplit()) { + // Generics aren't setup to use a split matrix. This will be inefficient. + split_pc_->produce(); + split_pc_->consume(); + } + if (block_height() == 8) { + // We are currently forced to use MatVec generics for this case. + LOG(WARNING) << "Need to implement MatVec for 8x4 for non-AVX2 targets!!"; + sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, tid, + replicas, output_stride, output->data()); + if (barrier != nullptr) barrier->barrier(); + } else { + sparse_matrix_.SpMM_bias(rhs, bias_, output, relu, tid, barrier); + } + } + + int rows() const { return sparse_matrix_.rows(); } + int cols() const { return sparse_matrix_.cols(); } + float sparsity() const { return sparse_matrix_.sparsity(); } + int block_width() const { return sparse_matrix_.block_width(); } + int block_height() const { return sparse_matrix_.block_height(); } + int num_threads() const { return sparse_matrix_.num_threads(); } + const CacheAlignedVector& bias() const { return bias_; } + const std::vector& split_points() const { + return sparse_matrix_.split_points(); + } + bool IsSplit() const { + return !thread_layers_.empty() && split_pc_ != nullptr; + } + + std::size_t bytes() const { return sparse_matrix_.bytes() + bias_.bytes(); } + void Print() const { + printf("Matrix\n"); + sparse_matrix_.Print(); + printf("Bias\n"); + bias_.Print(); + } + + // Combines adjacent row blocks, doubling the block height. + // This necessarily involves adding zero weights where the blocks don't align + // across adjacent pairs of rows, so use with caution, as the resulting matrix + // is most likely to run slower if very sparse to begin with. + // In the few cases where the blocks do mostly align, the resulting matmul + // could be much faster, as the number of reads of the rhs will be halved. + void DoubleBlockHeight() { sparse_matrix_.DoubleBlockHeight(); } + + // Cache_line_size is provided only for testing. Normally uses a value for + // the current architecture. + int PrepareForThreads(int num_threads, int cache_line_size = -1) { + num_threads_ = num_threads; + if (num_threads_ > 1) { + split_pc_ = + absl::make_unique(num_threads_, num_threads_); + } else { + split_pc_.reset(nullptr); + } + return sparse_matrix_.PrepareForThreads(num_threads, cache_line_size); + } + + // Partitions the matrix into pieces by thread. + // In this matrix, we can go ahead and calculate the part that only depends + // on rhs inputs that were generated by this thread in the previous matvec, + // without having to use any thread synchronization, and only after that do we + // have to wait for the other threads to finish the previous matvec. + // So we split the matrix using the |split_points| from the previous matrix + // into 2 * |num_threads_| pieces: self and other for each thread, being the + // parts that can be calculated before and after the other threads have + // completed their calculation of the previous matvec. + // We then have to use a ProducerConsumer lock instead of a SpinBarrier to + // synchronize the data produced by the other threads. + void SliceForThreads(const std::vector& split_points) { + thread_layers_.clear(); + thread_layers_.reserve(num_threads_); + LOG(INFO) << "Slicing " << rows() << "x" << cols() << " matrix for " + << num_threads_ << " threads"; + for (int tid = 0; tid < num_threads_; ++tid) { + thread_layers_.emplace_back( + sparse_matrix_, full_bias_, bias_, tid, + split_points[tid] * sparse_matrix_.block_height(), + split_points[tid + 1] * sparse_matrix_.block_height()); + } + mid_output_ = + std::move(csrblocksparse::CacheAlignedVector(rows())); + mid_output_.FillZero(); + } + + // Splits the layer by inputs into 2 equal pieces. Each of the resulting + // layers should be computed independently on the first and second halves of + // the inputs respectively and the results added to achieve the same effect + // as the original layer. + void SplitInputs( + SparseLinearLayer* part1, + SparseLinearLayer* part2) { + CsrBlockSparseMatrix matrix1( + sparse_matrix_.SplitByColumn(0, sparse_matrix_.cols() / 2)); + CsrBlockSparseMatrix matrix2( + sparse_matrix_.SplitByColumn(sparse_matrix_.cols() / 2, + sparse_matrix_.cols())); + *part1 = + std::move(SparseLinearLayer( + std::move(matrix1), + std::move(CacheAlignedVector(full_bias_)))); + CacheAlignedVector bias2(sparse_matrix_.rows()); + bias2.FillZero(); + *part2 = + std::move(SparseLinearLayer( + std::move(matrix2), std::move(bias2))); + } + + // Splits the layer by outputs into 2 equal pieces. Each of the resulting + // layers should be computed independently on the full inputs and the results + // concatenated to achieve the same effect as the original layer. + void SplitOutputs( + SparseLinearLayer* part1, + SparseLinearLayer* part2) { + LOG(INFO) << "input rows=" << sparse_matrix_.rows() + << ", cols=" << sparse_matrix_.cols(); + CsrBlockSparseMatrix matrix1( + sparse_matrix_.SplitByRow(0, sparse_matrix_.rows() / 2)); + CsrBlockSparseMatrix matrix2(sparse_matrix_.SplitByRow( + sparse_matrix_.rows() / 2, sparse_matrix_.rows())); + CacheAlignedVector bias1(full_bias_, 0, full_bias_.size() / 2); + *part1 = + std::move(SparseLinearLayer( + std::move(matrix1), std::move(bias1))); + CacheAlignedVector bias2(full_bias_, full_bias_.size() / 2, + full_bias_.size()); + *part2 = + std::move(SparseLinearLayer( + std::move(matrix2), std::move(bias2))); + } + + private: + // Simple struct to hold a partitioned layer. + struct PartLinearLayer { + // The original matrix is first split by row to generate only the outputs + // for the given tid. The |row_sub_matrix| is then split by column into two + // partitions: + // self is the part for which the rhs elements in [|start_col|, |end_col|) + // were generated by this thread in some previous matmul. + // |other| is the rest of the columns that require rhs elements from other + // threads. + // NOTE that| start_col|, |end_col| are in raw columns, not blocks. + PartLinearLayer(const CsrBlockSparseMatrix& matrix, + const CacheAlignedVector& bias, + const CacheAlignedVector& bias_4, int tid, + int start_col, int end_col) { + int block_height = matrix.block_height(); + // Split the input matrix by row, selecting only the rows relevant to + // thread tid. + int start_row = matrix.split_points()[tid] * block_height; + int end_row = matrix.split_points()[tid + 1] * block_height; + LOG(INFO) << "input cols [" << start_col << "," << end_col << ") rows [" + << start_row << "," << end_row << ")"; + CsrBlockSparseMatrix row_sub_matrix = + matrix.SplitByRow(start_row, end_row); + // Partition into the columns that use rhs elements that thread tid + // produced in a previous matmul, and the other rhs elements. + // NOTE that we |keep_rhs_size|=true so that each matrix can operate on + // the same rhs input vector. The self matrix just guarantees not to + // access any of the elements that are generated by another thread. + self_matrix = std::move(row_sub_matrix.SplitByColumn( + start_col, end_col, /*keep_rhs_size=*/true)); + self_matrix.PrepareForThreads(1); + // The reversed start and end slice out the complement of [start, end). + other_matrix = std::move(row_sub_matrix.SplitByColumn( + end_col, start_col, /*keep_rhs_size=*/true)); + other_matrix.PrepareForThreads(1); + full_bias = + std::move(CacheAlignedVector(bias, start_row, end_row)); + // TODO(b/189958858): Eliminate the quarter bias from all the code. + quarter_bias = + std::move(CacheAlignedVector(bias_4, start_row, end_row)); + } + // The part of the matrix that only depends on this thread for rhs inputs. + CsrBlockSparseMatrix self_matrix; + CacheAlignedVector full_bias; + CacheAlignedVector quarter_bias; + // The part of the matrix that uses rhs inputs from other threads. + CsrBlockSparseMatrix other_matrix; + }; + CsrBlockSparseMatrix sparse_matrix_; + CacheAlignedVector bias_; + CacheAlignedVector full_bias_; + // Output from the self_matrix that will be given to |other_matrix| as bias. + CacheAlignedVector mid_output_; + // One partitioned pair of matrices for each thread. + std::vector thread_layers_; + // Producer-consumer lock used to wait between computing |self_matrix| and + // |other_matrix| for the other threads to finish the *previous* matvec. + std::unique_ptr split_pc_; + int num_threads_ = 0; +}; + +template +SparseLinearLayer CreateRandomLayer(int rows, int cols, + float sparsity, + int block_height = 1, + int block_width = 1) { + typedef typename TypeOfProduct::type BiasType; + CacheAlignedVector bias(rows); + bias.FillRandom(); + + auto masked_matrix = MaskedSparseMatrix(rows, cols, sparsity, + block_height, block_width); + auto sparse_matrix = CsrBlockSparseMatrix(masked_matrix); + + return SparseLinearLayer(std::move(sparse_matrix), + std::move(bias)); +} + +template +SparseLinearLayer CreateConstantLayer( + int rows, int cols, float sparsity, float constant = 1.f) { + typedef typename TypeOfProduct::type BiasType; + CacheAlignedVector bias(rows); + bias.FillOnes(); + + MaskedSparseMatrix masked_matrix(rows, cols, sparsity, + /*block_height=*/1, /*block_width=*/1, + constant, /*random=*/false); + CsrBlockSparseMatrix sparse_matrix(masked_matrix); + + return SparseLinearLayer(std::move(sparse_matrix), + std::move(bias)); +} + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_ diff --git a/sparse_matmul/layers/sparse_linear_layer_test.cc b/sparse_matmul/layers/sparse_linear_layer_test.cc new file mode 100644 index 00000000..bb256ec0 --- /dev/null +++ b/sparse_matmul/layers/sparse_linear_layer_test.cc @@ -0,0 +1,187 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/layers/sparse_linear_layer.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/test_utils.h" + +namespace csrblocksparse { +namespace { + +constexpr int kBlockSize = 4; +constexpr int kSize = 256; +constexpr int kNumThreads = 4; +constexpr int kCols = 1; + +void SlicedThreadBody(SpinBarrier* spin_barrier, int tid, + const FatCacheAlignedVector& rhs, + SparseLinearLayer* sparse_linear_layer, + FatCacheAlignedVector* out, bool use_relu) { + sparse_linear_layer->MatVec(rhs, use_relu, tid, /*replicas=*/1, + /*output_stride=*/0, out); + spin_barrier->barrier(); +} + +// Tests that a Layer that has been SliceForThreads computes the same result as +// the original layer. This is a basic test that all the slicing didn't mess up +// any of the computations. +TEST(CsrBlockSparseMatrix, SliceForThreads) { + MaskedSparseMatrix matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize); + FatCacheAlignedVector rhs(kSize, kCols); + CacheAlignedVector bias(kSize); + FatCacheAlignedVector out1(kSize, kCols); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + FatCacheAlignedVector out_reference = out1; + CsrBlockSparseMatrix sparse_matrix(matrix); + SparseLinearLayer sparse_linear_layer(std::move(sparse_matrix), + std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out_reference); + std::vector fake_split_points = {0, 48 / kBlockSize, 128 / kBlockSize, + 208 / kBlockSize, kSize / kBlockSize}; + sparse_linear_layer.PrepareForThreads(kNumThreads); + sparse_linear_layer.SliceForThreads(fake_split_points); + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, SlicedThreadBody, rhs, + &sparse_linear_layer, &out1, + /*relu=*/true); + + CheckResult(out_reference, out1, kCols); +} + +void LayersThreadBody(SpinBarrier* spin_barrier, int tid, + const FatCacheAlignedVector& rhs, + SparseLinearLayer* sparse_linear_layer1, + SparseLinearLayer* sparse_linear_layer2, + FatCacheAlignedVector* out1, + FatCacheAlignedVector* out2, bool use_relu) { + sparse_linear_layer1->MatVec(rhs, use_relu, tid, /*replicas=*/1, + /*output_stride=*/0, out1); + // NOTE no barrier here! + sparse_linear_layer2->MatVec(*out1, use_relu, tid, /*replicas=*/1, + /*output_stride=*/0, out2); + spin_barrier->barrier(); +} + +// Tests that a pair of layers computes the same result whether or not the +// second layer has been SliceForThreads. This is a more critical test that +// the replacement of barriers with producer-consumer locks works. +// Must be run with tsan to really test it properly. +TEST(CsrBlockSparseMatrix, SliceForThreadsLayers) { + MaskedSparseMatrix matrix1(kSize, kSize, 0.95, kBlockSize, kBlockSize); + FatCacheAlignedVector rhs(kSize, kCols); + CacheAlignedVector bias1(kSize); + FatCacheAlignedVector out1(kSize, kCols); + MaskedSparseMatrix matrix2(kSize, kSize, 0.95, kBlockSize, kBlockSize); + CacheAlignedVector bias2(kSize); + FatCacheAlignedVector out2(kSize, kCols); + + bias1.FillRandom(); + rhs.FillRandom(); + bias2.FillRandom(); + out1.FillZero(); + out2.FillZero(); + FatCacheAlignedVector out_reference = out2; + CsrBlockSparseMatrix sparse_matrix1(matrix1); + SparseLinearLayer layer1(std::move(sparse_matrix1), + std::move(bias1)); + CsrBlockSparseMatrix sparse_matrix2(matrix2); + SparseLinearLayer layer2(std::move(sparse_matrix2), + std::move(bias2)); + layer1.PrepareForThreads(1); + layer2.PrepareForThreads(1); + layer1.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out1); + layer2.MatVec(out1, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out_reference); + layer1.PrepareForThreads(kNumThreads); + layer2.PrepareForThreads(kNumThreads); + layer2.SliceForThreads(layer1.split_points()); + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, LayersThreadBody, rhs, + &layer1, &layer2, &out1, &out2, + /*relu=*/true); + + CheckResult(out_reference, out2, kCols); +} + +// Tests that a Layer that has been DoubleBlockHeight()-ed computes the same +// result as original layer. (Float compute type). +TEST(CsrBlockSparseMatrix, Float8x4) { + using ComputeType = float; + using RhsType = float; + using BiasType = float; + MaskedSparseMatrix matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize); + matrix.CastWeights(); + FatCacheAlignedVector rhs(kSize, kCols); + CacheAlignedVector bias(kSize); + FatCacheAlignedVector out1(kSize, kCols); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + FatCacheAlignedVector out_reference = out1; + CsrBlockSparseMatrix sparse_matrix(matrix); + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out_reference); + sparse_linear_layer.DoubleBlockHeight(); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out1); + CheckResult(out_reference, out1, kCols); +} + +// Tests that a Layer that has been DoubleBlockHeight()-ed computes the same +// result as original layer. (Fixed16 compute type). +TEST(CsrBlockSparseMatrix, Fixed8x4) { + using ComputeType = csrblocksparse::fixed16<4>; + using RhsType = csrblocksparse::fixed16<4>; + using BiasType = typename TypeOfProduct::type; + MaskedSparseMatrix matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize); + matrix.CastWeights(); + FatCacheAlignedVector rhs(kSize, kCols); + CacheAlignedVector bias(kSize); + FatCacheAlignedVector out1(kSize, kCols); + + bias.FillRandom(); + rhs.FillRandom(); + out1.FillZero(); + FatCacheAlignedVector out_reference = out1; + CsrBlockSparseMatrix sparse_matrix(matrix); + SparseLinearLayer sparse_linear_layer( + std::move(sparse_matrix), std::move(bias)); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out_reference); + sparse_linear_layer.DoubleBlockHeight(); + sparse_linear_layer.PrepareForThreads(1); + sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1, + /*output_stride=*/0, &out1); + CheckResult(out_reference, out1, kCols); +} + +TEST(SparseLinearLayerTest, PrintCompiles) { + SparseLinearLayer sparse_linear_layer; + sparse_linear_layer.Print(); +} + +} // namespace +} // namespace csrblocksparse diff --git a/sparse_matmul/layers/status_macros.h b/sparse_matmul/layers/status_macros.h new file mode 100644 index 00000000..d2ebeaa7 --- /dev/null +++ b/sparse_matmul/layers/status_macros.h @@ -0,0 +1,34 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_ +#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +#define SPARSE_MATMUL_RETURN_IF_ERROR(expr) \ + do { \ + const absl::Status _status = (expr); \ + if (!_status.ok()) return _status; \ + } while (0) +template +absl::Status DoAssignOrReturn(T& lhs, absl::StatusOr result) { + if (result.ok()) { + lhs = result.value(); + } + return result.status(); +} + +#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz new file mode 100644 index 00000000..745fd041 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz new file mode 100644 index 00000000..32535e32 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz new file mode 100644 index 00000000..b976bf4f Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz new file mode 100644 index 00000000..804e3b58 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz new file mode 100644 index 00000000..e2944e99 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz new file mode 100644 index 00000000..5870e3c8 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz new file mode 100644 index 00000000..5ccb2ee7 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz new file mode 100644 index 00000000..03c07675 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz new file mode 100644 index 00000000..272cd774 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz new file mode 100644 index 00000000..cfcd6774 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz new file mode 100644 index 00000000..91f20f55 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz new file mode 100644 index 00000000..5acaecbd Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz new file mode 100644 index 00000000..d989cab2 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz new file mode 100644 index 00000000..1366f1bf Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz new file mode 100644 index 00000000..8fb70348 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz new file mode 100644 index 00000000..17bc2c03 Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz differ diff --git a/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz new file mode 100644 index 00000000..f8334c5a Binary files /dev/null and b/sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz differ diff --git a/sparse_matmul/layers/testdata/lyra_conv1d_bias.raw.gz b/sparse_matmul/layers/testdata/lyra_conv1d_bias.raw.gz new file mode 100644 index 00000000..d24c03bd Binary files /dev/null and b/sparse_matmul/layers/testdata/lyra_conv1d_bias.raw.gz differ diff --git a/sparse_matmul/layers/testdata/lyra_conv1d_mask.raw.gz b/sparse_matmul/layers/testdata/lyra_conv1d_mask.raw.gz new file mode 100644 index 00000000..8b72f388 Binary files /dev/null and b/sparse_matmul/layers/testdata/lyra_conv1d_mask.raw.gz differ diff --git a/sparse_matmul/layers/testdata/lyra_conv1d_weights.raw.gz b/sparse_matmul/layers/testdata/lyra_conv1d_weights.raw.gz new file mode 100644 index 00000000..cf26d5b2 Binary files /dev/null and b/sparse_matmul/layers/testdata/lyra_conv1d_weights.raw.gz differ diff --git a/sparse_matmul/layers/utils.cc b/sparse_matmul/layers/utils.cc new file mode 100644 index 00000000..0a8d5796 --- /dev/null +++ b/sparse_matmul/layers/utils.cc @@ -0,0 +1,129 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Source for various utility functions related to reading and writing files +// and vectors. Would be much simpler if Android and Windows supported File. + +#include "sparse_matmul/layers/utils.h" + +#ifdef _WIN32 +#include + +#include +#include // NOLINT +#else +#include +#endif // _WIN32 + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" + +namespace csrblocksparse { + +namespace { + +// Helper to test if a filename is "." or "..". +template +bool IsDotOrDotDot(const CharType* filename) { + if (filename[0] == '.') { + if (filename[1] == '\0') { + return true; + } + if (filename[1] == '.' && filename[2] == '\0') { + return true; + } + } + + return false; +} + +#ifdef _WIN32 // We only define these conversion routines on Win32. +static std::mutex g_converter_mutex; +static std::wstring_convert> g_converter; + +std::string Narrow(const std::wstring& wide) { + std::lock_guard auto_lock(g_converter_mutex); + return g_converter.to_bytes(wide); +} + +std::wstring Widen(const std::string& narrow) { + std::lock_guard auto_lock(g_converter_mutex); + return g_converter.from_bytes(narrow); +} + +inline constexpr char kLongPathPrefix[] = R"(\\?\)"; + +std::wstring ConvertToWindowsPathFormat(const std::string& path, + int max_path_length = MAX_PATH) { + if (path.length() + 1 > max_path_length && + !absl::StartsWith(path, kLongPathPrefix)) { + return Widen(absl::StrCat(kLongPathPrefix, path)); + } + return Widen(path); +} +#endif // _WIN32 + +} // namespace + +// Return all files in a given directory. +absl::Status FilesInDirectory(const std::string& path, + const std::string& must_contain, + std::vector* result) { +#ifdef _WIN32 + WIN32_FIND_DATAW child_data; + HANDLE find_handle = FindFirstFileW( + ConvertToWindowsPathFormat(absl::StrCat(path, "\\*")).c_str(), + &child_data); + if (find_handle == INVALID_HANDLE_VALUE) { + return absl::UnknownError( + absl::Substitute("Couldn't open: $0 (error $1)", path, GetLastError())); + } + do { + if (IsDotOrDotDot(child_data.cFileName)) continue; + const std::string name = Narrow(child_data.cFileName); + if (name.find(must_contain) == std::string::npos) continue; + result->push_back(name); + } while (FindNextFileW(find_handle, &child_data) != 0); + const auto err = GetLastError(); + FindClose(find_handle); + if (err != ERROR_NO_MORE_FILES) + return absl::UnknownError( + absl::Substitute("Error in FindNextFileW: $0", err)); +#else + DIR* dirp = opendir(path.c_str()); + if (dirp == nullptr) { + return absl::UnknownError(absl::Substitute("Couldn't open: $0", path)); + } + + dirent* dp; + errno = 0; + while ((dp = readdir(dirp)) != nullptr) { + if (IsDotOrDotDot(dp->d_name)) continue; + const std::string name(dp->d_name); + if (name.find(must_contain) == std::string::npos) continue; + result->push_back(name); + } + closedir(dirp); + if (errno != 0) + return absl::UnknownError(absl::Substitute("Error in readdir: $0", errno)); +#endif // _WIN32 + + return absl::OkStatus(); +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/layers/utils.h b/sparse_matmul/layers/utils.h new file mode 100644 index 00000000..e10b1b95 --- /dev/null +++ b/sparse_matmul/layers/utils.h @@ -0,0 +1,338 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Various utility functions related to reading and writing files, vectors, etc. +// Would be much simpler if Android supported File. + +#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_UTILS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_UTILS_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/substitute.h" +#include "include/ghc/filesystem.hpp" +#include "sparse_matmul/layers/errno_mapping.h" +#include "sparse_matmul/layers/masked_sparse_matrix.h" +#include "sparse_matmul/layers/read_array_ifstream.h" +#include "sparse_matmul/layers/sparse_linear_layer.h" +#include "sparse_matmul/layers/status_macros.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" +#include "sparse_matmul/zlib_wrapper/zlibwrapper.h" + +namespace csrblocksparse { + +template +void unzip(int64_t st_size, std::vector* array) { + ZLib z; + z.SetGzipHeaderMode(); + if (z.HasGzipHeader(reinterpret_cast(array->data()), st_size)) { + const std::size_t kMaxBufferSize = 1 << 27; // 128MB + + Bytef* dest; + uLongf dest_len = kMaxBufferSize; + CHECK_EQ(z.UncompressGzipAndAllocate(&dest, &dest_len, + (Bytef*)array->data(), st_size), + Z_OK); + CHECK_EQ(dest_len % sizeof(T), 0); + array->assign(reinterpret_cast(dest), + reinterpret_cast(dest + dest_len)); + free(dest); + } else { + CHECK_EQ(st_size % sizeof(T), 0); + } +} + +// Reads a file that contains an array of a single POD type. Eventually we +// will replace serializiation with protos, but for now this is the easiest way +// to interface with the rest of the pipeline. +// +// StatusOr might be preferred but does not compile on ARM. +// |DiskType| and |ElemType| template types have no effect in this function +// version and are only used to handle fixed_type disk storage. +template +typename std::enable_if::value, + absl::Status>::type +ReadArrayFromFile(const std::string& file_name, std::vector* array, + const std::string& path = "/data/local/tmp/") { + int64_t length = 0; + const absl::Status status = + detail::ReadArrayIfstream(file_name, path, array, &length); + if (!status.ok()) { + return status; + } + unzip(length, array); + + return absl::OkStatus(); +} + +// If the metatype |DiskType| is of fixed16_type, we load int16_ts from disk and +// construct |ElemType| from them. |ElemType| is necessary because we need to +// know the mantissa/exponent bit split before casting to float. We need a +// separate function template for fixed rather than an if block because the +// compiler will complain bfloat not having an int16_t constructor. +template +typename std::enable_if::value && + csrblocksparse::IsFixed16Type::value, + absl::Status>::type +ReadArrayFromFile(const std::string& file_name, std::vector* array, + const std::string& path = "/data/local/tmp/") { + std::vector disk_values; + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(file_name, &disk_values, path)); + array->resize(disk_values.size()); + std::transform( + disk_values.begin(), disk_values.end(), array->begin(), + [](int16_t disk_value) { return static_cast(ElemType(disk_value)); }); + return absl::OkStatus(); +} + +// Writes a vector to a binary file. Eventually serialization will be handled +// with protos. +template +absl::Status WriteArrayToFile(const std::vector& array, + const std::string& file_name, + std::string path = "/data/local/tmp/") { + path = (ghc::filesystem::path(path) / file_name).string(); + FILE* fp = fopen(path.c_str(), "wb"); + if (fp == nullptr) + return ErrnoToCanonicalStatus(errno, + absl::Substitute("Error opening $0", path)); + size_t write_count = fwrite(array.data(), sizeof(T), array.size(), fp); + if (write_count != array.size()) { + return ErrnoToCanonicalStatus( + errno, + absl::Substitute( + "Error writing array, only wrote $0 of $1 elements for file $2", + write_count, array.size(), path)); + } + SPARSE_MATMUL_RETURN_IF_ERROR(ErrnoToCanonicalStatus( + fclose(fp), absl::Substitute("Error closing $0", path))); + return absl::OkStatus(); +} + +// Reads an entire layer that consists of weights, bias and mask as a +// SparseLinearLayer. Eventually this serialization will be handled with +// protos, but the rest of the system currently does naive serialization. +// +// StatusOr might be preferred but does not compile on ARM. +// +// Here |DiskWeightType| is the metatype used to store the weights, usually +// fixed16_type, float, or bfloat. +// For |DiskWeightType| = fixed16_type specialization, this loads a file with a +// "fixed16_weights.raw" suffix which stores int16_ts as its element datatype. +// The disk elements should match fixed16. This cuts +// down disk storage of weights by +// >= half. For all other types it reads the weights as floats. +template +absl::Status LoadGenericLayer( + const std::string& prefix, bool zipped, const std::string& path, + float default_bias, + SparseLinearLayer* sparse_linear_layer) { + std::string fixed_prefix = + csrblocksparse::IsFixed16Type::value ? "fixed16_" : ""; + std::string extension = zipped ? ".gz" : ""; + std::string weight_name = + absl::StrCat(prefix, fixed_prefix, "weights.raw", extension); + std::string mask_name = absl::StrCat(prefix, "mask.raw", extension); + std::string bias_name = absl::StrCat(prefix, "bias.raw", extension); + + std::vector weight_vector; + std::vector mask_vector; + std::vector bias_vector; + + const auto status = ReadArrayFromFile( + weight_name, &weight_vector, path); + SPARSE_MATMUL_RETURN_IF_ERROR(status); + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(mask_name, &mask_vector, path)); + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(bias_name, &bias_vector, path)); + + CHECK(weight_vector.size() == mask_vector.size()) + << "Weight and mask must be" + << " the same size, weights: " << weight_vector.size() + << " mask: " << mask_vector.size(); + CHECK(weight_vector.size() % bias_vector.size() == 0) + << "Weights size must " + "be a multiple of the bias size. Weights: " + << weight_vector.size() + << " " + "bias: " + << bias_vector.size() + << " remainder: " << weight_vector.size() % bias_vector.size(); + + int rows = bias_vector.size(); + int cols = weight_vector.size() / rows; + + MaskedSparseMatrix weights_masked(rows, cols, mask_vector.data(), + weight_vector.data()); + + weights_masked.template CastWeights(); + using csrmatrix = CsrBlockSparseMatrix; + + csrmatrix weights(weights_masked); + // If the weights were not a multiple of the block size in rows, we need to + // expand the bias vector to match using the provided default_bias value. + bias_vector.resize(weights.rows(), default_bias); + using BiasType = typename TypeOfProduct::type; + CacheAlignedVector bias(bias_vector); + + *sparse_linear_layer = std::move(SparseLinearLayer( + std::move(weights), std::move(bias))); + + return absl::OkStatus(); +} +template +absl::Status LoadSparseLayer( + const std::string& prefix, bool zipped, + SparseLinearLayer* sparse_linear_layer, + const std::string& path = "/data/local/tmp/") { + return LoadGenericLayer( + prefix, zipped, path, 0.0f, sparse_linear_layer); +} +template +absl::Status LoadLogitLayer( + const std::string& prefix, bool zipped, const std::string& path, + SparseLinearLayer* sparse_linear_layer) { + return LoadGenericLayer( + prefix, zipped, path, std::numeric_limits::lowest(), + sparse_linear_layer); +} + +// Reads an entire layer that consists of weights, bias and mask as a +// MaskedLinearLayer. Eventually this serialization will be handled with +// protos, but the rest of the system currently does naive serialization. +// +// StatusOr might be preferred but does not compile on ARM. +template +absl::Status LoadMaskedLayer(const std::string& prefix, bool zipped, + MaskedLinearLayer* masked_sparse_matrix, + const std::string& path = "/data/local/tmp/") { + std::string extension = zipped ? ".gz" : ""; + std::string weight_name = absl::StrCat(prefix, "weights.raw", extension); + std::string mask_name = absl::StrCat(prefix, "mask.raw", extension); + std::string bias_name = absl::StrCat(prefix, "bias.raw", extension); + + std::vector weight_vector; + std::vector mask_vector; + std::vector bias_vector; + + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(weight_name, &weight_vector, path)); + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(mask_name, &mask_vector, path)); + SPARSE_MATMUL_RETURN_IF_ERROR( + ReadArrayFromFile(bias_name, &bias_vector, path)); + + CHECK(weight_vector.size() == mask_vector.size()) + << "Weight and mask must be" + << " the same size, weights: " << weight_vector.size() + << " mask: " << mask_vector.size(); + CHECK(weight_vector.size() % bias_vector.size() == 0) + << "Weights size must " + "be a multiple of the bias size. Weights: " + << weight_vector.size() + << " " + "bias: " + << bias_vector.size() + << " remainder: " << weight_vector.size() % bias_vector.size(); + + int rows = bias_vector.size(); + int cols = weight_vector.size() / rows; + + MaskedSparseMatrix weights_masked(rows, cols, mask_vector.data(), + weight_vector.data()); + CacheAlignedVector bias(bias_vector); + + *masked_sparse_matrix = + MaskedLinearLayer(std::move(weights_masked), std::move(bias)); + return absl::OkStatus(); +} + +// Load a vector of POD into a CacheAlignedVector. +// +// StatusOr might be preferred but does not compile on ARM. +template +absl::Status LoadVector(const std::string& file_name, + CacheAlignedVector* cache_aligned_vector, + const std::string& path = "/data/local/tmp/") { + std::vector values; + + SPARSE_MATMUL_RETURN_IF_ERROR(ReadArrayFromFile(file_name, &values, path)); + + *cache_aligned_vector = std::move(CacheAlignedVector(values)); + + return absl::OkStatus(); +} + +// Loads a 2D vector from a file. One of rows or cols can optionally be +// -1 to indicate that dimension should be inferred. +template +absl::Status LoadFatVector(const std::string& file_name, int rows, int cols, + FatCacheAlignedVector* fat_cache_aligned_vector, + const std::string& path = "/data/local/tmp/") { + // neither can be zero + CHECK(rows != 0 && cols != 0); + // only one can be -1 + CHECK(rows != -1 || cols != -1); + // otherwise must be positive + CHECK(rows >= -1 && cols >= -1); + + CacheAlignedVector values; + + SPARSE_MATMUL_RETURN_IF_ERROR(LoadVector(file_name, &values, path)); + + if (rows > 0) + CHECK_EQ(values.size() % rows, 0); + else + rows = values.size() / cols; + + if (cols > 0) + CHECK_EQ(values.size() % cols, 0); + else + cols = values.size() / rows; + + *fat_cache_aligned_vector = std::move(FatCacheAlignedVector(values, rows)); + + return absl::OkStatus(); +} + +// Return all files in a given directory +// If only File worked on Android and Windows... +absl::Status FilesInDirectory(const std::string& path, + const std::string& must_contain, + std::vector* result); + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_UTILS_H_ diff --git a/sparse_matmul/layers/utils_test.cc b/sparse_matmul/layers/utils_test.cc new file mode 100644 index 00000000..c70ee074 --- /dev/null +++ b/sparse_matmul/layers/utils_test.cc @@ -0,0 +1,185 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/layers/utils.h" + +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "include/ghc/filesystem.hpp" +#include "sparse_matmul/layers/csr_blocksparse_matrix.h" +#include "sparse_matmul/layers/errno_mapping.h" +#include "sparse_matmul/layers/sparse_linear_layer.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/test_utils.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" + +namespace csrblocksparse { +namespace { + +static constexpr char kTempOutputDir[] = + "third_party/lyra_codec/sparse_matmul/layers/testdata/"; +static constexpr int kTestExponentBits = 5; + +template +class CsrBlockSparseMatrixUtilsTest : public testing::Test { + protected: + CsrBlockSparseMatrixUtilsTest() + : output_dir_((ghc::filesystem::path(testing::TempDir()) / kTempOutputDir) + .string()) { + if (std::is_floating_point::value) { + tolerance_ = 1e-5; + } else if (csrblocksparse::IsCustomFloatType::value) { + // Casting float --> bfloat truncates the least significant 16 bits from + // the mantissa, thus the larger the exponent bits the larger the rounding + // error. + // The exponent for max_val is 2^4, meaning the max rounding error + // for the weight input is ~ 0.124. The tolerance is 2x this because + // although the intermediate multiplications are accumulated in float, + // the output is cast to bfloat. + // Placeholder for internal diagram. + float max_val = + std::pow(2, kTestExponentBits) - + std::pow(2, -fixed16::kMantissaBits); + tolerance_ = 2 * (max_val - static_cast(ComputeType(max_val))); + } else { + tolerance_ = std::pow(2, -MantissaBitsOf::value); + } + } + + void SetUp() override { + std::error_code error_code; + ghc::filesystem::create_directories(output_dir_, error_code); + ASSERT_FALSE(error_code); + } + + void TearDown() override { + std::error_code error_code; + ghc::filesystem::remove_all(output_dir_, error_code); + ASSERT_FALSE(error_code); + } + + const std::string output_dir_; + float tolerance_; +}; + +void GenerateRandomWeightBiasMaskVectors( + int weight_vector_size, int bias_vector_size, + std::vector* weight_vector, std::vector* bias_vector, + std::vector* mask_vector, std::vector* masked_weight_vector) { + weight_vector->resize(weight_vector_size); + bias_vector->resize(bias_vector_size); + mask_vector->resize(weight_vector_size); + masked_weight_vector->resize(weight_vector_size); + // Fill Weight and Bias with random values between +/-[2^|kTestExponentBits| - + // 1] - 0.5 to prevent clipping in the fixed16 case when the weight and bias + // are added with all 1s in the exponent and mantissa. + const float max_abs_random_value = + std::pow(2, kTestExponentBits - 1) - 0.5; + std::uniform_real_distribution distribution(-max_abs_random_value, + max_abs_random_value); + std::default_random_engine generator(1337); + std::generate(weight_vector->begin(), weight_vector->end(), + [&]() { return distribution(generator); }); + std::generate(bias_vector->begin(), bias_vector->end(), + [&]() { return distribution(generator); }); + std::bernoulli_distribution mask_distribution(0.5); + std::generate(mask_vector->begin(), mask_vector->end(), + [&]() { return mask_distribution(generator) ? 1 : 0; }); + // Construct the combined weight and mask vector. + std::transform(mask_vector->begin(), mask_vector->end(), + weight_vector->begin(), masked_weight_vector->begin(), + [&](float mask_value, float weight_value) { + return mask_value * weight_value; + }); +} + +using ComputeTypes = + testing::Types, + csrblocksparse::bfloat16>; +TYPED_TEST_SUITE(CsrBlockSparseMatrixUtilsTest, ComputeTypes); + +TYPED_TEST(CsrBlockSparseMatrixUtilsTest, LoadLayer) { + const int kWeightVectorSize = 16; + const int kBiasVectorSize = 4; + std::vector ref_weight_vector; + std::vector ref_bias_vector; + std::vector ref_mask_vector; + std::vector ref_masked_weight_vector; + + GenerateRandomWeightBiasMaskVectors( + kWeightVectorSize, kBiasVectorSize, &ref_weight_vector, &ref_bias_vector, + &ref_mask_vector, &ref_masked_weight_vector); + + // This fixed16_weights.raw vector should only be read by LoadGenericLayer + // when |TypeParam| is a fixed16_type. + std::vector fixed_weight_vector(ref_weight_vector.size()); + std::transform(ref_weight_vector.begin(), ref_weight_vector.end(), + fixed_weight_vector.begin(), [](float weight) { + return fixed16(weight).raw_val(); + }); + ASSERT_TRUE(WriteArrayToFile(fixed_weight_vector, "fixed16_weights.raw", + this->output_dir_) + .ok()); + ASSERT_TRUE( + WriteArrayToFile(ref_weight_vector, "weights.raw", this->output_dir_) + .ok()); + ASSERT_TRUE( + WriteArrayToFile(ref_bias_vector, "bias.raw", this->output_dir_).ok()); + ASSERT_TRUE( + WriteArrayToFile(ref_mask_vector, "mask.raw", this->output_dir_).ok()); + + // Read in the weights, mask, and bias to a layer. + SparseLinearLayer actual_layer; + using DiskWeightType = + typename std::conditional::value, + csrblocksparse::fixed16_type, TypeParam>::type; + auto status = LoadGenericLayer( + /*prefix=*/"", /*zipped=*/false, this->output_dir_, + /*default_bias=*/0.f, &actual_layer); + ASSERT_TRUE(status.ok()); + // Multiply the read in layer with an identity matrix so we just get + // the weights added with bias. + std::vector identity(kBiasVectorSize * kBiasVectorSize, + TypeParam(0.f)); + for (int i = 0; i < identity.size(); i += kBiasVectorSize + 1) { + identity.at(i) = TypeParam(1.f); + } + FatCacheAlignedVector masked_weights_plus_bias(kBiasVectorSize, + kBiasVectorSize); + actual_layer.SpMM_bias( + VectorView(identity.data(), /*rows=*/kBiasVectorSize, + /*cols=*/kBiasVectorSize), + &masked_weights_plus_bias); + // |masked_weights_plus_bias| - bias = masked weights. + for (int col = 0; col < masked_weights_plus_bias.cols(); col++) { + MutableVectorView col_data = masked_weights_plus_bias.slice(col); + for (int row = 0; row < masked_weights_plus_bias.rows(); row++) { + int flat_index = row * masked_weights_plus_bias.cols() + col; + EXPECT_NEAR(static_cast(col_data[row]) - ref_bias_vector.at(row), + ref_masked_weight_vector.at(flat_index), this->tolerance_); + } + } +} +} // namespace +} // namespace csrblocksparse diff --git a/sparse_matmul/numerics/BUILD b/sparse_matmul/numerics/BUILD new file mode 100644 index 00000000..0a81aafb --- /dev/null +++ b/sparse_matmul/numerics/BUILD @@ -0,0 +1,160 @@ +# Base numeric types and transcendental functions. + +licenses(["notice"]) + +cc_library( + name = "fast_transcendentals", + srcs = [ + "fast_transcendentals.cc", + ], + hdrs = [ + "fast_transcendentals.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [":types"], +) + +cc_library( + name = "test_utils", + testonly = 1, + hdrs = [ + "test_utils.h", + ], + visibility = ["//sparse_matmul:__subpackages__"], + deps = [ + ":types", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "types", + hdrs = [ + "fixed_types.h", + "float16_types.h", + "type_utils.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + "@com_google_glog//:glog", + ], +) + +cc_library( + name = "fast_transcendentals_cc", + srcs = ["fast_transcendentals.cc"], + hdrs = ["fast_transcendentals.h"], + deps = [":types"], +) + +cc_test( + name = "fasttranscendentals_test", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_fast", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = ["-DFAST_TRANSCENDENTALS"], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_fast_accurate", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DACCURATE_TRANSCENDENTAL_APPROX", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_fast_accurate_sigmoidastanh", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DACCURATE_TRANSCENDENTAL_APPROX", + "-DSIGMOID_AS_TANH", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_fast_sigmoidastanh", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DSIGMOID_AS_TANH", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fasttranscendentals_test_faster_sigmoid", + size = "small", + srcs = [ + "fasttranscendentals_test.cc", + ], + copts = [ + "-DFASTER_TRANSCENDENTALS", + ], + deps = [ + ":fast_transcendentals", + ":test_utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "fixed_types_test", + size = "small", + srcs = [ + "fixed_types_test.cc", + ], + deps = [ + ":test_utils", + ":types", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/sparse_matmul/numerics/fast_transcendentals.cc b/sparse_matmul/numerics/fast_transcendentals.cc new file mode 100644 index 00000000..75adf01a --- /dev/null +++ b/sparse_matmul/numerics/fast_transcendentals.cc @@ -0,0 +1,81 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/numerics/fast_transcendentals.h" + +namespace csrblocksparse { + +// Maximum desired precision of the output. +static constexpr int kMaxMantissaBits = 14; + +// Returns (and builds if not done yet) a static data table that implements +// tanh on fixed32 input, returning another fixed32 with the given number of +// mantissa bits (which is assumed to be less than the input mantissa bits). +// NOTE that this function is intended to be used only with fixed16 outputs that +// are sign-extended to 32 bits for convenience, and will return a nullptr +// if asked for more than |kMaxMantissaBits| of precision in the output table. +const int32_t* TanhTable(int num_mantissa_bits_out) { + if (num_mantissa_bits_out > kMaxMantissaBits) return nullptr; + // Static data dynamically created and never destructed. + static const int32_t* tanh_luts[kMaxMantissaBits]; + if (tanh_luts[num_mantissa_bits_out - 1] == nullptr) { + // Total bits is number each side of the binary point. + int tanh_lut_bits = num_mantissa_bits_out + kNumTanhExpBits; + // Offset is the number of negative numbers represented. + int tanh_offset = 1 << tanh_lut_bits; + // Size is double the offset plus one more for zero. + int tanh_size = tanh_offset * 2 + 1; + // Conversion between int and float. + float float_factor = static_cast(1 << num_mantissa_bits_out); + int* tanh_lut = new int[tanh_size]; + // Initialize the table. + for (int i = 0; i < tanh_size; ++i) { + float x = (i - tanh_offset) / float_factor; + tanh_lut[i] = static_cast(std::round(tanhf(x) * float_factor)); + } + tanh_luts[num_mantissa_bits_out - 1] = tanh_lut; + } + return tanh_luts[num_mantissa_bits_out - 1]; +} + +// As TanhTable, but for Sigmoid. +const int32_t* SigmoidTable(int num_mantissa_bits_out) { + if (num_mantissa_bits_out > kMaxMantissaBits) return nullptr; + // Static data dynamically created and never destructed. + static const int32_t* sigmoid_luts[kMaxMantissaBits]; + if (sigmoid_luts[num_mantissa_bits_out - 1] == nullptr) { + // Total bits is number each side of the binary point minus one for the fact + // that the gradient never exceeds 1/4. (Could probably use -2.) + int sigmoid_lut_bits = + num_mantissa_bits_out + kNumSigmoidExpBits - kNumExtraSigmoidShiftBits; + // Offset is the number of negative numbers represented. + int sigmoid_offset = 1 << sigmoid_lut_bits; + // Size is double the offset plus one more for zero. + int sigmoid_size = sigmoid_offset * 2 + 1; + // Conversion between int and float. + float float_factor = static_cast(1 << num_mantissa_bits_out); + int* sigmoid_lut = new int[sigmoid_size]; + // Initialize the table. + for (int i = 0; i < sigmoid_size; ++i) { + constexpr int kSigmoidFactor = 1 << kNumExtraSigmoidShiftBits; + float x = ((i - sigmoid_offset) * kSigmoidFactor) / float_factor; + float sigmoid = 1.0f / (1.0f + expf(-x)); + sigmoid_lut[i] = static_cast(std::round(sigmoid * float_factor)); + } + sigmoid_luts[num_mantissa_bits_out - 1] = sigmoid_lut; + } + return sigmoid_luts[num_mantissa_bits_out - 1]; +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/numerics/fast_transcendentals.h b/sparse_matmul/numerics/fast_transcendentals.h new file mode 100644 index 00000000..2c73eeec --- /dev/null +++ b/sparse_matmul/numerics/fast_transcendentals.h @@ -0,0 +1,1177 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_ + +#include +#if defined __ARM_NEON || defined __aarch64__ +#include +#else +#include +#endif +#if defined __AVX__ || defined __AVX2__ +#include +#endif +#include + +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/type_utils.h" + +namespace csrblocksparse { + +// The input to exp is clipped to bounds that prevent overflow/underflow in a +// 32 bit float representation. e^80 ~ 6e34, which is close to maxfloat. +constexpr float kMaxExpInput = 80.f; +constexpr int kMaxExpInputInt = static_cast(kMaxExpInput); +constexpr float kMinExpInput = -80.f; +// tanh(9) ~ 0.99999997, which cannot be resolved from 1 in a float32. +constexpr float kMaxTanhInput = 9.f; +constexpr float kMinTanhInput = -9.f; +// sigmoid(18) ~ 0.999999985, which cannot be resolved from 1 in a float32. +constexpr float kMaxSigmoidInput = 18.f; +constexpr float kMinSigmoidInput = -18.f; +// kAConstant ~= 2^23 / ln 2 +constexpr uint32_t kAConstant = 0x4b38aa3b; +// kBConstant ~= (127 << 23) - 366000 +constexpr uint32_t kBConstant = 0x4e7de9a9; +// Coefficients of the rational approximation to tanh. +// Coefficients of the numerator polynomial (odd). +constexpr float kTanhAlpha1 = 4.89352455891786e-03; +constexpr float kTanhAlpha3 = 6.37261928875436e-04; +constexpr float kTanhAlpha5 = 1.48572235717979e-05; +constexpr float kTanhAlpha7 = 5.12229709037114e-08; +constexpr float kTanhAlpha9 = -8.60467152213735e-11; +constexpr float kTanhAlpha11 = 2.00018790482477e-13; +constexpr float kTanhAlpha13 = -2.76076847742355e-16; +// The monomial coefficients of the denominator polynomial (even). +constexpr float kTanhBeta0 = 4.89352518554385e-03; +constexpr float kTanhBeta2 = 2.26843463243900e-03; +constexpr float kTanhBeta4 = 1.18534705686654e-04; +constexpr float kTanhBeta6 = 1.19825839466702e-06; + +// Coefficients of the rational approximation to sigmoid. +// Coefficients of the numerator polynomial (odd). +constexpr float kSigmoidAlpha1 = 2.48287947061529e-01; +constexpr float kSigmoidAlpha3 = 8.51377133304701e-03; +constexpr float kSigmoidAlpha5 = 6.08574864600143e-05; +constexpr float kSigmoidAlpha7 = 1.15627324459942e-07; +constexpr float kSigmoidAlpha9 = 4.37031012579801e-11; + +// The monomial coefficients of the denominator polynomial (even). +constexpr float kSigmoidBeta0 = 9.93151921023180e-01; +constexpr float kSigmoidBeta2 = 1.16817656904453e-01; +constexpr float kSigmoidBeta4 = 1.70198817374094e-03; +constexpr float kSigmoidBeta6 = 6.29106785017040e-06; +constexpr float kSigmoidBeta8 = 5.76102136993427e-09; +constexpr float kSigmoidBeta10 = 6.10247389755681e-13; + +// x is the first term of the Taylor series approximation of tanh near 0 and +// because the leading error term of tanh(x) - x is O(x^3), it is good for a +// wide interval, use it in this region where the other approximation is +// inaccurate. tanh(x) = x - x^3 / 3 + 2x^5 / 15 - 17x^7 / 315 + ... +// Similarly for sigmoid where the first term is .25x +constexpr float kTanhLinearRegion = .15f; +constexpr float kSigmoidLinearRegion = .75f; + +// Maximum shift factor for 1/log 2 to keep it inside int32. +constexpr int kMaxLog2Shift = 30; +static const int kLogFactor = static_cast((1 << kMaxLog2Shift) / log(2.f)); +static const float kOneOverLog2 = 1.0f / log(2.f); +// Number of real mantissa bits in IEEE float32. +constexpr int kFloatMantissaBits = 23; +// Offset to correct the exponent value in the resulting float. +constexpr int kFloatExponentOffset = 127 << kFloatMantissaBits; +// Mask for mantissa. +constexpr int kFloatMantissaMask = (1 << kFloatMantissaBits) - 1; +// Mask for exponent; +constexpr int kFloatExponentMask = (-1) ^ kFloatMantissaMask; + +// ========== COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK ============ +// Summary: Use the exponent-mantissa representation of a floating point number +// to give exponentiation of 2 for free. If we desire f(z) = e^z = 2^(x+n), (for +// some fixed-point z expressed as an integer with imaginary binary point within +// it) then we have to compute x+n = z / ln 2 and then splitting x+n into +// n = int(x+n) and x = fract(x+n) in [0, 1), we can use n and 2^x as the +// exponent and mantissa of a floating point number, and that float is equal to +// e^z. For original reference see: +// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.9.4508&rep=rep1&type=pdf +// Important detail: +// IEEE floats are stored normalized, ie 1.bbbbbbb... x 2^exponent. The leading +// 1 bit is not actually stored, (as it is always 1), providing an extra bit of +// precision. +// Since 2^0=1 and 2^1=2, we can treat the problem as 2^x = 1 + u and we thus +// need a mapping x in [0, 1) -> u in [0, 1) and the 1 + is provided by the +// representation. +// In the original paper cited above, the mapping is u = x - c, where c is set +// to minimize the average error. The function to compute exp(x) this way is +// incredibly simple and computationally cheap, but not very accurate. +// Fortunately, the problem has been reduced to u = 2^x - 1 over [0, 1) for +// which it is far easier to construct accurate approximations with small +// polynomials than a full range exp(x), and this is what the cubic and quartic +// versions below do. An important feature of these functions is that they +// constrain the solution to be exact at 0 and 1 so there is continuity at each +// integer boundary where we wrap from 1 to 0 and increment the power of 2. + +// Coefficients for quartic representation of 2^x - 1 for x on [0,1). +// The quartic representation is 2^x - 1 ~ x - x(1-x)(ax^2 + bx + c), hence the +// coefficients of a quadratic are all that is required. +// Coefficients came from numerical experiments. +constexpr float kExpQuarticFactor2 = 0.0135302434f; +constexpr float kExpQuarticFactor1 = 0.0656107542f; +constexpr float kExpQuarticFactor0 = 0.306963906f; +// Coefficients for cubic representation of 2^x - 1 for x on [0,1] +// The cubic representation is 2^x - 1 ~ x - x(1-x)(mx + c), hence the +// coefficients of a linear function are all that is required. +// Coefficients came from numerical experiments. +constexpr float kExpCubicFactor1 = 0.0780252018f; +constexpr float kExpCubicFactor0 = 0.304684167f; +// Coefficients are optimized to minimize the absolute error on +// tanh = (e^2x - 1) / (e^2x + 1) instead of on pure e^x. + +// Enum that determines how a transcendental is computed. +enum TranscendentalMode { + // Cubic using 16 bit integer arithmetic. + TM_ORDER3_16BIT, + // Quartic using 16 bit integer arithmetic. + TM_ORDER4_16BIT, + // Quartic using 32 bit float arithmetic. + TM_ORDER4_FLOAT, +}; + +inline int FloatAsInt16(float x) { + return static_cast(x * (1 << 15) + 0.5f); +} + +inline int FloatAsInt32(float x) { + return static_cast(x * (1 << 30) + 0.5f); +} + +#if defined __ARM_NEON || defined __aarch64__ + +constexpr int kMaxSigmoidInputInt = static_cast(kMaxSigmoidInput); + +// Computes and returns 2^(x>>23) ie 2^u where x = u << 23 bits. +// Uses the quartic floating point exponent trick, see COMMON DOCUMENTATION FOR +// THE FLOATING EXPONENT TRICK above for details. +// Returns the true value, ie not scaled. +inline float32x4_t float32_pow2(float32x4_t x) { + // The input is already shifted left by 23 bits, so when we convert to int, + // the bottom 23 bits are the fractional part, and the top bits are the + // integer part. We want to compute a function of the fractional part, so + // we will mask it off and manipulate it. + int32x4_t exp_int_x = vcvtq_s32_f32(x); + // Mask to allow conversion of just the fractional part of x to fixed16<0>. + int32x4_t mantissa_mask16 = vdupq_n_s32(0x7fff00); + // Mask to allow conversion of just the fractional part of x to fixed32<1>. + int32x4_t mantissa_mask32 = vdupq_n_s32(0x7fffff); + // Narrowing shift to convert to fixed16<0>. + int16x4_t x_16 = vshrn_n_s32(vandq_s32(mantissa_mask16, exp_int_x), 8); + // Shift to convert to fixed32<1>. + int32x4_t x_32 = vshlq_n_s32(vandq_s32(mantissa_mask32, exp_int_x), 7); + // Compute the polynomial x(x - 1)(ax^2 + bx + c) of the fractional part. + // Ordering these lines carefully makes it faster, as some of the multiply + // operations can pipeline instead of waiting for the previous result. + int32x4_t x_squared = vmull_s16(x_16, x_16); + int16x4_t b = vdup_n_s16(FloatAsInt16(kExpQuarticFactor1)); + int32x4_t c = vdupq_n_s32(FloatAsInt32(kExpQuarticFactor0)); + int32x4_t bx_plus_c = vmlal_s16(c, b, x_16); + int16x4_t a = vdup_n_s16(FloatAsInt16(kExpQuarticFactor2)); + // Finish the quadratic: result = ax^2 + bx + c. + int32x4_t result = vmlal_s16(bx_plus_c, a, vshrn_n_s32(x_squared, 15)); + int32x4_t x_squared_minus_x = vsubq_s32(x_squared, x_32); + + // Multiply by x^2 - x. + result = vqrdmulhq_s32(result, x_squared_minus_x); + // Shift back to mantissa position. vqrdmulhq_s32 took 2x 30-mantissa bit + // inputs, made 60-mantissa bit result, doubled it to 61 bits, then discarded + // the bottom 32 making 29, so shift right 6 to get 23. + result = vshrq_n_s32(result, 6); + // Add the constant to normalize the exponent for IEEE format. + int32x4_t exp_offset = vdupq_n_s32(kFloatExponentOffset); + exp_int_x = vaddq_s32(exp_int_x, exp_offset); + exp_int_x = vaddq_s32(exp_int_x, result); + // Cast back to float, as we just computed the exponent and mantissa and + // assembled them in IEEE format. + return vreinterpretq_f32_s32(exp_int_x); +} + +// Scaled float to float exp approximation, using a quartic refinement of +// the exponent trick. See COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK +// above for details. Input is a fixed32<31 - mantissa_bits> that has been +// converted to a float without any further shifting. MUST HAVE ALREADY BEEN +// CLIPPED to a suitable range for exp! +// Returns a vector of standard unscaled floats. +inline float32x4_t fixed32_exp_float_preclipped(const int mantissa_bits, + float32x4_t x) { + // Divide by log 2 to convert problem to 2^x, and scale to match the + // mantissa bits required by IEEE floats. + // This is the shift of the FP mantissa relative to the input mantissa. + const int kXShift = kFloatMantissaBits - mantissa_bits; + const float kLogFactor = static_cast(1 << kXShift); + float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); + float32x4_t y = vmulq_f32(x, factor); + // Now compute 2^x. + return float32_pow2(y); +} + +// uses trick that 2^x can be computed by shifting integer into the +// exponent, see the following reference for a derivation using double: +// goo.gl/aUVTK3 +// Input x is clamped to [-64, 64], even infinity and NaN. +// Accurate to within 3% relative across the entire range. +// Fully pipelined throughput is about 10 cycles per fast_exp call. +inline float32x4_t fast_exp(float32x4_t x) { +#if defined FAST_TRANSCENDENTALS && __ARM_ARCH >= 800 + // Uses vcvtnq_s32_f32, not available on ARM v7 NEON. + + // Load A and B, which are defined as integers into float registers. + float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); + float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); + + // Make sure x within the allowed range. + x = vminq_f32(x, vdupq_n_f32(kMaxExpInput)); + x = vmaxq_f32(x, vdupq_n_f32(kMinExpInput)); + + // res = A * x + B. + // This shifts x into the exponent field and adds the bias. + res = vmlaq_f32(res, A, x); + + // Convert back to an integer, this is what uses the floating point + // unit to compute 2^x. + int32x4_t x_int = vcvtnq_s32_f32(res); + + return vreinterpretq_f32_s32(x_int); +#else + float32x4_t return_val = vdupq_n_f32(0.f); + + float exponent = expf(vgetq_lane_f32(x, 0)); + return_val = vld1q_lane_f32(&exponent, return_val, 0); + + exponent = expf(vgetq_lane_f32(x, 1)); + return_val = vld1q_lane_f32(&exponent, return_val, 1); + exponent = expf(vgetq_lane_f32(x, 2)); + return_val = vld1q_lane_f32(&exponent, return_val, 2); + exponent = expf(vgetq_lane_f32(x, 3)); + return_val = vld1q_lane_f32(&exponent, return_val, 3); + + return return_val; +#endif // FAST_TRANSCENDENTALS +} + +// This version does a conversion of the input to floating point, then calls +// the floating point fast_exp function. There is another version +// fast_exp_fixed, that never does a conversion and is less accurate, but much +// faster. +template +inline float32x4_t fast_exp(int32x4_t x) { + return fast_exp(vcvtq_n_f32_s32(x, 31 - ExponentBits)); +} + +// Performs an exp estimate without doing any floating point operations. The +// result is a floating point number. See scalar version for an explanation. +template +inline float32x4_t fast_exp_fixed(int32x4_t x) { + static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); + constexpr int kA = 1.4426950408889634 * (1 << (ExponentBits - 8)); + constexpr int kB = (127 << 23) - 366000; + + constexpr int maxInput = 80 << (31 - ExponentBits); + constexpr int minInput = -maxInput; + + int32x4_t A = vdupq_n_s32(kA); + int32x4_t res = vdupq_n_s32(kB); + + // Make sure x within the allowed range. + x = vminq_s32(x, vdupq_n_s32(maxInput)); + x = vmaxq_s32(x, vdupq_n_s32(minInput)); + + // res = A * x + B. + // This shifts x into the exponent field and adds the bias. + res = vmlaq_s32(res, A, x); + + return vreinterpretq_f32_s32(res); +} + +// fast_exp_norange_check uses vcvtnq_s32_f32, not available on ARM v7 NEON. +#if __ARM_ARCH >= 800 +namespace detail { +// tanh can do range check once. +// Input x is clamped to [-64, 64], even infinity and NaN. +inline float32x4_t fast_exp_norange_check(float32x4_t x) { + float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); + float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); + + res = vmlaq_f32(res, A, x); + + int32x4_t x_int = vcvtnq_s32_f32(res); + + return vreinterpretq_f32_s32(x_int); +} + +} // namespace detail +#endif // __ARM_ARCH >= 800 + +// Clips float input to [-kLimit,kLimit]. +inline float32x4_t ClipToFloatBounds(const float kLimit, const float32x4_t x) { + // Clip to the input bounds for this approximation. + float32x4_t clip_limit = vdupq_n_f32(kLimit); + float32x4_t clipped_x = vminq_f32(x, clip_limit); + clip_limit = vnegq_f32(clip_limit); + return vmaxq_f32(clipped_x, clip_limit); +} + +inline float32x4_t float_tanh_float(const float32x4_t& x) { + float32x4_t clipped_x = ClipToFloatBounds(kMaxTanhInput, x); + // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and + // scale to the mantissa bits required by float32_pow2 all in one multiply. + // Add one to double the input. + const float kLogFactor = static_cast(1 << (kFloatMantissaBits + 1)); + float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); + clipped_x = vmulq_f32(clipped_x, factor); + // Now compute 2^x. + float32x4_t exp_result = float32_pow2(clipped_x); + // Now compute tanh using (e^2x - 1) / (e^2x + 1). + float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t numerator = vsubq_f32(exp_result, one); + float32x4_t denominator = vaddq_f32(exp_result, one); + float32x4_t recp = vrecpeq_f32(denominator); + // Newton-Raphson iteration, accuracy is important for audio quality + recp = vmulq_f32(recp, vrecpsq_f32(recp, denominator)); + recp = vmulq_f32(recp, numerator); + // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low + // relative error close to 0. + float32x4_t third = vdupq_n_f32(1.0f / 3.0f); + float32x4_t taylor = vmulq_f32(x, x); + taylor = vmulq_f32(taylor, x); + taylor = vmulq_f32(taylor, third); + taylor = vsubq_f32(x, taylor); + // Test |x| <= 1/9, roughly where the errors cross over, without needing yet + // another constant. + float32x4_t ninth = vmulq_f32(third, third); + uint32x4_t cmp_results = vcaleq_f32(x, ninth); + return vbslq_f32(cmp_results, taylor, recp); +} + +// Calculates (exp(x) - exp(-x)) / (exp(x) + exp(-x)). +// Input x is clamped to [-9, 9], even infinity and NaN. +// See test program for bounds. Throughput of FAST is 334 Mega/sec, +// throughput of accurate is 232 Mega/sec. +inline float32x4_t fast_tanh(float32x4_t x) { +#if defined FASTER_TRANSCENDENTALS + return float_tanh_float(x); +#elif defined ACCURATE_TRANSCENDENTAL_APPROX && defined FAST_TRANSCENDENTALS + x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); + x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); + + // The monomial coefficients of the numerator polynomial (odd). + const float32x4_t alpha_1 = vdupq_n_f32(kTanhAlpha1); + const float32x4_t alpha_3 = vdupq_n_f32(kTanhAlpha3); + const float32x4_t alpha_5 = vdupq_n_f32(kTanhAlpha5); + const float32x4_t alpha_7 = vdupq_n_f32(kTanhAlpha7); + const float32x4_t alpha_9 = vdupq_n_f32(kTanhAlpha9); + const float32x4_t alpha_11 = vdupq_n_f32(kTanhAlpha11); + const float32x4_t alpha_13 = vdupq_n_f32(kTanhAlpha13); + + // The monomial coefficients of the denominator polynomial (even). + const float32x4_t beta_0 = vdupq_n_f32(kTanhBeta0); + const float32x4_t beta_2 = vdupq_n_f32(kTanhBeta2); + const float32x4_t beta_4 = vdupq_n_f32(kTanhBeta4); + const float32x4_t beta_6 = vdupq_n_f32(kTanhBeta6); + + // Since the polynomials are odd/even, we need x^2. + const float32x4_t x2 = vmulq_f32(x, x); + + // Evaluate the numerator polynomial |p|. + float32x4_t p = vmlaq_f32(alpha_11, x2, alpha_13); + p = vmlaq_f32(alpha_9, x2, p); + p = vmlaq_f32(alpha_7, x2, p); + p = vmlaq_f32(alpha_5, x2, p); + p = vmlaq_f32(alpha_3, x2, p); + p = vmlaq_f32(alpha_1, x2, p); + p = vmulq_f32(x, p); + + // Evaluate the denominator polynomial p. + float32x4_t q = vmlaq_f32(beta_4, x2, beta_6); + q = vmlaq_f32(beta_2, x2, q); + q = vmlaq_f32(beta_0, x2, q); + + // Divide the numerator by the denominator. + float32x4_t recp = vrecpeq_f32(q); + recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); + return vmulq_f32(p, recp); +#elif defined FAST_TRANSCENDENTALS && __ARM_ARCH >= 800 + // Uses vcvtnq_s32_f32, not available on ARM v7 NEON. + + x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); + x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); + float32x4_t exp_est = detail::fast_exp_norange_check(x); + float32x4_t neg_exp_est = detail::fast_exp_norange_check(-x); + + // If we're in the linear region. + // caleq = compare absolute <= + uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kTanhLinearRegion)); + + float32x4_t diff = vsubq_f32(exp_est, neg_exp_est); + float32x4_t sum = vaddq_f32(exp_est, neg_exp_est); + float32x4_t recp = vrecpeq_f32(sum); + recp = vmulq_f32(recp, vrecpsq_f32(recp, sum)); + float32x4_t tanh_estimate = vmulq_f32(diff, recp); + + // Based on comparison, possibly copy x through instead of calculated value. + // TODO(b/191497441): Is the compiler generating VBIT or VBSL ? VBIT is one + // cycle and VBSL is two... documentation suggests it can do either. + return vbslq_f32(cmp_results, x, tanh_estimate); +#else + float32x4_t return_val = vdupq_n_f32(0.f); + + float tanh_value = tanhf(vgetq_lane_f32(x, 0)); + return_val = vld1q_lane_f32(&tanh_value, return_val, 0); + tanh_value = tanhf(vgetq_lane_f32(x, 1)); + return_val = vld1q_lane_f32(&tanh_value, return_val, 1); + tanh_value = tanhf(vgetq_lane_f32(x, 2)); + return_val = vld1q_lane_f32(&tanh_value, return_val, 2); + tanh_value = tanhf(vgetq_lane_f32(x, 3)); + return_val = vld1q_lane_f32(&tanh_value, return_val, 3); + + return return_val; +#endif // FAST_TRANSCENDENTALS +} + +// Input x is clamped to [-18, 18], even infinity and NaN. +// See tests for error bounds. Using SIGMOID_AS_TANH with +// ACCURATE_TRANSCENDENTAL_APPROX is both faster and more accurate. Using +// SIGMOID_AS_TANH with just FAST is slower, but more accurate. +// SIGMOID_AS_TANH, ACCURATE is 205 Mega/sec +// SIGMOID_AS_TANH, FAST is 290 Mega/sec +// FAST is 340 Mega/sec +inline float32x4_t fast_sigmoid(float32x4_t x) { +#ifdef SIGMOID_AS_TANH + float32x4_t half = vdupq_n_f32(0.5f); + return vmlaq_f32(half, half, fast_tanh(vmulq_f32(half, x))); +#else // SIGMOID_AS_TANH +#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX + x = vminq_f32(x, vdupq_n_f32(kMaxSigmoidInput)); + x = vmaxq_f32(x, vdupq_n_f32(kMinSigmoidInput)); + + // The monomial coefficients of the numerator polynomial (odd). + const float32x4_t alpha_1 = vdupq_n_f32(kSigmoidAlpha1); + const float32x4_t alpha_3 = vdupq_n_f32(kSigmoidAlpha3); + const float32x4_t alpha_5 = vdupq_n_f32(kSigmoidAlpha5); + const float32x4_t alpha_7 = vdupq_n_f32(kSigmoidAlpha7); + const float32x4_t alpha_9 = vdupq_n_f32(kSigmoidAlpha9); + + // The monomial coefficients of the denominator polynomial (even). + const float32x4_t beta_0 = vdupq_n_f32(kSigmoidBeta0); + const float32x4_t beta_2 = vdupq_n_f32(kSigmoidBeta2); + const float32x4_t beta_4 = vdupq_n_f32(kSigmoidBeta4); + const float32x4_t beta_6 = vdupq_n_f32(kSigmoidBeta6); + const float32x4_t beta_8 = vdupq_n_f32(kSigmoidBeta8); + const float32x4_t beta_10 = vdupq_n_f32(kSigmoidBeta10); + + // Since the polynomials are odd/even, we need x^2. + const float32x4_t x2 = vmulq_f32(x, x); + + // Evaluate the numerator polynomial p. + float32x4_t p = vmlaq_f32(alpha_7, x2, alpha_9); + p = vmlaq_f32(alpha_5, x2, p); + p = vmlaq_f32(alpha_3, x2, p); + p = vmlaq_f32(alpha_1, x2, p); + p = vmulq_f32(x, p); + + // Evaluate the denominator polynomial p. + float32x4_t q = vmlaq_f32(beta_8, x2, beta_10); + q = vmlaq_f32(beta_6, x2, q); + q = vmlaq_f32(beta_4, x2, q); + q = vmlaq_f32(beta_2, x2, q); + q = vmlaq_f32(beta_0, x2, q); + + // Divide the numerator by the denominator. + float32x4_t recp = vrecpeq_f32(q); + recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); + return vmlaq_f32(vdupq_n_f32(0.5f), p, recp); +#elif defined FAST_TRANSCENDENTALS + float32x4_t denom = vaddq_f32(fast_exp(vnegq_f32(x)), vdupq_n_f32(1.f)); + + float32x4_t recp = vrecpeq_f32(denom); + // Newton-Raphson iteration, accuracy is important for audio quality. + recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); + float32x4_t half = vdupq_n_f32(0.5f); + float32x4_t quarter = vdupq_n_f32(0.245f); + float32x4_t linear_approx = vmlaq_f32(half, quarter, x); + uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kSigmoidLinearRegion)); + + return vbslq_f32(cmp_results, linear_approx, recp); +#else + float32x4_t return_val = vdupq_n_f32(0.f); + + float result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 0))); + return_val = vld1q_lane_f32(&result, return_val, 0); + result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 1))); + return_val = vld1q_lane_f32(&result, return_val, 1); + result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 2))); + return_val = vld1q_lane_f32(&result, return_val, 2); + result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 3))); + return_val = vld1q_lane_f32(&result, return_val, 3); + + return return_val; +#endif // FAST_TRANSCENDENTALS +#endif // SIGMOID_AS_TANH +} + +// Scalar implementations, mainly useful for testing. +inline float fast_exp(float x) { + return vgetq_lane_f32(fast_exp(vdupq_n_f32(x)), 0); +} + +template +inline float fast_exp(fixed32 x) { + return vgetq_lane_f32(fast_exp(vdupq_n_s32(x.raw_val())), 0); +} + +// Returns the exponent of a fixed point number in floating point without ever +// doing any conversions. Less accurate than the version that does conversions, +// but still accurate to within 4% relative for x < 16. +template +inline float fast_exp_fixed(fixed32 x) { + return vgetq_lane_f32(fast_exp_fixed(vdupq_n_s32(x.raw_val())), + 0); +} + +inline float fast_sigmoid(float x) { + return vgetq_lane_f32(fast_sigmoid(vdupq_n_f32(x)), 0); +} + +inline float fast_tanh(float x) { + return vgetq_lane_f32(fast_tanh(vdupq_n_f32(x)), 0); +} + +// Clips integer input to [-|kLimit|, |kLimit|]. +// Input: register containins 4x fixed32 with mantissa_bits. +// Output: register containing 4x fixed32 limited to +// [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. +template +inline int32x4_t ClipToBounds(const int mantissa_bits, const int32x4_t x) { + // Clip to the input bounds for this approximation. + int32x4_t clip_limit = vdupq_n_s32(-(kLimit << mantissa_bits)); + int32x4_t clipped_x = vmaxq_s32(x, clip_limit); + clip_limit = vnegq_s32(clip_limit); + return vminq_s32(clipped_x, clip_limit); +} + +// Fixed32 sigmoid approximation via a quadratic refinement of the exponent +// trick. +// Input: Register containing 4x fixed32 with |mantissa_bits|. +// Output: Register containing 4x float results. +inline float32x4_t fixed32_sigmoid_float(const int mantissa_bits, + const int32x4_t x) { + int32x4_t input = vnegq_s32(x); + float32x4_t y = + vcvtq_f32_s32(ClipToBounds(mantissa_bits, input)); + y = fixed32_exp_float_preclipped(mantissa_bits, y); + float32x4_t one = vdupq_n_f32(1.0f); + // Approximate reciprocal is not accurate enough - use full division. + float32x4_t denom = vaddq_f32(y, one); + float32x4_t recp = vrecpeq_f32(denom); + // Newton-Raphson iteration, accuracy is important for audio quality + recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); + return recp; +} + +template +inline float32x4_t fast_sigmoid(int32x4_t x) { +#if defined FASTER_TRANSCENDENTALS + // Computation will fail to produce the right result if the input mantissa + // bits exceeds the number in a float. + static_assert(kFloatMantissaBits >= fixed32::kMantissaBits, + "Mantissa bits must be at most 23!"); + return fixed32_sigmoid_float(fixed32::kMantissaBits, x); +#else + return fast_sigmoid(vcvtq_n_f32_s32(x, fixed32::kMantissaBits)); +#endif // FASTER_TRANSCENDENTALS +} + +template +inline float fast_sigmoid(fixed32 x) { + return vgetq_lane_f32(fast_sigmoid(vdupq_n_s32(x.raw_val())), + 0); +} + +#else // defined __ARM_NEON || defined __aarch64__ + +inline float fast_exp(float x) { +#ifdef FAST_TRANSCENDENTALS + if (isnan(x)) return 0.0f; + x = std::max(std::min(x, kMaxExpInput), kMinExpInput); + float AConstant, BConstant; + memcpy(&AConstant, &kAConstant, sizeof(int)); + memcpy(&BConstant, &kBConstant, sizeof(int)); + float y = x * AConstant + BConstant; + int x_int = static_cast(y); + float ret; + memcpy(&ret, &x_int, sizeof(float)); + return ret; +#else + return expf(x); +#endif // FAST_TRANSCENDENTALS +} + +template +inline float fast_exp(fixed32 x) { + return fast_exp(static_cast(x)); +} + +template +inline float fast_exp_fixed(fixed32 x) { + static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); + int matched_decimal = + std::max(std::min(x.raw_val(), (80 << (31 - ExponentBits))), + -(80 << (31 - ExponentBits))); + // Convert 1 / log(2) to 16-bit fixed point with 1 exponent bit + // (1 / log(2)) * (1 << 14), but then right shift by the appropriate amount to + // line the decimal point up with the 32-bit float representation. + // (MantissaBits of x) + (MantissaBits of constant) = 23 + // 23 - (MantissaBits of x) = MantissaBits of constant + // 23 - (31 - ExponentBits of x) = ... + // (ExponentBits of x - 8) = MantissaBits of constant + const int16_t A = (1.f / logf(2.f)) * (1 << (ExponentBits - 8)); + // Same rationale as for floating point versions, bias exponent, subtract + // 366000 to reduce error by centering approximation, instead of being + // one-sided. + const int B = (127 << 23) - 366000; + matched_decimal = A * matched_decimal + B; + float ret_val; + memcpy(&ret_val, &matched_decimal, sizeof(float)); + return ret_val; +} + +inline float fast_tanh(float x) { +#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX + // Doesn't do anything fancy, just a 13/6-degree rational interpolant which + // is accurate up to a couple of ulp in the range [-9, 9], outside of which + // fl(tanh(x)) = +/-1. + x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); + + // Since the polynomials are odd/even, we need x^2. + float x2 = x * x; + + // Evaluate numerator. + float p = kTanhAlpha11 + x2 * kTanhAlpha13; + p = kTanhAlpha9 + x2 * p; + p = kTanhAlpha7 + x2 * p; + p = kTanhAlpha5 + x2 * p; + p = kTanhAlpha3 + x2 * p; + p = kTanhAlpha1 + x2 * p; + p = x * p; + + // Evaluate denominator. + float q = kTanhBeta4 + x2 * kTanhBeta6; + q = kTanhBeta2 + x2 * q; + q = kTanhBeta0 + x2 * q; + + return p / q; +#elif defined FAST_TRANSCENDENTALS + if (std::abs(x) < kTanhLinearRegion) { + return x; + } else { + x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); + float positive = fast_exp(x); + float negative = fast_exp(-x); + return (positive - negative) / (positive + negative); + } +#else + return tanhf(x); +#endif // FAST_TRANSCENDENTALS +} + +inline float fast_sigmoid(float x) { +#ifdef SIGMOID_AS_TANH + return .5f * fast_tanh(.5f * x) + .5f; +#else +#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX + // Doesn't do anything fancy, just a 9/10-degree rational interpolant which + // interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulp in the range + // [-18, 18], outside of which the fl(sigmoid(x)) = {0|1}. The shifted + // sigmoid is interpolated because it was easier to make the fit converge. + // See GenericPacketMath.h* in the open source Eigen library. + x = std::max(std::min(x, kMaxSigmoidInput), kMinSigmoidInput); + + // Since the polynomials are odd/even, we need x^2. + float x2 = x * x; + + // Evaluate numerator. + float p = kSigmoidAlpha7 + x2 * kSigmoidAlpha9; + p = kSigmoidAlpha5 + x2 * p; + p = kSigmoidAlpha3 + x2 * p; + p = kSigmoidAlpha1 + x2 * p; + p = x * p; + + // Evaluate denominator. + float q = kSigmoidBeta8 + x2 * kSigmoidBeta10; + q = kSigmoidBeta6 + x2 * q; + q = kSigmoidBeta4 + x2 * q; + q = kSigmoidBeta2 + x2 * q; + q = kSigmoidBeta0 + x2 * q; + + return p / q + 0.5f; +#elif defined FAST_TRANSCENDENTALS + if (std::abs(x) < kSigmoidLinearRegion) { + return .245 * x + .5; + } else { + return 1.f / (1.f + fast_exp(-x)); + } +#else + return 1.f / (1.f + expf(-x)); +#endif // FAST_TRANSCENDENTALS +#endif // SIGMOID_AS_TANH +} + +template +inline float fast_sigmoid(fixed32 x) { + return fast_sigmoid(static_cast(x)); +} + +#endif // defined __aarch64__ + +// Number of exponent bits to use for tanh. +static constexpr int kNumTanhExpBits = 3; +// Number of exponent bits to use for sigmoid. +static constexpr int kNumSigmoidExpBits = 4; +// Number of extra bits to shift sigmoid, due to its low gradient. +static constexpr int kNumExtraSigmoidShiftBits = 1; + +// Returns (and builds if not done yet) a static data table (that is never +// deleted, as per the style guide) that implements tanh on fixed32 input, +// returning another fixed32 with the given number of mantissa bits (which is +// assumed to be less than the input mantissa bits). +// NOTE that this function is intended to be used only with fixed16 outputs that +// are sign-extended to 32 bits for convenience, and will return a nullptr +// if asked for more than |kMaxMantissaBits| of precision in the output table. +const int* TanhTable(int num_mantissa_bits_out); +// As TanhTable, but for Sigmoid. +const int* SigmoidTable(int num_mantissa_bits_out); + +// Scalar/generic function to compute and return the fast approximation to exp +// via a polynomial refinement of the floating point exponent trick. +// TM_ORDER4_16BIT:Max relative error < 5e-6, absolute error < 1e-5 for x < 1. +// TM_ORDER3_16BIT:Max relative error < 1.1e-4, absolute error < 3e-4 for x +// < 1. +template +float fixed32_exp(fixed32 x) { + constexpr int kMantissaBits = MantissaBitsOf>::value; + // Clip x to min/max exp input to avoid infinities. + int64_t clipped_x = + std::max(std::min(x.raw_val(), kMaxExpInputInt << kMantissaBits), + -(kMaxExpInputInt << kMantissaBits)); + // First convert problem from e^x to 2^x by multiplying by 1/log(2). + // To maximize precision, log_factor is shifted left the maximum amount to + // keep within int32, and we shift x left a further amount such that the + // binary point of the product sits in the correct place in the top 32 bits of + // the result to be used directly as a float. We can't do that directly, as x + // would overflow, so we have to shift by 1 bit less and shift the result by + // 1 bit less to match. + constexpr int kXShift = + kFloatMantissaBits + 31 - kMaxLog2Shift - kMantissaBits; + static_assert(kXShift >= 0, + "Mantissa bits > kFloatMantissaBits + 31 - kMaxLog2Shift"); + clipped_x <<= kXShift; + int float_as_int = (kLogFactor * clipped_x >> 31) + kFloatExponentOffset; + // Separate the resulting fixed-point into integer and fractional parts. + int int_part = float_as_int & kFloatExponentMask; + int float_part = float_as_int & kFloatMantissaMask; + float fraction = static_cast(float_part) / (1 << kFloatMantissaBits); + // Compute the mantissa = 2^fraction using: + // fraction - fraction*(1-fraction)*(polynomial of fraction) + // This guarantees exactness at 0 and 1, providing continuity of the error at + // integer boundaries. + float mantissa; + if (kOrder == TM_ORDER4_16BIT || kOrder == TM_ORDER4_FLOAT) { + mantissa = (kExpQuarticFactor2 * fraction + kExpQuarticFactor1) * fraction + + kExpQuarticFactor0; + } else if (kOrder == TM_ORDER3_16BIT) { + mantissa = kExpCubicFactor1 * fraction + kExpCubicFactor0; + } + mantissa = fraction - fraction * (1.0f - fraction) * mantissa; + // Since the function above guarantees to stay within [0, 1), we could do all + // the above in fixed point if necessary, in which case, we can just stuff + // the bottom kFloatMantissaBits in with the exponent and we are done. + // In the floating point world, it is simpler to just multiply them together. + float result; + memcpy(&result, &int_part, sizeof(float)); + return result * (1.0f + mantissa); +} + +// Computes and returns tanh(x) fixed32->float using a polynomial refinement of +// the floating point exponent trick. +// kOrder=4: Absolute error < 1.8e-6. Relative error < 1.2e-4 for |x| > 0.01. +// kOrder=3: Absolute error < 6e-5. Relative error < 3e-3 for |x| > 0.01 +template +float fixed32_tanh(fixed32 x) { + float float_x = static_cast(x); + if (std::abs(float_x) < 1.0f / 9.0f) { + return float_x * (1 - float_x * float_x / 3.0f); + } + x = static_cast>(x.raw_val() * 2); + float exp_2x = fixed32_exp(x); + return (exp_2x - 1.0f) / (exp_2x + 1.0f); +} + +// Computes and returns sigmoid(x) fixed32->float using a polynomial refinement +// of the floating point exponent trick. +// TM_ORDER4_16BIT: Absolute error < 9e-7, relative < 4e-6. +// TM_ORDER3_16BIT: Absolute error < 3e-5, relative < 1.1e-4. +template +float fixed32_sigmoid(fixed32 x) { + x = static_cast>(-x.raw_val()); + float exp_x = fixed32_exp(x); + return 1.0f / (exp_x + 1.0f); +} + +#if defined __AVX2__ + +// Inline function to access an int32 data table by shifting |x| right by +// |kNumShiftBits|, and adding |kTableOffset| to the result. |x| contains 8 +// indices and 8 results are returned. The data table is of size +// |kTableOffset| * 2 + 1. +template +inline __m256i index_data_table(const int32_t* data_table, const __m256i& x) { + // Shift right with rounding to match input and output precision. + __m256i shifted = _mm256_set1_epi32(1 << (kNumShiftBits - 1)); + shifted = _mm256_add_epi32(x, shifted); + shifted = _mm256_srai_epi32(shifted, kNumShiftBits); + // Add the offset. + __m256i addend = _mm256_set1_epi32(kTableOffset); + shifted = _mm256_add_epi32(shifted, addend); + // And clamp to the indices of the LUT. + addend = _mm256_add_epi32(addend, addend); + shifted = _mm256_min_epi32(shifted, addend); + shifted = _mm256_max_epi32(shifted, _mm256_setzero_si256()); + // Lookup the results in the table. + return _mm256_i32gather_epi32(data_table, shifted, 4); +} + +// Fixed32 to fixed16-in-an-int32 tanh LUT function. +// Input: register containins 8x fixed32 with |NumInputMantissaBits|. +// Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but +// note that they are sign-extended to 32 bits and are therefore basically the +// same as fixed32 with |NumOutputMantissaBits|. +template +inline __m256i fixed32_tanh_fixed16(const int* tanh_table, const __m256i& x) { + // Lose the unnecessary input precision. + constexpr int kNumShiftBits = NumInputMantissaBits - NumOutputMantissaBits; + constexpr int kTableOffset = 1 << (NumOutputMantissaBits + kNumTanhExpBits); + return index_data_table(tanh_table, x); +} + +// Fixed32 to fixed16-in-an-int32 sigmoid LUT function. +// Input: register containins 8x fixed32 with |NumInputMantissaBits|. +// Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but +// note that they are sign-extended to 32 bits and are therefore basically the +// same as fixed32 with |NumOutputMantissaBits|. +template +inline __m256i fixed32_sigmoid_fixed16(const int* sigmoid_table, + const __m256i& x) { + // Lose the unnecessary input precision. + constexpr int kNumShiftBits = + kNumExtraSigmoidShiftBits + NumInputMantissaBits - NumOutputMantissaBits; + constexpr int kTableOffset = 1 + << (NumOutputMantissaBits + kNumSigmoidExpBits - + kNumExtraSigmoidShiftBits); + return index_data_table(sigmoid_table, x); +} + +// Convert 2x registers of 8x float32 into 1 register of 16x16 bit fixed int, +// assuming that the floats are already scaled up. +inline __m256i PackFloatsToFixed16(const __m256& x0, const __m256& x1) { + __m256i int0 = _mm256_cvtps_epi32(x0); + __m256i int1 = _mm256_cvtps_epi32(x1); + int0 = _mm256_packs_epi32(int0, int1); + // Swap the middle 64 bit elements so the results are in the right order. + return _mm256_permute4x64_epi64(int0, 0xd8); +} + +// Clips integer input to [-|kLimit|, |kLimit|]. +// Input: register containins 8x fixed32 with |mantissa_bits|. +// Output: register containing 8x fixed32 limited to +// [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. +template +inline __m256i ClipToBounds(const int mantissa_bits, const __m256i& x) { + // Clip to the input bounds for this approximation. + __m256i clip_limit = _mm256_set1_epi32(-(kLimit << mantissa_bits)); + __m256i clipped_x = _mm256_max_epi32(x, clip_limit); + // This quickly negates the limit without having to load another constant. + clip_limit = _mm256_sign_epi32(clip_limit, clip_limit); + return _mm256_min_epi32(clipped_x, clip_limit); +} + +// Clips float input to [-|kLimit|, |kLimit|]. +// Input: register containins 8x float. +// Output: register containing 8x float limited to [-|kLimit|, |kLimit|]. +inline __m256 ClipToFloatBounds(const float kLimit, const __m256& x) { + __m256 clip_limit = _mm256_set1_ps(kLimit); + __m256 clipped_x = _mm256_min_ps(x, clip_limit); + clip_limit = _mm256_set1_ps(-kLimit); + return _mm256_max_ps(clipped_x, clip_limit); +} + +// Float to float power of 2 approximation, using a quartic refinement of +// the exponent trick. For TM_ORDER4_16BIT and TM_ORDER3_16BIT, implementation +// is entirely in integer, using 16x16=16 multiplication, using AVX2, which +// enables 16 elements to be computed in parallel, hence the double register +// input/output args. +// The price paid for this speed is an increase in error over the (scalar) int32 +// example implementations above by a variable factor of 4-10. +// For the TM_ORDER4_FLOAT case, the computation is all done in float, solving +// this lower precision problem. +// NOTE: The input must have already been clipped to prevent overflow, which +// sets the practical limit to +/-126 << kFloatMantissaBits. +// NOTE: The input is a scaled float, as if converted raw from int, and the +// scale factor is fixed at kFloatMantissaBits! +// Input: 2x register containining 8x float * 1 << kFloatMantissaBits. +// Output: 2x register containing 8x float. +// TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. +// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. +// TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. +template +inline void float32_pow2(__m256& x0, __m256& x1) { + // Convert straight to int. + __m256i exp_int_x0 = _mm256_cvtps_epi32(x0); + __m256i exp_int_x1 = _mm256_cvtps_epi32(x1); + __m256i result_x0, result_x1; + + static_assert(kOrder == TM_ORDER4_FLOAT || kOrder == TM_ORDER4_16BIT || + kOrder == TM_ORDER3_16BIT, + "Invalid order."); + + if (kOrder == TM_ORDER4_FLOAT) { + __m256i mantissa_mask = _mm256_set1_epi32(0x7fffff); + __m256 float_factor = + _mm256_set1_ps(1.0f / static_cast(1 << kFloatMantissaBits)); + __m256i fract0 = _mm256_and_si256(mantissa_mask, exp_int_x0); + __m256i fract1 = _mm256_and_si256(mantissa_mask, exp_int_x1); + __m256 float0 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract0), float_factor); + __m256 float1 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract1), float_factor); + // Compute the polynomial of the fractional part. + // Ordering these lines carefully makes it faster, as some of the multiply + // operations can pipeline instead of waiting for the previous result. + __m256 x_squared0 = _mm256_mul_ps(float0, float0); + __m256 x_squared1 = _mm256_mul_ps(float1, float1); + __m256 b = _mm256_set1_ps(kExpQuarticFactor1); + __m256 b_x0 = _mm256_mul_ps(b, float0); + __m256 b_x1 = _mm256_mul_ps(b, float1); + __m256 a = _mm256_set1_ps(kExpQuarticFactor2); + __m256 a_x_squared0 = _mm256_mul_ps(a, x_squared0); + __m256 a_x_squared1 = _mm256_mul_ps(a, x_squared1); + __m256 x_squared_minus_x0 = _mm256_sub_ps(x_squared0, float0); + __m256 x_squared_minus_x1 = _mm256_sub_ps(x_squared1, float1); + __m256 c = _mm256_set1_ps(kExpQuarticFactor0); + b_x0 = _mm256_add_ps(b_x0, c); + b_x1 = _mm256_add_ps(b_x1, c); + float_factor = _mm256_set1_ps(static_cast(1 << kFloatMantissaBits)); + a_x_squared0 = _mm256_add_ps(a_x_squared0, b_x0); + a_x_squared1 = _mm256_add_ps(a_x_squared1, b_x1); + a_x_squared0 = _mm256_mul_ps(a_x_squared0, x_squared_minus_x0); + a_x_squared1 = _mm256_mul_ps(a_x_squared1, x_squared_minus_x1); + result_x0 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared0, float_factor)); + result_x1 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared1, float_factor)); + } else { + // Combine the fractional part of both inputs into a single register. + // The representation is fixed16<0>, ie 15 mantissa bits. + __m256i mantissa_mask = _mm256_set1_epi32(0x7fff00); + __m256i x_01 = + _mm256_srli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x0), 8); + x_01 = _mm256_or_si256( + x_01, + _mm256_slli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x1), 8)); + // Compute the polynomial of the fractional part. + // Ordering these lines carefully makes it faster, as some of the multiply + // operations can pipeline instead of waiting for the previous result. + __m256i x_squared = _mm256_mulhrs_epi16(x_01, x_01); + __m256i result, x_squared_minus_x; + if (kOrder == TM_ORDER4_16BIT) { + __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor1)); + __m256i b_x = _mm256_mulhrs_epi16(b, x_01); + __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor2)); + __m256i a_x_squared = _mm256_mulhrs_epi16(a, x_squared); + x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); + // LOG(INFO) << "x_squared_minus_x=" << + // static_cast(_mm256_extract_epi16(x_squared_minus_x, 0)) / + // 32768.0f; + __m256i c = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); + b_x = _mm256_add_epi16(b_x, c); + // LOG(INFO) << "bx+c=" << static_cast(_mm256_extract_epi16(b_x, + // 0)) / 32768.0f; + result = _mm256_add_epi16(a_x_squared, b_x); + } else { // kOrder = TM_ORDER3_16BIT + __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpCubicFactor1)); + __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); + __m256i a_x = _mm256_mulhrs_epi16(a, x_01); + x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); + result = _mm256_add_epi16(a_x, b); + } + result = _mm256_mulhrs_epi16(result, x_squared_minus_x); + // Extract 16x16-bit results back to the separate sets of 8x32. + result_x0 = _mm256_slli_epi32(result, 16); + result_x0 = _mm256_srai_epi32(result_x0, 8); + result_x1 = _mm256_srai_epi32(result, 16); + result_x1 = _mm256_slli_epi32(result_x1, 8); + } + // Add the constant to normalize the exponent. + __m256i exp_offset = _mm256_set1_epi32(kFloatExponentOffset); + exp_int_x0 = _mm256_add_epi32(exp_int_x0, exp_offset); + exp_int_x0 = _mm256_add_epi32(exp_int_x0, result_x0); + exp_int_x1 = _mm256_add_epi32(exp_int_x1, exp_offset); + exp_int_x1 = _mm256_add_epi32(exp_int_x1, result_x1); + // Cast back to float, as we just computed the exponent and mantissa and + // assembled them in IEEE format. + x0 = _mm256_castsi256_ps(exp_int_x0); + x1 = _mm256_castsi256_ps(exp_int_x1); +} + +// Fixed32 to to float exp approximation, using a quartic/cubic refinement of +// the exponent trick. Implementation is entirely in integer, using 16x16=16 +// multiplication, using AVX2, which enables 16 elements to be computed in +// parallel, hence the double register input/output args. +// The price paid for this speed is an increase in error over the (scalar) int32 +// example implementations above by a variable factor of 4-10. +// The TM_ORDER4_FLOAT version uses floats and improves the precision. +// Input: 2x registers containins 8x fixed32 with kMantissaBits. +// Output: 2x registers containing 8x float32. +// TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. +// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. +// TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. +template +inline void float_exp_float_preclipped(__m256& y0, __m256& y1) { + // Divide by log 2 to convert problem to 2^x, and scale to match the + // mantissa bits required by IEEE floats. Without a _mm256_mulhrs_epi32, it is + // much easier to do this in float, even with the double conversion, as 16 bit + // is not precise enough here. + // This is the shift of the FP mantissa relative to the input mantissa. + constexpr int kXShift = kFloatMantissaBits - kInputMantissaBits; + constexpr float kLogFactor = static_cast(1 << kXShift); + __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); + y0 = _mm256_mul_ps(y0, factor); + y1 = _mm256_mul_ps(y1, factor); + // Now compute 2^x. + float32_pow2(y0, y1); +} +template +inline void fixed32_exp_float(const __m256i& x0, const __m256i& x1, __m256& y0, + __m256& y1) { + // Clip to acceptable bounds to prevent overflow, and convert to float. + y0 = + _mm256_cvtepi32_ps(ClipToBounds(kInputMantissaBits, x0)); + y1 = + _mm256_cvtepi32_ps(ClipToBounds(kInputMantissaBits, x1)); + float_exp_float_preclipped(y0, y1); +} + +// Float->float tanh approximation via the exponent trick. +// Note that the input is scaled floats, as if converted raw from fixed16/32. +// Input: 2x registers containing 8x float scaled by input_mantissa_bits. +// Output: two registers containing 8x float. +// TM_ORDER4_FLOAT: Max relative error < 2.1e-5, absolute error < 2.3e-6. +// TM_ORDER4_16BIT: Max relative error < 1e-4, absolute error < 1.3e-5. +// TM_ORDER3_16BIT: Max relative error < 2.1e-3, absolute error < 3e-4. +template +inline void float_tanh_float(const __m256& x0, const __m256& x1, __m256& y0, + __m256& y1) { + // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and + // scale to the mantissa bits required by float32_pow2 all in one multiply. + // This is the shift of the FP mantissa relative to the input mantissa. + // Add one to double the input. + const float kLogFactor = + static_cast(1 << (kFloatMantissaBits - kInputMantissaBits + 1)); + __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); + // Clip to suitable input bounds for tanh. + __m256 clip_limit = _mm256_set1_ps(kMaxTanhInput * (1 << kInputMantissaBits)); + __m256 clip0 = _mm256_min_ps(x0, clip_limit); + __m256 clip1 = _mm256_min_ps(x1, clip_limit); + clip_limit = _mm256_set1_ps(-kMaxTanhInput * (1 << kInputMantissaBits)); + clip0 = _mm256_max_ps(clip0, clip_limit); + clip1 = _mm256_max_ps(clip1, clip_limit); + __m256 exp0 = _mm256_mul_ps(clip0, factor); + __m256 exp1 = _mm256_mul_ps(clip1, factor); + // Now compute 2^x. + float32_pow2(exp0, exp1); + // Now compute tanh using (e^2x - 1) / (e^2x + 1). + __m256 one = _mm256_set1_ps(1.0f); + __m256 numerator = _mm256_sub_ps(exp0, one); + __m256 denominator = _mm256_add_ps(exp0, one); + // Approximate reciprocal is not accurate enough - use full division. + exp0 = _mm256_div_ps(numerator, denominator); + numerator = _mm256_sub_ps(exp1, one); + denominator = _mm256_add_ps(exp1, one); + exp1 = _mm256_div_ps(numerator, denominator); + // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low + // relative error close to 0. + // Normalize the inputs back to proper floats. + factor = _mm256_set1_ps(1.0f / (1 << kInputMantissaBits)); + clip0 = _mm256_mul_ps(clip0, factor); + clip1 = _mm256_mul_ps(clip1, factor); + __m256 third = _mm256_set1_ps(-1.0f / 3.0f); + __m256 taylor0 = _mm256_mul_ps(clip0, clip0); + __m256 taylor1 = _mm256_mul_ps(clip1, clip1); + taylor0 = _mm256_mul_ps(taylor0, clip0); + taylor1 = _mm256_mul_ps(taylor1, clip1); + // TODO(b/191497441): The next two pairs of instructions could be combined to + // _mm256_fmadd_ps, but requires -mfma compilation option, eg: + // taylor0 = _mm256_fmadd_ps(taylor0, third, clip0); + taylor0 = _mm256_mul_ps(taylor0, third); + taylor1 = _mm256_mul_ps(taylor1, third); + taylor0 = _mm256_add_ps(clip0, taylor0); + taylor1 = _mm256_add_ps(clip1, taylor1); + // Test |x| <= 1/9, roughly where the errors cross over, without needing yet + // another constant. + third = _mm256_mul_ps(third, third); + __m256 neg_zero = _mm256_set1_ps(-0.0f); + clip0 = _mm256_andnot_ps(neg_zero, clip0); + clip1 = _mm256_andnot_ps(neg_zero, clip1); + __m256 cmp_results0 = _mm256_cmp_ps(clip0, third, _CMP_LE_OQ); + __m256 cmp_results1 = _mm256_cmp_ps(clip1, third, _CMP_LE_OQ); + y0 = _mm256_blendv_ps(exp0, taylor0, cmp_results0); + y1 = _mm256_blendv_ps(exp1, taylor1, cmp_results1); +} + +// Fixed32 sigmoid approximation via the AVX2 implementation of the exponent +// trick. +// Input: 2x registers containins 8x float containing converted fixed32 scaled +// with kInputMantissaBits. +// Output: 2x registers containing 8x float. +// TM_ORDER4_FLOAT: Max relative error < 4e-6, absolute error < 1e-6. +// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 7e-6. +// TM_ORDER3_16BIT: Max relative error < 5.4e-4, absolute error < 1.4e-4. +template +inline void float_sigmoid_float(__m256& y0, __m256& y1) { + constexpr float kInputFactor = static_cast(1 << kInputMantissaBits); + // Negate the inputs. + __m256 minus_zero = _mm256_set1_ps(-0.0f); + y0 = _mm256_xor_ps(y0, minus_zero); + y1 = _mm256_xor_ps(y1, minus_zero); + y0 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y0); + y1 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y1); + float_exp_float_preclipped(y0, y1); + __m256 one = _mm256_set1_ps(1.0f); + // Approximate reciprocal is not accurate enough - use full division. + y0 = _mm256_div_ps(one, _mm256_add_ps(y0, one)); + y1 = _mm256_div_ps(one, _mm256_add_ps(y1, one)); +} + +#endif // defined __AVX2__ + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_ diff --git a/sparse_matmul/numerics/fasttranscendentals_test.cc b/sparse_matmul/numerics/fasttranscendentals_test.cc new file mode 100644 index 00000000..004241e5 --- /dev/null +++ b/sparse_matmul/numerics/fasttranscendentals_test.cc @@ -0,0 +1,665 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined __aarch64__ +#include +#endif +#if defined __AVX__ || defined __AVX2__ +#include +#endif + +#include + +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/test_utils.h" + +namespace csrblocksparse { + +const float kExpFixedRelTolerance = .084f; + +#ifdef SIGMOID_AS_TANH +#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX +const float kSigmoidRelTolerance = .093f; // 9.3% relative +const float kSigmoidAbsTolerance = .0005f; +const float kSigmoidFixedRelTolerance = .093f; +const float kSigmoidFixedAbsTolerance = .0005f; +#elif defined FAST_TRANSCENDENTALS +const float kSigmoidRelTolerance = .09f; // 9.0% relative +const float kSigmoidAbsTolerance = .003f; +const float kSigmoidFixedRelTolerance = .09f; +const float kSigmoidFixedAbsTolerance = .003f; +#endif +#elif defined FAST_TRANSCENDENTALS and defined ACCURATE_TRANSCENDENTAL_APPROX +const float kSigmoidRelTolerance = .102f; // 10.2% relative +const float kSigmoidAbsTolerance = .0003f; +const float kSigmoidFixedRelTolerance = .102f; +const float kSigmoidFixedAbsTolerance = .0003f; +#elif defined FAST_TRANSCENDENTALS +const float kSigmoidRelTolerance = .09f; // 9.0% relative +const float kSigmoidAbsTolerance = .006f; +const float kSigmoidFixedRelTolerance = .09f; +const float kSigmoidFixedAbsTolerance = .006f; +#else +const float kSigmoidRelTolerance = .0001f; +const float kSigmoidAbsTolerance = 1e-5f; +const float kSigmoidFixedRelTolerance = .001f; +const float kSigmoidFixedAbsTolerance = .001f; +#endif + +#if (defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX || \ + defined FASTER_TRANSCENDENTALS) +const float kExpRelTolerance = .03f; // 3% relative +const float kTanhRelTolerance = .006f; // .6% relative +const float kTanhAbsTolerance = .0003f; +#elif defined FAST_TRANSCENDENTALS +const float kExpRelTolerance = .03f; // 3% relative +const float kTanhRelTolerance = .091f; // .91% relative +const float kTanhAbsTolerance = .00525f; +#else +const float kExpRelTolerance = .0001f; +const float kTanhRelTolerance = .0001f; +const float kTanhAbsTolerance = 1e-5f; +#endif + +constexpr float kQuarticFloatExpRelTolerance = 8e-6f; +constexpr float kQuarticFloatExpTolerance = 9e-6f; +constexpr float kQuarticExpRelTolerance = 3e-5f; +constexpr float kQuarticExpTolerance = 6e-5f; +constexpr float kCubicExpRelTolerance = 6e-4f; +constexpr float kCubicExpTolerance = 2e-3f; +constexpr float kQuarticFloatTanhRelTolerance = 3e-5f; +constexpr float kQuarticFloatTanhTolerance = 3e-6f; +constexpr float kCubicTanhRelTolerance = 3e-3f; +constexpr float kCubicTanhTolerance = 3e-4f; +constexpr float kQuarticSigmoidRelTolerance = 3e-5f; +constexpr float kQuarticSigmoidTolerance = 7e-6f; +constexpr float kCubicSigmoidRelTolerance = 6e-4f; +constexpr float kCubicSigmoidTolerance = 2e-4f; +#ifdef __AVX2__ +constexpr float kQuarticTanhRelTolerance = 1e-4f; +constexpr float kQuarticTanhTolerance = 2e-5f; +constexpr float kQuarticFloatSigmoidRelTolerance = 4e-6f; +constexpr float kQuarticFloatSigmoidTolerance = 1e-6f; +#endif // __AVX2__ + +TEST(Transcendentals, Exp) { + // 132 - 127 = 5, we check between -63.99... and 63.99... + const int maxExponent = 132; + const int minExponent = 0; + float max_error = 0.f; + constexpr int kExponentBits = 7; + for (int s = 0; s < 2; ++s) { + for (int e = minExponent; e < maxExponent; ++e) { + // Don't check every mantissa for speed reasons. + for (int m = 0; m < (1 << 23); m += (1 << 10)) { + uint32_t int_val = s << 31 | e << 23 | m; + float x; + memcpy(&x, &int_val, sizeof(float)); + + float exact_exp = expf(x); + float approx_exp = csrblocksparse::fast_exp(x); + float approx_exp_fixed = csrblocksparse::fast_exp( + csrblocksparse::fixed32(x)); + + float rel_diff = RelDiff(exact_exp, approx_exp); + float rel_diff_fixed = RelDiff(exact_exp, approx_exp_fixed); + max_error = std::max(max_error, rel_diff); + EXPECT_LT(rel_diff, kExpRelTolerance) + << exact_exp << " " << approx_exp << " " << x; + EXPECT_LT(rel_diff_fixed, kExpRelTolerance) + << exact_exp << " " << approx_exp << " " << x; + } + } + } +} + +TEST(Transcendentals, FixedExp) { + const int maxExponent = 132; + const int minExponent = 120; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int s = 0; s < 2; ++s) { + for (int e = minExponent; e < maxExponent; ++e) { + // Don't check every mantissa for speed reasons. + for (int m = 0; m < (1 << 23); m += (1 << 10)) { + uint32_t int_val = s << 31 | e << 23 | m; + float x; + memcpy(&x, &int_val, sizeof(float)); + + float exact_exp = expf(x); + float approx_exp = + csrblocksparse::fast_exp_fixed(csrblocksparse::fixed32<16>(x)); + + float rel_diff = RelDiff(exact_exp, approx_exp); + float abs_diff = std::abs(exact_exp - approx_exp); + max_error = std::max(max_error, rel_diff); + max_abs_error = std::max(max_abs_error, abs_diff); + EXPECT_LT(rel_diff, kExpFixedRelTolerance) + << exact_exp << " " << approx_exp << " " << x; + } + } + } + LOG(INFO) << "Max relative exp error = " << max_error + << ", abs=" << max_abs_error; +} + +template +void TestExp(float abs_tolerance, float rel_tolerance) { + constexpr int kMaxInput = 80 << 16; + constexpr int kMinInput = -(80 << 16); + constexpr int kExponentBits = 15; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int i = kMinInput; i <= kMaxInput; ++i) { + csrblocksparse::fixed32 fixed_int(i); + float x = static_cast(fixed_int); + float exact_exp = expf(x); + float approx_exp = fixed32_exp(fixed_int); + float diff = exact_exp - approx_exp; + float abs_diff = std::abs(diff); + float rel_diff = RelDiff(exact_exp, approx_exp); + max_error = std::max(max_error, rel_diff); + if (x <= 1.0f) { + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << exact_exp << ", aprx=" << approx_exp; + max_abs_error = std::max(max_abs_error, abs_diff); + } + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << exact_exp << ", aprx=" << approx_exp; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticExp) { + TestExp(kQuarticFloatExpTolerance, + kQuarticFloatExpRelTolerance); +} + +TEST(Transcendentals, CubicExp) { + TestExp(kCubicExpTolerance, kCubicExpRelTolerance); +} + +template +void TestTanh(float abs_tolerance, float rel_tolerance) { + constexpr int kMaxInput = (40 << 16); + constexpr int kMinInput = -(40 << 16); + constexpr int kExponentBits = 15; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int i = kMinInput; i <= kMaxInput; ++i) { + csrblocksparse::fixed32 fixed_int(i); + float x = static_cast(fixed_int); + float exact_tanh = tanh(x); + float approx_tanh = fixed32_tanh(fixed_int); + float diff = exact_tanh - approx_tanh; + float abs_diff = std::abs(diff); + float rel_diff = RelDiff(exact_tanh, approx_tanh); + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << exact_tanh << ", aprx=" << approx_tanh; + max_abs_error = std::max(max_abs_error, abs_diff); + max_error = std::max(max_error, rel_diff); + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << exact_tanh << ", aprx=" << approx_tanh; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticTanh) { + TestTanh(kQuarticFloatTanhTolerance, + kQuarticFloatTanhRelTolerance); +} + +TEST(Transcendentals, CubicTanh) { + TestTanh(kCubicTanhTolerance, kCubicTanhRelTolerance); +} + +template +void TestSigmoid(float abs_tolerance, float rel_tolerance) { + constexpr int kMaxInput = 80 << 16; + constexpr int kMinInput = -(80 << 16); + constexpr int kExponentBits = 15; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int i = kMinInput; i <= kMaxInput; ++i) { + csrblocksparse::fixed32 fixed_int(i); + float x = static_cast(fixed_int); + float exact_sigmoid = 1.0f / (1.0f + exp(-x)); + float approx_sigmoid = fixed32_sigmoid(fixed_int); + float diff = exact_sigmoid - approx_sigmoid; + float abs_diff = std::abs(diff); + float rel_diff = RelDiff(exact_sigmoid, approx_sigmoid); + max_error = std::max(max_error, rel_diff); + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << exact_sigmoid + << ", aprx=" << approx_sigmoid; + max_abs_error = std::max(max_abs_error, abs_diff); + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << exact_sigmoid + << ", aprx=" << approx_sigmoid; + } + LOG(INFO) << "Max relative sigmoid error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticSigmoidExp) { + TestSigmoid(kQuarticSigmoidTolerance, + kQuarticSigmoidRelTolerance); +} + +TEST(Transcendentals, CubicSigmoidExp) { + TestSigmoid(kCubicSigmoidTolerance, + kCubicSigmoidRelTolerance); +} + +TEST(Transcendentals, Sigmoid) { + // 132 - 127 = 5, we check between -63.99... and 63.99... + const int maxExponent = 132; + const int minExponent = 0; + // The mantissa bits must not exceed 23, so min exponent bits here is: + // 31 - 23 = 8. + constexpr int kExponentBits = 9; + float max_error = 0.f; + float max_abs_error = 0.f; +#if defined __aarch64__ + float max_vector_error = 0.f; + float max_vector_abs_error = 0.f; +#endif + for (int s = 0; s < 2; ++s) { + for (int e = minExponent; e < maxExponent; ++e) { + // Don't check every mantissa for speed reasons. + for (int m = 0; m < (1 << 23); m += (1 << 10)) { + uint32_t int_val = s << 31 | e << 23 | m; + float x; + memcpy(&x, &int_val, sizeof(float)); + + float exact_sigmoid = 1. / (1. + expf(-x)); + float approx_sigmoid = csrblocksparse::fast_sigmoid(x); + float approx_sigmoid_fixed = + csrblocksparse::fast_sigmoid( + csrblocksparse::fixed32(x)); + + float rel_diff = RelDiff(exact_sigmoid, approx_sigmoid); + float abs_diff = std::abs(exact_sigmoid - approx_sigmoid); + float rel_diff_fixed = RelDiff(exact_sigmoid, approx_sigmoid_fixed); + max_error = std::max(max_error, rel_diff); + max_abs_error = std::max(max_abs_error, abs_diff); + EXPECT_LT(rel_diff, kSigmoidRelTolerance) + << exact_sigmoid << " " << approx_sigmoid << " " << x; + EXPECT_NEAR(approx_sigmoid, exact_sigmoid, kSigmoidAbsTolerance) << x; + + EXPECT_LT(rel_diff_fixed, kSigmoidFixedRelTolerance) + << exact_sigmoid << " " << approx_sigmoid_fixed << " " << x; + EXPECT_NEAR(approx_sigmoid_fixed, exact_sigmoid, + kSigmoidFixedAbsTolerance) + << x; +#if defined __aarch64__ + constexpr int kSIMD_WIDTH = 4; + float approx_results[kSIMD_WIDTH]; + int32x4_t input = + vdupq_n_s32(csrblocksparse::fixed32(x).raw_val()); + float32x4_t result = csrblocksparse::fast_sigmoid(input); + vst1q_f32(approx_results, result); + + for (int i = 0; i < kSIMD_WIDTH; ++i) { + float rel_diff = RelDiff(exact_sigmoid, approx_results[i]); + float abs_diff = std::abs(exact_sigmoid - approx_results[i]); + max_vector_error = std::max(max_vector_error, rel_diff); + max_vector_abs_error = std::max(max_vector_abs_error, abs_diff); + EXPECT_LT(rel_diff, kSigmoidRelTolerance) + << exact_sigmoid << " " << approx_sigmoid << " " << x; + EXPECT_NEAR(approx_sigmoid, exact_sigmoid, kSigmoidAbsTolerance) << x; + } +#endif + } + } + } + LOG(INFO) << "Max relative error in float sigmoid=" << max_error; + LOG(INFO) << "Max abs error in float sigmoid=" << max_abs_error; +#if defined __aarch64__ + LOG(INFO) << "Max relative vector error fixed sigmoid=" << max_vector_error; + LOG(INFO) << "Max abs vector error fixed sigmoid=" << max_vector_abs_error; +#endif +} + +TEST(Transcendentals, Tanh) { + // 132 - 127 = 5, we check between -63.99... and 63.99... + const int maxExponent = 132; + const int minExponent = 0; + float max_error = 0.f; + float max_abs_error = 0.f; + for (int s = 0; s < 2; ++s) { + for (int e = minExponent; e < maxExponent; ++e) { + // Don't check every mantissa for speed reasons. + for (int m = 0; m < (1 << 23); m += (1 << 10)) { + uint32_t int_val = s << 31 | e << 23 | m; + float x; + memcpy(&x, &int_val, sizeof(float)); + + float exact_tanh = tanhf(x); + float approx_tanh = csrblocksparse::fast_tanh(x); + + float rel_diff = RelDiff(exact_tanh, approx_tanh); + float abs_diff = std::abs(exact_tanh - approx_tanh); + max_error = std::max(rel_diff, max_error); + max_abs_error = std::max(abs_diff, max_abs_error); + + EXPECT_LT(rel_diff, kTanhRelTolerance) + << exact_tanh << " " << approx_tanh << " " << x; + EXPECT_NEAR(approx_tanh, exact_tanh, kTanhAbsTolerance) << x; + } + } + } + LOG(INFO) << "Max relative error in float tanh=" << max_error; + LOG(INFO) << "Max abs error in float tanh=" << max_abs_error; + + // tanh behavior is not identical across all lanes, so need to test + // with some values in the linear region and some not. +#if defined __aarch64__ + float vals[4] = {-1.f, -.1f, .1f, 1.f}; + float exact_results[4]; + float approx_results[4]; + max_error = 0.f; + max_abs_error = 0.f; + + float32x4_t input = vld1q_f32(vals); + float32x4_t result = csrblocksparse::fast_tanh(input); + vst1q_f32(approx_results, result); + + for (int i = 0; i < 4; ++i) { + exact_results[i] = tanh(vals[i]); + float rel_diff = RelDiff(exact_results[i], approx_results[i]); + float abs_diff = std::abs(exact_results[i] - approx_results[i]); + max_error = std::max(rel_diff, max_error); + max_abs_error = std::max(abs_diff, max_abs_error); + + EXPECT_LT(rel_diff, kTanhRelTolerance) + << exact_results[i] << " " << approx_results[i] << " " << vals[i]; + EXPECT_NEAR(approx_results[i], exact_results[i], kTanhAbsTolerance) + << vals[i]; + } + LOG(INFO) << "Max relative vector error in float tanh=" << max_error; + LOG(INFO) << "Max abs vector error in float tanh=" << max_abs_error; +#endif +} + +#if defined __AVX2__ + +constexpr int kSIMDSize = 8; +constexpr int kNumExpBitsIn = 10; +constexpr int kNumExpBitsOut = 5; + +TEST(Transcendentals, TanhLut) { + // Test every value in (-1, 1) for round-trip exactness. + constexpr int kNumMantissaBitsIn = fixed32::kMantissaBits; + constexpr int kNumMantissaBitsOut = fixed16::kMantissaBits; + const int32_t* tanh_table = TanhTable(kNumMantissaBitsOut); + float in_factor = static_cast(1 << kNumMantissaBitsIn); + float out_factor = static_cast(1 << kNumMantissaBitsOut); + for (int i = 1 - (1 << kNumMantissaBitsOut); + i + kSIMDSize < (1 << kNumMantissaBitsOut); i += kSIMDSize) { + int32_t inputs[kSIMDSize]; + int32_t outputs[kSIMDSize]; + int32_t target_outputs[kSIMDSize]; + for (int j = 0; j < kSIMDSize; ++j) { + float target_tanh = (i + j) / out_factor; + float x = atanhf(static_cast(target_tanh)); + inputs[j] = static_cast(x * in_factor); + target_outputs[j] = i + j; + } + __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(inputs)); + __m256i output = + fixed32_tanh_fixed16( + tanh_table, x_in); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(outputs), output); + for (int j = 0; j < kSIMDSize; ++j) { + EXPECT_EQ(target_outputs[j], outputs[j]); + } + } +} + +TEST(Transcendentals, SigmoidLut) { + // Test every value in (-1, 1) for round-trip exactness. + constexpr int kNumMantissaBitsIn = fixed32::kMantissaBits; + constexpr int kNumMantissaBitsOut = fixed16::kMantissaBits; + const int32_t* sigmoid_table = SigmoidTable(kNumMantissaBitsOut); + float in_factor = static_cast(1 << kNumMantissaBitsIn); + float out_factor = static_cast(1 << kNumMantissaBitsOut); + for (int i = 1; i + kSIMDSize < (1 << kNumMantissaBitsOut); i += kSIMDSize) { + int32_t inputs[kSIMDSize]; + int32_t outputs[kSIMDSize]; + int32_t target_outputs[kSIMDSize]; + for (int j = 0; j < kSIMDSize; ++j) { + float target_sigmoid = (i + j) / out_factor; + float x = 2.0f * atanhf(2.0f * static_cast(target_sigmoid) - 1.0f); + inputs[j] = static_cast(x * in_factor); + target_outputs[j] = i + j; + } + __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(inputs)); + __m256i output = + fixed32_sigmoid_fixed16( + sigmoid_table, x_in); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(outputs), output); + for (int j = 0; j < kSIMDSize; ++j) { + EXPECT_EQ(target_outputs[j], outputs[j]); + } + } +} + +template +static void TestExpAVX2(float abs_tolerance, float rel_tolerance) { + constexpr int kMantissaBits = 20; + // Test every value in [-80, 80] and report the max error. + constexpr int kMinInput = -(80 << kMantissaBits); + constexpr int kMaxInput = 80 << kMantissaBits; + constexpr int kNumInputs = kMaxInput - kMinInput; + std::vector inputs(kNumInputs); + std::vector outputs(kNumInputs); + std::vector target_outputs(kNumInputs); + for (int i = 0; i < inputs.size(); ++i) { + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + inputs[i] = fixed_int.raw_val(); + target_outputs[i] = expf(x); + } + absl::Time t_start = absl::Now(); + for (int i = 0; i + kSIMDSize * 2 <= kNumInputs; i += kSIMDSize * 2) { + __m256i x0 = + _mm256_loadu_si256(reinterpret_cast(inputs.data() + i)); + __m256i x1 = _mm256_loadu_si256( + reinterpret_cast(inputs.data() + i + kSIMDSize)); + __m256 y0, y1; + fixed32_exp_float(x0, x1, y0, y1); + _mm256_storeu_ps(outputs.data() + i, y0); + _mm256_storeu_ps(outputs.data() + i + kSIMDSize, y1); + } + LOG(INFO) << "Time=" << absl::ToDoubleMilliseconds(absl::Now() - t_start); + float max_error = 0.f; + float max_abs_error = 0.f; + for (int i = 0; i < kNumInputs; ++i) { + float diff = target_outputs[i] - outputs[i]; + float abs_diff = std::abs(diff); + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + float rel_diff = RelDiff(target_outputs[i], outputs[i]); + max_error = std::max(max_error, rel_diff); + if (x <= 1.0f) { + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", result= " << outputs[i] << ", i=" << i; + max_abs_error = std::max(max_abs_error, abs_diff); + } + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", result= " << outputs[i] << ", i=" << i; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticFloatExpAVX2) { + TestExpAVX2(kQuarticFloatExpTolerance, + kQuarticFloatExpRelTolerance); +} + +TEST(Transcendentals, QuarticExpAVX2) { + TestExpAVX2(kQuarticExpTolerance, kQuarticExpRelTolerance); +} + +TEST(Transcendentals, CubicExpAVX2) { + TestExpAVX2(kCubicExpTolerance, kCubicExpRelTolerance); +} + +template +void TestTanhAVX2Float(float abs_tolerance, float rel_tolerance) { + constexpr int kMantissaBits = 16; + // Test every value in [-10, 10] and report the max error. + constexpr int kMinInput = -(10 << kMantissaBits); + constexpr int kMaxInput = 10 << kMantissaBits; + constexpr int kNumInputs = kMaxInput - kMinInput; + float max_error = 0.f; + float max_abs_error = 0.f; + std::vector inputs(kNumInputs); + std::vector outputs(kNumInputs); + std::vector target_outputs(kNumInputs); + for (int i = 0; i < inputs.size(); ++i) { + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + float exact = tanh(x); + inputs[i] = static_cast(fixed_int.raw_val()); + target_outputs[i] = exact; + } + absl::Time t_start = absl::Now(); + for (int i = 0; i + kSIMDSize * 2 <= inputs.size(); i += kSIMDSize * 2) { + __m256 x0 = _mm256_loadu_ps(inputs.data() + i); + __m256 x1 = _mm256_loadu_ps(inputs.data() + kSIMDSize + i); + __m256 y0, y1; + float_tanh_float(x0, x1, y0, y1); + _mm256_storeu_ps(outputs.data() + i, y0); + _mm256_storeu_ps(outputs.data() + i + kSIMDSize, y1); + } + LOG(INFO) << "Time=" << absl::ToDoubleMilliseconds(absl::Now() - t_start); + float worst_abs_x = 0.0f, worst_rel_x = 0.0f; + for (int i = 0; i < inputs.size(); ++i) { + float diff = target_outputs[i] - outputs[i]; + float abs_diff = std::abs(diff); + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", aprx=" << outputs[i]; + if (abs_diff > max_abs_error) worst_abs_x = x; + max_abs_error = std::max(max_abs_error, abs_diff); + float rel_diff = 0.0f; + rel_diff = RelDiff(target_outputs[i], outputs[i]); + if (rel_diff > max_error) worst_rel_x = x; + max_error = std::max(max_error, rel_diff); + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", aprx=" << outputs[i]; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; + LOG(INFO) << "Worst rel x = " << worst_rel_x << ", abs=" << worst_abs_x; +} + +TEST(Transcendentals, QuarticTanhFloatAVX2Float) { + TestTanhAVX2Float(kQuarticFloatTanhTolerance, + kQuarticFloatTanhRelTolerance); +} + +TEST(Transcendentals, QuarticTanhAVX2Float) { + TestTanhAVX2Float(kQuarticTanhTolerance, + kQuarticTanhRelTolerance); +} + +TEST(Transcendentals, CubicTanhAVX2Float) { + TestTanhAVX2Float(kCubicTanhTolerance, + kCubicTanhRelTolerance); +} + +template +void TestSigmoidAVX2Float(float abs_tolerance, float rel_tolerance) { + constexpr int kMantissaBits = 20; + // Test every value in [-20, 20] and report the max error. + constexpr int kMaxInput = 20 << kMantissaBits; + constexpr int kMinInput = -(20 << kMantissaBits); + float max_error = 0.f; + float max_abs_error = 0.f; + std::vector inputs(kMaxInput - kMinInput); + std::vector outputs(kMaxInput - kMinInput); + std::vector target_outputs(kMaxInput - kMinInput); + for (int i = 0; i < inputs.size(); ++i) { + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + float exact = 1.0f / (1.0f + expf(-x)); + inputs[i] = fixed_int.raw_val(); + target_outputs[i] = exact; + } + absl::Time t_start = absl::Now(); + for (int i = 0; i + kSIMDSize * 2 <= inputs.size(); i += kSIMDSize * 2) { + __m256i x0 = + _mm256_loadu_si256(reinterpret_cast(inputs.data() + i)); + __m256i x1 = _mm256_loadu_si256( + reinterpret_cast(inputs.data() + i + kSIMDSize)); + __m256 y0 = _mm256_cvtepi32_ps(x0); + __m256 y1 = _mm256_cvtepi32_ps(x1); + float_sigmoid_float(y0, y1); + _mm256_storeu_ps(outputs.data() + i, y0); + _mm256_storeu_ps(outputs.data() + i + kSIMDSize, y1); + } + LOG(INFO) << "Time=" << absl::ToDoubleMilliseconds(absl::Now() - t_start); + for (int i = 0; i < inputs.size(); ++i) { + float diff = target_outputs[i] - outputs[i]; + float abs_diff = std::abs(diff); + csrblocksparse::fixed32<31 - kMantissaBits> fixed_int(i + kMinInput); + float x = static_cast(fixed_int); + float rel_diff = RelDiff(target_outputs[i], outputs[i]); + max_error = std::max(max_error, rel_diff); + ASSERT_LT(abs_diff, abs_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", aprx=" << outputs[i]; + max_abs_error = std::max(max_abs_error, abs_diff); + ASSERT_LT(rel_diff, rel_tolerance) + << "x=" << x << ", target=" << target_outputs[i] + << ", aprx=" << outputs[i]; + } + LOG(INFO) << "Max relative error = " << max_error + << ", abs=" << max_abs_error; +} + +TEST(Transcendentals, QuarticSigmoidFloatAVX2Float) { + TestSigmoidAVX2Float(kQuarticFloatSigmoidTolerance, + kQuarticFloatSigmoidRelTolerance); +} + +TEST(Transcendentals, QuarticSigmoidAVX2Float) { + TestSigmoidAVX2Float(kQuarticSigmoidTolerance, + kQuarticSigmoidRelTolerance); +} + +TEST(Transcendentals, CubicSigmoidAVX2Float) { + TestSigmoidAVX2Float(kCubicSigmoidTolerance, + kCubicSigmoidRelTolerance); +} +#endif // __AVX2__ + +} // namespace csrblocksparse diff --git a/sparse_matmul/numerics/fixed_types.h b/sparse_matmul/numerics/fixed_types.h new file mode 100644 index 00000000..932f81a0 --- /dev/null +++ b/sparse_matmul/numerics/fixed_types.h @@ -0,0 +1,139 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FIXED_TYPES_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FIXED_TYPES_H_ + +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +namespace csrblocksparse { + +// Useful for meta-programming and determining if a type is a fixed point type +class fixed_type {}; +class fixed16_type : fixed_type {}; +class fixed32_type : fixed_type {}; + +// Storage class for 16-bit fixed point values, not meant to be used directly +// for computation. Used for storage and converting to/from float32. +// N = 16 - 1 - |ExponentBits|. +// range = [-2^|ExponentBits|, 2^|ExponentBits|), increment = 2^-N. +template +class fixed16 : fixed16_type { + static_assert(ExponentBits >= 0 && ExponentBits < 16, + "ExponentBits must be in" + " the interval [0, 15]"); + + public: + static constexpr int kExponentBits = ExponentBits; + static constexpr int kMantissaBits = 16 - ExponentBits - 1; + + fixed16() = default; + explicit fixed16(float x) : val_(float_to_fixed16(x)) {} + explicit fixed16(int16_t x) : val_(x) {} + + explicit operator float() const { return fixed16_to_float(val_); } + + int raw_val() const { return val_; } + + private: + inline float fixed16_to_float(int16_t x) const { + return static_cast(x) / (1 << kMantissaBits); + } + + // Conversion clips to the representable range. + inline int16_t float_to_fixed16(float x) const { + float fval = std::round(x * static_cast(1 << kMantissaBits)); + const float max_bound = std::numeric_limits::max(); + const float min_bound = std::numeric_limits::min(); + auto val = + static_cast(std::max(std::min(fval, max_bound), min_bound)); + LOG_IF(INFO, fval > max_bound || fval < min_bound) + << "Conversion clipping: " << x << " to " << fixed16_to_float(val); + return val; + } + + int16_t val_; +}; + +// Storage class for 32-bit fixed point values, not meant to be used directly +// for computation. Used for storage and converting to/from float32. +// N = 32 - 1 - |ExponentBits|. +// range = [-2^|ExponentBits|, 2^|ExponentBits|), increment = 2^-N. +template +class fixed32 : fixed32_type { + static_assert(ExponentBits >= 0 && ExponentBits < 32, + "ExponentBits must be in" + " the interval [0, 31]"); + + public: + static constexpr int kExponentBits = ExponentBits; + static constexpr int kMantissaBits = 32 - ExponentBits - 1; + + fixed32() = default; + explicit fixed32(float x) : val_(float_to_fixed32(x)) {} + explicit fixed32(int32_t x) : val_(x) {} + + explicit operator float() const { return fixed32_to_float(val_); } + + int raw_val() const { return val_; } + + private: + inline float fixed32_to_float(int32_t x) const { + return static_cast(x) / (1LL << kMantissaBits); + } + + // Conversion clips to the representable range. + inline int32_t float_to_fixed32(float x) const { + float fval = std::round(x * static_cast(1LL << kMantissaBits)); + const int32_t max_bound = std::numeric_limits::max(); + const int32_t min_bound = std::numeric_limits::min(); + int32_t val = fval >= static_cast(max_bound) + ? max_bound + : (fval < static_cast(min_bound) + ? min_bound + : static_cast(fval)); + + LOG_IF(INFO, fval >= max_bound || fval < min_bound) + << "Conversion clipping: " << x << " to " << fixed32_to_float(val); + return val; + } + + int32_t val_; +}; + +template +struct IsFixed16Type + : std::integral_constant::value> {}; + +template +struct IsFixed32Type + : std::integral_constant::value> {}; + +template +struct IsFixedType : std::integral_constant::value || + IsFixed32Type::value> { +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FIXED_TYPES_H_ diff --git a/sparse_matmul/numerics/fixed_types_test.cc b/sparse_matmul/numerics/fixed_types_test.cc new file mode 100644 index 00000000..82fcd93d --- /dev/null +++ b/sparse_matmul/numerics/fixed_types_test.cc @@ -0,0 +1,43 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/numerics/fixed_types.h" + +#include + +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/test_utils.h" +#include "sparse_matmul/numerics/type_utils.h" + +namespace csrblocksparse { + +// Basic test that makes sure basic multiplication and TypeOfProduct work +// correctly. +TEST(FixedPoint, Multiplication) { + fixed16<4> a(.1f); + fixed16<4> b(1.f); + + TypeOfProduct, fixed16<4>>::type c(a.raw_val() * b.raw_val()); + + EXPECT_NEAR(static_cast(c), .1f, + 1. / (1 << fixed16<2>::kMantissaBits)); +} + +TEST(FixedPoint, SafeCastingIntMax) { + const float int_max_float = std::numeric_limits::max(); + const csrblocksparse::fixed32<31> int_max_fixed(int_max_float); + EXPECT_FLOAT_EQ(int_max_float, static_cast(int_max_fixed)); +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/numerics/float16_types.h b/sparse_matmul/numerics/float16_types.h new file mode 100644 index 00000000..5a313271 --- /dev/null +++ b/sparse_matmul/numerics/float16_types.h @@ -0,0 +1,149 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FLOAT16_TYPES_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FLOAT16_TYPES_H_ + +#include +#include +#include + +namespace csrblocksparse { + +// Storage class for fp16 values, not meant to be used directly for computation. +// Used for converting to/from float32. +class fp16 { + public: + fp16() = default; + explicit fp16(float x) : val_(float_to_fp16(x)) {} + explicit fp16(uint16_t x) : val_(x) {} + static constexpr int kMantissaBits = 11; + + explicit operator float() const { return fp16_to_float(val_); } + + private: + inline float fp16_to_float(uint16_t as_int) const { +#if defined __aarch64__ + float x; + float* x_ptr = &x; + asm volatile( + "dup v0.8h, %w[as_int]\n" + "fcvtl v1.4s, v0.4h\n " + "st1 {v1.s}[0], [%[x_ptr]]\n" + : // outputs + : // inputs + [x_ptr] "r"(x_ptr), + [as_int] "r"(as_int) + : // clobbers + "cc", "memory", "v0", "v1"); + return x; +#else + unsigned int sign_bit = (as_int & 0x8000) << 16; + unsigned int exponent = as_int & 0x7c00; + + unsigned int mantissa; + if (exponent == 0) + mantissa = 0; + else + mantissa = ((as_int & 0x7fff) << 13) + 0x38000000; + mantissa |= sign_bit; + + float x; + memcpy(&x, &mantissa, sizeof(int)); + return x; +#endif // defined __aarch64__ + } + + inline uint16_t float_to_fp16(float x) const { +#if defined __aarch64__ + uint16_t as_int; + uint16_t* as_int_ptr = &as_int; + asm volatile( + "dup v0.4s, %w[x]\n" + "fcvtn v1.4h, v0.4s\n" + "st1 {v1.h}[0], [%[as_int_ptr]]\n" + : // outputs + : // inputs + [as_int_ptr] "r"(as_int_ptr), + [x] "r"(x) + : // clobbers + "cc", "memory", "v0", "v1"); + return as_int; +#else + unsigned int x_int; + memcpy(&x_int, &x, sizeof(int)); + + unsigned int sign_bit = (x_int & 0x80000000) >> 16; + unsigned int exponent = x_int & 0x7f800000; + + unsigned int mantissa; + if (exponent < 0x38800000) { // exponent too small or denormal + mantissa = 0; + } else if (exponent > 0x8e000000) { + mantissa = 0x7bff; // exponent too big, inf + } else { + mantissa = ((x_int & 0x7fffffff) >> 13) - 0x1c000; + } + + mantissa |= sign_bit; + + return static_cast(mantissa & 0xFFFF); +#endif + } + + uint16_t val_; +}; + +// Storage class for bfloat16 values, not meant to be used directly for +// computation. Used for converting to/from float32. +class bfloat16 { + public: + bfloat16() = default; + explicit bfloat16(float x) : val_(float_to_bfloat16(x)) {} + explicit bfloat16(uint16_t x) : val_(x) {} + static constexpr int kMantissaBits = 7; + + explicit operator float() const { return bfloat16_to_float(val_); } + + private: + inline uint16_t float_to_bfloat16(float x) const { + uint32_t as_int; + std::memcpy(&as_int, &x, sizeof(float)); + return as_int >> 16; + } + + inline float bfloat16_to_float(uint32_t as_int) const { + as_int <<= 16; + float x; + std::memcpy(&x, &as_int, sizeof(float)); + return x; + } + + uint16_t val_; +}; + +template +struct IsCustomFloatType + : std::integral_constant::value || + std::is_same::value> {}; +template +struct IsAnyFloatType + : std::integral_constant::value || + IsCustomFloatType::value> {}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FLOAT16_TYPES_H_ diff --git a/sparse_matmul/numerics/test_utils.h b/sparse_matmul/numerics/test_utils.h new file mode 100644 index 00000000..7f33a4fe --- /dev/null +++ b/sparse_matmul/numerics/test_utils.h @@ -0,0 +1,75 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TEST_UTILS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TEST_UTILS_H_ + +#include +#include +#include + +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/type_utils.h" + +namespace csrblocksparse { + +// Computes the relative difference between two floating point numbers +// std::abs(b - a) / a. If the a is < 10 * epsilon, then use the absolute +// difference instead of the relative one. +template +T RelDiff(T a, T b) { + static_assert(std::is_floating_point::value, + "RelDiff should only be used on floating point types."); + if (std::abs(a) < 600 * std::numeric_limits::epsilon()) { + return std::abs(b - a); + } + return std::abs((b - a) / a); +} + +// Compares two CacheAlignedVectors elementwise, checks if each pair passes a +// RelDiff check. The result of RelDiff is scaled by the log of the size of the +// column to account for increasing summation errors as the number of summands +// increases. +template +void CheckResult(const VectorType& lhs, const VectorType& rhs, int columns) { + ASSERT_EQ(lhs.size(), rhs.size()); + float epsilon = + 1.0f / + (1 << (MantissaBitsOf::value - 1)); + + // if we're summing a large number of values, then we can relax the tolerance + float log_scale = std::max(1.f, logf(columns)); + + // The tolerance is so large because it is a relative tolerance used to test + // numbers that are close to zero at the limit of the resolution of the + // representation. It would probably be better to focus on an absolute + // tolerance, based on the epsilon above. + const float tolerance = 0.026f; + for (int i = 0; i < lhs.size(); ++i) { + float lhs_value = static_cast(lhs.data()[i]); + float rhs_value = static_cast(rhs.data()[i]); + // If the absolute difference is no more than the epsilon for the + // representation, then it is OK. + if (std::abs(lhs_value - rhs_value) <= epsilon) continue; + float rel_diff = RelDiff(lhs_value, rhs_value) / log_scale; + EXPECT_LT(rel_diff, tolerance) << i % columns << " " << i / columns << " " + << lhs_value << " " << rhs_value; + } +} + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TEST_UTILS_H_ diff --git a/sparse_matmul/numerics/type_utils.h b/sparse_matmul/numerics/type_utils.h new file mode 100644 index 00000000..51291abe --- /dev/null +++ b/sparse_matmul/numerics/type_utils.h @@ -0,0 +1,89 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TYPE_UTILS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TYPE_UTILS_H_ + +// A collection of useful utilities for determining types based on other types. + +#include + +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" + +namespace csrblocksparse { + +// Basic idea is that any two float types yield a float, fixed16 types +// yield a fixed32 with the exponent bits summed. Other options are not +// allowed. +template +struct TypeOfProduct {}; + +template +struct TypeOfProduct< + LhsType, RhsType, + typename std::enable_if::value && + IsAnyFloatType::value>::type> { + using type = float; +}; + +template +struct TypeOfProduct< + LhsType, RhsType, + typename std::enable_if::value && + IsFixed16Type::value>::type> { + static_assert(LhsType::kMantissaBits + RhsType::kMantissaBits < 31, + "Sum of mantissa bits must not exceed 31."); + using type = fixed32<31 - LhsType::kMantissaBits - RhsType::kMantissaBits>; +}; + +// Given a weight type T, determine what the RhsType should be for that type. +// bfloat16 / fp16 -> float; fixed16 = fixed16 +template +struct RhsTypeIs { + using type = float; +}; + +template +struct RhsTypeIs::value>::type> { + using type = T; +}; + +template +struct MantissaBitsOf { + // Although int types have zero mantissa bits, use 1 to avoid division by 0. + static constexpr int value = 1; +}; + +template +struct MantissaBitsOf< + T, typename std::enable_if::value || + IsCustomFloatType::value>::type> { + public: + static constexpr int value = T::kMantissaBits; +}; + +template +struct MantissaBitsOf< + T, typename std::enable_if::value>::type> { + public: + // Ignoring the fact that doubles have more mantissa bits. + static constexpr int value = 24; +}; + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_TYPE_UTILS_H_ diff --git a/sparse_matmul/os/BUILD b/sparse_matmul/os/BUILD new file mode 100644 index 00000000..9c9d7684 --- /dev/null +++ b/sparse_matmul/os/BUILD @@ -0,0 +1,26 @@ +# Modules that interact with the operating system, and have no other dependencies. + +licenses(["notice"]) + +cc_library( + name = "coop_threads", + srcs = ["coop_threads.cc"], + hdrs = ["coop_threads.h"], + visibility = ["//sparse_matmul:__subpackages__"], + deps = [ + "@com_google_absl//absl/memory", + "@com_google_glog//:glog", + ], +) + +cc_test( + name = "coop_threads_test", + size = "small", + srcs = [ + "coop_threads_test.cc", + ], + deps = [ + ":coop_threads", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/sparse_matmul/os/coop_threads.cc b/sparse_matmul/os/coop_threads.cc new file mode 100644 index 00000000..ece0995d --- /dev/null +++ b/sparse_matmul/os/coop_threads.cc @@ -0,0 +1,63 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/os/coop_threads.h" + +#include + +namespace csrblocksparse { + +// All threads must execute a std::memory_order_seq_cst operation on +// |barrier_step_| this is what ensures the global memory consistency across +// the barrier. +// +// It is possible for the |barrier_step_| to roll over, but this is safe here. +// +// |yield| instructs the processor that it is in a spin loop and can stop doing +// things like out of order, speculative execution, prefetching, etc. On hyper +// threaded machines it can also choose to swap in the other thread. Note that +// this is a hardware level decision and the OS is never involved. +void SpinBarrier::barrier() { + if (num_threads_ < 2) return; + + int old_step = barrier_step_.load(std::memory_order_relaxed); + + int val_threads = threads_at_barrier_.fetch_add(1, std::memory_order_acq_rel); + + if (val_threads == num_threads_ - 1) { + // This is where the logic can go all wrong if the barrier is called by + // more threads than |num_threads_| -- the assumption that we're the last + // thread is inherently invalid. + + // Assuming num_threads_ are calling this barrier, then we're the last + // thread to reach the barrier, reset and advance step count. + threads_at_barrier_.store(0, std::memory_order_relaxed); + barrier_step_.store(old_step + 1, std::memory_order_release); + } else { + // Wait for step count to advance, then continue. + while (barrier_step_.load(std::memory_order_acquire) == old_step) { + // Intel recommends the equivalent instruction PAUSE, not be called more + // than once in a row, I can't find any recommendations for ARM, so + // following that advice here. +#if defined __aarch64__ || defined __arm__ + asm volatile("yield\n" ::: "memory"); +#else + // No pause for x86! The pause instruction on Skylake takes 141 clock + // cycles, which in an AVX2-down-clocked CPU is getting on for 70ns. +#endif + } + } +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/os/coop_threads.h b/sparse_matmul/os/coop_threads.h new file mode 100644 index 00000000..9aefa614 --- /dev/null +++ b/sparse_matmul/os/coop_threads.h @@ -0,0 +1,179 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_OS_COOP_THREADS_H_ +#define LYRA_CODEC_SPARSE_MATMUL_OS_COOP_THREADS_H_ + +#include +#include // NOLINT +#include + +#define _COOP_THREADS_USE_STD_THREAD 1 + +#include "absl/memory/memory.h" +#include "glog/logging.h" + +namespace csrblocksparse { + +// A re-usable barrier. Keeps threads in extremely tight sync without +// relinquishing control. All memory writes _before_ this barrier are visible +// to all threads _after_ this barrier. Similar in spirit to +// pthreads_barrier. If you expect arrival times at this barrier to be varied +// by more than microseconds, this is probably not the right synchronization +// primitive for you. If |num_threads| exceeds the number of physical threads +// that can run simultaneously, then using this is certainly a bad idea +// (although it should still be correct). +// +// Callers MUST NOT call barrier from more threads than |num_threads|. The +// result is undefined behavior. +class SpinBarrier { + public: + explicit SpinBarrier(int num_threads) + : num_threads_(num_threads), threads_at_barrier_(0), barrier_step_(0) {} + + void barrier(); + + private: + const int num_threads_; + std::atomic threads_at_barrier_; + std::atomic barrier_step_; // unsigned to make overflow defined. +}; + +// Producer-consumer API using the same underlying mechanism as SpinBarrier. +// This class is intended to allow >=1 producers to produce data for >=1 +// consumers, without blocking the producers. +// The consumer will block if it is ready before all the producer(s) have +// produced. +// WARNING: By design this lock does not work without some other barrier that +// prevents any producer from producing again, or consumer from consuming again +// until all consumers have consumed. Basically any loop that uses +// ProducerConsumer must have at least two consume() calls in each thread (on +// different instances) in order for the lock to work correctly. +class ProducerConsumer { + public: + ProducerConsumer(int num_producers, int num_consumers) + : num_producers_(num_producers), + num_consumers_(num_consumers), + producers_ready_(0), + consumers_passed_(0) {} + + // Indicates that the data produced by this thread is ready. Does NOT block. + // NOTE that some other lock must exist between the call to this produce and + // looping back to call produce again on the same ProducerConsumer, that + // depends on all consumers having called consume. One such candidate would + // be a call to SpinBarrier above by all producers and consumers. + // Another candidate would be a separate ProducerConsumer object in which + // these producers consume some data produced by the threads that consume + // the data produced here. Eg. + // tid 0 1 2 3 + // action 1 produce produce consume consume (on ProducerConsumer 1) + // action 2 consume consume produce produce (on ProducerConsumer 2) + // action 3 produce produce consume consume (on ProducerConsumer 3) + // action 4 consume consume produce produce (on ProducerConsumer 4) + // loop back to action 1. + // NOTE: It is inadequate to loop back after action2, as thread 0 could loop + // back and consume again on PC2 while thread 1 is still completing its call + // to consume. It is still inadequate to loop back after action 3 for the same + // reason (but tsan doesn't seem to pick this up.) + inline void produce() { + producers_ready_.fetch_add(1, std::memory_order_acq_rel); + } + + // Waits if necessary for all producers to have produced before proceeding. + // The ProducerConsumer cannot be reused until all consumers have consumed. + // See detailed comment and example on produce(). + inline void consume() { + // We can't do anything until all the producers have produced. + while (producers_ready_.load(std::memory_order_acquire) < num_producers_) { +#if defined __aarch64__ || defined __arm__ + asm volatile("yield\n" ::: "memory"); +#else + // No pause for x86! The pause instruction on Skylake takes 141 clock + // cycles, which in an AVX2-down-clocked CPU is getting on for 70ns. +#endif + } + // NOTE: It is tempting to move this fetch_add to before the wait loop to + // reduce contention for the memory location, but that would break the lock, + // as then the last to arrive could zero out the producers_ready before the + // other consumers have noticed that all producers have produced. + // With the fetch_add after the wait loop, we are guaranteed that all + // producers have produced AND all consumers have noticed that they have + // produced before we zero out the counters. + int consumers = consumers_passed_.fetch_add(1, std::memory_order_acq_rel); + if (consumers == num_consumers_ - 1) { + // The last consumer to pass has to reset everything for the next time. + producers_ready_.store(0, std::memory_order_relaxed); + consumers_passed_.store(0, std::memory_order_relaxed); + } + } + int num_producers() const { return num_producers_; } + int num_consumers() const { return num_consumers_; } + + private: + const int num_producers_; + const int num_consumers_; + std::atomic producers_ready_; + std::atomic consumers_passed_; +}; + +// We define Thread here, so we can easily change its type later. + +using Thread = std::thread; +using ThreadId = std::thread::id; + +// Creates (|num_threads|-1) threads and executes a total of |num_threads| +// copies of |func| (executes one on the calling thread). +// +// Useful for long running func bodies that are intended to run in lock step. +// A possible use case for this style parallelism over a thread pool is when +// we want tight control over which memory is resident in the L2 cache of a +// processor. With a pool we have no control over which thread gets assigned +// which portion of the computation resulting in L2 thrashing. With this +// breakdown we can make sure each thread only acceses a specific L2-sized +// portion of memory. +// +// func's signature must be (SpinBarrier*, int thread_id, ...); +template +void LaunchOnThreadsWithBarrier(int num_threads, Function&& func, + Args&&... args) { + SpinBarrier spin_barrier(num_threads); + + std::vector> threads; + threads.reserve(num_threads); + for (int tid = 1; tid < num_threads; ++tid) { + auto f = [&, tid]() { func(&spin_barrier, tid, args...); }; + + threads.emplace_back(absl::make_unique(f)); +#ifndef _COOP_THREADS_USE_STD_THREAD + CHECK_OK(threads.back()->Start()); +#endif + } + + const int kLocalTid = 0; + func(&spin_barrier, kLocalTid, args...); + + for (auto& thread : threads) { +#ifdef _COOP_THREADS_USE_STD_THREAD + thread->join(); +#else + CHECK_OK(thread->Join()); +#endif + } +} + +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_OS_COOP_THREADS_H_ diff --git a/sparse_matmul/os/coop_threads_test.cc b/sparse_matmul/os/coop_threads_test.cc new file mode 100644 index 00000000..0aba27f9 --- /dev/null +++ b/sparse_matmul/os/coop_threads_test.cc @@ -0,0 +1,134 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/os/coop_threads.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +TEST(Threads, LaunchThreads) { + std::atomic counter(0); + + auto f = [&](csrblocksparse::SpinBarrier* barrier, int tid) { + counter.fetch_add(tid); + }; + + const int kNumThreads = 10; + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); + + ASSERT_EQ(counter.load(), kNumThreads * (kNumThreads - 1) / 2); +} + +TEST(Threads, SpinBarrier) { + const int kNumThreads = 10; + + std::vector tids(kNumThreads, 0); + std::vector> expected; + for (int i = 0; i < 10; ++i) { + expected.emplace_back(kNumThreads); + std::iota(expected.back().begin(), expected.back().end(), 0); + std::transform(expected.back().begin(), expected.back().end(), + expected.back().begin(), + [i](int x) -> int { return (i + 1) * x; }); + } + + auto f = [&](csrblocksparse::SpinBarrier* barrier, int tid) { + for (int i = 0; i < 10; ++i) { + tids[tid] += tid; + barrier->barrier(); + EXPECT_EQ(tids, expected[i]); + barrier->barrier(); + } + }; + + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); +} + +TEST(Threads, ProducerConsumer) { + constexpr int kNumThreads = 4; + constexpr int kNumIterations = 10; + + std::vector shared_data(kNumThreads, 0); + std::vector> expected; + for (int i = 1; i <= kNumIterations; ++i) { + // Execute the parallel work sequentially. + // Last two threads write their id * iteration. + std::pair inputs = + std::make_pair((kNumThreads - 2) * i, (kNumThreads - 1) * i); + // First two threads compute sum and difference of those values. + std::pair diffs = std::make_pair(inputs.first + inputs.second, + inputs.first - inputs.second); + // Last two threads compute sum and product. + std::pair sums = + std::make_pair(diffs.first + diffs.second, diffs.first * diffs.second); + // First two threads compute product and difference of those values. + expected.emplace_back( + std::make_pair(sums.first * sums.second, sums.first - sums.second)); + // Last two threads will check for the correct result. + } + csrblocksparse::ProducerConsumer first_pc(2, 2); + csrblocksparse::ProducerConsumer second_pc(2, 2); + csrblocksparse::ProducerConsumer third_pc(2, 2); + csrblocksparse::ProducerConsumer fourth_pc(2, 2); + + auto f = [&](csrblocksparse::SpinBarrier* barrier, int tid) { + for (int i = 1; i <= kNumIterations; ++i) { + if (tid == kNumThreads - 2) { + // Last two threads write their id * iteration. + shared_data[tid] = tid * i; + first_pc.produce(); + second_pc.consume(); + // They then compute sum and product. + shared_data[tid] = shared_data[0] + shared_data[1]; + third_pc.produce(); + // They finally check the result. + fourth_pc.consume(); + EXPECT_EQ(expected[i - 1].first, shared_data[0]) << "i=" << i; + } else if (tid == kNumThreads - 1) { + shared_data[tid] = tid * i; + first_pc.produce(); + second_pc.consume(); + shared_data[tid] = shared_data[0] * shared_data[1]; + third_pc.produce(); + fourth_pc.consume(); + EXPECT_EQ(expected[i - 1].second, shared_data[1]) << "i=" << i; + } else if (tid == 0) { + // First two threads compute sum and difference. + first_pc.consume(); + shared_data[tid] = + shared_data[kNumThreads - 2] + shared_data[kNumThreads - 1]; + second_pc.produce(); + // They then compute product and difference. + third_pc.consume(); + shared_data[tid] = + shared_data[kNumThreads - 2] * shared_data[kNumThreads - 1]; + fourth_pc.produce(); + } else if (tid == 1) { + first_pc.consume(); + shared_data[tid] = + shared_data[kNumThreads - 2] - shared_data[kNumThreads - 1]; + second_pc.produce(); + third_pc.consume(); + shared_data[tid] = + shared_data[kNumThreads - 2] - shared_data[kNumThreads - 1]; + fourth_pc.produce(); + } + } + }; + + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); +} diff --git a/sparse_matmul/sparse_matmul.h b/sparse_matmul/sparse_matmul.h new file mode 100644 index 00000000..dc507278 --- /dev/null +++ b/sparse_matmul/sparse_matmul.h @@ -0,0 +1,34 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_SPARSE_MATMUL_H_ +#define LYRA_CODEC_SPARSE_MATMUL_SPARSE_MATMUL_H_ + +// IWYU pragma: begin_exports +#include "sparse_matmul/compute/gru_gates.h" +#include "sparse_matmul/layers/csr_blocksparse_matrix.h" +#include "sparse_matmul/layers/masked_sparse_matrix.h" +#include "sparse_matmul/layers/sparse_linear_layer.h" +#include "sparse_matmul/layers/utils.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/float16_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/os/coop_threads.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" +// IWYU pragma: end_exports + +#endif // LYRA_CODEC_SPARSE_MATMUL_SPARSE_MATMUL_H_ diff --git a/sparse_matmul/vector/BUILD b/sparse_matmul/vector/BUILD new file mode 100644 index 00000000..3fc064a4 --- /dev/null +++ b/sparse_matmul/vector/BUILD @@ -0,0 +1,63 @@ +# Vector that always aligns its data to the cache line of the host machine. + +licenses(["notice"]) + +cc_library( + name = "cache_aligned_vector", + hdrs = [ + "cache_aligned_vector.h", + ], + visibility = [ + "//sparse_matmul:__subpackages__", + ], + deps = [ + ":aligned_malloc", + "//sparse_matmul/numerics:fast_transcendentals", + "//sparse_matmul/numerics:types", + "//sparse_matmul/os:coop_threads", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "cachealignedvector_test", + size = "small", + srcs = [ + "cachealignedvector_test.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DSIGMOID_AS_TANH", + ], + deps = [ + ":cache_aligned_vector", + "//sparse_matmul/numerics:test_utils", + "//sparse_matmul/os:coop_threads", + "@com_google_googletest//:gtest_main", + ], +) + +cc_binary( + name = "cachealignedvector_benchmark", + srcs = [ + "cachealignedvector_benchmark.cc", + ], + copts = [ + "-DFAST_TRANSCENDENTALS", + "-DSIGMOID_AS_TANH", + "-DACCURATE_TRANSCENDENTAL_APPROX", + ], + deps = [ + ":cache_aligned_vector", + "@com_github_google_benchmark//:benchmark", + "@com_github_google_benchmark//:benchmark_main", + ], +) + +cc_library( + name = "aligned_malloc", + srcs = ["aligned_malloc.cc"], + hdrs = [ + "aligned_malloc.h", + ], +) diff --git a/sparse_matmul/vector/aligned_malloc.cc b/sparse_matmul/vector/aligned_malloc.cc new file mode 100644 index 00000000..410d268e --- /dev/null +++ b/sparse_matmul/vector/aligned_malloc.cc @@ -0,0 +1,46 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +namespace csrblocksparse { + +void Free(void* ptr) { free(ptr); } + +void* Malloc(size_t size) { return malloc(size); } + +void aligned_free(void* aligned_memory) { Free(aligned_memory); } + +void* aligned_malloc(size_t size, int minimum_alignment) { +#if defined(__ANDROID__) + return memalign(minimum_alignment, size); +#else // !defined(__ANDROID__) + void* ptr = nullptr; + // posix_memalign requires that the requested alignment be at least + // sizeof(void*). In this case, fall back on malloc which should return + // memory aligned to at least the size of a pointer. + const int required_alignment = sizeof(void*); + if (minimum_alignment < required_alignment) return Malloc(size); + int err = posix_memalign(&ptr, minimum_alignment, size); + if (err != 0) { + return nullptr; + } else { + return ptr; + } +#endif +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/vector/aligned_malloc.h b/sparse_matmul/vector/aligned_malloc.h new file mode 100644 index 00000000..ff13d939 --- /dev/null +++ b/sparse_matmul/vector/aligned_malloc.h @@ -0,0 +1,32 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_VECTOR_ALIGNED_MALLOC_H_ +#define LYRA_CODEC_SPARSE_MATMUL_VECTOR_ALIGNED_MALLOC_H_ + +#include +namespace csrblocksparse { + +void Free(void* ptr); + +void* Malloc(size_t size); + +void aligned_free(void* aligned_memory); + +void* aligned_malloc(size_t size, int minimum_alignment); +} // namespace csrblocksparse + +#endif // LYRA_CODEC_SPARSE_MATMUL_VECTOR_ALIGNED_MALLOC_H_ diff --git a/sparse_matmul/vector/cache_aligned_vector.h b/sparse_matmul/vector/cache_aligned_vector.h new file mode 100644 index 00000000..871298d2 --- /dev/null +++ b/sparse_matmul/vector/cache_aligned_vector.h @@ -0,0 +1,1117 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LYRA_CODEC_SPARSE_MATMUL_VECTOR_CACHE_ALIGNED_VECTOR_H_ +#define LYRA_CODEC_SPARSE_MATMUL_VECTOR_CACHE_ALIGNED_VECTOR_H_ + +#if defined __aarch64__ +#include +#endif +#if defined __AVX__ || defined __AVX2__ +#include +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "sparse_matmul/numerics/fast_transcendentals.h" +#include "sparse_matmul/numerics/fixed_types.h" +#include "sparse_matmul/numerics/type_utils.h" +#include "sparse_matmul/os/coop_threads.h" +#include "sparse_matmul/vector/aligned_malloc.h" + +namespace csrblocksparse { + +template +class MutableVectorView; +template +class VectorView; + +// CacheAlignedVector is a simple vector-like class that makes sure its +// underlying buffer is aligned to a |kCacheLineSize| boundary. It is meant +// for numeric computation and cannot be used to store objects that are +// not POD as it will neither call their constructors nor destructors. +// +// It is meant to be used with the CSRBlockSparseMatrix class for +// implenting basic neural network layers composed of SpMV. +// +// This class is thread compatible. +template +class CacheAlignedVector { + static_assert(std::is_pod::value, + "CacheAlignedVector can only be" + " used with POD"); + + public: + using value_type = DataType; + + explicit CacheAlignedVector(std::size_t size) : size_(size), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + } + + explicit CacheAlignedVector(const std::vector& input) + : size_(input.size()), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + memcpy(data_, input.data(), size_ * sizeof(DataType)); + } + + template + explicit CacheAlignedVector(const std::vector& input) + : size_(input.size()), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + for (int i = 0; i < size_; ++i) + data_[i] = static_cast(input.data()[i]); + } + + CacheAlignedVector(const DataType* input, int size) + : size_(size), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + memcpy(data_, input, size_ * sizeof(DataType)); + } + + template + explicit CacheAlignedVector(const InputType* input, int size) + : size_(size), data_(nullptr) { + gen_ = absl::make_unique(0); + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + for (int i = 0; i < size_; ++i) data_[i] = static_cast(input[i]); + } + + CacheAlignedVector() : size_(0), data_(nullptr) {} + + ~CacheAlignedVector() { + aligned_free(data_); + data_ = nullptr; + size_ = 0; + } + + // Copies are _deep_ copies + CacheAlignedVector(CacheAlignedVector const& other) + : size_(0), data_(nullptr), gen_(nullptr) { + if (other.gen_) + gen_ = absl::make_unique(std::minstd_rand(*other.gen_)); + this->resize(other.size()); + memcpy(data_, other.data(), size_ * sizeof(DataType)); + } + // Copies a slice of the input. + CacheAlignedVector(CacheAlignedVector const& other, int start, int end) + : size_(0), data_(nullptr), gen_(nullptr) { + if (other.gen_) + gen_ = absl::make_unique(std::minstd_rand(*other.gen_)); + this->resize(end - start); + memcpy(data_, other.data() + start, size_ * sizeof(DataType)); + } + + void operator=(CacheAlignedVector const& other) { + if (other.gen_) + gen_ = absl::make_unique(std::minstd_rand(*other.gen_)); + else + gen_.reset(nullptr); + this->resize(other.size()); + memcpy(data_, other.data(), size_ * sizeof(DataType)); + } + + CacheAlignedVector(CacheAlignedVector&& other) + : size_(0), data_(nullptr), gen_(std::move(other.gen_)) { + size_ = other.size_; + data_ = other.data_; + other.size_ = 0; + other.data_ = nullptr; + } + + CacheAlignedVector& operator=( + CacheAlignedVector&& other) { + aligned_free(data_); + if (other.gen_) + gen_ = absl::make_unique(std::move(*other.gen_)); + else + gen_.reset(nullptr); + size_ = other.size_; + data_ = other.data_; + other.size_ = 0; + other.data_ = nullptr; + return *this; + } + + VectorView AsView() const { + return VectorView(this->data(), this->size(), 1); + } + + MutableVectorView AsMutableView() { + return MutableVectorView(this->data(), this->size(), 1); + } + + // Copies the |split_points| to use in ReducingSample. + void PrepareForThreads(const std::vector& split_points, + int block_height) { + maxes_.resize(split_points.size() - 1); + thread_starts_ = split_points; + for (int t = 0; t < thread_starts_.size(); ++t) { + thread_starts_[t] *= block_height; + } + } + + void FillRandom(float min = -10.f, float max = 10.f) { + // 10 is smaller than any nonzero bound of the range of any data type. + std::uniform_real_distribution dist(min, max); + for (std::size_t i = 0; i < size_; i++) { + data_[i] = DataType(dist(*gen_)); + } + } + + void FillZero() { + for (std::size_t i = 0; i < size_; i++) { + data_[i] = DataType(0.f); + } + } + + void FillOnes() { + for (std::size_t i = 0; i < size_; i++) { + data_[i] = DataType(1.f); + } + } + + void FillWith(const DataType& value) { + for (std::size_t i = 0; i < size_; i++) { + data_[i] = value; + } + } + + // Interprets |data_| as logits and samples from the distribution, this + // version operates IN PLACE and uses an internal random source. + template + typename std::enable_if::value, int>::type Sample( + float temperature = 1.f) { + return Sample(temperature, gen_.get(), this); + } + + // Interprets |data_| as logits and samples. This version requires the random + // source and temporary memory to be passed in. It is thread safe assuming + // no other threads are using the generator and temporary memory. +#if defined __aarch64__ + template + typename std::enable_if::value, int>::type Sample( + float temperature, std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + DCHECK(scratch->size() >= size_); + // Round down to nearest multiple of 8. + int SIMD_iterations = 8 * (size_ / 8); + float* scratch_ptr = scratch->data(); + std::uniform_real_distribution dist; + float random_number = dist(*gen); + + float32x4_t sum = vdupq_n_f32(0.f); + float32x4_t sum1 = vdupq_n_f32(0.f); + float32x4_t max_value = vdupq_n_f32(std::numeric_limits::lowest()); + float32x4_t max_value1 = vdupq_n_f32(std::numeric_limits::lowest()); + float32x4_t inv_temp = vdupq_n_f32(1.f / temperature); + // Compute sum of exp(x) for the denominator. + // Hand unroll by 2, gives speed improvement. + constexpr int kUnrollFactor = 2; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < SIMD_iterations; i += kElementsPerIter) { + max_value = vmaxq_f32(vld1q_f32(data_ + i), max_value); + max_value1 = vmaxq_f32(vld1q_f32(data_ + i + 4), max_value1); + } + + // Pairwise reduction. + max_value = vpmaxq_f32(max_value, max_value1); + // Duplicate (dupq) maximum across vector (maxnmvq). + float scalar_max_value = vmaxvq_f32(max_value); + + for (int i = SIMD_iterations; i < size_; ++i) { + scalar_max_value = std::max(data_[i], scalar_max_value); + } + + max_value = vdupq_n_f32(scalar_max_value); + + for (std::size_t i = 0; i < SIMD_iterations; i += kElementsPerIter) { + // Load and multiply by temperature. + float32x4_t x = + vmulq_f32(vsubq_f32(vld1q_f32(data_ + i), max_value), inv_temp); + float32x4_t x1 = + vmulq_f32(vsubq_f32(vld1q_f32(data_ + i + 4), max_value), inv_temp); + + float32x4_t exponent = fast_exp(x); + float32x4_t exponent1 = fast_exp(x1); + + sum = vaddq_f32(sum, exponent); + sum1 = vaddq_f32(sum1, exponent1); + + vst1q_f32(scratch_ptr + i, exponent); + vst1q_f32(scratch_ptr + i + 4, exponent1); + } + + // Horizontally reduce the two sums. + sum = vpaddq_f32(sum, sum1); + sum = vpaddq_f32(sum, sum); + float denom = vgetq_lane_f32(sum, 0) + vgetq_lane_f32(sum, 1); + + for (int i = SIMD_iterations; i < size_; ++i) { + float x = (data_[i] - scalar_max_value) / temperature; + float x_exp = expf(x); + denom += x_exp; + scratch_ptr[i] = x_exp; + } + + // Note: rather than normalize all the probabilities, we can just + // apply the inverse normalization to the random number. + random_number *= denom; + + // Now do the scan in serial, return as soon as possible. + // TODO(b/188821456): This could be made into a parallel SIMD scan + // followed by a binary search, for a small speedup. + float cumsum = 0.f; + for (std::size_t i = 0; i < size_; i++) { + cumsum += scratch_ptr[i]; + if (cumsum >= random_number) return i; + } + return size_ - 1; + } + + template + static inline int32x4_t vmul_temp_fixed(int32x4_t x, int32x2_t inv_temp) { + int32x2_t xh = vget_high_s32(x); + int32x2_t xl = vget_low_s32(x); + int32x2_t ph = vqrshrn_n_s64(vmull_s32(xh, inv_temp), Q::kMantissaBits); + int32x2_t pl = vqrshrn_n_s64(vmull_s32(xl, inv_temp), Q::kMantissaBits); + return vcombine_s32(pl, ph); + } + + template + static inline int float_to_fixed(float x) { + return static_cast(x * (1 << Q::kMantissaBits)); + } + + template + static inline float fixed_to_float(int x) { + const float inv_denom = 1.f / (1 << Q::kMantissaBits); + return static_cast(x) * inv_denom; + } + + template + typename std::enable_if::value, int>::type Sample( + float temperature, std::minstd_rand* gen, + CacheAlignedVector* scratch) const { + DCHECK(scratch->size() >= size_); + // Round down to nearest multiple of 8. + int SIMD_iterations = 8 * (size_ / 8); + int* scratch_ptr = scratch->data(); + float scalar_inv_temp = 1.f / temperature; + + int32x4_t sum = vdupq_n_s32(0); + int32x4_t sum1 = vdupq_n_s32(0); + int32x4_t max_value = vdupq_n_s32(std::numeric_limits::lowest()); + int32x4_t max_value1 = vdupq_n_s32(std::numeric_limits::lowest()); + int32x2_t inv_temp = vdup_n_s32(float_to_fixed(scalar_inv_temp)); + // Compute sum of exp(x) for the denominator. + // Hand unroll by 2, gives speed improvement. + + const int* data_ptr = reinterpret_cast(data_); + constexpr int kUnrollFactor = 2; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < SIMD_iterations; i += kElementsPerIter) { + max_value = vmaxq_s32(vld1q_s32(data_ptr + i), max_value); + max_value1 = vmaxq_s32(vld1q_s32(data_ptr + i + kSIMDWidth), max_value1); + } + + // Pairwise reduction. + max_value = vpmaxq_s32(max_value, max_value1); + int scalar_max_value = vmaxvq_s32(max_value); + + for (int i = SIMD_iterations; i < size_; ++i) { + scalar_max_value = std::max(data_[i].raw_val(), scalar_max_value); + } + max_value = vdupq_n_s32(scalar_max_value); + // We clip all loaded values to a lower bound of the lowest possible arg to + // exp + the max value that we are going to subtract, to prevent underflow + // in exp and also to avoid wrap-around with values that are already minint. + int32x4_t clip_min = + vdupq_n_s32(scalar_max_value - (80 << MantissaBitsOf::value)); + + for (std::size_t i = 0; i < SIMD_iterations; i += kElementsPerIter) { + // Load and multiply by temperature. + int32x4_t loaded = vmaxq_s32(vld1q_s32(data_ptr + i), clip_min); + int32x4_t x = vmul_temp_fixed(vsubq_s32(loaded, max_value), inv_temp); + loaded = vmaxq_s32(vld1q_s32(data_ptr + i + kSIMDWidth), clip_min); + int32x4_t x1 = vmul_temp_fixed(vsubq_s32(loaded, max_value), inv_temp); + + int32x4_t exponent = vcvtq_n_s32_f32(fast_exp_fixed(x), + Q::kMantissaBits); + int32x4_t exponent1 = vcvtq_n_s32_f32( + fast_exp_fixed(x1), Q::kMantissaBits); + + sum = vaddq_s32(sum, exponent); + sum1 = vaddq_s32(sum1, exponent1); + + vst1q_s32(scratch_ptr + i, exponent); + vst1q_s32(scratch_ptr + i + kSIMDWidth, exponent1); + } + + // Horizontally reduce the two sums. + sum = vpaddq_s32(sum, sum1); + sum = vpaddq_s32(sum, sum); + float denom = + fixed_to_float(vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1)); + for (int i = SIMD_iterations; i < size_; ++i) { + float x_exp = fast_exp_fixed( + DataType((data_[i].raw_val() - scalar_max_value) * scalar_inv_temp)); + + denom += x_exp; + scratch_ptr[i] = float_to_fixed(x_exp); + } + + // Note: rather than normalize all the probabilities, we can just + // apply the inverse normalization to the random number. + std::uniform_real_distribution dist; + int random_number = float_to_fixed(dist(*gen) * denom); + + // Now do the scan in serial, return as soon as possible. + // TODO(b/188821456): This could be made into a parallel SIMD scan + // followed by a binary search, for a small speedup. + int cumsum = 0; + for (std::size_t i = 0; i < size_; i += kSIMDWidth) { + int32x4_t next_vals = vld1q_s32(&scratch_ptr[i]); + cumsum += vaddvq_s32(next_vals); + if (cumsum >= random_number) { + int high_sum = vaddv_s32(vget_high_s32(next_vals)); + if (cumsum - high_sum > random_number) { + // One of the lower ones. + return (cumsum - high_sum - scratch_ptr[i + 1] > random_number) + ? i + : i + 1; + } else { + // One of the upper ones. + return (cumsum - scratch_ptr[i + 3] > random_number) ? i + 2 : i + 3; + } + } + } + return size_ - 1; + } +#endif // defined __aarch64__ + + template +#if defined __aarch64__ + typename std::enable_if< + !std::is_same::value && !IsFixed32Type::value, int>::type +#else + int +#endif + Sample(float temperature, std::minstd_rand* gen, + CacheAlignedVector* scratch, int tid = 0, + SpinBarrier* barrier = nullptr) const { + return ScalarSample(temperature, gen, scratch, tid, 0, -1, barrier); + } + + int ScalarSample(float temperature, std::minstd_rand* gen, + CacheAlignedVector* scratch, int tid = 0, + const int mindex = 0, const int maxdex = -1, + SpinBarrier* barrier = nullptr) const { + // TODO(b/188821456) Don't ignore |tid| and |barrier|. Currently all threads + // duplicate the same work and ignore |tid| and |barrier|, but they could + // be used to execute a reducing max over the data before the exp operation. + DCHECK_EQ(barrier, nullptr); + DCHECK_EQ(tid, 0); + DCHECK(scratch->size() >= size_); + DCHECK(size_ % 8 == 0) << "CacheAlignedVector size must be a multiple of " + "8 to allow for maximum SIMD and loop unroll, " + "got " + << size_ % 8; + DCHECK(size_ > mindex >= 0); + DCHECK((maxdex == -1) || (0 <= mindex < maxdex < size_)); + int maxindex = maxdex > 0 ? maxdex : size_; + + float* scratch_ptr = scratch->data(); + std::uniform_real_distribution dist; + float random_number = dist(*gen); + + float sum = 0.f; + float max_value = std::numeric_limits::lowest(); + for (int i = mindex; i < maxindex; ++i) { + max_value = std::max(max_value, static_cast(data_[i])); + } + float inv_temperature = 1.f / temperature; + for (int i = mindex; i < maxindex; ++i) { + float exponent = fast_exp((static_cast(data_[i]) - max_value) * + inv_temperature); + scratch_ptr[i] = exponent; + sum += exponent; + } + + // Note: rather than normalize all the probabilities, we can just + // apply the inverse normalization to the random number. + random_number *= sum; + + float cumsum = 0.f; + for (std::size_t i = mindex; i < maxindex; i++) { + cumsum += scratch_ptr[i]; + if (cumsum >= random_number) return i; + } + return maxindex - 1; + } + +#if defined __AVX2__ + // Some AVX2-only code. + // Returns the max of |data_| in the range [|t_start|, |t_end|). + inline int ThreadMax(int t_start, int t_end) const { + // Note: The AVX2 code requires that the number of threads and the output + // size be a power of 2. For efficiency purposes, these should be checked + // when preparing for threads in an architecture class. + // The output size must be a power of 2 so the binary search for the sample + // point works correctly. + // The number of threads must be a power of 2 so that it nicely divides the + // output size, which has to be a power of 2. + __m256i maxes = + _mm256_load_si256(reinterpret_cast<__m256i const*>(data_ + t_start)); + for (int i = t_start + kSIMDWidth; i < t_end; i += kSIMDWidth) { + __m256i data = + _mm256_load_si256(reinterpret_cast<__m256i const*>(data_ + i)); + maxes = _mm256_max_epi32(maxes, data); + } + // Max within the register. + // Bring the top lane down to the bottom. + __m256i other = _mm256_permute4x64_epi64(maxes, 0xe); + maxes = _mm256_max_epi32(maxes, other); + // Bring the 2nd 64 bits to the bottom. + other = _mm256_shuffle_epi32(maxes, 0xe); + maxes = _mm256_max_epi32(maxes, other); + // Bring the 2nd 32 bits to the bottom. + other = _mm256_shuffle_epi32(maxes, 1); + maxes = _mm256_max_epi32(maxes, other); + return _mm256_extract_epi32(maxes, 0); + } + + // Applies exp (approximately) to the difference between |data_| and + // |max_value|, storing the result in scratch, and returns the sum. + template + inline float ApplyExpAndSum(int max_value, float* scratch_ptr) { + // Rough approximation for exp(x). See fast_exp_fixed. + // Constant clipping limit on exp arg. Since its value is never positive, + // we only need to clip on the negative side. + constexpr int kClipLimit = -(80 << kMantissaBits); + __m256i clip_val = _mm256_set1_epi32(kClipLimit); + // Multiplication factor to convert x from log base e to log base 2, shifted + // by an amount that lines up the binary point with the float32 + // representation, after the multiplication + static const int kLogFactor = (1 << (23 - kMantissaBits)) / logf(2.f); + __m256i log_factor = _mm256_set1_epi32(kLogFactor); + // Fix the exponent bias and add the additive fudge factor for the mantissa + // to finish the approximate conversion. + constexpr int kAddConstant = (127 << 23) - 366000; + __m256i constant = _mm256_set1_epi32(kAddConstant); + // Broadcast the max_value. + __m256i max_val = _mm256_set1_epi32(max_value); + // Add the max to the |clip_val|, so it can be used before the subtraction. + clip_val = _mm256_add_epi32(clip_val, max_val); + // The sum of the exps. + __m256 sum1 = _mm256_setzero_ps(); + for (int i = 0; i < size_; i += kSIMDWidth) { + // |data_| - |max_value|. + __m256i data = + _mm256_load_si256(reinterpret_cast<__m256i const*>(data_ + i)); + // Clip to negative limit before the subtraction of |max_val| to avoid + // wrap-around with min-int values. + data = _mm256_max_epi32(data, clip_val); + __m256i difference = _mm256_sub_epi32(data, max_val); + // Exponent trick exp. + // Multiply by |log_factor|, keeping only the lower 32 bits. + difference = _mm256_mullo_epi32(difference, log_factor); + // Add the constant. + difference = _mm256_add_epi32(difference, constant); + // Reinterpret the results as float32. + __m256 float_exp = _mm256_castsi256_ps(difference); + // Sum the results and save to scratch space. + _mm256_store_ps(scratch_ptr + i, float_exp); + sum1 = _mm256_add_ps(sum1, float_exp); + } + // Horizontally add the 8 values in sum. + // Get the top lane down to the bottom. + __m256 sum2 = _mm256_permute2f128_ps(sum1, sum1, 1); + sum1 = _mm256_add_ps(sum1, sum2); + sum1 = _mm256_hadd_ps(sum1, sum1); + sum1 = _mm256_hadd_ps(sum1, sum1); + return _mm256_cvtss_f32(sum1); + } + + // Binary search for the index where the cumulative sum meets random_target. + inline void FindSamplePoint(const float* scratch_ptr, float* random_target, + int* start, int* end) { + int halfsize = (*end - *start) / 2; + do { + // Sum the first half. + // We sum the section in two independent parts, so we can step down 2 + // levels if we get a hit in this half. + int quartersize = halfsize / (2 * kSIMDWidth); + quartersize *= kSIMDWidth; + halfsize = quartersize * 2; + // The sums of the quarters. + __m256 sum1 = _mm256_setzero_ps(); + __m256 sum2 = _mm256_setzero_ps(); + const float* ptr1 = scratch_ptr + *start; + const float* ptr2 = ptr1 + quartersize; + for (int i = 0; i < quartersize; i += kSIMDWidth) { + __m256 data1 = _mm256_load_ps(ptr1 + i); + __m256 data2 = _mm256_load_ps(ptr2 + i); + sum1 = _mm256_add_ps(sum1, data1); + sum2 = _mm256_add_ps(sum2, data2); + } + // Horizontally add the two sums, keeping the results separate. + // Numbering |sum1|=[0-7] and |sum2|=[8-15]... + sum1 = _mm256_hadd_ps(sum1, sum2); + // |sum1| now has [0+1, 2+3, 8+9, 10+11, 4+5, 6+7, 12+13, 14+15]. + // Bring the top lane down to the bottom. + sum2 = _mm256_permute2f128_ps(sum1, sum1, 1); + sum1 = _mm256_hadd_ps(sum1, sum2); + // Now |sum1| has [0-3, 8-11, 4-7, 12-15], so swap the middle two + // elements. + sum1 = _mm256_shuffle_ps(sum1, sum1, 0xd8); + sum1 = _mm256_hadd_ps(sum1, sum1); + // Now |sum1| has [0-7, 8-15, ....]. + float bottom_quarter = _mm256_cvtss_f32(sum1); + if (bottom_quarter >= *random_target) { + *end = *start + quartersize; + } else { + float bottom_half = _mm256_cvtss_f32(_mm256_hadd_ps(sum1, sum1)); + if (bottom_half >= *random_target) { + *start += quartersize; + *end = *start + quartersize; + *random_target -= bottom_quarter; + } else { + *start += halfsize; + *random_target -= bottom_half; + } + } + halfsize = (*end - *start) / 2; + } while (halfsize >= kSIMDWidth * 2); + } +#endif // __AVX2__ code + + // Fixed32 version. + template + typename std::enable_if::value, int>::type ThreadMax( + int tid) const { + int t_start = thread_starts_[tid]; + int t_end = thread_starts_[tid + 1]; +#if defined __AVX2__ + return ThreadMax(t_start, t_end); +#else + // With operator<, could use std::max_element. + int max_value = data_[t_start].raw_val(); + for (int i = t_start + 1; i < t_end; ++i) { + max_value = std::max(max_value, data_[i].raw_val()); + } + return max_value; +#endif + } + + // As Sample above, except that if |tid| and |barrier| are provided, it will + // save some time by running a local max in each thread before combining them + // and doing the rest of the work duplicated across all threads. + // Fixed32 version. + template + typename std::enable_if::value, int>::type ReducingSample( + std::minstd_rand* gen, CacheAlignedVector* scratch, int tid = 0, + float temperature = 1.0f, SpinBarrier* barrier = nullptr) { + if (barrier != nullptr) barrier->barrier(); + // Sample only accepts tid of 0, as it would ignore it anyway. + // All threads duplicate the same work in this path. + return Sample(temperature, gen, scratch, /*tid=*/0); + } + + template + typename std::enable_if::value, int>::type ReducingSample( + std::minstd_rand* gen, CacheAlignedVector* scratch, int tid = 0, + float temperature = 1.0f, SpinBarrier* barrier = nullptr) { + int max_value; + if (barrier == nullptr) { + // There is only one thread. + max_value = ThreadMax(tid); + } else { + // Reduce max using the threads to do some of the work. + maxes_[tid] = ThreadMax(tid); + barrier->barrier(); + // The rest of the work is duplicated by all threads. + max_value = *std::max_element(maxes_.begin(), maxes_.end()); + } + float* scratch_ptr = scratch->data(); + std::uniform_real_distribution dist; + float sum = 0.0f; +#if defined __AVX2__ + sum = ApplyExpAndSum::value>(max_value, scratch_ptr); +#else + int clip_limit = max_value - (80 << MantissaBitsOf::value); + for (int i = 0; i < size_; ++i) { + int difference = std::max(data_[i].raw_val(), clip_limit) - max_value; + float exponent = expf(static_cast(DataType(difference))); + scratch_ptr[i] = exponent; + sum += exponent; + } +#endif // __AVX2__ + + float random_target = dist(*gen) * sum; + int start = 0; + int end = size_; + +#if defined __AVX2__ + FindSamplePoint(scratch_ptr, &random_target, &start, &end); + // The scalar code finishes the job from here... +#endif // __AVX2__ + float cumsum = 0.f; + for (std::size_t i = start; i < end; i++) { + cumsum += scratch_ptr[i]; + if (cumsum >= random_target) return i; + } + return end - 1; + } + + template + typename std::enable_if::value, void>::type Exp() { +#if defined __aarch64__ + DCHECK(size_ % 16 == 0) << "CacheAlignedVector size must be a multiple of " + "16 to allow for maximum SIMD and loop unroll " + "got " + << size_ % 16; + constexpr int kUnrollFactor = 4; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < size_; i += kElementsPerIter) { + float32x4_t x = vld1q_f32(data_ + i); + float32x4_t x1 = vld1q_f32(data_ + i + 4); + float32x4_t x2 = vld1q_f32(data_ + i + 8); + float32x4_t x3 = vld1q_f32(data_ + i + 12); + + vst1q_f32(data_ + i, fast_exp(x)); + vst1q_f32(data_ + i + 4, fast_exp(x1)); + vst1q_f32(data_ + i + 8, fast_exp(x2)); + vst1q_f32(data_ + i + 12, fast_exp(x3)); + } +#else + for (int i = 0; i < size_; ++i) { + data_[i] = expf(data_[i]); + } +#endif // defined __aarch64__ + } + + template + typename std::enable_if::value, void>::type Sigmoid() { +#if defined __aarch64__ + DCHECK(size_ % 8 == 0) << "CacheAlignedVector size must be a multiple of " + "8 to allow for maximum SIMD and loop unroll " + "got " + << size_ % 8; + constexpr int kUnrollFactor = 2; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < size_; i += kElementsPerIter) { + float32x4_t x = vld1q_f32(data_ + i); + float32x4_t x1 = vld1q_f32(data_ + i + 4); + + vst1q_f32(data_ + i, fast_sigmoid(x)); + vst1q_f32(data_ + i + 4, fast_sigmoid(x1)); + } +#else + for (int i = 0; i < size_; ++i) { + data_[i] = 1.f / (1.f + expf(-data_[i])); + } +#endif // defined __aarch64__ + } + + template + typename std::enable_if< + IsFixed32Type::value && IsFixed32Type::value, void>::type + // For benchmarking only. + Sigmoid(const int32_t* sigmoid_table, CacheAlignedVector* result) { +#if defined __AVX2__ + for (int i = 0; i < size_; i += kSIMDWidth) { + __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(data_ + i)); + __m256i output = fixed32_sigmoid_fixed16::value, + MantissaBitsOf::value>( + sigmoid_table, x_in); + _mm256_store_si256(reinterpret_cast<__m256i*>(result->data() + i), + output); + } +#else + for (int i = 0; i < size_; ++i) { + result->data()[i] = 1.f / (1.f + expf(-data_[i])); + } +#endif // defined __AVX2__ + } + + template + typename std::enable_if::value, void>::type Tanh() { +#if defined __aarch64__ + DCHECK(size_ % 8 == 0) << "CacheAlignedVector size must be a multiple of " + "8 to allow for maximum SIMD and loop unroll " + "got " + << size_ % 8; + constexpr int kUnrollFactor = 2; + constexpr int kElementsPerIter = kUnrollFactor * kSIMDWidth; + for (std::size_t i = 0; i < size_; i += kElementsPerIter) { + float32x4_t x = vld1q_f32(data_ + i); + float32x4_t x1 = vld1q_f32(data_ + i + 4); + + vst1q_f32(data_ + i, fast_tanh(x)); + vst1q_f32(data_ + i + 4, fast_tanh(x1)); + } +#else + for (int i = 0; i < size_; ++i) { + data_[i] = tanhf(data_[i]); + } +#endif // defined __aarch64__ + } + + template + typename std::enable_if< + IsFixed32Type::value && IsFixed32Type::value, void>::type + // For benchmarking only + Tanh(const int32_t* tanh_table, CacheAlignedVector* result) { +#if defined __AVX2__ + for (int i = 0; i < size_; i += kSIMDWidth) { + __m256i x_in = _mm256_loadu_si256(reinterpret_cast<__m256i*>(data_ + i)); + __m256i output = + fixed32_tanh_fixed16::value, + MantissaBitsOf::value>(tanh_table, x_in); + _mm256_store_si256(reinterpret_cast<__m256i*>(result->data() + i), + output); + } +#else + for (int i = 0; i < size_; ++i) { + result->data()[i] = tanhf(data_[i]); + } +#endif // defined __AVX2__ + } + + // Returns |data_| cast to the correct integer type if fixed point. + template + typename std::enable_if::value, const int32_t*>::type + cast_data() const { + return reinterpret_cast(data_); + } + template + typename std::enable_if::value, const int16_t*>::type + cast_data() const { + return reinterpret_cast(data_); + } + template + typename std::enable_if::value || IsFixed16Type::value), + const Q*>::type + cast_data() const { + return data_; + } + const DataType* begin() const { return data_; } + const DataType* end() const { return data_ + size_; } + const DataType* data() const { return data_; } + DataType* data() { return data_; } + + const DataType& operator[](int pos) const { return data_[pos]; } + DataType& operator[](int pos) { return data_[pos]; } + + std::size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + std::size_t bytes() const { return size_ * sizeof(DataType); } + + int rows() const { return size_; } + int cols() const { return 1; } + + // Stride to get to move over by one column (which is the number of rows). + int col_stride() const { return size_; } + + void Print() const { + for (int i = 0; i < size(); ++i) + absl::PrintF("[%d]=%g\n", i, static_cast(data_[i])); + } + + float maximum() const { + float max_val = std::numeric_limits::lowest(); + for (int i = 0; i < size_; ++i) { + max_val = std::max(max_val, std::abs(static_cast(data_[i]))); + } + + return max_val; + } + + private: + void resize(std::size_t size) { + aligned_free(data_); + size_ = size; + data_ = reinterpret_cast( + aligned_malloc(size_ * sizeof(DataType), kCacheLineSize)); + } + + std::size_t size_; + DataType* data_; + // Data used by the threaded version for sampling only. + std::vector maxes_; // Max value of logits. + std::vector thread_starts_; // First index for this thread. +#if defined __AVX__ || defined __AVX2__ + static constexpr int kCacheLineSize = 64; + static constexpr int kSIMDWidth = 8; +#else + static constexpr int kCacheLineSize = 128; + static constexpr int kSIMDWidth = 4; +#endif // __AVX__ + std::unique_ptr gen_; +}; + +// Used for doing Sparse Matrix * Dense Matrix multiplication. This class is +// not intended to be a general Matrix class, just for the RHS of a SpMM, hence +// the name fat vector rather than Matrix. The data layout is COLUMN MAJOR. +template +class FatCacheAlignedVector { + public: + using value_type = T; + + FatCacheAlignedVector() : rows_(0), cols_(0) {} + + // Creates a new vector that is (rows, cols), doesn't init memory. + FatCacheAlignedVector(int rows, int cols) + : vector_(rows * cols), rows_(rows), cols_(cols) {} + + // Copies and reshapes vector from (1, size) to (|rows|, size / |rows|). + FatCacheAlignedVector(const CacheAlignedVector& vector, int rows) + : vector_(vector), rows_(rows) { + CHECK_EQ(vector_.size() % rows_, 0); + cols_ = vector_.size() / rows_; + } + + template + explicit FatCacheAlignedVector(const FatCacheAlignedVector& vector) + : vector_(vector.size()), rows_(vector.rows()), cols_(vector.cols()) { + for (int i = 0; i < vector.size(); ++i) { + vector_[i] = static_cast(vector[i]); + } + } + + // Moves and reshapes vector from (1, size) to (|rows|, size / |rows|) + FatCacheAlignedVector(CacheAlignedVector&& vector, int rows) + : vector_(vector), rows_(rows) { + CHECK_EQ(vector_.size() % rows_, 0); + cols_ = vector_.size() / rows_; + } + + VectorView slice(const int col) const { + return VectorView(this->data() + rows() * col, rows(), 1); + } + MutableVectorView slice(const int col) { + return MutableVectorView(this->data() + rows() * col, rows(), 1); + } + + const T* data() const { return vector_.data(); } + T* data() { return vector_.data(); } + // Returns |data_| cast to the correct integer type if fixed point. + template + typename std::enable_if::value, const int32_t*>::type + cast_data() const { + return vector_.cast_data(); + } + template + typename std::enable_if::value, const int16_t*>::type + cast_data() const { + return vector_.cast_data(); + } + template + typename std::enable_if::value || IsFixed16Type::value), + const Q*>::type + cast_data() const { + return vector_.cast_data(); + } + + int rows() const { return rows_; } + int cols() const { return cols_; } + int size() const { return rows_ * cols_; } + bool empty() const { return rows_ == 0 || cols_ == 0; } + std::size_t bytes() const { return vector_.bytes(); } + + void reshape(int rows, int cols) { + CHECK_EQ(rows * cols, rows_ * cols_); + rows_ = rows; + cols_ = cols; + } + + float maximum() const { return vector_.maximum(); } + + // Stride to get to move over by one column (which is the number of rows). + int col_stride() const { return rows_; } + + void FillOnes() { vector_.FillOnes(); } + void FillZero() { vector_.FillZero(); } + void FillRandom(float min = -10.f, float max = 10.f) { + vector_.FillRandom(min, max); + } + + const T& operator[](int pos) const { return vector_[pos]; } + T& operator[](int pos) { return vector_[pos]; } + + private: + CacheAlignedVector vector_; + int rows_; + int cols_; +}; + +// View into a 2D Matrix. Currently only supports partitions by row. This is +// expected to be used with underlying data that is COLUMN MAJOR. +template +class MutableVectorView { + public: + using value_type = T; + + // Construct from a raw pointer, |rows|, |cols| and |col_stride|. + // |col_stride| will default to |rows| if not specified. + explicit MutableVectorView(T* data = nullptr, int rows = 0, int cols = 0, + int col_stride = 0) + : data_(data), + rows_(rows), + cols_(cols), + col_stride_(col_stride > 0 ? col_stride : rows) {} + + // Construct from a CacheAlignedVector, must have one column, can optionally + // specify an offset and row count. + explicit MutableVectorView(CacheAlignedVector* vector) + : MutableVectorView(vector->data(), vector->rows(), 1) {} + + explicit MutableVectorView(CacheAlignedVector* vector, int pos = 0, + int rows = 0) + : MutableVectorView(vector->data() + pos, + rows == 0 ? vector->rows() - pos : rows, 1, + vector->rows()) {} + + // Construct from a FatCacheAlignedVector, can optionally specify an offset, + // and row count. Views that have fewer columns than the original are not + // supported. + explicit MutableVectorView(FatCacheAlignedVector* vector) + : MutableVectorView(vector->data(), vector->rows(), vector->cols()) {} + + MutableVectorView(FatCacheAlignedVector* vector, int pos, int rows) + : MutableVectorView(vector->data() + pos, rows, vector->cols(), + vector->rows()) {} + + T* data() { return data_; } + const T* data() const { return data_; } + + // Returns |data_| cast to the correct integer type if fixed point. + template + typename std::enable_if::value, const int32_t*>::type + cast_data() const { + return reinterpret_cast(data_); + } + template + typename std::enable_if::value, const int16_t*>::type + cast_data() const { + return reinterpret_cast(data_); + } + template + typename std::enable_if::value || IsFixed16Type::value), + const Q*>::type + cast_data() const { + return data_; + } + + // Number of columns in the underlying (Fat)CacheAlignedVector. + int cols() const { return cols_; } + + // Number of rows in this view. + int rows() const { return rows_; } + + // Returns true if there's nothing in the MutableVectorView. + bool empty() const { return rows_ == 0 || cols_ == 0; } + + // Stride to get to the next column (usually the number of rows in the + // underlying data structure). + int col_stride() const { return col_stride_; } + + // Returns the total number of bytes that are "owned" by this view. Uses + // cols and not col_stride. + std::size_t bytes() const { return rows_ * cols_ * sizeof(T); } + + void reshape(int rows, int cols) { + CHECK_EQ(rows * cols, rows_ * cols_); + rows_ = rows; + cols_ = cols; + col_stride_ = rows_; + } + + const T& operator[](int pos) const { return data_[pos]; } + T& operator[](int pos) { return data_[pos]; } + + protected: + T* data_; + int rows_; + int cols_; + int col_stride_; +}; + +// Specialization of MutableVectorView which is read-only. +template +class VectorView : public MutableVectorView { + public: + using value_type = T; + + explicit VectorView(const MutableVectorView& other) + : MutableVectorView(other.data(), other.rows(), other.cols(), + other.col_stride()) {} + + // Construct from a raw pointer, |rows|, |cols| and |col_stride|. + // |col_stride| will default to |rows| if not specified. + explicit VectorView(const T* data = nullptr, int rows = 0, int cols = 0, + int col_stride = 0) + : MutableVectorView(data, rows, cols, col_stride) {} + + // Construct from a CacheAlignedVector, must have one column, can optionally + // specify an offset and row count + explicit VectorView(const CacheAlignedVector& vector) + : MutableVectorView(vector.data(), vector.rows(), 1) {} + + explicit VectorView(const CacheAlignedVector& vector, int pos = 0, + int rows = 0) + : MutableVectorView(vector.data() + pos, + rows == 0 ? vector.rows() - pos : rows, 1, + vector.rows()) {} + + // Construct from a FatCacheAlignedVector, can optionally specify an offset, + // and row count. Views that have fewer columns than the original are not + // supported. + explicit VectorView(const FatCacheAlignedVector& vector) + : MutableVectorView(vector.data(), vector.rows(), + vector.cols()) {} + + VectorView(const FatCacheAlignedVector& vector, int pos, int rows) + : MutableVectorView(vector.data() + pos, rows, vector.cols(), + vector.rows()) {} + + VectorView& operator=(const MutableVectorView& other) { + this->data_ = other.data(); + this->rows_ = other.rows(); + this->cols_ = other.cols(); + this->col_stride_ = other.col_stride(); + return *this; + } +}; + +} // namespace csrblocksparse +#endif // LYRA_CODEC_SPARSE_MATMUL_VECTOR_CACHE_ALIGNED_VECTOR_H_ diff --git a/sparse_matmul/vector/cachealignedvector_benchmark.cc b/sparse_matmul/vector/cachealignedvector_benchmark.cc new file mode 100644 index 00000000..9141e2d5 --- /dev/null +++ b/sparse_matmul/vector/cachealignedvector_benchmark.cc @@ -0,0 +1,60 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "benchmark/benchmark.h" +#include "sparse_matmul/vector/cache_aligned_vector.h" + +// A simple benchmark for CacheAlignedVector. +// +// Running on x86: +// As written, it's not representative of x86 performance since ReducingSample +// is used on x86 and not Sample. +// +// Running on arm64: +// bazel build -c opt --dynamic_mode=off --copt=-gmlt \ +// --copt=-DUSE_FIXED32 --config=android_arm64 \ +// sparse_matmul/vector:cachealignedvector_benchmark +namespace csrblocksparse { + +#ifdef USE_BFLOAT16 +using ComputeType = csrblocksparse::bfloat16; +#elif defined USE_FIXED32 +using ComputeType = csrblocksparse::fixed32<11>; // kGruMatMulOutBits +#else +using ComputeType = float; +#endif // USE_BFLOAT16 + +#if defined(USE_FIXED32) && defined(__aarch64__) +using ScratchType = int; +#else +using ScratchType = float; +#endif // defined(USE_FIXED32) && defined(__aarch64__) + +void BM_Sample(benchmark::State& state) { + constexpr int kVectorSize = 16384; // A large vector. + std::minstd_rand generator; + + CacheAlignedVector values(kVectorSize); + CacheAlignedVector scratch(kVectorSize); + values.FillRandom(); + + for (auto _ : state) { + values.Sample(/*temperature=*/0.98f, &generator, &scratch); + } +} +BENCHMARK(BM_Sample); + +} // namespace csrblocksparse diff --git a/sparse_matmul/vector/cachealignedvector_test.cc b/sparse_matmul/vector/cachealignedvector_test.cc new file mode 100644 index 00000000..245c64d7 --- /dev/null +++ b/sparse_matmul/vector/cachealignedvector_test.cc @@ -0,0 +1,405 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sparse_matmul/vector/cache_aligned_vector.h" + +#if defined __aarch64__ +#include +#endif + +#include + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "sparse_matmul/numerics/test_utils.h" +#include "sparse_matmul/os/coop_threads.h" + +namespace csrblocksparse { + +const float kExpRelTolerance = .03f; // 3% relative +#ifdef SIGMOID_AS_TANH +const float kSigmoidRelTolerance = .09f; // 9.0% relative +const float kSigmoidAbsTolerance = .003f; +#else +const float kSigmoidRelTolerance = .031f; // 3.1% relative +const float kSigmoidAbsTolerance = .006f; +#endif +const float kTanhRelTolerance = .014f; // 1.4% relative +const float kTanhAbsTolerance = .00525f; + +TEST(Transcendentals, CacheAlignedVectorExp) { + const int kTestSize = 1 << 16; + CacheAlignedVector values(kTestSize); + values.FillRandom(); + CacheAlignedVector values_ref = values; + + values.Exp(); + for (int i = 0; i < kTestSize; ++i) { + float exact_val = std::exp(values_ref[i]); + float rel_diff = RelDiff(exact_val, values[i]); + + EXPECT_LT(rel_diff, kExpRelTolerance) + << exact_val << " " << values[i] << " " << values_ref[i]; + } +} + +TEST(Transcendentals, CacheAlignedVectorSigmoid) { + const int kTestSize = 1 << 16; + CacheAlignedVector values(kTestSize); + values.FillRandom(); + CacheAlignedVector values_ref = values; + + values.Sigmoid(); + for (int i = 0; i < kTestSize; ++i) { + float exact_val = 1. / (1. + std::exp(-values_ref[i])); + float rel_diff = RelDiff(exact_val, values[i]); + + EXPECT_LT(rel_diff, kSigmoidRelTolerance) + << exact_val << " " << values[i] << " " << values_ref[i]; + EXPECT_NEAR(values[i], exact_val, kSigmoidAbsTolerance) << values_ref[i]; + } +} + +TEST(Transcendentals, CacheAlignedVectorTanh) { + const int kTestSize = 1 << 16; + CacheAlignedVector values(kTestSize); + values.FillRandom(); + CacheAlignedVector values_ref = values; + + values.Tanh(); + for (int i = 0; i < kTestSize; ++i) { + float exact_val = std::tanh(values_ref[i]); + float rel_diff = RelDiff(exact_val, values[i]); + + EXPECT_LT(rel_diff, kTanhRelTolerance) + << exact_val << " " << values[i] << " " << values_ref[i]; + EXPECT_NEAR(values[i], exact_val, kTanhAbsTolerance) << values_ref[i]; + } +} + +// Uniformly sample logits and check that the resulting sample choices are +// also (nearly) uniformly distributed. +TEST(Sampling, Random) { + const int kSize = 256; + + CacheAlignedVector logits(kSize); + logits.FillZero(); + + double histogram[kSize] = {}; + + const int kIterations = 10000; + for (int i = 0; i < kIterations; ++i) { + histogram[logits.Sample()]++; + } + + for (int i = 0; i < kSize; ++i) { + // .002 is an empirical bound + EXPECT_GT(histogram[i] / kIterations, 1. / kSize - .002f); + EXPECT_LT(histogram[i] / kIterations, 1. / kSize + .002f); + } +} + +// Put (nearly) all the probability mass on one bin and make sure only that bin +// is chosen. +TEST(Sampling, FixedDistribution) { + const int kSize = 256; + + CacheAlignedVector logits(kSize); + + int histogram[kSize] = {}; + + const int kIterations = 1000; + const int kIndex = 3; + const int kAllProbabilityMass = 10; + const int kNoProbabilityMass = -10; + for (int i = 0; i < kIterations; ++i) { + for (int i = 1; i <= kSize; ++i) { + logits.data()[i - 1] = + i == (kIndex + 1) ? kAllProbabilityMass : kNoProbabilityMass; + } + + histogram[logits.Sample()]++; + } + + EXPECT_EQ(histogram[kIndex], 1000); +} + +// Put (nearly) all the probability mass on one bin outside the target range, +// and make sure that bin is not chosen. +TEST(ScalarSample, ThreadedMasked) { + const int kSize = 256; + const int mindex = 2; + const int maxdex = 3; + const int kNumThreads = 4; + const int kIterations = 1000; + const int kIndex = 3; + const int kMostProbabilityMass = 3; + const int kLittleProbabilityMass = -3; + + CacheAlignedVector logits(kSize); + std::vector> tmp_vectors; + std::vector generators(kNumThreads); + + for (int i = 0; i < kNumThreads; ++i) { + tmp_vectors.emplace_back(kSize); + } + + for (int i = 0; i < kSize; ++i) { + logits.data()[i] = + (i + 1) == (kIndex + 1) ? kMostProbabilityMass : kLittleProbabilityMass; + } + + std::vector> histograms; + for (int i = 0; i < kNumThreads; ++i) { + histograms.emplace_back(kSize); + } + + auto f = [&](csrblocksparse::SpinBarrier* /*barrier*/, int tid) { + for (int i = 0; i < kIterations; ++i) { + histograms[tid][logits.ScalarSample( + 1.f, &generators[tid], &tmp_vectors[tid], 0, mindex, maxdex)]++; + } + }; + + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); + + // Every thread should generate the exact same set of samples. + for (int i = 0; i < kSize; ++i) { + int val = histograms[0][i]; + for (int tid = 1; tid < kNumThreads; ++tid) { + EXPECT_EQ(val, histograms[tid][i]); + } + } + + // The most probable sample should be the only one we're sampling. + for (int tid = 0; tid < kNumThreads; ++tid) { + EXPECT_EQ(std::distance(histograms[tid].begin(), + std::max_element(histograms[tid].begin(), + histograms[tid].end())), + mindex); + } +} + +TEST(Sampling, Threaded) { + const int kSize = 256; + const int kNumThreads = 4; + const int kIterations = 1000; + const int kIndex = 3; + const int kMostProbabilityMass = 3; + const int kLittleProbabilityMass = -3; + + CacheAlignedVector logits(kSize); + std::vector> tmp_vectors; + std::vector generators(kNumThreads); + + for (int i = 0; i < kNumThreads; ++i) { + tmp_vectors.emplace_back(kSize); + } + + for (int i = 0; i < kSize; ++i) { + logits.data()[i] = + (i + 1) == (kIndex + 1) ? kMostProbabilityMass : kLittleProbabilityMass; + } + + std::vector> histograms; + for (int i = 0; i < kNumThreads; ++i) { + histograms.emplace_back(kSize); + } + + auto f = [&](csrblocksparse::SpinBarrier* /*barrier*/, int tid) { + for (int i = 0; i < kIterations; ++i) { + histograms[tid] + [logits.Sample(1.f, &generators[tid], &tmp_vectors[tid])]++; + } + }; + + csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, f); + + // Every thread should generate the exact same set of samples. + for (int i = 0; i < kSize; ++i) { + int val = histograms[0][i]; + for (int tid = 1; tid < kNumThreads; ++tid) { + EXPECT_EQ(val, histograms[tid][i]); + } + } + + // The most probable sample should be the one with the most probability mass. + for (int tid = 0; tid < kNumThreads; ++tid) { + EXPECT_EQ(std::distance(histograms[tid].begin(), + std::max_element(histograms[tid].begin(), + histograms[tid].end())), + kIndex); + } +} + +void CreateVectorHelper( + csrblocksparse::FatCacheAlignedVector* fat_vector, int cols, + int rows, std::unique_ptr>* view) { + *view = absl::make_unique>(*fat_vector, + cols, rows); +} + +void CreateVectorHelper( + csrblocksparse::FatCacheAlignedVector* fat_vector, int cols, + int rows, std::unique_ptr>* view) { + *view = absl::make_unique>( + fat_vector, cols, rows); +} + +csrblocksparse::FatCacheAlignedVector CreateFatAlignedVector(int rows, + int cols) { + csrblocksparse::FatCacheAlignedVector fat_vector(rows, cols); + // Usage intent of FatCacheAlignedVector is that they are COLUMN MAJOR. + float v = 0; + for (int c = 0; c < cols; ++c) { + for (int r = 0; r < rows; ++r) { + fat_vector.data()[c * rows + r] = v++; + } + } + + return fat_vector; +} + +template +void TestFatVectorView() { + const int kRows = 6; + const int kCols = 6; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + + std::unique_ptr top; + CreateVectorHelper(&fat_vector, 0, kRows / 2, &top); + std::unique_ptr bottom; + CreateVectorHelper(&fat_vector, kRows / 2, kRows / 2, &bottom); + + EXPECT_EQ(top->cols(), kCols); + EXPECT_EQ(bottom->cols(), kCols); + EXPECT_EQ(top->rows(), kRows / 2); + EXPECT_EQ(bottom->rows(), kRows / 2); + EXPECT_EQ(top->col_stride(), kRows); + EXPECT_EQ(bottom->col_stride(), kRows); + + for (int c = 0; c < kCols; ++c) { + for (int r = 0; r < kRows; ++r) { + if (r < kRows / 2) { + EXPECT_EQ(fat_vector[c * kRows + r], + top->data()[c * top->col_stride() + r]); + } else { + EXPECT_EQ(fat_vector[c * kRows + r], + bottom->data()[c * top->col_stride() + r - kRows / 2]); + } + } + } +} + +TEST(FatVector, View) { + TestFatVectorView>(); +} +TEST(FatVector, MutableView) { + TestFatVectorView>(); +} + +TEST(FatVector, SliceMutableView) { + const int kRows = 6; + const int kCols = 3; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + + int c = 1; + csrblocksparse::MutableVectorView slice = fat_vector.slice(c); + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(slice[r], c * kRows + r); + } +} + +TEST(FatVector, SliceConstView) { + const int kRows = 6; + const int kCols = 3; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + + int c = 1; + csrblocksparse::VectorView const_slice; + { + // Take a VectorView from a non-const slice. + const_slice = fat_vector.slice(c); + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(const_slice[r], c * kRows + r); + } + } + + { + // Take a VectorView from a const slice. + const auto& const_fat_vector = fat_vector; + const_slice = const_fat_vector.slice(c); + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(const_slice[r], c * kRows + r); + } + } +} + +TEST(View, FromMutableToConst) { + const int kRows = 6; + const int kCols = 3; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + csrblocksparse::MutableVectorView slice = fat_vector.slice(0); + + csrblocksparse::VectorView const_slice(slice); + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(const_slice[r], r); + } +} + +TEST(View, CopyTest) { + const int kRows = 6; + const int kCols = 3; + auto fat_vector = CreateFatAlignedVector(kRows, kCols); + csrblocksparse::MutableVectorView slice = fat_vector.slice(0); + csrblocksparse::MutableVectorView slice2(slice); + + for (int r = 0; r < kRows; ++r) { + EXPECT_EQ(slice2[r], r); + } +} + +TEST(Vector, CopyNull) { + // Check that we can copy a vector with a null generator without segfault. + CacheAlignedVector foo((CacheAlignedVector())); + // This is here to prevent foo from being optimized out. + CHECK_EQ(foo.size(), 0); + CacheAlignedVector foo_bar = CacheAlignedVector(); + CHECK_EQ(foo_bar.size(), 0); +} + +TEST(Vector, FromRawPointer) { + std::vector input; + for (int i = 0; i < 5; ++i) { + input.push_back(i * 2); + } + + // Calls first constructor. + CacheAlignedVector foo(input.data(), 5); + CHECK_EQ(foo.size(), 5); + EXPECT_THAT(input, testing::ElementsAreArray(foo.data(), 5)); + + // Calls the second constructor. + CacheAlignedVector foo2(input.data(), 5); + CHECK_EQ(foo2.size(), 5); + EXPECT_THAT(input, testing::ElementsAreArray(foo2.data(), 5)); +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/zlib_wrapper/BUILD b/sparse_matmul/zlib_wrapper/BUILD new file mode 100644 index 00000000..b9653dab --- /dev/null +++ b/sparse_matmul/zlib_wrapper/BUILD @@ -0,0 +1,20 @@ +licenses(["notice"]) + +cc_library( + name = "zlib_wrapper", + srcs = [ + "gzipheader.cc", + "zlibwrapper.cc", + ], + hdrs = [ + "gzipheader.h", + "zlibwrapper.h", + ], + visibility = ["//:__subpackages__"], + deps = [ + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_glog//:glog", + "@zlib", + ], +) diff --git a/sparse_matmul/zlib_wrapper/gzipheader.cc b/sparse_matmul/zlib_wrapper/gzipheader.cc new file mode 100644 index 00000000..a8d5c3ca --- /dev/null +++ b/sparse_matmul/zlib_wrapper/gzipheader.cc @@ -0,0 +1,190 @@ +// Copyright 2002 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Author: Neal Cardwell +// + +#include "sparse_matmul/zlib_wrapper/gzipheader.h" + +#include + +#include "absl/base/macros.h" +#include "glog/logging.h" +#include "zlib.h" // for Z_DEFAULT_COMPRESSION + +namespace csrblocksparse { + +const uint8_t GZipHeader::magic[] = {0x1f, 0x8b}; + +// ---------------------------------------------------------------------- +// GZipHeader::ReadMore() +// Attempt to parse the beginning of the given buffer as a gzip +// header. If these bytes do not constitute a complete gzip header, +// return INCOMPLETE_HEADER. If these bytes do not constitute a +// *valid* gzip header, return INVALID_HEADER. If we find a +// complete header, return COMPLETE_HEADER and set the pointer +// pointed to by header_end to the first byte beyond the gzip header. +// ---------------------------------------------------------------------- + +GZipHeader::Status GZipHeader::ReadMore(const char* inbuf, int inbuf_len, + const char** header_end) { + CHECK_GE(inbuf_len, 0); + const uint8_t* pos = reinterpret_cast(inbuf); + const uint8_t* const end = pos + inbuf_len; + + while (pos < end) { + switch (state_) { + case IN_HEADER_ID1: + if (*pos != magic[0]) return INVALID_HEADER; + pos++; + state_++; + break; + case IN_HEADER_ID2: + if (*pos != magic[1]) return INVALID_HEADER; + pos++; + state_++; + break; + case IN_HEADER_CM: + if (*pos != Z_DEFLATED) return INVALID_HEADER; + pos++; + state_++; + break; + case IN_HEADER_FLG: + flags_ = + (*pos) & (FLAG_FHCRC | FLAG_FEXTRA | FLAG_FNAME | FLAG_FCOMMENT); + pos++; + state_++; + break; + + case IN_HEADER_MTIME_BYTE_0: + pos++; + state_++; + break; + case IN_HEADER_MTIME_BYTE_1: + pos++; + state_++; + break; + case IN_HEADER_MTIME_BYTE_2: + pos++; + state_++; + break; + case IN_HEADER_MTIME_BYTE_3: + pos++; + state_++; + break; + + case IN_HEADER_XFL: + pos++; + state_++; + break; + + case IN_HEADER_OS: + pos++; + state_++; + break; + + case IN_XLEN_BYTE_0: + if (!(flags_ & FLAG_FEXTRA)) { + state_ = IN_FNAME; + break; + } + // We have a two-byte little-endian length, followed by a + // field of that length. + extra_length_ = *pos; + pos++; + state_++; + break; + case IN_XLEN_BYTE_1: + extra_length_ += *pos << 8; + pos++; + state_++; + // If we have a zero-length FEXTRA, we want to check to notice that + // we're done reading the FEXTRA before we exit this loop... + ABSL_FALLTHROUGH_INTENDED; + + case IN_FEXTRA: { + // Grab the rest of the bytes in the extra field, or as many + // of them as are actually present so far. + const int num_extra_bytes = std::min(extra_length_, (end - pos)); + pos += num_extra_bytes; + extra_length_ -= num_extra_bytes; + if (extra_length_ == 0) { + state_ = IN_FNAME; // advance when we've seen extra_length_ bytes + flags_ &= ~FLAG_FEXTRA; // we're done with the FEXTRA stuff + } + break; + } + + case IN_FNAME: + if (!(flags_ & FLAG_FNAME)) { + state_ = IN_FCOMMENT; + break; + } + // See if we can find the end of the \0-terminated FNAME field. + pos = reinterpret_cast(memchr(pos, '\0', (end - pos))); + if (pos != nullptr) { + pos++; // advance past the '\0' + flags_ &= ~FLAG_FNAME; // we're done with the FNAME stuff + state_ = IN_FCOMMENT; + } else { + pos = end; // everything we have so far is part of the FNAME + } + break; + + case IN_FCOMMENT: + if (!(flags_ & FLAG_FCOMMENT)) { + state_ = IN_FHCRC_BYTE_0; + break; + } + // See if we can find the end of the \0-terminated FCOMMENT field. + pos = reinterpret_cast(memchr(pos, '\0', (end - pos))); + if (pos != nullptr) { + pos++; // advance past the '\0' + flags_ &= ~FLAG_FCOMMENT; // we're done with the FCOMMENT stuff + state_ = IN_FHCRC_BYTE_0; + } else { + pos = end; // everything we have so far is part of the FNAME + } + break; + + case IN_FHCRC_BYTE_0: + if (!(flags_ & FLAG_FHCRC)) { + state_ = IN_DONE; + break; + } + pos++; + state_++; + break; + + case IN_FHCRC_BYTE_1: + pos++; + flags_ &= ~FLAG_FHCRC; // we're done with the FHCRC stuff + state_++; + break; + + case IN_DONE: + *header_end = reinterpret_cast(pos); + return COMPLETE_HEADER; + } + } + + if ((state_ > IN_HEADER_OS) && (flags_ == 0)) { + *header_end = reinterpret_cast(pos); + return COMPLETE_HEADER; + } else { + return INCOMPLETE_HEADER; + } +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/zlib_wrapper/gzipheader.h b/sparse_matmul/zlib_wrapper/gzipheader.h new file mode 100644 index 00000000..21cd71e4 --- /dev/null +++ b/sparse_matmul/zlib_wrapper/gzipheader.h @@ -0,0 +1,107 @@ +// +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_GZIPHEADER_H +#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_GZIPHEADER_H + +// The GZipHeader class allows you to parse a gzip header, such as you +// might find at the beginning of a file compressed by gzip (ie, a .gz +// file), or at the beginning of an HTTP response that uses a gzip +// Content-Encoding. See RFC 1952 for the specification for the gzip +// header. +// +// The model is that you call ReadMore() for each chunk of bytes +// you've read from a file or socket. +// + +#include + +namespace csrblocksparse { + +class GZipHeader { + public: + GZipHeader() { Reset(); } + ~GZipHeader() {} + + // Wipe the slate clean and start from scratch. + void Reset() { + state_ = IN_HEADER_ID1; + flags_ = 0; + extra_length_ = 0; + } + + enum Status { + INCOMPLETE_HEADER, // don't have all the bits yet... + COMPLETE_HEADER, // complete, valid header + INVALID_HEADER, // found something invalid in the header + }; + + // Attempt to parse the given buffer as the next installment of + // bytes from a gzip header. If the bytes we've seen so far do not + // yet constitute a complete gzip header, return + // INCOMPLETE_HEADER. If these bytes do not constitute a *valid* + // gzip header, return INVALID_HEADER. When we've seen a complete + // gzip header, return COMPLETE_HEADER and set the pointer pointed + // to by header_end to the first byte beyond the gzip header. + Status ReadMore(const char* inbuf, int inbuf_len, const char** header_end); + + private: + // NOLINTNEXTLINE + static const uint8_t magic[]; // gzip magic header + + enum { // flags (see RFC) + FLAG_FTEXT = 0x01, // bit 0 set: file probably ascii text + FLAG_FHCRC = 0x02, // bit 1 set: header CRC present + FLAG_FEXTRA = 0x04, // bit 2 set: extra field present + FLAG_FNAME = 0x08, // bit 3 set: original file name present + FLAG_FCOMMENT = 0x10, // bit 4 set: file comment present + FLAG_RESERVED = 0xE0, // bits 5..7: reserved + }; + + enum State { + // The first 10 bytes are the fixed-size header: + IN_HEADER_ID1, + IN_HEADER_ID2, + IN_HEADER_CM, + IN_HEADER_FLG, + IN_HEADER_MTIME_BYTE_0, + IN_HEADER_MTIME_BYTE_1, + IN_HEADER_MTIME_BYTE_2, + IN_HEADER_MTIME_BYTE_3, + IN_HEADER_XFL, + IN_HEADER_OS, + + IN_XLEN_BYTE_0, + IN_XLEN_BYTE_1, + IN_FEXTRA, + + IN_FNAME, + + IN_FCOMMENT, + + IN_FHCRC_BYTE_0, + IN_FHCRC_BYTE_1, + + IN_DONE, + }; + + int state_; // our current State in the parsing FSM: an int so we can ++ + uint8_t flags_; // the flags byte of the header ("FLG" in the RFC) + uint16_t extra_length_; // how much of the "extra field" we have yet to read +}; + +} // namespace csrblocksparse + +#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_GZIPHEADER_H diff --git a/sparse_matmul/zlib_wrapper/zlibwrapper.cc b/sparse_matmul/zlib_wrapper/zlibwrapper.cc new file mode 100644 index 00000000..a3a2fa5e --- /dev/null +++ b/sparse_matmul/zlib_wrapper/zlibwrapper.cc @@ -0,0 +1,841 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sparse_matmul/zlib_wrapper/zlibwrapper.h" + +#include +#include + +#include +#include +#include + +#include "glog/logging.h" +#include "sparse_matmul/zlib_wrapper/gzipheader.h" +#include "zconf.h" +#include "zlib.h" + +// The GZIP header (see RFC 1952): +// +---+---+---+---+---+---+---+---+---+---+ +// |ID1|ID2|CM |FLG| MTIME |XFL|OS | +// +---+---+---+---+---+---+---+---+---+---+ +// ID1 \037 +// ID2 \213 +// CM \010 (compression method == DEFLATE) +// FLG \000 (special flags that we do not support) +// MTIME Unix format modification time (0 means not available) +// XFL 2-4? DEFLATE flags +// OS ???? Operating system indicator (255 means unknown) + +// Header value we generate: +// We use a #define so sizeof() works correctly +#define GZIP_HEADER "\037\213\010\000\000\000\000\000\002\377" + +namespace csrblocksparse { + +// We allow all kinds of bad footers when this flag is true. +// Some web servers send bad pages corresponding to these cases +// and IE is tolerant with it. +// - Extra bytes after gzip footer (see bug 69126) +// - No gzip footer (see bug 72896) +// - Incomplete gzip footer (see bug 71871706) +bool ZLib::should_be_flexible_with_gzip_footer_ = false; + +// Initialize the ZLib class +ZLib::ZLib() + : comp_init_(false), uncomp_init_(false), gzip_header_(new GZipHeader) { + Reinit(); + init_settings_ = settings_; +} + +ZLib::~ZLib() { + if (comp_init_) { + deflateEnd(&comp_stream_); + } + if (uncomp_init_) { + inflateEnd(&uncomp_stream_); + } + delete gzip_header_; +} + +void ZLib::Reinit() { + settings_.dictionary_ = nullptr; + settings_.dict_len_ = 0; + settings_.compression_level_ = Z_DEFAULT_COMPRESSION; + settings_.window_bits_ = MAX_WBITS; + settings_.mem_level_ = 8; // DEF_MEM_LEVEL + settings_.no_header_mode_ = false; + settings_.gzip_header_mode_ = false; + settings_.dont_hide_zstream_end_ = false; + + if (comp_init_) { + int err = deflateReset(&comp_stream_); + if (err != Z_OK) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + } + if (uncomp_init_) { + // Use negative window bits size to indicate bare stream with no header. + int wbits = (settings_.no_header_mode_ ? -MAX_WBITS : MAX_WBITS); + int err = inflateReset2(&uncomp_stream_, wbits); + if (err == Z_OK) { + init_settings_.no_header_mode_ = settings_.no_header_mode_; + } else { + inflateEnd(&uncomp_stream_); + uncomp_init_ = false; + } + } + crc_ = 0; + uncompressed_size_ = 0; + gzip_header_->Reset(); + gzip_footer_bytes_ = -1; + first_chunk_ = true; +} + +void ZLib::Reset() { + first_chunk_ = true; + gzip_header_->Reset(); +} + +void ZLib::CheckValidParams() { + if (settings_.dictionary_ != nullptr && + (settings_.no_header_mode_ || settings_.gzip_header_mode_)) { + LOG(FATAL) + << "Incompatible params: require zlib headers with preset dictionary"; + } +} + +void ZLib::SetNoHeaderMode(bool no_header_mode) { + settings_.no_header_mode_ = no_header_mode; + if (init_settings_.no_header_mode_ != settings_.no_header_mode_) { + // Once the header mode changes, we have to reinitialize all our streams + if (comp_init_) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + if (uncomp_init_) { + inflateEnd(&uncomp_stream_); + uncomp_init_ = false; + } + } else { + // Mode hasn't changed, but treat this as a reset request nevertheless + Reset(); + } + CheckValidParams(); +} + +void ZLib::SetGzipHeaderMode() { + settings_.gzip_header_mode_ = true; + SetNoHeaderMode(true); // we use gzip headers, not zlib headers + CheckValidParams(); +} + +void ZLib::SetDictionary(const char* initial_dict, unsigned int dict_len) { + settings_.dictionary_ = (Bytef*)initial_dict; // NOLINT + settings_.dict_len_ = dict_len; + CheckValidParams(); +} + +void ZLib::SetDontHideStreamEnd() { settings_.dont_hide_zstream_end_ = true; } + +int ZLib::MinFooterSize() const { + int min_footer_size = 2; // Room for empty chunk. + if (settings_.gzip_header_mode_) { + min_footer_size += 8; // Room for actual footer. + } + return min_footer_size; +} + +// --------- COMPRESS MODE + +// Initialization method to be called if we hit an error while +// compressing. On hitting an error, call this method before returning +// the error. +void ZLib::CompressErrorInit() { + if (comp_init_) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + Reset(); +} + +// These probably return Z_OK, but may return Z_BUF_ERROR if outbuf is full +int ZLib::WriteGzipHeader() { + if (comp_stream_.avail_out < sizeof(GZIP_HEADER)) return Z_BUF_ERROR; + memcpy(comp_stream_.next_out, GZIP_HEADER, sizeof(GZIP_HEADER) - 1); + comp_stream_.next_out += sizeof(GZIP_HEADER) - 1; + comp_stream_.avail_out -= sizeof(GZIP_HEADER) - 1; + return Z_OK; +} + +int ZLib::WriteGzipFooter(Bytef* dest, uLongf destLen) { + if (destLen < 8) // not enough space for footer + return Z_BUF_ERROR; + *dest++ = (crc_ >> 0) & 255; + *dest++ = (crc_ >> 8) & 255; + *dest++ = (crc_ >> 16) & 255; + *dest++ = (crc_ >> 24) & 255; + *dest++ = (uncompressed_size_ >> 0) & 255; + *dest++ = (uncompressed_size_ >> 8) & 255; + *dest++ = (uncompressed_size_ >> 16) & 255; + *dest++ = (uncompressed_size_ >> 24) & 255; + return Z_OK; +} + +int ZLib::DeflateInit() { + int err = + deflateInit2(&comp_stream_, settings_.compression_level_, Z_DEFLATED, + (settings_.no_header_mode_ ? -settings_.window_bits_ + : settings_.window_bits_), + settings_.mem_level_, Z_DEFAULT_STRATEGY); + if (err == Z_OK) { + // Save parameters for later reusability checks + init_settings_.compression_level_ = settings_.compression_level_; + init_settings_.window_bits_ = settings_.window_bits_; + init_settings_.mem_level_ = settings_.mem_level_; + init_settings_.no_header_mode_ = settings_.no_header_mode_; + } + return err; +} + +int ZLib::CompressInit(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen) { + int err; + + comp_stream_.next_in = (Bytef*)source; // NOLINT + comp_stream_.avail_in = (uInt)*sourceLen; + // Check for sourceLen (unsigned long) to fit into avail_in (unsigned int). + if ((uLong)comp_stream_.avail_in != *sourceLen) return Z_BUF_ERROR; + comp_stream_.next_out = dest; + comp_stream_.avail_out = (uInt)*destLen; + // Check for destLen (unsigned long) to fit into avail_out (unsigned int). + if ((uLong)comp_stream_.avail_out != *destLen) return Z_BUF_ERROR; + + if (!first_chunk_) // only need to set up stream the first time through + return Z_OK; + + // Force full reinit if properties have changed in a way we can't adjust. + if (comp_init_ && + (init_settings_.dictionary_ != settings_.dictionary_ || + init_settings_.dict_len_ != settings_.dict_len_ || + init_settings_.window_bits_ != settings_.window_bits_ || + init_settings_.mem_level_ != settings_.mem_level_ || + init_settings_.no_header_mode_ != settings_.no_header_mode_)) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + + // Reuse if we've already initted the object. + if (comp_init_) { // we've already initted it + err = deflateReset(&comp_stream_); + if (err != Z_OK) { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + } + + // If compression level has changed, try to reconfigure instead of reinit + if (comp_init_ && + init_settings_.compression_level_ != settings_.compression_level_) { + err = deflateParams(&comp_stream_, settings_.compression_level_, + Z_DEFAULT_STRATEGY); + if (err == Z_OK) { + init_settings_.compression_level_ = settings_.compression_level_; + } else { + deflateEnd(&comp_stream_); + comp_init_ = false; + } + } + + // First use or previous state was not reusable with current settings. + if (!comp_init_) { + comp_stream_.zalloc = (alloc_func)0; + comp_stream_.zfree = (free_func)0; + comp_stream_.opaque = (voidpf)0; + err = DeflateInit(); + if (err != Z_OK) return err; + comp_init_ = true; + } + return Z_OK; +} + +// In a perfect world we'd always have the full buffer to compress +// when the time came, and we could just call Compress(). Alas, we +// want to do chunked compression on our webserver. In this +// application, we compress the header, send it off, then compress the +// results, send them off, then compress the footer. Thus we need to +// use the chunked compression features of zlib. +int ZLib::CompressAtMostOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen, + int flush_mode) { // Z_FULL_FLUSH or Z_FINISH + int err; + + if ((err = CompressInit(dest, destLen, source, sourceLen)) != Z_OK) + return err; + + // This is used to figure out how many bytes we wrote *this chunk* + int compressed_size = comp_stream_.total_out; + + // Some setup happens only for the first chunk we compress in a run + if (first_chunk_) { + // Append the gzip header before we start compressing + if (settings_.gzip_header_mode_) { + if ((err = WriteGzipHeader()) != Z_OK) return err; + compressed_size -= sizeof(GZIP_HEADER) - 1; // -= is right: adds to size + crc_ = crc32(0, nullptr, 0); // initialize + } + + // Initialize the dictionary just before we start compressing + if (settings_.dictionary_) { + err = deflateSetDictionary(&comp_stream_, settings_.dictionary_, + settings_.dict_len_); + if (err != Z_OK) return err; + init_settings_.dictionary_ = settings_.dictionary_; + init_settings_.dict_len_ = settings_.dict_len_; + } + + uncompressed_size_ = 0; + first_chunk_ = false; // so we don't do this again + } + + // flush_mode is Z_FINISH for all mode, Z_SYNC_FLUSH for incremental + // compression. + err = deflate(&comp_stream_, flush_mode); + + const uLong source_bytes_consumed = *sourceLen - comp_stream_.avail_in; + *sourceLen = comp_stream_.avail_in; + + if ((err == Z_STREAM_END || err == Z_OK) && comp_stream_.avail_in == 0 && + comp_stream_.avail_out != 0) { + // we processed everything ok and the output buffer was large enough. + {} + } else if (err == Z_STREAM_END && comp_stream_.avail_in > 0) { + return Z_BUF_ERROR; // should never happen + } else if (err != Z_OK && err != Z_STREAM_END && err != Z_BUF_ERROR) { + // an error happened + CompressErrorInit(); + return err; + } else if (comp_stream_.avail_out == 0) { // not enough space + err = Z_BUF_ERROR; + } + + assert(err == Z_OK || err == Z_STREAM_END || err == Z_BUF_ERROR); + if (err == Z_STREAM_END) err = Z_OK; + + // update the crc and other metadata + uncompressed_size_ += source_bytes_consumed; + compressed_size = comp_stream_.total_out - compressed_size; // delta + *destLen = compressed_size; + if (settings_.gzip_header_mode_) // don't bother with crc else + crc_ = crc32(crc_, source, source_bytes_consumed); + + return err; +} + +int ZLib::CompressChunkOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen, + int flush_mode) { // Z_FULL_FLUSH or Z_FINISH + const int ret = + CompressAtMostOrAll(dest, destLen, source, &sourceLen, flush_mode); + if (ret == Z_BUF_ERROR) CompressErrorInit(); + return ret; +} + +int ZLib::CompressChunk(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen) { + return CompressChunkOrAll(dest, destLen, source, sourceLen, Z_SYNC_FLUSH); +} + +int ZLib::CompressAtMost(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen) { + return CompressAtMostOrAll(dest, destLen, source, sourceLen, Z_SYNC_FLUSH); +} + +// This writes the gzip footer info, if necessary. +// No matter what, we call Reset() so we can compress Chunks again. +int ZLib::CompressChunkDone(Bytef* dest, uLongf* destLen) { + // Make sure our buffer is of reasonable size. + if (*destLen < MinFooterSize()) { + *destLen = 0; + return Z_BUF_ERROR; + } + + // The underlying zlib library requires a non-nullptr source pointer, even if + // the source length is zero, otherwise it will generate an (incorrect) zero- + // valued CRC checksum. + char dummy = '\0'; + int err; + + assert(!first_chunk_ && comp_init_); + + const uLongf orig_destLen = *destLen; + // NOLINTNEXTLINE + if ((err = CompressChunkOrAll(dest, destLen, (const Bytef*)&dummy, 0, + Z_FINISH)) != Z_OK) { + Reset(); // we assume they won't retry on error + return err; + } + + // Make sure that when we exit, we can start a new round of chunks later + // (This must be set after the call to CompressChunkOrAll() above.) + Reset(); + + // Write gzip footer if necessary. They're explicitly in little-endian order + if (settings_.gzip_header_mode_) { + if ((err = WriteGzipFooter(dest + *destLen, orig_destLen - *destLen)) != + Z_OK) + return err; + *destLen += 8; // zlib footer took up another 8 bytes + } + return Z_OK; // stream_end is ok +} + +// This routine only initializes the compression stream once. Thereafter, it +// just does a deflateReset on the stream, which should be faster. +int ZLib::Compress(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen) { + int err; + const uLongf orig_destLen = *destLen; + if ((err = CompressChunkOrAll(dest, destLen, source, sourceLen, Z_FINISH)) != + Z_OK) + return err; + Reset(); // reset for next call to Compress + + if (settings_.gzip_header_mode_) { + if ((err = WriteGzipFooter(dest + *destLen, orig_destLen - *destLen)) != + Z_OK) + return err; + *destLen += 8; // zlib footer took up another 8 bytes + } + + return Z_OK; +} + +// --------- UNCOMPRESS MODE + +int ZLib::InflateInit() { + // Use negative window bits size to indicate bare stream with no header. + int wbits = (settings_.no_header_mode_ ? -MAX_WBITS : MAX_WBITS); + int err = inflateInit2(&uncomp_stream_, wbits); + if (err == Z_OK) { + init_settings_.no_header_mode_ = settings_.no_header_mode_; + } + return err; +} + +// Initialization method to be called if we hit an error while +// uncompressing. On hitting an error, call this method before +// returning the error. +void ZLib::UncompressErrorInit() { + if (uncomp_init_) { + inflateEnd(&uncomp_stream_); + uncomp_init_ = false; + } + Reset(); +} + +int ZLib::UncompressInit(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen) { + int err; + + uncomp_stream_.next_in = (Bytef*)source; // NOLINT + uncomp_stream_.avail_in = (uInt)*sourceLen; + // Check for sourceLen (unsigned long) to fit into avail_in (unsigned int). + if ((uLong)uncomp_stream_.avail_in != *sourceLen) return Z_BUF_ERROR; + + uncomp_stream_.next_out = dest; + uncomp_stream_.avail_out = (uInt)*destLen; + // Check for destLen (unsigned long) to fit into avail_out (unsigned int). + if ((uLong)uncomp_stream_.avail_out != *destLen) return Z_BUF_ERROR; + + if (!first_chunk_) // only need to set up stream the first time through + return Z_OK; + + // Force full reinit if properties have changed in a way we can't adjust. + if (uncomp_init_ && (init_settings_.dictionary_ != settings_.dictionary_ || + init_settings_.dict_len_ != settings_.dict_len_)) { + inflateEnd(&uncomp_stream_); + uncomp_init_ = false; + } + + // Reuse if we've already initted the object. + if (uncomp_init_) { + // Use negative window bits size to indicate bare stream with no header. + int wbits = (settings_.no_header_mode_ ? -MAX_WBITS : MAX_WBITS); + err = inflateReset2(&uncomp_stream_, wbits); + if (err == Z_OK) { + init_settings_.no_header_mode_ = settings_.no_header_mode_; + } else { + UncompressErrorInit(); + } + } + + // First use or previous state was not reusable with current settings. + if (!uncomp_init_) { + uncomp_stream_.zalloc = (alloc_func)0; + uncomp_stream_.zfree = (free_func)0; + uncomp_stream_.opaque = (voidpf)0; + err = InflateInit(); + if (err != Z_OK) return err; + uncomp_init_ = true; + } + return Z_OK; +} + +// If you compressed your data a chunk at a time, with CompressChunk, +// you can uncompress it a chunk at a time with UncompressChunk. +// Only difference bewteen chunked and unchunked uncompression +// is the flush mode we use: Z_SYNC_FLUSH (chunked) or Z_FINISH (unchunked). +int ZLib::UncompressAtMostOrAll(Bytef* dest, uLongf* destLen, + const Bytef* source, uLong* sourceLen, + int flush_mode) { // Z_SYNC_FLUSH or Z_FINISH + int err = Z_OK; + + if (first_chunk_) { + gzip_footer_bytes_ = -1; + if (settings_.gzip_header_mode_) { + // If we haven't read our first chunk of actual compressed data, + // and we're expecting gzip headers, then parse some more bytes + // from the gzip headers. + const Bytef* bodyBegin = nullptr; + GZipHeader::Status status = gzip_header_->ReadMore( + reinterpret_cast(source), *sourceLen, + reinterpret_cast(&bodyBegin)); + switch (status) { + case GZipHeader::INCOMPLETE_HEADER: // don't have the complete header + *destLen = 0; + *sourceLen = 0; // GZipHeader used all the input + return Z_OK; + case GZipHeader::INVALID_HEADER: // bogus header + Reset(); + return Z_DATA_ERROR; + case GZipHeader::COMPLETE_HEADER: // we have the full header + *sourceLen -= (bodyBegin - source); // skip past header bytes + source = bodyBegin; + crc_ = crc32(0, nullptr, 0); // initialize CRC + break; + default: + LOG(FATAL) << "Unexpected gzip header parsing result: " << status; + } + } + } else if (gzip_footer_bytes_ >= 0) { + // We're now just reading the gzip footer. We already read all the data. + if (gzip_footer_bytes_ + *sourceLen > sizeof(gzip_footer_) && + // When this flag is true, we allow some extra bytes after the + // gzip footer. + !should_be_flexible_with_gzip_footer_) { + VLOG(1) << "UncompressChunkOrAll: Received " + << (gzip_footer_bytes_ + *sourceLen - sizeof(gzip_footer_)) + << " extra bytes after gzip footer: " + << std::string(reinterpret_cast(source), + std::min(*sourceLen, 20UL)); + Reset(); + return Z_DATA_ERROR; + } + uLong len = sizeof(gzip_footer_) - gzip_footer_bytes_; + if (len > *sourceLen) len = *sourceLen; + if (len > 0) { + memcpy(gzip_footer_ + gzip_footer_bytes_, source, len); + gzip_footer_bytes_ += len; + } + *sourceLen -= len; + *destLen = 0; + return Z_OK; + } + + if ((err = UncompressInit(dest, destLen, source, sourceLen)) != Z_OK) { + LOG(WARNING) << "ZLib: UncompressInit: Error: " << err + << "SourceLen: " << *sourceLen; + return err; + } + + // This is used to figure out how many output bytes we wrote *this chunk*: + const uLong old_total_out = uncomp_stream_.total_out; + + // This is used to figure out how many input bytes we read *this chunk*: + const uLong old_total_in = uncomp_stream_.total_in; + + // Some setup happens only for the first chunk we compress in a run + if (first_chunk_) { + // Initialize the dictionary just before we start compressing + if (settings_.gzip_header_mode_ || settings_.no_header_mode_) { + // In no_header_mode, we can just set the dictionary, since no + // checking is done to advance past header bits to get us in the + // dictionary setting mode. In settings_.gzip_header_mode_ we've already + // removed headers, so this code works too. + if (settings_.dictionary_) { + err = inflateSetDictionary(&uncomp_stream_, settings_.dictionary_, + settings_.dict_len_); + if (err != Z_OK) { + LOG(WARNING) << "inflateSetDictionary: Error: " << err + << " dict_len: " << settings_.dict_len_; + UncompressErrorInit(); + return err; + } + init_settings_.dictionary_ = settings_.dictionary_; + init_settings_.dict_len_ = settings_.dict_len_; + } + } + + first_chunk_ = false; // so we don't do this again + + // For the first chunk *only* (to avoid infinite troubles), we let + // there be no actual data to uncompress. This sometimes triggers + // when the input is only the gzip header, say. + if (*sourceLen == 0) { + *destLen = 0; + return Z_OK; + } + } + + // We'll uncompress as much as we can. If we end OK great, otherwise + // if we get an error that seems to be the gzip footer, we store the + // gzip footer and return OK, otherwise we return the error. + + // flush_mode is Z_SYNC_FLUSH for chunked mode, Z_FINISH for all mode. + err = inflate(&uncomp_stream_, flush_mode); + if (settings_.dictionary_ && err == Z_NEED_DICT) { + err = inflateSetDictionary(&uncomp_stream_, settings_.dictionary_, + settings_.dict_len_); + if (err != Z_OK) { + LOG(WARNING) << "UncompressChunkOrAll: failed in inflateSetDictionary : " + << err; + UncompressErrorInit(); + return err; + } + init_settings_.dictionary_ = settings_.dictionary_; + init_settings_.dict_len_ = settings_.dict_len_; + err = inflate(&uncomp_stream_, flush_mode); + } + + // Figure out how many bytes of the input zlib slurped up: + const uLong bytes_read = uncomp_stream_.total_in - old_total_in; + CHECK_LE(source + bytes_read, source + *sourceLen); + *sourceLen = uncomp_stream_.avail_in; + + // Next we look at the footer, if any. Note that we might currently + // have just part of the footer (eg, if this data is arriving over a + // socket). After looking for a footer, log a warning if there is + // extra cruft. + if ((err == Z_STREAM_END) && + ((gzip_footer_bytes_ == -1) || + (gzip_footer_bytes_ < sizeof(gzip_footer_))) && + (uncomp_stream_.avail_in <= sizeof(gzip_footer_) || + // When this flag is true, we allow some extra bytes after the + // zlib footer. + should_be_flexible_with_gzip_footer_)) { + // Due to a bug in old versions of zlibwrapper, we appended the gzip + // footer even in non-gzip mode. Thus we always allow a gzip footer + // even if we're not in gzip mode, so we can continue to uncompress + // the old data. :-( + + // Store gzip footer bytes so we can check for footer consistency + // in UncompressChunkDone(). (If we have the whole footer, we + // could do the checking here, but we don't to keep consistency + // with CompressChunkDone().) + gzip_footer_bytes_ = std::min(static_cast(uncomp_stream_.avail_in), + sizeof(gzip_footer_)); + memcpy(gzip_footer_, source + bytes_read, gzip_footer_bytes_); + *sourceLen -= gzip_footer_bytes_; + } else if ((err == Z_STREAM_END || err == Z_OK) // everything went ok + && uncomp_stream_.avail_in == 0) { // and we read it all + {} + } else if (err == Z_STREAM_END && uncomp_stream_.avail_in > 0) { + VLOG(1) << "UncompressChunkOrAll: Received some extra data, bytes total: " + << uncomp_stream_.avail_in << " bytes: " + << std::string( + reinterpret_cast(uncomp_stream_.next_in), + std::min(static_cast(uncomp_stream_.avail_in), 20)); + UncompressErrorInit(); + return Z_DATA_ERROR; // what's the extra data for? + } else if (err != Z_OK && err != Z_STREAM_END && err != Z_BUF_ERROR) { + // an error happened + VLOG(1) << "UncompressChunkOrAll: Error: " << err + << " avail_out: " << uncomp_stream_.avail_out; + UncompressErrorInit(); + return err; + } else if (uncomp_stream_.avail_out == 0) { + err = Z_BUF_ERROR; + } + + assert(err == Z_OK || err == Z_BUF_ERROR || err == Z_STREAM_END); + if (err == Z_STREAM_END && !settings_.dont_hide_zstream_end_) err = Z_OK; + + // update the crc and other metadata + uncompressed_size_ = uncomp_stream_.total_out; + *destLen = uncomp_stream_.total_out - old_total_out; // size for this call + if (settings_.gzip_header_mode_) crc_ = crc32(crc_, dest, *destLen); + + return err; +} + +int ZLib::UncompressChunkOrAll(Bytef* dest, uLongf* destLen, + const Bytef* source, uLong sourceLen, + int flush_mode) { // Z_SYNC_FLUSH or Z_FINISH + const int ret = + UncompressAtMostOrAll(dest, destLen, source, &sourceLen, flush_mode); + if (ret == Z_BUF_ERROR) UncompressErrorInit(); + return ret; +} + +int ZLib::UncompressAtMost(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen) { + return UncompressAtMostOrAll(dest, destLen, source, sourceLen, Z_SYNC_FLUSH); +} + +int ZLib::UncompressChunk(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen) { + return UncompressChunkOrAll(dest, destLen, source, sourceLen, Z_SYNC_FLUSH); +} + +// We make sure we've uncompressed everything, that is, the current +// uncompress stream is at a compressed-buffer-EOF boundary. In gzip +// mode, we also check the gzip footer to make sure we pass the gzip +// consistency checks. We RETURN true iff both types of checks pass. +bool ZLib::UncompressChunkDone() { + if (first_chunk_ || !uncomp_init_) { + return false; + } + // Make sure we're at the end-of-compressed-data point. This means + // if we call inflate with Z_FINISH we won't consume any input or + // write any output + Bytef dummyin, dummyout; + uLongf dummylen = 0; + if (UncompressChunkOrAll(&dummyout, &dummylen, &dummyin, 0, Z_FINISH) != + Z_OK) { + return false; + } + + // Make sure that when we exit, we can start a new round of chunks later + Reset(); + + // We don't need to check footer when this flag is true. + if (should_be_flexible_with_gzip_footer_) { + return true; + } + + // Whether we were hoping for a gzip footer or not, we allow a gzip + // footer. (See the note above about bugs in old zlibwrappers.) But + // by the time we've seen all the input, it has to be either a + // complete gzip footer, or no footer at all. + if ((gzip_footer_bytes_ != -1) && (gzip_footer_bytes_ != 0) && + (gzip_footer_bytes_ != sizeof(gzip_footer_))) + return false; + + if (!settings_.gzip_header_mode_) return true; + + return IsGzipFooterValid(); +} + +bool ZLib::IsGzipFooterValid() const { + // If we were expecting a gzip footer, and didn't get a full one, + // that's an error. + if (gzip_footer_bytes_ == -1 || gzip_footer_bytes_ < sizeof(gzip_footer_)) + return false; + + // The footer holds the lower four bytes of the length. + uLong uncompressed_size = 0; + uncompressed_size += static_cast(gzip_footer_[7]) << 24; + uncompressed_size += gzip_footer_[6] << 16; + uncompressed_size += gzip_footer_[5] << 8; + uncompressed_size += gzip_footer_[4] << 0; + if (uncompressed_size != (uncompressed_size_ & 0xffffffff)) { + return false; + } + + uLong checksum = 0; + checksum += static_cast(gzip_footer_[3]) << 24; + checksum += gzip_footer_[2] << 16; + checksum += gzip_footer_[1] << 8; + checksum += gzip_footer_[0] << 0; + if (crc_ != checksum) return false; + + return true; +} + +// Uncompresses the source buffer into the destination buffer. +// The destination buffer must be long enough to hold the entire +// decompressed contents. +// +// We only initialize the uncomp_stream once. Thereafter, we use +// inflateReset2, which should be faster. +// +// Returns Z_OK on success, otherwise, it returns a zlib error code. +int ZLib::Uncompress(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen) { + int err; + if ((err = UncompressChunkOrAll(dest, destLen, source, sourceLen, + Z_FINISH)) != Z_OK) { + Reset(); // let us try to compress again + return err; + } + if (!UncompressChunkDone()) // calls Reset() + return Z_DATA_ERROR; + return Z_OK; // stream_end is ok +} + +// read uncompress length from gzip footer +uLongf ZLib::GzipUncompressedLength(const Bytef* source, uLong len) { + if (len <= 4) return 0; // malformed data. + + return (static_cast(source[len - 1]) << 24) + + (static_cast(source[len - 2]) << 16) + + (static_cast(source[len - 3]) << 8) + + (static_cast(source[len - 4]) << 0); +} + +int ZLib::UncompressGzipAndAllocate(Bytef** dest, uLongf* destLen, + const Bytef* source, uLong sourceLen) { + *dest = nullptr; // until we successfully allocate + if (!settings_.gzip_header_mode_) return Z_VERSION_ERROR; // *shrug* + + uLongf uncompress_length = GzipUncompressedLength(source, sourceLen); + + // Do not trust the uncompress size reported by the compressed buffer. + if (uncompress_length > *destLen) { + if (!HasGzipHeader(reinterpret_cast(source), sourceLen)) { + VLOG(1) << "Attempted to un-gzip data that is not gzipped."; + return Z_DATA_ERROR; + } + VLOG(1) << "Uncompressed size " << uncompress_length + << " exceeds maximum expected size " << *destLen; + return Z_MEM_ERROR; // probably a corrupted gzip buffer + } + + *destLen = uncompress_length; + + *dest = (Bytef*)malloc(*destLen); // NOLINT + if (*dest == nullptr) // probably a corrupted gzip buffer + return Z_MEM_ERROR; + + const int retval = Uncompress(*dest, destLen, source, sourceLen); + if (retval != Z_OK) { // just to make life easier for them + free(*dest); + *dest = nullptr; + } + return retval; +} + +// Convenience method to check if a bytestream has a header. This +// is intended as a quick test: "Is this likely a GZip file?" +bool ZLib::HasGzipHeader(const char* source, int sourceLen) { + GZipHeader gzh; + const char* ptr = nullptr; + return gzh.ReadMore(source, sourceLen, &ptr) == GZipHeader::COMPLETE_HEADER; +} + +} // namespace csrblocksparse diff --git a/sparse_matmul/zlib_wrapper/zlibwrapper.h b/sparse_matmul/zlib_wrapper/zlibwrapper.h new file mode 100644 index 00000000..22e3980e --- /dev/null +++ b/sparse_matmul/zlib_wrapper/zlibwrapper.h @@ -0,0 +1,320 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_ZLIBWRAPPER_H +#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_ZLIBWRAPPER_H + +#include "zlib.h" + +namespace csrblocksparse { + +class GZipHeader; + +class ZLib { + public: + ZLib(); + ~ZLib(); + + // Set this to true if you want to be flexible with the gzip footer. + static void set_should_be_flexible_with_gzip_footer(bool b) { + should_be_flexible_with_gzip_footer_ = b; + } + + static bool should_be_flexible_with_gzip_footer() { + return should_be_flexible_with_gzip_footer_; + } + + // Wipe a ZLib object to a virgin state. This differs from Reset() + // in that it also breaks any dictionary, gzip, etc, state. + void Reinit(); + + // Call this to make a zlib buffer as good as new. Here's the only + // case where they differ: + // CompressChunk(a); CompressChunk(b); CompressChunkDone(); vs + // CompressChunk(a); Reset(); CompressChunk(b); CompressChunkDone(); + // You'll want to use Reset(), then, when you interrupt a compress + // (or uncompress) in the middle of a chunk and want to start over. + void Reset(); + + // Sets no_header_mode appropriately. Note that using NoHeaderMode + // in conjunction with a preset dictionary is not supported (zlib + // starts behaving oddly if you try to do this). + void SetNoHeaderMode(bool no_header_mode); + + // Returns our current no_header_mode. + bool no_header_mode() const { return settings_.no_header_mode_; } + + // Uses a gzip header/footer; the output is a valid gzip file. + // This also causes us to generate a crc32 checksum used with gzip + void SetGzipHeaderMode(); + + // By default UncompressAtMostOrAll will return Z_OK upon hitting the end of + // the input stream. This function modifies that behavior by returning + // Z_STREAM_END instead. This is useful when getting multiple compressed + // documents in a single stream. Returning Z_STREAM_END will indicate the end + // of a document. + void SetDontHideStreamEnd(); + + // Sets the compression level to be used + void SetCompressionLevel(int level) { settings_.compression_level_ = level; } + + // Sets the size of the window (history buffer) used by the compressor. + // The size is expressed in bits (log base 2 of the desired size). + void SetCompressionWindowSizeInBits(int bits) { + settings_.window_bits_ = bits; + } + + // Controls the amount of memory used by the compresser. + // Legal value are 1 through 9. See zlib.h for more info. + void SetCompressionMemLevel(int level) { settings_.mem_level_ = level; } + + // Sets the initial dictionary to be used for decompression. + void SetDictionary(const char* initial_dict, unsigned int dict_len); + + // According to the zlib manual, when you Compress, the destination + // buffer must have size at least src + .1%*src + 12. This function + // helps you calculate that. Augment this to account for a potential + // gzip header and footer, plus a few bytes of slack. + static uLong MinCompressbufSize(uLong uncompress_size) { + return uncompress_size + uncompress_size / 1000 + 40; + } + + // The minimum size of footers written by CompressChunkDone(). + int MinFooterSize() const; + + // Compresses the source buffer into the destination buffer. + // sourceLen is the byte length of the source buffer. + // Upon entry, destLen is the total size of the destination buffer, + // which must be of size at least MinCompressbufSize(sourceLen). + // Upon exit, destLen is the actual size of the compressed buffer. + // + // This function can be used to compress a whole file at once if the + // input file is mmap'ed. + // + // Returns Z_OK if success, Z_MEM_ERROR if there was not + // enough memory, Z_BUF_ERROR if there was not enough room in the + // output buffer. Note that if the output buffer is exactly the same + // size as the compressed result, we still return Z_BUF_ERROR. + // (check CL#1936076) + // + // If the values of *destLen or sourceLen do not fit in an unsigned int, + // Z_BUF_ERROR is returned. + int Compress(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen); + + // Uncompresses the source buffer into the destination buffer. + // The destination buffer must be long enough to hold the entire + // decompressed contents. + // + // Returns Z_OK on success, otherwise, it returns a zlib error code. + // + // If the values of *destLen or sourceLen do not fit in an unsigned int, + // Z_BUF_ERROR is returned. + int Uncompress(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen); + + // Get the uncompressed size from the gzip header. Returns 0 if source is too + // short (len < 5). + uLongf GzipUncompressedLength(const Bytef* source, uLong len); + + // Special helper function to help uncompress gzipped documents: + // We'll allocate (with malloc) a destination buffer exactly big + // enough to hold the gzipped content. We set dest and destLen. + // If we don't return Z_OK, *dest will be NULL, otherwise you + // should free() it when you're done with it. + // Returns Z_OK on success, otherwise, it returns a zlib error code. + // Its the responsibility of the user to set *destLen to the + // expected maximum size of the uncompressed data. The size of the + // uncompressed data is read from the compressed buffer gzip footer. + // This value cannot be trusted, so we compare it to the expected + // maximum size supplied by the user, returning Z_MEM_ERROR if its + // greater than the expected maximum size. + int UncompressGzipAndAllocate(Bytef** dest, uLongf* destLen, + const Bytef* source, uLong sourceLen); + + // Streaming compression and decompression methods come in two + // variations. {Unc,C}ompressAtMost() and {Unc,C}ompressChunk(). + // The former decrements sourceLen by the amount of data that was + // consumed: if it returns Z_BUF_ERROR, set the source of the next + // {Unc,C}ompressAtMost() to the unconsumed data. + // {Unc,C}ompressChunk() is the legacy interface and does not do + // this, thus it cannot recover from a Z_BUF_ERROR (except for in + // the first chunk). + + // Compresses data one chunk at a time -- ie you can call this more + // than once. This is useful for a webserver, for instance, which + // might want to use chunked encoding with compression. To get this + // to work you need to call start and finish routines. + // + // Returns Z_OK if success, Z_MEM_ERROR if there was not + // enough memory, Z_BUF_ERROR if there was not enough room in the + // output buffer. + + int CompressAtMost(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen); + + int CompressChunk(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen); + + // Emits gzip footer information, as needed. + // destLen should be at least MinFooterSize() long. + // Returns Z_OK, Z_MEM_ERROR, and Z_BUF_ERROR as in CompressChunk(). + int CompressChunkDone(Bytef* dest, uLongf* destLen); + + // Uncompress data one chunk at a time -- ie you can call this + // more than once. To get this to work you need to call per-chunk + // and "done" routines. + // + // Returns Z_OK if success, Z_MEM_ERROR if there was not + // enough memory, Z_BUF_ERROR if there was not enough room in the + // output buffer. + + int UncompressAtMost(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen); + int UncompressChunk(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen); + + // Checks gzip footer information, as needed. Mostly this just + // makes sure the checksums match. Whenever you call this, it + // will assume the last 8 bytes from the previous UncompressChunk + // call are the footer. Returns true iff everything looks ok. + bool UncompressChunkDone(); + + // Only meaningful for chunked compressing/uncompressing. It's true + // after initialization or reset and before the first chunk of + // user data is received. + bool first_chunk() const { return first_chunk_; } + + // Returns a pointer to our current dictionary: + const Bytef* dictionary() const { return settings_.dictionary_; } + + // Convenience method to check if a bytestream has a header. This + // is intended as a quick test: "Is this likely a GZip file?" + static bool HasGzipHeader(const char* source, int sourceLen); + + // Have we parsed the complete gzip footer, and does it match the + // length and CRC checksum of the content that we have uncompressed + // so far? + bool IsGzipFooterValid() const; + + // Accessor for the uncompressed size (first added to address issue #509976) + uLong uncompressed_size() const { return uncompressed_size_; } + + private: + int InflateInit(); // sets up the zlib inflate structure + int DeflateInit(); // sets up the zlib deflate structure + + // These init the zlib data structures for compressing/uncompressing + int CompressInit(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen); + int UncompressInit(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen); + // Initialization method to be called if we hit an error while + // uncompressing. On hitting an error, call this method before + // returning the error. + void UncompressErrorInit(); + // Helper functions to write gzip-specific data + int WriteGzipHeader(); + int WriteGzipFooter(Bytef* dest, uLongf destLen); + + // Helper function for both Compress and CompressChunk + int CompressChunkOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen, int flush_mode); + int CompressAtMostOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen, int flush_mode); + + // Likewise for UncompressAndUncompressChunk + int UncompressChunkOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong sourceLen, int flush_mode); + + int UncompressAtMostOrAll(Bytef* dest, uLongf* destLen, const Bytef* source, + uLong* sourceLen, int flush_mode); + + // Initialization method to be called if we hit an error while + // compressing. On hitting an error, call this method before + // returning the error. + void CompressErrorInit(); + + // Makes sure the parameters are valid + void CheckValidParams(); + + struct Settings { + // null if we don't want an initial dictionary + Bytef* dictionary_; // NOLINT + + // initial dictionary length + unsigned int dict_len_; // NOLINT + + // compression level + int compression_level_; // NOLINT + + // log base 2 of the window size used in compression + int window_bits_; // NOLINT + + // specifies the amount of memory to be used by compressor (1-9) + int mem_level_; // NOLINT + + // true if we want/expect no zlib headers + bool no_header_mode_; // NOLINT + + // true if we want/expect gzip headers + bool gzip_header_mode_; // NOLINT + + // Controls behavior of UncompressAtMostOrAll with regards to returning + // Z_STREAM_END. See comments for SetDontHideStreamEnd. + bool dont_hide_zstream_end_; // NOLINT + }; + + // We allow all kinds of bad footers when this flag is true. + // Some web servers send bad pages corresponding to these cases + // and IE is tolerant with it. + // - Extra bytes after gzip footer (see bug 69126) + // - No gzip footer (see bug 72896) + // - Incomplete gzip footer (see bug 71871706) + static bool should_be_flexible_with_gzip_footer_; + + // "Current" settings. These will be used whenever we next configure zlib. + // For example changing compression level or header mode will be recorded + // in these, but don't usually get applied immediately but on next compress. + Settings settings_; + + // Settings last used to initialise and configure zlib. These are needed + // to know if the current desired configuration in settings_ is sufficiently + // compatible with the previous configuration and we can just reconfigure the + // underlying zlib objects, or have to recreate them from scratch. + Settings init_settings_; + + z_stream comp_stream_; // Zlib stream data structure + bool comp_init_; // True if we have initialized comp_stream_ + z_stream uncomp_stream_; // Zlib stream data structure + bool uncomp_init_; // True if we have initialized uncomp_stream_ + + // These are used only in gzip compression mode + uLong crc_; // stored in gzip footer, fitting 4 bytes + uLong uncompressed_size_; + + GZipHeader* gzip_header_; // our gzip header state + + Byte gzip_footer_[8]; // stored footer, used to uncompress + int gzip_footer_bytes_; // num of footer bytes read so far, or -1 + + // These are used only with chunked compression. + bool first_chunk_; // true if we need to emit headers with this chunk +}; + +} // namespace csrblocksparse + +#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_ZLIB_ZLIBWRAPPER_H diff --git a/toolchain/BUILD b/toolchain/BUILD index dc31f3df..cbb209d8 100644 --- a/toolchain/BUILD +++ b/toolchain/BUILD @@ -12,14 +12,6 @@ filegroup( output_licenses = ["unencumbered"], ) -filegroup( - name = "compiler_files", - srcs = [ - ":includes", - ], - output_licenses = ["unencumbered"], -) - cc_toolchain( name = "k8_toolchain", all_files = ":empty", diff --git a/toolchain/README.md b/toolchain/README.md new file mode 100644 index 00000000..c6a33f9d --- /dev/null +++ b/toolchain/README.md @@ -0,0 +1,38 @@ +# Clang/libc++ toolchain setup + +These instructions are for building Lyra with clang using +`--config=clang_toolchain`. + +This is not necessary for most users, who will be fine using the default +toolchain (likely gcc). The clang toolchain is provided as a reference for +debugging on Linux, since the android NDK also requires the use of clang/libc++. + +You can use a default clang installed from your package manager. It should be a +version of clang that is at least 11.0. + +Optionally, you can install a certain version of clang and libc++ from source +with a recipe like the following: + +```shell +git clone https://github.com/llvm/llvm-project.git +cd llvm-project +git checkout 96ef4f307df2 + +mkdir build_clang +cd build_clang +cmake -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DLLVM_ENABLE_PROJECTS="clang" -DCMAKE_BUILD_TYPE=release ../llvm +ninja +sudo $(which ninja) install + +cd .. +mkdir build_libcxx +cd build_libcxx +cmake -G Ninja -DCMAKE_C_COMPILER=/usr/local/bin/clang -DCMAKE_CXX_COMPILER=/usr/local/bin/clang++ -DLLVM_ENABLE_PROJECTS="libcxx;libcxxabi" -DCMAKE_BUILD_TYPE=release ../llvm +ninja +sudo $(which ninja) install + +sudo ldconfig +``` + +Note: the above will install a particular version of libc++ to /usr/local/lib, +and clang to /usr/local/bin, which the toolchain depends on. diff --git a/transpose_convolutional_layer_wrapper.h b/transpose_convolutional_layer_wrapper.h index cc4c2e0e..fa25b276 100644 --- a/transpose_convolutional_layer_wrapper.h +++ b/transpose_convolutional_layer_wrapper.h @@ -24,7 +24,7 @@ #include "absl/memory/memory.h" #include "glog/logging.h" #include "layer_wrapper.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/transpose_convolutional_layer_wrapper_test.cc b/transpose_convolutional_layer_wrapper_test.cc index ae0e415b..a48bde6f 100644 --- a/transpose_convolutional_layer_wrapper_test.cc +++ b/transpose_convolutional_layer_wrapper_test.cc @@ -17,13 +17,13 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "gmock/gmock.h" #include "gtest/gtest.h" #include "include/ghc/filesystem.hpp" #include "layer_wrapper.h" #include "layer_wrapper_test_common.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/vector_quantizer_impl.cc b/vector_quantizer_impl.cc index de2d04e9..125740a0 100644 --- a/vector_quantizer_impl.cc +++ b/vector_quantizer_impl.cc @@ -29,7 +29,7 @@ #include "audio/dsp/signal_vector_util.h" #include "glog/logging.h" #include "include/ghc/filesystem.hpp" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { diff --git a/vector_quantizer_impl_test.cc b/vector_quantizer_impl_test.cc index 65b42e42..ea8accd5 100644 --- a/vector_quantizer_impl_test.cc +++ b/vector_quantizer_impl_test.cc @@ -23,7 +23,7 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "Eigen/Core" #include "absl/memory/memory.h" #include "absl/strings/string_view.h" diff --git a/wav_util_test.cc b/wav_util_test.cc index 0985ea1b..670e3b3c 100644 --- a/wav_util_test.cc +++ b/wav_util_test.cc @@ -14,8 +14,8 @@ #include "wav_util.h" -// placeholder for get runfiles header. -// placeholder for testing header. +// Placeholder for get runfiles header. +// Placeholder for testing header. #include "absl/flags/flag.h" #include "absl/status/statusor.h" #include "gtest/gtest.h" diff --git a/wavegru_model_impl.cc b/wavegru_model_impl.cc index cea00d8b..e690fa01 100644 --- a/wavegru_model_impl.cc +++ b/wavegru_model_impl.cc @@ -27,7 +27,7 @@ #include "glog/logging.h" #include "include/ghc/filesystem.hpp" #include "lyra_wavegru.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" // IWYU pragma: no_include "speech/greco3/core/thread.h" #include "absl/memory/memory.h" #include "absl/status/status.h" @@ -42,7 +42,7 @@ namespace codec { std::unique_ptr WavegruModelImpl::Create( int num_samples_per_hop, int num_features, int num_frames_per_packet, - const ghc::filesystem::path& model_path) { + float silence_value, const ghc::filesystem::path& model_path) { const int kNumThreads = 1; const int kNumCondHiddens = 512; const std::string kModelPrefix = "lyra_16khz"; @@ -63,13 +63,13 @@ std::unique_ptr WavegruModelImpl::Create( return absl::WrapUnique(new WavegruModelImpl( std::string(model_path), kModelPrefix, kNumThreads, num_features, kNumCondHiddens, num_samples_per_hop, num_frames_per_packet, - std::move(wavegru), std::move(merge_filter))); + silence_value, std::move(wavegru), std::move(merge_filter))); } WavegruModelImpl::WavegruModelImpl( const std::string& model_path, const std::string& model_prefix, int num_threads, int num_features, int num_cond_hiddens, - int num_samples_per_hop, int num_frames_per_packet, + int num_samples_per_hop, int num_frames_per_packet, float silence_value, std::unique_ptr> wavegru, std::unique_ptr buffer_merger) : num_threads_(num_threads), @@ -90,7 +90,7 @@ WavegruModelImpl::WavegruModelImpl( conditioning_ = absl::make_unique( num_features, num_cond_hiddens, wavegru_->num_gru_hiddens(), num_samples_per_hop_, num_frames_per_packet, - /*num_threads=*/1, model_path, model_prefix); + /*num_threads=*/1, silence_value, model_path, model_prefix); } WavegruModelImpl::~WavegruModelImpl() { @@ -130,8 +130,7 @@ absl::optional> WavegruModelImpl::GenerateSamples( wavegru_->SampleThreaded(tid, conditioning_.get(), &model_split_samples_, 0); }; - background_threads_.emplace_back( - absl::make_unique(f)); + background_threads_.emplace_back(absl::make_unique(f)); } } diff --git a/wavegru_model_impl.h b/wavegru_model_impl.h index 85bb3abe..8d7c4137 100644 --- a/wavegru_model_impl.h +++ b/wavegru_model_impl.h @@ -20,6 +20,7 @@ #include #include #include +#include // NOLINT #include #include "absl/types/optional.h" @@ -29,7 +30,7 @@ #include "include/ghc/filesystem.hpp" #include "lyra_types.h" #include "lyra_wavegru.h" -#include "sparse_inference_matrixvector.h" +#include "sparse_matmul/sparse_matmul.h" namespace chromemedia { namespace codec { @@ -40,7 +41,7 @@ class WavegruModelImpl : public GenerativeModelInterface { // Returns a nullptr on failure. static std::unique_ptr Create( int num_samples_per_hop, int num_features, int num_frames_per_packet, - const ghc::filesystem::path& model_path); + float silence_value, const ghc::filesystem::path& model_path); ~WavegruModelImpl() override; @@ -65,6 +66,7 @@ class WavegruModelImpl : public GenerativeModelInterface { const std::string& model_prefix, int num_threads, int num_features, int num_cond_hiddens, int num_samples_per_hop, int num_frames_per_packet, + float silence_value, std::unique_ptr> wavegru, std::unique_ptr buffer_merger); @@ -73,7 +75,7 @@ class WavegruModelImpl : public GenerativeModelInterface { // The direct output samples from the model in the split domain. std::vector> model_split_samples_; - std::vector> background_threads_; + std::vector> background_threads_; std::unique_ptr> wavegru_; std::unique_ptr conditioning_; diff --git a/wavegru_model_impl_test.cc b/wavegru_model_impl_test.cc index e11d7772..c4877ceb 100644 --- a/wavegru_model_impl_test.cc +++ b/wavegru_model_impl_test.cc @@ -19,7 +19,7 @@ #include #include -// placeholder for get runfiles header. +// Placeholder for get runfiles header. #include "gtest/gtest.h" #include "include/ghc/filesystem.hpp" #include "lyra_config.h" @@ -33,7 +33,7 @@ class WavegruModelImplTest : public testing::Test { WavegruModelImplTest() : num_samples_per_hop_(GetNumSamplesPerHop(kInternalSampleRateHz)), model_(WavegruModelImpl::Create( - num_samples_per_hop_, kNumFeatures, kNumFramesPerPacket, + num_samples_per_hop_, kNumFeatures, kNumFramesPerPacket, 0.0f, ghc::filesystem::current_path() / "wavegru")) {} const int num_samples_per_hop_; std::unique_ptr model_;