From 75cc3562aa508877f1d852946ee05e8d9e33bf6a Mon Sep 17 00:00:00 2001 From: Andrey Kalinin <andrey.kalinin@intel.com> Date: Sat, 23 Mar 2019 14:05:49 -0700 Subject: [PATCH] cpu: conv: gemm f32 forward: oc/os/ic blocking --- src/common/mkldnn_thread.hpp | 3 +- src/cpu/gemm_convolution.cpp | 197 +++++--- src/cpu/gemm_convolution_utils.cpp | 754 ++++++++++++++++++++--------- src/cpu/gemm_convolution_utils.hpp | 2 +- src/cpu/jit_primitive_conf.hpp | 5 + 5 files changed, 663 insertions(+), 298 deletions(-) diff --git a/src/common/mkldnn_thread.hpp b/src/common/mkldnn_thread.hpp index 8e6f579cd7c..9ece18a8134 100644 --- a/src/common/mkldnn_thread.hpp +++ b/src/common/mkldnn_thread.hpp @@ -125,8 +125,7 @@ void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, grp = ithr / grp_size_big; grp_ithr = ithr % grp_size_big; grp_nthr = grp_size_big; - } - else { // ithr in last groups + } else { // ithr in last groups grp = n_grp_big + ithr_bound_distance / grp_size_small; grp_ithr = ithr_bound_distance % grp_size_small; grp_nthr = grp_size_small; diff --git a/src/cpu/gemm_convolution.cpp b/src/cpu/gemm_convolution.cpp index 604a728b476..c6dc6088ce4 100644 --- a/src/cpu/gemm_convolution.cpp +++ b/src/cpu/gemm_convolution.cpp @@ -31,6 +31,18 @@ using namespace mkldnn::impl::status; using namespace mkldnn::impl::memory_tracking::names; using namespace mkldnn::impl::utils; +namespace { +struct im_pos_t { + im_pos_t() : n{ 0 }, g{ 0 }, od{ 0 }, sp{ 0 }, ic{ 0 }, oc{ 0 } {} + int n, g, od, sp, ic, oc; + bool do_im2col(const im_pos_t &prev) const { + return true + && (n != prev.n || g != prev.g || od != prev.od || sp != prev.sp + || ic != prev.ic); + } +}; +} // namespace + void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS); @@ -41,97 +53,144 @@ void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const { const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_; - const int M = jcp.os * jcp.od; const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id; - const size_t dst_step = jcp.oc * M; - const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks; + const size_t weights_oc_size = jcp.ic * jcp.ks; + const size_t weights_g_size = weights_oc_size * jcp.oc; assert(IMPLICATION( - jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow)); - assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1)); - - const int K = jcp.ic * jcp.ks; - const int N = jcp.oc; + jcp.id != 1, jcp.os_block == jcp.os && jcp.ic_block == jcp.ic)); if (jcp.im2col_sz && jcp.id != 1) parallel_nd(jcp.im2col_sz * jcp.nthr, [&](ptrdiff_t i) { col[i] = (data_t)0; }); - const int nb_oh = div_up(jcp.oh, jcp.oh_block); - const int nb_ow = div_up(jcp.ow, jcp.ow_block); - const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow; parallel(jcp.nthr, [&](const int ithr, const int nthr) { data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz; - int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 }; - size_t start = 0, end = 0; - - balance211(work_amount, nthr, ithr, start, end); - nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, - nb_oh, owb, nb_ow); - for (size_t iwork = start; iwork < end; ++iwork) { - int oh = ohb * jcp.oh_block; - int ow = owb * jcp.ow_block; - const data_t *_src = src + (n * jcp.ngroups + g) * src_step; - const data_t *_weights = weights + g * weights_g_size; - data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step; - const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh); - const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow); - if (jcp.im2col_sz) { + auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev, + im_pos_t &step, const im_pos_t &end) { + const data_t *_src + = src + (curr.n * jcp.ngroups + curr.g) * src_step; + step.oc = nstl::min( + jcp.oc_block, nstl::min(jcp.oc, end.oc) - curr.oc); + step.sp = nstl::min(jcp.os_block, + nstl::min(jcp.os - curr.sp, end.sp - spatial)); + step.ic = nstl::min( + jcp.ic_block, nstl::min(jcp.ic, end.ic) - curr.ic); + bool do_im2col = curr.do_im2col(prev); + prev = curr; + + if (jcp.im2col_sz && do_im2col) { if (jcp.id == 1) - jit_gemm_convolution_utils::im2col( - jcp, _src, _col, oh, h_step, ow, w_step); + jit_gemm_convolution_utils::im2col(jcp, _src, _col, curr.sp, + step.sp, curr.ic, step.ic); else - jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od); + jit_gemm_convolution_utils::im2col_3d( + jcp, _src, _col, curr.od); } - const data_t one = 1.0; - const int m = h_step * w_step; + const int M = jcp.os * jcp.od; + const size_t dst_step = jcp.oc * M; + const int m = step.sp; const int LDA = jcp.im2col_sz ? m : M; - data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow; - - extended_sgemm("N", "N", &m, &N, &K, &one, - jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K, - &this->beta_, _dst, &M); - - data_t *d = _dst; - if (eltwise_) { - // fast branch for ReLU case - if (eltwise_->alg_ == alg_kind::eltwise_relu) { - parallel_nd(jcp.oc, [&](const int oc) { - data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; - data_t *d_ = d + oc * M; - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_; - } - }); - } else { - parallel_nd(jcp.oc, [&](const int oc) { - data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0; - data_t *d_ = d + oc * M; + data_t *_dst = dst + (curr.n * jcp.ngroups + curr.g) * dst_step + + curr.oc * M + curr.od * jcp.os + curr.sp; + const int K = step.ic * jcp.ks; + const int LDB = jcp.ic * jcp.ks; + const int N = step.oc; + + // TODO: what if this->beta_ != 0 && != 1 ? + const float beta = (curr.ic == 0) ? this->beta_ : one; + const float *_source = jcp.im2col_sz + ? _col + : _src + curr.ic * M + curr.od * jcp.os + curr.sp; + const data_t *_weights = weights + curr.g * weights_g_size + + curr.oc * weights_oc_size + curr.ic * jcp.ks; + + extended_sgemm("N", "N", &m, &N, &K, &one, _source, &LDA, _weights, + &LDB, &beta, _dst, &M); + if (curr.ic == jcp.ic - step.ic) { + const int oc_start = curr.g * jcp.oc + curr.oc; + if (eltwise_) { + // fast branch for ReLU case + if (eltwise_->alg_ == alg_kind::eltwise_relu) { + parallel_nd(step.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[oc_start + oc] : 0; + data_t *d_ = _dst + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + if (d_[oS] < 0) + d_[oS] *= eltwise_->alpha_; + } + }); + } else { + parallel_nd(step.oc, [&](const int oc) { + data_t b = jcp.with_bias ? bias[oc_start + oc] : 0; + data_t *d_ = _dst + oc * M; + PRAGMA_OMP_SIMD() + for (int oS = 0; oS < m; ++oS) { + d_[oS] += b; + d_[oS] = eltwise_->compute_scalar(d_[oS]); + } + }); + } + } else if (jcp.with_bias) { + parallel_nd(step.oc, [&](const int oc) { + data_t b = bias[oc_start + oc]; + data_t *d_ = _dst + oc * M; PRAGMA_OMP_SIMD() for (int oS = 0; oS < m; ++oS) { d_[oS] += b; - d_[oS] = eltwise_->compute_scalar(d_[oS]); } }); } - } else if (jcp.with_bias) { - parallel_nd(jcp.oc, [&](const int oc) { - data_t b = bias[g * jcp.oc + oc]; - data_t *d_ = d + oc * M; - PRAGMA_OMP_SIMD() - for (int oS = 0; oS < m; ++oS) { - d_[oS] += b; - } - }); } - nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh, - owb, nb_ow); + }; + im_pos_t start, end; + end.ic = jcp.ic; + + if (jcp.id == 1) { + const int sp_work = jcp.mb * jcp.ngroups * jcp.od * jcp.os; + balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc, + end.oc, jcp.nthr_oc); + } else { + const int sp_work = jcp.mb * jcp.ngroups * jcp.od; + balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc, + end.oc, jcp.nthr_oc); + start.sp *= jcp.os; + end.sp *= jcp.os; } + + im_pos_t curr, prev, step; + prev.n = prev.g = prev.od = prev.sp = prev.ic = -1; + step.oc = jcp.oc_block; + step.sp = jcp.os_block; + step.ic = jcp.ic_block; + + if (jcp.loop_order == gemm_loop_rlb) + for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic) + for (int spatial = start.sp; spatial < end.sp; + spatial += step.sp) { + nd_iterator_init(spatial, curr.n, jcp.mb, curr.g, + jcp.ngroups, curr.od, jcp.od, curr.sp, jcp.os); + for (curr.oc = start.oc; curr.oc < end.oc; + curr.oc += step.oc) { + inner_ker(spatial, curr, prev, step, end); + } + } + else if (jcp.loop_order == gemm_loop_lrb) + for (int spatial = start.sp; spatial < end.sp; spatial += step.sp) { + nd_iterator_init(spatial, curr.n, jcp.mb, curr.g, jcp.ngroups, + curr.od, jcp.od, curr.sp, jcp.os); + for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic) + for (curr.oc = start.oc; curr.oc < end.oc; + curr.oc += step.oc) + inner_ker(spatial, curr, prev, step, end); + } + else + assert("Unknown loop order"); }); } @@ -256,10 +315,10 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights( if (jcp.im2col_sz) { if (jcp.id == 1) jit_gemm_convolution_utils::im2col( - jcp, _src, _col, 0, jcp.oh, 0, jcp.ow); + jcp, _src, _col, 0, jcp.os, 0, jcp.ic); else - jit_gemm_convolution_utils::im2col_3d(jcp, _src, - _col, od); + jit_gemm_convolution_utils::im2col_3d( + jcp, _src, _col, od); } const data_t zero = 0.0, one = 1.0; diff --git a/src/cpu/gemm_convolution_utils.cpp b/src/cpu/gemm_convolution_utils.cpp index a78d7d2a7ed..0db572f28d7 100644 --- a/src/cpu/gemm_convolution_utils.cpp +++ b/src/cpu/gemm_convolution_utils.cpp @@ -119,106 +119,155 @@ void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, }); } +inline int saturate(int low, int upper, int value) { + return nstl::max(low, nstl::min(upper, value)); +} + /* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */ -void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, - float *__restrict col, int hs, int hb, int ws, int wb) { +void im2col(const jit_gemm_conv_conf_t &jcp, + const float *__restrict im, float *__restrict col, int ss, int sb, + int cs, int cb) { const size_t im_step = jcp.is; - const size_t col_step = jcp.ks * hb * wb; - if (jcp.stride_w == 1) { - // Generated code is more optimized for stride_w == 1 - // because innermost loop is by width - auto ker = [&](int ic, int kh, int kw, int oh) { - const float *__restrict im_ = im + ic * im_step; - float *__restrict col_ - = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb; - - const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad - + kh * (1 + jcp.dilate_h); - if (ih < 0 || ih >= jcp.ih) { - for (int ow = 0; ow < wb; ++ow) - col_[ow] = 0.f; - } else { - for (int ow = 0; ow < wb; ++ow) { - const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w); - if (iw < 0 || iw >= jcp.iw) - col_[ow] = 0.f; - else { - const size_t im_idx = ih * jcp.iw + iw; - col_[ow] = im_[im_idx]; + const size_t col_step = jcp.ks * sb; + const int dh = 1 + jcp.dilate_h; + const int dw = 1 + jcp.dilate_w; + const int sh = jcp.stride_h; + const int sw = jcp.stride_w; + const int tp = jcp.t_pad; + const int lp = jcp.l_pad; + const int first_oh = ss / jcp.ow; + const int last_oh = (ss + sb - 1) / jcp.ow; + const int oh_begin = first_oh; + const int oh_end = last_oh + 1; + const int first_ow = ss % jcp.ow; + const int last_ow = (ss + sb - 1) % jcp.ow; + + if (jcp.outer_threading) { + if (sw == 1) { + // Generated code is more optimized for stride_w == 1 + // because innermost loop is by width + for (int ic = 0; ic < cb; ic++) { + const float *__restrict im_ic = im + (ic + cs) * im_step; + for (int kh = 0; kh < jcp.kh; kh++) { + for (int kw = 0; kw < jcp.kw; kw++) { + float *__restrict col_k + = col + ic * col_step + (kh * jcp.kw + kw) * sb; + for (int oh = oh_begin; oh < oh_end; oh++) { + const int ih = oh * sh - tp + kh * dh; + const float *__restrict im_ + = im_ic + ih * jcp.iw - lp + kw * dw; + const int ow_begin + = (oh == first_oh) ? first_ow : 0; + const int ow_end + = (oh == last_oh) ? (last_ow + 1) : jcp.ow; + float *__restrict col_ = col_k + oh * jcp.ow - ss; + if (ih < 0 || ih >= jcp.ih) + for (int ow = ow_begin; ow < ow_end; ow++) + col_[ow] = 0.f; + else { + for (int ow = ow_begin; ow < ow_end; ++ow) { + const int iw = ow; + if (iw < lp - kw * dw || iw >= jcp.iw + lp - kw * dw) + col_[ow] = 0.f; + else + col_[ow] = im_[iw]; + } + } + } } } } - }; - - if (jcp.outer_threading) { - for (int ic = 0; ic < jcp.ic; ic++) - for (int kh = 0; kh < jcp.kh; kh++) - for (int kw = 0; kw < jcp.kw; kw++) - for (int oh = 0; oh < hb; oh++) - ker(ic, kh, kw, oh); - } - else { - parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker); - } - } else if (jcp.ic == 1) { - parallel_nd(jcp.kh, hb, [&](int kh, int oh) { - const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad - + kh * (1 + jcp.dilate_h); - if (ih < 0 || ih >= jcp.ih) - for (int kw = 0; kw < jcp.kw; ++kw) { - for (int ow = 0; ow < wb; ++ow) { - const size_t col_idx - = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; - col[col_idx] = 0; - } - } - else - for (int kw = 0; kw < jcp.kw; ++kw) { - for (int ow = 0; ow < wb; ++ow) { - const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad - + kw * (1 + jcp.dilate_w); - const size_t col_idx - = ((kh * jcp.kw + kw) * hb + oh) * wb + ow; - const size_t im_idx = ih * jcp.iw + iw; - if (iw < 0 || iw >= jcp.iw) - col[col_idx] = 0; - else - col[col_idx] = im[im_idx]; + } else { + for (int ic = 0; ic < cb; ic++) { + const float *__restrict im_ = im + (ic + cs) * im_step; + for (int kh = 0; kh < jcp.kh; kh++) { + for (int kw = 0; kw < jcp.kw; kw++) { + float *__restrict col_k + = col + ic * col_step + (kh * jcp.kw + kw) * sb; + for (int oh = oh_begin; oh < oh_end; oh++) { + const int ih = oh * sh - tp + kh * dh; + const int ow_begin + = (oh == first_oh) ? first_ow : 0; + const int ow_end + = (oh == last_oh) ? (last_ow + 1) : jcp.ow; + float *__restrict col_oh = col_k + oh * jcp.ow - ss; + if (ih < 0 || ih >= jcp.ih) + for (int ow = ow_begin; ow < ow_end; ow++) + col_oh[ow] = 0.f; + else + for (int ow = ow_begin; ow < ow_end; ow++) { + const int iw = ow * sw - lp + kw * dw; + if (iw < 0 || iw >= jcp.iw) + col_oh[ow] = 0.f; + else { + const ptrdiff_t im_idx + = ih * jcp.iw + iw; + col_oh[ow] = im_[im_idx]; + } + } + } } } - }); - } else { - - parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, - [&](int ic, int kh, int kw, int oh) { - const float *__restrict im_ = im + ic * im_step; - float *__restrict col_ = col + ic * col_step - + ((kh * jcp.kw + kw) * hb + oh) * wb; - - const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad - + kh * (1 + jcp.dilate_h); - if (ih < 0 || ih >= jcp.ih) { - for (int ow = 0; ow < wb; ++ow) - col_[ow] = 0.f; - } else { - for (int ow = 0; ow < wb; ++ow) { - const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad - + kw * (1 + jcp.dilate_w); - const size_t im_idx = ih * jcp.iw + iw; - if (iw < 0 || iw >= jcp.iw) - col_[ow] = 0.f; - else - col_[ow] = im_[im_idx]; - } } - }); + } + } else { + // TODO: optimize threading if jcp.ic*jcp.kh*jcp.kw*oh_range is small + // comparing to number of threads + const int oh_range = oh_end - oh_begin; + // Generated code is more optimized for stride_w == 1 + // because innermost loop is by width + if (sw == 1) + parallel_nd(cb, jcp.kh, jcp.kw, oh_range, + [&](int ic, int kh, int kw, int ohr) { + const int oh = ohr + oh_begin; + const int ih = oh * sh - tp + kh * dh; + const int ow_start = (oh == first_oh) ? first_ow : 0; + const int ow_end + = (oh == last_oh) ? (last_ow + 1) : jcp.ow; + float *__restrict col_oh = col + ic * col_step + + (kh * jcp.kw + kw) * sb + oh * jcp.ow - ss; + const float *__restrict im_ = im + (ic + cs) * im_step + ih * jcp.iw; + const int iw_shift = kw * dw - lp; + if (ih < 0 || ih >= jcp.ih) + for (int ow = ow_start; ow < ow_end; ow++) + col_oh[ow] = 0.f; + else + for (int ow = ow_start; ow < ow_end; ow++) { + const int iw = ow + iw_shift; + if (iw < 0 || iw >= jcp.iw) + col_oh[ow] = 0.f; + else + col_oh[ow] = im_[iw]; + } + }); + else + parallel_nd(cb, jcp.kh, jcp.kw, oh_range, + [&](int ic, int kh, int kw, int ohr) { + const int oh = ohr + oh_begin; + const int ih = oh * sh - tp + kh * dh; + const int ow_start = (oh == first_oh) ? first_ow : 0; + const int ow_end + = (oh == last_oh) ? (last_ow + 1) : jcp.ow; + float *__restrict col_oh = col + ic * col_step + + (kh * jcp.kw + kw) * sb + oh * jcp.ow - ss; + const float *__restrict im_ = im + (ic + cs) * im_step; + if (ih < 0 || ih >= jcp.ih) + for (int ow = ow_start; ow < ow_end; ow++) + col_oh[ow] = 0.f; + else + for (int ow = ow_start; ow < ow_end; ow++) { + const int iw = ow * sw - lp + kw * dw; + if (iw < 0 || iw >= jcp.iw) + col_oh[ow] = 0.f; + else { + const ptrdiff_t im_idx = ih * jcp.iw + iw; + col_oh[ow] = im_[im_idx]; + } + } + }); } } -inline int limit(int low, int upper, int value) { - return nstl::max(low, nstl::min(upper, value)); -} - /* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */ template <typename T> void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, @@ -238,10 +287,10 @@ void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */ const int hp = hs - tp; const int wp = ws - lp; - const int ih_start = limit(0, jcp.ih, hp); - const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh); - const int iw_start = limit(0, jcp.iw, wp); - const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw); + const int ih_start = saturate(0, jcp.ih, hp); + const int ih_end = saturate(0, jcp.ih, hp + hb + jcp.kh); + const int iw_start = saturate(0, jcp.iw, wp); + const int iw_end = saturate(0, jcp.iw, wp + wb + jcp.kw); const int ihb = ih_end - ih_start; const int iwb = iw_end - iw_start; @@ -267,15 +316,15 @@ void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, for (int kh = 0; kh < jcp.kh; kh++) { const ptrdiff_t col_idx_kh = kh * col_kh_stride; const int oh_kh = oh_init - kh; - const int oh_start = limit(0, hb, oh_kh); - const int oh_end = limit(0, hb, oh_kh + ihb); + const int oh_start = saturate(0, hb, oh_kh); + const int oh_end = saturate(0, hb, oh_kh + ihb); for (int kw = 0; kw < jcp.kw; kw++) { const ptrdiff_t col_idx_kw = col_idx_kh + kw * jcp.ic * col_ic_str; const int ow_kw = ow_init - kw; const int imtr_shift = oh_kh * iwb + ow_kw; - const int ow_start = limit(0, wb, ow_kw); - const int ow_end = limit(0, wb, ow_kw + iwb); + const int ow_start = saturate(0, wb, ow_kw); + const int ow_end = saturate(0, wb, ow_kw + iwb); for (int ic = 0; ic < jcp.ic; ic++) { const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str; const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift; @@ -315,9 +364,9 @@ void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, col[col_idx_base + ow] = shift; else { const int wp = lp - kw * dw; - const int ow_start = limit(0, wb, div_up(wp, sw) - ws); + const int ow_start = saturate(0, wb, div_up(wp, sw) - ws); const int ow_end - = limit(0, wb, div_up(jcp.iw + wp, sw) - ws); + = saturate(0, wb, div_up(jcp.iw + wp, sw) - ws); for (int ow = 0; ow < ow_start; ow++) col[col_idx_base + ow] = shift; const int iw_base = ws * sw - wp; @@ -540,6 +589,12 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp, const bool is_bwd_d = jcp.prop_kind == backward_data; const bool is_bwd_w = jcp.prop_kind == backward_weights; const bool is_fwd = !is_bwd_d && !is_bwd_w; + jcp.os_block = jcp.os; + jcp.oc_block = jcp.oc; + jcp.ic_block = jcp.ic; + jcp.loop_order = gemm_loop_rlb; + jcp.nthr_oc = 1; + jcp.oh_block = is_fwd ? jcp.oh : jcp.ih; jcp.ow_block = is_fwd ? jcp.ow : jcp.iw; @@ -547,127 +602,133 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp, bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; // TODO: maybe mitigate blocking restriction - const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw; const int L2 = get_cache_size(2, true) / (is_int8_conv ? sizeof(int8_t) : sizeof(float)); - bool is_blocking_applicable = true - && is_fwd && jcp.im2col_sz - && jcp.id == 1 && jcp.od == 1 - && jcp.dilate_h == 0 && jcp.dilate_w == 0 - && !is_depthwise - && wei_size < L2/2; - if (is_blocking_applicable) { - // looking for oh and ow blocking - int h_block{ jcp.oh_block }, w_block{ jcp.ow_block }; - const int ic = jcp.ic; - const int oc = jcp.oc; - const int iw = jcp.iw; - const int ow = jcp.ow; - const int oh = jcp.oh; - const int os = oh * ow; - - // 1. cache requirement - int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow); - if (is_int8_conv) { - // Heuristic rule: gemm needed a lot of memory for internal usage - row_size *= 5; - // memory for accumulators - row_size += oc * ow * sizeof(uint32_t); - // memory for transposition - row_size += ic * iw; - } - - h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size))); - if (h_block == 1) { - int col_size = ic * jcp.ks + 2 * (ic + oc); - if (is_int8_conv) { - col_size *= 5; - col_size += oc * sizeof(uint32_t); - col_size += ic; - } - w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size))); - } - - // 2. threading requirement - if (h_block != oh) - h_block = nstl::max(1, rnd_dn(h_block, 4)); - if (w_block != ow) - w_block = nstl::max(1, rnd_dn(w_block, simd_w)); - - float thr_eff = 0.f; - float thr_eff_treshold = 0.9f; - if (w_block == ow) { - do { - int nb_h = div_up(oh, h_block); - size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h; - float disb = (float)oh / rnd_up(oh, h_block); - thr_eff = (float)work / rnd_up(work, max_threads); - thr_eff = (thr_eff + disb) / 2.f; - if (thr_eff >= thr_eff_treshold) - break; - h_block = rnd_dn(h_block - 4, 4); - } while (h_block > 0); - } - if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block - { - h_block = 1; - int nb_h = oh; - do { - int nb_w = div_up(ow, w_block); - size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w; - float disb = (float)ow / rnd_up(ow, w_block); - thr_eff = (float)work_amount / rnd_up(work_amount, max_threads); - thr_eff = (thr_eff + disb) / 2.f; - if (thr_eff > thr_eff_treshold) - break; - w_block = rnd_dn(w_block - simd_w, simd_w); - } while (w_block > 0); - } - h_block = nstl::max(1, h_block); - w_block = nstl::max(1, w_block); - const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w); - const float inner_thr_eff - = (float)inner_work / rnd_up(inner_work, max_threads); - if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) { - jcp.oh_block = h_block; - jcp.ow_block = w_block; - jcp.outer_threading = true; - } - // updating jcp.im2col_sz - if (jcp.oh_block != 1) - jcp.ow_block = ow; - jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block; - } - // For threading selection in bwd_d we do: - // 1. Rough estimation of efficiency for inner and outer threading. - // 2. Gemm size estimation in assumption that it does not work - // so effectively for small sizes. - // 64K - this is heuristic gemm size per thread threshold. const int gemm_thrld = 64 * 1024; if (is_int8_conv) { if (is_fwd) { + const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw; + bool is_blocking_applicable = true && is_fwd && jcp.im2col_sz + && jcp.id == 1 && jcp.od == 1 && jcp.dilate_h == 0 + && jcp.dilate_w == 0 && !is_depthwise && wei_size < L2 / 2; + if (is_blocking_applicable) { + // looking for oh and ow blocking + int h_block{ jcp.oh_block }, w_block{ jcp.ow_block }; + const int ic = jcp.ic; + const int oc = jcp.oc; + const int iw = jcp.iw; + const int ow = jcp.ow; + const int oh = jcp.oh; + const int os = oh * ow; + + // 1. cache requirement + int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow); + // Heuristic rule: gemm needed a lot of memory for internal + // usage + row_size *= 5; + // memory for accumulators + row_size += oc * ow * sizeof(uint32_t); + // memory for transposition + row_size += ic * iw; + + h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size))); + if (h_block == 1) { + int col_size = ic * jcp.ks + 2 * (ic + oc); + if (is_int8_conv) { + col_size *= 5; + col_size += oc * sizeof(uint32_t); + col_size += ic; + } + w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size))); + } + + // 2. threading requirement + if (h_block != oh) + h_block = nstl::max(1, rnd_dn(h_block, 4)); + if (w_block != ow) + w_block = nstl::max(1, rnd_dn(w_block, simd_w)); + + float thr_eff = 0.f; + float thr_eff_treshold = 0.9f; + if (w_block == ow) { + do { + int nb_h = div_up(oh, h_block); + size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h; + float disb = (float)oh / rnd_up(oh, h_block); + thr_eff = (float)work / rnd_up(work, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff >= thr_eff_treshold) + break; + h_block = rnd_dn(h_block - 4, 4); + } while (h_block > 0); + } + if (thr_eff + < thr_eff_treshold) // we didn't find suitable h_block + { + h_block = 1; + int nb_h = oh; + do { + int nb_w = div_up(ow, w_block); + size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w; + float disb = (float)ow / rnd_up(ow, w_block); + thr_eff = (float)work_amount + / rnd_up(work_amount, max_threads); + thr_eff = (thr_eff + disb) / 2.f; + if (thr_eff > thr_eff_treshold) + break; + w_block = rnd_dn(w_block - simd_w, simd_w); + } while (w_block > 0); + } + h_block = nstl::max(1, h_block); + w_block = nstl::max(1, w_block); + const size_t inner_work + = div_up(os, simd_w) * div_up(oc, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); + if (thr_eff >= inner_thr_eff / 2 && h_block > 0 + && w_block > 0) { + jcp.oh_block = h_block; + jcp.ow_block = w_block; + jcp.outer_threading = true; + } + // updating jcp.im2col_sz + if (jcp.oh_block != 1) + jcp.ow_block = ow; + jcp.im2col_sz + = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block; + } + // For threading selection in bwd_d we do: + // 1. Rough estimation of efficiency for inner and outer threading. + // 2. Gemm size estimation in assumption that it does not work + // so effectively for small sizes. + // 64K - this is heuristic gemm size per thread threshold. + const int gemm_thrld = 64 * 1024; if (!jcp.outer_threading) { - bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; + bool is_depthwise + = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; const size_t outer_work = jcp.ngroups * jcp.mb; const float outer_thr_eff - = (float)outer_work / rnd_up(outer_work, max_threads); + = (float)outer_work / rnd_up(outer_work, max_threads); const size_t inner_work - = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); const float inner_thr_eff - = (float)inner_work / rnd_up(inner_work, max_threads); - jcp.outer_threading = (is_depthwise - || (jcp.is / max_threads < 64 && jcp.mb != 1)) - && (outer_thr_eff / inner_thr_eff >= 1.f - || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + = (float)inner_work / rnd_up(inner_work, max_threads); + jcp.outer_threading + = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads + < gemm_thrld); } jcp.nthr = jcp.outer_threading ? max_threads : 1; scratchpad.book(key_conv_gemm_col, - sizeof(int8_t) * jcp.nthr * jcp.im2col_sz); + sizeof(int8_t) * jcp.nthr * jcp.im2col_sz); scratchpad.book(key_conv_int_dat_in_acc_dt, - sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc); + sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block + * jcp.oc); scratchpad.book(key_conv_gemm_imtr, - sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic); + sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic); } else if (is_bwd_d) { bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1; const size_t outer_work = jcp.ngroups * jcp.mb; @@ -677,23 +738,259 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp, = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); const float inner_thr_eff = (float)inner_work / rnd_up(inner_work, max_threads); - jcp.outer_threading = (is_depthwise - || (jcp.is / max_threads < 64 && jcp.mb != 1)) - && (outer_thr_eff / inner_thr_eff >= 1.f - || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + jcp.outer_threading + = (is_depthwise + || (jcp.is / max_threads < 64 && jcp.mb != 1)) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads + < gemm_thrld); jcp.nthr = jcp.outer_threading ? max_threads : 1; scratchpad.book(key_conv_gemm_col, - sizeof(int32_t) * jcp.nthr * jcp.im2col_sz); + sizeof(int32_t) * jcp.nthr * jcp.im2col_sz); scratchpad.book(key_conv_int_dat_in_acc_dt, - sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic); + sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic); } else if (is_bwd_w) { assert(!"unimplemented prop_kind"); return status::unimplemented; } } else { if (is_fwd) { - if (!jcp.outer_threading) { + const int sh = jcp.stride_h; + const int sw = jcp.stride_w; + const int spatial = jcp.mb * jcp.ngroups * jcp.od * jcp.os; + int K = jcp.ic * jcp.ks; + + // There is some heuristics in the definition of + // inner/outer threading cross point due to the nature of the + // gemm implementation which we cannot control + bool is_blocking_applicable = true && jcp.id == 1 && jcp.od == 1 + && (!jcp.im2col_sz + // spatial is small + || spatial >= max_threads * simd_w + // inner threading work is greater then outer threading work + || jcp.os < jcp.mb * jcp.ngroups * jcp.od + // im2col is big + || (sw == 1 && K <= 0.05 * jcp.oc)) + // heuristic condition + && (jcp.im2col_sz + || (jcp.ic / jcp.oc < 42 + && jcp.ic * jcp.oc * jcp.is < 1024)); + + if (is_blocking_applicable) { + const int min_oc_block = 8; + const int min_os_block = simd_w; + const float non_cache_access = 20; + const float strided_im2col_k = 8; + const float thr_disb_k = 8; + const float thr_mem_eff_k{ 1 }, oc_disb_k{ 1 }, os_disb_k{ 1 }, + ic_disb_k{ 1 }, reg_osb_disb_k{ 1 }, gemm_eff_k{ 0.5 }, + gemm_calc_eff_k{ 1 }; + const float k_sum = thr_disb_k + oc_disb_k + os_disb_k + + ic_disb_k + reg_osb_disb_k + thr_mem_eff_k + + gemm_eff_k + gemm_calc_eff_k; + + auto calc_max_icb = [=](int nthr_oc, int ocb, int osb, + int oc_per_thr, int os_per_thr) { + const int block_out_size = ocb * osb; + // TODO: need more precise calculation if stride more than + // kernel size + const int inp_row_size = sh * sw * osb; + int max_icb = 1; + if (jcp.im2col_sz) { + const int col_row_size = jcp.ks * osb; + if (osb >= os_per_thr) { // one pass by os + const int wei_col_size = jcp.ks * ocb; + max_icb = L2 / (inp_row_size + col_row_size); + if (ocb < oc_per_thr) { + max_icb = nstl::min(max_icb, + (L2 - block_out_size) + / (col_row_size + + wei_col_size)); + } + } else { + const int wei_col_size = jcp.ks * oc_per_thr; + max_icb = (L2 - block_out_size) + / (inp_row_size + col_row_size + + wei_col_size); + } + } else { + if (osb >= os_per_thr) + max_icb = L2 / inp_row_size; + else { + const int wei_col_size = jcp.ks * oc_per_thr; + max_icb = L2 / (inp_row_size + wei_col_size); + } + } + if (max_icb < jcp.ic) { + if (jcp.im2col_sz) { + const int col_row_size = jcp.ks * osb; + const int wei_col_size = jcp.ks * oc_per_thr; + max_icb = (L2 - block_out_size) + / (inp_row_size + col_row_size + + wei_col_size); + } + } + return max_icb; + }; + + auto est_eff = [=](int nthr_oc, int ocb, int osb, int &icb, + int max_oc_per_thr, int max_os_per_thr) { + // for given nthr_oc, oc block: + // 1. find ic block to fit into cache + // 2. estimate efficiency basing on rules and heuristic: + // - Minimize im2col cost + // - ratio of FMA number to data size + // - gemm works better if M divided by 48 and N divided by 8 + if (osb > max_os_per_thr || ocb > max_oc_per_thr) + return 0.f; + + int sp_start{ 0 }, sp_end{ 0 }, oc_start{ 0 }, oc_end{ 0 }; + size_t max_thr_size{ 0 }, + min_thr_size{ (size_t)spatial * jcp.oc + 1 }, + max_y{ 0 }, max_oc{ 0 }; + + for (int i = 0; i < max_threads; i++) { + balance2D(max_threads, i, spatial, sp_start, sp_end, + jcp.oc, oc_start, oc_end, nthr_oc); + const size_t thr_size + = (sp_end - sp_start) * (oc_end - oc_start); + if (thr_size > max_thr_size) { + max_y = (sp_end - sp_start); + max_oc = (oc_end - oc_start); + max_thr_size = thr_size; + } + if (thr_size < min_thr_size) + min_thr_size = thr_size; + } + auto thr_disb = (float)min_thr_size / max_thr_size; + + const int oc_per_thr = max_oc; + const int os_per_thr = max_y; + ocb = nstl::min(oc_per_thr, ocb); + const int os_max = nstl::min(jcp.os, os_per_thr); + osb = nstl::min(os_max, osb); + + // -- selecting icb --------------------- + int max_ic_block = calc_max_icb( + nthr_oc, ocb, osb, oc_per_thr, os_per_thr); + // if we don't fit into cache then access to memory is + // expensive + int mem_access_cost + = (max_ic_block < 1) ? non_cache_access : 1; + max_ic_block = nstl::max(1, max_ic_block); + icb = nstl::max(1, jcp.ic / div_up(jcp.ic, max_ic_block)); + int nb_ic = div_up(jcp.ic, icb); + int kb = icb * jcp.ks; + int kb_caligned = rnd_up(kb, simd_w); + + // -- mem efficiency ------------ + const size_t out_size + = oc_per_thr * rnd_up(os_per_thr, simd_w); + const size_t out_ops = mem_access_cost * out_size + * ((icb == jcp.ic) ? 1 : (2 * nb_ic - 1)); + const int osb_caligned = rnd_up(osb, simd_w); + const size_t inp_size + = jcp.ic * rnd_up(os_per_thr * sh * sw, simd_w); + size_t inp_ops = 0; + size_t col_ops = 0; + // TODO: simplify calculations + if (jcp.im2col_sz) { + inp_ops = mem_access_cost * jcp.ks * inp_size; + const float col_tail_koeff = (float)osb_caligned / osb; + col_ops = mem_access_cost + * (jcp.ks * inp_size * col_tail_koeff + + jcp.ks * inp_size * col_tail_koeff); + if (sw != 1) // im2col with strides is much slower + col_ops *= strided_im2col_k; + } else { + inp_ops = mem_access_cost * jcp.ks * inp_size; + } + // TODO: what about groups? + const size_t wei_size = oc_per_thr * rnd_up(K, simd_w); + const size_t wei_ops = mem_access_cost * wei_size; + // ratio of real FMA to number of memory ops + const float thr_mem_eff + = (((float)os_per_thr / simd_w) * oc_per_thr * K) + / (inp_ops + col_ops + wei_ops + out_ops); + + auto oc_disb = (float)oc_per_thr / rnd_up(oc_per_thr, ocb); + auto os_disb = (float)os_max / rnd_up(os_max, osb); + auto ic_disb = (float)jcp.ic / rnd_up(jcp.ic, icb); + + auto reg_osb_disb = (float)osb / rnd_up(osb, 3 * simd_w); + + // Heuristics + const float gemm_eff = ((float)osb * ocb * kb) + / ((float)oc_per_thr * os_per_thr * K); + + // number of FMA to memory size + const float gemm_calc_eff + = (((float)osb / simd_w) * ocb * kb) + / (osb_caligned * kb + ocb * kb_caligned + + ocb * osb_caligned); + + const float res_eff = pow(pow(thr_disb, thr_disb_k) + * pow(oc_disb, oc_disb_k) + * pow(os_disb, os_disb_k) + * pow(ic_disb, ic_disb) + * pow(reg_osb_disb, reg_osb_disb_k) + * pow(thr_mem_eff, thr_mem_eff_k) + * pow(gemm_eff, gemm_eff_k) + * pow(gemm_calc_eff, gemm_calc_eff_k), + 1.f / k_sum); + return res_eff; + }; + + /* find the best thread distribution and blocking with highest + * efficiency */ + int best_nthr_oc{ 1 }, best_ocb{ jcp.oc }, best_osb{ jcp.os }, + best_icb{ jcp.ic }; + float best_thr_eff = est_eff(best_nthr_oc, best_ocb, best_osb, + best_icb, jcp.oc, jcp.os); + + int icb{ best_icb }; + const int nthr_oc_max = max_threads; + for (int nthr_oc = 1; nthr_oc <= nthr_oc_max; ++nthr_oc) { + const int max_oc_per_thr = div_up(jcp.oc, nthr_oc); + const int min_oc_per_thr + = nstl::min(min_oc_block, max_oc_per_thr); + const int max_os_per_thr = nstl::min(jcp.os, + div_up(spatial, + nstl::max(1, max_threads / nthr_oc))); + const int min_os_per_thr + = nstl::min(min_os_block, max_os_per_thr); + for (int ocb = min_oc_per_thr; ocb <= max_oc_per_thr; + ocb += nstl::max(1, + nstl::min(min_oc_block, + max_oc_per_thr - ocb))) { + for (int osb = min_os_per_thr; osb <= jcp.os; + osb += nstl::max(1, + nstl::min(min_os_block, + max_os_per_thr - osb))) { + float thr_eff = est_eff(nthr_oc, ocb, osb, icb, + max_oc_per_thr, max_os_per_thr); + if (thr_eff > best_thr_eff) { + best_thr_eff = thr_eff; + best_nthr_oc = nthr_oc; + best_ocb = ocb; + best_osb = osb; + best_icb = icb; + } + } + } + } + + jcp.outer_threading = true; + jcp.nthr_oc = best_nthr_oc; + jcp.oc_block = best_ocb; + jcp.os_block = best_osb; + jcp.ic_block = best_icb; + // TODO: define loop order + // if im2col then gemm_loop_rlb and gemm_loop_lrb looks + // preferable otherwise other loop orders are possible + jcp.loop_order = gemm_loop_rlb; + } else { const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od; const float outer_thr_eff = (float)outer_work_amount / rnd_up(outer_work_amount, max_threads); @@ -702,38 +999,43 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp, const float inner_thr_eff = (float)inner_work_amount / rnd_up(inner_work_amount, max_threads); jcp.outer_threading = jcp.os / max_threads < 512 - && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2) - && (outer_thr_eff / inner_thr_eff >= 1.f - || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + && IMPLICATION( + jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.os * jcp.ic * jcp.oc) / max_threads + < gemm_thrld); } + if (jcp.im2col_sz) + jcp.im2col_sz = (ptrdiff_t)jcp.ic_block * jcp.ks * jcp.os_block; } else if (is_bwd_d) { const size_t outer_work_amount = jcp.ngroups * jcp.mb; const float outer_thr_eff = (float)outer_work_amount - / rnd_up(outer_work_amount, max_threads); + / rnd_up(outer_work_amount, max_threads); const size_t inner_work - = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); - const float inner_thr_eff = (float)inner_work - / rnd_up(inner_work, max_threads); + = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w); + const float inner_thr_eff + = (float)inner_work / rnd_up(inner_work, max_threads); jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64) - && (jcp.mb != 1 || jcp.ngroups > 2) - && (outer_thr_eff / inner_thr_eff >= 1.f - || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld); + && (jcp.mb != 1 || jcp.ngroups > 2) + && (outer_thr_eff / inner_thr_eff >= 1.f + || (jcp.is * jcp.ic * jcp.oc) / max_threads + < gemm_thrld); } else if (is_bwd_w) jcp.outer_threading = jcp.os / max_threads < 256 - && (jcp.mb != 1 || jcp.ngroups > 2); + && (jcp.mb != 1 || jcp.ngroups > 2); jcp.nthr = jcp.outer_threading ? max_threads : 1; - scratchpad.book(key_conv_gemm_col, - sizeof(float) * jcp.nthr * jcp.im2col_sz); + scratchpad.book( + key_conv_gemm_col, sizeof(float) * jcp.nthr * jcp.im2col_sz); if (is_bwd_w) { jcp.need_wei_reduction = mkldnn_thr_syncable() - ? jcp.mb != 1 && jcp.nthr != 1 : false; + ? jcp.mb != 1 && jcp.nthr != 1 + : false; scratchpad.book(key_conv_wei_reduction, sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size()); } } - return status::success; } diff --git a/src/cpu/gemm_convolution_utils.hpp b/src/cpu/gemm_convolution_utils.hpp index e006789344e..ff4165b98ab 100644 --- a/src/cpu/gemm_convolution_utils.hpp +++ b/src/cpu/gemm_convolution_utils.hpp @@ -34,7 +34,7 @@ namespace jit_gemm_convolution_utils { void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col, int od); void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im, - float *__restrict col, int hs, int hb, int ws, int wb); + float *__restrict col, int ss, int sb, int cs, int cb); template <typename T> void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im, T* __restrict imtr, uint8_t *__restrict col, diff --git a/src/cpu/jit_primitive_conf.hpp b/src/cpu/jit_primitive_conf.hpp index 5c011cdb0cc..7b394a94f28 100644 --- a/src/cpu/jit_primitive_conf.hpp +++ b/src/cpu/jit_primitive_conf.hpp @@ -31,6 +31,8 @@ enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn, loop_ngcw, loop_nhwcg, loop_nwcg}; enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr, loop_brl}; +enum conv_gemm_loop_order_t { gemm_loop_rlb, gemm_loop_lrb }; + enum conv_kernel_kind_t {embd_bcast, expl_bcast}; enum conv_harness_t {harness_2d_reduction, harness_3d_reduction, harness_mb_reduction}; @@ -416,7 +418,10 @@ struct jit_gemm_conv_conf_t { bool signed_input; int oh_block; int ow_block; + int os_block; bool outer_threading; + conv_gemm_loop_order_t loop_order; + int nthr_oc; }; struct jit_1x1_conv_call_s {