Skip to content

Commit

Permalink
refactor: change to TORCH_LIBRARY (#823)
Browse files Browse the repository at this point in the history
This PR updates FlashInfer's C++/CUDA extensions from pybind11 modules
to `torch.libraries`, which is recommended since PyTorch 2.5.

This is mainly implemented in #764. We have investigated that the issue
in #820 was not caused by this PR, so we're opening it up again.

---------

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: abmfy <abmfy@icloud.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
  • Loading branch information
3 people authored Feb 13, 2025
1 parent c716aed commit dbb1e4e
Show file tree
Hide file tree
Showing 53 changed files with 503 additions and 386 deletions.
12 changes: 6 additions & 6 deletions aot_build_utils/generate_aot_default_additional_params_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_aot_default_additional_params_header_str() -> str:
"rope_rcp_scale",
"rope_rcp_theta",
], # additional_scalar_names
["float", "float", "float", "float"], # additional_scalar_dtypes
["double", "double", "double", "double"], # additional_scalar_dtypes
)

ret += generate_macro_entry(
Expand All @@ -98,15 +98,15 @@ def get_aot_default_additional_params_header_str() -> str:
"rope_rcp_scale",
"rope_rcp_theta",
],
["float", "float", "float", "float"],
["double", "double", "double", "double"],
)

ret += generate_macro_entry(
"SINGLE_PREFILL_SM90",
[],
[],
["logits_soft_cap", "sm_scale"],
["float", "float"],
["double", "double"],
is_sm90_template=True,
)

Expand All @@ -120,7 +120,7 @@ def get_aot_default_additional_params_header_str() -> str:
"rope_rcp_scale",
"rope_rcp_theta",
], # additional_scalar_names
["float", "float", "float", "float"], # additional_scalar_dtypes
["double", "double", "double", "double"], # additional_scalar_dtypes
)

ret += generate_macro_entry(
Expand All @@ -133,15 +133,15 @@ def get_aot_default_additional_params_header_str() -> str:
"rope_rcp_scale",
"rope_rcp_theta",
],
["float", "float", "float", "float"],
["double", "double", "double", "double"],
)

ret += generate_macro_entry(
"BATCH_PREFILL_SM90",
[],
[],
["logits_soft_cap", "sm_scale"],
["float", "float"],
["double", "double"],
is_sm90_template=True,
)

Expand Down
20 changes: 10 additions & 10 deletions csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "batch_decode_config.inc"
#include "pytorch_extension_utils.h"
#include "pytorch_conversion_utils.h"

namespace flashinfer {

Expand All @@ -32,13 +33,12 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params

using namespace flashinfer;

std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
at::Tensor BatchDecodeWithPagedKVCachePlan(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk,
unsigned int head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data,
int64_t cuda_stream) {
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size,
bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk,
int64_t head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
Expand Down Expand Up @@ -74,17 +74,17 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
});
});

return plan_info.ToVector();
return vec_to_tensor(plan_info.ToVector());
}

void BatchDecodeWithPagedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional<at::Tensor> maybe_lse,
unsigned int kv_layout_code, int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
int64_t kv_layout_code, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
DecodePlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
plan_info.FromVector(tensor_to_vec(plan_info_vec));
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
auto device = q.device();
int64_t batch_size = q.size(0);
Expand Down
23 changes: 12 additions & 11 deletions csrc/batch_decode_jit_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@
#include "batch_decode_config.inc"
#include "pytorch_extension_utils.h"

std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
at::Tensor BatchDecodeWithPagedKVCachePlan(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
bool enable_cuda_graph, int window_left, float logits_soft_cap, unsigned int head_dim_qk,
unsigned int head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data,
int64_t cuda_stream);
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size,
bool enable_cuda_graph, int64_t window_left, double logits_soft_cap, int64_t head_dim_qk,
int64_t head_dim_vo, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream);

void BatchDecodeWithPagedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional<at::Tensor> maybe_lse,
unsigned int kv_layout_code, int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);
int64_t kv_layout_code, int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("plan", &BatchDecodeWithPagedKVCachePlan, "Batched decode with paged KV-Cache plan");
m.def("run", &BatchDecodeWithPagedKVCacheRun, "Batched decode with paged KV-Cache run");
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
// Batched decode with paged KV-Cache plan
m.def("plan", BatchDecodeWithPagedKVCachePlan);
// Batched decode with paged KV-Cache run
m.def("run", BatchDecodeWithPagedKVCacheRun);
}
9 changes: 5 additions & 4 deletions csrc/batch_decode_mla_plan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

#include "mla_config.inc"
#include "pytorch_extension_utils.h"
#include "pytorch_conversion_utils.h"

using namespace flashinfer;

std::vector<int64_t> BatchDecodeWithPagedKVCachePlanMLA(
at::Tensor BatchDecodeWithPagedKVCachePlanMLA(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int page_size, bool enable_cuda_graph,
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph,
int64_t cuda_stream) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
Expand All @@ -35,5 +36,5 @@ std::vector<int64_t> BatchDecodeWithPagedKVCachePlanMLA(
TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCachePlanMLA failed with error ",
cudaGetErrorString(status));

return plan_info.ToVector();
return vec_to_tensor(plan_info.ToVector());
}
18 changes: 9 additions & 9 deletions csrc/batch_decode_mla_pybind.cu
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
#include "mla_config.inc"
#include "pytorch_extension_utils.h"

std::vector<int64_t> BatchDecodeWithPagedKVCachePlanMLA(
at::Tensor BatchDecodeWithPagedKVCachePlanMLA(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph, int64_t cuda_stream);

void BatchDecodeWithPagedKVCacheRunMLA(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, at::Tensor paged_kv_indptr,
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, float sm_scale,
int window_left, float logits_soft_cap, float rope_scale, float rope_theta,
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, double sm_scale,
int64_t window_left, double logits_soft_cap, double rope_scale, double rope_theta,
std::optional<at::Tensor> maybe_lse, int64_t cuda_stream);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("plan", &BatchDecodeWithPagedKVCachePlanMLA);
m.def("run", &BatchDecodeWithPagedKVCacheRunMLA);
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("plan", BatchDecodeWithPagedKVCachePlanMLA);
m.def("run", BatchDecodeWithPagedKVCacheRunMLA);
}
9 changes: 5 additions & 4 deletions csrc/batch_decode_mla_run.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@

#include "mla_config.inc"
#include "pytorch_extension_utils.h"
#include "pytorch_conversion_utils.h"

using namespace flashinfer;

void BatchDecodeWithPagedKVCacheRunMLA(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, at::Tensor paged_kv_indptr,
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, float sm_scale,
int window_left, float logits_soft_cap, float rope_scale, float rope_theta,
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, double sm_scale,
int64_t window_left, double logits_soft_cap, double rope_scale, double rope_theta,
std::optional<at::Tensor> maybe_lse, int64_t cuda_stream) {
DecodePlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
plan_info.FromVector(tensor_to_vec(plan_info_vec));

auto device = q_nope.device();
int64_t batch_size = q_nope.size(0);
Expand Down
16 changes: 8 additions & 8 deletions csrc/batch_mla_plan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
#include <optional>

#include "batch_mla_config.inc"
#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

std::vector<int64_t> BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len, unsigned int num_heads,
unsigned int head_dim_o, bool causal,
int64_t cuda_stream) {
at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
int64_t num_heads, int64_t head_dim_o, bool causal,
int64_t cuda_stream) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
size_t int_workspace_size_in_bytes =
Expand All @@ -47,5 +47,5 @@ std::vector<int64_t> BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffe

TORCH_CHECK(status == cudaSuccess, "Failed to plan MLA, error: ", cudaGetErrorString(status));

return plan_info.ToVector();
return vec_to_tensor(plan_info.ToVector());
}
29 changes: 14 additions & 15 deletions csrc/batch_mla_pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,21 @@
#include "batch_mla_config.inc"
#include "pytorch_extension_utils.h"

std::vector<int64_t> BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len, unsigned int num_heads,
unsigned int head_dim_o, bool causal,
int64_t cuda_stream);
at::Tensor BatchMLAPagedAttentionPlan(at::Tensor float_workspace_buffer,
at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len,
int64_t num_heads, int64_t head_dim_o, bool causal,
int64_t cuda_stream);

void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, at::Tensor q_nope,
at::Tensor q_pe, at::Tensor ckv_cache, at::Tensor kpe_cache,
at::Tensor kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int mask_mode_code,
int num_heads, int page_size, float sm_scale, int64_t cuda_stream);
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices,
at::Tensor o, std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code, int64_t num_heads, int64_t page_size,
double sm_scale, int64_t cuda_stream);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("plan", &BatchMLAPagedAttentionPlan, "Batch MLA Page Attention Plan");
m.def("run", &BatchMLAPagedAttentionRun, "Batch MLA Page Attention Run");
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("plan", &BatchMLAPagedAttentionPlan);
m.def("run", &BatchMLAPagedAttentionRun);
}
15 changes: 7 additions & 8 deletions csrc/batch_mla_run.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,29 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <driver_types.h>

#include <flashinfer/attention/mla_fa2.cuh>
#include <flashinfer/attention/scheduler.cuh>
#include <flashinfer/fastdiv.cuh>
#include <optional>

#include "batch_mla_config.inc"
#include "pytorch_conversion_utils.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

void BatchMLAPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, at::Tensor q_nope,
at::Tensor q_pe, at::Tensor ckv_cache, at::Tensor kpe_cache,
at::Tensor kv_indices, at::Tensor o,
std::optional<at::Tensor> maybe_lse, int mask_mode_code,
int num_heads, int page_size, float sm_scale, int64_t cuda_stream) {
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe,
at::Tensor ckv_cache, at::Tensor kpe_cache, at::Tensor kv_indices,
at::Tensor o, std::optional<at::Tensor> maybe_lse,
int64_t mask_mode_code, int64_t num_heads, int64_t page_size,
double sm_scale, int64_t cuda_stream) {
// q_nope: [n, num_heads, head_dim_ckv]
// q_pe: [n, num_heads, head_dim_kpe]
// ckv_cache: [num_pages, page_size, head_dim_ckv]
// kpe_cache: [num_pages, page_size, head_dim_kpe]
MLAPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
plan_info.FromVector(tensor_to_vec(plan_info_vec));

auto device = q_nope.device();

Expand Down
25 changes: 13 additions & 12 deletions csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "batch_prefill_config.inc"
#include "pytorch_extension_utils.h"
#include "pytorch_conversion_utils.h"

namespace flashinfer {

Expand All @@ -39,12 +40,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Para

using namespace flashinfer;

std::vector<int64_t> BatchPrefillWithKVCachePlan(
at::Tensor BatchPrefillWithKVCachePlan(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size,
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
bool enable_cuda_graph, unsigned int head_dim_qk, unsigned int head_dim_vo, bool causal,
at::Tensor kv_len_arr, int64_t total_num_rows, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size,
bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal,
int64_t cuda_stream) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
Expand All @@ -64,17 +65,17 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
TORCH_CHECK(status == cudaSuccess,
"Failed to plan prefill with error: ", cudaGetErrorString(status));

return plan_info.ToVector();
return vec_to_tensor(plan_info.ToVector());
}

void BatchPrefillWithRaggedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v,
at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional<at::Tensor> maybe_lse,
unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS,
int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS,
int64_t cuda_stream) {
PrefillPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
plan_info.FromVector(tensor_to_vec(plan_info_vec));
QKVLayout kv_layout = static_cast<QKVLayout>(layout);

int64_t num_qo_heads = q.size(1);
Expand Down Expand Up @@ -194,13 +195,13 @@ void BatchPrefillWithRaggedKVCacheRun(

void BatchPrefillWithPagedKVCacheRun(
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr,
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o,
std::optional<at::Tensor> maybe_lse, unsigned int mask_mode_code, unsigned int layout,
int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, int64_t layout,
int64_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream) {
PrefillPlanInfo plan_info;
plan_info.FromVector(plan_info_vec);
plan_info.FromVector(tensor_to_vec(plan_info_vec));
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
auto device = q.device();
int64_t batch_size = paged_kv_indptr.size(0) - 1;
Expand Down
Loading

0 comments on commit dbb1e4e

Please sign in to comment.