From 75cc3562aa508877f1d852946ee05e8d9e33bf6a Mon Sep 17 00:00:00 2001
From: Andrey Kalinin <andrey.kalinin@intel.com>
Date: Sat, 23 Mar 2019 14:05:49 -0700
Subject: [PATCH] cpu: conv: gemm f32 forward: oc/os/ic blocking

---
 src/common/mkldnn_thread.hpp       |   3 +-
 src/cpu/gemm_convolution.cpp       | 197 +++++---
 src/cpu/gemm_convolution_utils.cpp | 754 ++++++++++++++++++++---------
 src/cpu/gemm_convolution_utils.hpp |   2 +-
 src/cpu/jit_primitive_conf.hpp     |   5 +
 5 files changed, 663 insertions(+), 298 deletions(-)

diff --git a/src/common/mkldnn_thread.hpp b/src/common/mkldnn_thread.hpp
index 8e6f579cd7c..9ece18a8134 100644
--- a/src/common/mkldnn_thread.hpp
+++ b/src/common/mkldnn_thread.hpp
@@ -125,8 +125,7 @@ void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
         grp = ithr / grp_size_big;
         grp_ithr = ithr % grp_size_big;
         grp_nthr = grp_size_big;
-    }
-    else { // ithr in last groups
+    } else { // ithr in last groups
         grp = n_grp_big + ithr_bound_distance / grp_size_small;
         grp_ithr = ithr_bound_distance % grp_size_small;
         grp_nthr = grp_size_small;
diff --git a/src/cpu/gemm_convolution.cpp b/src/cpu/gemm_convolution.cpp
index 604a728b476..c6dc6088ce4 100644
--- a/src/cpu/gemm_convolution.cpp
+++ b/src/cpu/gemm_convolution.cpp
@@ -31,6 +31,18 @@ using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
+namespace {
+struct im_pos_t {
+    im_pos_t() : n{ 0 }, g{ 0 }, od{ 0 }, sp{ 0 }, ic{ 0 }, oc{ 0 } {}
+    int n, g, od, sp, ic, oc;
+    bool do_im2col(const im_pos_t &prev) const {
+        return true
+                && (n != prev.n || g != prev.g || od != prev.od || sp != prev.sp
+                           || ic != prev.ic);
+    }
+};
+} // namespace
+
 void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
     auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
     auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
@@ -41,97 +53,144 @@ void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
 
     const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
 
-    const int M = jcp.os * jcp.od;
     const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
-    const size_t dst_step = jcp.oc * M;
-    const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
+    const size_t weights_oc_size = jcp.ic * jcp.ks;
+    const size_t weights_g_size = weights_oc_size * jcp.oc;
 
     assert(IMPLICATION(
-            jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
-    assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
-
-    const int K = jcp.ic * jcp.ks;
-    const int N = jcp.oc;
+            jcp.id != 1, jcp.os_block == jcp.os && jcp.ic_block == jcp.ic));
 
     if (jcp.im2col_sz && jcp.id != 1)
         parallel_nd(jcp.im2col_sz * jcp.nthr,
                 [&](ptrdiff_t i) { col[i] = (data_t)0; });
 
-    const int nb_oh = div_up(jcp.oh, jcp.oh_block);
-    const int nb_ow = div_up(jcp.ow, jcp.ow_block);
-    const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow;
     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
         data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
 
-        int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 };
-        size_t start = 0, end = 0;
-
-        balance211(work_amount, nthr, ithr, start, end);
-        nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb,
-                nb_oh, owb, nb_ow);
-        for (size_t iwork = start; iwork < end; ++iwork) {
-            int oh = ohb * jcp.oh_block;
-            int ow = owb * jcp.ow_block;
-            const data_t *_src = src + (n * jcp.ngroups + g) * src_step;
-            const data_t *_weights = weights + g * weights_g_size;
-            data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step;
-            const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
-            const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
-            if (jcp.im2col_sz) {
+        auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev,
+                                 im_pos_t &step, const im_pos_t &end) {
+            const data_t *_src
+                    = src + (curr.n * jcp.ngroups + curr.g) * src_step;
+            step.oc = nstl::min(
+                    jcp.oc_block, nstl::min(jcp.oc, end.oc) - curr.oc);
+            step.sp = nstl::min(jcp.os_block,
+                    nstl::min(jcp.os - curr.sp, end.sp - spatial));
+            step.ic = nstl::min(
+                    jcp.ic_block, nstl::min(jcp.ic, end.ic) - curr.ic);
+            bool do_im2col = curr.do_im2col(prev);
+            prev = curr;
+
+            if (jcp.im2col_sz && do_im2col) {
                 if (jcp.id == 1)
-                    jit_gemm_convolution_utils::im2col(
-                            jcp, _src, _col, oh, h_step, ow, w_step);
+                    jit_gemm_convolution_utils::im2col(jcp, _src, _col, curr.sp,
+                            step.sp, curr.ic, step.ic);
                 else
-                    jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od);
+                    jit_gemm_convolution_utils::im2col_3d(
+                            jcp, _src, _col, curr.od);
             }
-
             const data_t one = 1.0;
 
-            const int m = h_step * w_step;
+            const int M = jcp.os * jcp.od;
+            const size_t dst_step = jcp.oc * M;
+            const int m = step.sp;
             const int LDA = jcp.im2col_sz ? m : M;
-            data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow;
-
-            extended_sgemm("N", "N", &m, &N, &K, &one,
-                    jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K,
-                    &this->beta_, _dst, &M);
-
-            data_t *d = _dst;
-            if (eltwise_) {
-                // fast branch for ReLU case
-                if (eltwise_->alg_ == alg_kind::eltwise_relu) {
-                    parallel_nd(jcp.oc, [&](const int oc) {
-                        data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0;
-                        data_t *d_ = d + oc * M;
-                        PRAGMA_OMP_SIMD()
-                        for (int oS = 0; oS < m; ++oS) {
-                            d_[oS] += b;
-                            if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_;
-                        }
-                    });
-                } else {
-                    parallel_nd(jcp.oc, [&](const int oc) {
-                        data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0;
-                        data_t *d_ = d + oc * M;
+            data_t *_dst = dst + (curr.n * jcp.ngroups + curr.g) * dst_step
+                    + curr.oc * M + curr.od * jcp.os + curr.sp;
+            const int K = step.ic * jcp.ks;
+            const int LDB = jcp.ic * jcp.ks;
+            const int N = step.oc;
+
+            // TODO: what if this->beta_ != 0 && != 1 ?
+            const float beta = (curr.ic == 0) ? this->beta_ : one;
+            const float *_source = jcp.im2col_sz
+                    ? _col
+                    : _src + curr.ic * M + curr.od * jcp.os + curr.sp;
+            const data_t *_weights = weights + curr.g * weights_g_size
+                    + curr.oc * weights_oc_size + curr.ic * jcp.ks;
+
+            extended_sgemm("N", "N", &m, &N, &K, &one, _source, &LDA, _weights,
+                    &LDB, &beta, _dst, &M);
+            if (curr.ic == jcp.ic - step.ic) {
+                const int oc_start = curr.g * jcp.oc + curr.oc;
+                if (eltwise_) {
+                    // fast branch for ReLU case
+                    if (eltwise_->alg_ == alg_kind::eltwise_relu) {
+                        parallel_nd(step.oc, [&](const int oc) {
+                            data_t b = jcp.with_bias ? bias[oc_start + oc] : 0;
+                            data_t *d_ = _dst + oc * M;
+                            PRAGMA_OMP_SIMD()
+                            for (int oS = 0; oS < m; ++oS) {
+                                d_[oS] += b;
+                                if (d_[oS] < 0)
+                                    d_[oS] *= eltwise_->alpha_;
+                            }
+                        });
+                    } else {
+                        parallel_nd(step.oc, [&](const int oc) {
+                            data_t b = jcp.with_bias ? bias[oc_start + oc] : 0;
+                            data_t *d_ = _dst + oc * M;
+                            PRAGMA_OMP_SIMD()
+                            for (int oS = 0; oS < m; ++oS) {
+                                d_[oS] += b;
+                                d_[oS] = eltwise_->compute_scalar(d_[oS]);
+                            }
+                        });
+                    }
+                } else if (jcp.with_bias) {
+                    parallel_nd(step.oc, [&](const int oc) {
+                        data_t b = bias[oc_start + oc];
+                        data_t *d_ = _dst + oc * M;
                         PRAGMA_OMP_SIMD()
                         for (int oS = 0; oS < m; ++oS) {
                             d_[oS] += b;
-                            d_[oS] = eltwise_->compute_scalar(d_[oS]);
                         }
                     });
                 }
-            } else if (jcp.with_bias) {
-                parallel_nd(jcp.oc, [&](const int oc) {
-                    data_t b = bias[g * jcp.oc + oc];
-                    data_t *d_ = d + oc * M;
-                    PRAGMA_OMP_SIMD()
-                    for (int oS = 0; oS < m; ++oS) {
-                        d_[oS] += b;
-                    }
-                });
             }
-            nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh,
-                    owb, nb_ow);
+        };
+        im_pos_t start, end;
+        end.ic = jcp.ic;
+
+        if (jcp.id == 1) {
+            const int sp_work = jcp.mb * jcp.ngroups * jcp.od * jcp.os;
+            balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc,
+                    end.oc, jcp.nthr_oc);
+        } else {
+            const int sp_work = jcp.mb * jcp.ngroups * jcp.od;
+            balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc,
+                    end.oc, jcp.nthr_oc);
+            start.sp *= jcp.os;
+            end.sp *= jcp.os;
         }
+
+        im_pos_t curr, prev, step;
+        prev.n = prev.g = prev.od = prev.sp = prev.ic = -1;
+        step.oc = jcp.oc_block;
+        step.sp = jcp.os_block;
+        step.ic = jcp.ic_block;
+
+        if (jcp.loop_order == gemm_loop_rlb)
+            for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic)
+                for (int spatial = start.sp; spatial < end.sp;
+                        spatial += step.sp) {
+                    nd_iterator_init(spatial, curr.n, jcp.mb, curr.g,
+                            jcp.ngroups, curr.od, jcp.od, curr.sp, jcp.os);
+                    for (curr.oc = start.oc; curr.oc < end.oc;
+                            curr.oc += step.oc) {
+                        inner_ker(spatial, curr, prev, step, end);
+                    }
+                }
+        else if (jcp.loop_order == gemm_loop_lrb)
+            for (int spatial = start.sp; spatial < end.sp; spatial += step.sp) {
+                nd_iterator_init(spatial, curr.n, jcp.mb, curr.g, jcp.ngroups,
+                        curr.od, jcp.od, curr.sp, jcp.os);
+                for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic)
+                    for (curr.oc = start.oc; curr.oc < end.oc;
+                            curr.oc += step.oc)
+                        inner_ker(spatial, curr, prev, step, end);
+            }
+        else
+            assert("Unknown loop order");
     });
 }
 
@@ -256,10 +315,10 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights(
                     if (jcp.im2col_sz) {
                         if (jcp.id == 1)
                             jit_gemm_convolution_utils::im2col(
-                                    jcp, _src, _col, 0, jcp.oh, 0, jcp.ow);
+                                    jcp, _src, _col, 0, jcp.os, 0, jcp.ic);
                         else
-                            jit_gemm_convolution_utils::im2col_3d(jcp, _src,
-                                _col, od);
+                            jit_gemm_convolution_utils::im2col_3d(
+                                    jcp, _src, _col, od);
                     }
 
                     const data_t zero = 0.0, one = 1.0;
diff --git a/src/cpu/gemm_convolution_utils.cpp b/src/cpu/gemm_convolution_utils.cpp
index a78d7d2a7ed..0db572f28d7 100644
--- a/src/cpu/gemm_convolution_utils.cpp
+++ b/src/cpu/gemm_convolution_utils.cpp
@@ -119,106 +119,155 @@ void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
     });
 }
 
+inline int saturate(int low, int upper, int value) {
+    return nstl::max(low, nstl::min(upper, value));
+}
+
 /* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */
-void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
-       float *__restrict col, int hs, int hb, int ws, int wb) {
+void im2col(const jit_gemm_conv_conf_t &jcp,
+        const float *__restrict im, float *__restrict col, int ss, int sb,
+        int cs, int cb) {
     const size_t im_step = jcp.is;
-    const size_t col_step = jcp.ks * hb * wb;
-    if (jcp.stride_w == 1) {
-        // Generated code is more optimized for stride_w == 1
-        // because innermost loop is by width
-        auto ker = [&](int ic, int kh, int kw, int oh) {
-            const float *__restrict im_ = im + ic * im_step;
-            float *__restrict col_
-                = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb;
-
-            const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
-                + kh * (1 + jcp.dilate_h);
-            if (ih < 0 || ih >= jcp.ih) {
-                for (int ow = 0; ow < wb; ++ow)
-                    col_[ow] = 0.f;
-            } else {
-                for (int ow = 0; ow < wb; ++ow) {
-                    const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w);
-                    if (iw < 0 || iw >= jcp.iw)
-                        col_[ow] = 0.f;
-                    else {
-                        const size_t im_idx = ih * jcp.iw + iw;
-                        col_[ow] = im_[im_idx];
+    const size_t col_step = jcp.ks * sb;
+    const int dh = 1 + jcp.dilate_h;
+    const int dw = 1 + jcp.dilate_w;
+    const int sh = jcp.stride_h;
+    const int sw = jcp.stride_w;
+    const int tp = jcp.t_pad;
+    const int lp = jcp.l_pad;
+    const int first_oh = ss / jcp.ow;
+    const int last_oh = (ss + sb - 1) / jcp.ow;
+    const int oh_begin = first_oh;
+    const int oh_end = last_oh + 1;
+    const int first_ow = ss % jcp.ow;
+    const int last_ow = (ss + sb - 1) % jcp.ow;
+
+    if (jcp.outer_threading) {
+        if (sw == 1) {
+            // Generated code is more optimized for stride_w == 1
+            // because innermost loop is by width
+            for (int ic = 0; ic < cb; ic++) {
+                const float *__restrict im_ic = im + (ic + cs) * im_step;
+                for (int kh = 0; kh < jcp.kh; kh++) {
+                    for (int kw = 0; kw < jcp.kw; kw++) {
+                        float *__restrict col_k
+                                = col + ic * col_step + (kh * jcp.kw + kw) * sb;
+                        for (int oh = oh_begin; oh < oh_end; oh++) {
+                            const int ih = oh * sh - tp + kh * dh;
+                            const float *__restrict im_
+                                    = im_ic + ih * jcp.iw - lp + kw * dw;
+                            const int ow_begin
+                                    = (oh == first_oh) ? first_ow : 0;
+                            const int ow_end
+                                    = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
+                            float *__restrict col_ = col_k + oh * jcp.ow - ss;
+                            if (ih < 0 || ih >= jcp.ih)
+                                for (int ow = ow_begin; ow < ow_end; ow++)
+                                    col_[ow] = 0.f;
+                            else {
+                                for (int ow = ow_begin; ow < ow_end; ++ow) {
+                                    const int iw = ow;
+                                    if (iw < lp - kw * dw || iw >= jcp.iw + lp - kw * dw)
+                                        col_[ow] = 0.f;
+                                    else
+                                        col_[ow] = im_[iw];
+                                }
+                            }
+                        }
                     }
                 }
             }
-        };
-
-        if (jcp.outer_threading) {
-            for (int ic = 0; ic < jcp.ic; ic++)
-                for (int kh = 0; kh < jcp.kh; kh++)
-                    for (int kw = 0; kw < jcp.kw; kw++)
-                        for (int oh = 0; oh < hb; oh++)
-                            ker(ic, kh, kw, oh);
-        }
-        else {
-            parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker);
-        }
-    } else if (jcp.ic == 1) {
-        parallel_nd(jcp.kh, hb, [&](int kh, int oh) {
-            const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
-                    + kh * (1 + jcp.dilate_h);
-            if (ih < 0 || ih >= jcp.ih)
-                for (int kw = 0; kw < jcp.kw; ++kw) {
-                    for (int ow = 0; ow < wb; ++ow) {
-                        const size_t col_idx
-                                = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
-                        col[col_idx] = 0;
-                    }
-                }
-            else
-                for (int kw = 0; kw < jcp.kw; ++kw) {
-                    for (int ow = 0; ow < wb; ++ow) {
-                        const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
-                                + kw * (1 + jcp.dilate_w);
-                        const size_t col_idx
-                                = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
-                        const size_t im_idx = ih * jcp.iw + iw;
-                        if (iw < 0 || iw >= jcp.iw)
-                            col[col_idx] = 0;
-                        else
-                            col[col_idx] = im[im_idx];
+        } else {
+            for (int ic = 0; ic < cb; ic++) {
+                const float *__restrict im_ = im + (ic + cs) * im_step;
+                for (int kh = 0; kh < jcp.kh; kh++) {
+                    for (int kw = 0; kw < jcp.kw; kw++) {
+                        float *__restrict col_k
+                                = col + ic * col_step + (kh * jcp.kw + kw) * sb;
+                        for (int oh = oh_begin; oh < oh_end; oh++) {
+                            const int ih = oh * sh - tp + kh * dh;
+                            const int ow_begin
+                                    = (oh == first_oh) ? first_ow : 0;
+                            const int ow_end
+                                    = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
+                            float *__restrict col_oh = col_k + oh * jcp.ow - ss;
+                            if (ih < 0 || ih >= jcp.ih)
+                                for (int ow = ow_begin; ow < ow_end; ow++)
+                                    col_oh[ow] = 0.f;
+                            else
+                                for (int ow = ow_begin; ow < ow_end; ow++) {
+                                    const int iw = ow * sw - lp + kw * dw;
+                                    if (iw < 0 || iw >= jcp.iw)
+                                        col_oh[ow] = 0.f;
+                                    else {
+                                        const ptrdiff_t im_idx
+                                                = ih * jcp.iw + iw;
+                                        col_oh[ow] = im_[im_idx];
+                                    }
+                                }
+                        }
                     }
                 }
-        });
-    } else {
-
-        parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb,
-            [&](int ic, int kh, int kw, int oh) {
-            const float *__restrict im_ = im + ic * im_step;
-            float *__restrict col_ = col + ic * col_step
-                + ((kh * jcp.kw + kw) * hb + oh) * wb;
-
-            const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
-                + kh * (1 + jcp.dilate_h);
-            if (ih < 0 || ih >= jcp.ih) {
-                for (int ow = 0; ow < wb; ++ow)
-                    col_[ow] = 0.f;
-            } else {
-                for (int ow = 0; ow < wb; ++ow) {
-                    const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
-                        + kw * (1 + jcp.dilate_w);
-                    const size_t im_idx = ih * jcp.iw + iw;
-                    if (iw < 0 || iw >= jcp.iw)
-                        col_[ow] = 0.f;
-                    else
-                        col_[ow] = im_[im_idx];
-                }
             }
-        });
+        }
+    } else {
+        // TODO: optimize threading if jcp.ic*jcp.kh*jcp.kw*oh_range is small
+        // comparing to number of threads
+        const int oh_range = oh_end - oh_begin;
+        // Generated code is more optimized for stride_w == 1
+        // because innermost loop is by width
+        if (sw == 1)
+            parallel_nd(cb, jcp.kh, jcp.kw, oh_range,
+                    [&](int ic, int kh, int kw, int ohr) {
+                        const int oh = ohr + oh_begin;
+                        const int ih = oh * sh - tp + kh * dh;
+                        const int ow_start = (oh == first_oh) ? first_ow : 0;
+                        const int ow_end
+                                = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
+                        float *__restrict col_oh = col + ic * col_step
+                                + (kh * jcp.kw + kw) * sb + oh * jcp.ow - ss;
+                        const float *__restrict im_ = im + (ic + cs) * im_step + ih * jcp.iw;
+                        const int iw_shift = kw * dw - lp;
+                        if (ih < 0 || ih >= jcp.ih)
+                            for (int ow = ow_start; ow < ow_end; ow++)
+                                col_oh[ow] = 0.f;
+                        else
+                            for (int ow = ow_start; ow < ow_end; ow++) {
+                                const int iw = ow + iw_shift;
+                                if (iw < 0 || iw >= jcp.iw)
+                                    col_oh[ow] = 0.f;
+                                else
+                                    col_oh[ow] = im_[iw];
+                            }
+                    });
+        else
+            parallel_nd(cb, jcp.kh, jcp.kw, oh_range,
+                    [&](int ic, int kh, int kw, int ohr) {
+                        const int oh = ohr + oh_begin;
+                        const int ih = oh * sh - tp + kh * dh;
+                        const int ow_start = (oh == first_oh) ? first_ow : 0;
+                        const int ow_end
+                                = (oh == last_oh) ? (last_ow + 1) : jcp.ow;
+                        float *__restrict col_oh = col + ic * col_step
+                                + (kh * jcp.kw + kw) * sb + oh * jcp.ow - ss;
+                        const float *__restrict im_ = im + (ic + cs) * im_step;
+                        if (ih < 0 || ih >= jcp.ih)
+                            for (int ow = ow_start; ow < ow_end; ow++)
+                                col_oh[ow] = 0.f;
+                        else
+                            for (int ow = ow_start; ow < ow_end; ow++) {
+                                const int iw = ow * sw - lp + kw * dw;
+                                if (iw < 0 || iw >= jcp.iw)
+                                    col_oh[ow] = 0.f;
+                                else {
+                                    const ptrdiff_t im_idx = ih * jcp.iw + iw;
+                                    col_oh[ow] = im_[im_idx];
+                                }
+                            }
+                    });
     }
 }
 
-inline int limit(int low, int upper, int value) {
-    return nstl::max(low, nstl::min(upper, value));
-}
-
 /* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */
 template <typename T>
 void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
@@ -238,10 +287,10 @@ void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
         /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */
         const int hp = hs - tp;
         const int wp = ws - lp;
-        const int ih_start = limit(0, jcp.ih, hp);
-        const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh);
-        const int iw_start = limit(0, jcp.iw, wp);
-        const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw);
+        const int ih_start = saturate(0, jcp.ih, hp);
+        const int ih_end = saturate(0, jcp.ih, hp + hb + jcp.kh);
+        const int iw_start = saturate(0, jcp.iw, wp);
+        const int iw_end = saturate(0, jcp.iw, wp + wb + jcp.kw);
 
         const int ihb = ih_end - ih_start;
         const int iwb = iw_end - iw_start;
@@ -267,15 +316,15 @@ void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
         for (int kh = 0; kh < jcp.kh; kh++) {
             const ptrdiff_t col_idx_kh = kh * col_kh_stride;
             const int oh_kh = oh_init - kh;
-            const int oh_start = limit(0, hb, oh_kh);
-            const int oh_end = limit(0, hb, oh_kh + ihb);
+            const int oh_start = saturate(0, hb, oh_kh);
+            const int oh_end = saturate(0, hb, oh_kh + ihb);
             for (int kw = 0; kw < jcp.kw; kw++) {
                 const ptrdiff_t col_idx_kw
                         = col_idx_kh + kw * jcp.ic * col_ic_str;
                 const int ow_kw = ow_init - kw;
                 const int imtr_shift = oh_kh * iwb + ow_kw;
-                const int ow_start = limit(0, wb, ow_kw);
-                const int ow_end = limit(0, wb, ow_kw + iwb);
+                const int ow_start = saturate(0, wb, ow_kw);
+                const int ow_end = saturate(0, wb, ow_kw + iwb);
                 for (int ic = 0; ic < jcp.ic; ic++) {
                     const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str;
                     const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift;
@@ -315,9 +364,9 @@ void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
                         col[col_idx_base + ow] = shift;
                 else {
                     const int wp = lp - kw * dw;
-                    const int ow_start = limit(0, wb, div_up(wp, sw) - ws);
+                    const int ow_start = saturate(0, wb, div_up(wp, sw) - ws);
                     const int ow_end
-                            = limit(0, wb, div_up(jcp.iw + wp, sw) - ws);
+                            = saturate(0, wb, div_up(jcp.iw + wp, sw) - ws);
                     for (int ow = 0; ow < ow_start; ow++)
                         col[col_idx_base + ow] = shift;
                     const int iw_base = ws * sw - wp;
@@ -540,6 +589,12 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp,
     const bool is_bwd_d = jcp.prop_kind == backward_data;
     const bool is_bwd_w = jcp.prop_kind == backward_weights;
     const bool is_fwd = !is_bwd_d && !is_bwd_w;
+    jcp.os_block = jcp.os;
+    jcp.oc_block = jcp.oc;
+    jcp.ic_block = jcp.ic;
+    jcp.loop_order = gemm_loop_rlb;
+    jcp.nthr_oc = 1;
+
     jcp.oh_block = is_fwd ? jcp.oh : jcp.ih;
     jcp.ow_block = is_fwd ? jcp.ow : jcp.iw;
 
@@ -547,127 +602,133 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp,
     bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
 
     // TODO: maybe mitigate blocking restriction
-    const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
     const int L2 = get_cache_size(2, true)
           / (is_int8_conv ? sizeof(int8_t) : sizeof(float));
-    bool is_blocking_applicable = true
-            && is_fwd && jcp.im2col_sz
-            && jcp.id == 1 && jcp.od == 1
-            && jcp.dilate_h == 0 && jcp.dilate_w == 0
-            && !is_depthwise
-            && wei_size < L2/2;
-    if (is_blocking_applicable) {
-        // looking for oh and ow blocking
-        int h_block{ jcp.oh_block }, w_block{ jcp.ow_block };
-        const int ic = jcp.ic;
-        const int oc = jcp.oc;
-        const int iw = jcp.iw;
-        const int ow = jcp.ow;
-        const int oh = jcp.oh;
-        const int os = oh * ow;
-
-        // 1. cache requirement
-        int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow);
-        if (is_int8_conv) {
-            // Heuristic rule: gemm needed a lot of memory for internal usage
-            row_size *= 5;
-            // memory for accumulators
-            row_size += oc * ow * sizeof(uint32_t);
-            // memory for transposition
-            row_size += ic * iw;
-        }
-
-        h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size)));
-        if (h_block == 1) {
-            int col_size = ic * jcp.ks + 2 * (ic + oc);
-            if (is_int8_conv) {
-                col_size *= 5;
-                col_size += oc * sizeof(uint32_t);
-                col_size += ic;
-            }
-            w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size)));
-        }
-
-        // 2. threading requirement
-        if (h_block != oh)
-            h_block = nstl::max(1, rnd_dn(h_block, 4));
-        if (w_block != ow)
-            w_block = nstl::max(1, rnd_dn(w_block, simd_w));
-
-        float thr_eff = 0.f;
-        float thr_eff_treshold = 0.9f;
-        if (w_block == ow) {
-            do {
-                int nb_h = div_up(oh, h_block);
-                size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h;
-                float disb = (float)oh / rnd_up(oh, h_block);
-                thr_eff = (float)work / rnd_up(work, max_threads);
-                thr_eff = (thr_eff + disb) / 2.f;
-                if (thr_eff >= thr_eff_treshold)
-                    break;
-                h_block = rnd_dn(h_block - 4, 4);
-            } while (h_block > 0);
-        }
-        if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block
-        {
-            h_block = 1;
-            int nb_h = oh;
-            do {
-                int nb_w = div_up(ow, w_block);
-                size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w;
-                float disb = (float)ow / rnd_up(ow, w_block);
-                thr_eff = (float)work_amount / rnd_up(work_amount, max_threads);
-                thr_eff = (thr_eff + disb) / 2.f;
-                if (thr_eff > thr_eff_treshold)
-                    break;
-                w_block = rnd_dn(w_block - simd_w, simd_w);
-            } while (w_block > 0);
-        }
-        h_block = nstl::max(1, h_block);
-        w_block = nstl::max(1, w_block);
-        const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w);
-        const float inner_thr_eff
-                = (float)inner_work / rnd_up(inner_work, max_threads);
-        if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) {
-            jcp.oh_block = h_block;
-            jcp.ow_block = w_block;
-            jcp.outer_threading = true;
-        }
-        // updating jcp.im2col_sz
-        if (jcp.oh_block != 1)
-            jcp.ow_block = ow;
-        jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block;
-    }
-    //  For threading selection in bwd_d we do:
-    //  1. Rough estimation of efficiency for inner and outer threading.
-    //  2. Gemm size estimation in assumption that it does not work
-    //  so effectively for small sizes.
-    //  64K - this is heuristic gemm size per thread threshold.
     const int gemm_thrld = 64 * 1024;
 
     if (is_int8_conv) {
         if (is_fwd) {
+            const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
+            bool is_blocking_applicable = true && is_fwd && jcp.im2col_sz
+                    && jcp.id == 1 && jcp.od == 1 && jcp.dilate_h == 0
+                    && jcp.dilate_w == 0 && !is_depthwise && wei_size < L2 / 2;
+            if (is_blocking_applicable) {
+                // looking for oh and ow blocking
+                int h_block{ jcp.oh_block }, w_block{ jcp.ow_block };
+                const int ic = jcp.ic;
+                const int oc = jcp.oc;
+                const int iw = jcp.iw;
+                const int ow = jcp.ow;
+                const int oh = jcp.oh;
+                const int os = oh * ow;
+
+                // 1. cache requirement
+                int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow);
+                // Heuristic rule: gemm needed a lot of memory for internal
+                // usage
+                row_size *= 5;
+                // memory for accumulators
+                row_size += oc * ow * sizeof(uint32_t);
+                // memory for transposition
+                row_size += ic * iw;
+
+                h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size)));
+                if (h_block == 1) {
+                    int col_size = ic * jcp.ks + 2 * (ic + oc);
+                    if (is_int8_conv) {
+                        col_size *= 5;
+                        col_size += oc * sizeof(uint32_t);
+                        col_size += ic;
+                    }
+                    w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size)));
+                }
+
+                // 2. threading requirement
+                if (h_block != oh)
+                    h_block = nstl::max(1, rnd_dn(h_block, 4));
+                if (w_block != ow)
+                    w_block = nstl::max(1, rnd_dn(w_block, simd_w));
+
+                float thr_eff = 0.f;
+                float thr_eff_treshold = 0.9f;
+                if (w_block == ow) {
+                    do {
+                        int nb_h = div_up(oh, h_block);
+                        size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h;
+                        float disb = (float)oh / rnd_up(oh, h_block);
+                        thr_eff = (float)work / rnd_up(work, max_threads);
+                        thr_eff = (thr_eff + disb) / 2.f;
+                        if (thr_eff >= thr_eff_treshold)
+                            break;
+                        h_block = rnd_dn(h_block - 4, 4);
+                    } while (h_block > 0);
+                }
+                if (thr_eff
+                        < thr_eff_treshold) // we didn't find suitable h_block
+                {
+                    h_block = 1;
+                    int nb_h = oh;
+                    do {
+                        int nb_w = div_up(ow, w_block);
+                        size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w;
+                        float disb = (float)ow / rnd_up(ow, w_block);
+                        thr_eff = (float)work_amount
+                                / rnd_up(work_amount, max_threads);
+                        thr_eff = (thr_eff + disb) / 2.f;
+                        if (thr_eff > thr_eff_treshold)
+                            break;
+                        w_block = rnd_dn(w_block - simd_w, simd_w);
+                    } while (w_block > 0);
+                }
+                h_block = nstl::max(1, h_block);
+                w_block = nstl::max(1, w_block);
+                const size_t inner_work
+                        = div_up(os, simd_w) * div_up(oc, simd_w);
+                const float inner_thr_eff
+                        = (float)inner_work / rnd_up(inner_work, max_threads);
+                if (thr_eff >= inner_thr_eff / 2 && h_block > 0
+                        && w_block > 0) {
+                    jcp.oh_block = h_block;
+                    jcp.ow_block = w_block;
+                    jcp.outer_threading = true;
+                }
+                // updating jcp.im2col_sz
+                if (jcp.oh_block != 1)
+                    jcp.ow_block = ow;
+                jcp.im2col_sz
+                        = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block;
+            }
+            //  For threading selection in bwd_d we do:
+            //  1. Rough estimation of efficiency for inner and outer threading.
+            //  2. Gemm size estimation in assumption that it does not work
+            //  so effectively for small sizes.
+            //  64K - this is heuristic gemm size per thread threshold.
+            const int gemm_thrld = 64 * 1024;
             if (!jcp.outer_threading) {
-                bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
+                bool is_depthwise
+                        = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
                 const size_t outer_work = jcp.ngroups * jcp.mb;
                 const float outer_thr_eff
-                    = (float)outer_work / rnd_up(outer_work, max_threads);
+                        = (float)outer_work / rnd_up(outer_work, max_threads);
                 const size_t inner_work
-                    = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
+                        = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
                 const float inner_thr_eff
-                    = (float)inner_work / rnd_up(inner_work, max_threads);
-                jcp.outer_threading = (is_depthwise
-                 || (jcp.is / max_threads < 64 && jcp.mb != 1))
-                 && (outer_thr_eff / inner_thr_eff >= 1.f
-                     || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
+                        = (float)inner_work / rnd_up(inner_work, max_threads);
+                jcp.outer_threading
+                        = (is_depthwise
+                                  || (jcp.is / max_threads < 64 && jcp.mb != 1))
+                        && (outer_thr_eff / inner_thr_eff >= 1.f
+                                   || (jcp.os * jcp.ic * jcp.oc) / max_threads
+                                           < gemm_thrld);
             }
             jcp.nthr = jcp.outer_threading ? max_threads : 1;
             scratchpad.book(key_conv_gemm_col,
-                sizeof(int8_t) * jcp.nthr * jcp.im2col_sz);
+                    sizeof(int8_t) * jcp.nthr * jcp.im2col_sz);
             scratchpad.book(key_conv_int_dat_in_acc_dt,
-                sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc);
+                    sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block
+                            * jcp.oc);
             scratchpad.book(key_conv_gemm_imtr,
-                sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic);
+                    sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic);
         } else if (is_bwd_d) {
             bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
             const size_t outer_work = jcp.ngroups * jcp.mb;
@@ -677,23 +738,259 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp,
                     = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
             const float inner_thr_eff
                     = (float)inner_work / rnd_up(inner_work, max_threads);
-            jcp.outer_threading = (is_depthwise
-                 || (jcp.is / max_threads < 64 && jcp.mb != 1))
-                 && (outer_thr_eff / inner_thr_eff >= 1.f
-                     || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
+            jcp.outer_threading
+                    = (is_depthwise
+                              || (jcp.is / max_threads < 64 && jcp.mb != 1))
+                    && (outer_thr_eff / inner_thr_eff >= 1.f
+                               || (jcp.is * jcp.ic * jcp.oc) / max_threads
+                                       < gemm_thrld);
 
             jcp.nthr = jcp.outer_threading ? max_threads : 1;
             scratchpad.book(key_conv_gemm_col,
-                sizeof(int32_t) * jcp.nthr * jcp.im2col_sz);
+                    sizeof(int32_t) * jcp.nthr * jcp.im2col_sz);
             scratchpad.book(key_conv_int_dat_in_acc_dt,
-                sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic);
+                    sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic);
         } else if (is_bwd_w) {
             assert(!"unimplemented prop_kind");
             return status::unimplemented;
         }
     } else {
         if (is_fwd) {
-            if (!jcp.outer_threading) {
+            const int sh = jcp.stride_h;
+            const int sw = jcp.stride_w;
+            const int spatial = jcp.mb * jcp.ngroups * jcp.od * jcp.os;
+            int K = jcp.ic * jcp.ks;
+
+            // There is some heuristics in the definition of
+            // inner/outer threading cross point due to the nature of the
+            // gemm implementation which we cannot control
+            bool is_blocking_applicable = true && jcp.id == 1 && jcp.od == 1
+                && (!jcp.im2col_sz
+                    // spatial is small
+                    || spatial >= max_threads * simd_w
+                    // inner threading work is greater then outer threading work
+                    || jcp.os < jcp.mb * jcp.ngroups * jcp.od
+                    // im2col is big
+                    || (sw == 1 && K <= 0.05 * jcp.oc))
+                // heuristic condition
+                && (jcp.im2col_sz
+                    || (jcp.ic / jcp.oc < 42
+                        && jcp.ic * jcp.oc * jcp.is < 1024));
+
+            if (is_blocking_applicable) {
+                const int min_oc_block = 8;
+                const int min_os_block = simd_w;
+                const float non_cache_access = 20;
+                const float strided_im2col_k = 8;
+                const float thr_disb_k = 8;
+                const float thr_mem_eff_k{ 1 }, oc_disb_k{ 1 }, os_disb_k{ 1 },
+                        ic_disb_k{ 1 }, reg_osb_disb_k{ 1 }, gemm_eff_k{ 0.5 },
+                        gemm_calc_eff_k{ 1 };
+                const float k_sum = thr_disb_k + oc_disb_k + os_disb_k
+                        + ic_disb_k + reg_osb_disb_k + thr_mem_eff_k
+                        + gemm_eff_k + gemm_calc_eff_k;
+
+                auto calc_max_icb = [=](int nthr_oc, int ocb, int osb,
+                                            int oc_per_thr, int os_per_thr) {
+                    const int block_out_size = ocb * osb;
+                    // TODO: need more precise calculation if stride more than
+                    // kernel size
+                    const int inp_row_size = sh * sw * osb;
+                    int max_icb = 1;
+                    if (jcp.im2col_sz) {
+                        const int col_row_size = jcp.ks * osb;
+                        if (osb >= os_per_thr) { // one pass by os
+                            const int wei_col_size = jcp.ks * ocb;
+                            max_icb = L2 / (inp_row_size + col_row_size);
+                            if (ocb < oc_per_thr) {
+                                max_icb = nstl::min(max_icb,
+                                        (L2 - block_out_size)
+                                                / (col_row_size
+                                                          + wei_col_size));
+                            }
+                        } else {
+                            const int wei_col_size = jcp.ks * oc_per_thr;
+                            max_icb = (L2 - block_out_size)
+                                    / (inp_row_size + col_row_size
+                                              + wei_col_size);
+                        }
+                    } else {
+                        if (osb >= os_per_thr)
+                            max_icb = L2 / inp_row_size;
+                        else {
+                            const int wei_col_size = jcp.ks * oc_per_thr;
+                            max_icb = L2 / (inp_row_size + wei_col_size);
+                        }
+                    }
+                    if (max_icb < jcp.ic) {
+                        if (jcp.im2col_sz) {
+                            const int col_row_size = jcp.ks * osb;
+                            const int wei_col_size = jcp.ks * oc_per_thr;
+                            max_icb = (L2 - block_out_size)
+                                    / (inp_row_size + col_row_size
+                                              + wei_col_size);
+                        }
+                    }
+                    return max_icb;
+                };
+
+                auto est_eff = [=](int nthr_oc, int ocb, int osb, int &icb,
+                                       int max_oc_per_thr, int max_os_per_thr) {
+                    // for given nthr_oc, oc block:
+                    // 1. find ic block to fit into cache
+                    // 2. estimate efficiency basing on rules and heuristic:
+                    // - Minimize im2col cost
+                    // - ratio of FMA number to data size
+                    // - gemm works better if M divided by 48 and N divided by 8
+                    if (osb > max_os_per_thr || ocb > max_oc_per_thr)
+                        return 0.f;
+
+                    int sp_start{ 0 }, sp_end{ 0 }, oc_start{ 0 }, oc_end{ 0 };
+                    size_t max_thr_size{ 0 },
+                            min_thr_size{ (size_t)spatial * jcp.oc + 1 },
+                            max_y{ 0 }, max_oc{ 0 };
+
+                    for (int i = 0; i < max_threads; i++) {
+                        balance2D(max_threads, i, spatial, sp_start, sp_end,
+                                jcp.oc, oc_start, oc_end, nthr_oc);
+                        const size_t thr_size
+                                = (sp_end - sp_start) * (oc_end - oc_start);
+                        if (thr_size > max_thr_size) {
+                            max_y = (sp_end - sp_start);
+                            max_oc = (oc_end - oc_start);
+                            max_thr_size = thr_size;
+                        }
+                        if (thr_size < min_thr_size)
+                            min_thr_size = thr_size;
+                    }
+                    auto thr_disb = (float)min_thr_size / max_thr_size;
+
+                    const int oc_per_thr = max_oc;
+                    const int os_per_thr = max_y;
+                    ocb = nstl::min(oc_per_thr, ocb);
+                    const int os_max = nstl::min(jcp.os, os_per_thr);
+                    osb = nstl::min(os_max, osb);
+
+                    // -- selecting icb ---------------------
+                    int max_ic_block = calc_max_icb(
+                            nthr_oc, ocb, osb, oc_per_thr, os_per_thr);
+                    // if we don't fit into cache then access to memory is
+                    // expensive
+                    int mem_access_cost
+                            = (max_ic_block < 1) ? non_cache_access : 1;
+                    max_ic_block = nstl::max(1, max_ic_block);
+                    icb = nstl::max(1, jcp.ic / div_up(jcp.ic, max_ic_block));
+                    int nb_ic = div_up(jcp.ic, icb);
+                    int kb = icb * jcp.ks;
+                    int kb_caligned = rnd_up(kb, simd_w);
+
+                    // -- mem efficiency ------------
+                    const size_t out_size
+                            = oc_per_thr * rnd_up(os_per_thr, simd_w);
+                    const size_t out_ops = mem_access_cost * out_size
+                            * ((icb == jcp.ic) ? 1 : (2 * nb_ic - 1));
+                    const int osb_caligned = rnd_up(osb, simd_w);
+                    const size_t inp_size
+                            = jcp.ic * rnd_up(os_per_thr * sh * sw, simd_w);
+                    size_t inp_ops = 0;
+                    size_t col_ops = 0;
+                    // TODO: simplify calculations
+                    if (jcp.im2col_sz) {
+                        inp_ops = mem_access_cost * jcp.ks * inp_size;
+                        const float col_tail_koeff = (float)osb_caligned / osb;
+                        col_ops = mem_access_cost
+                                * (jcp.ks * inp_size * col_tail_koeff
+                                          + jcp.ks * inp_size * col_tail_koeff);
+                        if (sw != 1) // im2col with strides is much slower
+                            col_ops *= strided_im2col_k;
+                    } else {
+                        inp_ops = mem_access_cost * jcp.ks * inp_size;
+                    }
+                    // TODO: what about groups?
+                    const size_t wei_size = oc_per_thr * rnd_up(K, simd_w);
+                    const size_t wei_ops = mem_access_cost * wei_size;
+                    // ratio of real FMA to number of memory ops
+                    const float thr_mem_eff
+                            = (((float)os_per_thr / simd_w) * oc_per_thr * K)
+                            / (inp_ops + col_ops + wei_ops + out_ops);
+
+                    auto oc_disb = (float)oc_per_thr / rnd_up(oc_per_thr, ocb);
+                    auto os_disb = (float)os_max / rnd_up(os_max, osb);
+                    auto ic_disb = (float)jcp.ic / rnd_up(jcp.ic, icb);
+
+                    auto reg_osb_disb = (float)osb / rnd_up(osb, 3 * simd_w);
+
+                    // Heuristics
+                    const float gemm_eff = ((float)osb * ocb * kb)
+                            / ((float)oc_per_thr * os_per_thr * K);
+
+                    // number of FMA to memory size
+                    const float gemm_calc_eff
+                            = (((float)osb / simd_w) * ocb * kb)
+                            / (osb_caligned * kb + ocb * kb_caligned
+                                      + ocb * osb_caligned);
+
+                    const float res_eff = pow(pow(thr_disb, thr_disb_k)
+                                    * pow(oc_disb, oc_disb_k)
+                                    * pow(os_disb, os_disb_k)
+                                    * pow(ic_disb, ic_disb)
+                                    * pow(reg_osb_disb, reg_osb_disb_k)
+                                    * pow(thr_mem_eff, thr_mem_eff_k)
+                                    * pow(gemm_eff, gemm_eff_k)
+                                    * pow(gemm_calc_eff, gemm_calc_eff_k),
+                            1.f / k_sum);
+                    return res_eff;
+                };
+
+                /* find the best thread distribution and blocking with highest
+                 * efficiency */
+                int best_nthr_oc{ 1 }, best_ocb{ jcp.oc }, best_osb{ jcp.os },
+                        best_icb{ jcp.ic };
+                float best_thr_eff = est_eff(best_nthr_oc, best_ocb, best_osb,
+                        best_icb, jcp.oc, jcp.os);
+
+                int icb{ best_icb };
+                const int nthr_oc_max = max_threads;
+                for (int nthr_oc = 1; nthr_oc <= nthr_oc_max; ++nthr_oc) {
+                    const int max_oc_per_thr = div_up(jcp.oc, nthr_oc);
+                    const int min_oc_per_thr
+                            = nstl::min(min_oc_block, max_oc_per_thr);
+                    const int max_os_per_thr = nstl::min(jcp.os,
+                            div_up(spatial,
+                                    nstl::max(1, max_threads / nthr_oc)));
+                    const int min_os_per_thr
+                            = nstl::min(min_os_block, max_os_per_thr);
+                    for (int ocb = min_oc_per_thr; ocb <= max_oc_per_thr;
+                            ocb += nstl::max(1,
+                                    nstl::min(min_oc_block,
+                                            max_oc_per_thr - ocb))) {
+                        for (int osb = min_os_per_thr; osb <= jcp.os;
+                                osb += nstl::max(1,
+                                        nstl::min(min_os_block,
+                                                max_os_per_thr - osb))) {
+                            float thr_eff = est_eff(nthr_oc, ocb, osb, icb,
+                                    max_oc_per_thr, max_os_per_thr);
+                            if (thr_eff > best_thr_eff) {
+                                best_thr_eff = thr_eff;
+                                best_nthr_oc = nthr_oc;
+                                best_ocb = ocb;
+                                best_osb = osb;
+                                best_icb = icb;
+                            }
+                        }
+                    }
+                }
+
+                jcp.outer_threading = true;
+                jcp.nthr_oc = best_nthr_oc;
+                jcp.oc_block = best_ocb;
+                jcp.os_block = best_osb;
+                jcp.ic_block = best_icb;
+                // TODO: define loop order
+                // if im2col then gemm_loop_rlb and gemm_loop_lrb looks
+                // preferable otherwise other loop orders are possible
+                jcp.loop_order = gemm_loop_rlb;
+            } else {
                 const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od;
                 const float outer_thr_eff = (float)outer_work_amount
                         / rnd_up(outer_work_amount, max_threads);
@@ -702,38 +999,43 @@ status_t init_conf(jit_gemm_conv_conf_t &jcp,
                 const float inner_thr_eff = (float)inner_work_amount
                         / rnd_up(inner_work_amount, max_threads);
                 jcp.outer_threading = jcp.os / max_threads < 512
-                    && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2)
-                    && (outer_thr_eff / inner_thr_eff >= 1.f
-                      || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
+                        && IMPLICATION(
+                                   jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2)
+                        && (outer_thr_eff / inner_thr_eff >= 1.f
+                                   || (jcp.os * jcp.ic * jcp.oc) / max_threads
+                                           < gemm_thrld);
             }
+            if (jcp.im2col_sz)
+                jcp.im2col_sz = (ptrdiff_t)jcp.ic_block * jcp.ks * jcp.os_block;
         } else if (is_bwd_d) {
             const size_t outer_work_amount = jcp.ngroups * jcp.mb;
             const float outer_thr_eff = (float)outer_work_amount
-                / rnd_up(outer_work_amount, max_threads);
+                    / rnd_up(outer_work_amount, max_threads);
             const size_t inner_work
-                = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
-            const float inner_thr_eff = (float)inner_work
-                / rnd_up(inner_work, max_threads);
+                    = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
+            const float inner_thr_eff
+                    = (float)inner_work / rnd_up(inner_work, max_threads);
             jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64)
-                && (jcp.mb != 1 || jcp.ngroups > 2)
-                && (outer_thr_eff / inner_thr_eff >= 1.f
-                    || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
+                    && (jcp.mb != 1 || jcp.ngroups > 2)
+                    && (outer_thr_eff / inner_thr_eff >= 1.f
+                               || (jcp.is * jcp.ic * jcp.oc) / max_threads
+                                       < gemm_thrld);
         } else if (is_bwd_w)
             jcp.outer_threading = jcp.os / max_threads < 256
-                && (jcp.mb != 1 || jcp.ngroups > 2);
+                    && (jcp.mb != 1 || jcp.ngroups > 2);
 
         jcp.nthr = jcp.outer_threading ? max_threads : 1;
-        scratchpad.book(key_conv_gemm_col,
-                sizeof(float) * jcp.nthr * jcp.im2col_sz);
+        scratchpad.book(
+                key_conv_gemm_col, sizeof(float) * jcp.nthr * jcp.im2col_sz);
 
         if (is_bwd_w) {
             jcp.need_wei_reduction = mkldnn_thr_syncable()
-                ? jcp.mb != 1 && jcp.nthr != 1 : false;
+                    ? jcp.mb != 1 && jcp.nthr != 1
+                    : false;
             scratchpad.book(key_conv_wei_reduction,
                     sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size());
         }
     }
-
     return status::success;
 }
 
diff --git a/src/cpu/gemm_convolution_utils.hpp b/src/cpu/gemm_convolution_utils.hpp
index e006789344e..ff4165b98ab 100644
--- a/src/cpu/gemm_convolution_utils.hpp
+++ b/src/cpu/gemm_convolution_utils.hpp
@@ -34,7 +34,7 @@ namespace jit_gemm_convolution_utils {
 void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
         int od);
 void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
-       float *__restrict col, int hs, int hb, int ws, int wb);
+       float *__restrict col, int ss, int sb, int cs, int cb);
 template <typename T>
 void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
         T* __restrict imtr, uint8_t *__restrict col,
diff --git a/src/cpu/jit_primitive_conf.hpp b/src/cpu/jit_primitive_conf.hpp
index 5c011cdb0cc..7b394a94f28 100644
--- a/src/cpu/jit_primitive_conf.hpp
+++ b/src/cpu/jit_primitive_conf.hpp
@@ -31,6 +31,8 @@ enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn,
                             loop_ngcw, loop_nhwcg, loop_nwcg};
 enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr,
                             loop_brl};
+enum conv_gemm_loop_order_t { gemm_loop_rlb, gemm_loop_lrb };
+
 enum conv_kernel_kind_t {embd_bcast, expl_bcast};
 enum conv_harness_t {harness_2d_reduction, harness_3d_reduction,
                      harness_mb_reduction};
@@ -416,7 +418,10 @@ struct jit_gemm_conv_conf_t {
     bool signed_input;
     int oh_block;
     int ow_block;
+    int os_block;
     bool outer_threading;
+    conv_gemm_loop_order_t loop_order;
+    int nthr_oc;
 };
 
 struct jit_1x1_conv_call_s {