diff --git a/.github/workflows/execute-notebook.yml b/.github/workflows/execute-notebook.yml index 674efe86de1..f1cf30ac392 100644 --- a/.github/workflows/execute-notebook.yml +++ b/.github/workflows/execute-notebook.yml @@ -1,11 +1,6 @@ name: Execute Notebooks on: - push: - branches: [ main ] - paths: - - "python/sglang/**" - - "docs/**" pull_request: branches: [ main ] paths: diff --git a/README.md b/README.md index 8675d522425..2780c9562c5 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s ## Adoption and Sponsorship The project has been deployed to large-scale production, generating trillions of tokens every day. -It is supported by the following institutions: AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, and 01.AI. +It is supported by the following institutions: AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Iflytek, Jam & Tea Studios, LinkedIn, LMSYS, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, and 01.AI. logo diff --git a/benchmark/deepseek_v3/README.md b/benchmark/deepseek_v3/README.md index 910862db5e3..6196b09c417 100644 --- a/benchmark/deepseek_v3/README.md +++ b/benchmark/deepseek_v3/README.md @@ -6,14 +6,6 @@ Special thanks to Meituan's Search & Recommend Platform Team and Baseten's Model For optimizations made on the DeepSeek series models regarding SGLang, please refer to [DeepSeek Model Optimizations in SGLang](https://docs.sglang.ai/references/deepseek.html). -## Hardware Recommendation - -- 8 x NVIDIA H200 GPUs - -If you do not have GPUs with large enough memory, please try multi-node tensor parallelism. There is an example serving with [2 H20 nodes](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-2-h208) below. - -For running on AMD MI300X, use this as a reference. [Running DeepSeek-R1 on a single NDv5 MI300X VM](https://techcommunity.microsoft.com/blog/azurehighperformancecomputingblog/running-deepseek-r1-on-a-single-ndv5-mi300x-vm/4372726) - ## Installation & Launch If you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded. @@ -183,6 +175,15 @@ python3 benchmark/gsm8k/bench_sglang.py --num-questions 1319 --host http://10.0. python3 -m sglang.bench_one_batch_server --model None --base-url http://10.0.0.1:30000 --batch-size 1 --input-len 128 --output-len 128 ``` + +### Example: Serving with 8 A100/A800 with AWQ Quantization + +AWQ does not support BF16, so add the `--dtype half` flag if AWQ is used for quantization. One example is as follows: + +```bash +python3 -m sglang.launch_server --model cognitivecomputations/DeepSeek-R1-AWQ --tp 8 --trust-remote-code --dtype half +``` + ### Example: Serving on any cloud or Kubernetes with SkyPilot SkyPilot helps find cheapest available GPUs across any cloud or existing Kubernetes clusters and launch distributed serving with a single command. See details [here](https://github.com/skypilot-org/skypilot/tree/master/llm/deepseek-r1). diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index b6db00520ad..a14cf5ee925 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -41,13 +41,14 @@ def benchmark_config( topk: int, dtype: torch.dtype, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, block_shape: List[int] = None, num_iters: int = 100, ) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) - if use_int8_w8a16: + if use_int8_w8a16 or use_int8_w8a8: w1 = torch.randint( -127, 127, @@ -86,7 +87,7 @@ def benchmark_config( (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 ) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) - if use_fp8_w8a8: + if use_fp8_w8a8 or use_int8_w8a8: if block_shape is None: w1_scale = torch.randn(num_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32) @@ -105,6 +106,7 @@ def benchmark_config( (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32 ) + if use_fp8_w8a8: w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) @@ -126,6 +128,7 @@ def run(): renormalize=True, inplace=True, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, w1_scale=w1_scale, w2_scale=w2_scale, @@ -235,6 +238,7 @@ def benchmark( topk: int, dtype: torch.dtype, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, block_shape: List[int], ) -> Tuple[Dict[str, int], float]: @@ -270,6 +274,7 @@ def benchmark( topk, dtype, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, block_shape, ) @@ -284,6 +289,7 @@ def tune( topk: int, dtype: torch.dtype, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, block_shape: List[int], search_space: List[Dict[str, int]], @@ -301,6 +307,7 @@ def tune( topk, dtype, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, block_shape, num_iters=10, @@ -340,11 +347,15 @@ def save_configs( topk: int, dtype: torch.dtype, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, block_shape: List[int], ) -> None: dtype_str = get_config_dtype_str( - dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which @@ -396,6 +407,7 @@ def main(args: argparse.Namespace): hidden_size = config.hidden_size dtype = config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a8 = args.dtype == "int8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" block_shape = None if ( @@ -467,6 +479,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: topk, dtype, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, block_shape, search_space, @@ -485,6 +498,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: topk, dtype, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, block_shape, ) @@ -502,6 +516,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: topk, dtype, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, block_shape, ) @@ -521,7 +536,10 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: ) parser.add_argument("--tp-size", "-tp", type=int, default=2) parser.add_argument( - "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + "--dtype", + type=str, + choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"], + default="auto", ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) diff --git a/benchmark/kernels/quantization/tuning_block_wise_fp8.py b/benchmark/kernels/quantization/tuning_block_wise_kernel.py similarity index 84% rename from benchmark/kernels/quantization/tuning_block_wise_fp8.py rename to benchmark/kernels/quantization/tuning_block_wise_kernel.py index a84ec0f3071..197939f0292 100644 --- a/benchmark/kernels/quantization/tuning_block_wise_fp8.py +++ b/benchmark/kernels/quantization/tuning_block_wise_kernel.py @@ -30,6 +30,7 @@ _w8a8_block_fp8_matmul, _w8a8_block_fp8_matmul_unrolledx4, ) +from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul from sglang.srt.utils import get_device_core_count, get_device_name, is_hip is_hip_ = is_hip() @@ -42,7 +43,7 @@ } -def w8a8_block_fp8_matmul( +def w8a8_block_matmul( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -94,11 +95,15 @@ def grid(META): num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( N, config["BLOCK_SIZE_N"] ) - kernel = ( - _w8a8_block_fp8_matmul_unrolledx4 - if (is_hip_ == True and num_workgroups <= get_device_core_count()) - else _w8a8_block_fp8_matmul - ) + + if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn: + kernel = ( + _w8a8_block_fp8_matmul_unrolledx4 + if (is_hip_ == True and num_workgroups <= get_device_core_count()) + else _w8a8_block_fp8_matmul + ) + else: + kernel = _w8a8_block_int8_matmul kernel[grid]( A, @@ -208,10 +213,10 @@ def get_weight_shapes(tp_size): def benchmark_config( - A_fp8, B_fp8, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10 + A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10 ): def run(): - w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, config, out_dtype) + w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) torch.cuda.synchronize() # JIT complication & warmup @@ -234,20 +239,41 @@ def run(): return avg -def tune(M, N, K, block_size, out_dtype, search_space): +def tune(M, N, K, block_size, out_dtype, search_space, input_type): factor_for_scale = 1e-2 - fp8_info = torch.finfo(torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn) - fp8_max, fp8_min = fp8_info.max, fp8_info.min - A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max - A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to( - torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn - ) + if input_type == "fp8": + fp8_info = torch.finfo( + torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + ) + fp8_max, fp8_min = fp8_info.max, fp8_info.min - B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max - B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to( - torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn - ) + A_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) + A = A_fp32.clamp(min=fp8_min, max=fp8_max).to( + torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + ) + + B_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) + B = B_fp32.clamp(min=fp8_min, max=fp8_max).to( + torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + ) + else: + int8_info = torch.iinfo(torch.int8) + int8_max, int8_min = int8_info.max, int8_info.min + + A_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max + ) + A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) + + B_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max + ) + B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n @@ -264,8 +290,8 @@ def tune(M, N, K, block_size, out_dtype, search_space): for config in tqdm(search_space): try: kernel_time = benchmark_config( - A_fp8, - B_fp8, + A, + B, As, Bs, block_size, @@ -293,10 +319,11 @@ def save_configs( block_k, configs, save_path, + input_type="fp8", ) -> None: os.makedirs(save_path, exist_ok=True) device_name = get_device_name().replace(" ", "_") - json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json" + json_file_name = f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json" config_file_path = os.path.join(save_path, json_file_name) print(f"Writing best config to {config_file_path}...") @@ -325,6 +352,7 @@ def tune_on_gpu(args_dict): block_k = args.block_k out_dtype = DTYPE_MAP[args.out_dtype] save_path = args.save_path + input_type = args.input_type search_space = get_configs_compute_bound() search_space = [ @@ -337,11 +365,19 @@ def tune_on_gpu(args_dict): N, K = shape[0], shape[1] print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`") benchmark_results = [ - tune(batch_size, N, K, [block_n, block_k], out_dtype, search_space) + tune( + batch_size, + N, + K, + [block_n, block_k], + out_dtype, + search_space, + input_type, + ) for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") ] best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)} - save_configs(N, K, block_n, block_k, best_configs, save_path) + save_configs(N, K, block_n, block_k, best_configs, save_path, input_type) end = time.time() print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") @@ -418,6 +454,9 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument("--tp-size", "-tp", type=int, default=8) + parser.add_argument( + "--input-type", type=str, choices=["fp8", "int8"], default="fp8" + ) parser.add_argument( "--out-dtype", type=str, diff --git a/docs/backend/offline_engine_api.ipynb b/docs/backend/offline_engine_api.ipynb index 53a3bb4967c..6a95e59e616 100644 --- a/docs/backend/offline_engine_api.ipynb +++ b/docs/backend/offline_engine_api.ipynb @@ -23,6 +23,17 @@ "Additionally, you can easily build a custom server on top of the SGLang offline engine. A detailed example working in a python script can be found in [custom_server](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/custom_server.py)." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced Usage\n", + "\n", + "The engine supports [vlm inference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/offline_batch_inference_vlm.py) as well as [extracting hidden states](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py). \n", + "\n", + "Please see [the examples](https://github.com/sgl-project/sglang/tree/main/examples/runtime/engine) for further use cases." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -39,14 +50,22 @@ "outputs": [], "source": [ "# launch the offline engine\n", - "from sglang.utils import stream_and_merge, async_stream_and_merge\n", - "import sglang as sgl\n", "import asyncio\n", + "import io\n", + "import os\n", + "\n", + "from PIL import Image\n", + "import requests\n", + "import sglang as sgl\n", + "\n", + "from sglang.srt.conversation import chat_templates\n", "from sglang.test.test_utils import is_in_ci\n", + "from sglang.utils import async_stream_and_merge, stream_and_merge\n", "\n", "if is_in_ci():\n", " import patch\n", "\n", + "\n", "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")" ] }, @@ -185,57 +204,6 @@ "asyncio.run(main())" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "llm.shutdown()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Return Hidden States" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "llm = sgl.Engine(\n", - " model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", return_hidden_states=True\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The president of the United States is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", - "]\n", - "\n", - "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"max_new_tokens\": 10}\n", - "\n", - "outputs = llm.generate(prompts, sampling_params=sampling_params)\n", - "for prompt, output in zip(prompts, outputs):\n", - " print(\"===============================\")\n", - " print(\n", - " f\"Prompt: {prompt}\\nGenerated text: {output['text']}\\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\\tCompletion_tokens: {output['meta_info']['completion_tokens']}\\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}\"\n", - " )\n", - " print()" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/docs/backend/openai_api_vision.ipynb b/docs/backend/openai_api_vision.ipynb index b4ee64ad74e..c868f6a3b11 100644 --- a/docs/backend/openai_api_vision.ipynb +++ b/docs/backend/openai_api_vision.ipynb @@ -26,7 +26,7 @@ "\n", "Launch the server in your terminal and wait for it to initialize.\n", "\n", - "**Remember to add `--chat-template llama_3_vision` to specify the vision chat template, otherwise the server only supports text, and performance degradation may occur.**\n", + "**Remember to add** `--chat-template llama_3_vision` **to specify the vision chat template, otherwise the server only supports text, and performance degradation may occur.**\n", "\n", "We need to specify `--chat-template` for vision language models because the chat template provided in Hugging Face tokenizer only supports text." ] @@ -46,7 +46,7 @@ "\n", "from sglang.utils import wait_for_server, print_highlight, terminate_process\n", "\n", - "embedding_process, port = launch_server_cmd(\n", + "vision_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \\\n", " --chat-template=llama_3_vision\n", @@ -245,7 +245,7 @@ "metadata": {}, "outputs": [], "source": [ - "terminate_process(embedding_process)" + "terminate_process(vision_process)" ] }, { diff --git a/docs/backend/sampling_params.md b/docs/backend/sampling_params.md index 91df324f4df..ef8c8bb5424 100644 --- a/docs/backend/sampling_params.md +++ b/docs/backend/sampling_params.md @@ -55,6 +55,7 @@ Please refer to our dedicated guide on [constrained decoding](https://docs.sglan * `ignore_eos`: Don't stop generation when EOS token is sampled. * `skip_special_tokens`: Remove special tokens during decoding. * `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below. +* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information. ### Custom Logit Processor diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index b7f33f87f9a..7879ada577a 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -52,7 +52,7 @@ Please consult the documentation below to learn more about the parameters you ma * `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.html#Chat-Template). * `is_embedding`: Set to true to perform [embedding](https://docs.sglang.ai/backend/openai_api_embeddings.html) / [encode](https://docs.sglang.ai/backend/native_api.html#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api.html#Classify-(reward-model)) tasks. * `revision`: Adjust if a specific version of the model should be used. -* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. +* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/input_ids.py). * `json_model_override_args`: Override model config with the provided JSON. * `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model. @@ -133,6 +133,7 @@ Please consult the documentation below to learn more about the parameters you ma * `attention_backend`: The backend for attention computation and KV cache management. * `sampling_backend`: The backend for sampling. +* `enable_flashinfer_mla`: The backend for flashinfer MLA wrapper that accelerates deepseek models. (In Experiment Stage) ## Constrained Decoding diff --git a/docs/backend/structured_outputs.ipynb b/docs/backend/structured_outputs.ipynb index 1fe1b05762f..94e8902d66f 100644 --- a/docs/backend/structured_outputs.ipynb +++ b/docs/backend/structured_outputs.ipynb @@ -17,10 +17,13 @@ "\n", "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n", + "- [Llguidance](https://github.com/guidance-ai/llguidance): Supports JSON schema, regular expression, and EBNF constraints.\n", "\n", "We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", "\n", - "To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default.\n", + "To use Xgrammar, simply add `--grammar-backend xgrammar` when launching the server.\n", + "To use llguidance, add `--grammar-backend llguidance` when launching the server.\n", + "If no backend is specified, Outlines will be used as the default.\n", "\n", "For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n" ] diff --git a/docs/frontend/__init__.py b/docs/frontend/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/docs/frontend/frontend.ipynb b/docs/frontend/frontend.ipynb new file mode 100644 index 00000000000..7bc07f66c74 --- /dev/null +++ b/docs/frontend/frontend.ipynb @@ -0,0 +1,464 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SGLang Frontend Language" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "SGLang frontend language can be used to define simple and easy prompts in a convenient, structured way." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Launch A Server\n", + "\n", + "Launch the server in your terminal and wait for it to initialize." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "import os\n", + "\n", + "from sglang import assistant_begin, assistant_end\n", + "from sglang import assistant, function, gen, system, user\n", + "from sglang import image\n", + "from sglang import RuntimeEndpoint, set_default_backend\n", + "from sglang.srt.utils import load_image\n", + "from sglang.test.test_utils import is_in_ci\n", + "from sglang.utils import print_highlight, terminate_process, wait_for_server\n", + "\n", + "if is_in_ci():\n", + " from patch import launch_server_cmd\n", + "else:\n", + " from sglang.utils import launch_server_cmd\n", + "\n", + "\n", + "server_process, port = launch_server_cmd(\n", + " \"python -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")\n", + "print(f\"Server started on http://localhost:{port}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set the default backend. Note: Besides the local server, you may use also `OpenAI` or other API endpoints." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "set_default_backend(RuntimeEndpoint(f\"http://localhost:{port}\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage\n", + "\n", + "The most simple way of using SGLang frontend language is a simple question answer dialog between a user and an assistant." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@function\n", + "def basic_qa(s, question):\n", + " s += system(f\"You are a helpful assistant than can answer questions.\")\n", + " s += user(question)\n", + " s += assistant(gen(\"answer\", max_tokens=512))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "state = basic_qa(\"List 3 countries and their capitals.\")\n", + "print_highlight(state[\"answer\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-turn Dialog\n", + "\n", + "SGLang frontend language can also be used to define multi-turn dialogs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@function\n", + "def multi_turn_qa(s):\n", + " s += system(f\"You are a helpful assistant than can answer questions.\")\n", + " s += user(\"Please give me a list of 3 countries and their capitals.\")\n", + " s += assistant(gen(\"first_answer\", max_tokens=512))\n", + " s += user(\"Please give me another list of 3 countries and their capitals.\")\n", + " s += assistant(gen(\"second_answer\", max_tokens=512))\n", + " return s\n", + "\n", + "\n", + "state = multi_turn_qa()\n", + "print_highlight(state[\"first_answer\"])\n", + "print_highlight(state[\"second_answer\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Control flow\n", + "\n", + "You may use any Python code within the function to define more complex control flows." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@function\n", + "def tool_use(s, question):\n", + " s += assistant(\n", + " \"To answer this question: \"\n", + " + question\n", + " + \". I need to use a \"\n", + " + gen(\"tool\", choices=[\"calculator\", \"search engine\"])\n", + " + \". \"\n", + " )\n", + "\n", + " if s[\"tool\"] == \"calculator\":\n", + " s += assistant(\"The math expression is: \" + gen(\"expression\"))\n", + " elif s[\"tool\"] == \"search engine\":\n", + " s += assistant(\"The key word to search is: \" + gen(\"word\"))\n", + "\n", + "\n", + "state = tool_use(\"What is 2 * 2?\")\n", + "print_highlight(state[\"tool\"])\n", + "print_highlight(state[\"expression\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parallelism\n", + "\n", + "Use `fork` to launch parallel prompts. Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@function\n", + "def tip_suggestion(s):\n", + " s += assistant(\n", + " \"Here are two tips for staying healthy: \"\n", + " \"1. Balanced Diet. 2. Regular Exercise.\\n\\n\"\n", + " )\n", + "\n", + " forks = s.fork(2)\n", + " for i, f in enumerate(forks):\n", + " f += assistant(\n", + " f\"Now, expand tip {i+1} into a paragraph:\\n\"\n", + " + gen(\"detailed_tip\", max_tokens=256, stop=\"\\n\\n\")\n", + " )\n", + "\n", + " s += assistant(\"Tip 1:\" + forks[0][\"detailed_tip\"] + \"\\n\")\n", + " s += assistant(\"Tip 2:\" + forks[1][\"detailed_tip\"] + \"\\n\")\n", + " s += assistant(\n", + " \"To summarize the above two tips, I can say:\\n\" + gen(\"summary\", max_tokens=512)\n", + " )\n", + "\n", + "\n", + "state = tip_suggestion()\n", + "print_highlight(state[\"summary\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Constrained Decoding\n", + "\n", + "Use `regex` to specify a regular expression as a decoding constraint. This is only supported for local models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@function\n", + "def regular_expression_gen(s):\n", + " s += user(\"What is the IP address of the Google DNS servers?\")\n", + " s += assistant(\n", + " gen(\n", + " \"answer\",\n", + " temperature=0,\n", + " regex=r\"((25[0-5]|2[0-4]\\d|[01]?\\d\\d?).){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\",\n", + " )\n", + " )\n", + "\n", + "\n", + "state = regular_expression_gen()\n", + "print_highlight(state[\"answer\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use `regex` to define a `JSON` decoding schema." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "character_regex = (\n", + " r\"\"\"\\{\\n\"\"\"\n", + " + r\"\"\" \"name\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n", + " + r\"\"\" \"house\": \"(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)\",\\n\"\"\"\n", + " + r\"\"\" \"blood status\": \"(Pure-blood|Half-blood|Muggle-born)\",\\n\"\"\"\n", + " + r\"\"\" \"occupation\": \"(student|teacher|auror|ministry of magic|death eater|order of the phoenix)\",\\n\"\"\"\n", + " + r\"\"\" \"wand\": \\{\\n\"\"\"\n", + " + r\"\"\" \"wood\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n", + " + r\"\"\" \"core\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n", + " + r\"\"\" \"length\": [0-9]{1,2}\\.[0-9]{0,2}\\n\"\"\"\n", + " + r\"\"\" \\},\\n\"\"\"\n", + " + r\"\"\" \"alive\": \"(Alive|Deceased)\",\\n\"\"\"\n", + " + r\"\"\" \"patronus\": \"[\\w\\d\\s]{1,16}\",\\n\"\"\"\n", + " + r\"\"\" \"bogart\": \"[\\w\\d\\s]{1,16}\"\\n\"\"\"\n", + " + r\"\"\"\\}\"\"\"\n", + ")\n", + "\n", + "\n", + "@function\n", + "def character_gen(s, name):\n", + " s += user(\n", + " f\"{name} is a character in Harry Potter. Please fill in the following information about this character.\"\n", + " )\n", + " s += assistant(gen(\"json_output\", max_tokens=256, regex=character_regex))\n", + "\n", + "\n", + "state = character_gen(\"Harry Potter\")\n", + "print_highlight(state[\"json_output\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Batching \n", + "\n", + "Use `run_batch` to run a batch of prompts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@function\n", + "def text_qa(s, question):\n", + " s += user(question)\n", + " s += assistant(gen(\"answer\", stop=\"\\n\"))\n", + "\n", + "\n", + "states = text_qa.run_batch(\n", + " [\n", + " {\"question\": \"What is the capital of the United Kingdom?\"},\n", + " {\"question\": \"What is the capital of France?\"},\n", + " {\"question\": \"What is the capital of Japan?\"},\n", + " ],\n", + " progress_bar=True,\n", + ")\n", + "\n", + "for i, state in enumerate(states):\n", + " print_highlight(f\"Answer {i+1}: {states[i]['answer']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming \n", + "\n", + "Use `stream` to stream the output to the user." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@function\n", + "def text_qa(s, question):\n", + " s += user(question)\n", + " s += assistant(gen(\"answer\", stop=\"\\n\"))\n", + "\n", + "\n", + "state = text_qa.run(\n", + " question=\"What is the capital of France?\", temperature=0.1, stream=True\n", + ")\n", + "\n", + "for out in state.text_iter():\n", + " print(out, end=\"\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Complex Prompts\n", + "\n", + "You may use `{system|user|assistant}_{begin|end}` to define complex prompts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@function\n", + "def chat_example(s):\n", + " s += system(\"You are a helpful assistant.\")\n", + " # Same as: s += s.system(\"You are a helpful assistant.\")\n", + "\n", + " with s.user():\n", + " s += \"Question: What is the capital of France?\"\n", + "\n", + " s += assistant_begin()\n", + " s += \"Answer: \" + gen(\"answer\", max_tokens=100, stop=\"\\n\")\n", + " s += assistant_end()\n", + "\n", + "\n", + "state = chat_example()\n", + "print_highlight(state[\"answer\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Multi-modal Generation\n", + "\n", + "You may use SGLang frontend language to define multi-modal prompts.\n", + "See [here](https://docs.sglang.ai/references/supported_models.html) for supported models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server_process, port = launch_server_cmd(\n", + " \"python -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B-Instruct --host 0.0.0.0\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")\n", + "print(f\"Server started on http://localhost:{port}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "set_default_backend(RuntimeEndpoint(f\"http://localhost:{port}\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Ask a question about an image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@function\n", + "def image_qa(s, image_file, question):\n", + " s += user(image(image_file) + question)\n", + " s += assistant(gen(\"answer\", max_tokens=256))\n", + "\n", + "\n", + "image_url = \"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true\"\n", + "image_bytes, _ = load_image(image_url)\n", + "state = image_qa(image_bytes, \"What is in the image?\")\n", + "print_highlight(state[\"answer\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/frontend/frontend.md b/docs/frontend/frontend.md deleted file mode 100644 index 1a02d6adb86..00000000000 --- a/docs/frontend/frontend.md +++ /dev/null @@ -1,238 +0,0 @@ -# Structured Generation Language -The frontend language can be used with local models or API models. It is an alternative to the OpenAI API. You may find it easier to use for complex prompting workflow. - -## Quick Start -The example below shows how to use SGLang to answer a multi-turn question. - -### Using Local Models -First, launch a server with -``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 -``` - -Then, connect to the server and answer a multi-turn question. - -```python -from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint - -@function -def multi_turn_question(s, question_1, question_2): - s += system("You are a helpful assistant.") - s += user(question_1) - s += assistant(gen("answer_1", max_tokens=256)) - s += user(question_2) - s += assistant(gen("answer_2", max_tokens=256)) - -set_default_backend(RuntimeEndpoint("http://localhost:30000")) - -state = multi_turn_question.run( - question_1="What is the capital of the United States?", - question_2="List two local attractions.", -) - -for m in state.messages(): - print(m["role"], ":", m["content"]) - -print(state["answer_1"]) -``` - -### Using OpenAI Models -Set the OpenAI API Key -``` -export OPENAI_API_KEY=sk-****** -``` - -Then, answer a multi-turn question. -```python -from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI - -@function -def multi_turn_question(s, question_1, question_2): - s += system("You are a helpful assistant.") - s += user(question_1) - s += assistant(gen("answer_1", max_tokens=256)) - s += user(question_2) - s += assistant(gen("answer_2", max_tokens=256)) - -set_default_backend(OpenAI("gpt-3.5-turbo")) - -state = multi_turn_question.run( - question_1="What is the capital of the United States?", - question_2="List two local attractions.", -) - -for m in state.messages(): - print(m["role"], ":", m["content"]) - -print(state["answer_1"]) -``` - -### More Examples -Anthropic and VertexAI (Gemini) models are also supported. -You can find more examples at [examples/quick_start](https://github.com/sgl-project/sglang/tree/main/examples/frontend_language/quick_start). - -## Language Feature -To begin with, import sglang. -```python -import sglang as sgl -``` - -`sglang` provides some simple primitives such as `gen`, `select`, `fork`, `image`. -You can implement your prompt flow in a function decorated by `sgl.function`. -You can then invoke the function with `run` or `run_batch`. -The system will manage the state, chat template, parallelism and batching for you. - -The complete code for the examples below can be found at [readme_examples.py](https://github.com/sgl-project/sglang/blob/main/examples/frontend_language/usage/readme_examples.py) - -### Control Flow -You can use any Python code within the function body, including control flow, nested function calls, and external libraries. - -```python -@sgl.function -def tool_use(s, question): - s += "To answer this question: " + question + ". " - s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". " - - if s["tool"] == "calculator": - s += "The math expression is" + sgl.gen("expression") - elif s["tool"] == "search engine": - s += "The key word to search is" + sgl.gen("word") -``` - -### Parallelism -Use `fork` to launch parallel prompts. -Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel. - -```python -@sgl.function -def tip_suggestion(s): - s += ( - "Here are two tips for staying healthy: " - "1. Balanced Diet. 2. Regular Exercise.\n\n" - ) - - forks = s.fork(2) - for i, f in enumerate(forks): - f += f"Now, expand tip {i+1} into a paragraph:\n" - f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n") - - s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" - s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" - s += "In summary" + sgl.gen("summary") -``` - -### Multi-Modality -Use `sgl.image` to pass an image as input. - -```python -@sgl.function -def image_qa(s, image_file, question): - s += sgl.user(sgl.image(image_file) + question) - s += sgl.assistant(sgl.gen("answer", max_tokens=256) -``` - -See also [local_example_llava_next.py](https://github.com/sgl-project/sglang/blob/main/examples/frontend_language/quick_start/local_example_llava_next.py). - -### Constrained Decoding -Use `regex` to specify a regular expression as a decoding constraint. -This is only supported for local models. - -```python -@sgl.function -def regular_expression_gen(s): - s += "Q: What is the IP address of the Google DNS servers?\n" - s += "A: " + sgl.gen( - "answer", - temperature=0, - regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)", - ) -``` - -### JSON Decoding -Use `regex` to specify a JSON schema with a regular expression. - -```python -character_regex = ( - r"""\{\n""" - + r""" "name": "[\w\d\s]{1,16}",\n""" - + r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n""" - + r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n""" - + r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n""" - + r""" "wand": \{\n""" - + r""" "wood": "[\w\d\s]{1,16}",\n""" - + r""" "core": "[\w\d\s]{1,16}",\n""" - + r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n""" - + r""" \},\n""" - + r""" "alive": "(Alive|Deceased)",\n""" - + r""" "patronus": "[\w\d\s]{1,16}",\n""" - + r""" "bogart": "[\w\d\s]{1,16}"\n""" - + r"""\}""" -) - -@sgl.function -def character_gen(s, name): - s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n" - s += sgl.gen("json_output", max_tokens=256, regex=character_regex) -``` - -See also [json_decode.py](https://github.com/sgl-project/sglang/blob/main/examples/frontend_language/usage/json_decode.py) for an additional example of specifying formats with Pydantic models. - -### Batching -Use `run_batch` to run a batch of requests with continuous batching. - -```python -@sgl.function -def text_qa(s, question): - s += "Q: " + question + "\n" - s += "A:" + sgl.gen("answer", stop="\n") - -states = text_qa.run_batch( - [ - {"question": "What is the capital of the United Kingdom?"}, - {"question": "What is the capital of France?"}, - {"question": "What is the capital of Japan?"}, - ], - progress_bar=True -) -``` - -### Streaming -Add `stream=True` to enable streaming. - -```python -@sgl.function -def text_qa(s, question): - s += "Q: " + question + "\n" - s += "A:" + sgl.gen("answer", stop="\n") - -state = text_qa.run( - question="What is the capital of France?", - temperature=0.1, - stream=True -) - -for out in state.text_iter(): - print(out, end="", flush=True) -``` - -### Roles - -Use `sgl.system`, `sgl.user` and `sgl.assistant` to set roles when using Chat models. You can also define more complex role prompts using begin and end tokens. - -```python -@sgl.function -def chat_example(s): - s += sgl.system("You are a helpful assistant.") - # Same as: s += s.system("You are a helpful assistant.") - - with s.user(): - s += "Question: What is the capital of France?" - - s += sgl.assistant_begin() - s += "Answer: " + sgl.gen(max_tokens=100, stop="\n") - s += sgl.assistant_end() -``` - -### Tips and Implementation Details -- The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability. -- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`. diff --git a/docs/frontend/patch.py b/docs/frontend/patch.py new file mode 100644 index 00000000000..d16422d08f0 --- /dev/null +++ b/docs/frontend/patch.py @@ -0,0 +1,50 @@ +import os +import weakref + +from sglang.utils import execute_shell_command, reserve_port + +DEFAULT_MAX_RUNNING_REQUESTS = 200 +DEFAULT_MAX_TOTAL_TOKENS = 20480 + +import sglang.srt.server_args as server_args_mod + +_original_post_init = server_args_mod.ServerArgs.__post_init__ + + +def patched_post_init(self): + _original_post_init(self) + if self.max_running_requests is None: + self.max_running_requests = DEFAULT_MAX_RUNNING_REQUESTS + if self.max_total_tokens is None: + self.max_total_tokens = DEFAULT_MAX_TOTAL_TOKENS + self.disable_cuda_graph = True + + +server_args_mod.ServerArgs.__post_init__ = patched_post_init + +process_socket_map = weakref.WeakKeyDictionary() + + +def launch_server_cmd(command: str, host: str = "0.0.0.0", port: int = None): + """ + Launch the server using the given command. + If no port is specified, a free port is reserved. + """ + if port is None: + port, lock_socket = reserve_port(host) + else: + lock_socket = None + + extra_flags = ( + f"--max-running-requests {DEFAULT_MAX_RUNNING_REQUESTS} " + f"--max-total-tokens {DEFAULT_MAX_TOTAL_TOKENS} " + f"--disable-cuda-graph" + ) + + full_command = f"{command} --port {port} {extra_flags}" + process = execute_shell_command(full_command) + + if lock_socket is not None: + process_socket_map[process] = lock_socket + + return process, port diff --git a/docs/index.rst b/docs/index.rst index a3fbd3c106c..8553b5e47b6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -30,7 +30,6 @@ The core features include: backend/sampling_params.md backend/hyperparameter_tuning.md - .. toctree:: :maxdepth: 1 :caption: Advanced Features @@ -45,7 +44,7 @@ The core features include: :maxdepth: 1 :caption: Frontend Tutorial - frontend/frontend.md + frontend/frontend.ipynb frontend/choices_methods.md .. toctree:: @@ -58,6 +57,7 @@ The core features include: :maxdepth: 1 :caption: References + references/deepseek references/general references/hardware references/advanced_deploy diff --git a/docs/references/advanced_deploy.rst b/docs/references/advanced_deploy.rst index 24f46c4aa2d..b1059015dc9 100644 --- a/docs/references/advanced_deploy.rst +++ b/docs/references/advanced_deploy.rst @@ -3,6 +3,5 @@ Multi-Node Deployment .. toctree:: :maxdepth: 1 - deepseek.md multi_node.md k8s.md diff --git a/docs/references/amd.md b/docs/references/amd.md index 4b1c8230d06..4f88b137373 100644 --- a/docs/references/amd.md +++ b/docs/references/amd.md @@ -124,6 +124,8 @@ drun -p 30000:30000 \ --port 30000 ``` +[Running DeepSeek-R1 on a single NDv5 MI300X VM](https://techcommunity.microsoft.com/blog/azurehighperformancecomputingblog/running-deepseek-r1-on-a-single-ndv5-mi300x-vm/4372726) could also be a good reference. + ### Running Llama3.1 Running Llama3.1 is nearly identical. The only difference is in the model specified when starting the server, shown by the following example command: diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 0267ab5a877..ad180d1bdba 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -1,10 +1,59 @@ -# DeepSeek Model Usage and Optimizations +# DeepSeek Usage SGLang provides several optimizations specifically designed for the DeepSeek model to boost its inference speed. This document outlines current optimizations for DeepSeek. Additionally, the SGLang team is actively developing enhancements for [DeepSeek V3](https://github.com/sgl-project/sglang/issues/2591). ## Launch DeepSeek V3 with SGLang -SGLang is recognized as one of the top engines for [DeepSeek model inference](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3). Refer to [installation and launch](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#installation--launch) to learn how to run fast inference of DeepSeek V3/R1 with SGLang. +SGLang is recognized as one of the top engines for [DeepSeek model inference](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3). To run DeepSeek V3/R1 models, the requirements are as follows: + +| Weight Type | Configuration | +|------------|-------------------| +| **Full precision FP8**
*(recommended)* | 8 x H200 | +| | 8 x MI300X | +| | 2 x 8 x H100/800/20 | +| **Full precision BF16** | 2 x 8 x H200 | +| | 2 x 8 x MI300X | +| | 4 x 8 x H100/800/20 | +| | 4 x 8 x A100/A800 | +| **Quantized weights (AWQ)** | 8 x H100/800/20 | +| | 8 x A100/A800 | + + + +Detailed commands for reference: + +- [8 x H200](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#using-docker-recommended) +- [8 x MI300X](https://docs.sglang.ai/references/amd.html#running-deepseek-v3) +- [2 x 8 x H200](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h208-nodes) +- [4 x 8 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-four-a1008-nodes) +- [8 x A100 (AWQ)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-8-a100a800-with-awq-quantization) ### Download Weights @@ -34,7 +83,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **Triton Decoding Kernel Optimization**: In the MLA decoding kernel, there is only one KV head. This optimization reduces memory access to the KV cache by processing multiple query heads within one block, accelerating the decoding process. +- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). (In Experiment Stage) - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. @@ -84,6 +133,6 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o ## FAQ -**Question**: What should I do if model loading takes too long and NCCL timeout occurs? +1. **Question**: What should I do if model loading takes too long and NCCL timeout occurs? -Answer: You can try to add `--dist-timeout 3600` when launching the model, this allows for 1-hour timeout.i + **Answer**: You can try to add `--dist-timeout 3600` when launching the model, this allows for 1-hour timeout. diff --git a/docs/references/deepseek.rst b/docs/references/deepseek.rst new file mode 100644 index 00000000000..b45383a4b3b --- /dev/null +++ b/docs/references/deepseek.rst @@ -0,0 +1,6 @@ +Multi-Node Deployment +========================== +.. toctree:: + :maxdepth: 1 + + deepseek.md diff --git a/docs/router/router.md b/docs/router/router.md index 051233264d4..feda3932992 100644 --- a/docs/router/router.md +++ b/docs/router/router.md @@ -27,11 +27,13 @@ The router supports two working modes: This will be a drop-in replacement for the existing `--dp-size` argument of SGLang Runtime. Under the hood, it uses multi-processes to launch multiple workers, wait for them to be ready, then connect the router to all workers. ```bash -$ python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dp-size 1 +python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dp-size 4 ``` After the server is ready, you can directly send requests to the router as the same way as sending requests to each single worker. +Please adjust the batchsize accordingly to archieve maximum throughput. + ```python import requests @@ -47,7 +49,7 @@ print(response.json()) This is useful for multi-node DP. First, launch workers on multiple nodes, then launch a router on the main node, and connect the router to all workers. ```bash -$ python -m sglang_router.launch_router --worker-urls http://worker_url_1 http://worker_url_2 +python -m sglang_router.launch_router --worker-urls http://worker_url_1 http://worker_url_2 ``` ## Dynamic Scaling APIs @@ -59,15 +61,17 @@ We offer `/add_worker` and `/remove_worker` APIs to dynamically add or remove wo Usage: ```bash -$ curl -X POST http://localhost:30000/add_worker?url=http://worker_url_1 +curl -X POST http://localhost:30000/add_worker?url=http://worker_url_1 ``` Example: ```bash -$ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30001 -$ curl -X POST http://localhost:30000/add_worker?url=http://127.0.0.1:30001 -Successfully added worker: http://127.0.0.1:30001 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30001 + +curl -X POST http://localhost:30000/add_worker?url=http://127.0.0.1:30001 + +# Successfully added worker: http://127.0.0.1:30001 ``` - `/remove_worker` @@ -75,14 +79,15 @@ Successfully added worker: http://127.0.0.1:30001 Usage: ```bash -$ curl -X POST http://localhost:30000/remove_worker?url=http://worker_url_1 +curl -X POST http://localhost:30000/remove_worker?url=http://worker_url_1 ``` Example: ```bash -$ curl -X POST http://localhost:30000/remove_worker?url=http://127.0.0.1:30001 -Successfully removed worker: http://127.0.0.1:30001 +curl -X POST http://localhost:30000/remove_worker?url=http://127.0.0.1:30001 + +# Successfully removed worker: http://127.0.0.1:30001 ``` Note: diff --git a/docs/start/install.md b/docs/start/install.md index 8d58f155773..55b084d381c 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -4,11 +4,13 @@ You can install SGLang using any of the methods below. For running DeepSeek V3/R1, refer to [DeepSeek V3 Support](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3). It is recommended to use the [latest version](https://pypi.org/project/sglang/#history) and deploy it with [Docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#using-docker-recommended) to avoid environment-related problems. -## Method 1: With pip +We recommend using uv to install the dependencies with a higher installation speed: +## Method 1: With pip or uv ```bash pip install --upgrade pip -pip install "sglang[all]>=0.4.3.post2" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python +pip install uv +uv pip install "sglang[all]>=0.4.3.post2" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python ``` **Quick Fixes to Installation** @@ -141,4 +143,4 @@ sky status --endpoint 30000 sglang - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please switch to other kernels by adding `--attention-backend triton --sampling-backend pytorch` and open an issue on GitHub. - If you only need to use OpenAI models with the frontend language, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. - The language frontend operates independently of the backend runtime. You can install the frontend locally without needing a GPU, while the backend can be set up on a GPU-enabled machine. To install the frontend, run `pip install sglang`, and for the backend, use `pip install sglang[srt]`. `srt` is the abbreviation of SGLang runtime. -- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python>=0.2.1.post2" -i https://flashinfer.ai/whl/cu124/torch2.5 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`. +- To reinstall flashinfer locally, use the following command: `pip install "flashinfer-python>=0.2.2.post1" -i https://flashinfer.ai/whl/cu124/torch2.5 --force-reinstall --no-deps` and then delete the cache with `rm -rf ~/.cache/flashinfer`. diff --git a/examples/runtime/engine/hidden_states.py b/examples/runtime/engine/hidden_states.py new file mode 100644 index 00000000000..9c7b89b740f --- /dev/null +++ b/examples/runtime/engine/hidden_states.py @@ -0,0 +1,44 @@ +""" +Usage: +python hidden_states.py + +Note that each time you change the `return_hidden_states` parameter, +the cuda graph will be recaptured, which might lead to a performance hit. +So avoid getting hidden states and completions alternately. +""" + +import sglang as sgl + + +def main(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + # Create an LLM. + llm = sgl.Engine( + model_path="Alibaba-NLP/gte-Qwen2-1.5B-instruct", + ) + + sampling_params = { + "temperature": 0.8, + "top_p": 0.95, + "max_new_tokens": 10, + "return_hidden_states": True, + } + + outputs = llm.generate(prompts, sampling_params=sampling_params) + for prompt, output in zip(prompts, outputs): + print("===============================") + print( + f"Prompt: {prompt}\nGenerated text: {output['text']}\nPrompt_Tokens: {output['meta_info']['prompt_tokens']}\tCompletion_tokens: {output['meta_info']['completion_tokens']}\nHidden states: {[i.shape for i in output['meta_info']['hidden_states']]}" + ) + print() + + +# The __main__ condition is necessary here because we use "spawn" to create subprocesses +# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine +if __name__ == "__main__": + main() diff --git a/examples/runtime/engine/offline_batch_inference_vlm.py b/examples/runtime/engine/offline_batch_inference_vlm.py index 808d0fce9b7..28ab7a2efdb 100644 --- a/examples/runtime/engine/offline_batch_inference_vlm.py +++ b/examples/runtime/engine/offline_batch_inference_vlm.py @@ -5,56 +5,45 @@ import argparse import dataclasses +import io +import os -from transformers import AutoProcessor +import requests +from PIL import Image import sglang as sgl -from sglang.srt.openai_api.adapter import v1_chat_generate_request -from sglang.srt.openai_api.protocol import ChatCompletionRequest +from sglang.srt.conversation import chat_templates from sglang.srt.server_args import ServerArgs def main( server_args: ServerArgs, ): - # Create an LLM. vlm = sgl.Engine(**dataclasses.asdict(server_args)) - # prepare prompts. - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What’s in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true", - }, - }, - ], - } - ] - chat_request = ChatCompletionRequest( - messages=messages, - model=server_args.model_path, - temperature=0.8, - top_p=0.95, - ) - gen_request, _ = v1_chat_generate_request( - [chat_request], - vlm.tokenizer_manager, - ) + conv = chat_templates[server_args.chat_template].copy() + image_token = conv.image_token + + image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + + prompt = f"What's in this image?\n{image_token}" - outputs = vlm.generate( - input_ids=gen_request.input_ids, - image_data=gen_request.image_data, - sampling_params=gen_request.sampling_params, + sampling_params = { + "temperature": 0.001, + "max_new_tokens": 30, + } + + output = vlm.generate( + prompt=prompt, + image_data=image_url, + sampling_params=sampling_params, ) print("===============================") - print(f"Prompt: {messages[0]['content'][0]['text']}") - print(f"Generated text: {outputs['text']}") + print(f"Prompt: {prompt}") + print(f"Generated text: {output['text']}") + + vlm.shutdown() # The __main__ condition is necessary here because we use "spawn" to create subprocesses @@ -63,5 +52,6 @@ def main( parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) main(server_args) diff --git a/python/pyproject.toml b/python/pyproject.toml index 91430603cfd..1c762bbcb12 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -35,14 +35,15 @@ runtime_common = [ "torchao>=0.7.0", "uvicorn", "uvloop", - "xgrammar==0.1.10", + "xgrammar==0.1.14", "ninja", "transformers==4.48.3", + "llguidance>=0.6.15" ] srt = [ "sglang[runtime_common]", "sgl-kernel>=0.0.3.post6", - "flashinfer_python>=0.2.1.post2", + "flashinfer_python>=0.2.2.post1", "torch==2.5.1", "vllm>=0.6.4.post1,<=0.7.2", "cuda-python", diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index 6f304ea171e..82a647fa02e 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -86,6 +86,13 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size): from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size) + elif server_args.grammar_backend == "llguidance": + from sglang.srt.constrained.llguidance_backend import GuidanceBackend + + grammar_backend = GuidanceBackend( + tokenizer=tokenizer, + whitespace_pattern=server_args.constrained_json_whitespace_pattern, + ) else: raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") diff --git a/python/sglang/srt/constrained/llguidance_backend.py b/python/sglang/srt/constrained/llguidance_backend.py new file mode 100644 index 00000000000..5d2b69790db --- /dev/null +++ b/python/sglang/srt/constrained/llguidance_backend.py @@ -0,0 +1,146 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constrained decoding with llguidance backend.""" + +import json +import os +from typing import List, Optional, Tuple + +import llguidance +import llguidance.hf +import llguidance.torch +import torch +from llguidance.gbnf_to_lark import any_to_lark + +from sglang.srt.constrained.base_grammar_backend import ( + BaseGrammarBackend, + BaseGrammarObject, +) + + +class GuidanceGrammar(BaseGrammarObject): + def __init__( + self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str + ): + self.llguidance_tokenizer = llguidance_tokenizer + self.serialized_grammar = serialized_grammar + + # TODO: add support for fast-forward tokens in the future + self.ll_interpreter = llguidance.LLInterpreter( + self.llguidance_tokenizer, + self.serialized_grammar, + enable_backtrack=False, + enable_ff_tokens=False, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + self.pending_ff_tokens: list[int] = [] + self.finished = False + self.bitmask = None + + def try_jump_forward(self, tokenizer) -> Tuple[List[int], str]: + if len(self.pending_ff_tokens) > 0: + s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens) + ff_tokens = self.pending_ff_tokens + self.pending_ff_tokens = [] + return (ff_tokens, s) + + return None + + def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]: + return "", -1 + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + pass + + def accept_token(self, token: int): + backtrack, ff_tokens = self.ll_interpreter.commit_token(token) + if len(ff_tokens) > 0 and backtrack == 0: + # first token is last generated token + ff_tokens = ff_tokens[1:] + self.pending_ff_tokens.extend(ff_tokens) + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + if len(self.pending_ff_tokens) > 0: + # if we have pending fast-forward tokens, + # just return them immediately + ff_token = self.pending_ff_tokens.pop(0) + vocab_mask[idx, :] = 0 + vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32) + return + + if self.ll_interpreter.has_pending_stop(): + self.finished = True + + llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx) + + def allocate_vocab_mask( + self, vocab_size: int, batch_size: int, device + ) -> torch.Tensor: + if self.bitmask is None or self.bitmask.shape[0] < batch_size: + # only create bitmask when batch gets larger + self.bitmask = llguidance.torch.allocate_token_bitmask( + batch_size, self.llguidance_tokenizer.vocab_size + ) + bitmask = self.bitmask + else: + bitmask = self.bitmask[:batch_size] + + return bitmask + + @staticmethod + def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor: + return vocab_mask.to(device, non_blocking=True) + + @staticmethod + def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: + llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask) + + def copy(self): + return GuidanceGrammar( + llguidance_tokenizer=self.llguidance_tokenizer, + serialized_grammar=self.serialized_grammar, + ) + + +class GuidanceBackend(BaseGrammarBackend): + def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None): + super().__init__() + + self.tokenizer = tokenizer + self.whitespace_flexible = ( + True if whitespace_pattern == "whitespace_flexible" else False + ) + self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) + + def init_value_impl(self, key: Tuple[str, str]) -> GuidanceGrammar: + mode, value = key + if mode == "json": + json_schema = value + compiler = llguidance.JsonCompiler( + whitespace_flexible=self.whitespace_flexible + ) + serialized_grammar = compiler.compile(json_schema) + elif mode == "regex": + compiler = llguidance.RegexCompiler() + serialized_grammar = compiler.compile(regex=value) + elif mode == "ebnf": + compiler = llguidance.LarkCompiler() + serialized_grammar = compiler.compile(any_to_lark(value)) + + return GuidanceGrammar( + llguidance_tokenizer=self.llguidance_tokenizer, + serialized_grammar=serialized_grammar, + ) diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index c423a567eda..7bf14bfc285 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -19,7 +19,6 @@ import torch from xgrammar import ( CompiledGrammar, - Grammar, GrammarCompiler, GrammarMatcher, TokenizerInfo, @@ -135,9 +134,7 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: return None elif key_type == "regex": try: - ctx = self.grammar_compiler.compile_grammar( - Grammar.from_regex(key_string) - ) + ctx = self.grammar_compiler.compile_regex(key_string) except RuntimeError as e: logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") return None diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 1fb2f7c64ca..671b8f2c3f5 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -330,7 +330,7 @@ def _set_envs_and_config(server_args: ServerArgs): if server_args.attention_backend == "flashinfer": assert_pkg_version( "flashinfer_python", - "0.2.1.post2", + "0.2.2.post1", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", diff --git a/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py b/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py index 2182445012e..4c3e6396828 100644 --- a/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +++ b/python/sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py @@ -230,7 +230,7 @@ def _fwd_grouped_kernel_stage1_rope( other=0.0, ) # positional embedding part of keys - if USE_ROPE and start_n >= cur_batch_seq_len - BLOCK_N: + if (USE_ROPE and LAST_SPLIT) and start_n >= cur_batch_seq_len - BLOCK_N: k_pe = tl.where( offs_n[None, :] != (split_kv_end - 1), k_pe, diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index e0486891aa7..ae7d13ea593 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -1,10 +1,17 @@ import logging -from typing import Optional +from typing import List, Optional import torch import triton import triton.language as tl +from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 + +_is_cuda = torch.cuda.is_available() and torch.version.cuda +if _is_cuda: + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) logger = logging.getLogger(__name__) @@ -218,12 +225,19 @@ def grouped_gemm_triton_kernel( seg_indptr, weight_indices, m_num_tiles_indptr, - use_fp8_w8a8, scale_a, scale_b, + use_fp8_w8a8: tl.constexpr, + group_n: tl.constexpr, + group_k: tl.constexpr, a_stride_0: tl.constexpr, b_stride_0: tl.constexpr, b_stride_1: tl.constexpr, + as_stride_0: tl.constexpr, + as_stride_1: tl.constexpr, + bs_stride_0: tl.constexpr, + bs_stride_2: tl.constexpr, + bs_stride_1: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -260,6 +274,12 @@ def grouped_gemm_triton_kernel( + (n_range_start + offs_bn[:, None]) * b_stride_1 + offs_k[None, :] ) + + if group_k > 0 and group_n > 0: + a_scale_ptrs = scale_a + (m_range_start + offs_am[:, None]) * as_stride_0 + offs_bsn = (n_range_start + offs_bn) // group_n + b_scale_ptrs = scale_b + (expert_id * bs_stride_0) + offs_bsn * bs_stride_1 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a_tile = tl.load( @@ -268,14 +288,23 @@ def grouped_gemm_triton_kernel( b_tile = tl.load( b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0 ) - accumulator = tl.dot(a_tile, b_tile.T, accumulator) + + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load(a_scale_ptrs + offs_ks * as_stride_1) + b_scale = tl.load(b_scale_ptrs + offs_ks * bs_stride_2) + accumulator += tl.dot(a_tile, b_tile.T) * a_scale * b_scale[None, :] + else: + accumulator = tl.dot(a_tile, b_tile.T, accumulator) a_ptr += BLOCK_SIZE_K b_ptr += BLOCK_SIZE_K - if use_fp8_w8a8: + if use_fp8_w8a8 and not (group_k > 0 and group_n > 0): scale_a_value = tl.load(scale_a + expert_id) scale_b_value = tl.load(scale_b + expert_id) accumulator *= scale_a_value * scale_b_value + c_tile = accumulator.to(c_dtype) offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M) @@ -307,14 +336,29 @@ def grouped_gemm_triton( use_fp8_w8a8: bool = False, scale_a: torch.Tensor = None, scale_b: torch.Tensor = None, + block_shape: Optional[List[int]] = None, ): assert weight_column_major == True # TODO: more - if use_fp8_w8a8: + if use_fp8_w8a8 and block_shape is None: assert scale_a is not None and scale_b is not None + if block_shape is not None: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + if _is_cuda: + a, scale_a = sglang_per_token_group_quant_fp8(a, block_k) + else: + a, scale_a = per_token_group_quant_fp8(a, block_k) + + assert triton.cdiv(a.shape[-1], block_k) == scale_a.shape[-1] + assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2] + assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1] + + # TODO: adjust config or tune kernel + # Reduce block size to prevent L40 shared memory overflow. config = { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, } @@ -338,12 +382,19 @@ def grouped_gemm_triton( seg_indptr, weight_indices, m_num_tiles_indptr, - use_fp8_w8a8, scale_a, scale_b, + use_fp8_w8a8, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], a.stride(0), b.stride(0), b.stride(1), + scale_a.stride(0) if scale_a is not None and scale_a.ndim == 2 else 0, + scale_a.stride(1) if scale_a is not None and scale_a.ndim == 2 else 0, + scale_b.stride(0) if scale_b is not None and scale_b.ndim >= 2 else 0, + scale_b.stride(2) if scale_b is not None and scale_b.ndim == 3 else 0, + scale_b.stride(1) if scale_b is not None and scale_b.ndim >= 2 else 0, **config, ) return c diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 4d6040646b3..7468c0b9192 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -17,6 +17,7 @@ run_moe_ep_preproess, silu_and_mul_triton_kernel, ) +from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( @@ -61,6 +62,7 @@ def forward( use_fp8_w8a8: bool = False, scale_a: torch.Tensor = None, scale_b: torch.Tensor = None, + block_shape: Optional[List[int]] = None, ): if self.use_flashinfer: # TODO: flashinfer @@ -87,6 +89,7 @@ def forward( use_fp8_w8a8, scale_a, scale_b, + block_shape=block_shape, ) return c @@ -147,12 +150,20 @@ def __init__( if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() self.use_fp8_w8a8 = False + self.use_block_quant = False + self.block_shape = None self.activation_scheme = None else: self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod( quant_config ) self.use_fp8_w8a8 = True + self.use_block_quant = getattr(self.quant_method, "block_quant", False) + self.block_shape = ( + self.quant_method.quant_config.weight_block_size + if self.use_block_quant + else None + ) self.fp8_dtype = torch.float8_e4m3fn self.activation_scheme = quant_config.activation_scheme @@ -173,7 +184,8 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): if self.grouped_gemm_runner is None: self.grouped_gemm_runner = GroupedGemmRunner( - hidden_states.device, use_flashinfer=False # TODO: use flashinfer + hidden_states.device, + use_flashinfer=False, # TODO: use flashinfer ) topk_weights, topk_ids = select_experts( @@ -195,9 +207,13 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): gateup_input = torch.empty( (int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]), device=hidden_states.device, - dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + dtype=( + self.fp8_dtype + if (self.use_fp8_w8a8 and not self.use_block_quant) + else hidden_states.dtype + ), ) - if self.activation_scheme == "dynamic": + if self.activation_scheme == "dynamic" and not self.use_block_quant: max_value = ( torch.max(hidden_states) .repeat(self.num_experts_per_partition) @@ -243,7 +259,12 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): weight_indices=weight_indices_cur_rank, use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w13_input_scale, - scale_b=self.w13_weight_scale, + scale_b=( + self.w13_weight_scale_inv + if self.use_block_quant + else self.w13_weight_scale + ), + block_shape=self.block_shape, ) # Act @@ -251,9 +272,13 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): gateup_output.shape[0], gateup_output.shape[1] // 2, device=gateup_output.device, - dtype=self.fp8_dtype if self.use_fp8_w8a8 else hidden_states.dtype, + dtype=( + self.fp8_dtype + if (self.use_fp8_w8a8 and not self.use_block_quant) + else hidden_states.dtype + ), ) - if self.w2_input_scale is None: + if self.w2_input_scale is None and not self.use_block_quant: self.w2_input_scale = torch.ones( self.num_experts_per_partition, dtype=torch.float32, @@ -291,7 +316,12 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): weight_indices=weight_indices_cur_rank, use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w2_input_scale, - scale_b=self.w2_weight_scale, + scale_b=( + self.w2_weight_scale_inv + if self.use_block_quant + else self.w2_weight_scale + ), + block_shape=self.block_shape, ) # PostReorder @@ -358,7 +388,11 @@ def weight_loader( # Special case for fp8 scales. if "scale" in weight_name: self._load_fp8_scale( - param.data, loaded_weight, weight_name, shard_id, expert_id + param.data, + loaded_weight, + weight_name, + shard_id, + expert_id, ) return @@ -395,18 +429,33 @@ def _load_fp8_scale( param_data[expert_id] = loaded_weight # Weight scales elif "weight_scale" in weight_name: + if self.use_block_quant: + block_n, block_k = self.block_shape[0], self.block_shape[1] + if shard_id == "w1": + param_data[expert_id][ + : (self.intermediate_size + block_n - 1) // block_n, : + ] = loaded_weight + elif shard_id == "w3": + param_data[expert_id][ + (self.intermediate_size + block_n - 1) // block_n :, : + ] = loaded_weight + else: # w2 + param_data[expert_id] = loaded_weight # If we are in merged column case (gate_up_proj) - if shard_id in ("w1", "w3"): - # We have to keep the weight scales of w1 and w3 because - # we need to re-quantize w1/w3 weights after weight loading. - idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight - # If we are in the row parallel case (down_proj) else: - param_data[expert_id] = loaded_weight + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + + # If we are in the row parallel case (down_proj) + else: + param_data[expert_id] = loaded_weight class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): + def create_weights( self, layer: torch.nn.Module, @@ -498,6 +547,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None def create_weights( self, @@ -512,6 +562,29 @@ def create_weights( if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + # Required by collum parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1: + # Required by row parallel + if intermediate_size % block_k != 0: + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -538,21 +611,49 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, 2, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) + if self.block_quant: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts_per_partition, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts_per_partition, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" + else: + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, 2, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts_per_partition, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts_per_partition, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly - extra_weight_attrs.update({"quant_method": "tensor"}) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + if self.block_quant + else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..1cd253d118e --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..6ae89c7571a --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..653bb997ee1 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..7d8044772f8 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..5011fd572c4 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..7b10d6d7814 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..d72288781e4 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..933170b7be9 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..13e3c20b748 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..5e40bbe6471 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..246062a221d --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..bfdff8adcbd --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..b2cab083111 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..6ae00329ecc --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f4ffed10b1d..ea72804854e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -607,9 +607,6 @@ class ScheduleBatch: # Enable custom logit processor enable_custom_logit_processor: bool = False - # Return hidden states - return_hidden_states: bool = False - @classmethod def init_new( cls, @@ -621,7 +618,6 @@ def init_new( enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, enable_custom_logit_processor: bool, - return_hidden_states: bool = False, ): return cls( reqs=reqs, @@ -636,7 +632,6 @@ def init_new( device=req_to_token_pool.device, spec_algorithm=spec_algorithm, enable_custom_logit_processor=enable_custom_logit_processor, - return_hidden_states=return_hidden_states, ) def batch_size(self): @@ -1205,7 +1200,7 @@ def get_model_worker_batch(self): spec_info=self.spec_info, capture_hidden_mode=( CaptureHiddenMode.FULL - if self.return_hidden_states + if self.sampling_info.return_hidden_states else ( getattr( self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e4a141a9c8b..ea6d6751819 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1030,7 +1030,6 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.enable_overlap, self.spec_algorithm, self.server_args.enable_custom_logit_processor, - self.server_args.return_hidden_states, ) new_batch.prepare_for_extend() @@ -1221,9 +1220,8 @@ def process_batch_result_prefill( logprob_pt += self.add_logprob_return_values( i, req, logprob_pt, next_token_ids, logits_output ) - if ( - self.server_args.return_hidden_states + req.sampling_params.return_hidden_states and logits_output.hidden_states is not None ): req.hidden_states.append( @@ -1331,7 +1329,7 @@ def process_batch_result_decode( ) if ( - self.server_args.return_hidden_states + req.sampling_params.return_hidden_states and logits_output.hidden_states is not None ): req.hidden_states.append(logits_output.hidden_states[i].cpu().clone()) @@ -1459,7 +1457,10 @@ def stream_output( completion_tokens = [] cached_tokens = [] spec_verify_ct = [] - output_hidden_states = [] if self.server_args.return_hidden_states else None + return_hidden_states = any( + req.sampling_params.return_hidden_states for req in reqs + ) + output_hidden_states = [] if return_hidden_states else None if return_logprob: input_token_logprobs_val = [] @@ -1526,7 +1527,7 @@ def stream_output( output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_idx.append(req.output_top_logprobs_idx) - if self.server_args.return_hidden_states: + if req.sampling_params.return_hidden_states: output_hidden_states.append(req.hidden_states) # Send to detokenizer @@ -1619,7 +1620,6 @@ def get_idle_batch(self): self.enable_overlap, self.spec_algorithm, self.server_args.enable_custom_logit_processor, - self.server_args.return_hidden_states, ) idle_batch.prepare_for_idle() return idle_batch diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d3f2e5146e1..e8877e1f8d0 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -120,7 +120,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): if max(capture_bs) > model_runner.req_to_token_pool.size: # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests - # is very samll. We add more values here to make sure we capture the maximum bs. + # is very small. We add more values here to make sure we capture the maximum bs. capture_bs = list( sorted( set( @@ -175,6 +175,7 @@ def __init__(self, model_runner: ModelRunner): # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_forward_mode = ForwardMode.DECODE + self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 if model_runner.spec_algorithm.is_eagle(): if self.model_runner.is_draft_worker: @@ -335,6 +336,10 @@ def capture_one_batch_size(self, bs: int, forward: Callable): gathered_buffer = None spec_info = self.get_spec_info(num_tokens) + if self.capture_hidden_mode != CaptureHiddenMode.FULL: + self.capture_hidden_mode = ( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ) forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, @@ -355,15 +360,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable): mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, - capture_hidden_mode=( - CaptureHiddenMode.FULL - if self.model_runner.server_args.return_hidden_states - else ( - spec_info.capture_hidden_mode - if spec_info - else CaptureHiddenMode.NULL - ) - ), + capture_hidden_mode=self.capture_hidden_mode, ) # Attention backend @@ -406,6 +403,23 @@ def run_once(): def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None + hidden_mode_from_spec_info = getattr( + forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL + ) + # If the capture_hidden_mode changes, we need to recapture the graph + if ( + forward_batch.sampling_info.return_hidden_states + and self.capture_hidden_mode != CaptureHiddenMode.FULL + ): + self.capture_hidden_mode = CaptureHiddenMode.FULL + self.capture() + elif ( + not forward_batch.sampling_info.return_hidden_states + and self.capture_hidden_mode != hidden_mode_from_spec_info + ): + self.capture_hidden_mode = hidden_mode_from_spec_info + self.capture() + raw_bs = forward_batch.batch_size raw_num_token = raw_bs * self.num_tokens_per_bs diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 9521a34f4f6..6297e1fe058 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -37,6 +37,9 @@ class SamplingBatchInfo: # Whether any request has custom logit processor has_custom_logit_processor: bool + # Whether any request needs to return hidden states + return_hidden_states: bool + # Bias Tensors vocab_size: int grammars: Optional[List] = None @@ -91,6 +94,9 @@ def from_schedule_batch( and any(r.custom_logit_processor for r in reqs) # then check the requests. ) + # Check if any request needs to return hidden states + return_hidden_states = any(r.sampling_params.return_hidden_states for r in reqs) + if has_custom_logit_processor: # Merge the same type of custom logit processors together processor_dict = {} @@ -130,6 +136,7 @@ def from_schedule_batch( device=device, custom_params=custom_params, custom_logit_processor=merged_custom_logit_processor, + return_hidden_states=return_hidden_states, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -336,6 +343,10 @@ def merge_batch(self, other: "SamplingBatchInfo"): self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) + + # Merge the return hidden states flag + self.return_hidden_states |= other.return_hidden_states + # Merge the custom logit processors and custom params lists if self.has_custom_logit_processor or other.has_custom_logit_processor: # Merge the custom logit processors diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 2224fb0919a..d82a0f28217 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -23,7 +23,7 @@ class SamplingParams: The sampling parameters. See docs/references/sampling_params.md or - https://docs.sglang.ai/references/sampling_params.html + https://docs.sglang.ai/backend/sampling_params.html for the documentation. """ @@ -48,6 +48,7 @@ def __init__( no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, + return_hidden_states: bool = False, custom_params: Optional[Dict[str, Any]] = None, ) -> None: self.temperature = temperature @@ -72,6 +73,7 @@ def __init__( self.json_schema = json_schema self.ebnf = ebnf self.no_stop_trim = no_stop_trim + self.return_hidden_states = return_hidden_states self.custom_params = custom_params # Process some special cases diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 610c0f5a87c..fd2188dcce7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -162,7 +162,6 @@ class ServerArgs: delete_ckpt_after_loading: bool = False enable_memory_saver: bool = False allow_auto_truncate: bool = False - return_hidden_states: bool = False enable_custom_logit_processor: bool = False tool_call_parser: str = None enable_hierarchical_cache: bool = False @@ -698,7 +697,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--grammar-backend", type=str, - choices=["xgrammar", "outlines"], + choices=["xgrammar", "outlines", "llguidance"], default=ServerArgs.grammar_backend, help="Choose the backend for grammar-guided decoding.", ) @@ -917,11 +916,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable users to pass custom logit processors to the server (disabled by default for security)", ) - parser.add_argument( - "--return-hidden-states", - action="store_true", - help="Return hidden states in the response.", - ) parser.add_argument( "--tool-call-parser", type=str, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 8f5227d5169..3bcb086f295 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -499,14 +499,17 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N pass if include_parent: - try: - itself.kill() + if parent_pid == os.getpid(): + sys.exit(0) + else: + try: + itself.kill() - # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), - # so we send an additional signal to kill them. - itself.send_signal(signal.SIGQUIT) - except psutil.NoSuchProcess: - pass + # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), + # so we send an additional signal to kill them. + itself.send_signal(signal.SIGQUIT) + except psutil.NoSuchProcess: + pass def monkey_patch_p2p_access_check(): diff --git a/python/sglang/test/test_block_fp8_ep.py b/python/sglang/test/test_block_fp8_ep.py new file mode 100644 index 00000000000..c077d0c458b --- /dev/null +++ b/python/sglang/test/test_block_fp8_ep.py @@ -0,0 +1,361 @@ +import itertools +import random +import unittest +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch + +from sglang.srt.layers.moe.ep_moe.kernels import ( + grouped_gemm_triton, + post_reorder_triton_kernel, + pre_reorder_triton_kernel, + run_moe_ep_preproess, + silu_and_mul_triton_kernel, +) +from sglang.srt.layers.moe.topk import select_experts + + +# For test +def ep_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + # ep config + num_experts: int = 256, + fp8_dtype: torch.types = torch.float8_e4m3fn, + num_experts_per_partition: int = 128, + start_expert_id: int = 0, + end_expert_id: int = 127, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + use_fp8_w8a8: bool = False, + w1_scale_inv: Optional[torch.Tensor] = None, + w2_scale_inv: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, +): + use_blockwise_fp8 = block_shape is not None + topk_weights, topk_ids = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + # correction_bias=correction_bias, #skip this in test + custom_routing_function=custom_routing_function, + ) + + reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts) + + gateup_input = torch.empty( + (int(hidden_states.shape[0] * top_k), hidden_states.shape[1]), + device=hidden_states.device, + dtype=( + fp8_dtype + if (use_fp8_w8a8 and not use_blockwise_fp8) + else hidden_states.dtype + ), + ) + + if use_fp8_w8a8 and not use_blockwise_fp8: + max_value = ( + torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32) + ) + w1_input_scale = max_value / torch.finfo(fp8_dtype).max + else: + w1_input_scale = None + + # PreReorder + pre_reorder_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + src2dst, + topk_ids, + w1_input_scale, + start_expert_id, + end_expert_id, + top_k, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + + seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2] + weight_indices_cur_rank = torch.arange( + 0, + num_experts_per_partition, + device=hidden_states.device, + dtype=torch.int64, + ) + + # GroupGemm-0 + gateup_output = torch.empty( + gateup_input.shape[0], + w1.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + gateup_output = grouped_gemm_triton( + a=gateup_input, + b=w1, + c=gateup_output, + batch_size=num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=use_fp8_w8a8, + scale_a=w1_input_scale, + scale_b=w1_scale_inv, + block_shape=block_shape, + ) + + # Act + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=( + fp8_dtype + if (use_fp8_w8a8 and not use_blockwise_fp8) + else hidden_states.dtype + ), + ) + if use_fp8_w8a8 and not use_blockwise_fp8: + w2_input_scale = torch.ones( + num_experts_per_partition, + dtype=torch.float32, + device=hidden_states.device, + ) + else: + w2_input_scale = None + + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + w2_input_scale, + start_expert_id, + end_expert_id, + BLOCK_SIZE=512, + ) + + # GroupGemm-1 + down_output = torch.empty( + down_input.shape[0], + w2.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + down_output = grouped_gemm_triton( + a=down_input, + b=w2, + c=down_output, + batch_size=num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=use_fp8_w8a8, + scale_a=w2_input_scale, + scale_b=w2_scale_inv, + block_shape=block_shape, + ) + + # PostReorder + output = torch.empty_like(hidden_states) + post_reorder_triton_kernel[(hidden_states.size(0),)]( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + start_expert_id, + end_expert_id, + top_k, + hidden_states.size(1), + BLOCK_SIZE=512, + ) + return output + + +# test util +def block_dequant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function converts block-wise quantization to tensor-wise quantization. + The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale + and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. + Note only float8 is supported for now. + """ + + # process 3D tensor + if x_q_block.dim() == 3: + batch_size = x_q_block.size(0) + return torch.stack( + [block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)] + ) + + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = x_q_block.to(torch.float32) + + x_dq_block_tiles = [ + [ + x_dq_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + + for i in range(k_tiles): + for j in range(n_tiles): + x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] + + return x_dq_block + + +class TestW8A8BlockFP8EPMoE(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + M = [1, 222, 1024, 2048] + N = [128, 1024, 2048] + K = [256, 4096, 5120] + E = [8, 16] + ep_size = [2, 4] + TOP_KS = [2, 4] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w8a8_block_fp8_ep_moe( + self, M, N, K, E, ep_size, topk, block_size, dtype, seed + ): + torch.manual_seed(seed) + random.seed(seed) + # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max + w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max + w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_s = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) + + w1_ref = block_dequant(w1, w1_s, block_size).to(dtype) + w2_ref = block_dequant(w2, w2_s, block_size).to(dtype) + + score = torch.randn((M, E), dtype=dtype) + num_experts_per_partition = E // ep_size + cur_rank = random.randint(0, ep_size - 1) + start_id = cur_rank * num_experts_per_partition + end_id = start_id + num_experts_per_partition - 1 + + with torch.inference_mode(): + out = ep_moe( + hidden_states=a, + w1=w1, + w2=w2, + router_logits=score, + top_k=topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale_inv=w1_s, + w2_scale_inv=w2_s, + block_shape=block_size, + num_experts=E, + num_experts_per_partition=num_experts_per_partition, + start_expert_id=start_id, + end_expert_id=end_id, + ) + ref_out = ep_moe( + hidden_states=a, + w1=w1_ref, + w2=w2_ref, + router_logits=score, + top_k=topk, + renormalize=False, + use_fp8_w8a8=False, + w1_scale_inv=None, + w2_scale_inv=None, + block_shape=None, + num_experts=E, + num_experts_per_partition=num_experts_per_partition, + start_expert_id=start_id, + end_expert_id=end_id, + ) + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6) + < 0.06 + ) + + def test_w8a8_block_fp8_ep_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.ep_size, + self.TOP_KS, + self.BLOCK_SIZE, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + ep_size=params[4], + topk=params[5], + block_size=params[6], + dtype=params[7], + seed=params[8], + ): + self._w8a8_block_fp8_ep_moe(*params) + torch.cuda.empty_cache() + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index d5c09751b26..b496ac78741 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -16,7 +16,7 @@ pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2 rm -rf /root/.cache/flashinfer # Force reinstall flashinfer and torch_memory_saver -pip install flashinfer_python==0.2.1.post2 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps +pip install flashinfer_python==0.2.2.post1 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps pip install torch_memory_saver --force-reinstall diff --git a/test/srt/test_ebnf_constrained.py b/test/srt/test_ebnf_constrained.py index 5e852bec6e4..863dc633f36 100644 --- a/test/srt/test_ebnf_constrained.py +++ b/test/srt/test_ebnf_constrained.py @@ -1,6 +1,8 @@ """ python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting +python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_email +python3 -m unittest test_ebnf_constrained.TestEBNFConstrainedLLGuidance.test_ebnf_generate_greeting """ import json @@ -17,7 +19,7 @@ ) -def setup_class(cls, disable_overlap: bool): +def setup_class(cls, backend: str, disable_overlap: bool): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.ebnf_grammar = 'root ::= "test"' # Default grammar @@ -26,7 +28,7 @@ def setup_class(cls, disable_overlap: bool): "--max-running-requests", "10", "--grammar-backend", - "xgrammar", + backend, ] if disable_overlap: @@ -43,7 +45,7 @@ def setup_class(cls, disable_overlap: bool): class TestEBNFConstrained(unittest.TestCase): @classmethod def setUpClass(cls): - setup_class(cls, disable_overlap=False) + setup_class(cls, "xgrammar", disable_overlap=False) cls.check_jump_forward = False @classmethod @@ -236,5 +238,12 @@ def test_ebnf_generate_custom_log_format(self): ) +class TestEBNFConstrainedLLGuidance(TestEBNFConstrained): + @classmethod + def setUpClass(cls): + setup_class(cls, "llguidance", disable_overlap=False) + cls.check_jump_forward = False + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py index 219c04693ad..5e39ea60773 100644 --- a/test/srt/test_hidden_states.py +++ b/test/srt/test_hidden_states.py @@ -14,12 +14,15 @@ def test_return_hidden_states(self): tokenizer = AutoTokenizer.from_pretrained(model_path) input_ids = tokenizer(prompts).input_ids - sampling_params = {"temperature": 0, "max_new_tokens": 8} + sampling_params = { + "temperature": 0, + "max_new_tokens": 8, + "return_hidden_states": True, + } engine = sgl.Engine( model_path=model_path, random_seed=42, - return_hidden_states=True, skip_tokenizer_init=True, ) outputs = engine.generate(input_ids=input_ids, sampling_params=sampling_params) @@ -72,6 +75,58 @@ def test_return_hidden_states(self): ) ) + def test_repeatedly_changes_hidden_states(self): + prompts = ["Today is", "Today is a sunny day and I like"] + model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_path) + input_ids = tokenizer(prompts).input_ids + + sample_completion = { + "temperature": 0, + "max_new_tokens": 8, + "return_hidden_states": True, + } + + sample_hidden_state = { + "temperature": 0, + "max_new_tokens": 8, + "return_hidden_states": False, + } + + engine = sgl.Engine( + model_path=model_path, + random_seed=42, + skip_tokenizer_init=True, + ) + outputs_completion_first_round = engine.generate( + input_ids=input_ids, sampling_params=sample_completion + ) + outputs_hidden_state = engine.generate( + input_ids=input_ids, sampling_params=sample_hidden_state + ) + + outputs_completion_last_round = engine.generate( + input_ids=input_ids, sampling_params=sample_completion + ) + engine.shutdown() + + for ( + output_completion_first_round, + output_hidden_state, + output_completion_last_round, + ) in zip( + outputs_completion_first_round, + outputs_hidden_state, + outputs_completion_last_round, + ): + self.assertEqual( + len(output_completion_first_round["meta_info"]["hidden_states"]), 8 + ) + self.assertNotIn("hidden_states", output_hidden_state["meta_info"]) + self.assertEqual( + len(output_completion_last_round["meta_info"]["hidden_states"]), 8 + ) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index adb5c18fbe2..464604bbab6 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -1,6 +1,7 @@ """ python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate +python3 -m unittest test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate """ import json @@ -30,6 +31,7 @@ def setup_class(cls, backend: str, disable_overlap: bool): "population": {"type": "integer"}, }, "required": ["name", "population"], + "additionalProperties": False, } ) @@ -146,5 +148,12 @@ def setUpClass(cls): cls.check_jump_forward = False +class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend): + @classmethod + def setUpClass(cls): + setup_class(cls, backend="llguidance", disable_overlap=False) + cls.check_jump_forward = False + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_regex_constrained.py b/test/srt/test_regex_constrained.py index 6d5acec15e2..303f5f118b8 100644 --- a/test/srt/test_regex_constrained.py +++ b/test/srt/test_regex_constrained.py @@ -1,6 +1,10 @@ """ python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting +python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_email +python3 -m unittest test_regex_constrained.TestRegexConstrainedLLGuidance.test_regex_generate_greeting +python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_email +python3 -m unittest test_regex_constrained.TestJumpForwardLLGuidance.test_regex_generate_greeting """ import json @@ -17,7 +21,7 @@ ) -def setup_class(cls, disable_overlap: bool): +def setup_class(cls, backend: str, disable_overlap: bool): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST @@ -25,7 +29,7 @@ def setup_class(cls, disable_overlap: bool): "--max-running-requests", "10", "--grammar-backend", - "xgrammar", + backend, ] if disable_overlap: @@ -42,7 +46,7 @@ def setup_class(cls, disable_overlap: bool): class TestRegexConstrained(unittest.TestCase): @classmethod def setUpClass(cls): - setup_class(cls, disable_overlap=False) + setup_class(cls, "xgrammar", disable_overlap=False) cls.check_jump_forward = False @classmethod @@ -178,9 +182,22 @@ def test_regex_generate_custom_log_format(self): class TestJumpForward(TestRegexConstrained): @classmethod def setUpClass(cls): - setup_class(cls, disable_overlap=True) + setup_class(cls, "xgrammar", disable_overlap=True) cls.check_jump_forward = True +class TestJumpForwardLLGuidance(TestRegexConstrained): + @classmethod + def setUpClass(cls): + setup_class(cls, "llguidance", disable_overlap=True) + cls.check_jump_forward = True + + +class TestRegexConstrainedLLGuidance(TestRegexConstrained): + @classmethod + def setUpClass(cls): + setup_class(cls, "llguidance", disable_overlap=True) + + if __name__ == "__main__": unittest.main()