From dbb1e4e1427ea88a6d997de6d704e705b3b873fc Mon Sep 17 00:00:00 2001 From: Bowen Wang Date: Thu, 13 Feb 2025 20:13:08 +0800 Subject: [PATCH] refactor: change to TORCH_LIBRARY (#823) This PR updates FlashInfer's C++/CUDA extensions from pybind11 modules to `torch.libraries`, which is recommended since PyTorch 2.5. This is mainly implemented in #764. We have investigated that the issue in #820 was not caused by this PR, so we're opening it up again. --------- Signed-off-by: youkaichao Signed-off-by: abmfy Co-authored-by: youkaichao Co-authored-by: Zihao Ye --- ...te_aot_default_additional_params_header.py | 12 +- csrc/batch_decode.cu | 20 +- csrc/batch_decode_jit_pybind.cu | 23 +- csrc/batch_decode_mla_plan.cu | 9 +- csrc/batch_decode_mla_pybind.cu | 18 +- csrc/batch_decode_mla_run.cu | 9 +- csrc/batch_mla_plan.cu | 16 +- csrc/batch_mla_pybind.cu | 29 ++- csrc/batch_mla_run.cu | 15 +- csrc/batch_prefill.cu | 25 +-- csrc/batch_prefill_jit_pybind.cu | 31 +-- csrc/batch_prefill_sm90.cu | 40 ++-- csrc/batch_prefill_sm90_jit_pybind.cu | 32 +-- csrc/flashinfer_cascade_ops.cu | 12 +- csrc/flashinfer_gemm_ops.cu | 8 +- csrc/flashinfer_gemm_sm90_ops.cu | 8 +- csrc/flashinfer_norm_ops.cu | 15 +- csrc/flashinfer_ops.cu | 197 ++++++++++-------- csrc/flashinfer_ops_sm90.cu | 38 ++-- csrc/flashinfer_page_ops.cu | 16 +- csrc/flashinfer_quantization_ops.cu | 8 +- csrc/flashinfer_rope_ops.cu | 39 ++-- csrc/flashinfer_sampling_ops.cu | 40 ++-- csrc/page.cu | 8 +- csrc/pytorch_conversion_utils.h | 29 +++ csrc/pytorch_extension_utils.h | 35 +++- csrc/renorm.cu | 4 +- csrc/rope.cu | 20 +- csrc/sampling.cu | 2 +- csrc/single_decode.cu | 4 +- csrc/single_decode_jit_pybind.cu | 9 +- csrc/single_prefill.cu | 4 +- csrc/single_prefill_jit_pybind.cu | 10 +- csrc/single_prefill_sm90.cu | 4 +- csrc/single_prefill_sm90_jit_pybind.cu | 10 +- flashinfer/activation.py | 2 +- flashinfer/cascade.py | 2 +- flashinfer/decode.py | 4 +- flashinfer/gemm.py | 4 +- flashinfer/jit/activation.py | 4 +- flashinfer/jit/attention.py | 12 +- flashinfer/jit/core.py | 8 +- flashinfer/norm.py | 2 +- flashinfer/page.py | 2 +- flashinfer/prefill.py | 8 +- flashinfer/quantization.py | 2 +- flashinfer/rope.py | 2 +- flashinfer/sampling.py | 2 +- include/flashinfer/attention/scheduler.cuh | 4 +- include/flashinfer/gemm/group_gemm.cuh | 2 +- include/flashinfer/gemm/group_gemm_sm90.cuh | 2 +- setup.py | 10 +- tests/test_jit_example.py | 18 +- 53 files changed, 503 insertions(+), 386 deletions(-) create mode 100644 csrc/pytorch_conversion_utils.h diff --git a/aot_build_utils/generate_aot_default_additional_params_header.py b/aot_build_utils/generate_aot_default_additional_params_header.py index 76f3b758e..2e832bb36 100644 --- a/aot_build_utils/generate_aot_default_additional_params_header.py +++ b/aot_build_utils/generate_aot_default_additional_params_header.py @@ -85,7 +85,7 @@ def get_aot_default_additional_params_header_str() -> str: "rope_rcp_scale", "rope_rcp_theta", ], # additional_scalar_names - ["float", "float", "float", "float"], # additional_scalar_dtypes + ["double", "double", "double", "double"], # additional_scalar_dtypes ) ret += generate_macro_entry( @@ -98,7 +98,7 @@ def get_aot_default_additional_params_header_str() -> str: "rope_rcp_scale", "rope_rcp_theta", ], - ["float", "float", "float", "float"], + ["double", "double", "double", "double"], ) ret += generate_macro_entry( @@ -106,7 +106,7 @@ def get_aot_default_additional_params_header_str() -> str: [], [], ["logits_soft_cap", "sm_scale"], - ["float", "float"], + ["double", "double"], is_sm90_template=True, ) @@ -120,7 +120,7 @@ def get_aot_default_additional_params_header_str() -> str: "rope_rcp_scale", "rope_rcp_theta", ], # additional_scalar_names - ["float", "float", "float", "float"], # additional_scalar_dtypes + ["double", "double", "double", "double"], # additional_scalar_dtypes ) ret += generate_macro_entry( @@ -133,7 +133,7 @@ def get_aot_default_additional_params_header_str() -> str: "rope_rcp_scale", "rope_rcp_theta", ], - ["float", "float", "float", "float"], + ["double", "double", "double", "double"], ) ret += generate_macro_entry( @@ -141,7 +141,7 @@ def get_aot_default_additional_params_header_str() -> str: [], [], ["logits_soft_cap", "sm_scale"], - ["float", "float"], + ["double", "double"], is_sm90_template=True, ) diff --git a/csrc/batch_decode.cu b/csrc/batch_decode.cu index 33d94ec8b..065bb5acb 100644 --- a/csrc/batch_decode.cu +++ b/csrc/batch_decode.cu @@ -20,6 +20,7 @@ #include "batch_decode_config.inc" #include "pytorch_extension_utils.h" +#include "pytorch_conversion_utils.h" namespace flashinfer { @@ -32,13 +33,12 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params using namespace flashinfer; -std::vector BatchDecodeWithPagedKVCachePlan( +at::Tensor BatchDecodeWithPagedKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk, - unsigned int head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, - int64_t cuda_stream) { + at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, + int64_t head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = @@ -74,17 +74,17 @@ std::vector BatchDecodeWithPagedKVCachePlan( }); }); - return plan_info.ToVector(); + return vec_to_tensor(plan_info.ToVector()); } void BatchDecodeWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, - unsigned int kv_layout_code, int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { + int64_t kv_layout_code, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { DecodePlanInfo plan_info; - plan_info.FromVector(plan_info_vec); + plan_info.FromVector(tensor_to_vec(plan_info_vec)); QKVLayout kv_layout = static_cast(kv_layout_code); auto device = q.device(); int64_t batch_size = q.size(0); diff --git a/csrc/batch_decode_jit_pybind.cu b/csrc/batch_decode_jit_pybind.cu index db43cddd9..b8d1ace4f 100644 --- a/csrc/batch_decode_jit_pybind.cu +++ b/csrc/batch_decode_jit_pybind.cu @@ -16,22 +16,23 @@ #include "batch_decode_config.inc" #include "pytorch_extension_utils.h" -std::vector BatchDecodeWithPagedKVCachePlan( +at::Tensor BatchDecodeWithPagedKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk, - unsigned int head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, - int64_t cuda_stream); + at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, + int64_t head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream); void BatchDecodeWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, - unsigned int kv_layout_code, int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + int64_t kv_layout_code, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("plan", &BatchDecodeWithPagedKVCachePlan, "Batched decode with paged KV-Cache plan"); - m.def("run", &BatchDecodeWithPagedKVCacheRun, "Batched decode with paged KV-Cache run"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Batched decode with paged KV-Cache plan + m.def("plan", BatchDecodeWithPagedKVCachePlan); + // Batched decode with paged KV-Cache run + m.def("run", BatchDecodeWithPagedKVCacheRun); } diff --git a/csrc/batch_decode_mla_plan.cu b/csrc/batch_decode_mla_plan.cu index ea37e87b5..8c4b114b4 100644 --- a/csrc/batch_decode_mla_plan.cu +++ b/csrc/batch_decode_mla_plan.cu @@ -4,13 +4,14 @@ #include "mla_config.inc" #include "pytorch_extension_utils.h" +#include "pytorch_conversion_utils.h" using namespace flashinfer; -std::vector BatchDecodeWithPagedKVCachePlanMLA( +at::Tensor BatchDecodeWithPagedKVCachePlanMLA( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int page_size, bool enable_cuda_graph, + at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size, + int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); @@ -35,5 +36,5 @@ std::vector BatchDecodeWithPagedKVCachePlanMLA( TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCachePlanMLA failed with error ", cudaGetErrorString(status)); - return plan_info.ToVector(); + return vec_to_tensor(plan_info.ToVector()); } diff --git a/csrc/batch_decode_mla_pybind.cu b/csrc/batch_decode_mla_pybind.cu index 8eb9d5dab..1aace8295 100644 --- a/csrc/batch_decode_mla_pybind.cu +++ b/csrc/batch_decode_mla_pybind.cu @@ -1,20 +1,20 @@ #include "mla_config.inc" #include "pytorch_extension_utils.h" -std::vector BatchDecodeWithPagedKVCachePlanMLA( +at::Tensor BatchDecodeWithPagedKVCachePlanMLA( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream); + at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size, + int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph, int64_t cuda_stream); void BatchDecodeWithPagedKVCacheRunMLA( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, + at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, at::Tensor paged_kv_indptr, - at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, float sm_scale, - int window_left, float logits_soft_cap, float rope_scale, float rope_theta, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, double sm_scale, + int64_t window_left, double logits_soft_cap, double rope_scale, double rope_theta, std::optional maybe_lse, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("plan", &BatchDecodeWithPagedKVCachePlanMLA); - m.def("run", &BatchDecodeWithPagedKVCacheRunMLA); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("plan", BatchDecodeWithPagedKVCachePlanMLA); + m.def("run", BatchDecodeWithPagedKVCacheRunMLA); } diff --git a/csrc/batch_decode_mla_run.cu b/csrc/batch_decode_mla_run.cu index 05254b53c..82c53ddf0 100644 --- a/csrc/batch_decode_mla_run.cu +++ b/csrc/batch_decode_mla_run.cu @@ -4,18 +4,19 @@ #include "mla_config.inc" #include "pytorch_extension_utils.h" +#include "pytorch_conversion_utils.h" using namespace flashinfer; void BatchDecodeWithPagedKVCacheRunMLA( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, + at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, at::Tensor paged_kv_indptr, - at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, float sm_scale, - int window_left, float logits_soft_cap, float rope_scale, float rope_theta, + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, double sm_scale, + int64_t window_left, double logits_soft_cap, double rope_scale, double rope_theta, std::optional maybe_lse, int64_t cuda_stream) { DecodePlanInfo plan_info; - plan_info.FromVector(plan_info_vec); + plan_info.FromVector(tensor_to_vec(plan_info_vec)); auto device = q_nope.device(); int64_t batch_size = q_nope.size(0); diff --git a/csrc/batch_mla_plan.cu b/csrc/batch_mla_plan.cu index e3d370868..cbfcf2fd7 100644 --- a/csrc/batch_mla_plan.cu +++ b/csrc/batch_mla_plan.cu @@ -17,17 +17,17 @@ #include #include "batch_mla_config.inc" +#include "pytorch_conversion_utils.h" #include "pytorch_extension_utils.h" using namespace flashinfer; -std::vector BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer, - at::Tensor int_workspace_buffer, - at::Tensor page_locked_int_workspace_buffer, - at::Tensor qo_indptr, at::Tensor kv_indptr, - at::Tensor kv_len, unsigned int num_heads, - unsigned int head_dim_o, bool causal, - int64_t cuda_stream) { +at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len, + int64_t num_heads, int64_t head_dim_o, bool causal, + int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = @@ -47,5 +47,5 @@ std::vector BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffe TORCH_CHECK(status == cudaSuccess, "Failed to plan MLA, error: ", cudaGetErrorString(status)); - return plan_info.ToVector(); + return vec_to_tensor(plan_info.ToVector()); } diff --git a/csrc/batch_mla_pybind.cu b/csrc/batch_mla_pybind.cu index 8fae8b049..ddc6a9940 100644 --- a/csrc/batch_mla_pybind.cu +++ b/csrc/batch_mla_pybind.cu @@ -16,22 +16,21 @@ #include "batch_mla_config.inc" #include "pytorch_extension_utils.h" -std::vector BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer, - at::Tensor int_workspace_buffer, - at::Tensor page_locked_int_workspace_buffer, - at::Tensor qo_indptr, at::Tensor kv_indptr, - at::Tensor kv_len, unsigned int num_heads, - unsigned int head_dim_o, bool causal, - int64_t cuda_stream); +at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len, + int64_t num_heads, int64_t head_dim_o, bool causal, + int64_t cuda_stream); void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q_nope, - at::Tensor q_pe, at::Tensor ckv_cache, at::Tensor kpe_cache, - at::Tensor kv_indices, at::Tensor o, - std::optional maybe_lse, int mask_mode_code, - int num_heads, int page_size, float sm_scale, int64_t cuda_stream); + at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, + at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices, + at::Tensor o, std::optional maybe_lse, + int64_t mask_mode_code, int64_t num_heads, int64_t page_size, + double sm_scale, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("plan", &BatchMLAPagedAttentionPlan, "Batch MLA Page Attention Plan"); - m.def("run", &BatchMLAPagedAttentionRun, "Batch MLA Page Attention Run"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("plan", &BatchMLAPagedAttentionPlan); + m.def("run", &BatchMLAPagedAttentionRun); } diff --git a/csrc/batch_mla_run.cu b/csrc/batch_mla_run.cu index ecbccb37e..ea2dbb866 100644 --- a/csrc/batch_mla_run.cu +++ b/csrc/batch_mla_run.cu @@ -13,30 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include - #include #include #include #include #include "batch_mla_config.inc" +#include "pytorch_conversion_utils.h" #include "pytorch_extension_utils.h" using namespace flashinfer; void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q_nope, - at::Tensor q_pe, at::Tensor ckv_cache, at::Tensor kpe_cache, - at::Tensor kv_indices, at::Tensor o, - std::optional maybe_lse, int mask_mode_code, - int num_heads, int page_size, float sm_scale, int64_t cuda_stream) { + at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, + at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices, + at::Tensor o, std::optional maybe_lse, + int64_t mask_mode_code, int64_t num_heads, int64_t page_size, + double sm_scale, int64_t cuda_stream) { // q_nope: [n, num_heads, head_dim_ckv] // q_pe: [n, num_heads, head_dim_kpe] // ckv_cache: [num_pages, page_size, head_dim_ckv] // kpe_cache: [num_pages, page_size, head_dim_kpe] MLAPlanInfo plan_info; - plan_info.FromVector(plan_info_vec); + plan_info.FromVector(tensor_to_vec(plan_info_vec)); auto device = q_nope.device(); diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index 9c7b77029..98fa33e65 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -20,6 +20,7 @@ #include "batch_prefill_config.inc" #include "pytorch_extension_utils.h" +#include "pytorch_conversion_utils.h" namespace flashinfer { @@ -39,12 +40,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para using namespace flashinfer; -std::vector BatchPrefillWithKVCachePlan( +at::Tensor BatchPrefillWithKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, - at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); @@ -64,17 +65,17 @@ std::vector BatchPrefillWithKVCachePlan( TORCH_CHECK(status == cudaSuccess, "Failed to plan prefill with error: ", cudaGetErrorString(status)); - return plan_info.ToVector(); + return vec_to_tensor(plan_info.ToVector()); } void BatchPrefillWithRaggedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS, + int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { PrefillPlanInfo plan_info; - plan_info.FromVector(plan_info_vec); + plan_info.FromVector(tensor_to_vec(plan_info_vec)); QKVLayout kv_layout = static_cast(layout); int64_t num_qo_heads = q.size(1); @@ -194,13 +195,13 @@ void BatchPrefillWithRaggedKVCacheRun( void BatchPrefillWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, - std::optional maybe_lse, unsigned int mask_mode_code, unsigned int layout, - int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { + std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { PrefillPlanInfo plan_info; - plan_info.FromVector(plan_info_vec); + plan_info.FromVector(tensor_to_vec(plan_info_vec)); QKVLayout kv_layout = static_cast(layout); auto device = q.device(); int64_t batch_size = paged_kv_indptr.size(0) - 1; diff --git a/csrc/batch_prefill_jit_pybind.cu b/csrc/batch_prefill_jit_pybind.cu index 6d35deef4..17dee4aae 100644 --- a/csrc/batch_prefill_jit_pybind.cu +++ b/csrc/batch_prefill_jit_pybind.cu @@ -16,33 +16,34 @@ #include "batch_prefill_config.inc" #include "pytorch_extension_utils.h" -std::vector BatchPrefillWithKVCachePlan( +at::Tensor BatchPrefillWithKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, - at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS, + int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); void BatchPrefillWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, - std::optional maybe_lse, unsigned int mask_mode_code, unsigned int layout, - int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("plan", &BatchPrefillWithKVCachePlan, "Batch-request prefill attention with KV-Cache plan"); - m.def("ragged_run", &BatchPrefillWithRaggedKVCacheRun, - "Batch-request prefill attention with KV-Cache operator"); - m.def("paged_run", &BatchPrefillWithPagedKVCacheRun, - "Batch-request prefill attention with KV-Cache operator"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Batch-request prefill attention with KV-Cache plan + m.def("plan", BatchPrefillWithKVCachePlan); + // Batch-request prefill attention with KV-Cache operator + m.def("ragged_run", BatchPrefillWithRaggedKVCacheRun); + // Batch-request prefill attention with KV-Cache operator + m.def("paged_run", BatchPrefillWithPagedKVCacheRun); } diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu index 6ee020a7b..694525eb6 100644 --- a/csrc/batch_prefill_sm90.cu +++ b/csrc/batch_prefill_sm90.cu @@ -20,8 +20,10 @@ #include #include -#include "batch_prefill_sm90_config.inc" #include "pytorch_extension_utils.h" +#include "pytorch_conversion_utils.h" + +#include "batch_prefill_sm90_config.inc" namespace flashinfer { @@ -37,19 +39,19 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, cudaStream_t using namespace flashinfer; -std::vector BatchPrefillWithKVCacheSM90Plan( +at::Tensor BatchPrefillWithKVCacheSM90Plan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, - at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t cuda_stream) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); size_t int_workspace_size_in_bytes = int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); - PrefillPlanSM90Info plan_info; + flashinfer::PrefillPlanSM90Info plan_info; cudaStream_t stream = reinterpret_cast(cuda_stream); @@ -64,17 +66,17 @@ std::vector BatchPrefillWithKVCacheSM90Plan( TORCH_CHECK(status == cudaSuccess, "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); - return plan_info.ToVector(); + return vec_to_tensor(plan_info.ToVector()); } void BatchPrefillWithRaggedKVCacheSM90Run( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS, + int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { PrefillPlanSM90Info plan_info; - plan_info.FromVector(plan_info_vec); + plan_info.FromVector(tensor_to_vec(plan_info_vec)); if (maybe_lse) { const auto& lse = *maybe_lse; @@ -85,8 +87,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run( void* float_buffer_ptr = float_workspace_buffer.data_ptr(); void* int_buffer_ptr = int_workspace_buffer.data_ptr(); - unsigned int head_dim_qk = q.size(2); - unsigned int head_dim_vo = v.size(2); + int64_t head_dim_qk = q.size(2); + int64_t head_dim_vo = v.size(2); auto q_scalar_type = q.scalar_type(); auto kv_scalar_type = k.scalar_type(); @@ -156,13 +158,13 @@ void BatchPrefillWithRaggedKVCacheSM90Run( void BatchPrefillWithPagedKVCacheSM90Run( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, - std::optional maybe_lse, unsigned int mask_mode_code, unsigned int layout, - int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { + std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { PrefillPlanSM90Info plan_info; - plan_info.FromVector(plan_info_vec); + plan_info.FromVector(tensor_to_vec(plan_info_vec)); if (maybe_lse) { const auto& lse = *maybe_lse; @@ -170,9 +172,9 @@ void BatchPrefillWithPagedKVCacheSM90Run( TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); } QKVLayout kv_layout = static_cast(layout); - unsigned int num_kv_heads, page_size; - unsigned int head_dim_qk = q.size(2); - unsigned int head_dim_vo = paged_v_cache.size(3); + int64_t num_kv_heads, page_size; + int64_t head_dim_qk = q.size(2); + int64_t head_dim_vo = paged_v_cache.size(3); if (kv_layout == QKVLayout::kHND) { num_kv_heads = paged_k_cache.size(1); page_size = paged_k_cache.size(2); diff --git a/csrc/batch_prefill_sm90_jit_pybind.cu b/csrc/batch_prefill_sm90_jit_pybind.cu index 5466a6ac8..1af16907b 100644 --- a/csrc/batch_prefill_sm90_jit_pybind.cu +++ b/csrc/batch_prefill_sm90_jit_pybind.cu @@ -16,34 +16,34 @@ #include "batch_prefill_sm90_config.inc" #include "pytorch_extension_utils.h" -std::vector BatchPrefillWithKVCacheSM90Plan( +at::Tensor BatchPrefillWithKVCacheSM90Plan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, - at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheSM90Run( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS, + int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); void BatchPrefillWithPagedKVCacheSM90Run( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, - std::optional maybe_lse, unsigned int mask_mode_code, unsigned int layout, - int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("plan", &BatchPrefillWithKVCacheSM90Plan, - "Batch-request prefill attention with KV-Cache plan"); - m.def("ragged_run", &BatchPrefillWithRaggedKVCacheSM90Run, - "Batch-request prefill attention with KV-Cache operator"); - m.def("paged_run", &BatchPrefillWithPagedKVCacheSM90Run, - "Batch-request prefill attention with KV-Cache operator"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Batch-request prefill attention with KV-Cache plan + m.def("plan", BatchPrefillWithKVCacheSM90Plan); + // Batch-request prefill attention with KV-Cache operator + m.def("ragged_run", BatchPrefillWithRaggedKVCacheSM90Run); + // Batch-request prefill attention with KV-Cache operator + m.def("paged_run", BatchPrefillWithPagedKVCacheSM90Run); } diff --git a/csrc/flashinfer_cascade_ops.cu b/csrc/flashinfer_cascade_ops.cu index 4527022d7..90dcd3011 100644 --- a/csrc/flashinfer_cascade_ops.cu +++ b/csrc/flashinfer_cascade_ops.cu @@ -24,9 +24,11 @@ void merge_state_in_place(at::Tensor v, at::Tensor s, at::Tensor v_other, at::Te void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_merged, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("merge_state", &merge_state, "Merge two self-attention states"); - m.def("merge_state_in_place", &merge_state_in_place, - "Merge another self-attention state in-place."); - m.def("merge_states", &merge_states, "Merge multiple self-attention states"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Merge two self-attention states + m.def("merge_state", merge_state); + // Merge another self-attention state in-place. + m.def("merge_state_in_place", merge_state_in_place); + // "Merge multiple self-attention states" + m.def("merge_states", merge_states); } diff --git a/csrc/flashinfer_gemm_ops.cu b/csrc/flashinfer_gemm_ops.cu index b13a2bcd6..6e6094c7b 100644 --- a/csrc/flashinfer_gemm_ops.cu +++ b/csrc/flashinfer_gemm_ops.cu @@ -23,7 +23,9 @@ void CutlassSegmentGEMM(at::Tensor workspace_buffer, at::Tensor all_problems, at at::Tensor y_ld, at::Tensor empty_x_data, bool weight_column_major, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM"); - m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // "Cutlass Segment GEMM" + m.def("cutlass_segment_gemm", CutlassSegmentGEMM); + // "BMM FP8" + m.def("bmm_fp8", bmm_fp8); } diff --git a/csrc/flashinfer_gemm_sm90_ops.cu b/csrc/flashinfer_gemm_sm90_ops.cu index 55b075ab5..cbdf96310 100644 --- a/csrc/flashinfer_gemm_sm90_ops.cu +++ b/csrc/flashinfer_gemm_sm90_ops.cu @@ -19,9 +19,9 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo at::Tensor all_problems, at::Tensor x_ptr, at::Tensor w_ptr, at::Tensor y_ptr, at::Tensor x_stride, at::Tensor weight_stride, at::Tensor y_stride, at::Tensor empty_x_data, bool weight_column_major, - std::vector plan_info_vec, int64_t cuda_stream); + at::Tensor plan_info_vec, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, - "Cutlass Segment GEMM operator for SM90"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // "Cutlass Segment GEMM operator for SM90" + m.def("cutlass_segment_gemm_sm90", CutlassSegmentGEMMSM90); } diff --git a/csrc/flashinfer_norm_ops.cu b/csrc/flashinfer_norm_ops.cu index 52a103508..4c6d40ba8 100644 --- a/csrc/flashinfer_norm_ops.cu +++ b/csrc/flashinfer_norm_ops.cu @@ -27,10 +27,13 @@ void gemma_rmsnorm(at::Tensor& out, at::Tensor& input, at::Tensor& weight, doubl void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, - "Gemma Fused add root mean square normalization"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Root mean square normalization + m.def("rmsnorm", rmsnorm); + // Fused add root mean square normalization + m.def("fused_add_rmsnorm", fused_add_rmsnorm); + // Gemma Root mean square normalization + m.def("gemma_rmsnorm", gemma_rmsnorm); + // Gemma Fused add root mean square normalization + m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm); } diff --git a/csrc/flashinfer_ops.cu b/csrc/flashinfer_ops.cu index 2dc620603..b6c2f7cc8 100644 --- a/csrc/flashinfer_ops.cu +++ b/csrc/flashinfer_ops.cu @@ -36,23 +36,23 @@ void merge_states(at::Tensor v, at::Tensor s, at::Tensor v_merged, at::Tensor s_ //========== decode ========== void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, - at::Tensor o, unsigned int layout, - int window_left SINGLE_DECODE_ADDITIONAL_FUNC_PARAMS, + at::Tensor o, int64_t layout, + int64_t window_left SINGLE_DECODE_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -std::vector BatchDecodeWithPagedKVCachePlan( +at::Tensor BatchDecodeWithPagedKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk, - unsigned head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream); + at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, + int64_t head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream); void BatchDecodeWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional maybe_lse, - unsigned int kv_layout_code, int window_left BATCH_DECODE_ADDITIONAL_FUNC_PARAMS, + int64_t kv_layout_code, int64_t window_left BATCH_DECODE_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); //========== gemm ========== @@ -84,46 +84,46 @@ void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, - unsigned int layout, int64_t cuda_stream); + int64_t layout, int64_t cuda_stream); void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, at::Tensor block_sparse_indptr, at::Tensor vector_sparse_offsets, at::Tensor vector_sparse_indptr, - at::Tensor kv_len_arr, unsigned int stride_block, - unsigned int stride_n, unsigned int batch_size, - unsigned int block_size, int64_t cuda_stream); + at::Tensor kv_len_arr, int64_t stride_block, + int64_t stride_n, int64_t batch_size, + int64_t block_size, int64_t cuda_stream); //========== prefill ========== void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, - int32_t window_left SINGLE_PREFILL_ADDITIONAL_FUNC_PARAMS, + int64_t mask_mode_code, int64_t layout, + int64_t window_left SINGLE_PREFILL_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -std::vector BatchPrefillWithKVCachePlan( +at::Tensor BatchPrefillWithKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, - at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, - int32_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + int64_t mask_mode_code, int64_t layout, + int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); void BatchPrefillWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, - std::optional maybe_lse, unsigned int mask_mode_code, unsigned int layout, - int32_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); //========== quantization ========== @@ -135,22 +135,22 @@ void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_i //========== rope ========== void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr, - at::Tensor offsets, unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, int64_t cuda_stream); + at::Tensor offsets, int64_t rotary_dim, bool interleave, double rope_scale, + double rope_theta, int64_t cuda_stream); void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor indptr, at::Tensor offsets, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, int64_t cuda_stream); + at::Tensor indptr, at::Tensor offsets, int64_t rotary_dim, + bool interleave, double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length, int64_t cuda_stream); void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, - float rope_scale, float rope_theta, int64_t cuda_stream); + at::Tensor pos_ids, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, int64_t cuda_stream); void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, - float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, + at::Tensor pos_ids, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length, int64_t cuda_stream); void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, @@ -168,7 +168,7 @@ void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at: void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic, int64_t cuda_stream); + int64_t top_k_val, bool deterministic, int64_t cuda_stream); void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, std::optional maybe_min_p_arr, double min_p_val, @@ -185,11 +185,11 @@ void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, int64_t cuda_stream); void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_k_arr, unsigned int top_k_val, + std::optional maybe_top_k_arr, int64_t top_k_val, int64_t cuda_stream); void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, - std::optional maybe_top_k_arr, unsigned int top_k_val, + std::optional maybe_top_k_arr, int64_t top_k_val, int64_t cuda_stream); void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids, @@ -198,75 +198,96 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i at::Tensor output_emitted_token_num, bool deterministic, int64_t cuda_stream); -//========== pybind11 ========== +//========== Torch Library ========== -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { // activation - m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); - m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); - m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); + // Fused SiLU and Mul + m.def("silu_and_mul", silu_and_mul); + // Fused GeLU Tanh and Mul + m.def("gelu_tanh_and_mul", gelu_tanh_and_mul); + // Fused GeLU and Mul + m.def("gelu_and_mul", gelu_and_mul); // cascade - m.def("merge_state", &merge_state, "Merge two self-attention states"); - m.def("merge_state_in_place", &merge_state_in_place, - "Merge another self-attention state in-place."); - m.def("merge_states", &merge_states, "Merge multiple self-attention states"); + // Merge two self-attention states + m.def("merge_state", merge_state); + // Merge another self-attention state in-place. + m.def("merge_state_in_place", merge_state_in_place); + // "Merge multiple self-attention states" + m.def("merge_states", merge_states); // decode - m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, - "Single-request decode with KV-Cache operator"); - m.def("batch_decode_with_paged_kv_cache_plan", &BatchDecodeWithPagedKVCachePlan); - m.def("batch_decode_with_paged_kv_cache_run", &BatchDecodeWithPagedKVCacheRun); + // "Single-request decode with KV-Cache operator" + m.def("single_decode_with_kv_cache", single_decode_with_kv_cache); + m.def("batch_decode_with_paged_kv_cache_plan", BatchDecodeWithPagedKVCachePlan); + m.def("batch_decode_with_paged_kv_cache_run", BatchDecodeWithPagedKVCacheRun); // gemm - m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); - m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator"); + // BMM FP8 + m.def("bmm_fp8", bmm_fp8); + // Cutlass Segment GEMM operator + m.def("cutlass_segment_gemm", CutlassSegmentGEMM); // norm - m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, - "Gemma Fused add root mean square normalization"); + // Root mean square normalization + m.def("rmsnorm", rmsnorm); + // Fused add root mean square normalization + m.def("fused_add_rmsnorm", fused_add_rmsnorm); + // Gemma Root mean square normalization + m.def("gemma_rmsnorm", gemma_rmsnorm); + // Gemma Fused add root mean square normalization + m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm); // page - m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + // Append paged KV-Cache operator + m.def("append_paged_kv_cache", append_paged_kv_cache); + // Precompute block sparse offsets m.def("block_sparse_indices_to_vector_sparse_offsets", - &block_sparse_indices_to_vector_sparse_offsets, "Precompute block sparse offsets"); + block_sparse_indices_to_vector_sparse_offsets); // prefill - m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, - "Single-request prefill attention with KV-Cache operator"); - m.def("batch_prefill_with_kv_cache_plan", &BatchPrefillWithKVCachePlan); - m.def("batch_prefill_with_ragged_kv_cache_run", &BatchPrefillWithRaggedKVCacheRun); - m.def("batch_prefill_with_paged_kv_cache_run", &BatchPrefillWithPagedKVCacheRun); + // Single-request prefill attention with KV-Cache operator + m.def("single_prefill_with_kv_cache", single_prefill_with_kv_cache); + m.def("batch_prefill_with_kv_cache_plan", BatchPrefillWithKVCachePlan); + m.def("batch_prefill_with_ragged_kv_cache_run", BatchPrefillWithRaggedKVCacheRun); + m.def("batch_prefill_with_paged_kv_cache_run", BatchPrefillWithPagedKVCacheRun); // quantization - m.def("packbits", &packbits, "GPU packbits operator"); - m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); + // GPU packbits operator + m.def("packbits", packbits); + // GPU segment packbits operator + m.def("segment_packbits", segment_packbits); // rope - m.def("apply_rope", &apply_rope, "Apply RoPE"); - m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); - m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); - m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, - "Apply Llama 3.1 style RoPE with positional ids"); - m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache, - "Apply RoPE with positional ids and cosine/sine cache"); + // "Apply RoPE" + m.def("apply_rope", apply_rope); + // "Apply Llama 3.1 style RoPE" + m.def("apply_llama31_rope", apply_llama31_rope); + // "Apply RoPE with positional ids" + m.def("apply_rope_pos_ids", apply_rope_pos_ids); + // "Apply Llama 3.1 style RoPE with positional ids" + m.def("apply_llama31_rope_pos_ids", apply_llama31_rope_pos_ids); + // "Apply RoPE with positional ids and cosine/sine cache" + m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache); // sampling - m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); - m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, - "Top-k sampling from probabilities"); - m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, - "Min-p sampling from probabilities"); - m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, - "Top-p sampling from probabilities"); - m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, - "Top-k and top-p sampling from probabilities"); - m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); - m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); - m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); - m.def("chain_speculative_sampling", &chain_speculative_sampling, - "Speculative sampling from sequence of probabilities"); + // Sample from probabilities + m.def("sampling_from_probs", sampling_from_probs); + // Top-k sampling from probabilities + m.def("top_k_sampling_from_probs", top_k_sampling_from_probs); + // Min-p sampling from probabilities + m.def("min_p_sampling_from_probs", min_p_sampling_from_probs); + // Top-p sampling from probabilities + m.def("top_p_sampling_from_probs", top_p_sampling_from_probs); + // Top-k and top-p sampling from probabilities + m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs); + // Renormalize probabilities by top-k mask + m.def("top_k_renorm_probs", top_k_renorm_probs); + // Renormalize probabilities by top-p mask + m.def("top_p_renorm_probs", top_p_renorm_probs); + // Mask logits by top-k mask + m.def("top_k_mask_logits", top_k_mask_logits); + // Speculative sampling from sequence of probabilities + m.def("chain_speculative_sampling", chain_speculative_sampling); } diff --git a/csrc/flashinfer_ops_sm90.cu b/csrc/flashinfer_ops_sm90.cu index 7e2ab7bdd..dbba2f687 100644 --- a/csrc/flashinfer_ops_sm90.cu +++ b/csrc/flashinfer_ops_sm90.cu @@ -24,37 +24,37 @@ void CutlassSegmentGEMMSM90(at::Tensor float_workspace_buffer, at::Tensor int_wo void single_prefill_with_kv_cache_sm90( at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o, - std::optional maybe_lse, unsigned int mask_mode_code, unsigned int layout, - int32_t window_left SINGLE_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left SINGLE_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -std::vector BatchPrefillWithKVCacheSM90Plan( +at::Tensor BatchPrefillWithKVCacheSM90Plan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, - at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, - unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal, + at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size, + int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, + bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheSM90Run( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, - int32_t window_left BATCH_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + int64_t mask_mode_code, int64_t layout, + int64_t window_left BATCH_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); void BatchPrefillWithPagedKVCacheSM90Run( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, - std::optional maybe_lse, unsigned int mask_mode_code, unsigned int layout, - int32_t window_left BATCH_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + std::optional maybe_lse, int64_t mask_mode_code, int64_t layout, + int64_t window_left BATCH_PREFILL_SM90_ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, - "Cutlass Segment GEMM operator for SM90"); - m.def("single_prefill_with_kv_cache_sm90", &single_prefill_with_kv_cache_sm90); - m.def("batch_prefill_with_kv_cache_sm90_plan", &BatchPrefillWithKVCacheSM90Plan); - m.def("batch_prefill_with_ragged_kv_cache_sm90_run", &BatchPrefillWithRaggedKVCacheSM90Run); - m.def("batch_prefill_with_paged_kv_cache_sm90_run", &BatchPrefillWithPagedKVCacheSM90Run); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // "Cutlass Segment GEMM operator for SM90" + m.def("cutlass_segment_gemm_sm90", CutlassSegmentGEMMSM90); + m.def("single_prefill_with_kv_cache_sm90", single_prefill_with_kv_cache_sm90); + m.def("batch_prefill_with_kv_cache_sm90_plan", BatchPrefillWithKVCacheSM90Plan); + m.def("batch_prefill_with_ragged_kv_cache_sm90_run", BatchPrefillWithRaggedKVCacheSM90Run); + m.def("batch_prefill_with_paged_kv_cache_sm90_run", BatchPrefillWithPagedKVCacheSM90Run); } diff --git a/csrc/flashinfer_page_ops.cu b/csrc/flashinfer_page_ops.cu index e365eb629..349100dbd 100644 --- a/csrc/flashinfer_page_ops.cu +++ b/csrc/flashinfer_page_ops.cu @@ -18,18 +18,20 @@ void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, - unsigned int layout, int64_t cuda_stream); + int64_t layout, int64_t cuda_stream); void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indices, at::Tensor block_sparse_indptr, at::Tensor vector_sparse_offsets, at::Tensor vector_sparse_indptr, - at::Tensor kv_len_arr, unsigned int stride_block, - unsigned int stride_n, unsigned int batch_size, - unsigned int block_size, int64_t cuda_stream); + at::Tensor kv_len_arr, int64_t stride_block, + int64_t stride_n, int64_t batch_size, + int64_t block_size, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // "Append paged KV-Cache operator" + m.def("append_paged_kv_cache", append_paged_kv_cache); + // "Precompute block sparse offsets" m.def("block_sparse_indices_to_vector_sparse_offsets", - &block_sparse_indices_to_vector_sparse_offsets, "Precompute block sparse offsets"); + block_sparse_indices_to_vector_sparse_offsets); } diff --git a/csrc/flashinfer_quantization_ops.cu b/csrc/flashinfer_quantization_ops.cu index d23867bfb..5ff8052d3 100644 --- a/csrc/flashinfer_quantization_ops.cu +++ b/csrc/flashinfer_quantization_ops.cu @@ -20,7 +20,9 @@ void packbits(at::Tensor x, const std::string& bitorder, at::Tensor y, int64_t c void segment_packbits(at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, const std::string& bitorder, at::Tensor y, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("packbits", &packbits, "GPU packbits operator"); - m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // GPU packbits operator + m.def("packbits", packbits); + // GPU segment packbits operator + m.def("segment_packbits", segment_packbits); } diff --git a/csrc/flashinfer_rope_ops.cu b/csrc/flashinfer_rope_ops.cu index f447aba71..60205d40b 100644 --- a/csrc/flashinfer_rope_ops.cu +++ b/csrc/flashinfer_rope_ops.cu @@ -18,34 +18,37 @@ #include "pytorch_extension_utils.h" void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr, - at::Tensor offsets, unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, int64_t cuda_stream); + at::Tensor offsets, int64_t rotary_dim, bool interleave, double rope_scale, + double rope_theta, int64_t cuda_stream); void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor indptr, at::Tensor offsets, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, int64_t cuda_stream); + at::Tensor indptr, at::Tensor offsets, int64_t rotary_dim, + bool interleave, double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length, int64_t cuda_stream); void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, - float rope_scale, float rope_theta, int64_t cuda_stream); + at::Tensor pos_ids, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, int64_t cuda_stream); void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, - float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, + at::Tensor pos_ids, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length, int64_t cuda_stream); void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("apply_rope", &apply_rope, "Apply RoPE"); - m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); - m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); - m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, - "Apply Llama 3.1 style RoPE with positional ids"); - m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache, - "Apply RoPE with positional ids and cosine/sine cache"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // "Apply RoPE" + m.def("apply_rope", apply_rope); + // "Apply Llama 3.1 style RoPE" + m.def("apply_llama31_rope", apply_llama31_rope); + // "Apply RoPE with positional ids" + m.def("apply_rope_pos_ids", apply_rope_pos_ids); + // "Apply Llama 3.1 style RoPE with positional ids" + m.def("apply_llama31_rope_pos_ids", apply_llama31_rope_pos_ids); + // "Apply RoPE with positional ids and cosine/sine cache" + m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache); } diff --git a/csrc/flashinfer_sampling_ops.cu b/csrc/flashinfer_sampling_ops.cu index 313925e07..aaaac1c57 100644 --- a/csrc/flashinfer_sampling_ops.cu +++ b/csrc/flashinfer_sampling_ops.cu @@ -24,7 +24,7 @@ void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at: void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic, int64_t cuda_stream); + int64_t top_k_val, bool deterministic, int64_t cuda_stream); void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, std::optional maybe_min_p_arr, double min_p_val, @@ -41,11 +41,11 @@ void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, int64_t cuda_stream); void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_k_arr, unsigned int top_k_val, + std::optional maybe_top_k_arr, int64_t top_k_val, int64_t cuda_stream); void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, - std::optional maybe_top_k_arr, unsigned int top_k_val, + std::optional maybe_top_k_arr, int64_t top_k_val, int64_t cuda_stream); void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_ids, @@ -54,19 +54,23 @@ void chain_speculative_sampling(at::Tensor draft_probs, at::Tensor draft_token_i at::Tensor output_emitted_token_num, bool deterministic, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); - m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, - "Top-k sampling from probabilities"); - m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, - "Min-p sampling from probabilities"); - m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, - "Top-p sampling from probabilities"); - m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, - "Top-k and top-p sampling from probabilities"); - m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); - m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); - m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); - m.def("chain_speculative_sampling", &chain_speculative_sampling, - "Speculative sampling from sequence of probabilities"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Sample from probabilities + m.def("sampling_from_probs", sampling_from_probs); + // Top-k sampling from probabilities + m.def("top_k_sampling_from_probs", top_k_sampling_from_probs); + // Min-p sampling from probabilities + m.def("min_p_sampling_from_probs", min_p_sampling_from_probs); + // Top-p sampling from probabilities + m.def("top_p_sampling_from_probs", top_p_sampling_from_probs); + // Top-k and top-p sampling from probabilities + m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs); + // Renormalize probabilities by top-k mask + m.def("top_k_renorm_probs", top_k_renorm_probs); + // Renormalize probabilities by top-p mask + m.def("top_p_renorm_probs", top_p_renorm_probs); + // Mask logits by top-k mask + m.def("top_k_mask_logits", top_k_mask_logits); + // Speculative sampling from sequence of probabilities + m.def("chain_speculative_sampling", chain_speculative_sampling); } diff --git a/csrc/page.cu b/csrc/page.cu index db6841944..dbc8d6cc8 100644 --- a/csrc/page.cu +++ b/csrc/page.cu @@ -22,7 +22,7 @@ using namespace flashinfer; void append_paged_kv_cache(at::Tensor append_key, at::Tensor append_value, at::Tensor batch_indices, at::Tensor positions, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor kv_indices, at::Tensor kv_indptr, at::Tensor kv_last_page_len, - unsigned int layout, int64_t cuda_stream) { + int64_t layout, int64_t cuda_stream) { CHECK_LAST_DIM_CONTIGUOUS(append_key); CHECK_LAST_DIM_CONTIGUOUS(append_value); CHECK_INPUT(batch_indices); @@ -115,9 +115,9 @@ void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indic at::Tensor block_sparse_indptr, at::Tensor vector_sparse_offsets, at::Tensor vector_sparse_indptr, - at::Tensor kv_len_arr, unsigned int stride_block, - unsigned int stride_n, unsigned int batch_size, - unsigned int block_size, int64_t cuda_stream) { + at::Tensor kv_len_arr, int64_t stride_block, + int64_t stride_n, int64_t batch_size, + int64_t block_size, int64_t cuda_stream) { CHECK_INPUT(block_sparse_indices); CHECK_INPUT(block_sparse_indptr); CHECK_INPUT(vector_sparse_offsets); diff --git a/csrc/pytorch_conversion_utils.h b/csrc/pytorch_conversion_utils.h new file mode 100644 index 000000000..105b188ba --- /dev/null +++ b/csrc/pytorch_conversion_utils.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +inline at::Tensor vec_to_tensor(const std::vector& vec) { + return at::tensor(vec, at::dtype(at::kLong).device(at::kCPU)); +} + +inline std::vector tensor_to_vec(const at::Tensor& tensor) { + const size_t size = tensor.numel(); + const int64_t* first = tensor.const_data_ptr(); + const int64_t* last = first + size; + return std::vector(first, last); +} diff --git a/csrc/pytorch_extension_utils.h b/csrc/pytorch_extension_utils.h index ba965a95d..7976dec21 100644 --- a/csrc/pytorch_extension_utils.h +++ b/csrc/pytorch_extension_utils.h @@ -14,8 +14,9 @@ * limitations under the License. */ #pragma once +#include -#include +#include #ifdef FLASHINFER_ENABLE_BF16 #include @@ -29,6 +30,38 @@ #include #endif +#ifndef FLASHINFER_EXT_MODULE_INITED +#define FLASHINFER_EXT_MODULE_INITED + +// To expand macros in #name +#define FLASHINFER_EXT_MODULE_INIT_EXPAND(name) FLASHINFER_EXT_MODULE_INIT(name) + +/* Creates a dummy empty module that can be imported from Python. + The import from Python will load the .so consisting of the file + in this extension, so that the TORCH_LIBRARY_FRAGMENT static initializers + are run. */ +#define FLASHINFER_EXT_MODULE_INIT(name) \ +extern "C" { \ + __attribute__((weak)) PyObject *PyInit_##name(void) { \ + static struct PyModuleDef module_def = { \ + PyModuleDef_HEAD_INIT, \ + #name, /* name of module */ \ + NULL, /* module documentation, may be NULL */ \ + -1, /* size of per-interpreter state of the module, \ + or -1 if the module keeps state in global variables. */ \ + NULL, /* methods */ \ + }; \ + return PyModule_Create(&module_def); \ + } \ +} + +FLASHINFER_EXT_MODULE_INIT_EXPAND(TORCH_EXTENSION_NAME) + +#undef FLASHINFER_EXT_MODULE_INIT +#undef FLASHINFER_EXT_MODULE_INIT_EXPAND + +#endif + #ifdef FLASHINFER_ENABLE_F16 #define _DISPATCH_CASE_F16(c_type, ...) \ case at::ScalarType::Half: { \ diff --git a/csrc/renorm.cu b/csrc/renorm.cu index 4a17ce2e1..a79460a28 100644 --- a/csrc/renorm.cu +++ b/csrc/renorm.cu @@ -39,7 +39,7 @@ void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, } void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, - std::optional maybe_top_k_arr, unsigned int top_k_val, + std::optional maybe_top_k_arr, int64_t top_k_val, int64_t cuda_stream) { CHECK_INPUT(probs); auto device = probs.device(); @@ -59,7 +59,7 @@ void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, } void top_k_mask_logits(at::Tensor logits, at::Tensor mask_logits, - std::optional maybe_top_k_arr, unsigned int top_k_val, + std::optional maybe_top_k_arr, int64_t top_k_val, int64_t cuda_stream) { CHECK_INPUT(logits); auto device = logits.device(); diff --git a/csrc/rope.cu b/csrc/rope.cu index b4eefcc2c..d8d000fa0 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -20,8 +20,8 @@ using namespace flashinfer; void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, at::Tensor indptr, - at::Tensor offsets, unsigned int rotary_dim, bool interleave, float rope_scale, - float rope_theta, int64_t cuda_stream) { + at::Tensor offsets, int64_t rotary_dim, bool interleave, double rope_scale, + double rope_theta, int64_t cuda_stream) { CHECK_LAST_DIM_CONTIGUOUS(q); CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_INPUT(indptr); @@ -65,8 +65,8 @@ void apply_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope } void apply_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, - float rope_scale, float rope_theta, int64_t cuda_stream) { + at::Tensor pos_ids, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, int64_t cuda_stream) { CHECK_LAST_DIM_CONTIGUOUS(q); CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_INPUT(pos_ids); @@ -153,9 +153,9 @@ void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_r } void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor indptr, at::Tensor offsets, unsigned int rotary_dim, - bool interleave, float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, int64_t cuda_stream) { + at::Tensor indptr, at::Tensor offsets, int64_t rotary_dim, + bool interleave, double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length, int64_t cuda_stream) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous CHECK_INPUT(indptr); @@ -200,9 +200,9 @@ void apply_llama31_rope(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tenso } void apply_llama31_rope_pos_ids(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, - at::Tensor pos_ids, unsigned int rotary_dim, bool interleave, - float rope_scale, float rope_theta, float low_freq_factor, - float high_freq_factor, float old_context_length, + at::Tensor pos_ids, int64_t rotary_dim, bool interleave, + double rope_scale, double rope_theta, double low_freq_factor, + double high_freq_factor, double old_context_length, int64_t cuda_stream) { CHECK_CUDA(q); // not necessarily contiguous CHECK_CUDA(k); // not necessarily contiguous diff --git a/csrc/sampling.cu b/csrc/sampling.cu index 2677419ea..815d710eb 100644 --- a/csrc/sampling.cu +++ b/csrc/sampling.cu @@ -66,7 +66,7 @@ void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at: void top_k_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic, int64_t cuda_stream) { + int64_t top_k_val, bool deterministic, int64_t cuda_stream) { CHECK_INPUT(probs); CHECK_INPUT(uniform_samples); auto device = probs.device(); diff --git a/csrc/single_decode.cu b/csrc/single_decode.cu index b3d31ed5f..661993114 100644 --- a/csrc/single_decode.cu +++ b/csrc/single_decode.cu @@ -30,8 +30,8 @@ cudaError_t SingleDecodeWithKVCacheDispatched(Params params, typename Params::DT using namespace flashinfer; void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, - at::Tensor o, unsigned int layout, - int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { + at::Tensor o, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { CHECK_INPUT(q); CHECK_INPUT(k); CHECK_INPUT(v); diff --git a/csrc/single_decode_jit_pybind.cu b/csrc/single_decode_jit_pybind.cu index 107ab4f91..a79906057 100644 --- a/csrc/single_decode_jit_pybind.cu +++ b/csrc/single_decode_jit_pybind.cu @@ -18,9 +18,10 @@ #include "single_decode_config.inc" void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, - at::Tensor o, unsigned int layout, - int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + at::Tensor o, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("run", &single_decode_with_kv_cache, "Single-request decode with KV-Cache operator"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Single-request decode with KV-Cache operator + m.def("run", single_decode_with_kv_cache); } diff --git a/csrc/single_prefill.cu b/csrc/single_prefill.cu index caaa78133..2ead09e68 100644 --- a/csrc/single_prefill.cu +++ b/csrc/single_prefill.cu @@ -35,8 +35,8 @@ using namespace flashinfer; void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, - int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { + int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { auto device = q.device(); unsigned int head_dim_qk = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; diff --git a/csrc/single_prefill_jit_pybind.cu b/csrc/single_prefill_jit_pybind.cu index 63452fe68..fdfe46816 100644 --- a/csrc/single_prefill_jit_pybind.cu +++ b/csrc/single_prefill_jit_pybind.cu @@ -18,10 +18,10 @@ void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, - int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); + int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("run", &single_prefill_with_kv_cache, - "Single-request prefill attention with KV-Cache operator"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Single-request prefill attention with KV-Cache operator + m.def("run", single_prefill_with_kv_cache); } diff --git a/csrc/single_prefill_sm90.cu b/csrc/single_prefill_sm90.cu index 4e89a1ecd..529250381 100644 --- a/csrc/single_prefill_sm90.cu +++ b/csrc/single_prefill_sm90.cu @@ -33,8 +33,8 @@ using namespace flashinfer; void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, - int32_t window_left ADDITIONAL_FUNC_PARAMS, + int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) { unsigned int head_dim_qk = q.size(2); unsigned int head_dim_vo = v.size(2); diff --git a/csrc/single_prefill_sm90_jit_pybind.cu b/csrc/single_prefill_sm90_jit_pybind.cu index 4b288de8b..dfd95c8fc 100644 --- a/csrc/single_prefill_sm90_jit_pybind.cu +++ b/csrc/single_prefill_sm90_jit_pybind.cu @@ -18,11 +18,11 @@ void single_prefill_with_kv_cache_sm90(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o, std::optional maybe_lse, - unsigned int mask_mode_code, unsigned int layout, - int32_t window_left ADDITIONAL_FUNC_PARAMS, + int64_t mask_mode_code, int64_t layout, + int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("run", &single_prefill_with_kv_cache_sm90, - "Single-request prefill attention with KV-Cache operator"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + // Single-request prefill attention with KV-Cache operator + m.def("run", single_prefill_with_kv_cache_sm90); } diff --git a/flashinfer/activation.py b/flashinfer/activation.py index c5795bbb4..e9455f0bd 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -56,7 +56,7 @@ def get_act_and_mul_module(act_func_name: str): global _jit_modules if act_func_name not in _jit_modules: if has_prebuilt_ops: - from . import _kernels # type: ignore[attr-defined] + _kernels = torch.ops.flashinfer_kernels module = _kernels else: diff --git a/flashinfer/cascade.py b/flashinfer/cascade.py index d64d39620..d7458db14 100644 --- a/flashinfer/cascade.py +++ b/flashinfer/cascade.py @@ -30,7 +30,7 @@ def get_cascade_module(): global _cascade_module if _cascade_module is None: if has_prebuilt_ops: - from . import _kernels + _kernels = torch.ops.flashinfer_kernels _cascade_module = _kernels else: diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 5c562391a..25c7d1f0f 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -66,7 +66,7 @@ def get_single_decode_module(*args): if args not in _single_decode_modules: uri = get_single_decode_uri(*args) if has_prebuilt_ops and uri in prebuilt_ops_uri: - from . import _kernels + _kernels = torch.ops.flashinfer_kernels run_func = _kernels.single_decode_with_kv_cache else: @@ -213,7 +213,7 @@ def get_batch_decode_module(*args): if args not in _batch_decode_modules: uri = get_batch_decode_uri(*args) if has_prebuilt_ops and uri in prebuilt_ops_uri: - from . import _kernels + _kernels = torch.ops.flashinfer_kernels plan_func = _kernels.batch_decode_with_paged_kv_cache_plan run_func = _kernels.batch_decode_with_paged_kv_cache_run diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 1dfafa4f7..2f5b88177 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -39,7 +39,7 @@ def get_gemm_module(): global _gemm_module if _gemm_module is None: if has_prebuilt_ops: - from . import _kernels + _kernels = torch.ops.flashinfer_kernels module = _kernels else: @@ -149,7 +149,7 @@ def get_gemm_sm90_module(): global _gemm_module_sm90 if _gemm_module_sm90 is None: if has_prebuilt_ops: - from . import _kernels_sm90 + _kernels_sm90 = torch.ops.flashinfer_kernels_sm90 module = _kernels_sm90 else: diff --git a/flashinfer/jit/activation.py b/flashinfer/jit/activation.py index c0dc59b09..e82a7df99 100644 --- a/flashinfer/jit/activation.py +++ b/flashinfer/jit/activation.py @@ -49,8 +49,8 @@ }); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("{{ func_name }}", &{{ func_name }}, "Fused {{ act_func_name }} and Mul"); +TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { + m.def("{{ func_name }}", {{ func_name }}); } """ diff --git a/flashinfer/jit/attention.py b/flashinfer/jit/attention.py index 5816147e4..38720ba3f 100644 --- a/flashinfer/jit/attention.py +++ b/flashinfer/jit/attention.py @@ -375,7 +375,7 @@ def gen_single_decode_module( "rope_rcp_scale", "rope_rcp_theta", ], # additional_scalar_names - ["float", "float", "float", "float"], # additional_scalar_dtypes + ["double", "double", "double", "double"], # additional_scalar_dtypes f"DefaultAttention", # variant_name f"#include", # variant_decl pos_encoding_mode=pos_encoding_mode, @@ -417,14 +417,14 @@ def gen_single_prefill_module( "rope_rcp_scale", "rope_rcp_theta", ] - additional_scalar_dtypes = ["float", "float", "float", "float"] + additional_scalar_dtypes = ["double", "double", "double", "double"] variant_name = f"DefaultAttention" variant_decl = f"#include" else: additional_tensor_names = [] additional_tensor_dtypes = [] additional_scalar_names = ["logits_soft_cap", "sm_scale"] - additional_scalar_dtypes = ["float", "float"] + additional_scalar_dtypes = ["double", "double"] variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" variant_decl = f"#include" @@ -487,7 +487,7 @@ def gen_batch_decode_module( "rope_rcp_scale", "rope_rcp_theta", ], # additional_scalar_names - ["float", "float", "float", "float"], # additional_scalar_dtypes + ["double", "double", "double", "double"], # additional_scalar_dtypes f"DefaultAttention", # variant_name f"#include", # variant_decl pos_encoding_mode=pos_encoding_mode, @@ -540,14 +540,14 @@ def gen_batch_prefill_module( "rope_rcp_scale", "rope_rcp_theta", ] - additional_scalar_dtypes = ["float", "float", "float", "float"] + additional_scalar_dtypes = ["double", "double", "double", "double"] variant_name = f"DefaultAttention" variant_decl = f"#include" else: additional_tensor_names = [] additional_tensor_dtypes = [] additional_scalar_names = ["logits_soft_cap", "sm_scale"] - additional_scalar_dtypes = ["float", "float"] + additional_scalar_dtypes = ["double", "double"] variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>" variant_decl = f"#include" diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 0e7aea323..c8cd2946e 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import List, Union, Optional from contextlib import suppress +import torch import torch.utils.cpp_extension as torch_cpp_ext from filelock import FileLock @@ -117,7 +118,7 @@ def load_cuda_ops( ] + CUTLASS_INCLUDE_DIRS lock = FileLock(FLASHINFER_JIT_DIR / f"{name}.lock", thread_local=False) with lock: - module = torch_cpp_ext.load( + torch_cpp_ext.load( name, list(map(lambda _: str(_), sources)), extra_cflags=cflags, @@ -127,6 +128,9 @@ def load_cuda_ops( build_directory=build_directory, verbose=verbose, with_cuda=True, + # We switched to torch.library, so will be loaded into torch.ops + # instead of into a separate module. + is_python_module=False, ) logger.info(f"Finished loading JIT ops: {name}") - return module + return getattr(torch.ops, name) diff --git a/flashinfer/norm.py b/flashinfer/norm.py index 1919296fb..600f86cd5 100644 --- a/flashinfer/norm.py +++ b/flashinfer/norm.py @@ -28,7 +28,7 @@ def get_norm_module(): global _norm_module if _norm_module is None: if has_prebuilt_ops: - from . import _kernels + _kernels = torch.ops.flashinfer_kernels _norm_module = _kernels else: diff --git a/flashinfer/page.py b/flashinfer/page.py index b0f80903d..898dafd62 100644 --- a/flashinfer/page.py +++ b/flashinfer/page.py @@ -37,7 +37,7 @@ def get_page_module(): global _page_module if _page_module is None: if has_prebuilt_ops: - from . import _kernels + _kernels = torch.ops.flashinfer_kernels _page_module = _kernels else: diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 18e521b6e..e5df023b4 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -70,11 +70,11 @@ def backend_module(*args): uri = get_single_prefill_uri(backend, *args) if has_prebuilt_ops and uri in prebuilt_ops_uri: if backend == "fa2": - from . import _kernels + _kernels = torch.ops.flashinfer_kernels run_func = _kernels.single_prefill_with_kv_cache else: - from . import _kernels_sm90 + _kernels_sm90 = torch.ops.flashinfer_kernels_sm90 run_func = _kernels_sm90.single_prefill_with_kv_cache_sm90 else: @@ -177,13 +177,13 @@ def backend_module(*args): uri = get_batch_prefill_uri(backend, *args) if has_prebuilt_ops and uri in prebuilt_ops_uri: if backend == "fa2": - from . import _kernels + _kernels = torch.ops.flashinfer_kernels plan_func = _kernels.batch_prefill_with_kv_cache_plan ragged_run_func = _kernels.batch_prefill_with_ragged_kv_cache_run paged_run_func = _kernels.batch_prefill_with_paged_kv_cache_run else: - from . import _kernels_sm90 + _kernels_sm90 = torch.ops.flashinfer_kernels_sm90 plan_func = _kernels_sm90.batch_prefill_with_kv_cache_sm90_plan ragged_run_func = ( diff --git a/flashinfer/quantization.py b/flashinfer/quantization.py index f5c00340b..20e7574fc 100644 --- a/flashinfer/quantization.py +++ b/flashinfer/quantization.py @@ -28,7 +28,7 @@ def get_quantization_module(): global _quantization_module if _quantization_module is None: if has_prebuilt_ops: - from . import _kernels + _kernels = torch.ops.flashinfer_kernels _quantization_module = _kernels else: diff --git a/flashinfer/rope.py b/flashinfer/rope.py index e587f0b60..97a7a5b5a 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -28,7 +28,7 @@ def get_rope_module(): global _rope_module if _rope_module is None: if has_prebuilt_ops: - from . import _kernels + _kernels = torch.ops.flashinfer_kernels _rope_module = _kernels else: diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index 942167f00..233988445 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -29,7 +29,7 @@ def get_sampling_module(): global _sampling_module if _sampling_module is None: if has_prebuilt_ops: - from . import _kernels + _kernels = torch.ops.flashinfer_kernels module = _kernels else: diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 8744d8f3d..82b60045b 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -579,7 +579,7 @@ struct PrefillPlanInfo { void FromVector(const std::vector& vec) { if (vec.size() != 15) { std::ostringstream err_msg; - err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 14, but got " << vec.size(); + err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 15, but got " << vec.size(); FLASHINFER_ERROR(err_msg.str()); } padded_batch_size = vec[0]; @@ -734,7 +734,7 @@ struct PrefillPlanSM90Info { head_indices_offset(0), work_indptr_offset(0), same_schedule_for_all_heads(false) {} - + // convert PrefillPlanSM90Info to std::vector std::vector ToVector() const { return {qo_tile_indices_offset, qo_indptr_offset, diff --git a/include/flashinfer/gemm/group_gemm.cuh b/include/flashinfer/gemm/group_gemm.cuh index 3c142eea5..a7944a2ab 100644 --- a/include/flashinfer/gemm/group_gemm.cuh +++ b/include/flashinfer/gemm/group_gemm.cuh @@ -36,7 +36,7 @@ namespace group_gemm { template cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffer_size_in_bytes, - void* all_problems, unsigned int batch_size, void* x, void* w, + void* all_problems, int64_t batch_size, void* x, void* w, void* y, void* x_ld, void* w_ld, void* y_ld, bool weight_column_major, cudaStream_t stream) { using cutlass::epilogue::thread::LinearCombination; diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh index 07e60d977..c5cee1e3e 100644 --- a/include/flashinfer/gemm/group_gemm_sm90.cuh +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -51,7 +51,7 @@ using namespace cute; template cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer, size_t int_buffer_size_in_bytes, - void* all_problems, unsigned int batch_size, void* x, void* w, + void* all_problems, int64_t batch_size, void* x, void* w, void* y, void* x_stride, void* w_stride, void* y_stride, bool weight_column_major, cudaStream_t stream) { auto compute_capacity = GetCudaComputeCapability(); diff --git a/setup.py b/setup.py index 1a606a3cf..72152dc44 100644 --- a/setup.py +++ b/setup.py @@ -221,6 +221,10 @@ def __init__(self, *args, **kwargs) -> None: "-compress-all", "-use_fast_math", ] + libraries = [ + "cublas", + "cublasLt", + ] sm90a_flags = "-gencode arch=compute_90a,code=sm_90a".split() kernel_sources = [ "csrc/bmm_fp8.cu", @@ -252,9 +256,10 @@ def __init__(self, *args, **kwargs) -> None: prefill_sm90_sources = list(gen_dir.glob("*prefill_head*_sm90.cu")) ext_modules = [ torch_cpp_ext.CUDAExtension( - name="flashinfer._kernels", + name="flashinfer.flashinfer_kernels", sources=kernel_sources + decode_sources + prefill_sources, include_dirs=include_dirs, + libraries=libraries, extra_compile_args={ "cxx": cxx_flags, "nvcc": nvcc_flags, @@ -264,9 +269,10 @@ def __init__(self, *args, **kwargs) -> None: if enable_sm90: ext_modules += [ torch_cpp_ext.CUDAExtension( - name="flashinfer._kernels_sm90", + name="flashinfer.flashinfer_kernels_sm90", sources=kernel_sm90_sources + prefill_sm90_sources, include_dirs=include_dirs, + libraries=libraries, extra_compile_args={ "cxx": cxx_flags, "nvcc": nvcc_flags + sm90a_flags, diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index 84fb86bc0..970417e33 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -52,7 +52,7 @@ def test_single_decode_mask(): ["custom_mask"], # additional_tensor_names ["uint8_t"], # additional_tensor_dtypes ["sm_scale"], # # additional_scalar_names - ["float"], # additional_scalar_dtypes + ["double"], # additional_scalar_dtypes "SingleDecodeWithCustomMask", variant_decl, ) @@ -137,7 +137,7 @@ def test_flash_sigmoid(): [], # additional_tensor_names [], # additional_tensor_dtypes ["logits_scale", "sigmoid_bias"], # additional_scalar_names - ["float", "float"], # additional_scalar_dtypes + ["double", "double"], # additional_scalar_dtypes "FlashSigmoid", variant_decl, ) @@ -196,7 +196,7 @@ def test_dump_logits(): ["output_logits"], # additional_tensor_names ["float"], # additional_tensor_dtypes ["sm_scale"], # additional_scalar_names - ["float"], # additional_scalar_dtypes + ["double"], # additional_scalar_dtypes "DumpLogits", variant_decl, ) @@ -221,7 +221,7 @@ def test_batch_decode_flash_sigmoid(use_tensor_cores): torch.manual_seed(42) variant_decl = flash_sigmoid_sm80_decl jit_args = ( - "batch_decode_flash_sigmoid_sm80", # uri + f"batch_decode_flash_sigmoid_sm80_{use_tensor_cores}", # uri torch.float16, # dtype_q torch.float16, # dtype_kv torch.float16, # dtype_o @@ -231,7 +231,7 @@ def test_batch_decode_flash_sigmoid(use_tensor_cores): [], # additional_tensor_names [], # additional_tensor_dtypes ["logits_scale", "sigmoid_bias"], # additional_scalar_names - ["float", "float"], # additional_scalar_dtypes + ["double", "double"], # additional_scalar_dtypes "FlashSigmoid", variant_decl, ) @@ -338,7 +338,7 @@ def test_batch_prefill_flash_sigmoid(): [], # additional_tensor_names [], # additional_tensor_dtypes ["logits_scale", "sigmoid_bias"], # additional_scalar_names - ["float", "float"], # additional_scalar_dtypes + ["double", "double"], # additional_scalar_dtypes "FlashSigmoid", variant_decl, ) @@ -457,7 +457,7 @@ def test_batch_prefill_sm90_flash_sigmoid(): [], # additional_tensor_names [], # additional_tensor_dtypes ["logits_scale", "sigmoid_bias"], # additional_scalar_names - ["float", "float"], # additional_scalar_dtypes + ["double", "double"], # additional_scalar_dtypes "FlashSigmoid", variant_decl, ) @@ -600,7 +600,7 @@ def test_debug_print_logits(): [], # additional_tensor_names [], # additional_tensor_dtypes ["sm_scale"], # additional_scalar_names - ["float"], # additional_scalar_dtypes + ["double"], # additional_scalar_dtypes "DebugPrintLogits", variant_decl, ) @@ -673,7 +673,7 @@ def test_sm90_debug_print_logits(): [], # additional_tensor_names [], # additional_tensor_dtypes ["sm_scale"], # additional_scalar_names - ["float"], # additional_scalar_dtypes + ["double"], # additional_scalar_dtypes "DebugPrintLogits", variant_decl, )