-
Notifications
You must be signed in to change notification settings - Fork 230
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: change to TORCH_LIBRARY (#823)
This PR updates FlashInfer's C++/CUDA extensions from pybind11 modules to `torch.libraries`, which is recommended since PyTorch 2.5. This is mainly implemented in #764. We have investigated that the issue in #820 was not caused by this PR, so we're opening it up again. --------- Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: abmfy <abmfy@icloud.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Zihao Ye <expye@outlook.com>
- Loading branch information
1 parent
c716aed
commit dbb1e4e
Showing
53 changed files
with
503 additions
and
386 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,20 @@ | ||
#include "mla_config.inc" | ||
#include "pytorch_extension_utils.h" | ||
|
||
std::vector<int64_t> BatchDecodeWithPagedKVCachePlanMLA( | ||
at::Tensor BatchDecodeWithPagedKVCachePlanMLA( | ||
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, | ||
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, | ||
unsigned int num_qo_heads, unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream); | ||
at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, int64_t batch_size, | ||
int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph, int64_t cuda_stream); | ||
|
||
void BatchDecodeWithPagedKVCacheRunMLA( | ||
at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, | ||
std::vector<int64_t> plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, | ||
at::Tensor plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, | ||
at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, at::Tensor paged_kv_indptr, | ||
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, float sm_scale, | ||
int window_left, float logits_soft_cap, float rope_scale, float rope_theta, | ||
at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, double sm_scale, | ||
int64_t window_left, double logits_soft_cap, double rope_scale, double rope_theta, | ||
std::optional<at::Tensor> maybe_lse, int64_t cuda_stream); | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("plan", &BatchDecodeWithPagedKVCachePlanMLA); | ||
m.def("run", &BatchDecodeWithPagedKVCacheRunMLA); | ||
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { | ||
m.def("plan", BatchDecodeWithPagedKVCachePlanMLA); | ||
m.def("run", BatchDecodeWithPagedKVCacheRunMLA); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.