diff --git a/src/cpu/ref_softmax.cpp b/src/cpu/ref_softmax.cpp index aa707a60ac8..350cac335ae 100644 --- a/src/cpu/ref_softmax.cpp +++ b/src/cpu/ref_softmax.cpp @@ -23,10 +23,10 @@ #include "mkldnn_thread.hpp" #include "ref_softmax.hpp" +#include "gemm/os_blas.hpp" #ifdef USE_MKL #include "mkl_vml_functions.h" -#include "mkl_cblas.h" #endif namespace mkldnn { @@ -97,6 +97,12 @@ void ref_softmax_fwd_t::execute_forward_generic() const { template void ref_softmax_fwd_t::_max(int n, const data_t *x, data_t *max_data) const { +#ifdef USE_CBLAS + if (data_type == data_type::f32) { + max_data[0] = x[cblas_isamax(n, x, 1)]; + return; + } +#endif max_data[0] = x[0]; for (int c = 1; c < n; ++c) max_data[0] = nstl::max(max_data[0], x[c]); @@ -105,8 +111,18 @@ void ref_softmax_fwd_t::_max(int n, const data_t *x, template void ref_softmax_fwd_t::_sub(int n, data_t alpha, const data_t *x, data_t *y) const { - for (int c = 0; c < n; ++c) - y[c] = x[c] - alpha; + constexpr int unroll_factor = 32; + int tail = n % unroll_factor; + for (int i = 0; i < n - tail; i += unroll_factor) { + PRAGMA_OMP_SIMD() + for (int j = 0; j < unroll_factor; j++) { + y[i + j] = x[i + j] - alpha; + } + } + PRAGMA_OMP_SIMD() + for (int i = n - tail; i < n; i++) { + y[i] = x[i] - alpha; + } } template @@ -124,7 +140,7 @@ void ref_softmax_fwd_t::_exp(int n, const data_t *a, template void ref_softmax_fwd_t::_sum(int n, const data_t *x, data_t *sum_data) const { -#ifdef USE_MKL +#ifdef USE_CBLAS // Here we are summing x's eg. e^z , which are positives // so we can use BLAS ASUM if (data_type == data_type::f32) { @@ -141,7 +157,7 @@ void ref_softmax_fwd_t::_sum(int n, const data_t *x, template void ref_softmax_fwd_t::_scal(int n, data_t alpha, data_t *x) const { -#ifdef USE_MKL +#ifdef USE_CBLAS if (data_type == data_type::f32) { cblas_sscal(n, alpha, x, 1); return;