From 25e7aa6d0128d55d42d8d8988b1d5ce6fa7c452e Mon Sep 17 00:00:00 2001 From: i-evi Date: Mon, 26 Oct 2020 19:44:39 +0800 Subject: [PATCH] pool --- src/additional/ecpufn/ecpufn.c | 220 +++++++++++++++++++++++++++++++-- src/additional/ecpufn/ecpufn.h | 40 ++++-- 2 files changed, 240 insertions(+), 20 deletions(-) diff --git a/src/additional/ecpufn/ecpufn.c b/src/additional/ecpufn/ecpufn.c index c125324..20da372 100644 --- a/src/additional/ecpufn/ecpufn.c +++ b/src/additional/ecpufn/ecpufn.c @@ -124,8 +124,8 @@ void sse_conv2d_f32_k2sx(const f32 *in, f32 *out, i32 ix, i32 iy, for (j = 0; j <= ix - kw; j += sx) { off = i * ix + j; blk = _mm_shuffle_ps( - _mm_loadu_ps(in + off), // l2 - _mm_loadu_ps(in + off + ix - 2), // h2 + _mm_loadu_ps(in + off), + _mm_loadu_ps(in + off + ix - 2), _MM_SHUFFLE(3, 2, 1, 0)); blk = blk * lvk; *res += res[1] + res[2] + res[3]; @@ -328,17 +328,17 @@ void naive_max_pool2d_ ## dtype (const dtype *in, \ i32 oy = y / s; \ i32 i, j, k, l; \ const dtype *curr; \ - dtype v_max; \ + dtype max; \ for (i = 0; i < oy; ++i) { \ for (j = 0; j < ox; ++j) { \ - v_max = *(in + s * i * x + s * j); \ + max = *(in + s * i * x + s * j); \ for (k = 0; k < s; ++k) { \ for (l = 0; l < s; ++l) { \ curr = in + (s * i + k) * x + j * s + l; \ - v_max = *curr > v_max ? *curr : v_max; \ + max = *curr > max ? *curr : max; \ } \ } \ - *out++ = v_max; \ + *out++ = max; \ } \ } \ } @@ -372,7 +372,6 @@ DFL_MAXPOOL2D_IMPLEMENTATION (u64); /* DFL_MAXPOOL2D_IMPLEMENTATION (f32); */ DFL_MAXPOOL2D_IMPLEMENTATION (f64); - void ecpu_max_pool2d_f32(const f32 *in, f32 *out, i32 x, i32 y, i32 s) { switch (s) { @@ -380,6 +379,16 @@ void ecpu_max_pool2d_f32(const f32 *in, f32 *out, i32 x, i32 y, i32 s) case 2: ALT_MAXPOOL2D_F32_S2(in, out, x, y, s); break; +#endif +#ifdef ALT_MAXPOOL2D_F32_S3 + case 3: + ALT_MAXPOOL2D_F32_S3(in, out, x, y, s); + break; +#endif +#ifdef ALT_MAXPOOL2D_F32_S4 + case 4: + ALT_MAXPOOL2D_F32_S4(in, out, x, y, s); + break; #endif default: naive_max_pool2d_f32(in, out, x, y, s); @@ -395,8 +404,8 @@ void sse_max_pool2d_f32_s2(const f32 *in, f32 *out, i32 x, i32 y, i32 s) __m128 blkb; f32 maxl, maxh, a, b; f32 *pa = (f32*)&blka, *pb = ((f32*)&blka) + 2; - for (i = 0; i < (y - 1); i += 2) { - for (j = 0; j < (x - 3); j += 4) { + for (i = 0; i <= (y - 2); i += 2) { + for (j = 0; j <= (x - 4); j += 4) { blka = _mm_loadu_ps(in + (i + 0) * x + j); blkb = _mm_loadu_ps(in + (i + 1) * x + j); blka = _mm_max_ps(blka, blkb); @@ -414,4 +423,197 @@ void sse_max_pool2d_f32_s2(const f32 *in, f32 *out, i32 x, i32 y, i32 s) } } } + +void sse_max_pool2d_f32_s3(const f32 *in, f32 *out, i32 x, i32 y, i32 s) +{ + i32 i, j; + __m128 blka, blkb, blkc; + f32 t; + f32 *p = (f32*)&blka; + for (i = 0; i <= (y - 3); i += 3) { + for (j = 0; j <= (x - 4); j += 3) { + blka = _mm_loadu_ps(in + (i + 0) * x + j); + blkb = _mm_loadu_ps(in + (i + 1) * x + j); + blkc = _mm_loadu_ps(in + (i + 2) * x + j); + blka = _mm_max_ps(blka, blkb); + blka = _mm_max_ps(blka, blkc); + t = p[0] > p[1] ? p[0] : p[1]; + *out++ = t > p[2] ? t : p[2]; + } + if (j == x - 3) { + blka = _mm_loadu_ps(in + (i + 0) * x + j - 1); + blkb = _mm_loadu_ps(in + (i + 1) * x + j - 1); + blkc = _mm_loadu_ps(in + (i + 2) * x + j - 1); + blka = _mm_max_ps(blka, blkb); + blka = _mm_max_ps(blka, blkc); + t = p[1] > p[2] ? p[1] : p[2]; + *out++ = t > p[3] ? t : p[3]; + } + } +} + +void sse_max_pool2d_f32_s4(const f32 *in, f32 *out, i32 x, i32 y, i32 s) +{ + i32 i, j; + __m128 blka, blkb, blkc, blkd; + f32 maxh, maxl; + f32 *p = (f32*)&blka; + for (i = 0; i <= (y - 4); i += 4) { + for (j = 0; j <= (x - 4); j += 4) { + blka = _mm_loadu_ps(in + (i + 0) * x + j); + blkb = _mm_loadu_ps(in + (i + 1) * x + j); + blkc = _mm_loadu_ps(in + (i + 2) * x + j); + blkd = _mm_loadu_ps(in + (i + 3) * x + j); + blka = _mm_max_ps(blka, blkb); + blkc = _mm_max_ps(blkc, blkd); + blka = _mm_max_ps(blka, blkc); + maxl = p[0] > p[1] ? p[0] : p[1]; + maxh = p[2] > p[3] ? p[2] : p[3]; + *out++ = maxh > maxl ? maxh : maxl; + } + } +} +#endif + +#define NAIVE_AVGPOOL2D_IMPLEMENTATION(dtype) \ +void naive_avg_pool2d_ ## dtype (const dtype *in, \ + dtype *out, i32 x, i32 y, i32 s) \ +{ \ + i32 ox = x / s; \ + i32 oy = y / s; \ + i32 i, j, k, l, nelem; \ + const dtype *curr; \ + dtype sum; \ + nelem = s * s; \ + for (i = 0; i < oy; ++i) { \ + for (j = 0; j < ox; ++j) { \ + sum = 0; \ + for (k = 0; k < s; ++k) { \ + for (l = 0; l < s; ++l) { \ + curr = in + (s * i + k) * x + j * s + l; \ + sum += *curr; \ + } \ + } \ + *out++ = sum / nelem; \ + } \ + } \ +} + +NAIVE_AVGPOOL2D_IMPLEMENTATION (i8); +NAIVE_AVGPOOL2D_IMPLEMENTATION (u8); +NAIVE_AVGPOOL2D_IMPLEMENTATION (i16); +NAIVE_AVGPOOL2D_IMPLEMENTATION (u16); +NAIVE_AVGPOOL2D_IMPLEMENTATION (i32); +NAIVE_AVGPOOL2D_IMPLEMENTATION (u32); +NAIVE_AVGPOOL2D_IMPLEMENTATION (i64); +NAIVE_AVGPOOL2D_IMPLEMENTATION (u64); +NAIVE_AVGPOOL2D_IMPLEMENTATION (f32); +NAIVE_AVGPOOL2D_IMPLEMENTATION (f64); + +#define DFL_AVGPOOL2D_IMPLEMENTATION(dtype) \ +void ecpu_avg_pool2d_ ## dtype ( \ + const dtype *in, dtype *out, i32 x, i32 y, i32 s) \ +{ \ + naive_avg_pool2d_ ## dtype(in, out, x, y, s); \ +} + +DFL_AVGPOOL2D_IMPLEMENTATION (i8); +DFL_AVGPOOL2D_IMPLEMENTATION (u8); +DFL_AVGPOOL2D_IMPLEMENTATION (i16); +DFL_AVGPOOL2D_IMPLEMENTATION (u16); +DFL_AVGPOOL2D_IMPLEMENTATION (i32); +DFL_AVGPOOL2D_IMPLEMENTATION (u32); +DFL_AVGPOOL2D_IMPLEMENTATION (i64); +DFL_AVGPOOL2D_IMPLEMENTATION (u64); +/* DFL_AVGPOOL2D_IMPLEMENTATION (f32); */ +DFL_AVGPOOL2D_IMPLEMENTATION (f64); + +void ecpu_avg_pool2d_f32(const f32 *in, f32 *out, i32 x, i32 y, i32 s) +{ + switch (s) { +#ifdef ALT_AVGPOOL2D_F32_S2 + case 2: + ALT_AVGPOOL2D_F32_S2(in, out, x, y, s); + break; +#endif +#ifdef ALT_AVGPOOL2D_F32_S3 + case 3: + ALT_AVGPOOL2D_F32_S3(in, out, x, y, s); + break; +#endif +#ifdef ALT_AVGPOOL2D_F32_S4 + case 4: + ALT_AVGPOOL2D_F32_S4(in, out, x, y, s); + break; +#endif + default: + naive_avg_pool2d_f32(in, out, x, y, s); + break; + } +} + +#if defined(__x86_64) && defined(__SSE__) +void sse_avg_pool2d_f32_s2(const f32 *in, f32 *out, i32 x, i32 y, i32 s) +{ + i32 i, j; + __m128 blka, blkb; + f32 sum; + f32 *pa = (f32*)&blka, *pb = ((f32*)&blka) + 2; + for (i = 0; i <= (y - 2); i += 2) { + for (j = 0; j <= (x - 4); j += 4) { + blka = _mm_loadu_ps(in + (i + 0) * x + j); + blkb = _mm_loadu_ps(in + (i + 1) * x + j); + blka = blka + blkb; + *out++ = (pa[0] + pa[1]) / 4; + *out++ = (pb[0] + pb[1]) / 4; + } + for (; j < (x - 1); j += 2) { + sum = *(in + (i + 0) * x + j + 0); + sum += *(in + (i + 0) * x + j + 1); + sum += *(in + (i + 1) * x + j + 0); + sum += *(in + (i + 1) * x + j + 1); + *out++ = sum / 4; + } + } +} + +void sse_avg_pool2d_f32_s3(const f32 *in, f32 *out, i32 x, i32 y, i32 s) +{ + i32 i, j; + __m128 blka, blkb, blkc; + f32 *p = (f32*)&blka; + for (i = 0; i <= (y - 3); i += 3) { + for (j = 0; j <= (x - 4); j += 3) { + blka = _mm_loadu_ps(in + (i + 0) * x + j); + blkb = _mm_loadu_ps(in + (i + 1) * x + j); + blkc = _mm_loadu_ps(in + (i + 2) * x + j); + blka = blka + blkb + blkc; + *out++ = (p[0] + p[1] + p[2]) / 9; + } + if (j == x - 3) { + blka = _mm_loadu_ps(in + (i + 0) * x + j - 1); + blkb = _mm_loadu_ps(in + (i + 1) * x + j - 1); + blkc = _mm_loadu_ps(in + (i + 2) * x + j - 1); + blka = blka + blkb + blkc; + *out++ = (p[1] + p[2] + p[3]) / 9; + } + } +} + +void sse_avg_pool2d_f32_s4(const f32 *in, f32 *out, i32 x, i32 y, i32 s) +{ + i32 i, j; + __m128 blka, blkb, blkc, blkd; + f32 *p = (f32*)&blka; + for (i = 0; i <= (y - 4); i += 4) { + for (j = 0; j <= (x - 4); j += 4) { + blka = _mm_loadu_ps(in + (i + 0) * x + j); + blkb = _mm_loadu_ps(in + (i + 1) * x + j); + blkc = _mm_loadu_ps(in + (i + 2) * x + j); + blkd = _mm_loadu_ps(in + (i + 3) * x + j); + blka = blka + blkb + blkc + blkd; + *out++ = (p[0] + p[1] + p[2] + p[3]) / 16; + } + } +} #endif diff --git a/src/additional/ecpufn/ecpufn.h b/src/additional/ecpufn/ecpufn.h index d0a4245..a92d2cc 100644 --- a/src/additional/ecpufn/ecpufn.h +++ b/src/additional/ecpufn/ecpufn.h @@ -86,28 +86,46 @@ void avx_dot_prod_f32(const f32 *in, f32 *out, const f32 *w, i32 iw); void sse_dot_prod_f32(const f32 *in, f32 *out, const f32 *w, i32 iw); /* void ecpu_max_pool2d_xxx(xxx *in, xxx *out, i32 x, i32 y, i32 s); */ -#define ECPU_MAXPOOL2D_DECLARATION(dtype) \ +/* void ecpu_avg_pool2d_xxx(xxx *in, xxx *out, i32 x, i32 y, i32 s); */ +#define ECPU_POOL2D_DECLARATION(dtype) \ void ecpu_max_pool2d_ ## dtype ( \ const dtype *in, dtype *out, i32 x, i32 y, i32 s); \ void naive_max_pool2d_ ## dtype ( \ + const dtype *in, dtype *out, i32 x, i32 y, i32 s); \ +void ecpu_avg_pool2d_ ## dtype ( \ + const dtype *in, dtype *out, i32 x, i32 y, i32 s); \ +void naive_avg_pool2d_ ## dtype ( \ const dtype *in, dtype *out, i32 x, i32 y, i32 s); -ECPU_MAXPOOL2D_DECLARATION (i8); -ECPU_MAXPOOL2D_DECLARATION (u8); -ECPU_MAXPOOL2D_DECLARATION (i16); -ECPU_MAXPOOL2D_DECLARATION (u16); -ECPU_MAXPOOL2D_DECLARATION (i32); -ECPU_MAXPOOL2D_DECLARATION (u32); -ECPU_MAXPOOL2D_DECLARATION (i64); -ECPU_MAXPOOL2D_DECLARATION (u64); -ECPU_MAXPOOL2D_DECLARATION (f32); -ECPU_MAXPOOL2D_DECLARATION (f64); +ECPU_POOL2D_DECLARATION (i8); +ECPU_POOL2D_DECLARATION (u8); +ECPU_POOL2D_DECLARATION (i16); +ECPU_POOL2D_DECLARATION (u16); +ECPU_POOL2D_DECLARATION (i32); +ECPU_POOL2D_DECLARATION (u32); +ECPU_POOL2D_DECLARATION (i64); +ECPU_POOL2D_DECLARATION (u64); +ECPU_POOL2D_DECLARATION (f32); +ECPU_POOL2D_DECLARATION (f64); #if defined(__x86_64) && defined(__SSE__) #define ALT_MAXPOOL2D_F32_S2 sse_max_pool2d_f32_s2 + #define ALT_MAXPOOL2D_F32_S3 sse_max_pool2d_f32_s3 + #define ALT_MAXPOOL2D_F32_S4 sse_max_pool2d_f32_s4 + #define ALT_AVGPOOL2D_F32_S2 sse_avg_pool2d_f32_s2 + #define ALT_AVGPOOL2D_F32_S3 sse_avg_pool2d_f32_s3 + #define ALT_AVGPOOL2D_F32_S4 sse_avg_pool2d_f32_s4 #endif void sse_max_pool2d_f32_s2(const f32 *in, f32 *out, i32 x, i32 y, i32 s); +void sse_avg_pool2d_f32_s2(const f32 *in, f32 *out, i32 x, i32 y, i32 s); + +/* ((s > x) && (s >= y))*/ +void sse_max_pool2d_f32_s3(const f32 *in, f32 *out, i32 x, i32 y, i32 s); +void sse_avg_pool2d_f32_s3(const f32 *in, f32 *out, i32 x, i32 y, i32 s); + +void sse_max_pool2d_f32_s4(const f32 *in, f32 *out, i32 x, i32 y, i32 s); +void sse_avg_pool2d_f32_s4(const f32 *in, f32 *out, i32 x, i32 y, i32 s); #ifdef __cplusplus }