Skip to content

Commit

Permalink
[XLA:CPU] Extend the custom algorithm for transposed convolutions
Browse files Browse the repository at this point in the history
This commit adds support for a case with multiple input and output channels at the same time.

Performance of the already supported cases is not impacted. New cases show expected performance improvement. Results:

name                                                           old cpu/op   new cpu/op   delta
BM_Conv1DTransposedStrided/129/1/process_time                  34.0ms ±15%  34.7ms ±17%     ~     (p=0.548 n=5+5)
BM_Conv1DTransposedStrided/129/3/process_time                   15.4s ±21%    0.1s ±13%  -99.52%  (p=0.008 n=5+5)
BM_Conv1DTransposedStridedNonDefaultLayout/129/1/process_time  32.5ms ±15%  32.4ms ±17%     ~     (p=1.000 n=5+5)
BM_Conv1DTransposedStridedNonDefaultLayout/129/3/process_time   16.2s ±18%    0.1s ±14%  -99.55%  (p=0.008 n=5+5)
BM_Conv2DTransposedStrided/process_time                        36.1ms ±16%  34.9ms ±19%     ~     (p=0.841 n=5+5)

name                                                           old time/op  new time/op  delta
BM_Conv1DTransposedStrided/129/1/process_time                  9.58ms ±22%  9.56ms ±21%     ~     (p=1.000 n=5+5)
BM_Conv1DTransposedStrided/129/3/process_time                   732ms ±26%    15ms ±19%  -97.91%  (p=0.008 n=5+5)
BM_Conv1DTransposedStridedNonDefaultLayout/129/1/process_time  8.96ms ±18%  8.91ms ±23%     ~     (p=0.841 n=5+5)
BM_Conv1DTransposedStridedNonDefaultLayout/129/3/process_time   783ms ±24%    14ms ±18%  -98.21%  (p=0.008 n=5+5)
BM_Conv2DTransposedStrided/process_time                        10.2ms ±22%   9.9ms ±22%     ~     (p=0.690 n=5+5)

Planned improvements of this algorithm:
- support feature_group_size > 1 (grouped convolution),
- parallel packing of the patches (second algorithm step),
- explore input kernel rotation possibilities & perf impact,

PiperOrigin-RevId: 710297666
  • Loading branch information
Adam-Banas authored and Google-ML-Automation committed Dec 28, 2024
1 parent 82d1bb3 commit 5042012
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 49 deletions.
66 changes: 38 additions & 28 deletions xla/backends/cpu/runtime/convolution_thunk_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ constexpr auto kMaxConvMatrixSize = static_cast<size_t>(8) << 30; // 8 GiB
// Returns in 'out_data' (assumes to be zero-initialized) image patch in storage
// order (width, height, depth), constructed from patches in 'conv_matrix',
// which is required to be in storage order (in_width * in_height, filter_width,
// filter_height, in_depth). Based on TF implementation by Yangqing Jia (jiayq).
// filter_height, out_depth).
// Based on TF implementation by Yangqing Jia (jiayq).
// TODO(adambanas): The original implementation implicitly rotates the kernel by
// 180 degrees, but to be backwards compatible, we cannot do that in XLA. This
// results in counterintuitive operations on conv_matrix, which is also 15-20%
Expand Down Expand Up @@ -109,17 +110,18 @@ bool EigenTransposedConv2D(
Eigen::Index padding_y_before, Eigen::Index padding_y_after,
Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation,
std::function<void()> done_callback, bool use_thunk_runtime) {
// TODO(adambanas): Current custom conv algorithm doesn't support both
// multiple input channels and multiple output channels (i.e. kernel_filters)
// at the same time.
CHECK(input_channels == 1 || kernel_filters == 1);

typedef Eigen::TensorMap<Eigen::Tensor<ScalarType, 2, Eigen::RowMajor>,
Eigen::Unaligned>
TensorMap;
typedef Eigen::TensorMap<Eigen::Tensor<const ScalarType, 2, Eigen::RowMajor>,
Eigen::Aligned>
ConstTensorMap;
// Grouped convolutions are not supported yet.
CHECK(kernel_channels == input_channels);

using TensorMap2D =
Eigen::TensorMap<Eigen::Tensor<ScalarType, 2, Eigen::RowMajor>,
Eigen::Unaligned>;
using ConstTensorMap3D =
Eigen::TensorMap<Eigen::Tensor<const ScalarType, 3, Eigen::RowMajor>,
Eigen::Aligned>;
using ConstTensorMap2D =
Eigen::TensorMap<Eigen::Tensor<const ScalarType, 2, Eigen::RowMajor>,
Eigen::Aligned>;

// Total spatial dimensions.
const int input_image_size = input_x * input_y;
Expand Down Expand Up @@ -147,17 +149,17 @@ bool EigenTransposedConv2D(
out_data + input_batch * output_image_size * kernel_filters,
ScalarType(0.0f));

// Initialize contraction dims (we need to transpose 'B' below).
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
contract_dims[0].first = 1;
contract_dims[0].second = 1;
// Initialize contraction dims (we need to transpose 'B' below, the dimension
// we need to contract is 'kernel_channels').
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims = {
Eigen::IndexPair<Eigen::DenseIndex>(1, 1)};

// Compute intermediate results (convolution matrix) into conv_matrix.
TensorMap C(conv_matrix_data, input_batch * input_image_size,
kernel_total_size);
TensorMap2D C(conv_matrix_data, input_batch * input_image_size,
kernel_total_size);

ConstTensorMap A(lhs, input_batch * input_image_size, input_channels);
ConstTensorMap B(rhs, kernel_total_size, input_channels);
ConstTensorMap2D A(lhs, input_batch * input_image_size, input_channels);
ConstTensorMap3D B(rhs, kernel_x * kernel_y, kernel_channels, kernel_filters);

// Use concurrent execution if we have a thread pool device.
constexpr bool use_thread_pool =
Expand Down Expand Up @@ -200,25 +202,34 @@ bool EigenTransposedConv2D(
}
};

// Molds the output of the contraction into the shape expected by packing
// algorithm:
// - the minor dimension (dims[1]): the patch values to be packed; contiguous
// in memory
// - the major dimension (dims[0]): everything else
Eigen::DSizes<Eigen::Index, 2> post_contract_dims;
post_contract_dims[0] = input_batch * input_image_size;
post_contract_dims[1] = kernel_total_size;

if (done_callback) {
// Schedule the work in the thread pool and return.
C.device(device, std::move(pack_patches)) = A.contract(B, contract_dims);
C.device(device, std::move(pack_patches)) =
A.contract(B, contract_dims).reshape(post_contract_dims);
} else {
// Run synchronously in the current thread.
C.device(device) = A.contract(B, contract_dims);
C.device(device) = A.contract(B, contract_dims).reshape(post_contract_dims);
pack_patches();
}
return true;
}

inline bool CanUseCustomTransposedConv(
Eigen::Index input_channels, Eigen::Index kernel_filters,
Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index lhs_x_dilation,
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count) {
return (lhs_x_dilation > 1 || lhs_y_dilation > 1) && rhs_x_dilation == 1 &&
rhs_y_dilation == 1 && (input_channels == 1 || kernel_filters == 1) &&
feature_group_count == 1 && x_stride == 1 && y_stride == 1;
rhs_y_dilation == 1 && feature_group_count == 1 && x_stride == 1 &&
y_stride == 1;
}

// Algorithm that works for all types of 2D convolutions. Even though it works
Expand Down Expand Up @@ -372,9 +383,8 @@ void EigenConv2D(const EigenDevice& device, ScalarType* out, ScalarType* lhs,
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count,
std::function<void()> done_callback, bool use_thunk_runtime) {
if (CanUseCustomTransposedConv(input_channels, kernel_filters, x_stride,
y_stride, lhs_x_dilation, lhs_y_dilation,
rhs_x_dilation, rhs_y_dilation,
if (CanUseCustomTransposedConv(x_stride, y_stride, lhs_x_dilation,
lhs_y_dilation, rhs_x_dilation, rhs_y_dilation,
feature_group_count)) {
if (EigenTransposedConv2D(
device, out, lhs, rhs, input_batch, input_x, input_y,
Expand Down
75 changes: 54 additions & 21 deletions xla/service/cpu/benchmarks/convolution_benchmark_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,31 +138,39 @@ static void BM_GroupedConv2D(benchmark::State& state) {

// Regular strided 1D convolution. Shapes come from an actual use case.
static void BM_Conv1DStrided(benchmark::State& state) {
int input_channels = state.range(0);
int output_channels = state.range(1);

std::string hlo_module = R"(
HloModule jit_jconvf
ENTRY main.6 {
Arg_0.1 = f32[16,1,25600]{2,1,0} parameter(0)
Arg_1.2 = f32[1,129,256]{2,1,0} parameter(1)
ROOT conv.3 = f32[16,129,400]{2,1,0} convolution(Arg_0.1, Arg_1.2),
Arg_0.1 = $input_shape parameter(0)
Arg_1.2 = $kernel_shape parameter(1)
ROOT conv.3 = $output_shape convolution(Arg_0.1, Arg_1.2),
window={size=256 stride=64 pad=96_96}, dim_labels=bf0_io0->bf0
}
)";

std::minstd_rand0 engine;

// NCW layout
auto input_shape = ShapeUtil::MakeShape(F32, {16, 1, 25600});
auto input_shape = ShapeUtil::MakeShape(F32, {16, input_channels, 25600});
auto output_shape = ShapeUtil::MakeShape(F32, {16, output_channels, 400});
// IOW layout
auto kernel_shape = ShapeUtil::MakeShape(F32, {1, 129, 256});
auto kernel_shape =
ShapeUtil::MakeShape(F32, {input_channels, output_channels, 256});

auto input =
*LiteralUtil::CreateRandomLiteral<F32>(input_shape, &engine, 1.0f, 0.1f);
auto kernel =
*LiteralUtil::CreateRandomLiteral<F32>(kernel_shape, &engine, 1.0f, 0.1f);
std::vector<const Literal*> args = {&input, &kernel};

CHECK_OK(RunHloBenchmark(state, hlo_module, args));
CHECK_OK(RunHloBenchmark(state, hlo_module, args,
{{"$input_shape", input_shape.ToString()},
{"$kernel_shape", kernel_shape.ToString()},
{"$output_shape", output_shape.ToString()}}));
}

// Transposed version (i.e. gradient) of BM_Conv1DStrided. In terms of shapes,
Expand All @@ -172,61 +180,76 @@ static void BM_Conv1DStrided(benchmark::State& state) {
// Currently, the performance is few times worse than regular conv when they
// should be similar.
static void BM_Conv1DTransposedStrided(benchmark::State& state) {
int input_channels = state.range(0);
int output_channels = state.range(1);

std::string hlo_module = R"(
HloModule jit_jconvt
ENTRY main.6 {
Arg_0.1 = f32[16,129,400]{2,1,0} parameter(0)
Arg_1.2 = f32[129,1,256]{2,1,0} parameter(1)
ROOT conv.3 = f32[16,1,25600]{2,1,0} convolution(Arg_0.1, Arg_1.2),
Arg_0.1 = $input_shape parameter(0)
Arg_1.2 = $kernel_shape parameter(1)
ROOT conv.3 = $output_shape convolution(Arg_0.1, Arg_1.2),
window={size=256 pad=159_159 lhs_dilate=64}, dim_labels=bf0_io0->bf0
}
)";

std::minstd_rand0 engine;

// NCW layout
auto input_shape = ShapeUtil::MakeShape(F32, {16, 129, 400});
auto input_shape = ShapeUtil::MakeShape(F32, {16, input_channels, 400});
auto output_shape = ShapeUtil::MakeShape(F32, {16, output_channels, 25600});
// IOW layout
auto kernel_shape = ShapeUtil::MakeShape(F32, {129, 1, 256});
auto kernel_shape =
ShapeUtil::MakeShape(F32, {input_channels, output_channels, 256});

auto input =
*LiteralUtil::CreateRandomLiteral<F32>(input_shape, &engine, 1.0f, 0.1f);
auto kernel =
*LiteralUtil::CreateRandomLiteral<F32>(kernel_shape, &engine, 1.0f, 0.1f);
std::vector<const Literal*> args = {&input, &kernel};

CHECK_OK(RunHloBenchmark(state, hlo_module, args));
CHECK_OK(RunHloBenchmark(state, hlo_module, args,
{{"$input_shape", input_shape.ToString()},
{"$kernel_shape", kernel_shape.ToString()},
{"$output_shape", output_shape.ToString()}}));
}

// The same shapes as BM_Conv1DTransposedStrided, but with a different layout.
static void BM_Conv1DTransposedStridedNonDefaultLayout(
benchmark::State& state) {
int input_channels = state.range(0);
int output_channels = state.range(1);
std::string hlo_module = R"(
HloModule jit_jconvt
ENTRY main.6 {
Arg_0.1 = f32[16,400,129]{2,1,0} parameter(0)
Arg_1.2 = f32[256,1,129]{2,1,0} parameter(1)
ROOT conv.3 = f32[16,25600,1]{2,1,0} convolution(Arg_0.1, Arg_1.2),
Arg_0.1 = $input_shape parameter(0)
Arg_1.2 = $kernel_shape parameter(1)
ROOT conv.3 = $output_shape convolution(Arg_0.1, Arg_1.2),
window={size=256 pad=159_159 lhs_dilate=64}, dim_labels=b0f_0oi->b0f
}
)";

std::minstd_rand0 engine;

// NWC layout
auto input_shape = ShapeUtil::MakeShape(F32, {16, 400, 129});
auto input_shape = ShapeUtil::MakeShape(F32, {16, 400, input_channels});
auto output_shape = ShapeUtil::MakeShape(F32, {16, 25600, output_channels});
// WOI layout
auto kernel_shape = ShapeUtil::MakeShape(F32, {256, 1, 129});
auto kernel_shape =
ShapeUtil::MakeShape(F32, {256, output_channels, input_channels});

auto input =
*LiteralUtil::CreateRandomLiteral<F32>(input_shape, &engine, 1.0f, 0.1f);
auto kernel =
*LiteralUtil::CreateRandomLiteral<F32>(kernel_shape, &engine, 1.0f, 0.1f);
std::vector<const Literal*> args = {&input, &kernel};

CHECK_OK(RunHloBenchmark(state, hlo_module, args));
CHECK_OK(RunHloBenchmark(state, hlo_module, args,
{{"$input_shape", input_shape.ToString()},
{"$kernel_shape", kernel_shape.ToString()},
{"$output_shape", output_shape.ToString()}}));
}

// Regular strided 2D convolution. Buffer sizes and convolution parameters are
Expand Down Expand Up @@ -445,9 +468,19 @@ BENCHMARK(BM_GroupedConv2D)
// 1D and 2D strided convolutions
// -------------------------------------------------------------------------- //

BENCHMARK(BM_Conv1DStrided)->MeasureProcessCPUTime();
BENCHMARK(BM_Conv1DTransposedStrided)->MeasureProcessCPUTime();
BENCHMARK(BM_Conv1DTransposedStridedNonDefaultLayout)->MeasureProcessCPUTime();
BENCHMARK(BM_Conv1DStrided)
->MeasureProcessCPUTime()
->Args({1, 129})
->Args({3, 129});
BENCHMARK(BM_Conv1DTransposedStrided)
->MeasureProcessCPUTime()
->MeasureProcessCPUTime()
->Args({129, 1})
->Args({129, 3});
BENCHMARK(BM_Conv1DTransposedStridedNonDefaultLayout)
->MeasureProcessCPUTime()
->Args({129, 1})
->Args({129, 3});

BENCHMARK(BM_Conv2DStrided)->MeasureProcessCPUTime();
BENCHMARK(BM_Conv2DTransposedStrided)->MeasureProcessCPUTime();
Expand Down

0 comments on commit 5042012

Please sign in to comment.