Skip to content

Commit

Permalink
[SYCL][ESIMD][NFC] Optimize reduction implementation and add more tes…
Browse files Browse the repository at this point in the history
…t cases (intel#15215)
  • Loading branch information
fineg74 authored Aug 29, 2024
1 parent d5aaba1 commit c6b611d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 64 deletions.
27 changes: 4 additions & 23 deletions sycl/include/sycl/ext/intel/esimd/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1476,35 +1476,16 @@ template <typename T0, typename T1, int SZ> struct esimd_apply_prod {
template <typename T0, typename T1, int SZ> struct esimd_apply_reduced_max {
template <typename... T>
simd<T0, SZ> operator()(simd<T1, SZ> v1, simd<T1, SZ> v2) {
if constexpr (detail::is_generic_floating_point_v<T1>) {
using CppT = __ESIMD_DNS::element_type_traits<T1>::EnclosingCppT;
return __ESIMD_DNS::convert_vector<T1, CppT, SZ>(
__spirv_ocl_fmax<CppT, SZ>(
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v1.data()),
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v2.data())));
} else if constexpr (std::is_unsigned<T1>::value) {
return __esimd_umax<T1, SZ>(v1.data(), v2.data());
} else {
return __esimd_smax<T1, SZ>(v1.data(), v2.data());
}
return __ESIMD_DNS::convert_vector<T0, T1, SZ>(
__ESIMD_NS::max(v1, v2).data());
}
};

template <typename T0, typename T1, int SZ> struct esimd_apply_reduced_min {
template <typename... T>
simd<T0, SZ> operator()(simd<T1, SZ> v1, simd<T1, SZ> v2) {

if constexpr (detail::is_generic_floating_point_v<T1>) {
using CppT = __ESIMD_DNS::element_type_traits<T1>::EnclosingCppT;
return __ESIMD_DNS::convert_vector<T1, CppT, SZ>(
__spirv_ocl_fmin<CppT, SZ>(
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v1.data()),
__ESIMD_DNS::convert_vector<CppT, T1, SZ>(v2.data())));
} else if constexpr (std::is_unsigned<T1>::value) {
return __esimd_umin<T1, SZ>(v1.data(), v2.data());
} else {
return __esimd_smin<T1, SZ>(v1.data(), v2.data());
}
return __ESIMD_DNS::convert_vector<T0, T1, SZ>(
__ESIMD_NS::min(v1, v2).data());
}
};

Expand Down
88 changes: 47 additions & 41 deletions sycl/test-e2e/ESIMD/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,16 @@
#include "esimd_test_utils.hpp"

using namespace sycl;
constexpr unsigned InputSize = 32;
constexpr unsigned OutputSize = 4;
constexpr unsigned VL = 32;
constexpr unsigned GroupSize = 1;

typedef short TYPE;

int main(void) {
constexpr unsigned InputSize = 32;
constexpr unsigned OutputSize = 4;
constexpr unsigned VL = 32;
constexpr unsigned GroupSize = 1;

queue q(esimd_test::ESIMDSelector, esimd_test::createExceptionHandler());

auto dev = q.get_device();
std::cout << "Running on " << dev.get_info<info::device::name>() << "\n";
TYPE *A = malloc_shared<TYPE>(InputSize, q);
int *B = malloc_shared<int>(OutputSize, q);
template <typename TIn, typename TOut> bool test(sycl::queue &Queue) {
std::cout << "TIn= " << esimd_test::type_name<TIn>()
<< " TOut= " << esimd_test::type_name<TOut>() << std::endl;
TIn *A = malloc_shared<TIn>(InputSize, Queue);
TOut *B = malloc_shared<TOut>(OutputSize, Queue);

for (unsigned i = 0; i < InputSize; ++i) {
if (i == 19) {
Expand All @@ -41,49 +36,48 @@ int main(void) {
range<1> TaskRange{GroupSize};
nd_range<1> Range(GroupRange, TaskRange);

auto e = q.submit([&](handler &cgh) {
cgh.parallel_for<class Test>(
GroupRange * TaskRange, [=](id<1> i) SYCL_ESIMD_KERNEL {
using namespace sycl::ext::intel::esimd;
auto e = Queue.submit([&](handler &cgh) {
cgh.parallel_for(GroupRange * TaskRange, [=](id<1> i) SYCL_ESIMD_KERNEL {
using namespace sycl::ext::intel::esimd;

simd<TYPE, VL> va;
va.copy_from(A + i * VL);
simd<int, OutputSize> vb;
simd<TIn, VL> va;
va.copy_from(A + i * VL);
simd<TOut, OutputSize> vb;

vb[0] = reduce<int>(va, std::plus<>());
vb[1] = reduce<int>(va, std::multiplies<>());
vb[2] = hmax<int>(va);
vb[3] = hmin<int>(va);
vb[0] = reduce<TOut>(va, std::plus<>());
vb[1] = reduce<TOut>(va, std::multiplies<>());
vb[2] = hmax<TOut>(va);
vb[3] = hmin<TOut>(va);

vb.copy_to(B + i * VL);
});
vb.copy_to(B + i * VL);
});
});
e.wait();
} catch (sycl::exception const &e) {
std::cout << "SYCL exception caught: " << e.what() << '\n';
free(A, q);
free(B, q);
return 1;
free(A, Queue);
free(B, Queue);
return 0;
}

auto compute_reduce_sum = [](TYPE A[InputSize]) -> int {
int retv = A[0];
auto compute_reduce_sum = [](TIn A[InputSize]) -> TOut {
TOut retv = A[0];
for (int i = 1; i < InputSize; i++) {
retv += A[i];
}
return retv;
};

auto compute_reduce_prod = [](TYPE A[InputSize]) -> int {
int retv = A[0];
auto compute_reduce_prod = [](TIn A[InputSize]) -> TOut {
TOut retv = A[0];
for (int i = 1; i < InputSize; i++) {
retv *= A[i];
}
return retv;
};

auto compute_reduce_max = [](TYPE A[InputSize]) -> int {
int retv = A[0];
auto compute_reduce_max = [](TIn A[InputSize]) -> TOut {
TOut retv = A[0];
for (int i = 1; i < InputSize; i++) {
if (A[i] > retv) {
retv = A[i];
Expand All @@ -92,8 +86,8 @@ int main(void) {
return retv;
};

auto compute_reduce_min = [](TYPE A[InputSize]) -> int {
int retv = A[0];
auto compute_reduce_min = [](TIn A[InputSize]) -> TOut {
TOut retv = A[0];
for (int i = 1; i < InputSize; i++) {
if (A[i] < retv) {
retv = A[i];
Expand All @@ -103,7 +97,7 @@ int main(void) {
};

bool TestPass = true;
int ref = compute_reduce_sum(A);
TOut ref = compute_reduce_sum(A);
if (B[0] != ref) {
std::cout << "Incorrect sum " << B[0] << ", expected " << ref << "\n";
TestPass = false;
Expand All @@ -127,8 +121,20 @@ int main(void) {
TestPass = false;
}

free(A, q);
free(B, q);
free(A, Queue);
free(B, Queue);
return TestPass;
}

int main(void) {

queue Queue(esimd_test::ESIMDSelector, esimd_test::createExceptionHandler());

esimd_test::printTestLabel(Queue);

bool TestPass = test<int16_t, int32_t>(Queue);
TestPass &= test<uint16_t, uint32_t>(Queue);
TestPass &= test<float, float>(Queue);

if (!TestPass) {
std::cout << "Failed\n";
Expand Down

0 comments on commit c6b611d

Please sign in to comment.