Skip to content

Commit

Permalink
cpu: conv: gemm f32 forward: oc/os/ic blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
ankalinin committed May 7, 2019
1 parent dc5a566 commit 75cc356
Show file tree
Hide file tree
Showing 5 changed files with 663 additions and 298 deletions.
3 changes: 1 addition & 2 deletions src/common/mkldnn_thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
grp = ithr / grp_size_big;
grp_ithr = ithr % grp_size_big;
grp_nthr = grp_size_big;
}
else { // ithr in last groups
} else { // ithr in last groups
grp = n_grp_big + ithr_bound_distance / grp_size_small;
grp_ithr = ithr_bound_distance % grp_size_small;
grp_nthr = grp_size_small;
Expand Down
197 changes: 128 additions & 69 deletions src/cpu/gemm_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;

namespace {
struct im_pos_t {
im_pos_t() : n{ 0 }, g{ 0 }, od{ 0 }, sp{ 0 }, ic{ 0 }, oc{ 0 } {}
int n, g, od, sp, ic, oc;
bool do_im2col(const im_pos_t &prev) const {
return true
&& (n != prev.n || g != prev.g || od != prev.od || sp != prev.sp
|| ic != prev.ic);
}
};
} // namespace

void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
Expand All @@ -41,97 +53,144 @@ void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {

const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;

const int M = jcp.os * jcp.od;
const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
const size_t dst_step = jcp.oc * M;
const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
const size_t weights_oc_size = jcp.ic * jcp.ks;
const size_t weights_g_size = weights_oc_size * jcp.oc;

assert(IMPLICATION(
jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));

const int K = jcp.ic * jcp.ks;
const int N = jcp.oc;
jcp.id != 1, jcp.os_block == jcp.os && jcp.ic_block == jcp.ic));

if (jcp.im2col_sz && jcp.id != 1)
parallel_nd(jcp.im2col_sz * jcp.nthr,
[&](ptrdiff_t i) { col[i] = (data_t)0; });

const int nb_oh = div_up(jcp.oh, jcp.oh_block);
const int nb_ow = div_up(jcp.ow, jcp.ow_block);
const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow;
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;

int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 };
size_t start = 0, end = 0;

balance211(work_amount, nthr, ithr, start, end);
nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb,
nb_oh, owb, nb_ow);
for (size_t iwork = start; iwork < end; ++iwork) {
int oh = ohb * jcp.oh_block;
int ow = owb * jcp.ow_block;
const data_t *_src = src + (n * jcp.ngroups + g) * src_step;
const data_t *_weights = weights + g * weights_g_size;
data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step;
const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
if (jcp.im2col_sz) {
auto inner_ker = [&](int spatial, const im_pos_t &curr, im_pos_t &prev,
im_pos_t &step, const im_pos_t &end) {
const data_t *_src
= src + (curr.n * jcp.ngroups + curr.g) * src_step;
step.oc = nstl::min(
jcp.oc_block, nstl::min(jcp.oc, end.oc) - curr.oc);
step.sp = nstl::min(jcp.os_block,
nstl::min(jcp.os - curr.sp, end.sp - spatial));
step.ic = nstl::min(
jcp.ic_block, nstl::min(jcp.ic, end.ic) - curr.ic);
bool do_im2col = curr.do_im2col(prev);
prev = curr;

if (jcp.im2col_sz && do_im2col) {
if (jcp.id == 1)
jit_gemm_convolution_utils::im2col(
jcp, _src, _col, oh, h_step, ow, w_step);
jit_gemm_convolution_utils::im2col(jcp, _src, _col, curr.sp,
step.sp, curr.ic, step.ic);
else
jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od);
jit_gemm_convolution_utils::im2col_3d(
jcp, _src, _col, curr.od);
}

const data_t one = 1.0;

const int m = h_step * w_step;
const int M = jcp.os * jcp.od;
const size_t dst_step = jcp.oc * M;
const int m = step.sp;
const int LDA = jcp.im2col_sz ? m : M;
data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow;

extended_sgemm("N", "N", &m, &N, &K, &one,
jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K,
&this->beta_, _dst, &M);

data_t *d = _dst;
if (eltwise_) {
// fast branch for ReLU case
if (eltwise_->alg_ == alg_kind::eltwise_relu) {
parallel_nd(jcp.oc, [&](const int oc) {
data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0;
data_t *d_ = d + oc * M;
PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
d_[oS] += b;
if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_;
}
});
} else {
parallel_nd(jcp.oc, [&](const int oc) {
data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0;
data_t *d_ = d + oc * M;
data_t *_dst = dst + (curr.n * jcp.ngroups + curr.g) * dst_step
+ curr.oc * M + curr.od * jcp.os + curr.sp;
const int K = step.ic * jcp.ks;
const int LDB = jcp.ic * jcp.ks;
const int N = step.oc;

// TODO: what if this->beta_ != 0 && != 1 ?
const float beta = (curr.ic == 0) ? this->beta_ : one;
const float *_source = jcp.im2col_sz
? _col
: _src + curr.ic * M + curr.od * jcp.os + curr.sp;
const data_t *_weights = weights + curr.g * weights_g_size
+ curr.oc * weights_oc_size + curr.ic * jcp.ks;

extended_sgemm("N", "N", &m, &N, &K, &one, _source, &LDA, _weights,
&LDB, &beta, _dst, &M);
if (curr.ic == jcp.ic - step.ic) {
const int oc_start = curr.g * jcp.oc + curr.oc;
if (eltwise_) {
// fast branch for ReLU case
if (eltwise_->alg_ == alg_kind::eltwise_relu) {
parallel_nd(step.oc, [&](const int oc) {
data_t b = jcp.with_bias ? bias[oc_start + oc] : 0;
data_t *d_ = _dst + oc * M;
PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
d_[oS] += b;
if (d_[oS] < 0)
d_[oS] *= eltwise_->alpha_;
}
});
} else {
parallel_nd(step.oc, [&](const int oc) {
data_t b = jcp.with_bias ? bias[oc_start + oc] : 0;
data_t *d_ = _dst + oc * M;
PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
d_[oS] += b;
d_[oS] = eltwise_->compute_scalar(d_[oS]);
}
});
}
} else if (jcp.with_bias) {
parallel_nd(step.oc, [&](const int oc) {
data_t b = bias[oc_start + oc];
data_t *d_ = _dst + oc * M;
PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
d_[oS] += b;
d_[oS] = eltwise_->compute_scalar(d_[oS]);
}
});
}
} else if (jcp.with_bias) {
parallel_nd(jcp.oc, [&](const int oc) {
data_t b = bias[g * jcp.oc + oc];
data_t *d_ = d + oc * M;
PRAGMA_OMP_SIMD()
for (int oS = 0; oS < m; ++oS) {
d_[oS] += b;
}
});
}
nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh,
owb, nb_ow);
};
im_pos_t start, end;
end.ic = jcp.ic;

if (jcp.id == 1) {
const int sp_work = jcp.mb * jcp.ngroups * jcp.od * jcp.os;
balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc,
end.oc, jcp.nthr_oc);
} else {
const int sp_work = jcp.mb * jcp.ngroups * jcp.od;
balance2D(nthr, ithr, sp_work, start.sp, end.sp, jcp.oc, start.oc,
end.oc, jcp.nthr_oc);
start.sp *= jcp.os;
end.sp *= jcp.os;
}

im_pos_t curr, prev, step;
prev.n = prev.g = prev.od = prev.sp = prev.ic = -1;
step.oc = jcp.oc_block;
step.sp = jcp.os_block;
step.ic = jcp.ic_block;

if (jcp.loop_order == gemm_loop_rlb)
for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic)
for (int spatial = start.sp; spatial < end.sp;
spatial += step.sp) {
nd_iterator_init(spatial, curr.n, jcp.mb, curr.g,
jcp.ngroups, curr.od, jcp.od, curr.sp, jcp.os);
for (curr.oc = start.oc; curr.oc < end.oc;
curr.oc += step.oc) {
inner_ker(spatial, curr, prev, step, end);
}
}
else if (jcp.loop_order == gemm_loop_lrb)
for (int spatial = start.sp; spatial < end.sp; spatial += step.sp) {
nd_iterator_init(spatial, curr.n, jcp.mb, curr.g, jcp.ngroups,
curr.od, jcp.od, curr.sp, jcp.os);
for (curr.ic = 0; curr.ic < jcp.ic; curr.ic += step.ic)
for (curr.oc = start.oc; curr.oc < end.oc;
curr.oc += step.oc)
inner_ker(spatial, curr, prev, step, end);
}
else
assert("Unknown loop order");
});
}

Expand Down Expand Up @@ -256,10 +315,10 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights(
if (jcp.im2col_sz) {
if (jcp.id == 1)
jit_gemm_convolution_utils::im2col(
jcp, _src, _col, 0, jcp.oh, 0, jcp.ow);
jcp, _src, _col, 0, jcp.os, 0, jcp.ic);
else
jit_gemm_convolution_utils::im2col_3d(jcp, _src,
_col, od);
jit_gemm_convolution_utils::im2col_3d(
jcp, _src, _col, od);
}

const data_t zero = 0.0, one = 1.0;
Expand Down
Loading

0 comments on commit 75cc356

Please sign in to comment.