Skip to content

Commit

Permalink
Fix the scale swap bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 13, 2024
1 parent 4c927a0 commit b177ab6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
9 changes: 4 additions & 5 deletions benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,11 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
scale_a = (torch.randn((m, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, n), device="cuda", dtype=torch.float32))
bias = torch.rand((n, ), device="cuda", dtype=torch.bfloat16) * 10

out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, torch.bfloat16)

Check failure on line 107 in benchmarks/cutlass_benchmarks/sparse_benchmarks.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_benchmarks.py:107:81: E501 Line too long (92 > 80)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out, out_ref):
Expand Down
4 changes: 2 additions & 2 deletions csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,10 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
out, a, bt_nzs, bt_meta, a_scales, b_scales, *bias);
out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias);
} else {
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
out, a, bt_nzs, bt_meta, a_scales, b_scales);
out, a, bt_nzs, bt_meta, b_scales, a_scales);
}
}

Expand Down

0 comments on commit b177ab6

Please sign in to comment.