Skip to content

Commit

Permalink
update fn
Browse files Browse the repository at this point in the history
  • Loading branch information
i-evi committed Oct 13, 2020
1 parent 292c0fb commit a2bd8eb
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 71 deletions.
80 changes: 28 additions & 52 deletions src/cc_cpufn.c
Original file line number Diff line number Diff line change
Expand Up @@ -410,58 +410,6 @@ void cc_cpu_conv2d(const void *inp, void *oup, cc_int32 x, cc_int32 y,
}
}

#define CC_CPU_FULLY_CONNECTED_PROCESS \
for (i = 0; i < ow; ++i) { \
oup[i] = 0; \
for (j = 0; j < iw; ++j) { \
oup[i] += (*(inp + j)) * (*(w + iw * i + j)); \
} \
if (b) \
oup[i] += b[i]; \
}

static void cc_cpu_fully_connected_float32(
cc_float32 *inp, cc_float32 *oup, cc_float32 *w,
cc_float32 *b, cc_int32 iw, cc_int32 ow)
{
cc_int32 i, j;
#ifdef ENABLE_OPENMP
#pragma omp parallel for private(i, j)
#endif
CC_CPU_FULLY_CONNECTED_PROCESS;
}

static void cc_cpu_fully_connected_float64(
cc_float64 *inp, cc_float64 *oup, cc_float64 *w,
cc_float64 *b, cc_int32 iw, cc_int32 ow)
{
cc_int32 i, j;
#ifdef ENABLE_OPENMP
#pragma omp parallel for private(i, j)
#endif
CC_CPU_FULLY_CONNECTED_PROCESS;
}

void cc_cpu_fully_connected(const void *inp,
void *oup, const void *w, const void *b,
cc_int32 iw, cc_int32 ow, cc_dtype dt)
{
switch (dt) {
case CC_FLOAT32:
cc_cpu_fully_connected_float32(
(cc_float32*)inp, (cc_float32*)oup,
(cc_float32*)w, (cc_float32*)b, iw, ow);
break;
case CC_FLOAT64:
cc_cpu_fully_connected_float64(
(cc_float64*)inp, (cc_float64*)oup,
(cc_float64*)w, (cc_float64*)b, iw, ow);
break;
default:
UNSUPPORTED_DTYPE_LOG(dt);
}
}

#ifndef CC_BN_OFFSET_CFG
#define CC_BN_OFFSET_CFG
enum {
Expand Down Expand Up @@ -889,6 +837,34 @@ void cc_cpu_array_div_ew(void *oup,
}
}

#define ARRAY_DOTPROD_CASE(_DT, _dt) \
case _DT: \
*((_dt*)x) = 0; \
for (i = 0; i < arrlen; ++i) \
*((_dt*)x) += *(((_dt*)a) + i) * *(((_dt*)b) + i); \
break;
void cc_cpu_array_dot_prod(
const void *a, const void *b, int arrlen, void *x, int dt)
{
cc_int32 i;
switch (dt) {
ARRAY_DOTPROD_CASE(CC_UINT8, cc_uint8);
ARRAY_DOTPROD_CASE(CC_UINT16, cc_uint16);
ARRAY_DOTPROD_CASE(CC_UINT32, cc_uint32);
ARRAY_DOTPROD_CASE(CC_UINT64, cc_uint64);
ARRAY_DOTPROD_CASE(CC_INT8, cc_int8);
ARRAY_DOTPROD_CASE(CC_INT16, cc_int16);
ARRAY_DOTPROD_CASE(CC_INT32, cc_int32);
ARRAY_DOTPROD_CASE(CC_INT64, cc_int64);
ARRAY_DOTPROD_CASE(CC_FLOAT32, cc_float32);
ARRAY_DOTPROD_CASE(CC_FLOAT64, cc_float64);
default:
utlog_format(UTLOG_ERR,
"cc_array: unsupported dtype %x\n", dt);
break;
}
}

#define ARRAY_SUM_CASE(_DT, _dt) \
case _DT: \
ARRAY_SUM(arr, arrlen, _dt, x); \
Expand Down
7 changes: 3 additions & 4 deletions src/cc_cpufn.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ void cc_cpu_conv2d(const void *inp, void *oup, cc_int32 x, cc_int32 y,
cc_int32 oup_x, cc_int32 oup_y, cc_int32 sx, cc_int32 sy,
const void *filter, cc_int32 fw, cc_dtype dt);

void cc_cpu_fully_connected(const void *inp,
void *oup, const void *w, const void *b,
cc_int32 iw, cc_int32 ow, cc_dtype dt);

void cc_cpu_batch_norm(void *inp,
cc_int32 len, const void *bnpara, cc_dtype dt);

Expand Down Expand Up @@ -72,6 +68,9 @@ void cc_cpu_array_mul_ew(void *oup,
void cc_cpu_array_div_ew(void *oup,
int arrlen, const void *a, const void *b, int dt);

void cc_cpu_array_dot_prod(
const void *a, const void *b, int arrlen, void *x, int dt);

void cc_cpu_array_sum (const void *arr, int arrlen, void *x, int dt);
void cc_cpu_array_mean(const void *arr, int arrlen, void *x, int dt);

Expand Down
21 changes: 13 additions & 8 deletions src/cc_fullycon.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
#include "cc_fullycon.h"

#include "global_fn_cfg.h"
extern fn_fully_connected _fully_connected;
extern fn_array_dot_prod _array_dot_prod;

cc_tensor_t *cc_fully_connected(const cc_tensor_t *inp,
const cc_tensor_t *w, const cc_tensor_t *b, const char *name)
{
cc_int32 i, mmsize, dtsize;
cc_tensor_t *oup = NULL;
cc_int32 shape[CC_CNN2D_SHAPE] = {0};
#ifdef ENABLE_CC_ASSERT
Expand All @@ -33,13 +34,17 @@ cc_tensor_t *cc_fully_connected(const cc_tensor_t *inp,
shape[CC_CNN2D_SHAPE_W] = 1;
oup = cc_create(shape, *inp->dtype, name);
}
dtsize = cc_dtype_size(*inp->dtype);
mmsize = inp->shape[CC_CNN2D_SHAPE_C] * dtsize;
#ifdef ENABLE_OPENMP
#pragma omp parallel for private(i)
#endif
for (i = 0; i < w->shape[CC_CONV2D_KERNEL_O]; ++i) {
_array_dot_prod(inp->data, w->data + i * mmsize,
inp->shape[CC_CNN2D_SHAPE_C],
oup->data + i * dtsize, *inp->dtype);
}
if (b)
_fully_connected(inp->data, oup->data,
w->data, b->data, w->shape[CC_CONV2D_KERNEL_I],
w->shape[CC_CONV2D_KERNEL_O], *inp->dtype);
else
_fully_connected(inp->data, oup->data,
w->data, NULL, w->shape[CC_CONV2D_KERNEL_I],
w->shape[CC_CONV2D_KERNEL_O], *inp->dtype);
oup = cc_fmap2d_bias(oup, b, oup->name);
return oup;
}
6 changes: 3 additions & 3 deletions src/global_fn_cfg.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ fn_array_sub_ew _array_sub_ew = cc_cpu_array_sub_ew;
fn_array_mul_ew _array_mul_ew = cc_cpu_array_mul_ew;
fn_array_div_ew _array_div_ew = cc_cpu_array_div_ew;

fn_array_sum _array_sum = cc_cpu_array_sum;
fn_array_mean _array_mean = cc_cpu_array_mean;
fn_array_dot_prod _array_dot_prod = cc_cpu_array_dot_prod;
fn_array_sum _array_sum = cc_cpu_array_sum;
fn_array_mean _array_mean = cc_cpu_array_mean;

#define GLOBAL_FN_SET_ARRAY_CAST(dtype) \
fn_array_cast_ ## dtype _array_cast_ ## dtype = \
Expand All @@ -43,5 +44,4 @@ fn_activation_softmax _activation_softmax = cc_cpu_activation_softmax;
fn_conv2d _conv2d = cc_cpu_conv2d;
fn_max_pool2d _max_pool2d = cc_cpu_max_pool2d;
fn_max_pool2d _avg_pool2d = cc_cpu_avg_pool2d;
fn_fully_connected _fully_connected = cc_cpu_fully_connected;
fn_batch_norm _batch_norm = cc_cpu_batch_norm;
7 changes: 3 additions & 4 deletions src/global_fn_cfg.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ typedef void (*fn_array_mul_ew)(void *oup,
typedef void (*fn_array_div_ew)(void *oup,
int arrlen, const void *a, const void *b, int dt);

typedef void (*fn_array_dot_prod)(
const void *a, const void *b, int arrlen, void *x, int dt);

typedef void (*fn_array_sum )(
const void *arr, int arrlen, void *x, int dt);
typedef void (*fn_array_mean)(
Expand Down Expand Up @@ -74,10 +77,6 @@ typedef void (*fn_conv2d)(const void *inp, void *oup,
cc_int32 sx, cc_int32 sy, const void *filter,
cc_int32 fw, cc_dtype dt);

typedef void (*fn_fully_connected)(const void *inp,
void *oup, const void *w, const void *b,
cc_int32 iw, cc_int32 ow, cc_dtype dt);

typedef void (*fn_batch_norm)(void *inp,
cc_int32 len, const void *bnpara, cc_dtype dt);

Expand Down

0 comments on commit a2bd8eb

Please sign in to comment.