Skip to content

Commit

Permalink
load bn
Browse files Browse the repository at this point in the history
  • Loading branch information
i-evi committed Sep 24, 2020
1 parent 1eb9643 commit dda72a0
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/cc_normfn.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,45 @@
#include "global_fn_cfg.h"
extern fn_batch_norm _batch_norm;

static cc_float32 cc_bn_epsilon_dfl_fp32 = CC_BN_EPSILON_DFL_FP32;

cc_tensor_t *cc_load_bin_bnpara(const char *w_path, const char *b_path,
const char *m_path, const char *v_path, const char *e_path,
cc_int32 nchl, cc_dtype dtype, const char *name)
{
static cc_int32 shape[] = {0, 1, 1, 0};
cc_int32 i;
cc_tensor_t *para;
cc_tensor_t *tsr[CC_BN_PARAMETERS];
shape[0] = nchl;
tsr[CC_BN_OFFSET_GAMMA] = cc_load_bin(w_path, shape, dtype, NULL);
tsr[CC_BN_OFFSET_BETA] = cc_load_bin(b_path, shape, dtype, NULL);
tsr[CC_BN_OFFSET_MEAN] = cc_load_bin(m_path, shape, dtype, NULL);
tsr[CC_BN_OFFSET_VAR] = cc_load_bin(v_path, shape, dtype, NULL);
if (e_path) {
tsr[CC_BN_OFFSET_EPSILON] =
cc_load_bin(e_path, shape, dtype, NULL);
} else {
tsr[CC_BN_OFFSET_EPSILON] =
cc_create_tensor(shape, dtype, NULL);
switch (dtype) {
case CC_FLOAT32:
cc_set_tensor(tsr[CC_BN_OFFSET_EPSILON],
&cc_bn_epsilon_dfl_fp32);
break;
default:
utlog_format(UTLOG_ERR,
"[%s: %d] Unsupported dtype %x\n",
__FILE__, __LINE__, dtype);
break;
}
}
para = cc_tensor_stack(tsr, CC_BN_PARAMETERS, 1, name);
for (i = 0; i < CC_BN_PARAMETERS; ++i)
cc_free_tensor(tsr[i]);
return para;
}

cc_tensor_t *cc_batch_norm2d(cc_tensor_t *inp,
const cc_tensor_t *para, const char *name)
{
Expand Down
12 changes: 12 additions & 0 deletions src/cc_normfn.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ enum cc_batch_norm_paraoff {
};
#endif

#define CC_BN_EPSILON_DFL_FP32 1e-3

cc_tensor_t *cc_load_bin_bnpara(
const char *w_path, /* Gamma */
const char *b_path, /* Beta */
const char *m_path, /* Mean */
const char *v_path, /* Var */
const char *e_path, /* Epsilon */
cc_int32 nchl, /* Channel */
cc_dtype dtype,
const char *name);

cc_tensor_t *cc_batch_norm2d(cc_tensor_t *inp,
const cc_tensor_t *para, const char *name);

Expand Down

0 comments on commit dda72a0

Please sign in to comment.