Skip to content

Commit

Permalink
Merge branch 'run_compressed-tests' of github.com:vllm-project/llm-co…
Browse files Browse the repository at this point in the history
…mpressor into run_compressed-tests
  • Loading branch information
horheynm committed Jan 14, 2025
2 parents 21e6b73 + de3da3a commit 9a3d14d
Show file tree
Hide file tree
Showing 13 changed files with 374 additions and 40 deletions.
13 changes: 0 additions & 13 deletions examples/automodelforcausallm/README.md

This file was deleted.

11 changes: 0 additions & 11 deletions examples/automodelforcausallm/run_automodelforcausallm.py

This file was deleted.

23 changes: 19 additions & 4 deletions examples/multimodal_vision/llava_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import requests
from PIL import Image
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
Expand All @@ -23,8 +25,8 @@
GPTQModifier(
targets="Linear",
scheme="W4A16",
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
sequential_targets=["LlamaDecoderLayer"],
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
),
]

Expand All @@ -43,9 +45,22 @@

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")

# Save to disk compressed.
Expand Down
21 changes: 18 additions & 3 deletions examples/multimodal_vision/mllama_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import requests
from PIL import Image
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
Expand Down Expand Up @@ -42,9 +44,22 @@

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")

# Save to disk compressed.
Expand Down
23 changes: 19 additions & 4 deletions examples/multimodal_vision/pixtral_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import requests
from PIL import Image
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
Expand All @@ -23,8 +25,8 @@
GPTQModifier(
targets="Linear",
scheme="W4A16",
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
sequential_targets=["MistralDecoderLayer"],
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
),
]

Expand All @@ -43,9 +45,22 @@

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")

# Save to disk compressed.
Expand Down
4 changes: 4 additions & 0 deletions examples/quantizing_moe/deepseek_moe_w4a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map

# NOTE: transformers 4.48.0 has an import error with DeepSeek.
# Please consider either downgrading your transformers version to a
# previous version or upgrading to a version where this bug is fixed

# select a Mixture of Experts model for quantization
MODEL_ID = "deepseek-ai/DeepSeek-V2.5"

Expand Down
4 changes: 4 additions & 0 deletions examples/quantizing_moe/deepseek_moe_w8a8_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

# NOTE: transformers 4.48.0 has an import error with DeepSeek.
# Please consider either downgrading your transformers version to a
# previous version or upgrading to a version where this bug is fixed

# select a Mixture of Experts model for quantization
MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"

Expand Down
4 changes: 4 additions & 0 deletions examples/quantizing_moe/deepseek_moe_w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map

# NOTE: transformers 4.48.0 has an import error with DeepSeek.
# Please consider either downgrading your transformers version to a
# previous version or upgrading to a version where this bug is fixed

# select a Mixture of Experts model for quantization
MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"

Expand Down
7 changes: 4 additions & 3 deletions src/llmcompressor/transformers/tracing/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def maybe_install_metadata_image_features(

# TRACING: The shape of inputs_embeds is known. This function compensates for
# the fact that shape inference through `masked_scatter` is not implemented yet
def maybe_install_metadata_inputs_embeds(
def maybe_install_metadata_inputs_embeds_masked(
inputs_embeds_masked: Union[torch.Tensor, HFProxy],
inputs_embeds: Union[torch.Tensor, HFProxy],
special_image_mask: Union[torch.Tensor, HFProxy],
Expand All @@ -70,7 +70,7 @@ def maybe_install_metadata_inputs_embeds(
)
inputs_embeds_masked.install_metadata(metadata)

return inputs_embeds
return inputs_embeds_masked


# TRACING: override `__init__` and `forward`
Expand Down Expand Up @@ -153,6 +153,7 @@ def forward(
vision_feature_select_strategy=vision_feature_select_strategy,
)

# TRACING: install metadata
image_features = maybe_install_metadata_image_features(
image_features, pixel_values, self.config
)
Expand Down Expand Up @@ -223,7 +224,7 @@ def forward(
inputs_embeds_masked = inputs_embeds.masked_scatter(special_image_mask, image_features)

# TRACING: install metadata
inputs_embeds_masked = maybe_install_metadata_inputs_embeds(inputs_embeds_masked, inputs_embeds, special_image_mask, image_features)
inputs_embeds_masked = maybe_install_metadata_inputs_embeds_masked(inputs_embeds_masked, inputs_embeds, special_image_mask, image_features)
inputs_embeds = inputs_embeds_masked

outputs = self.language_model(
Expand Down
12 changes: 11 additions & 1 deletion tests/e2e/vLLM/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
logger.warning("vllm is not installed. This test will be skipped")

HF_MODEL_HUB_NAME = "nm-testing"

TEST_DATA_FILE = os.environ.get("TEST_DATA_FILE", "")

EXPECTED_SAVED_FILES = [
Expand Down Expand Up @@ -129,8 +130,17 @@ def test_vllm(self):

logger.info("================= UPLOADING TO HUB ======================")

stub = f"{HF_MODEL_HUB_NAME}/{self.save_dir}-e2e"

self.api.create_repo(
repo_id=stub,
exist_ok=True,
repo_type="model",
private=False,
)

self.api.upload_folder(
repo_id=f"{HF_MODEL_HUB_NAME}/{self.save_dir}-e2e",
repo_id=stub,
folder_path=self.save_dir,
)

Expand Down
33 changes: 33 additions & 0 deletions tests/examples/test_sparse_2of4_quantization_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pathlib import Path

import pytest

from tests.examples.utils import (
copy_and_run_script,
gen_cmd_fail_message,
requires_gpu_count,
)


@pytest.fixture
def example_dir() -> str:
return "examples/sparse_2of4_quantization_fp8"


@requires_gpu_count(1)
class TestSparse2of4QuantizationFP8:
"""
Tests for examples in the "sparse_2of4_quantization_fp8" example folder.
"""

@pytest.mark.parametrize(("flags"), [[], ["--fp8"]])
def test_blah(self, example_dir: str, tmp_path: Path, flags: list[str]):
"""
Tests for the "llama3_8b_2of4.py" example script.
"""
script_filename = "llama3_8b_2of4.py"
command, result = copy_and_run_script(
tmp_path, example_dir, script_filename, flags=flags
)

assert result.returncode == 0, gen_cmd_fail_message(command, result)
7 changes: 6 additions & 1 deletion tests/examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def copy_and_run_command(


def copy_and_run_script(
tmp_path: Path, example_dir: str, script_filename: str
tmp_path: Path,
example_dir: str,
script_filename: str,
flags: Optional[list[str]] = None,
) -> Tuple[List[str], CompletedProcess[str]]:
"""
Copies the contents of example_dir (relative to the current working directory) to
Expand All @@ -81,6 +84,8 @@ def copy_and_run_script(
:return: subprocess.CompletedProcess object
"""
command = [sys.executable, script_filename]
if flags:
command.extend(flags)
return command, copy_and_run_command(tmp_path, example_dir, command)


Expand Down
Loading

0 comments on commit 9a3d14d

Please sign in to comment.