Skip to content

Commit

Permalink
Push activations and output transposes into CUTLASS code
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 13, 2024
1 parent ac059b4 commit b559b6a
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 165 deletions.
3 changes: 1 addition & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ def main(args: argparse.Namespace):
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s, "
f"{total_num_tokens=} | {total_output_tokens=}")
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")

# Output JSON results if specified
if args.output_json:
Expand Down
20 changes: 20 additions & 0 deletions benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16)

Check failure on line 49 in benchmarks/cutlass_benchmarks/sparse_benchmarks.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_benchmarks.py:49:81: E501 Line too long (92 > 80)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")

timers = []
# pytorch impl - bfloat16
timers.append(
Expand Down Expand Up @@ -95,6 +105,16 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16)

Check failure on line 108 in benchmarks/cutlass_benchmarks/sparse_benchmarks.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_benchmarks.py:108:81: E501 Line too long (92 > 80)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")

timers = []

# pytorch impl w. bf16
Expand Down
9 changes: 6 additions & 3 deletions benchmarks/cutlass_benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ def prune_to_2_4(tensor):

def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
# a = torch.randn((m, k), device='cuda') * 5
# b = torch.randn((n, k), device='cuda').t() * 5

a = torch.ones((m, k), device='cuda')
b = torch.ones((n, k), device='cuda').t()

b = prune_to_2_4(b.t()).t()

Expand All @@ -78,7 +81,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
else:
raise ValueError("unsupported dtype")

b_compressed, e = ops.cutlass_compress_entry(b.t())
b_compressed, e = ops.cutlass_sparse_compress(b.t())

# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
Expand Down
4 changes: 2 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,12 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
c10::optional<torch::Tensor> const& bias);

void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& e, torch::Tensor const& b,
torch::Tensor const& b, torch::Tensor const& e,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);

bool cutlass_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e,
bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e,
torch::Tensor const& a);
#endif

Expand Down
7 changes: 2 additions & 5 deletions csrc/sparse/cutlass/sparse_compressor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;

// Just a dummy value
int32_t n = 128;

int64_t lda = a.stride(0);

using StrideA = Stride<int64_t, Int<1>, int64_t>;
Expand All @@ -85,7 +82,7 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
StrideA a_stride{lda, Int<1>{}, 0};

using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};

using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
Expand Down Expand Up @@ -155,7 +152,7 @@ bool sparsify_and_compress(torch::Tensor& a_compressed, torch::Tensor& e,
return true;
}

bool cutlass_compress_entry(torch::Tensor& a_compressed, torch::Tensor& e,
bool cutlass_sparse_compress(torch::Tensor& a_compressed, torch::Tensor& e,
torch::Tensor const& a) {
if (a.dtype() == torch::kBFloat16) {
return sparsify_and_compress<cutlass::bfloat16_t>(a_compressed, e, a);
Expand Down
Loading

0 comments on commit b559b6a

Please sign in to comment.