Skip to content

Commit

Permalink
cpu: softmax: improve performance of dense case for GNU compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
densamoilov committed Jan 23, 2019
1 parent e1268a1 commit d36a2f5
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/cpu/ref_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
#include "mkldnn_thread.hpp"

#include "ref_softmax.hpp"
#include "gemm/os_blas.hpp"

#ifdef USE_MKL
#include "mkl_vml_functions.h"
#include "mkl_cblas.h"
#endif

namespace mkldnn {
Expand Down Expand Up @@ -97,6 +97,12 @@ void ref_softmax_fwd_t<data_type>::execute_forward_generic() const {
template <impl::data_type_t data_type>
void ref_softmax_fwd_t<data_type>::_max(int n, const data_t *x,
data_t *max_data) const {
#ifdef USE_CBLAS
if (data_type == data_type::f32) {
max_data[0] = x[cblas_isamax(n, x, 1)];
return;
}
#endif
max_data[0] = x[0];
for (int c = 1; c < n; ++c)
max_data[0] = nstl::max(max_data[0], x[c]);
Expand All @@ -105,8 +111,18 @@ void ref_softmax_fwd_t<data_type>::_max(int n, const data_t *x,
template <impl::data_type_t data_type>
void ref_softmax_fwd_t<data_type>::_sub(int n, data_t alpha, const data_t *x,
data_t *y) const {
for (int c = 0; c < n; ++c)
y[c] = x[c] - alpha;
constexpr int unroll_factor = 32;
int tail = n % unroll_factor;
for (int i = 0; i < n - tail; i += unroll_factor) {
PRAGMA_OMP_SIMD()
for (int j = 0; j < unroll_factor; j++) {
y[i + j] = x[i + j] - alpha;
}
}
PRAGMA_OMP_SIMD()
for (int i = n - tail; i < n; i++) {
y[i] = x[i] - alpha;
}
}

template <impl::data_type_t data_type>
Expand All @@ -124,7 +140,7 @@ void ref_softmax_fwd_t<data_type>::_exp(int n, const data_t *a,
template <impl::data_type_t data_type>
void ref_softmax_fwd_t<data_type>::_sum(int n, const data_t *x,
data_t *sum_data) const {
#ifdef USE_MKL
#ifdef USE_CBLAS
// Here we are summing x's eg. e^z , which are positives
// so we can use BLAS ASUM
if (data_type == data_type::f32) {
Expand All @@ -141,7 +157,7 @@ void ref_softmax_fwd_t<data_type>::_sum(int n, const data_t *x,

template <impl::data_type_t data_type>
void ref_softmax_fwd_t<data_type>::_scal(int n, data_t alpha, data_t *x) const {
#ifdef USE_MKL
#ifdef USE_CBLAS
if (data_type == data_type::f32) {
cblas_sscal(n, alpha, x, 1);
return;
Expand Down

0 comments on commit d36a2f5

Please sign in to comment.