Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP, Kernel] (2/N) Machete - Integrate into GPTQMarlinLinearMethod and CompressedTensorsWNA16 #5

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
458d69e
squash-patch changes
LucasWilkinson Jul 31, 2024
1ee3608
remove gptq support
LucasWilkinson Aug 30, 2024
ab7507e
formatting + fixes
LucasWilkinson Aug 30, 2024
68ff26d
add gptq_marlin support back
LucasWilkinson Aug 31, 2024
7b9e8b2
remove extra prints
LucasWilkinson Aug 31, 2024
30f1056
add machete act ordering
LucasWilkinson Sep 6, 2024
3bbb902
udpate heuristic
LucasWilkinson Sep 6, 2024
196a9f2
add to tests
LucasWilkinson Sep 6, 2024
38f5b84
update benchmark
LucasWilkinson Sep 6, 2024
c59449b
tweak for llama 405b
LucasWilkinson Sep 6, 2024
3048911
env var for disabling kernels
LucasWilkinson Sep 10, 2024
df7c4c0
format + mypy
LucasWilkinson Sep 11, 2024
6f3f707
yapf format
LucasWilkinson Sep 11, 2024
90b8e03
refactor
LucasWilkinson Sep 11, 2024
c264c7a
add g_idx back
LucasWilkinson Sep 11, 2024
2d25a9a
clean-up
LucasWilkinson Sep 11, 2024
62508c5
review comments
LucasWilkinson Sep 12, 2024
84cfdb2
fix codespell
LucasWilkinson Sep 12, 2024
c452a86
TorchDynamo Compatability
LucasWilkinson Sep 13, 2024
096dd4a
add permute cols opcheck
LucasWilkinson Sep 13, 2024
a98f691
fix correctness test
LucasWilkinson Sep 16, 2024
7c02bcf
bug in filtering kernels by compute capability
LucasWilkinson Sep 16, 2024
95a85c9
Merge remote-tracking branch 'origin/main' into lwilkinson/machete-en…
LucasWilkinson Sep 20, 2024
a019473
add requirements.txt
LucasWilkinson Sep 20, 2024
306b283
Merge branch 'main' into lwilkinson/machete-end2end
mgoin Sep 21, 2024
e32bfc5
[dbrx] refactor dbrx experts to extend FusedMoe class (#8518)
divakar-amd Sep 21, 2024
05752e9
[Kernel][Bugfix] Delete some more useless code in marlin_moe_ops.cu (…
tlrmchlsmth Sep 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
Expand Down
74 changes: 61 additions & 13 deletions benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import math
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple

import pandas as pd
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
Expand Down Expand Up @@ -84,6 +86,10 @@ def loop_over_weights(
fn(a, w_ref, w_q, w_s)


_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None


def bench(atype: torch.dtype,
wtype: ScalarType,
group_size: int,
Expand All @@ -94,6 +100,8 @@ def bench(atype: torch.dtype,
sub_label: str,
benchmark_marlinv1: bool = True,
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
global _SWEEP_SCHEDULES_RESULTS

a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
sub_label += f", L={len(weights)}"

Expand Down Expand Up @@ -163,6 +171,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
best_schedule = None
schedules = ops.machete_supported_schedules(wtype)
for schedule in reversed(schedules):
schedule_M = int(schedule.split("_")[0].split("x")[1])

# Prune known bad schedules
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
continue

def run(a, _, w_q, w_s, schedule=schedule):
ops.machete_gemm(a,
Expand All @@ -175,6 +188,20 @@ def run(a, _, w_q, w_s, schedule=schedule):
res = bench_fn(label, sub_label, "machete_best",
lambda: loop_over_weights(a, weights_machete, run))

results_row = {
"M": m,
"K": k,
"N": n,
"group_size": group_size,
"schedule": schedule,
"median": res.median,
}
if _SWEEP_SCHEDULES_RESULTS is None:
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
columns=results_row.keys())
_SWEEP_SCHEDULES_RESULTS.\
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row

print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median:
best = res
Expand Down Expand Up @@ -235,18 +262,22 @@ def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))

data = run(args.dtype, args.sweep_schedules, MKNs)

make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")]
m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")]
m_increment, k_increment, n_increment = \
[int(x) for x in args.dim_increment.split(",")]
Ms = list(range(m_start, m_end + 1, m_increment))
Ks = list(range(k_start, k_end + 1, k_increment))
Ns = list(range(n_start, n_end + 1, n_increment))
MKNs = list(product(Ms, Ks, Ns))

data = run(args.dtype, args.sweep_schedules, MKNs)

make_output(data, MKNs, f"range_bench-{args.dtype}")
Expand Down Expand Up @@ -333,6 +364,9 @@ def to_torch_dtype(dt):
action="store_true",
help="Run a sweep over all supported schedules",
)
parser.add_argument("--sweep-csv-out",
help="CSV to store sweep results",
default="sch_sweep_results.csv")
subparsers = parser.add_subparsers(dest="cmd", required=True)

square_parser = subparsers.add_parser("square_bench")
Expand All @@ -342,12 +376,21 @@ def to_torch_dtype(dt):
square_parser.set_defaults(func=run_square_bench)

range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.add_argument(
"--dim-start",
type=str,
required=True,
help="Start value for M,K,N as common separated list")
range_parser.add_argument(
"--dim-end",
type=str,
required=True,
help="End value (inclusive) for M,K,N as common separated list")
range_parser.add_argument(
"--dim-increment",
type=str,
required=True,
help="Increment value for M,K,N as common separated list")
range_parser.set_defaults(func=run_range_bench)

model_parser = subparsers.add_parser("model_bench")
Expand All @@ -369,4 +412,9 @@ def to_torch_dtype(dt):
model_parser.set_defaults(func=run_model_bench)

args = parser.parse_args()

_SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
args.func(args)

if _SWEEP_SCHEDULES_RESULTS is not None:
_SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)
1 change: 1 addition & 0 deletions benchmarks/kernels/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pandas
8 changes: 7 additions & 1 deletion csrc/cutlass_extensions/torch_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
return tensor.stride(idx);
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
} else {
// Extra strides are assumed to be 0 or 1
Expand Down
3 changes: 0 additions & 3 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1704,9 +1704,6 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
}

#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,

}; // namespace machete

torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);

torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
Expand Down
88 changes: 88 additions & 0 deletions csrc/permute_cols.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <cuda_fp16.h>

static constexpr int default_threads = 256;
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }

// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
// Currently only supports 16bit types (since we permute half types)
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = std::max(finish_row - start_row, 0);

int row_stride = size_k * sizeof(half) / 16;

auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;

int offset = row * row_stride;

half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);

int base_k = 0;

for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];

out_half[cur_k] = a_row_half[src_pos];

base_k += default_threads;
}

if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];

out_half[cur_k] = a_row_half[src_pos];
}
}
};

for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}

// More efficient version of A[..., perm]
// taken from gptq_marlin.cu
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
auto dev = A.get_device();
auto stream = at::cuda::getCurrentCUDAStream(dev);

TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
"Currently only 16bit types are supported");
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
TORCH_CHECK(A.size(-1) % 8 == 0,
"A columns must be a multiple of 8 (128bits)");
auto A_2d = A.view({-1, A.size(-1)});

torch::Tensor D = torch::empty_like(A);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
int block_rows = div_ceil(A_2d.size(0), sms);
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
A_2d.size(0), A_2d.size(1), block_rows);
return D;
}
Loading
Loading