Skip to content

Commit

Permalink
squash-patch changes
Browse files Browse the repository at this point in the history
move heuristic into C++ code

fix unit tests + format

update for 3.5.1

remove custom scheduler

codespell

cleanup comment

cleanup diff

review comments

review comments

review comment changes

review comments

fix codespell

cleanup util logic

make dim names for prepack layout more canoncial

missed refactor

wip

interleaving + recasting

tweak tolerances

comments plus interleaving

format

codespell

review comments

end2end first pass

seperate out kernels, format

add machete as a gptq backend

update to use  ModelWeightParameter

formatting

update parameter.py

refactor permute layout

wip
  • Loading branch information
LucasWilkinson committed Aug 30, 2024
1 parent 1248e85 commit 735259b
Showing 16 changed files with 630 additions and 161 deletions.
3 changes: 2 additions & 1 deletion csrc/quantization/machete/machete_mm_kernel.cuh
Original file line number Diff line number Diff line change
@@ -152,7 +152,8 @@ struct MacheteKernelTemplate {

int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);

int const group_size = maybe_group_size.value_or(K);
int group_size = maybe_group_size.value_or(K);
group_size = (group_size == -1) ? K : group_size;
int const scale_k = (K + group_size - 1) / group_size;

TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
2 changes: 1 addition & 1 deletion csrc/quantization/machete/machete_prepack_launcher.cuh
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// clang-format on

// Allocate output
torch::Tensor D = torch::empty_like(B);
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);

prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
static_cast<ElementB*>(D.mutable_data_ptr()));
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
@@ -7,10 +7,11 @@
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
@@ -231,23 +232,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)
replace_parameter(layer, "qweight", marlin_qweight)

# Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)
replace_parameter(layer, "scales", marlin_scales)

# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qzeros", marlin_zp)
replace_parameter(layer, "qzeros", marlin_zp)

# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)
Original file line number Diff line number Diff line change
@@ -2,13 +2,10 @@

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
@@ -46,23 +43,32 @@ def __init__(self,

self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]

# Verify supported on platform.
verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size)

@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80

def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):

output_size_per_partition = sum(output_partition_sizes)

mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=self.group_size,
zero_points=False,
act_reordering=False
)

kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)

# If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1)
group_size = self.group_size if self.group_size != -1 else input_size
@@ -71,12 +77,6 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
# scales across all gpus.
partition_scales = (row_parallel and not channelwise)

verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)

scales_and_zp_size = input_size // group_size

if partition_scales:
@@ -123,62 +123,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.group_size = group_size
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name=None,
w_gidx_param_name=None)

# Checkpoints are serialized in compressed-tensors format, which is
# different from marlin format. Handle repacking here.
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.weight_packed.device

# Allocate marlin workspace.
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)

# Act-order not supported in compressed-tensors yet, so set to empty.
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
# Update for kernel
layer.weight_packed = torch.nn.Parameter(
layer.weight_packed.t().contiguous(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)

# Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_type.size_bits)
replace_tensor(layer, "weight_packed", marlin_qweight)

# Permute scales from compressed-tensors format to marlin format.
marlin_scales = marlin_permute_scales(
layer.weight_scale,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
self.kernel.process_weights_after_loading(layer)

def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

return apply_gptq_marlin_linear(
input=x,
weight=layer.weight_packed,
weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=True,
bias=bias)
return self.kernel.apply_weights(layer, x, bias)
99 changes: 25 additions & 74 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from typing import Any, Dict, List, Optional

import torch
from torch.nn import Parameter

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
check_marlin_supported, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
@@ -163,24 +161,29 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:

del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")

mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
act_reordering=self.quant_config.desc_act
)

kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)

# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size

verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)

# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
@@ -261,72 +264,20 @@ def create_weights(
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)

# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device

# required by torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)

# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx")

# Handle sorting for activation reordering if needed.
if self.quant_config.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "g_idx", g_idx)
else:
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

# No zero-point
layer.zp = marlin_make_empty_g_idx(device)

# Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight)

# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_gptq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias)
return self.kernel.apply_weights(layer, x, bias)
Loading

0 comments on commit 735259b

Please sign in to comment.