From 5f0b62cde54f59bdeac7978c9f9c12d0a4bc56db Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:51:05 -0800 Subject: [PATCH] [ORT 1.17.0 Release] Cherry-pick Final Round (#19327) ### Description Cherry-pick Final Round ### Motivation and Context --------- Co-authored-by: Adrian Lizarraga Co-authored-by: Changming Sun Co-authored-by: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Co-authored-by: rachguo Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Co-authored-by: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Co-authored-by: Yufeng Li --- ...anch.Nuget-WindowsAI-Pipeline.Official.yml | 2 + ThirdPartyNotices.txt | 207 ++++++ docs/ContribOperators.md | 16 +- docs/OperatorKernels.md | 2 +- .../contrib_ops/cpu/bert/attention_common.h | 5 + .../decoder_masked_multihead_attention.cc | 2 + .../cuda/bert/flash_attention/flash_api.cc | 51 +- .../cuda/bert/flash_attention/flash_api.h | 6 +- .../cuda/bert/group_query_attention.cc | 26 +- .../cuda/bert/group_query_attention.h | 5 + .../cuda/bert/group_query_attention_helper.h | 150 ++-- .../cuda/bert/group_query_attention_impl.cu | 125 ++-- .../cuda/bert/group_query_attention_impl.h | 2 + .../core/graph/contrib_ops/bert_defs.cc | 34 +- .../tensorrt/tensorrt_execution_provider.cc | 48 +- .../tensorrt/tensorrt_execution_provider.h | 6 +- .../transformers/models/whisper/README.md | 4 +- .../models/whisper/convert_to_onnx.py | 2 +- .../models/whisper/whisper_helper.py | 15 +- onnxruntime/test/perftest/README.md | 4 + .../test/perftest/command_args_parser.cc | 47 +- onnxruntime/test/perftest/ort_test_session.cc | 29 +- .../test/perftest/test_configuration.h | 2 + .../apple/apple_package_test/Podfile.template | 6 +- .../test/python/transformers/rotary_flash.py | 693 ++++++++++++++++++ .../python/transformers/test_flash_attn.py | 668 ++++++++++++++--- tools/ci_build/build.py | 3 +- ...ult_full_ios_framework_build_settings.json | 30 + .../github/apple/test_apple_packages.py | 13 +- .../azure-pipelines/templates/c-api-cpu.yml | 23 +- ...txt => requirements-transformers-test.txt} | 3 +- 31 files changed, 1923 insertions(+), 306 deletions(-) create mode 100644 onnxruntime/test/python/transformers/rotary_flash.py create mode 100644 tools/ci_build/github/apple/default_full_ios_framework_build_settings.json rename tools/ci_build/{requirements.txt => requirements-transformers-test.txt} (94%) diff --git a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml index 67f9d8b0ce392..fd3b7266d30f7 100644 --- a/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml +++ b/.pipelines/OneBranch.Nuget-WindowsAI-Pipeline.Official.yml @@ -29,6 +29,8 @@ extends: git: submodules: false globalSdl: # https://aka.ms/obpipelines/sdl + asyncSdl: + enabled: false tsa: enabled: true prefast: diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 700206180decd..30894903ec8d2 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -6299,3 +6299,210 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +_____ + +neural-speed + +https://github.com/intel/neural-speed + + Apache License + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + ============================================================================ + + Copyright 2016-2019 Intel Corporation + Copyright 2018 YANDEX LLC + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This distribution includes third party software ("third party programs"). + This third party software, even if included with the distribution of + the Intel software, may be governed by separate license terms, including + without limitation, third party license terms, other Intel software license + terms, and open source software license terms. These separate license terms + govern your use of the third party programs as set forth in the + "THIRD-PARTY-PROGRAMS" file. diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 22e82443167f6..fd26b09b09531 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
do_rotary : int
+
Whether to use rotary position embedding. Default value is 0.
kv_num_heads : int (required)
Number of attention heads for k and v
local_window_size : int
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
rotary_interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs +#### Inputs (7 - 9)
query : T
-
Query with shape (batch_size, sequence_length, hidden_size)
-
key : T
+
Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+
key (optional) : T
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
-
value : T
+
value (optional) : T
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
past_key (optional) : T
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
@@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
total_sequence_length : M
Scalar tensor of total sequence length (past + new).
+
cos_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9ecc58bee0725..6e5d842002116 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -843,7 +843,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index da489a6901512..8afeb874750b4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters { bool is_unidirectional; // causal int local_window_size; bool kv_share_buffer; + bool is_packed_qkv; bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool do_rotary; + bool rotary_interleaved; float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; + int zeros_count; + int* zero_ptr; }; namespace attention { diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc index a9b60da0c96ca..66c0aceaed1e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_multihead_attention.cc @@ -74,6 +74,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* parameters.kv_data_in_flight = ParseEnvironmentVariableWithDefault( attention::kDecoderMaskedAttentionLoadKVDataInFlight, false); + bool is_unidirectional = false; bool is_dmmha_packing = (key == nullptr && value == nullptr); ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, @@ -88,6 +89,7 @@ Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* num_heads_, mask_filter_value_, scale_, + is_unidirectional, past_present_share_buffer_, is_dmmha_packing, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index d6eb87228bb4a..2c296bf4f8483 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size - void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size) { - // if (seqlen_q == 1) { - // is_causal = false; - // } // causal=true is the same as causal=false in this case - + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + // In kv-cache case, seqlen_k_max as kv sequence length Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, is_causal ? 0 : -1); params.dprops = &dprops; - if (k != nullptr && v != nullptr) { + if (k_new != nullptr && v_new != nullptr) { params.seqlen_knew = seqlen_k_new; - params.knew_ptr = k; - params.vnew_ptr = v; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; // All stride are in elements, not bytes. - params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.knew_row_stride = num_heads_k * head_size; - params.vnew_row_stride = num_heads_k * head_size; + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } params.knew_head_stride = head_size; params.vnew_head_stride = head_size; } else { @@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.cu_seqlens_k = static_cast(seqlens_k_); } + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = (head_size / 16) * 16; + } + params.num_splits = num_splits; if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { params.softmax_lseaccum_ptr = softmax_lse_accum; @@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr); + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 3d75d6834b8e0..387d1cf9d84fe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size = -1); + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index fd6fb79742cac..fe56f84f0a886 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -47,6 +47,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) kv_num_heads_ = static_cast(kv_num_heads); is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -62,6 +64,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_memory_efficient_attention_ = true; #endif + if (!disable_flash_attention_) { + zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); + } } template @@ -73,6 +78,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); const Tensor* total_seqlen = context->Input(6); + const Tensor* cos_cache = context->Input(7); + const Tensor* sin_cache = context->Input(8); auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; @@ -84,6 +91,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { value, past_key, past_value, + cos_cache, + sin_cache, ¶meters, num_heads_, kv_num_heads_, @@ -93,7 +102,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { scale_, device_prop.maxThreadsPerBlock)); parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + parameters.zeros_count = kZerosCount; + parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; int sequence_length = parameters.sequence_length; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); @@ -139,6 +154,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && + do_rotary_ == false && + key != nullptr && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -182,8 +199,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); - data.key = reinterpret_cast(key->Data()); - data.value = reinterpret_cast(value->Data()); + data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); + data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.output = reinterpret_cast(output->MutableData()); @@ -229,6 +246,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } + // Rotary + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast(cos_cache->Data()); + data.sin_cache = reinterpret_cast(sin_cache->Data()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 54a8127e29e7b..15573ece166fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel { int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention int local_window_size_; + bool is_unidirectional_; bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; float scale_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) + IAllocatorUniquePtr zeros_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 2cb9955807f26..853e1a710cb24 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query, bool is_past_bsnh, float scale) { // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr // no packing for q/k/v: - // query (Q) : (B, S, D) - // key (K) : (B, S, D_kv) - // value (V) : (B, S, D_kv) + // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) + // key (K) : (B, S, D_kv) or nullptr + // value (V) : (B, S, D_kv) or nullptr ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - + const bool is_packed_qkv = key == nullptr; const auto& query_dims = query->Shape().GetDims(); - const auto& key_dims = key->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query, int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); int q_hidden_size = static_cast(query_dims[2]); - int head_size = static_cast(q_hidden_size) / num_heads; + int head_size = 0; + + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - int kv_hidden_size = static_cast(key_dims[2]); + int kv_hidden_size = 0; + // Check key and value when not packed + if (!is_packed_qkv) { + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } else if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != key_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 1 (sequence length)"); + } + kv_hidden_size = static_cast(key_dims[2]); + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } else if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 1 (sequence length)"); + } else if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + } else { + // Check packed qkv + head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + } + // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { const auto& past_key_dims = past_key->Shape().GetDims(); @@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (static_cast(sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)"); - } - - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - // Check seqlens_k tensor (holding past seqlen for token gen) const auto& seqlens_dim = seqlens_k->Shape().GetDims(); if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { @@ -180,6 +205,36 @@ Status CheckInputs(const Tensor* query, int total_sequence_length = *((*total_seqlen).template Data()); int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + if (cos_cache != nullptr && sin_cache != nullptr) { + const auto& cos_dims = cos_cache->Shape().GetDims(); + const auto& sin_dims = sin_cache->Shape().GetDims(); + + if (head_size % 16 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size shall be a multiple of 16. Got head_size % 16 == ", + head_size % 16); + } + if (cos_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 0 must be of present_sequence_length."); + } + if (sin_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 0 must be of present_sequence_length."); + } + if (cos_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (sin_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + } else if (cos_cache != nullptr || sin_cache != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); + } + bool is_prompt = sequence_length != 1; if (parameters != nullptr) { @@ -190,9 +245,10 @@ Status CheckInputs(const Tensor* query, output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; - output_parameters->head_size = q_hidden_size / num_heads; + output_parameters->head_size = head_size; output_parameters->kv_hidden_size = kv_hidden_size; output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; @@ -208,6 +264,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -220,7 +278,7 @@ Status CheckInputs(const Tensor* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); } } // namespace group_query_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 5b0f5d0cfe601..d88e9a49fb5ee 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -151,9 +151,10 @@ template Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, cudaStream_t stream, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.sequence_length; + const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; @@ -441,7 +442,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, return CUDA_CALL(cudaGetLastError()); } - __global__ void PastToTotalSeqlen(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int add_seqlen) { @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int threads_per_block) { if (parameters.is_prompt) { return Status::OK(); } @@ -482,91 +482,63 @@ Status FlashAttention( const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.sequence_length; - const int present_sequence_length = parameters.seqlen_present_kv_cache; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = true; - bool is_bf16 = std::is_same::value; - // Note: seqlens_k is past sequence length for flash - if (parameters.is_prompt) { - // Launch kernel to copy seqlen - constexpr int thr_per_blk = 256; - int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); - } - - void* seqlens_k = reinterpret_cast(data.seqlens_k); - - if (parameters.kv_share_buffer) { - // Share buffer case - if (data.past_key == nullptr || data.past_key != data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv shall share the same tensor when kv_share_buffer is on."); - } - - if (parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); - key = nullptr; - value = nullptr; - seqlens_k = reinterpret_cast(data.seqlens_k_total); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + void* query = reinterpret_cast(const_cast(data.query)); + void* key; + void* value; - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + if (!parameters.is_packed_qkv) { + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); } else { - // Not share buffer case - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (data.past_key != nullptr && data.past_key == data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv share the same tensor but kv_share_buffer is not on."); - } - - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; + } - if (!parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + void* seqlens_k = reinterpret_cast(data.seqlens_k); + if (parameters.is_prompt) { + // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value + // user should use seqlens_k to index into output to get new tokens + if (batch_size <= parameters.zeros_count) { + seqlens_k = parameters.zero_ptr; + } else { + // Launch kernel to create larger seqlen tensor when batch_size > 256 + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); + seqlens_k = data.seqlens_k_total; } - - seqlens_k = reinterpret_cast(data.seqlens_k_total); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); - DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size); - DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size); - DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, 0, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); } + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, + batch_size, num_heads, kv_num_heads, head_size, sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, + scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, + parameters.is_packed_qkv)); + + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -672,7 +644,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index de32d7ea93163..1bf91f9c875eb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -21,6 +21,8 @@ struct GroupQueryAttentionData { const T* past_key = nullptr; const T* past_value = nullptr; int* seqlens_k = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7f34647f1faef..8583474a1e391 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -259,13 +259,13 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); - } else { - fail_shape_inference("Missing input 2 (value)"); } } if (ctx.getNumOutputs() > 1) { // has present output if (hasInputShape(ctx, past_key_index)) { + // auto& query_shape = getInputShape(ctx, 0); + // auto& query_dims = query_shape.dim(); auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); if (past_dims.size() != 4) { @@ -273,8 +273,7 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& } ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not } } } @@ -1015,18 +1014,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", AttributeProto::INT, static_cast(-1)) + .Attr("do_rotary", + "Whether to use rotary position embedding. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("rotary_interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", - "Query with shape (batch_size, sequence_length, hidden_size)", + "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" + "(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).", "T") .Input(1, "key", "Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", - "T") + "T", + OpSchema::Optional) .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", - "T") + "T", + OpSchema::Optional) .Input(3, "past_key", "past state key with support for format BNSH. When past_key uses same tensor as present_key" @@ -1047,6 +1057,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "total_sequence_length", "Scalar tensor of total sequence length (past + new).", "M") + .Input(7, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) + .Input(8, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 39e5f5be000e5..1c9340dfd0e6d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1684,6 +1684,16 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv } } + // cuda graph: + // cudaStreamSynchronize() is not allowed in cuda graph capture. + // + // external stream: + // If user provides "external" cuda stream, only this cuda stream will be used even if multiple threads are running InferenceSession.Run() concurrently. + // So, no need to synchronize different streams after enqueueV3. + if (cuda_graph_enable_ || external_stream_) { + sync_stream_after_enqueue_ = false; + } + { auto lock = GetApiLock(); runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger())); @@ -2529,7 +2539,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } else if (number_of_trt_nodes == number_of_ort_nodes) { LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { - sync_stream_after_enqueue_ = true; LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } @@ -3131,7 +3140,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, @@ -3159,7 +3168,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; - bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; @@ -3552,7 +3560,21 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } - if (sync_stream_after_enqueue || dds_output_set.size() > 0) { + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, + * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. + * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, + * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. + * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. + * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream + * is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. + * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } @@ -3696,7 +3718,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con &contexts_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - sync_stream_after_enqueue_, context_memory_sharing_enable_, &max_ctx_mem_size_, &tensorrt_mu_}; @@ -3723,7 +3744,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; auto fused_node_name = trt_state->fused_node_name; - bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); @@ -3833,7 +3853,21 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } - if (sync_stream_after_enqueue || dds_output_set.size() > 0) { + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this function concurrently, + * TRT EP needs to carefully take care of concurrency here, if not, following concurrent issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per stream. + * In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling InferenceSession::Run() concurrently, + * the trt execution context instance is shared by all the threads and each thread aquires different stream from ORT. + * So TRT EP will end up having one trt execution context using multiple streams which is not suggested. + * But, since the whole compute_func() is protected by the lock and if cudaStreamSynchronize() is enforced here, one trt execution context per stream + * is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all operations to prevent the concurrent issue mentioned above. + * However, if cuda graph is enabled, TRT EP won't call cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (sync_stream_after_enqueue_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index ad2d2c55c67e1..e86f997b6597a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -149,7 +149,6 @@ struct TensorrtFuncState { std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; - bool sync_stream_after_enqueue = false; OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; @@ -193,7 +192,6 @@ struct TensorrtShortFuncState { std::unique_ptr* context = nullptr; std::vector> input_info; std::vector> output_info; - bool sync_stream_after_enqueue = false; bool context_memory_sharing_enable = false; size_t* max_context_mem_size_ptr = nullptr; OrtMutex* tensorrt_mu_ptr = nullptr; @@ -335,8 +333,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; - // Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3() - mutable bool sync_stream_after_enqueue_ = false; + // Call cudaStreamSynchronize() after TRT enqueueV3() + mutable bool sync_stream_after_enqueue_ = true; CUDAGraph cuda_graph_; bool is_graph_captured_ = false; diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 8ff5c8a6e1de0..02100266200f8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -60,10 +60,10 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w Export + Optimize for FP16 and GPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision ``` Export + Quantize for INT8 diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 50637b772c233..e15a12c07bed7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -478,7 +478,7 @@ def main(argv=None): # Wrap parity check in try-except to allow export to continue in case this produces an error try: with torch.no_grad(): - max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, ort_session, device) + max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) if max_diff > 1e-4: logger.warning("PyTorch and ONNX Runtime results are NOT close") else: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 8c22cd5e745b3..a4bef1f06b4fe 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -12,7 +12,9 @@ import numpy as np import torch +from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor +from transformers import __version__ as transformers_version from whisper_decoder import WhisperDecoder, WhisperDecoderHelper, WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderHelper from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitHelper @@ -88,7 +90,10 @@ def load_model( Returns: Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. """ - model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir) + extra_kwargs = {} + if version.parse(transformers_version) >= version.parse("4.36.0"): + extra_kwargs["attn_implementation"] = "eager" + model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs) if state_dict_path: model.load_state_dict(torch.load(state_dict_path), strict=False) @@ -262,11 +267,17 @@ def optimize_onnx( @staticmethod def verify_onnx( model_name_or_path: str, + cache_dir: str, ort_session: InferenceSession, device: torch.device, ): """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" - pt_model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path).to(device) + extra_kwargs = {} + if version.parse(transformers_version) >= version.parse("4.36.0"): + extra_kwargs["attn_implementation"] = "eager" + pt_model = WhisperForConditionalGeneration.from_pretrained( + model_name_or_path, cache_dir=cache_dir, **extra_kwargs + ).to(device) processor = WhisperProcessor.from_pretrained(model_name_or_path) config = WhisperConfig.from_pretrained(model_name_or_path) diff --git a/onnxruntime/test/perftest/README.md b/onnxruntime/test/perftest/README.md index 59059cf6b62b7..4169d1bf54c65 100644 --- a/onnxruntime/test/perftest/README.md +++ b/onnxruntime/test/perftest/README.md @@ -35,6 +35,10 @@ Options: -x: [intra_op_num_threads]: Sets the number of threads used to parallelize the execution within nodes. A value of 0 means the test will auto-select a default. Must >=0. -y: [inter_op_num_threads]: Sets the number of threads used to parallelize the execution of the graph (across nodes), A value of 0 means the test will auto-select a default. Must >=0. + + -C: [session_config_entries]: Specify session configuration entries as key-value pairs: -C "| |" + Refer to onnxruntime_session_options_config_keys.h for valid keys and values. + [Example] -C "session.disable_cpu_ep_fallback|1 ep.context_enable|1" -h: help. diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 6c1d447c7b3a3..7cfbe0a84e3e6 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -6,6 +6,9 @@ #include #include +#include +#include +#include // Windows Specific #ifdef _WIN32 @@ -57,6 +60,9 @@ namespace perftest { "\t-d [CUDA only][cudnn_conv_algorithm]: Specify CUDNN convolution algorithms: 0(benchmark), 1(heuristic), 2(default). \n" "\t-q [CUDA only] use separate stream for copy. \n" "\t-z: Set denormal as zero. When turning on this option reduces latency dramatically, a model may have denormals.\n" + "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" + "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" + "\t [Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n" "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" @@ -149,9 +155,42 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, return true; } +static bool ParseSessionConfigs(const std::string& configs_string, + std::unordered_map& session_configs) { + std::istringstream ss(configs_string); + std::string token; + + while (ss >> token) { + if (token == "") { + continue; + } + + std::string_view token_sv(token); + + auto pos = token_sv.find("|"); + if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) { + // Error: must use a '|' to separate the key and value for session configuration entries. + return false; + } + + std::string key(token_sv.substr(0, pos)); + std::string value(token_sv.substr(pos + 1)); + + auto it = session_configs.find(key); + if (it != session_configs.end()) { + // Error: specified duplicate session configuration entry: {key} + return false; + } + + session_configs.insert(std::make_pair(std::move(key), std::move(value))); + } + + return true; +} + /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:AMPIDZvhsqz"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqz"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -322,6 +361,12 @@ static bool ParseDimensionOverride(std::basic_string& dim_identifier, case 'T': test_config.run_config.intra_op_thread_affinities = ToUTF8String(optarg); break; + case 'C': { + if (!ParseSessionConfigs(ToUTF8String(optarg), test_config.run_config.session_config_entries)) { + return false; + } + break; + } case 'D': test_config.run_config.disable_spinning = true; break; diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 6854a2649060a..87506c7240578 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -634,22 +634,41 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); session_options.DisableMemPattern(); session_options.SetExecutionMode(performance_test_config.run_config.execution_mode); + // Set any extra session configuration entries provided by the user via command-line arguments. + // + // Some session config entries can also be set via dedicated command-line options. + // If the user uses multiple command-line options to set the same session config entry, + // we'll print a warning. Note that the dedicated command-line options will take precedence. + const auto& user_session_configs = performance_test_config.run_config.session_config_entries; + for (auto& it : user_session_configs) { + session_options.AddConfigEntry(it.first.c_str(), it.second.c_str()); + } + + auto warn_dup_config_entry = [&user_session_configs](const char* key) -> void { + if (user_session_configs.find(key) != user_session_configs.end()) { + fprintf(stderr, "[WARNING]: Trying to set session config entry '%s' via multiple command-line options\n", key); + } + }; + if (performance_test_config.run_config.intra_op_num_threads > 0) { fprintf(stdout, "Setting intra_op_num_threads to %d\n", performance_test_config.run_config.intra_op_num_threads); session_options.SetIntraOpNumThreads(performance_test_config.run_config.intra_op_num_threads); } if (!performance_test_config.run_config.intra_op_thread_affinities.empty()) { + warn_dup_config_entry(kOrtSessionOptionsConfigIntraOpThreadAffinities); fprintf(stdout, "Setting intra op thread affinity as %s\n", performance_test_config.run_config.intra_op_thread_affinities.c_str()); session_options.AddConfigEntry(kOrtSessionOptionsConfigIntraOpThreadAffinities, performance_test_config.run_config.intra_op_thread_affinities.c_str()); } if (performance_test_config.run_config.disable_spinning) { + warn_dup_config_entry(kOrtSessionOptionsConfigAllowIntraOpSpinning); fprintf(stdout, "Disabling intra-op thread spinning entirely\n"); session_options.AddConfigEntry(kOrtSessionOptionsConfigAllowIntraOpSpinning, "0"); } if (performance_test_config.run_config.disable_spinning_between_run) { + warn_dup_config_entry(kOrtSessionOptionsConfigForceSpinningStop); fprintf(stdout, "Disabling intra-op thread spinning between runs\n"); session_options.AddConfigEntry(kOrtSessionOptionsConfigForceSpinningStop, "1"); } @@ -661,12 +680,16 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); // Set optimization level. session_options.SetGraphOptimizationLevel(performance_test_config.run_config.optimization_level); - if (!performance_test_config.run_config.profile_file.empty()) + if (!performance_test_config.run_config.profile_file.empty()) { session_options.EnableProfiling(performance_test_config.run_config.profile_file.c_str()); - if (!performance_test_config.run_config.optimized_model_path.empty()) + } + if (!performance_test_config.run_config.optimized_model_path.empty()) { session_options.SetOptimizedModelFilePath(performance_test_config.run_config.optimized_model_path.c_str()); - if (performance_test_config.run_config.set_denormal_as_zero) + } + if (performance_test_config.run_config.set_denormal_as_zero) { + warn_dup_config_entry(kOrtSessionOptionsConfigSetDenormalAsZero); session_options.AddConfigEntry(kOrtSessionOptionsConfigSetDenormalAsZero, "1"); + } if (!performance_test_config.run_config.free_dim_name_overrides.empty()) { for (auto const& dim_override : performance_test_config.run_config.free_dim_name_overrides) { if (g_ort->AddFreeDimensionOverrideByName(session_options, ToUTF8String(dim_override.first).c_str(), dim_override.second) != nullptr) { diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 43ad556247f97..5a49414a49004 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -6,6 +6,7 @@ #include #include #include +#include #include "core/graph/constants.h" #include "core/framework/session_options.h" @@ -56,6 +57,7 @@ struct RunConfig { bool do_cuda_copy_in_separate_stream{false}; bool set_denormal_as_zero{false}; std::basic_string ep_runtime_config_string; + std::unordered_map session_config_entries; std::map, int64_t> free_dim_name_overrides; std::map, int64_t> free_dim_denotation_overrides; std::string intra_op_thread_affinities; diff --git a/onnxruntime/test/platform/apple/apple_package_test/Podfile.template b/onnxruntime/test/platform/apple/apple_package_test/Podfile.template index 3d191d6fb1cc6..4958e4fa85490 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/Podfile.template +++ b/onnxruntime/test/platform/apple/apple_package_test/Podfile.template @@ -1,6 +1,10 @@ def include_macos_target if '@C_POD_NAME@' != 'onnxruntime-mobile-c' - return true + if ENV['SKIP_MACOS_TEST'] != 'true' + return true + else + return false + end end return false end diff --git a/onnxruntime/test/python/transformers/rotary_flash.py b/onnxruntime/test/python/transformers/rotary_flash.py new file mode 100644 index 0000000000000..42bff9c92b41b --- /dev/null +++ b/onnxruntime/test/python/transformers/rotary_flash.py @@ -0,0 +1,693 @@ +# Copyright (c) 2023, Tri Dao. + + +from typing import Optional, Tuple, Union + +import torch +import triton +import triton.language as tl +from einops import rearrange, repeat + +##### TRITON KERNEL FOR ROTARY ##### + + +# @triton.autotune( +# configs=[ +# triton.Config({"block_m": 2}), +# triton.Config({"block_m": 4}), +# triton.Config({"block_m": 8}), +# triton.Config({"block_m": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], +# ) +@triton.jit +def rotary_kernel( + out_, # Pointers to matrices + x_, + cos_, + sin_, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + block_k: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + block_m: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_ = x_ + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ = out_ + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_ = x_ + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ = out_ + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * block_m >= seqlen: + return + rm = pid_m * block_m + tl.arange(0, block_m) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, block_k) + rk_half = tl.arange(0, block_k // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of x_, do calculation, then store to 1st and 2nd halves of out_ + x_ = x_ + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load(cos_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to( + tl.float32 + ) + sin = tl.load(sin_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to( + tl.float32 + ) + x0 = tl.load(x_, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) + x1 = tl.load( + x_ + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) + tl.store(out_, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store( + out_ + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load x_[0, 2, 4, ...] and x_[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = x_[0, 1, 2, 3, ...] and x1 = x_[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = cos_[0, 0, 1, 1, ...] and sin = sin_[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, block_k) // 2 + x0_ = x_ + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + x1_ = x_ + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + cos_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + sin_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(x0_, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32) + x1 = tl.load(x1_, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) + tl.store(out_, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + block_k = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + grid = lambda META: (triton.cdiv(seqlen, META["block_m"]), batch, nheads) # noqa + block_m = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # key for triton cache (limit number of compilations) + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + block_k, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + block_m, + ) + return output + + +##### ROTARY API ##### + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb + + +class ApplyRotaryEmbQKV(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) + apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + q, k = qkv[:, :, 0], qkv[:, :, 1] + apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) + apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cos_k, sin_k = ctx.saved_tensors + if cos_k is None and sin_k is None and dqkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] + apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True) + apply_rotary( + dk, + cos_k, + sin_k, + seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dqkv, None, None, None, None, None, None + + +def apply_rotary_emb_qkv_( + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + qkv: (batch_size, seqlen, 3, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of Q and K. + """ + return ApplyRotaryEmbQKV.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets) + + +class ApplyRotaryEmbKV(torch.autograd.Function): + @staticmethod + def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): + batch, seqlen, two, nheads, headdim = kv.shape + assert two == 2 + k = kv[:, :, 0] + apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return kv + + @staticmethod + def backward(ctx, dkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, seqlen_offsets = ctx.saved_tensors + else: + cos, sin = ctx.saved_tensors + apply_rotary( + dkv[:, :, 0], + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dkv, None, None, None, None + + +apply_rotary_emb_kv_ = ApplyRotaryEmbKV.apply + + +def apply_rotary_emb_kv_( + kv, + cos, + sin, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + kv: (batch_size, seqlen, 2, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + kv: (batch_size, seqlen, 2, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of K. + """ + return ApplyRotaryEmbKV.apply(kv, cos, sin, interleaved, seqlen_offsets) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + qkv: torch.Tensor, + kv: Optional[torch.Tensor] = None, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = qkv.shape[1] + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + if kv is None: + if self.scale is None: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + q = qkv + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + inplace=True, + seqlen_offsets=seqlen_offset, + ) + if self.scale is None: + kv = apply_rotary_emb_kv_( + kv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + kv = apply_rotary_emb_kv_( + kv, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + return q, kv diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 8a839875de2a2..90d28872d3cc8 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -20,6 +20,7 @@ from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper +from rotary_flash import apply_rotary_emb from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -184,7 +185,13 @@ def create_multihead_attention_graph(config): def create_group_query_attention_graph_prompt( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -193,18 +200,22 @@ def create_group_query_attention_graph_prompt( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key" if share_buffer else "", "past_value" if share_buffer else "", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -218,25 +229,9 @@ def create_group_query_attention_graph_prompt( [ config.batch_size, config.q_sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -250,6 +245,27 @@ def create_group_query_attention_graph_prompt( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] if share_buffer: graph_input += [ helper.make_tensor_value_info( @@ -273,6 +289,25 @@ def create_group_query_attention_graph_prompt( ], ), ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -334,7 +369,13 @@ def create_group_query_attention_graph_prompt( def create_group_query_attention_graph_past( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -345,18 +386,22 @@ def create_group_query_attention_graph_past( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key", "past_value", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -370,25 +415,9 @@ def create_group_query_attention_graph_past( [ config.batch_size, config.sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -411,8 +440,6 @@ def create_group_query_attention_graph_past( config.head_size, ], ), - ] - graph_input += [ helper.make_tensor_value_info( "seqlens_k", TensorProto.INT32, @@ -424,6 +451,46 @@ def create_group_query_attention_graph_past( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -663,21 +730,38 @@ def mha_func(q, k, v, config): def gqa_prompt_func( - q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + window_size=-1, + past_kv_format=Formats.BSNH, + share_buffer=True, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_prompt( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None - new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -686,9 +770,17 @@ def gqa_prompt_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -713,17 +805,23 @@ def gqa_prompt_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) io_binding.bind_output("output") @@ -737,21 +835,38 @@ def gqa_prompt_func( def gqa_past_func( - q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1 + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + past_kv_format=Formats.BSNH, + share_buffer=True, + window_size=-1, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_past( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() - new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -763,9 +878,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -790,8 +913,6 @@ def gqa_past_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": past_k.detach().cpu().numpy(), "past_value": past_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -805,9 +926,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) @@ -1029,9 +1158,12 @@ def parity_check_mha( def parity_check_gqa_prompt( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1080,6 +1212,8 @@ def parity_check_gqa_prompt( dtype=torch.float16, requires_grad=False, ) + # print(k.shape) + # print(new_k.shape) window_size = (-1, -1) left_window_size = -1 @@ -1105,19 +1239,47 @@ def parity_check_gqa_prompt( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1125,13 +1287,47 @@ def parity_check_gqa_prompt( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # print((out - out_ref)[0, :, 0, 0]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1139,10 +1335,16 @@ def parity_check_gqa_prompt( # Compare results print( "KV-buffer", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1171,9 +1373,12 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1229,13 +1434,42 @@ def parity_check_gqa_prompt_no_buff( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, k_cache_ref + k_cache_ref = k_ro + brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1243,9 +1477,39 @@ def parity_check_gqa_prompt_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + None, + None, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + None, + None, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1256,7 +1520,17 @@ def parity_check_gqa_prompt_no_buff( # Compare results print( - "KV-buffer", + "No buff", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1285,9 +1559,12 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1336,6 +1613,7 @@ def parity_check_gqa_past( dtype=torch.float16, requires_grad=False, ) + window_size = (-1, -1) left_window_size = -1 if local: @@ -1359,18 +1637,45 @@ def parity_check_gqa_past( dtype=torch.int32, device="cuda", ) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1378,13 +1683,46 @@ def parity_check_gqa_past( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, cache_seqlens[0], :]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1394,10 +1732,16 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, " B:", config.batch_size, " S:", @@ -1427,6 +1771,9 @@ def parity_check_gqa_past_no_buff( causal=False, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1503,18 +1850,47 @@ def parity_check_gqa_past_no_buff( device="cuda", ) cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = ( + torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + ) + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1522,13 +1898,47 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((out - out_ref)[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # Make sure past-present buffer updating correctly # assert numpy.allclose( # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True @@ -1540,10 +1950,16 @@ def parity_check_gqa_past_no_buff( # Compare results print( "NO buff", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1671,10 +2087,25 @@ def test_gqa_no_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, local=local, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1684,7 +2115,6 @@ def test_gqa_past(self): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- TEST GQA PAST (TOKEN GEN) ---------") - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( [(1, 128), (1, 1024), (1, 2048)] @@ -1706,6 +2136,7 @@ def test_gqa_past(self): num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") for b in batches: for s, s2 in seqs: for n, n2 in num_h: @@ -1734,23 +2165,30 @@ def test_gqa_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if __name__ == "__main__": diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 0da4adb51767d..31f242cce1a23 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2064,7 +2064,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): numpy_init_version = numpy.__version__ pb_init_version = google.protobuf.__version__ run_subprocess( - [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=SCRIPT_DIR + [sys.executable, "-m", "pip", "install", "-r", "requirements-transformers-test.txt"], + cwd=SCRIPT_DIR, ) run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it diff --git a/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json new file mode 100644 index 0000000000000..445bfca9889ff --- /dev/null +++ b/tools/ci_build/github/apple/default_full_ios_framework_build_settings.json @@ -0,0 +1,30 @@ +{ + "build_osx_archs": { + "iphoneos": [ + "arm64" + ], + "iphonesimulator": [ + "arm64", + "x86_64" + ] + }, + "build_params": { + "base": [ + "--parallel", + "--use_xcode", + "--build_apple_framework", + "--use_coreml", + "--use_xnnpack", + "--skip_tests", + "--cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF" + ], + "iphoneos": [ + "--ios", + "--apple_deploy_target=12.0" + ], + "iphonesimulator": [ + "--ios", + "--apple_deploy_target=12.0" + ] + } +} diff --git a/tools/ci_build/github/apple/test_apple_packages.py b/tools/ci_build/github/apple/test_apple_packages.py index 6dc4868dac8a3..cd360a63a3a0f 100644 --- a/tools/ci_build/github/apple/test_apple_packages.py +++ b/tools/ci_build/github/apple/test_apple_packages.py @@ -112,7 +112,10 @@ def _test_apple_packages(args): subprocess.run(["pod", "cache", "clean", "--all"], shell=False, check=True, cwd=target_proj_path) # install pods - subprocess.run(["pod", "install"], shell=False, check=True, cwd=target_proj_path) + # set env to skip macos test targets accordingly + env = os.environ.copy() + env["SKIP_MACOS_TEST"] = "true" if args.skip_macos_test else "false" + subprocess.run(["pod", "install"], shell=False, check=True, cwd=target_proj_path, env=env) # run the tests if not args.prepare_test_project_only: @@ -144,7 +147,7 @@ def _test_apple_packages(args): cwd=target_proj_path, ) - if PackageVariant[args.variant] != PackageVariant.Mobile: + if PackageVariant[args.variant] != PackageVariant.Mobile and not args.skip_macos_test: subprocess.run( [ "xcrun", @@ -206,6 +209,12 @@ def parse_args(): help="Prepare the test project only, without running the tests", ) + parser.add_argument( + "--skip_macos_test", + action="store_true", + help="Skip macos platform tests. Specify this argument when build targets only contain ios archs. ", + ) + return parser.parse_args() diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 81319e07c6b17..3325e261715cf 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -119,31 +119,32 @@ stages: - script: | set -e -x python3 tools/ci_build/github/apple/build_apple_framework.py \ - --build_dir "$(Build.BinariesDirectory)/apple_framework" \ + --build_dir "$(Build.BinariesDirectory)/ios_framework" \ --path_to_protoc_exe $(Build.BinariesDirectory)/protobuf_install/bin/protoc \ - tools/ci_build/github/apple/default_full_apple_framework_build_settings.json + tools/ci_build/github/apple/default_full_ios_framework_build_settings.json mkdir $(Build.BinariesDirectory)/artifacts - mkdir -p $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-apple-xcframework-$(OnnxRuntimeVersion) - cp -R $(Build.BinariesDirectory)/apple_framework/framework_out/onnxruntime.xcframework \ - $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-apple-xcframework-$(OnnxRuntimeVersion) + mkdir -p $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) + cp -R $(Build.BinariesDirectory)/ios_framework/framework_out/onnxruntime.xcframework \ + $(Build.BinariesDirectory)/artifacts_staging/onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) pushd $(Build.BinariesDirectory)/artifacts_staging zip -vr $(Build.BinariesDirectory)/artifacts/onnxruntime_xcframework.zip \ - onnxruntime-apple-xcframework-$(OnnxRuntimeVersion) + onnxruntime-ios-xcframework-$(OnnxRuntimeVersion) popd displayName: "Build Apple xcframework" - script: | python3 tools/ci_build/github/apple/test_apple_packages.py \ --fail_if_cocoapods_missing \ - --framework_info_file "$(Build.BinariesDirectory)/apple_framework/xcframework_info.json" \ - --c_framework_dir "$(Build.BinariesDirectory)/apple_framework/framework_out" \ - --variant Full + --framework_info_file "$(Build.BinariesDirectory)/ios_framework/xcframework_info.json" \ + --c_framework_dir "$(Build.BinariesDirectory)/ios_framework/framework_out" \ + --variant Full \ + --skip_macos_test displayName: "Test Apple framework" - task: PublishBuildArtifacts@1 inputs: pathtoPublish: '$(Build.BinariesDirectory)/artifacts' - artifactName: 'onnxruntime-apple-full-xcframework' + artifactName: 'onnxruntime-ios-full-xcframework' - template: component-governance-component-detection-steps.yml parameters: @@ -350,7 +351,7 @@ stages: - template: flex-downloadPipelineArtifact.yml parameters: StepName: 'Download iOS Pipeline Artifact' - ArtifactName: 'onnxruntime-apple-full-xcframework' + ArtifactName: 'onnxruntime-ios-full-xcframework' TargetPath: '$(Build.BinariesDirectory)/nuget-artifact' SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} diff --git a/tools/ci_build/requirements.txt b/tools/ci_build/requirements-transformers-test.txt similarity index 94% rename from tools/ci_build/requirements.txt rename to tools/ci_build/requirements-transformers-test.txt index 57fc8f08336d2..a5279781462a7 100644 --- a/tools/ci_build/requirements.txt +++ b/tools/ci_build/requirements-transformers-test.txt @@ -3,7 +3,8 @@ packaging protobuf==3.20.2 numpy==1.24.0 ; python_version < '3.12' numpy==1.26.0 ; python_version >= '3.12' +torch coloredlogs==15.0 transformers==4.36.0 psutil -einops \ No newline at end of file +einops