From 7578887cb069ee01c05aef8aae1abeb8f9e4ba4c Mon Sep 17 00:00:00 2001 From: i-evi Date: Fri, 30 Oct 2020 00:36:49 +0800 Subject: [PATCH] conv --- src/additional/ecpufn/ecpufn.c | 63 ++++++++++++++++++++++++++++++++++ src/additional/ecpufn/ecpufn.h | 6 ++++ 2 files changed, 69 insertions(+) diff --git a/src/additional/ecpufn/ecpufn.c b/src/additional/ecpufn/ecpufn.c index 656a6c1..b859f5c 100644 --- a/src/additional/ecpufn/ecpufn.c +++ b/src/additional/ecpufn/ecpufn.c @@ -88,6 +88,16 @@ void ecpu_conv2d_f32(const f32 *in, f32 *out, i32 ix, i32 iy, ALT_CONV2D_F32_K3SX( in, out, ix, iy, ox, oy, sx, sy, k, kw); break; +#ifdef ALT_CONV2D_F32_K4SX + case 4: + ALT_CONV2D_F32_K4SX(in, out, ix, iy, ox, oy, sx, sy, k, kw); + break; +#endif +#ifdef ALT_CONV2D_F32_K5SX + case 5: + ALT_CONV2D_F32_K5SX(in, out, ix, iy, ox, oy, sx, sy, k, kw); + break; +#endif #endif default: naive_conv2d_f32(in, out, ix, iy, ox, oy, sx, sy, k, kw); @@ -232,6 +242,59 @@ void sse_conv2d_f32_k3sx(const f32 *in, f32 *out, i32 ix, i32 iy, } } } + +void sse_conv2d_f32_k4sx(const f32 *in, f32 *out, i32 ix, i32 iy, + i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw) +{ + i32 i, j, u, off; + __m128 lvk[4]; + __m128 blk, acc; + f32 *res = (f32*)&(acc); + for (u = 0; u < kw; ++u) { + memcpy(&(lvk[u]), k + u * kw, sizeof(f32) * kw); + } + for (i = 0; i <= iy - kw; i += sy) { + for (j = 0; j <= ix - kw; j += sx) { + acc = _mm_setzero_ps(); + off = i * ix + j; + for (u = 0; u < kw; ++u) { + blk = _mm_loadu_ps(in + off); + acc += blk * lvk[u]; + off += ix; + } + *out++ = res[0] + res[1] + res[2] + res[3]; + } + } +} + +void sse_conv2d_f32_k5sx(const f32 *in, f32 *out, i32 ix, i32 iy, + i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw) +{ + i32 i, j, u, off; + __m128 vk4[5]; + __m128 blkv, accv; + f32 blks, accs, sk5[5]; + f32 *res = (f32*)&(accv); + for (u = 0; u < kw; ++u) { + memcpy(&(vk4[u]), k + u * kw, sizeof(f32) * 4); + sk5[u] = *(k + u * kw + 4); + } + for (i = 0; i <= iy - kw; i += sy) { + for (j = 0; j <= ix - kw; j += sx) { + accv = _mm_setzero_ps(); + accs = 0; + off = i * ix + j; + for (u = 0; u < kw; ++u) { + blkv = _mm_loadu_ps(in + off); + accv += blkv * vk4[u]; + blks = *(in + off + 4); + accs += blks * sk5[u]; + off += ix; + } + *out++ = res[0] + res[1] + res[2] + res[3] + accs; + } + } +} #endif /* __SSE__ */ #define NAIVE_DOTPROD_IMPLEMENTATION(dtype) \ diff --git a/src/additional/ecpufn/ecpufn.h b/src/additional/ecpufn/ecpufn.h index 707a7ef..39c66da 100644 --- a/src/additional/ecpufn/ecpufn.h +++ b/src/additional/ecpufn/ecpufn.h @@ -46,6 +46,8 @@ ECPU_CONV2D_DECLARATION (f64); #define ALT_CONV2D_F32_K2SX sse_conv2d_f32_k2sx #define ALT_CONV2D_F32_K3S1 sse_conv2d_f32_k3s1 #define ALT_CONV2D_F32_K3SX sse_conv2d_f32_k3sx + #define ALT_CONV2D_F32_K4SX sse_conv2d_f32_k4sx + #define ALT_CONV2D_F32_K5SX sse_conv2d_f32_k5sx #endif void sse_conv2d_f32_k1s1(const f32 *in, f32 *out, i32 ix, i32 iy, @@ -56,6 +58,10 @@ void sse_conv2d_f32_k3s1(const f32 *in, f32 *out, i32 ix, i32 iy, i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw); void sse_conv2d_f32_k3sx(const f32 *in, f32 *out, i32 ix, i32 iy, i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw); +void sse_conv2d_f32_k4sx(const f32 *in, f32 *out, i32 ix, i32 iy, + i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw); +void sse_conv2d_f32_k5sx(const f32 *in, f32 *out, i32 ix, i32 iy, + i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw); /* void ecpu_dot_prod_xxx(xxx *in, xxx *out, xxx *w, i32 iw); */ #define ECPU_DOTPROD_DECLARATION(dtype) \