Skip to content

Commit

Permalink
perf: implement conv1d with im2col and GEMM (#1597)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebraraktas authored Feb 9, 2024
1 parent 4c7b956 commit ce47032
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 71 deletions.
4 changes: 4 additions & 0 deletions include/ctranslate2/ops/conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ namespace ctranslate2 {
const StorageView& weight,
const StorageView* bias,
StorageView& output) const;

void compute_with_gemm(const StorageView& input, const StorageView& weight, StorageView& output) const;

void im2col(const StorageView& input, StorageView& output, dim_t kernel_size) const;
};

}
Expand Down
148 changes: 77 additions & 71 deletions src/ops/conv1d_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,61 +123,12 @@ namespace ctranslate2 {
# define CT2_NO_BLAS
# endif

# include "ctranslate2/ops/transpose.h"
# include "ctranslate2/ops/gemm.h"
# include "cpu/parallel.h"

namespace ctranslate2 {
namespace ops {

static void conv1d_kernel(const float* input,
const float* weight,
const float* bias,
float* output,
dim_t batch_size,
dim_t input_length,
dim_t output_length,
dim_t in_channels,
dim_t out_channels,
dim_t kernel_size,
dim_t stride,
dim_t padding) {
cpu::parallel_for(0, batch_size * out_channels, 1, [&](dim_t begin, dim_t end) {
for (dim_t i = begin; i < end; ++i) {
const dim_t b = i / out_channels;
const dim_t c_out = i % out_channels;

const float* filter = weight + (c_out * in_channels * kernel_size);
const float* x = input + b * (in_channels * input_length);
float* y = output + b * (out_channels * output_length);

for (dim_t t_out = 0; t_out < output_length; ++t_out) {
const dim_t t_in = t_out * stride - padding;

const dim_t window_offset = std::clamp(t_in, dim_t(0), input_length);
const dim_t window_end = std::clamp(t_in + kernel_size, dim_t(0), input_length);
const dim_t window_size = window_end - window_offset;
const dim_t filter_offset = window_offset - t_in;

const float* window = x + (window_offset * in_channels);
const float* kernel = filter + (filter_offset * in_channels);

#ifdef CT2_NO_BLAS
float value = 0;
for (dim_t j = 0; j < window_size * in_channels; ++j)
value += window[j] * kernel[j];
#else
float value = cblas_sdot(window_size * in_channels, window, 1, kernel, 1);
#endif

if (bias)
value += bias[c_out];

y[c_out * output_length + t_out] = value;
}
}
});
}

template<>
void Conv1D::compute<Device::CPU, float>(const StorageView& input,
const StorageView& weight,
Expand All @@ -186,34 +137,89 @@ namespace ctranslate2 {
if (_dilation != 1)
throw std::runtime_error("Dilation is not supported in this Conv1D implementation");

compute_with_gemm(input, weight, output);
// Add bias
if (bias) {
// Need to broadcast along dims 0 and 2, because output shape is:
// batch_size, out_channels, output_length
const auto batch_size = output.dim(0);
const auto out_channels = output.dim(1);
const auto output_length = output.dim(2);
const auto a = bias->data<float>();
const auto b = output.data<float>();
cpu::parallel_for(0, batch_size * out_channels, 1, [&](dim_t begin, dim_t end){
for (dim_t i = begin; i < end; ++i) {
// Add bias element a_i to output_length elements at once
// adjust index of `a` for items in the batch by calculating modulo
const auto a_i = a[i % out_channels];
const auto b_i = b + i * output_length;
primitives<>::add(a_i, b_i, b_i, output_length);
}
});
}
}

void Conv1D::compute_with_gemm(const StorageView& input,
const StorageView& weight,
StorageView& output) const {
const dim_t batch_size = input.dim(0);
const dim_t in_channels = input.dim(1);
const dim_t input_length = input.dim(2);
const dim_t output_length = output.dim(2);
const dim_t out_channels = weight.dim(0);
const dim_t kernel_size = weight.dim(2);
const dim_t output_length = output.dim(2);

// Transpose input and weight to apply the kernel with a single contiguous dot.
const Transpose transpose_op({0, 2, 1});
StorageView input_t;
StorageView weight_t;
transpose_op(input, input_t);
transpose_op(weight, weight_t);

conv1d_kernel(input_t.data<float>(),
weight_t.data<float>(),
bias ? bias->data<float>() : nullptr,
output.data<float>(),
batch_size,
input_length,
output_length,
in_channels,
out_channels,
kernel_size,
_stride,
_padding);
std::vector im2col_output_shape{batch_size, in_channels * kernel_size, output_length};
StorageView im2col_output(std::move(im2col_output_shape), static_cast<float>(0.0), Device::CPU);
im2col(input, im2col_output, kernel_size);
// Create a 2D view of weight to use in GEMM
const StorageView weight_view({weight.dim(0), in_channels * kernel_size}, const_cast<float*>(weight.data<float>()));

const dim_t m = out_channels;
const dim_t n = output_length;
const dim_t k = im2col_output.dim(1);
const dim_t strideb = k * output_length;
const dim_t stridec = out_channels * output_length;
auto* b = im2col_output.data<float>();
auto* c = output.data<float>();
const Gemm gemm(1.0, 0.0, false, false);
cpu::parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
for (dim_t i = begin; i < end; ++i) {
float* b_i = b + (i * strideb);
float* c_i = c + (i * stridec);
StorageView cc({m, n}, c_i);
StorageView bb({k, n}, b_i);
gemm(weight_view, bb, cc);
}
});
}

void Conv1D::im2col(const StorageView& input, StorageView& output, const dim_t kernel_size) const {
// input: batch_size x in_channels x input_length
// output: batch_size x (in_channels * kernel_size) x output_length
const dim_t batch_size = input.dim(0);
const dim_t in_channels = input.dim(1);
const dim_t input_length = input.dim(2);
auto* out = output.data <float>();
const auto* in = input.data <float>();
dim_t input_channel_offset = 0;
dim_t out_index = 0;
for (int i = 0; i < batch_size; i++) {
for (int c = 0; c < in_channels; c++) {
// For each input channel fill (kernel_size * output_length) items in output array
for (int k = 0; k < kernel_size; k++) {
for (dim_t ti = -_padding; ti <= (input_length - kernel_size + _padding); ti += _stride) {
// Fill items in [0, input_length) range
const auto window_i = k + ti;
if (0 <= window_i && window_i < input_length) {
out[out_index] = in[window_i + input_channel_offset];
}
out_index += 1;
}
}
input_channel_offset += input_length;
}
}
}
}
}

Expand Down

0 comments on commit ce47032

Please sign in to comment.