Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP4 gemm and FP4 checkpoints #3899

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"GPTQLinearMethod",
"FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod",
"ModelOptFp4LinearMethod",
"IPEXAWQLinearMethod",
]

Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config,
ModelOptFp8Config,
)
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config

QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
Expand All @@ -39,6 +42,7 @@
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptFp4Config,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
Expand Down
241 changes: 241 additions & 0 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
Expand All @@ -19,6 +20,11 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8 import Fp8KVCacheMethod
from sglang.srt.utils import is_cuda_available

if is_cuda_available():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant

# Initialize logger for the module
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -188,3 +194,238 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):

def __init__(self, quant_config: ModelOptFp8Config):
super().__init__(quant_config)


class ModelOptFp4Config(QuantizationConfig):
"""Config class for FP4."""

def __init__(
self,
is_checkpoint_nvfp4_serialized: bool = False,
kv_cache_quant_algo: str = None,
group_size: int = None,
exclude_modules: List[str] = None,
) -> None:
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
"Detected nvfp4 checkpoint. Please note that the "
"format is experimental and subject to change."
)
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules

@classmethod
def get_name(cls) -> str:
return "modelopt_fp4"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]

@classmethod
def get_min_capability(cls) -> int:
return 100

@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
if not quant_method in ["FP8", "NVFP4"]:
raise ValueError(
f"ModelOpt currently only supports: FP8, NVFP4"
" quantizations in sglang. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
if not (group_size and kv_cache_quant_algo and exclude_modules):
raise ValueError(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
)
return cls(
is_checkpoint_nvfp4_serialized,
kv_cache_quant_algo,
group_size,
exclude_modules,
)

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import

if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules):
return UnquantizedLinearMethod()
return ModelOptFp4LinearMethod(self)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class ModelOptFp4LinearMethod(LinearMethodBase):
"""Linear method for NVFP4.
Supports loading NVFP4 checkpoints with the following structure:

|Tensor Name | datatype | shape |
|----------------------------------------------------|
|input_scale | torch.float32 | scalar |
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|weight_scale | FP8-E4M3 | [X, Y] |
|weight_scale_2 | torch.float32 | scalar |

The weights are quantized per block of 16 elements.
Args: quant_config: The ModelOpt quantization config.
"""

def __init__(self, quant_config: ModelOptFp4Config):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)

output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")

layer.logical_widths = output_partition_sizes

layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
if input_size_per_partition % 16 != 0:
raise ValueError(
"Unsupported model when in features size is " "not multiple of 16"
)

weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_nvfp4_serialized
else params_dtype
)

weight = ModelWeightParameter(
data=torch.empty(
# 2 fp4 data is packed in one fp8 in the input dimension
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

input_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)

layer.register_parameter("input_scale", input_scale)

weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale_2", weight_scale_2)

weight_scale = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.group_size,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)

layer.register_parameter("weight_scale", weight_scale)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
input_scale_2 = layer.input_scale.max().to(torch.float32)
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
layer.alpha = Parameter(
layer.input_scale * layer.weight_scale_2, requires_grad=False
)

# Pad and blockwise interleave weight_scale
scales = layer.weight_scale
if scales.ndim == 2:
scales = scales.unsqueeze(0)
assert scales.ndim == 3
B, M, K = scales.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype)
padded_scales[:B, :M, :K] = scales
batches, rows, cols = padded_scales.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5))
layer.weight_scale_interleaved = Parameter(
padded_scales.contiguous().cuda(), requires_grad=False
)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
output_dtype = x.dtype
x_m, _ = x.shape
w_n, _ = layer.weight.shape
output_shape = [x_m, w_n]

# Quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale)

assert x_fp4.dtype == torch.uint8
assert x_scale_interleaved.dtype == torch.float8_e4m3fn
assert layer.weight.dtype == torch.uint8
assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32

out = cutlass_scaled_fp4_mm(
x_fp4,
layer.weight,
x_scale_interleaved,
layer.weight_scale_interleaved,
layer.alpha,
output_dtype,
)
if bias is not None:
out = out + bias
return out.view(*output_shape)
1 change: 1 addition & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"bitsandbytes",
"gguf",
"modelopt",
"modelopt_fp4",
"w8a8_int8",
],
help="The quantization method.",
Expand Down
14 changes: 14 additions & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def _get_version():
"src/sgl-kernel/csrc/speculative_sampling.cu",
"src/sgl-kernel/csrc/per_token_group_quant_fp8.cu",
"src/sgl-kernel/csrc/cublas_grouped_gemm.cu",
"src/sgl-kernel/csrc/cutlass_extensions/common.cpp",
"src/sgl-kernel/csrc/quantization/fp4/nvfp4_quant_entry.cu",
"src/sgl-kernel/csrc/quantization/fp4/nvfp4_quant_kernels.cu",
"src/sgl-kernel/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu",
"src/sgl-kernel/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu",
"3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/norm.cu",
Expand All @@ -113,11 +118,16 @@ def _get_version():

enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
enable_fp4 = os.getenv("SGL_KERNEL_ENABLE_FP4", "0") == "1"
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
enable_sm100a = os.getenv("SGL_KERNEL_ENABLE_SM100A", "0") == "1"
cuda_version = _get_cuda_version()
sm_version = _get_device_sm()

if torch.cuda.is_available():
if cuda_version >= (12, 8) and sm_version >= 100:
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
nvcc_flags.append("-DENABLE_NVFP4=1")
if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if sm_version >= 90:
Expand All @@ -126,8 +136,12 @@ def _get_version():
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
else:
# compilation environment without GPU
if enable_sm100a:
nvcc_flags.append("-gencode=arch=compute_100a,code=sm_100a")
if enable_sm90a:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
if enable_fp4:
nvcc_flags.append("-DENABLE_NVFP4=1")
if enable_fp8:
nvcc_flags.extend(nvcc_flags_fp8)
if enable_bf16:
Expand Down
4 changes: 4 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
cublas_grouped_gemm,
custom_dispose,
custom_reduce,
cutlass_scaled_fp4_mm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
fused_add_rmsnorm,
Expand All @@ -31,6 +32,7 @@
register_graph_buffers,
rmsnorm,
sampling_scaling_penalties,
scaled_fp4_quant,
sgl_per_token_group_quant_fp8,
silu_and_mul,
top_k_renorm_prob,
Expand All @@ -47,6 +49,7 @@
"cublas_grouped_gemm",
"custom_dispose",
"custom_reduce",
"cutlass_scaled_fp4_mm",
"fp8_blockwise_scaled_mm",
"fp8_scaled_mm",
"fused_add_rmsnorm",
Expand All @@ -63,6 +66,7 @@
"register_graph_buffers",
"rmsnorm",
"sampling_scaling_penalties",
"scaled_fp4_quant",
"silu_and_mul",
"top_k_renorm_prob",
"top_k_top_p_sampling_from_probs",
Expand Down
11 changes: 11 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "cutlass_extensions/common.hpp"

int32_t get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
0);
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
0);
int32_t version_num = major_capability * 10 + minor_capability;
return version_num;
}
Loading
Loading