Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP, Kernel] (3/N) Machete W4A8 #8046

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
511 changes: 379 additions & 132 deletions benchmarks/kernels/benchmark_machete.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions benchmarks/kernels/graph_machete_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
args = parser.parse_args()

with open(args.filename, 'rb') as f:
data: List[TMeasurement] = pickle.load(f)
data = pickle.load(f)
raw_results: List[TMeasurement] = data["results"]

results = defaultdict(lambda: list())
for v in data:
for v in raw_results:
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
if result is not None:
KN = result.group(1)
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/kernels/weight_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,10 @@
([8192, 57344], 1),
([28672, 8192], 0),
],
"meta-llama/Llama-3.1-405b-hf": [
([16384, 18432], 1),
([16384, 16384], 0),
([16384, 106496], 1),
([53248, 16384], 0),
],
}
4 changes: 2 additions & 2 deletions csrc/cutlass_extensions/cute_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
// is the layout f(x) = x
template <typename Layout>
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
if constexpr (std::is_same_v<Layout, void>)
if constexpr (std::is_same_v<Layout, void>) {
return true;
else {
} else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
Expand Down
29 changes: 29 additions & 0 deletions csrc/cutlass_extensions/vllm_cutlass_library_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,35 @@ class MixedInputKernelScheduleType(enum.Enum):
}
}

VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
**DataTypeSize, # type: ignore
**{
VLLMDataType.u4b8: 4,
VLLMDataType.u8b128: 8,
}
}

VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
VLLMDataType.u4b8: "vllm::kU4B8",
VLLMDataType.u8b128: "vllm::kU8B128",
DataType.u4: "vllm::kU4",
DataType.u8: "vllm::kU8",
DataType.s4: "vllm::kS4",
DataType.s8: "vllm::kS8",
DataType.f16: "vllm::kFloat16",
DataType.bf16: "vllm::kBfloat16",
}

VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
DataType.u8: "at::ScalarType::Byte",
DataType.s8: "at::ScalarType::Char",
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
DataType.s32: "at::ScalarType::Int",
DataType.f16: "at::ScalarType::Half",
DataType.bf16: "at::ScalarType::BFloat16",
DataType.f32: "at::ScalarType::Float",
}

VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
Expand Down
Loading
Loading