Skip to content

Commit

Permalink
examples: rnn: fix incorrect data type in int8 example
Browse files Browse the repository at this point in the history
  • Loading branch information
irinasok committed Feb 7, 2019
1 parent b5f1f92 commit 4661796
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions examples/simple_rnn_int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
int batch, int feature_size, int8_t *weights_src_layer,
float weights_src_layer_scale, int32_t *compensation,
uint8_t *dec_src_layer, float dec_src_layer_scale,
float dec_src_layer_shift, float *annotations,
float dec_src_layer_shift, uint8_t *annotations,
float *weighted_annotations, float *weights_alignments) {
// dst_iter : (n, c) matrix
// src_layer: (n, c) matrix
Expand Down Expand Up @@ -168,7 +168,10 @@ void compute_attention(float *context_vectors, int src_seq_length_max,
context_vectors[i * (feature_size + feature_size) + feature_size
+ j]
+= alignments[k * batch + i]
* annotations[j + feature_size * (i + batch * k)];
* (((float)annotations[j
+ feature_size * (i + batch * k)]
- dec_src_layer_shift)
/ dec_src_layer_scale);
}

void copy_context(float *src_iter, int n_layers, int n_states, int batch,
Expand Down Expand Up @@ -670,7 +673,7 @@ void simple_net() {
feature_size, user_weights_attention_src_layer.data(),
weights_attention_scale, weights_attention_sum_rows.data(),
src_att_layer_handle, data_scale, data_shift,
(float *)enc_bidir_dst_layer_memory.get_data_handle(),
(uint8_t *)enc_bidir_dst_layer_memory.get_data_handle(),
weighted_annotations.data(),
user_weights_alignments.data());

Expand Down

0 comments on commit 4661796

Please sign in to comment.