diff --git a/src/additional/ecpufn/ecpufn.c b/src/additional/ecpufn/ecpufn.c index d687b60..c125324 100644 --- a/src/additional/ecpufn/ecpufn.c +++ b/src/additional/ecpufn/ecpufn.c @@ -77,7 +77,12 @@ void ecpu_conv2d_f32(const f32 *in, f32 *out, i32 ix, i32 iy, #endif #ifdef ALT_CONV2D_F32_K3SX case 3: - ALT_CONV2D_F32_K3SX(in, out, ix, iy, ox, oy, sx, sy, k, kw); + if (sx == sy && sx == 1) + ALT_CONV2D_F32_K3S1( + in, out, ix, iy, ox, oy, sx, sy, k, kw); + else + ALT_CONV2D_F32_K3SX( + in, out, ix, iy, ox, oy, sx, sy, k, kw); break; #endif default: @@ -129,6 +134,59 @@ void sse_conv2d_f32_k2sx(const f32 *in, f32 *out, i32 ix, i32 iy, } } +void sse_conv2d_f32_k3s1(const f32 *in, f32 *out, i32 ix, i32 iy, + i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw) +{ + i32 i, j, u, off; + __m128 vkl[3]; + __m128 vkh[3]; + __m128 blk, accl, acch; + f32 *resl = (f32*)&(accl); + f32 *resh = (f32*)&(acch); + for (u = 0; u < kw; ++u) { + memcpy(((f32*)&(vkl[u])) + 0, k + u * kw, sizeof(f32) * kw); + memcpy(((f32*)&(vkh[u])) + 1, k + u * kw, sizeof(f32) * kw); + } + for (i = 0; i <= iy - kw - sy; i += sy) { + for (j = 0; j < ix - kw; j += 2) { + accl = _mm_setzero_ps(); + acch = _mm_setzero_ps(); + off = i * ix + j; + for (u = 0; u < kw; ++u) { + blk = _mm_loadu_ps(in + off); + accl += blk * vkl[u]; + acch += blk * vkh[u]; + off += ix; + } + *out++ = (resl[0] + resl[1] + resl[2]); + *out++ = (resh[1] + resh[2] + resh[3]); + } + if (j == ix - kw) { + accl = _mm_setzero_ps(); + off = i * ix + j; + for (u = 0; u < kw; ++u) { + blk = _mm_loadu_ps(in + off); + accl += blk * vkl[u]; + off += ix; + } + *out++ = (resl[0] + resl[1] + resl[2]); + } + } + if (iy - kw >= i) { + off = i * ix; + for (j = -1; j <= ix - kw - 1; ++j) { + acch = _mm_setzero_ps(); + off = i * ix + j; + for (u = 0; u < kw; ++u) { + blk = _mm_loadu_ps(in + off); + acch += blk * vkh[u]; + off += ix; + } + *out++ = (resh[1] + resh[2] + resh[3]); + } + } +} + void sse_conv2d_f32_k3sx(const f32 *in, f32 *out, i32 ix, i32 iy, i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw) { diff --git a/src/additional/ecpufn/ecpufn.h b/src/additional/ecpufn/ecpufn.h index 170a29f..d0a4245 100644 --- a/src/additional/ecpufn/ecpufn.h +++ b/src/additional/ecpufn/ecpufn.h @@ -43,6 +43,7 @@ ECPU_CONV2D_DECLARATION (f64); #if defined(__x86_64) && defined(__SSE__) #define ALT_CONV2D_F32_K1S1 sse_conv2d_f32_k1s1 #define ALT_CONV2D_F32_K2SX sse_conv2d_f32_k2sx + #define ALT_CONV2D_F32_K3S1 sse_conv2d_f32_k3s1 #define ALT_CONV2D_F32_K3SX sse_conv2d_f32_k3sx #endif @@ -50,10 +51,9 @@ void sse_conv2d_f32_k1s1(const f32 *in, f32 *out, i32 ix, i32 iy, i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw); void sse_conv2d_f32_k2sx(const f32 *in, f32 *out, i32 ix, i32 iy, i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw); -void sse_conv2d_f32_k3sx(const f32 *in, f32 *out, i32 ix, i32 iy, +void sse_conv2d_f32_k3s1(const f32 *in, f32 *out, i32 ix, i32 iy, i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw); - -void avx_conv2d_f32_k3sx(const f32 *in, f32 *out, i32 ix, i32 iy, +void sse_conv2d_f32_k3sx(const f32 *in, f32 *out, i32 ix, i32 iy, i32 ox, i32 oy, i32 sx, i32 sy, const f32 *k, i32 kw); /* void ecpu_dot_prod_xxx(xxx *in, xxx *out, xxx *w, i32 iw); */