From dda72a02e32f24eadc859ccf7bb1f4d25dc5ade0 Mon Sep 17 00:00:00 2001 From: i-evi Date: Fri, 25 Sep 2020 01:24:30 +0800 Subject: [PATCH] load bn --- src/cc_normfn.c | 39 +++++++++++++++++++++++++++++++++++++++ src/cc_normfn.h | 12 ++++++++++++ 2 files changed, 51 insertions(+) diff --git a/src/cc_normfn.c b/src/cc_normfn.c index 17839a7..24d3e91 100644 --- a/src/cc_normfn.c +++ b/src/cc_normfn.c @@ -8,6 +8,45 @@ #include "global_fn_cfg.h" extern fn_batch_norm _batch_norm; +static cc_float32 cc_bn_epsilon_dfl_fp32 = CC_BN_EPSILON_DFL_FP32; + +cc_tensor_t *cc_load_bin_bnpara(const char *w_path, const char *b_path, + const char *m_path, const char *v_path, const char *e_path, + cc_int32 nchl, cc_dtype dtype, const char *name) +{ + static cc_int32 shape[] = {0, 1, 1, 0}; + cc_int32 i; + cc_tensor_t *para; + cc_tensor_t *tsr[CC_BN_PARAMETERS]; + shape[0] = nchl; + tsr[CC_BN_OFFSET_GAMMA] = cc_load_bin(w_path, shape, dtype, NULL); + tsr[CC_BN_OFFSET_BETA] = cc_load_bin(b_path, shape, dtype, NULL); + tsr[CC_BN_OFFSET_MEAN] = cc_load_bin(m_path, shape, dtype, NULL); + tsr[CC_BN_OFFSET_VAR] = cc_load_bin(v_path, shape, dtype, NULL); + if (e_path) { + tsr[CC_BN_OFFSET_EPSILON] = + cc_load_bin(e_path, shape, dtype, NULL); + } else { + tsr[CC_BN_OFFSET_EPSILON] = + cc_create_tensor(shape, dtype, NULL); + switch (dtype) { + case CC_FLOAT32: + cc_set_tensor(tsr[CC_BN_OFFSET_EPSILON], + &cc_bn_epsilon_dfl_fp32); + break; + default: + utlog_format(UTLOG_ERR, + "[%s: %d] Unsupported dtype %x\n", + __FILE__, __LINE__, dtype); + break; + } + } + para = cc_tensor_stack(tsr, CC_BN_PARAMETERS, 1, name); + for (i = 0; i < CC_BN_PARAMETERS; ++i) + cc_free_tensor(tsr[i]); + return para; +} + cc_tensor_t *cc_batch_norm2d(cc_tensor_t *inp, const cc_tensor_t *para, const char *name) { diff --git a/src/cc_normfn.h b/src/cc_normfn.h index 5343d9d..3d26495 100644 --- a/src/cc_normfn.h +++ b/src/cc_normfn.h @@ -28,6 +28,18 @@ enum cc_batch_norm_paraoff { }; #endif +#define CC_BN_EPSILON_DFL_FP32 1e-3 + +cc_tensor_t *cc_load_bin_bnpara( + const char *w_path, /* Gamma */ + const char *b_path, /* Beta */ + const char *m_path, /* Mean */ + const char *v_path, /* Var */ + const char *e_path, /* Epsilon */ + cc_int32 nchl, /* Channel */ + cc_dtype dtype, + const char *name); + cc_tensor_t *cc_batch_norm2d(cc_tensor_t *inp, const cc_tensor_t *para, const char *name);