Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hqq support #21

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft

Hqq support #21

wants to merge 17 commits into from

Conversation

ElizaWszola
Copy link

@ElizaWszola ElizaWszola commented Oct 14, 2024

unit tests:

pytest tests/kernels/test_marlin_gemm.py -k test_hqq_marlin_gemm

offline inference:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_config = HqqConfig(nbits=4, group_size=64, axis=1)

model = AutoModelForCausalLM.from_pretrained(model_path,
                                             torch_dtype=torch.float16,
                                             cache_dir='.',
                                             device_map="cuda:0",
                                             quantization_config=quant_config,
                                             low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

qp = "tinyllama_hqq"
model.save_pretrained(qp)
tokenizer.save_pretrained(qp)

llm = LLM(
    model=qp,
    quantization="hqq",
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main this that needs to be updated in this PR is that we should not make any changes to the vllm/model_executor/models directory (there should be no changes made to llama.py). This allows us to encapsulate the details of HQQ. Right now, it is coupled with llama.py so it will only work for this model

Just like the other quantization methods (e.g. GPTQMarlin), we should setup create_weights such that the state dict of the vllm model matches the state_dict of the serialized model ... (for example -> this hqq_map for example should not be needed. Instead, just name the parameter W_q rather than .qweight)

Additionally, the conversion from the serialized format to the kernel format should be handled in process_weights_after_loading. So the create_weights should make tensors with the same type / shape as the serialized state dict and then functions that convert to the kernel format (e.g. unpack_4bit_u8) can do the conversion during process_weights_after_loading

Is there something unique about HQQ that prevents us from following this pattern?

@ElizaWszola
Copy link
Author

Is there something unique about HQQ that prevents us from following this pattern?

@robertgshaw2-neuralmagic My main difficulty has been the 4-bit quantization pattern where a tensor A of size (2M, N) is quantized such that the lower 4-bits of the 8-bit result elements correspond to the first (M, N) elements of A (while high 4-bits stand for last (M, N) elements of A). This was causing some issues with sharding, so I ended up unpacking from 4-bit to 8-bit when loading data with llama.py. It gets repacked into marlin format later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants