forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Created the entry + arch structure for the compressor and ignore 2to4…
… tests for >90 sm capability
- Loading branch information
Showing
13 changed files
with
90 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#include "cutlass_extensions/common.hpp" | ||
|
||
int32_t get_sm_version_num() { | ||
int32_t major_capability, minor_capability; | ||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, | ||
0); | ||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, | ||
0); | ||
int32_t version_num = major_capability * 10 + minor_capability; | ||
return version_num; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include <cudaTypedefs.h> | ||
|
||
#include <c10/cuda/CUDAGuard.h> | ||
#include <torch/all.h> | ||
|
||
#include "cutlass_extensions/common.hpp" | ||
|
||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X | ||
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta, | ||
torch::Tensor const& a); | ||
#endif | ||
|
||
bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta, | ||
torch::Tensor const& a) { | ||
// Checks for conformality | ||
TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2); | ||
TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) && | ||
a_nzs.size(1) * 2 == a.size(1) && | ||
a_meta.size(1) * 2 * 4 == a.size(1)); | ||
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4 | ||
|
||
// Check for strides and alignment | ||
TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 && | ||
a_meta.stride(1) == 1); // Row-major | ||
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression | ||
|
||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); | ||
int32_t version_num = get_sm_version_num(); | ||
|
||
// Guard against compilation issues for sm90 kernels | ||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X | ||
if (version_num >= 90) { | ||
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a); | ||
} | ||
#endif | ||
|
||
TORCH_CHECK_NOT_IMPLEMENTED( | ||
false, | ||
"No compiled cutlass_scaled_sparse_mm for a compute capability less than " | ||
"CUDA device capability: ", | ||
version_num); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters