From a2bd8ebf056c6b7b94c36b80d9e500a232980985 Mon Sep 17 00:00:00 2001 From: i-evi Date: Tue, 13 Oct 2020 15:15:10 +0800 Subject: [PATCH] update fn --- src/cc_cpufn.c | 80 ++++++++++++++++----------------------------- src/cc_cpufn.h | 7 ++-- src/cc_fullycon.c | 21 +++++++----- src/global_fn_cfg.c | 6 ++-- src/global_fn_cfg.h | 7 ++-- 5 files changed, 50 insertions(+), 71 deletions(-) diff --git a/src/cc_cpufn.c b/src/cc_cpufn.c index b9d5ac4..3c61e44 100644 --- a/src/cc_cpufn.c +++ b/src/cc_cpufn.c @@ -410,58 +410,6 @@ void cc_cpu_conv2d(const void *inp, void *oup, cc_int32 x, cc_int32 y, } } -#define CC_CPU_FULLY_CONNECTED_PROCESS \ -for (i = 0; i < ow; ++i) { \ - oup[i] = 0; \ - for (j = 0; j < iw; ++j) { \ - oup[i] += (*(inp + j)) * (*(w + iw * i + j)); \ - } \ - if (b) \ - oup[i] += b[i]; \ -} - -static void cc_cpu_fully_connected_float32( - cc_float32 *inp, cc_float32 *oup, cc_float32 *w, - cc_float32 *b, cc_int32 iw, cc_int32 ow) -{ - cc_int32 i, j; -#ifdef ENABLE_OPENMP -#pragma omp parallel for private(i, j) -#endif - CC_CPU_FULLY_CONNECTED_PROCESS; -} - -static void cc_cpu_fully_connected_float64( - cc_float64 *inp, cc_float64 *oup, cc_float64 *w, - cc_float64 *b, cc_int32 iw, cc_int32 ow) -{ - cc_int32 i, j; -#ifdef ENABLE_OPENMP -#pragma omp parallel for private(i, j) -#endif - CC_CPU_FULLY_CONNECTED_PROCESS; -} - -void cc_cpu_fully_connected(const void *inp, - void *oup, const void *w, const void *b, - cc_int32 iw, cc_int32 ow, cc_dtype dt) -{ - switch (dt) { - case CC_FLOAT32: - cc_cpu_fully_connected_float32( - (cc_float32*)inp, (cc_float32*)oup, - (cc_float32*)w, (cc_float32*)b, iw, ow); - break; - case CC_FLOAT64: - cc_cpu_fully_connected_float64( - (cc_float64*)inp, (cc_float64*)oup, - (cc_float64*)w, (cc_float64*)b, iw, ow); - break; - default: - UNSUPPORTED_DTYPE_LOG(dt); - } -} - #ifndef CC_BN_OFFSET_CFG #define CC_BN_OFFSET_CFG enum { @@ -889,6 +837,34 @@ void cc_cpu_array_div_ew(void *oup, } } +#define ARRAY_DOTPROD_CASE(_DT, _dt) \ +case _DT: \ +*((_dt*)x) = 0; \ +for (i = 0; i < arrlen; ++i) \ + *((_dt*)x) += *(((_dt*)a) + i) * *(((_dt*)b) + i); \ +break; +void cc_cpu_array_dot_prod( + const void *a, const void *b, int arrlen, void *x, int dt) +{ + cc_int32 i; + switch (dt) { + ARRAY_DOTPROD_CASE(CC_UINT8, cc_uint8); + ARRAY_DOTPROD_CASE(CC_UINT16, cc_uint16); + ARRAY_DOTPROD_CASE(CC_UINT32, cc_uint32); + ARRAY_DOTPROD_CASE(CC_UINT64, cc_uint64); + ARRAY_DOTPROD_CASE(CC_INT8, cc_int8); + ARRAY_DOTPROD_CASE(CC_INT16, cc_int16); + ARRAY_DOTPROD_CASE(CC_INT32, cc_int32); + ARRAY_DOTPROD_CASE(CC_INT64, cc_int64); + ARRAY_DOTPROD_CASE(CC_FLOAT32, cc_float32); + ARRAY_DOTPROD_CASE(CC_FLOAT64, cc_float64); + default: + utlog_format(UTLOG_ERR, + "cc_array: unsupported dtype %x\n", dt); + break; + } +} + #define ARRAY_SUM_CASE(_DT, _dt) \ case _DT: \ ARRAY_SUM(arr, arrlen, _dt, x); \ diff --git a/src/cc_cpufn.h b/src/cc_cpufn.h index fbf9360..41f1097 100644 --- a/src/cc_cpufn.h +++ b/src/cc_cpufn.h @@ -23,10 +23,6 @@ void cc_cpu_conv2d(const void *inp, void *oup, cc_int32 x, cc_int32 y, cc_int32 oup_x, cc_int32 oup_y, cc_int32 sx, cc_int32 sy, const void *filter, cc_int32 fw, cc_dtype dt); -void cc_cpu_fully_connected(const void *inp, - void *oup, const void *w, const void *b, - cc_int32 iw, cc_int32 ow, cc_dtype dt); - void cc_cpu_batch_norm(void *inp, cc_int32 len, const void *bnpara, cc_dtype dt); @@ -72,6 +68,9 @@ void cc_cpu_array_mul_ew(void *oup, void cc_cpu_array_div_ew(void *oup, int arrlen, const void *a, const void *b, int dt); +void cc_cpu_array_dot_prod( + const void *a, const void *b, int arrlen, void *x, int dt); + void cc_cpu_array_sum (const void *arr, int arrlen, void *x, int dt); void cc_cpu_array_mean(const void *arr, int arrlen, void *x, int dt); diff --git a/src/cc_fullycon.c b/src/cc_fullycon.c index 34ebe75..f40be69 100644 --- a/src/cc_fullycon.c +++ b/src/cc_fullycon.c @@ -8,11 +8,12 @@ #include "cc_fullycon.h" #include "global_fn_cfg.h" -extern fn_fully_connected _fully_connected; +extern fn_array_dot_prod _array_dot_prod; cc_tensor_t *cc_fully_connected(const cc_tensor_t *inp, const cc_tensor_t *w, const cc_tensor_t *b, const char *name) { + cc_int32 i, mmsize, dtsize; cc_tensor_t *oup = NULL; cc_int32 shape[CC_CNN2D_SHAPE] = {0}; #ifdef ENABLE_CC_ASSERT @@ -33,13 +34,17 @@ cc_tensor_t *cc_fully_connected(const cc_tensor_t *inp, shape[CC_CNN2D_SHAPE_W] = 1; oup = cc_create(shape, *inp->dtype, name); } + dtsize = cc_dtype_size(*inp->dtype); + mmsize = inp->shape[CC_CNN2D_SHAPE_C] * dtsize; +#ifdef ENABLE_OPENMP + #pragma omp parallel for private(i) +#endif + for (i = 0; i < w->shape[CC_CONV2D_KERNEL_O]; ++i) { + _array_dot_prod(inp->data, w->data + i * mmsize, + inp->shape[CC_CNN2D_SHAPE_C], + oup->data + i * dtsize, *inp->dtype); + } if (b) - _fully_connected(inp->data, oup->data, - w->data, b->data, w->shape[CC_CONV2D_KERNEL_I], - w->shape[CC_CONV2D_KERNEL_O], *inp->dtype); - else - _fully_connected(inp->data, oup->data, - w->data, NULL, w->shape[CC_CONV2D_KERNEL_I], - w->shape[CC_CONV2D_KERNEL_O], *inp->dtype); + oup = cc_fmap2d_bias(oup, b, oup->name); return oup; } diff --git a/src/global_fn_cfg.c b/src/global_fn_cfg.c index e72b6af..49ebfba 100644 --- a/src/global_fn_cfg.c +++ b/src/global_fn_cfg.c @@ -18,8 +18,9 @@ fn_array_sub_ew _array_sub_ew = cc_cpu_array_sub_ew; fn_array_mul_ew _array_mul_ew = cc_cpu_array_mul_ew; fn_array_div_ew _array_div_ew = cc_cpu_array_div_ew; -fn_array_sum _array_sum = cc_cpu_array_sum; -fn_array_mean _array_mean = cc_cpu_array_mean; +fn_array_dot_prod _array_dot_prod = cc_cpu_array_dot_prod; +fn_array_sum _array_sum = cc_cpu_array_sum; +fn_array_mean _array_mean = cc_cpu_array_mean; #define GLOBAL_FN_SET_ARRAY_CAST(dtype) \ fn_array_cast_ ## dtype _array_cast_ ## dtype = \ @@ -43,5 +44,4 @@ fn_activation_softmax _activation_softmax = cc_cpu_activation_softmax; fn_conv2d _conv2d = cc_cpu_conv2d; fn_max_pool2d _max_pool2d = cc_cpu_max_pool2d; fn_max_pool2d _avg_pool2d = cc_cpu_avg_pool2d; -fn_fully_connected _fully_connected = cc_cpu_fully_connected; fn_batch_norm _batch_norm = cc_cpu_batch_norm; diff --git a/src/global_fn_cfg.h b/src/global_fn_cfg.h index 40eabd0..d882b92 100644 --- a/src/global_fn_cfg.h +++ b/src/global_fn_cfg.h @@ -35,6 +35,9 @@ typedef void (*fn_array_mul_ew)(void *oup, typedef void (*fn_array_div_ew)(void *oup, int arrlen, const void *a, const void *b, int dt); +typedef void (*fn_array_dot_prod)( + const void *a, const void *b, int arrlen, void *x, int dt); + typedef void (*fn_array_sum )( const void *arr, int arrlen, void *x, int dt); typedef void (*fn_array_mean)( @@ -74,10 +77,6 @@ typedef void (*fn_conv2d)(const void *inp, void *oup, cc_int32 sx, cc_int32 sy, const void *filter, cc_int32 fw, cc_dtype dt); -typedef void (*fn_fully_connected)(const void *inp, - void *oup, const void *w, const void *b, - cc_int32 iw, cc_int32 ow, cc_dtype dt); - typedef void (*fn_batch_norm)(void *inp, cc_int32 len, const void *bnpara, cc_dtype dt);