diff --git a/docs/build.md b/docs/build.md index 2e687929..706341e4 100644 --- a/docs/build.md +++ b/docs/build.md @@ -5,8 +5,6 @@ - protobuf >= 3.13 - cmake >= 3.18 -There are submodules in this repository which you should clone with `--recurse-submodules`. - To install cudatoolkit-dev, you could run `conda install -c conda-forge cudatoolkit-dev` or follow the [official guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#runfile), the runfile installation with `--toolkit` arg is recommended. After installation, check the installation of `nvcc` and static libraries (*.a) in `${CUDA_PATH}/lib64`. diff --git a/example/transformer_example.cc.cu b/example/transformer_example.cc.cu index 2ac4df9c..4a5d27ab 100644 --- a/example/transformer_example.cc.cu +++ b/example/transformer_example.cc.cu @@ -10,8 +10,13 @@ Example of how to run transformer inference using our implementation. */ // Appoint precision. -const lightseq::cuda::OperationType optype = +#ifdef FP16_MODE +const lightseq::cuda::OperationType OPTYPE = + lightseq::cuda::OperationType::FP16; +#else +const lightseq::cuda::OperationType OPTYPE = lightseq::cuda::OperationType::FP32; +#endif int main(int argc, char *argv[]) { /* ---step1. init environment--- */ @@ -21,10 +26,10 @@ int main(int argc, char *argv[]) { cudaStreamCreate(&stream_); cublasCreate(&hd_); cublasSetStream(hd_, stream_); - typedef lightseq::cuda::OperationTypeTraits optraits; + typedef lightseq::cuda::OperationTypeTraits optraits; /* ---step2. load model weights into GPU memory--- */ - lightseq::cuda::TransformerWeight tw_; + lightseq::cuda::TransformerWeight tw_; // saved in custom proto file std::string model_weights_path = argv[1]; std::string res = tw_.initializing(model_weights_path); @@ -47,8 +52,8 @@ int main(int argc, char *argv[]) { std::vector(max_batch_size * tw_._max_step * tw_._hidden_size, 0); thrust::device_vector d_output_ = std::vector(max_batch_size * tw_._max_step, 0); - std::shared_ptr> encoder_ = - std::make_shared>( + std::shared_ptr> encoder_ = + std::make_shared>( max_batch_size, reinterpret_cast(thrust::raw_pointer_cast(d_input_.data())), reinterpret_cast( @@ -62,15 +67,16 @@ int main(int argc, char *argv[]) { return 1; } // instantiate decoder - std::shared_ptr> decoder_ = - std::make_shared>( + std::shared_ptr> decoder_ = + std::make_shared>( max_batch_size, reinterpret_cast( thrust::raw_pointer_cast(d_padding_mask_.data())), reinterpret_cast( thrust::raw_pointer_cast(d_encoder_output_.data())), reinterpret_cast(thrust::raw_pointer_cast(d_output_.data())), - tw_, stream_, hd_); + tw_, stream_, hd_, false, + reinterpret_cast(thrust::raw_pointer_cast(d_input_.data()))); res = decoder_->check(); if (!res.empty()) { std::cout << res << std::endl; @@ -104,7 +110,7 @@ int main(int argc, char *argv[]) { batch_seq_len, host_input); /* ---step5. infer and log--- */ - for (int i = 0; i < 10; i++) { + for (int i = 0; i < 1; i++) { auto start = std::chrono::high_resolution_clock::now(); // copy inputs from cpu memory to gpu memory cudaMemcpyAsync( diff --git a/kernels/CMakeLists.txt b/kernels/CMakeLists.txt index a3e6376b..95654447 100644 --- a/kernels/CMakeLists.txt +++ b/kernels/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.18) -set(cuda_kernel_files gptKernels.cc.cu transformerKernels.cc.cu) +set(cuda_kernel_files gptKernels.cc.cu transformerKernels.cc.cu multilgKernels.cc.cu) add_library(cuda_kernels STATIC ${cuda_kernel_files}) target_include_directories(cuda_kernels INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/kernels/gptKernels.h b/kernels/gptKernels.h index 86fd0e20..d08f962e 100644 --- a/kernels/gptKernels.h +++ b/kernels/gptKernels.h @@ -17,6 +17,7 @@ void ker_gpt_embedding_launcher(int batch_size, int batch_seq_len, int pos_offset); + template void ker_correlation_softmax_gpt_launcher(int batch_size, int batch_seq_len, int head_num, cudaStream_t stream, diff --git a/kernels/multilgKernels.cc.cu b/kernels/multilgKernels.cc.cu new file mode 100644 index 00000000..94742e93 --- /dev/null +++ b/kernels/multilgKernels.cc.cu @@ -0,0 +1,451 @@ +#include + +#include "common.h" +#include "multilgKernels.h" +#include "transformerKernels.h" +/** +@file +Implemented the cuda kernel function and its launcher +that required by multilingual nmt model. +Currently, fp16 and fp32 versions are provided +*/ +namespace lightseq { +namespace cuda { +/** +@brief: ker_multilg_enc_emb +for encoder, look up token embedding, add position embedding + +@thread +gridDim.x = batch_size +gridDim.y = batch_seq_len +blockDim.x = max_thread_per_block + +@param +token_emb: [vocab_size, hidden_size] +pos_emb: [max_step, hidden_size] +token_id: input token id, [batch_size, batch_seq_len] +output: result, [batch_size, batch_seq_len, hidden_size] +padding_mask: record the padding token, [batch_size, batch_seq_len] +padding_id, the padding token id +*/ +template +__global__ void ker_multilg_enc_emb(const T* token_emb, const T* pos_emb, + const T* src_lang_emb, + const int* token_id, T* output, + int* padding_mask, int padding_id, + const int hidden_size) { + int target_pos = blockIdx.x * gridDim.y + blockIdx.y; + int start = target_pos * hidden_size + threadIdx.x; + int end = (target_pos + 1) * hidden_size; + int tid = token_id[target_pos]; + int lang_id = token_id[blockIdx.x * gridDim.y]; + if (tid == padding_id) { + // for padding id + if (threadIdx.x == 0) padding_mask[target_pos] = 1; + for (uint i = start; i < end; i += blockDim.x) { + // output[target_pos * blockDim.x + threadIdx.x] = 0.f; + output[i] = 0.f; + } + return; + } + if (threadIdx.x == 0) { + padding_mask[target_pos] = 0; + } + for (uint i = start; i < end; i += blockDim.x) { + int offset = i - target_pos * hidden_size; + output[i] = token_emb[tid * hidden_size + offset] + + pos_emb[blockIdx.y * hidden_size + offset] + + src_lang_emb[lang_id * hidden_size + offset]; + } +} + +template <> +__global__ void ker_multilg_enc_emb<__half>(const __half* token_emb, + const __half* pos_emb, + const __half* src_lang_emb, + const int* token_id, __half* output, + int* padding_mask, int padding_id, + const int half_hidden_size) { + int target_pos = blockIdx.x * gridDim.y + blockIdx.y; + int start = target_pos * half_hidden_size + threadIdx.x; + int end = (target_pos + 1) * half_hidden_size; + int tid = token_id[target_pos]; + int lang_id = token_id[blockIdx.x * gridDim.y]; + half2* output_h = (half2*)output; + + if (tid == padding_id) { + // for padding id + if (threadIdx.x == 0) padding_mask[target_pos] = 1; + for (uint i = start; i < end; i += blockDim.x) { + output_h[i] = __float2half2_rn(0.f); + } + return; + } + if (threadIdx.x == 0) { + padding_mask[target_pos] = 0; + } + for (uint i = start; i < end; i += blockDim.x) { + int offset = i - target_pos * half_hidden_size; + float2 te = __half22float2( + ((const half2*)token_emb)[tid * half_hidden_size + offset]); + float2 pe = __half22float2( + ((const half2*)pos_emb)[blockIdx.y * half_hidden_size + offset]); + float2 le = __half22float2( + ((const half2*)src_lang_emb)[lang_id * half_hidden_size + offset]); + te.x = te.x + pe.x + le.x; + te.y = te.y + pe.y + le.y; + + output_h[i] = __float22half2_rn(te); + } +} + +template +void ker_multilg_enc_emb_launcher(int batch_size, int batch_seq_len, + int hidden_size, cudaStream_t stream, + const T* token_emb, const T* pos_emb, + const T* src_lang_emb, + const int* token_id, T* output, + int* padding_mask, int padding_id, + int max_thread_per_block) { + ker_multilg_enc_emb + <<>>( + token_emb, pos_emb, src_lang_emb, token_id, output, padding_mask, padding_id, + hidden_size); +} + +template <> +void ker_multilg_enc_emb_launcher<__half>(int batch_size, int batch_seq_len, + int hidden_size, cudaStream_t stream, + const __half* token_emb, + const __half* pos_emb, + const __half* src_lang_emb, + const int* token_id, __half* output, + int* padding_mask, int padding_id, + int max_thread_per_block) { + ker_multilg_enc_emb<__half> + <<>>( + token_emb, pos_emb, src_lang_emb, token_id, output, padding_mask, padding_id, + hidden_size / 2); +} + +template void ker_multilg_enc_emb_launcher( + int batch_size, int batch_seq_len, int hidden_size, cudaStream_t stream, + const float* token_emb, const float* pos_emb, const float* src_lang_emb, + const int* token_id, float* output, int* padding_mask, int padding_id, + int max_thread_per_block); + +template void ker_multilg_enc_emb_launcher<__half>( + int batch_size, int batch_seq_len, int hidden_size, cudaStream_t stream, + const __half* token_emb, const __half* pos_emb, const __half* src_lang_emb, + const int* token_id, __half* output, int* padding_mask, int padding_id, + int max_thread_per_block); + +/** +@brief: ker_multilg_dec_emb +for multilingual decoder, look up token embedding, add position embedding +and lang embedding + +@thread +gridDim.x = batch_size * beam_size +blockDim.x = max_thread_per_block + +@param +token_emb: [hidden_size, vocab_size], note, it is different with encoder +pos_emb: [max_step, hidden_size] +src_lang_emb: [lang_num, hidden_size] +trg_lang_emb: [lang_num, hidden_size] +src_token_id: [batch_size, src_seq_len] +token_id: input token id, [batch_size, beam_size, max_step] +output: result, [batch_size, beam_size, hidden_size] +step: current step +max_step: max decoder steps +vocab_size: vocabulary size +*/ +template +__global__ void ker_multilg_dec_emb(const T* token_emb, const T* pos_emb, + const T* src_lang_emb, const T* trg_lang_emb, + const int* src_token_id, + const int* token_id, T* output, int step, + int max_step, int vocab_size, + int hidden_size, int beam_size, int src_seq_len) { + int batch_id = blockIdx.x / beam_size; + // src seq is in [src_lang_id, trg_lang_id, tokens...] format + int src_lang_id = src_token_id[batch_id * src_seq_len]; + int trg_lang_id = src_token_id[batch_id * src_seq_len + 1]; + int token_idx = (step == 0 ? trg_lang_id : token_id[blockIdx.x * max_step + step]); + for (uint offset = threadIdx.x; offset < hidden_size; offset += blockDim.x) { + output[blockIdx.x * hidden_size + offset] = + token_emb[offset * vocab_size + token_idx] + + pos_emb[step * hidden_size + offset] + + src_lang_emb[src_lang_id * hidden_size + offset] + + trg_lang_emb[trg_lang_id * hidden_size + offset]; + } +} + +template +void ker_multilg_dec_emb_launcher(int step_token_num, int hidden_size, + cudaStream_t stream, const T* token_emb, + const T* pos_emb, const T* src_lang_emb, + const T* trg_lang_emb, const int* src_token_id, + const int* token_id, T* output, int step, int max_step, + int vocab_size, int beam_size, int src_seq_len, + int max_thread_per_block) { + ker_multilg_dec_emb<<>>( + token_emb, pos_emb, src_lang_emb, trg_lang_emb, src_token_id, + token_id, output, step, max_step, vocab_size, + hidden_size, beam_size, src_seq_len); +} + +template void ker_multilg_dec_emb_launcher( + int step_token_num, int hidden_size, cudaStream_t stream, + const float* token_emb, const float* pos_emb, + const float* src_lang_emb, const float* trg_lang_emb, + const int* src_token_id, const int* token_id, + float* output, int step, int max_step, + int vocab_size, int beam_size, int src_seq_len, int max_thread_per_block); + +template void ker_multilg_dec_emb_launcher<__half>( + int step_token_num, int hidden_size, cudaStream_t stream, + const __half* token_emb, const __half* pos_emb, + const __half* src_lang_emb, const __half* trg_lang_emb, + const int* src_token_id, const int* token_id, + __half* output, int step, int max_step, + int vocab_size, int beam_size, int src_seq_len, int max_thread_per_block); + + +/** +@brief: select_beam_rough_topk_multilg +one block for one beam, compute the log seq probability ended with every token +in +vocab, base on the previous log seq probability and current step's logit, select +rough topK candidate. + +@thread +gridDim.x = batch_size * beam_size +blockDim.x = max_thread_per_block + +@param +logits: [batch_size, beam_size, vocab_size], cur step logit +logit_bias: [vocab_size], logit bias +seq_probs: [batch_size, beam_size], prefix sequence log probability +seq_score: [batch_size, beam_size], prefix sequence score +alive_seq: [batch_size, beam_size, max_step], prefix sequence id +can_idx: [batch_size, beam_size, vocab_size], topk candidate's index +can_score: [batch_size, beam_size, vocab_size], topk candidate's score +num_beam_can: [1 + batch_size * beam_size]. + the first ele save the number of topk candidate of the whole batch + the remaining batch_size * beam_size ele save the number of topk candidate + of each beam +vocab_size: the vocab size of decoder +max_step: max decode step +length_norm: length penlty value for current step +cur_step: current step +diverse_lambda: lambda for diverse beam search +*/ +template +__global__ void select_beam_rough_topk_multilg( + const T* logits, const T* logit_bias, const float* seq_probs, + const float* seq_score, const int* alive_seq, + const int* vocab_mask, const int* src_token_id, int* can_idx, + float* can_score, int* num_beam_can, int vocab_size, int max_step, + float length_norm, int cur_step, + float diverse_lambda, int end_id, int src_seq_len) { + if (alive_seq[blockIdx.x * max_step + cur_step] == end_id) { + // this is a finished beam + if (threadIdx.x == 0) { + num_beam_can[blockIdx.x + 1] = 1; // generate one candidate + int pos = atomicAdd(num_beam_can, 1); // get a candidate pos + if (diverse_lambda == 0) { + can_score[pos] = + seq_score[blockIdx.x]; // this beam's score will not be change + } else { + // add the beam id offset in score to sort in each beam + int batch_id = blockIdx.x / beam_size; + can_score[pos] = seq_score[blockIdx.x] + + (blockIdx.x - batch_id) * min_log_probability; + } + can_idx[pos] = end_id + (blockIdx.x % beam_size) * vocab_size; // EOS + } + return; + } + + /* step1: compute each thread's max_logit and sum_exp_logit, store in + * rough_top_kth_logit, sum_exp_logit */ + int batch_id = blockIdx.x / beam_size; + int trg_lang_id = src_token_id[batch_id * src_seq_len + 1]; + const int block_start = blockIdx.x * vocab_size; + const int left_idx = block_start + threadIdx.x; + const int right_idx = (blockIdx.x + 1) * vocab_size; + float rough_top_kth_logit = CUDA_FLOAT_INF_NEG; + float sum_exp_logit = 0; + for (int i = left_idx; i < right_idx; i += blockDim.x) { + int lang_mask = vocab_mask[trg_lang_id * vocab_size + i - block_start]; + float lgt = (lang_mask == 0 ? CUDA_FLOAT_INF_NEG : + (float)logits[i] + (float)__ldg(&logit_bias[i - block_start])); + rough_top_kth_logit = fmaxf(rough_top_kth_logit, lgt); + } + float max_logit = blockReduceMax(rough_top_kth_logit); + __shared__ float s_max_logit; + if (threadIdx.x == 0) { + s_max_logit = max_logit; + } + __syncthreads(); + for (int i = left_idx; i < right_idx; i += blockDim.x) { + int lang_mask = vocab_mask[trg_lang_id * vocab_size + i - block_start]; + float lgt = lang_mask == 0 ? 0.f : + expf(fmaxf( + (float)(logits[i]) + (float)__ldg(&logit_bias[i - block_start]) - + s_max_logit, logit_thresh_min)); + sum_exp_logit += lgt; + } + + /* + step2: compute rough top-kth-logits and sum_exp_logit among the whole beam, + saved into s_topk and + s_log_prob_base + */ + __shared__ float + s_log_prob_base; // prefix sequence log prob - log_sum_exp_logit + __shared__ float s_topk; // rough top k-th value of logits + __shared__ int num_cur_beam_can; // candidate number for this beam + sum_exp_logit = blockReduceSum(sum_exp_logit); + rough_top_kth_logit = blockRoughTopK(rough_top_kth_logit); + if (threadIdx.x == 0) { + s_log_prob_base = seq_probs[blockIdx.x] - logf(sum_exp_logit) - s_max_logit; + s_topk = rough_top_kth_logit; + num_cur_beam_can = 0; + } + + /* + step3 : select the candidate token with logits bigger than s_topk, + compute the seq probability ended with them, + save the probability, token_index, selected token number. + */ + int idx = left_idx; + int batch_start_pos = batch_id * beam_size * vocab_size; + // int unk_vocab_id = vocab_size - 3; // last three element: unk, start, eos + __shared__ int l_n; // current iteration candidate number + for (int iter = 0; iter < (vocab_size + blockDim.x - 1) / blockDim.x; + iter++) { + // zero the counter + if (threadIdx.x == 0) l_n = 0; + __syncthreads(); + + float lgt = CUDA_FLOAT_INF_NEG - 1.f; // min s_topk is CUDA_FLOAT_INF_NEG + int pos; + int vocab_id = idx - block_start; + + // if ((vocab_id < vocab_size) && (vocab_id != unk_vocab_id)) { + if (vocab_id < vocab_size) { + int lang_mask = vocab_mask[trg_lang_id * vocab_size + vocab_id]; + if (lang_mask != 0) { + lgt = (float)(logits[idx]) + (float)__ldg(&logit_bias[vocab_id]); + if (lgt >= s_topk) + // pos: relative pos inside this iteration + pos = atomicAdd(&l_n, 1); + } + } + __syncthreads(); + + // leader increments the global counter + if (threadIdx.x == 0) { + atomicAdd(&num_cur_beam_can, l_n); + l_n = atomicAdd(num_beam_can, l_n); + } + __syncthreads(); + + // threads with true predicates write their elements + if ((lgt >= s_topk)) { + pos += l_n; // increment local pos by global counter + if (diverse_lambda == 0) { + can_score[pos] = fmaxf((lgt + s_log_prob_base) * length_norm, + min_log_probability + 1.f) + + batch_id * min_log_probability; + } else { + can_score[pos] = fmaxf((lgt + s_log_prob_base) * length_norm, + min_log_probability + 1.f) + + blockIdx.x * min_log_probability; + } + can_idx[pos] = idx - batch_start_pos; + } + __syncthreads(); + idx += blockDim.x; + } + if (threadIdx.x == 0) { + num_beam_can[blockIdx.x + 1] = num_cur_beam_can; + } +} + +template +void select_beam_rough_topk_multilg_launcher( + const T* logits, const T* logit_bias, const float* seq_probs, + const float* seq_score, const int* alive_seq, + const int* vocab_mask, const int* src_token_id, + int* can_idx, float* can_score, int* num_beam_can, int vocab_size, int max_step, + float length_norm, int cur_step, int step_token_num, + int max_thread_per_block, cudaStream_t stream, int beam_size, + float diverse_lambda, int end_id, int src_seq_len) { + if (beam_size == 1) + select_beam_rough_topk_multilg + <<>>( + logits, logit_bias, seq_probs, seq_score, alive_seq, + vocab_mask, src_token_id, can_idx, can_score, num_beam_can, + vocab_size, max_step, length_norm, cur_step, + diverse_lambda, end_id, src_seq_len); + if (beam_size == 2) + select_beam_rough_topk_multilg + <<>>( + logits, logit_bias, seq_probs, seq_score, alive_seq, + vocab_mask, src_token_id, can_idx, can_score, num_beam_can, + vocab_size, max_step, length_norm, cur_step, + diverse_lambda, end_id, src_seq_len); + if (beam_size == 4) + select_beam_rough_topk_multilg + <<>>( + logits, logit_bias, seq_probs, seq_score, alive_seq, + vocab_mask, src_token_id, can_idx, can_score, num_beam_can, + vocab_size, max_step, length_norm, cur_step, + diverse_lambda, end_id, src_seq_len); + if (beam_size == 8) + select_beam_rough_topk_multilg + <<>>( + logits, logit_bias, seq_probs, seq_score, alive_seq, + vocab_mask, src_token_id, can_idx, can_score, num_beam_can, + vocab_size, max_step, length_norm, cur_step, + diverse_lambda, end_id, src_seq_len); + if (beam_size == 16) + select_beam_rough_topk_multilg + <<>>( + logits, logit_bias, seq_probs, seq_score, alive_seq, + vocab_mask, src_token_id, can_idx, can_score, num_beam_can, + vocab_size, max_step, length_norm, cur_step, + diverse_lambda, end_id, src_seq_len); + if (beam_size == 32) + select_beam_rough_topk_multilg + <<>>( + logits, logit_bias, seq_probs, seq_score, alive_seq, + vocab_mask, src_token_id, can_idx, can_score, num_beam_can, + vocab_size, max_step, length_norm, cur_step, + diverse_lambda, end_id, src_seq_len); +} + +template void select_beam_rough_topk_multilg_launcher( + const float* logits, const float* logit_bias, const float* seq_probs, + const float* seq_score, const int* alive_seq, + const int* vocab_mask, const int* src_token_id, + int* can_idx, float* can_score, int* num_beam_can, int vocab_size, int max_step, + float length_norm, int cur_step, int step_token_num, + int max_thread_per_block, cudaStream_t stream, int beam_size, + float diverse_lambda, int end_id, int src_seq_len); + +template void select_beam_rough_topk_multilg_launcher<__half>( + const __half* logits, const __half* logit_bias, const float* seq_probs, + const float* seq_score, const int* alive_seq, + const int* vocab_mask, const int* src_token_id, + int* can_idx, float* can_score, int* num_beam_can, int vocab_size, int max_step, + float length_norm, int cur_step, int step_token_num, + int max_thread_per_block, cudaStream_t stream, int beam_size, + float diverse_lambda, int end_id, int src_seq_len); + +} // namespace cuda +} // namespace lightseq diff --git a/kernels/multilgKernels.h b/kernels/multilgKernels.h new file mode 100644 index 00000000..8b8de61b --- /dev/null +++ b/kernels/multilgKernels.h @@ -0,0 +1,37 @@ +#pragma once +#include +#include + +namespace lightseq { +namespace cuda { + +template +void ker_multilg_enc_emb_launcher(int batch_size, int batch_seq_len, + int hidden_size, cudaStream_t stream, + const T* token_emb, const T* pos_emb, + const T* src_lang_emb, + const int* token_id, T* output, + int* padding_mask, int padding_id, + int max_thread_per_block); + +template +void ker_multilg_dec_emb_launcher(int step_token_num, int hidden_size, + cudaStream_t stream, const T* token_emb, const T* pos_emb, + const T* src_lang_emb, const T* trg_lang_emb, + const int* src_token_id, const int* token_id, + T* output, int step, int max_step, + int vocab_size, int beam_size, int src_seq_len, + int max_thread_per_block); + +template +void select_beam_rough_topk_multilg_launcher( + const T* logits, const T* logit_bias, const float* seq_probs, + const float* seq_score, const int* alive_seq, + const int* vocab_mask, const int* src_token_id, + int* can_idx, float* can_score, int* num_beam_can, int vocab_size, int max_step, + float length_norm, int cur_step, int step_token_num, + int max_thread_per_block, cudaStream_t stream, int beam_size, + float diverse_lambda, int end_id, int src_seq_len); + +} // namespace cuda +} // namespace lightseq diff --git a/kernels/transformerKernels.cc.cu b/kernels/transformerKernels.cc.cu index 374cc62c..3604b935 100644 --- a/kernels/transformerKernels.cc.cu +++ b/kernels/transformerKernels.cc.cu @@ -1221,9 +1221,11 @@ src_padding_mask: [batch_size, batch_seq_len], template __global__ void ker_correlation_softmax_encself(T* correlation, const int* src_padding_mask) { - if (src_padding_mask[blockIdx.x * blockDim.x + blockIdx.y % blockDim.x]) - return; int idx = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (src_padding_mask[blockIdx.x * blockDim.x + blockIdx.y % blockDim.x]) { + correlation[idx] = (T) 0.f; + return; + } int mask = src_padding_mask[blockIdx.x * blockDim.x + threadIdx.x]; float val = (float)correlation[idx]; diff --git a/model/decoder.cc.cu b/model/decoder.cc.cu index b25cd1e1..218ea6dc 100644 --- a/model/decoder.cc.cu +++ b/model/decoder.cc.cu @@ -3,6 +3,7 @@ #include "3rdparty/cub/cub/cub.cuh" #include "decoder.h" #include "kernels/transformerKernels.h" +#include "kernels/multilgKernels.h" /** @file @@ -17,7 +18,8 @@ template Decoder::Decoder(int max_batch_size, const int* p_d_padding_mask, const _DataType* p_d_encoder_output, int* p_d_result, TransformerWeight& tw, cudaStream_t stream, - cublasHandle_t hd, bool output_topk) + cublasHandle_t hd, bool output_topk, + const int *p_d_token_id) : _max_batch_size(max_batch_size), _max_thread_per_block(1024), _h_can_num_batch(0), @@ -32,14 +34,16 @@ Decoder::Decoder(int max_batch_size, const int* p_d_padding_mask, _stream(stream), _hd(hd), _output_topk(output_topk), + _p_d_token_id(p_d_token_id), // source token id _layer_size_encdec_k(max_batch_size * tw._max_step * tw._hidden_size), _layer_size_self_k(max_batch_size * tw._max_step * tw._hidden_size * tw._beam_size), - _fone(1.f), + _type_one(1.f), + _type_zero(0.f), _fzero(0.f), _atten_scaler(sqrt(1.f / tw._dim_per_head)), - _output_scaler(_tw._no_scale_embedding ? 1.f - : sqrt(1.f / tw._hidden_size)), + _logit_scaler(_tw._no_scale_embedding ? 1.f + : sqrt(1.f / tw._hidden_size)), _h_alive_seq_probs(max_batch_size * tw._beam_size, min_log_probability / 2), _h_length_norm(tw._max_step, 1.f), @@ -217,7 +221,10 @@ std::string Decoder::check() { if (_tw._dim_per_head & 1) { return "violate dim_per_head % 2 = 0"; } - if (_p_d_trg_emb_wei.size() != 7) { + if (_tw._is_multilingual && _p_d_trg_emb_wei.size() != 8) { + return "violate p_d_trg_emb_wei.size() = 8"; + } + if (_tw._is_multilingual == false && _p_d_trg_emb_wei.size() != 7) { return "violate p_d_trg_emb_wei.size() = 7"; } if (_p_d_dec_wei.size() != _tw._weight_per_dec_layer * _tw._n_dec_layer) { @@ -320,8 +327,8 @@ void Decoder::project_encoder_output() { #endif CHECK_GPU_ERROR(cublasGemmEx( _hd, CUBLAS_OP_N, CUBLAS_OP_N, kv_dim, _batch_token_num, _tw._hidden_size, - &_fone, _p_d_trg_emb_wei[4], _AType, kv_dim, _p_d_encoder_output, _BType, - _tw._hidden_size, &_fzero, _p_d_encoder_out_buf, _CType, kv_dim, + &_type_one, _p_d_trg_emb_wei[4], _AType, kv_dim, _p_d_encoder_output, _BType, + _tw._hidden_size, &_type_zero, _p_d_encoder_out_buf, _CType, kv_dim, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // _p_d_encoder_out_buf: [batch_size, batch_seq_len, layer_num, 2, // hidden_size] @@ -348,6 +355,27 @@ template bool Decoder::run_step() { embedding(); decoder_stack(); + /* --- Project hidden states to vocab logits--- */ + CHECK_GPU_ERROR(cublasGemmEx( + _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._trg_vocab_size, _step_token_num, + _tw._hidden_size, &_logit_scaler, _p_d_trg_emb_wei[0], _AType, + _tw._trg_vocab_size, _p_d_cur_step_query, _BType, _tw._hidden_size, + //&_type_zero, _p_d_logit_buf, _CType, _tw._trg_vocab_size, _computeType, + &_fzero, _p_d_logit_buf, _CType, _tw._trg_vocab_size, CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +#ifdef DEBUG_RESULT + for (int i = 0; i < _batch_size; i++) { // batch_id + for (int j = 0; j < _tw._beam_size; j++) { // beam_id + std::cout << "decoder output: batch-" << i << ", beam-" << j << std::endl; + print_vec(_p_d_cur_step_query + i * _tw._beam_size * _tw._hidden_size + + j * _tw._hidden_size, + "hidden", 10); + print_vec(_p_d_logit_buf + i * _tw._beam_size * _tw._trg_vocab_size + + j * _tw._trg_vocab_size, + "logits", 10); + } + } +#endif if (_tw._sampling_method == "topk") { return sample(); } else if (_tw._sampling_method == "topp") { @@ -368,14 +396,29 @@ template void Decoder::embedding() { // _p_d_trg_emb_wei: {token_emb, position_emb, norm_scale, norm_bias, // enc_out_kernel_kv, enc_out_bias_kv, logit_bias} - ker_dec_embedding_launcher<_DataType>( - _step_token_num, _tw._hidden_size, _stream, _p_d_trg_emb_wei[0], - _p_d_trg_emb_wei[1], _p_d_alive_seq, _p_d_cur_step_query, _cur_step, - _tw._max_step, _tw._trg_vocab_size, _max_thread_per_block); + if (_tw._is_multilingual) { + ker_multilg_dec_emb_launcher<_DataType>( + _step_token_num, _tw._hidden_size, _stream, + _p_d_trg_emb_wei[0], _p_d_trg_emb_wei[1], + _tw.get_src_emb_wei()[4], _p_d_trg_emb_wei[7], + _p_d_token_id, _p_d_alive_seq, + _p_d_cur_step_query, _cur_step, _tw._max_step, + _tw._trg_vocab_size, _tw._beam_size, _batch_seq_len, _max_thread_per_block); + } else { + ker_dec_embedding_launcher<_DataType>( + _step_token_num, _tw._hidden_size, _stream, _p_d_trg_emb_wei[0], + _p_d_trg_emb_wei[1], _p_d_alive_seq, _p_d_cur_step_query, _cur_step, + _tw._max_step, _tw._trg_vocab_size, _max_thread_per_block); + } #ifdef DEBUG_RESULT - print_vec(_p_d_cur_step_query, "decoder embedding(head):", 5); - print_vec(_p_d_cur_step_query + _step_token_num * _tw._hidden_size - 5, - "decoder embedding(tail):", 5); + for (int i = 0; i < _batch_size; i++) { // batch_id + for (int j = 0; j < _tw._beam_size; j++) { // beam_id + std::cout << "decoder emb: batch-" << i << ", beam-" << j << std::endl; + print_vec(_p_d_cur_step_query + i * _tw._beam_size * _tw._hidden_size + + j * _tw._hidden_size, + "emb", 10); + } + } #endif return; } @@ -431,8 +474,8 @@ void Decoder::self_attention() { * gemm--- */ CHECK_GPU_ERROR(cublasGemmEx( _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size * 3, _step_token_num, - _tw._hidden_size, &_fone, _p_d_dec_wei[_weight_offset + 2], _AType, - _tw._hidden_size * 3, _p_d_query_buf1, _BType, _tw._hidden_size, &_fzero, + _tw._hidden_size, &_type_one, _p_d_dec_wei[_weight_offset + 2], _AType, + _tw._hidden_size * 3, _p_d_query_buf1, _BType, _tw._hidden_size, &_type_zero, _p_d_self_step_qkv, _CType, _tw._hidden_size * 3, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -479,7 +522,7 @@ void Decoder::self_attention() { _hd, CUBLAS_OP_T, CUBLAS_OP_N, _cur_step + 1, 1, _tw._dim_per_head, &_atten_scaler, _p_d_self_k_bgeem1[_layer_id], _AType, _tw._dim_per_head, _tw._max_step * _tw._dim_per_head, _p_d_query_buf1, _BType, - _tw._dim_per_head, _tw._dim_per_head, &_fzero, _p_d_c, _CType, + _tw._dim_per_head, _tw._dim_per_head, &_type_zero, _p_d_c, _CType, _cur_step + 1, _cur_step + 1, _step_token_num * _tw._head_num, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); ker_correlation_softmax_decself_launcher(_step_token_num * _tw._head_num, @@ -494,9 +537,9 @@ void Decoder::self_attention() { /* ---step 3. new_q = correlation * v--- */ CHECK_GPU_ERROR(cublasGemmStridedBatchedEx( _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._dim_per_head, 1, _cur_step + 1, - &_fone, _p_d_self_v_bgeem1[_layer_id], _AType, _tw._dim_per_head, + &_type_one, _p_d_self_v_bgeem1[_layer_id], _AType, _tw._dim_per_head, _tw._max_step * _tw._dim_per_head, _p_d_c, _BType, _cur_step + 1, - _cur_step + 1, &_fzero, _p_d_query_buf1, _CType, _tw._dim_per_head, + _cur_step + 1, &_type_zero, _p_d_query_buf1, _CType, _tw._dim_per_head, _tw._dim_per_head, _step_token_num * _tw._head_num, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -509,8 +552,8 @@ void Decoder::self_attention() { /* ---step 4. new_q = ori_q + new_q * output_wei--- */ CHECK_GPU_ERROR(cublasGemmEx( _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size, _step_token_num, - _tw._hidden_size, &_fone, _p_d_dec_wei[_weight_offset + 4], _AType, - _tw._hidden_size, _p_d_query_buf1, _BType, _tw._hidden_size, &_fone, + _tw._hidden_size, &_type_one, _p_d_dec_wei[_weight_offset + 4], _AType, + _tw._hidden_size, _p_d_query_buf1, _BType, _tw._hidden_size, &_type_one, _p_d_cur_step_query, _CType, _tw._hidden_size, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -543,8 +586,8 @@ void Decoder::encdec_attention() { * gemm--- */ CHECK_GPU_ERROR(cublasGemmEx( _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size, _step_token_num, - _tw._hidden_size, &_fone, _p_d_dec_wei[_weight_offset + 8], _AType, - _tw._hidden_size, _p_d_query_buf1, _BType, _tw._hidden_size, &_fzero, + _tw._hidden_size, &_type_one, _p_d_dec_wei[_weight_offset + 8], _AType, + _tw._hidden_size, _p_d_query_buf1, _BType, _tw._hidden_size, &_type_zero, _p_d_query_buf2, _CType, _tw._hidden_size, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); ker_arrange_encdec_q_launcher<_DataType>( @@ -557,7 +600,7 @@ void Decoder::encdec_attention() { _hd, CUBLAS_OP_T, CUBLAS_OP_N, _batch_seq_len, _tw._beam_size, _tw._dim_per_head, &_atten_scaler, _p_d_encdec_k_bgeem[_layer_id], _AType, _tw._dim_per_head, _batch_seq_len * _tw._dim_per_head, _p_d_query_buf1, - _BType, _tw._dim_per_head, _tw._beam_size * _tw._dim_per_head, &_fzero, + _BType, _tw._dim_per_head, _tw._beam_size * _tw._dim_per_head, &_type_zero, _p_d_c, _CType, _batch_seq_len, _tw._beam_size * _batch_seq_len, _batch_size * _tw._head_num, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -568,9 +611,9 @@ void Decoder::encdec_attention() { /* ---step 3. new_q = correlation * v--- */ CHECK_GPU_ERROR(cublasGemmStridedBatchedEx( _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._dim_per_head, _tw._beam_size, - _batch_seq_len, &_fone, _p_d_encdec_v_bgeem[_layer_id], _AType, + _batch_seq_len, &_type_one, _p_d_encdec_v_bgeem[_layer_id], _AType, _tw._dim_per_head, _batch_seq_len * _tw._dim_per_head, _p_d_c, _BType, - _batch_seq_len, _tw._beam_size * _batch_seq_len, &_fzero, _p_d_query_buf1, + _batch_seq_len, _tw._beam_size * _batch_seq_len, &_type_zero, _p_d_query_buf1, _CType, _tw._dim_per_head, _tw._beam_size * _tw._dim_per_head, _batch_size * _tw._head_num, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -582,8 +625,8 @@ void Decoder::encdec_attention() { /* ---step 4. new_q = ori_q + new_q * output_wei--- */ CHECK_GPU_ERROR(cublasGemmEx( _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size, _step_token_num, - _tw._hidden_size, &_fone, _p_d_dec_wei[_weight_offset + 10], _AType, - _tw._hidden_size, _p_d_query_buf2, _BType, _tw._hidden_size, &_fone, + _tw._hidden_size, &_type_one, _p_d_dec_wei[_weight_offset + 10], _AType, + _tw._hidden_size, _p_d_query_buf2, _BType, _tw._hidden_size, &_type_one, _p_d_cur_step_query, _CType, _tw._hidden_size, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); return; @@ -607,8 +650,8 @@ void Decoder::ffn_add_norm() { /* ---step 1. first ffn layer--- */ CHECK_GPU_ERROR(cublasGemmEx( _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._inner_size, _step_token_num, - _tw._hidden_size, &_fone, _p_d_dec_wei[_weight_offset + 14], _AType, - _tw._inner_size, _p_d_query_buf1, _BType, _tw._hidden_size, &_fzero, + _tw._hidden_size, &_type_one, _p_d_dec_wei[_weight_offset + 14], _AType, + _tw._inner_size, _p_d_query_buf1, _BType, _tw._hidden_size, &_type_zero, _p_d_query_buf2, _CType, _tw._inner_size, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -625,8 +668,8 @@ void Decoder::ffn_add_norm() { /* ---step 2. second ffn layer--- */ CHECK_GPU_ERROR(cublasGemmEx( _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._hidden_size, _step_token_num, - _tw._inner_size, &_fone, _p_d_dec_wei[_weight_offset + 16], _AType, - _tw._hidden_size, _p_d_query_buf2, _BType, _tw._inner_size, &_fone, + _tw._inner_size, &_type_one, _p_d_dec_wei[_weight_offset + 16], _AType, + _tw._hidden_size, _p_d_query_buf2, _BType, _tw._inner_size, &_type_one, _p_d_cur_step_query, _CType, _tw._hidden_size, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); return; @@ -634,24 +677,9 @@ void Decoder::ffn_add_norm() { template bool Decoder::sample() { - /* ---step 1. project hidden states to vocab logits--- */ - CHECK_GPU_ERROR(cublasGemmEx( - _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._trg_vocab_size, _step_token_num, - _tw._hidden_size, &_output_scaler, _p_d_trg_emb_wei[0], _AType, - _tw._trg_vocab_size, _p_d_cur_step_query, _BType, _tw._hidden_size, - &_fzero, _p_d_logit_buf, _CType, _tw._trg_vocab_size, _computeType, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - -#ifdef DEBUG_RESULT - print_vec(_p_d_logit_buf, "logits(head):", 5); - print_vec( - _p_d_logit_buf + _batch_size * _tw._beam_size * _tw._trg_vocab_size - 5, - "logits(tail):", 5); -#endif - CHECK_GPU_ERROR( cudaMemsetAsync(_p_d_sample_unfinished, 0, sizeof(int), _stream)); - // /* ---step 2. sample new tokens from logits */ + /* --- Sample new tokens from logits --- */ if (_tw._sampling_method == "topk") { ker_topk_sample_launcher<_DataType>( _batch_size, (_cur_step + 1), _tw._max_step, 1, _max_thread_per_block, @@ -683,19 +711,6 @@ bool Decoder::sample() { template bool Decoder::beam_search() { - /* ---step 0. project hidden states to vocab logits--- */ - CHECK_GPU_ERROR(cublasGemmEx( - _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._trg_vocab_size, _step_token_num, - _tw._hidden_size, &_output_scaler, _p_d_trg_emb_wei[0], _AType, - _tw._trg_vocab_size, _p_d_cur_step_query, _BType, _tw._hidden_size, - &_fzero, _p_d_logit_buf, _CType, _tw._trg_vocab_size, _computeType, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -#ifdef DEBUG_RESULT - print_vec(_p_d_logit_buf, "logits(head):", 5); - print_vec(_p_d_logit_buf + _step_token_num * _tw._trg_vocab_size - 5, - "logits(tail):", 5); -#endif - /* step 1. logits bias and softmax, select rough topk candidate for every batch item, @@ -728,6 +743,11 @@ bool Decoder::beam_search() { _p_d_can_score + _h_can_num_batch, _p_d_can_idx, thrust::greater()); +#ifdef DEBUG_RESULT + print_vec(_p_d_can_score, "can score", _h_can_num_batch); + print_vec(_p_d_can_idx, "can idx", _h_can_num_batch); +#endif + /* step 3. refresh alive_seq, seq_probs, seq_score, num_finish_beam based on sorted candidate. @@ -794,13 +814,26 @@ Record the candidate's beam_id, vocab_id and probability template void Decoder::update_new_seq_probs() { CHECK_GPU_ERROR(cudaMemsetAsync(_p_d_can_num, 0, sizeof(int), _stream)); - select_beam_rough_topk_launcher( + + if (_tw._is_multilingual) { + select_beam_rough_topk_multilg_launcher( + _p_d_logit_buf, _p_d_trg_emb_wei[6], _p_d_alive_seq_probs, + _p_d_alive_seq_score, _p_d_alive_seq, + _tw._p_d_trg_vocab_mask, _p_d_token_id, + _p_d_can_idx, _p_d_can_score, + _p_d_can_num, _tw._trg_vocab_size, _tw._max_step, + _h_length_norm[_cur_step], _cur_step, _step_token_num, + _max_thread_per_block, _stream, _tw._beam_size, _tw._diverse_lambda, + _tw._end_id, _batch_seq_len); + } else { + select_beam_rough_topk_launcher( _p_d_logit_buf, _p_d_trg_emb_wei[6], _p_d_alive_seq_probs, _p_d_alive_seq_score, _p_d_alive_seq, _p_d_can_idx, _p_d_can_score, _p_d_can_num, _tw._trg_vocab_size, _tw._max_step, _h_length_norm[_cur_step], _cur_step, _step_token_num, _max_thread_per_block, _stream, _tw._beam_size, _tw._diverse_lambda, _tw._end_id); + } thrust::exclusive_scan(thrust::cuda::par.on(_stream), _p_d_can_num + 1, _p_d_can_num + 1 + _step_token_num, _p_d_can_num + 1); return; @@ -813,24 +846,9 @@ bool Decoder::topk_greedy_search() { return beam_search(); } - /* ---step 1. project hidden states to vocab logits--- */ - CHECK_GPU_ERROR(cublasGemmEx( - _hd, CUBLAS_OP_N, CUBLAS_OP_N, _tw._trg_vocab_size, _step_token_num, - _tw._hidden_size, &_output_scaler, _p_d_trg_emb_wei[0], _AType, - _tw._trg_vocab_size, _p_d_cur_step_query, _BType, _tw._hidden_size, - &_fzero, _p_d_logit_buf, _CType, _tw._trg_vocab_size, _computeType, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - -#ifdef DEBUG_RESULT - print_vec(_p_d_logit_buf, "logits(head):", 5); - print_vec( - _p_d_logit_buf + _batch_size * _tw._beam_size * _tw._trg_vocab_size - 5, - "logits(tail):", 5); -#endif - CHECK_GPU_ERROR( cudaMemsetAsync(_p_d_sample_unfinished, 0, sizeof(int), _stream)); - // /* ---step 2. sample new tokens from logits */ + /* --- Sample new tokens from logits --- */ ker_topk_sample_launcher<_DataType>( _step_token_num, (_cur_step + 1), _tw._max_step, 1, _max_thread_per_block, _stream, _p_d_logit_buf, _p_d_trg_emb_wei[6], _p_d_alive_seq, diff --git a/model/decoder.h b/model/decoder.h index e35edde5..95cf08c5 100644 --- a/model/decoder.h +++ b/model/decoder.h @@ -61,6 +61,7 @@ class Decoder { int* _p_d_result; int* _p_d_sample_unfinished; curandState* _p_d_curandstate; //[batch_size] + const int* _p_d_token_id; // source token id std::vector _h_alive_seq_probs; std::vector _h_length_norm; @@ -113,11 +114,11 @@ class Decoder { const std::vector& _p_d_trg_emb_wei; // size: 7 const std::vector& _p_d_dec_wei; // size: 18 * dec_layer_num - const _DataType _fone; - const _DataType _fzero; - const _DataType - _atten_scaler; // scaling factor of Scaled Dot-Product Attention - const _DataType _output_scaler; // output scaling factor of the liner project + const _DataType _type_one; + const _DataType _type_zero; + const float _fzero; + const _DataType _atten_scaler; // scaling factor of Scaled Dot-Product Attention + const float _logit_scaler; // output scaling factor of the liner project // after decoder const long _layer_size_encdec_k; const long _layer_size_self_k; @@ -126,7 +127,8 @@ class Decoder { Decoder(int max_batch_size, const int* p_d_padding_mask, const _DataType* p_d_encoder_output, int* p_d_result, TransformerWeight& tw, cudaStream_t stream, - cublasHandle_t hd, bool output_topk = false); + cublasHandle_t hd, bool output_topk = false, + const int *p_d_token_id = nullptr); long compute_buffer_bytesize(); void init_buffer(void* pbuf); std::string check(); diff --git a/model/encoder.cc.cu b/model/encoder.cc.cu index 8ae78b39..f9412379 100644 --- a/model/encoder.cc.cu +++ b/model/encoder.cc.cu @@ -1,5 +1,6 @@ #include "encoder.h" #include "kernels/transformerKernels.h" +#include "kernels/multilgKernels.h" /** @file @@ -75,9 +76,12 @@ std::string Encoder::check() { if (_tw._dim_per_head & 1) { return "violate dim_per_head % 2 = 0"; } - if (_p_d_src_emb_wei.size() != 4) { + if (_tw._is_multilingual == false && _p_d_src_emb_wei.size() != 4) { return "violate p_d_src_emb_wei.size() = 4"; } + if (_tw._is_multilingual && _p_d_src_emb_wei.size() != 5) { + return "violate p_d_src_emb_wei.size() = 5"; + } if (_p_d_enc_wei.size() != _tw._weight_per_enc_layer * _tw._n_enc_layer) { return "violate p_d_enc_wei.size() = weight_per_enc_layer * n_enc_layer"; } @@ -100,14 +104,29 @@ void Encoder::run_one_infer(int batch_size, int batch_seq_len) { #endif /* ---step2. encoder feedforward--- */ - ker_enc_embedding_launcher<_DataType>( - batch_size, batch_seq_len, _tw._hidden_size, _stream, _p_d_src_emb_wei[0], - _p_d_src_emb_wei[1], _p_d_token_id, _p_d_output, _p_d_padding_mask, - _tw._padding_id, _max_thread_per_block); + if (_tw._is_multilingual) { + ker_multilg_enc_emb_launcher<_DataType>( + batch_size, batch_seq_len, _tw._hidden_size, _stream, _p_d_src_emb_wei[0], + _p_d_src_emb_wei[1], _p_d_src_emb_wei[4], + //_p_d_src_emb_wei[1], _p_d_src_emb_wei[1], + _p_d_token_id, _p_d_output, _p_d_padding_mask, + _tw._padding_id, _max_thread_per_block); + } else { + ker_enc_embedding_launcher<_DataType>( + batch_size, batch_seq_len, _tw._hidden_size, _stream, _p_d_src_emb_wei[0], + _p_d_src_emb_wei[1], + _p_d_token_id, _p_d_output, _p_d_padding_mask, + _tw._padding_id, _max_thread_per_block); + } #ifdef DEBUG_RESULT - print_vec(_p_d_output, "encoder embedding(head):", 5); - print_vec(_p_d_output + _batch_token_num * _tw._hidden_size - 5, - "encoder embedding(tail):", 5); + for (int i = 0; i < _batch_size; i++) { // batch_id + for (int j = 0; j < _batch_seq_len; j++) { // token_id + std::cout << "emb out: token-" << j << std::endl; + print_vec(_p_d_output + i * _batch_seq_len * _tw._hidden_size + + j * _tw._hidden_size, + "emb out", 10); + } + } // not normal #endif for (_layer_id = 0; _layer_id < _tw._n_enc_layer; _layer_id++) { _weight_offset = _layer_id * _tw._weight_per_enc_layer; @@ -151,6 +170,7 @@ void Encoder::self_attention() { _tw._hidden_size * 3, _p_d_q, _BType, _tw._hidden_size, &_fzero, _p_d_qkv_projected, _CType, _tw._hidden_size * 3, _computeType, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // get q, k, v by split and reshape qkv ker_arrange_encself_qkv_launcher<_DataType>( _batch_token_num, _tw._hidden_size, _stream, _p_d_qkv_projected, diff --git a/proto/transformer.proto b/proto/transformer.proto index 1fa20341..9aa514ba 100644 --- a/proto/transformer.proto +++ b/proto/transformer.proto @@ -79,13 +79,21 @@ message EmbeddingLayer { repeated float norm_scale = 3; // [hidden_size] repeated float norm_bias = 4; // [hidden_size] - // below only for trg, not in src - // [hidden_size, enc_layer, 2, hidden_size] + // only for trg, not in src + // [dec_layer_num, hidden_size, 2, hidden_size] repeated float encode_output_project_kernel_kv = 5; - // [enc_layer, 2, hidden_size] + // only for trg, not in src + // [dec_layer_num, 2, hidden_size] repeated float encode_output_project_bias_kv = 6; + // only for trg, not in src // decoder vocab logit bias repeated float shared_bias = 7; // [target_vocab_size] + + // For multi lingual model, [num_lang, hidden_size] + repeated float lang_emb = 8; + // only for trg, not in src + // For multi lingual model, [num_lang, target_vocab_size] + repeated int32 trg_vocab_mask = 9; } message ModelConf { @@ -104,6 +112,9 @@ message ModelConf { bool is_post_ln = 12; // Pre-LN or Post-LN bool no_scale_embedding = 13; // whether to scale embedding by sqrt(emb_dim) bool use_gelu = 14; // use gelu for activation otherwise relu + // Whether it is a multilingual model. + // If it is set to true, lang_emb and trg_vocab_mask should be non-empty. + bool is_multilingual = 15; } message Transformer { diff --git a/proto/transformer_weight.cc b/proto/transformer_weight.cc index 14b4b265..3c07a2ec 100644 --- a/proto/transformer_weight.cc +++ b/proto/transformer_weight.cc @@ -72,6 +72,7 @@ void TransformerWeight::get_model_config( _is_post_ln = transformer.model_conf().is_post_ln(); _no_scale_embedding = transformer.model_conf().no_scale_embedding(); _use_gelu = transformer.model_conf().use_gelu(); + _is_multilingual = transformer.model_conf().is_multilingual(); } /** @@ -92,23 +93,23 @@ std::string TransformerWeight::parse_emb_wei( offset.push_back(idx); if (layer.token_embedding_size() != vocab_size * _hidden_size) - return "wrong token_embedding_size !"; + return "Wrong token_embedding_size !"; for (float ele : layer.token_embedding()) value.push_back(ele); idx += vocab_size * _hidden_size; offset.push_back(idx); if (layer.position_embedding_size() != _max_step * _hidden_size) - return "wrong position_embedding_size !"; + return "Wrong position_embedding_size !"; for (float ele : layer.position_embedding()) value.push_back(ele); idx += _max_step * _hidden_size; offset.push_back(idx); - if (layer.norm_scale_size() != _hidden_size) return "wrong norm_scale_size !"; + if (layer.norm_scale_size() != _hidden_size) return "Wrong norm_scale_size !"; for (float ele : layer.norm_scale()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); - if (layer.norm_bias_size() != _hidden_size) return "wrong norm_bias_size !"; + if (layer.norm_bias_size() != _hidden_size) return "Wrong norm_bias_size !"; for (float ele : layer.norm_bias()) value.push_back(ele); idx += _hidden_size; @@ -125,7 +126,7 @@ std::string TransformerWeight::parse_emb_wei( offset.push_back(idx); if (layer.encode_output_project_kernel_kv_size() != _hidden_size * _hidden_size * 2 * _n_dec_layer) - return "wrong encode_output_project_kernel_kv_size !"; + return "Wrong encode_output_project_kernel_kv_size !"; for (float ele : layer.encode_output_project_kernel_kv()) value.push_back(ele); idx += _hidden_size * _hidden_size * 2 * _n_dec_layer; @@ -133,25 +134,57 @@ std::string TransformerWeight::parse_emb_wei( offset.push_back(idx); if (layer.encode_output_project_bias_kv_size() != _hidden_size * 2 * _n_dec_layer) - return "wrong encode_output_project_bias_kv_size !"; + return "Wrong encode_output_project_bias_kv_size !"; for (float ele : layer.encode_output_project_bias_kv()) value.push_back(ele); idx += _hidden_size * 2 * _n_dec_layer; offset.push_back(idx); if (layer.shared_bias_size() != vocab_size) - return "wrong shared_bias_size !"; + return "Wrong shared_bias_size !"; for (float ele : layer.shared_bias()) value.push_back(ele); idx += vocab_size; std::vector<_DataType> raw_value; for (float e : value) raw_value.push_back(float2required(e)); _d_trg_emb_wei = raw_value; - for (int e : offset) + for (int e : offset) { _p_d_trg_emb_wei.push_back( thrust::raw_pointer_cast(_d_trg_emb_wei.data()) + e); - } - std::cout << "finish initializing " << source + } + } // trg + + if (_is_multilingual) { + // fill in language embedding + std::vector<_DataType> raw_value; + for (float e : layer.lang_emb()) { + raw_value.push_back(float2required(e)); + } + + if (source == "src") { + _d_src_lang_emb = raw_value; + _p_d_src_emb_wei.push_back( + thrust::raw_pointer_cast(_d_src_lang_emb.data())); + } else { + if (layer.lang_emb_size() / _hidden_size != + layer.trg_vocab_mask_size() / _trg_vocab_size) { + return "Wrong trg_lang_emb_size or trg_vocab_mask_size !"; + } + _d_trg_lang_emb = raw_value; + _p_d_trg_emb_wei.push_back( + thrust::raw_pointer_cast(_d_trg_lang_emb.data())); + // fill in target vocab mask + std::vector h_mask; + for (int ele : layer.trg_vocab_mask()) h_mask.push_back(ele); + _d_trg_vocab_mask = h_mask; + _p_d_trg_vocab_mask = thrust::raw_pointer_cast(_d_trg_vocab_mask.data()); + } + + std::cout << "Finish loading multi lingual weights from host to device" + << std::endl; + } + + std::cout << "Finish loading " << source << "_emb_wei from host to device" << std::endl; return ""; } @@ -169,27 +202,27 @@ std::string TransformerWeight::parse_enc_wei( for (auto enc_layer : transformer.encoder_stack()) { offset.push_back(idx); if (enc_layer.multihead_norm_scale_size() != _hidden_size) - return "wrong multihead_norm_scale_size !"; + return "Wrong multihead_norm_scale_size !"; for (float ele : enc_layer.multihead_norm_scale()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (enc_layer.multihead_norm_bias_size() != _hidden_size) - return "wrong multihead_norm_bias_size !"; + return "Wrong multihead_norm_bias_size !"; for (float ele : enc_layer.multihead_norm_bias()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (enc_layer.multihead_project_kernel_qkv_size() != _hidden_size * _hidden_size * 3) - return "wrong multihead_project_kernel_qkv_size !"; + return "Wrong multihead_project_kernel_qkv_size !"; for (float ele : enc_layer.multihead_project_kernel_qkv()) value.push_back(ele); idx += _hidden_size * _hidden_size * 3; offset.push_back(idx); if (enc_layer.multihead_project_bias_qkv_size() != _hidden_size * 3) - return "wrong multihead_project_bias_qkv_size !"; + return "Wrong multihead_project_bias_qkv_size !"; for (float ele : enc_layer.multihead_project_bias_qkv()) value.push_back(ele); idx += _hidden_size * 3; @@ -197,51 +230,51 @@ std::string TransformerWeight::parse_enc_wei( offset.push_back(idx); if (enc_layer.multihead_project_kernel_output_size() != _hidden_size * _hidden_size) - return "wrong multihead_project_kernel_output_size !"; + return "Wrong multihead_project_kernel_output_size !"; for (float ele : enc_layer.multihead_project_kernel_output()) value.push_back(ele); idx += _hidden_size * _hidden_size; offset.push_back(idx); if (enc_layer.multihead_project_bias_output_size() != _hidden_size) - return "wrong multihead_project_bias_output_size !"; + return "Wrong multihead_project_bias_output_size !"; for (float ele : enc_layer.multihead_project_bias_output()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (enc_layer.ffn_norm_scale_size() != _hidden_size) - return "wrong ffn_norm_scale_size !"; + return "Wrong ffn_norm_scale_size !"; for (float ele : enc_layer.ffn_norm_scale()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (enc_layer.ffn_norm_bias_size() != _hidden_size) - return "wrong ffn_norm_bias_size !"; + return "Wrong ffn_norm_bias_size !"; for (float ele : enc_layer.ffn_norm_bias()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (enc_layer.ffn_first_kernel_size() != _hidden_size * _inner_size) - return "wrong ffn_first_kernel_size !"; + return "Wrong ffn_first_kernel_size !"; for (float ele : enc_layer.ffn_first_kernel()) value.push_back(ele); idx += _hidden_size * _inner_size; offset.push_back(idx); if (enc_layer.ffn_first_bias_size() != _inner_size) - return "wrong ffn_first_bias_size !"; + return "Wrong ffn_first_bias_size !"; for (float ele : enc_layer.ffn_first_bias()) value.push_back(ele); idx += _inner_size; offset.push_back(idx); if (enc_layer.ffn_second_kernel_size() != _hidden_size * _inner_size) - return "wrong ffn_second_kernel_size !"; + return "Wrong ffn_second_kernel_size !"; for (float ele : enc_layer.ffn_second_kernel()) value.push_back(ele); idx += _hidden_size * _inner_size; offset.push_back(idx); if (enc_layer.ffn_second_bias_size() != _hidden_size) - return "wrong ffn_second_bias_size !"; + return "Wrong ffn_second_bias_size !"; for (float ele : enc_layer.ffn_second_bias()) value.push_back(ele); idx += _hidden_size; @@ -253,7 +286,7 @@ std::string TransformerWeight::parse_enc_wei( for (int e : offset) _p_d_enc_wei.push_back(thrust::raw_pointer_cast(_d_enc_wei.data()) + e); - std::cout << "finish initializing enc_wei from host to device" << std::endl; + std::cout << "Finish loading enc_wei from host to device" << std::endl; return ""; } @@ -270,115 +303,115 @@ std::string TransformerWeight::parse_dec_wei( for (auto dec_layer : transformer.decoder_stack()) { offset.push_back(idx); if (dec_layer.self_norm_scale_size() != _hidden_size) - return "wrong self_norm_scale size !"; + return "Wrong self_norm_scale size !"; for (float ele : dec_layer.self_norm_scale()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (dec_layer.self_norm_bias_size() != _hidden_size) - return "wrong self_norm_bias_size !"; + return "Wrong self_norm_bias_size !"; for (float ele : dec_layer.self_norm_bias()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (dec_layer.self_project_kernel_qkv_size() != _hidden_size * _hidden_size * 3) - return "wrong self_project_kernel_qkv size !"; + return "Wrong self_project_kernel_qkv size !"; for (float ele : dec_layer.self_project_kernel_qkv()) value.push_back(ele); idx += _hidden_size * _hidden_size * 3; offset.push_back(idx); if (dec_layer.self_project_bias_qkv_size() != _hidden_size * 3) - return "wrong self_project_bias_qkv size !"; + return "Wrong self_project_bias_qkv size !"; for (float ele : dec_layer.self_project_bias_qkv()) value.push_back(ele); idx += _hidden_size * 3; offset.push_back(idx); if (dec_layer.self_project_kernel_output_size() != _hidden_size * _hidden_size) - return "wrong self_project_kernel_output size !"; + return "Wrong self_project_kernel_output size !"; for (float ele : dec_layer.self_project_kernel_output()) value.push_back(ele); idx += _hidden_size * _hidden_size; offset.push_back(idx); if (dec_layer.self_project_bias_output_size() != _hidden_size) - return "wrong self_project_bias_output size !"; + return "Wrong self_project_bias_output size !"; for (float ele : dec_layer.self_project_bias_output()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (dec_layer.encdec_norm_scale_size() != _hidden_size) - return "wrong encdec_norm_scale size !"; + return "Wrong encdec_norm_scale size !"; for (float ele : dec_layer.encdec_norm_scale()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (dec_layer.encdec_norm_bias_size() != _hidden_size) - return "wrong encdec_norm_bias_size !"; + return "Wrong encdec_norm_bias_size !"; for (float ele : dec_layer.encdec_norm_bias()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (dec_layer.encdec_project_kernel_q_size() != _hidden_size * _hidden_size) - return "wrong encdec_project_kernel_q size !"; + return "Wrong encdec_project_kernel_q size !"; for (float ele : dec_layer.encdec_project_kernel_q()) value.push_back(ele); idx += _hidden_size * _hidden_size; offset.push_back(idx); if (dec_layer.encdec_project_bias_q_size() != _hidden_size) - return "wrong encdec_project_bias_q size !"; + return "Wrong encdec_project_bias_q size !"; for (float ele : dec_layer.encdec_project_bias_q()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (dec_layer.encdec_project_kernel_output_size() != _hidden_size * _hidden_size) - return "wrong encdec_project_kernel_output size !"; + return "Wrong encdec_project_kernel_output size !"; for (float ele : dec_layer.encdec_project_kernel_output()) value.push_back(ele); idx += _hidden_size * _hidden_size; offset.push_back(idx); if (dec_layer.encdec_project_bias_output_size() != _hidden_size) - return "wrong encdec_project_bias_output size !"; + return "Wrong encdec_project_bias_output size !"; for (float ele : dec_layer.encdec_project_bias_output()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (dec_layer.ffn_norm_scale_size() != _hidden_size) - return "wrong ffn_norm_scale_size !"; + return "Wrong ffn_norm_scale_size !"; for (float ele : dec_layer.ffn_norm_scale()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (dec_layer.ffn_norm_bias_size() != _hidden_size) - return "wrong ffn_norm_bias_size !"; + return "Wrong ffn_norm_bias_size !"; for (float ele : dec_layer.ffn_norm_bias()) value.push_back(ele); idx += _hidden_size; offset.push_back(idx); if (dec_layer.ffn_first_kernel_size() != _hidden_size * _inner_size) - return "wrong ffn_first_kernel_size !"; + return "Wrong ffn_first_kernel_size !"; for (float ele : dec_layer.ffn_first_kernel()) value.push_back(ele); idx += _hidden_size * _inner_size; offset.push_back(idx); if (dec_layer.ffn_first_bias_size() != _inner_size) - return "wrong ffn_first_bias_size !"; + return "Wrong ffn_first_bias_size !"; for (float ele : dec_layer.ffn_first_bias()) value.push_back(ele); idx += _inner_size; offset.push_back(idx); if (dec_layer.ffn_second_kernel_size() != _hidden_size * _inner_size) - return "wrong ffn_second_kernel_size !"; + return "Wrong ffn_second_kernel_size !"; for (float ele : dec_layer.ffn_second_kernel()) value.push_back(ele); idx += _hidden_size * _inner_size; offset.push_back(idx); if (dec_layer.ffn_second_bias_size() != _hidden_size) - return "wrong ffn_second_bias_size !"; + return "Wrong ffn_second_bias_size !"; for (float ele : dec_layer.ffn_second_bias()) value.push_back(ele); idx += _hidden_size; @@ -390,7 +423,7 @@ std::string TransformerWeight::parse_dec_wei( for (int e : offset) _p_d_dec_wei.push_back(thrust::raw_pointer_cast(_d_dec_wei.data()) + e); - std::cout << "finish initializing dec_wei from host to device" << std::endl; + std::cout << "Finish loading dec_wei from host to device" << std::endl; return ""; } @@ -412,6 +445,10 @@ std::string TransformerWeight::initializing(std::string proto_path, get_model_config(transformer, only_decoder); + if (_hidden_size % 4 != 0) { + return "hidden_size should be a multiple of 4 to avoid misaligned address in CUDA"; + } + std::string res; if (!only_decoder) { res = parse_emb_wei(transformer.src_embedding(), "src"); @@ -429,7 +466,7 @@ std::string TransformerWeight::initializing(std::string proto_path, res = parse_dec_wei(transformer); if (!res.empty()) return res; - std::cout << "finish initializing all weight from host to device" + std::cout << "Finish loading all weight from host to device" << std::endl; // Optional: Delete all global objects allocated by libprotobuf. // google::protobuf::ShutdownProtobufLibrary(); diff --git a/proto/transformer_weight.h b/proto/transformer_weight.h index 372fee4f..0283e529 100644 --- a/proto/transformer_weight.h +++ b/proto/transformer_weight.h @@ -45,6 +45,9 @@ class TransformerWeight { thrust::device_vector<_DataType> _d_trg_emb_wei; thrust::device_vector<_DataType> _d_enc_wei; thrust::device_vector<_DataType> _d_dec_wei; + thrust::device_vector _d_trg_vocab_mask; + thrust::device_vector<_DataType> _d_src_lang_emb; + thrust::device_vector<_DataType> _d_trg_lang_emb; public: std::string initializing(std::string proto_path, bool only_decoder = false); @@ -104,6 +107,8 @@ class TransformerWeight { bool _is_post_ln; bool _no_scale_embedding; bool _use_gelu; + bool _is_multilingual; + const int *_p_d_trg_vocab_mask; void print_model_config() { std::cout << "***model config***" << std::endl; @@ -121,6 +126,7 @@ class TransformerWeight { std::cout << "start_id: " << _start_id << std::endl; std::cout << "end_id: " << _end_id << std::endl; std::cout << "padding_id: " << _padding_id << std::endl; + std::cout << "is_multilingual: " << _is_multilingual << std::endl; std::cout << std::endl; std::cout << "***generator config***" << std::endl; std::cout << "beam size: " << _beam_size << std::endl; diff --git a/server/transformer_server.cc.cu b/server/transformer_server.cc.cu index 9b9ba8d0..363c94ba 100644 --- a/server/transformer_server.cc.cu +++ b/server/transformer_server.cc.cu @@ -334,7 +334,8 @@ int Context::Init() { decoder_ = std::make_shared>( max_batch_size, reinterpret_cast(d_padding_mask_), reinterpret_cast<_optraits::DataType*>(d_encoder_output_), - reinterpret_cast(d_output_), tw_, stream_, hd_); + reinterpret_cast(d_output_), tw_, stream_, hd_, + false, reinterpret_cast(d_input_)); res = decoder_->check(); if (!res.empty()) { LOG_ERROR << res << std::endl; diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..e0fee123 --- /dev/null +++ b/setup.py @@ -0,0 +1,97 @@ +import os +import re +import sys +import platform +import subprocess + +from setuptools import setup, Extension +import setuptools +from setuptools.command.build_ext import build_ext +from distutils.version import LooseVersion + +ENABLE_FP32 = int(os.environ.get("ENABLE_FP32", 0)) +ENABLE_DEBUG = int(os.environ.get("ENABLE_DEBUG", 0)) + + +class CMakeExtension(Extension): + def __init__(self, name, sourcedir="", *args, **kwargs): + Extension.__init__(self, name, sources=[], *args, **kwargs) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + def run(self): + try: + out = subprocess.check_output(["cmake", "--version"]) + except OSError: + raise RuntimeError( + "CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions) + ) + + if platform.system() == "Windows": + cmake_version = LooseVersion( + re.search(r"version\s*([\d.]+)", out.decode()).group(1) + ) + if cmake_version < "3.1.0": + raise RuntimeError("CMake >= 3.1.0 is required on Windows") + + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + # required for auto-detection of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + cmake_args = [ + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, + "-DPYTHON_EXECUTABLE=" + sys.executable, + ] + + cfg = "Release" + build_args = ["--config", cfg] + + if platform.system() == "Windows": + cmake_args += [ + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(cfg.upper(), extdir) + ] + if sys.maxsize > 2 ** 32: + cmake_args += ["-A", "x64"] + build_args += ["--", "/m"] + else: + cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] + if not ENABLE_FP32: + cmake_args += ["-DFP16_MODE=ON"] + if ENABLE_DEBUG: + cmake_args += ["-DDEBUG_MODE=ON"] + build_args += ["--target", "lightseq"] + build_args += ["--", "-j"] + + env = os.environ.copy() + env["CXXFLAGS"] = '{} -DVERSION_INFO=\\"{}\\"'.format( + env.get("CXXFLAGS", ""), self.distribution.get_version() + ) + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + subprocess.check_call( + ["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env + ) + subprocess.check_call( + ["cmake", "--build", "."] + build_args, cwd=self.build_temp + ) + + +setup( + name="lightseq", + version="0.1.0", + author="Ying Xiong", + author_email="xiongying.taka@bytedance.com", + description="python wrapper of LightSeq, LightSeq is a high performance inference library for SOTA NLU/NLG models", + long_description="", + ext_modules=[CMakeExtension("lightseq")], + cmdclass=dict(build_ext=CMakeBuild), + zip_safe=False, + packages=setuptools.find_packages(), +)