-
Notifications
You must be signed in to change notification settings - Fork 693
Add npu support for LLM.int8 forward #1534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: multi-backend-refactor
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
|
||
from bitsandbytes.cextension import lib | ||
from bitsandbytes.functional import ( | ||
COOSparseTensor, | ||
get_4bit_type, | ||
get_ptr, | ||
) | ||
|
@@ -28,6 +29,43 @@ def assert_on_npu(tensors): | |
return True | ||
|
||
|
||
def coo_zeros(rows, cols, rowidx, colidx, values, nnz, device, dtype=torch.half): | ||
rowidx = rowidx.to(torch.int32) | ||
colidx = colidx.to(torch.int32) | ||
values = values.to(device).to(dtype) | ||
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) | ||
|
||
|
||
def row_col_stats(A, threshold): | ||
cols = A.shape[-1] | ||
if len(A.shape) == 3: | ||
rows = A.shape[0] * A.shape[1] | ||
else: | ||
rows = A.shape[0] | ||
|
||
row_max = torch.zeros(rows, dtype=torch.float32, device="npu") | ||
col_max = torch.zeros(cols, dtype=torch.float32, device="npu") | ||
outlier_num = torch.zeros(1, dtype=torch.int32, device="npu") | ||
lib.cget_col_row_stats( | ||
get_ptr(A), | ||
get_ptr(row_max), | ||
get_ptr(col_max), | ||
get_ptr(outlier_num), | ||
ct.c_float(threshold), | ||
ct.c_int32(rows), | ||
ct.c_int32(cols), | ||
torch.npu.current_stream() | ||
) | ||
return row_max, col_max, outlier_num | ||
|
||
|
||
class Int8AB: | ||
def __init__(self, A: torch.Tensor, B: torch.Tensor): | ||
self.A = A | ||
self.B = B | ||
self.device = A.device | ||
|
||
|
||
class NPUBackend(Backend): | ||
def int8_double_quant( | ||
self, | ||
|
@@ -38,7 +76,53 @@ def int8_double_quant( | |
out_row: Optional[torch.Tensor] = None, | ||
threshold=0.0, | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | ||
raise NotImplementedError | ||
past_device = None | ||
device = A.device | ||
assert A.dtype == torch.half | ||
assert device.type == "npu" | ||
|
||
cols = A.shape[-1] | ||
if len(A.shape) == 3: | ||
rows = A.shape[0] * A.shape[1] | ||
else: | ||
rows = A.shape[0] | ||
|
||
if past_device != str(A.device): | ||
torch.npu.set_device(A.device) # reset context | ||
past_device = str(A.device) | ||
|
||
row_stats, col_stats, cnt_npu = row_col_stats(A, threshold) | ||
|
||
quant_row = torch.empty((rows, cols), dtype=torch.int8, device=device) | ||
quant_col = torch.empty((rows, cols), dtype=torch.int8, device=device) | ||
outliers_row_idx = torch.zeros(rows, dtype=torch.int32, device=device) | ||
outliers_col_idx = torch.zeros(40 * cols, dtype=torch.int32, device=device) - 1 | ||
outliers_value = torch.empty(0, dtype=torch.float16, device=device) | ||
|
||
lib.cdouble_rowcol_quant( | ||
get_ptr(A), | ||
get_ptr(row_stats), | ||
get_ptr(col_stats), | ||
get_ptr(quant_row), | ||
get_ptr(quant_col), | ||
get_ptr(outliers_row_idx), | ||
get_ptr(outliers_col_idx), | ||
get_ptr(outliers_value), | ||
ct.c_int(cols), | ||
ct.c_float(threshold), | ||
ct.c_int32(rows), | ||
ct.c_int32(cols), | ||
torch.npu.current_stream() | ||
) | ||
|
||
colidx_tmp = torch.unique(outliers_col_idx) | ||
colidx = colidx_tmp[colidx_tmp != -1] | ||
|
||
coo_tensor = None | ||
if threshold != 0.0: | ||
coo_tensor = coo_zeros(rows, cols, outliers_row_idx, colidx, outliers_value, cnt_npu, device, dtype=torch.half) | ||
|
||
return quant_row, quant_col, row_stats, col_stats, coo_tensor | ||
|
||
def int8_vectorwise_dequant(self, A, stats): | ||
return super().int8_vectorwise_dequant(A, stats) | ||
|
@@ -48,7 +132,35 @@ def int8_vectorwise_quant( | |
A: torch.Tensor, | ||
threshold=0.0, | ||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | ||
raise NotImplementedError | ||
device = A.device | ||
assert A.dtype == torch.half | ||
assert device.type == "npu" | ||
|
||
cols = A.shape[-1] | ||
if len(A.shape) == 3: | ||
rows = A.shape[0] * A.shape[1] | ||
else: | ||
rows = A.shape[0] | ||
|
||
A_no_threshold = None | ||
if threshold > 0.0: | ||
zero = torch.tensor(0.0, dtype=torch.half, device=device) | ||
A_no_threshold = torch.where(A.view(rows, cols).abs() < threshold, A.view(rows, cols), zero) | ||
row_stats = torch.amax(A_no_threshold.abs(), dim=1, keepdim=True).to(device) | ||
out_row = torch.round(A_no_threshold * 127.0 / row_stats).to(torch.int8) | ||
else: | ||
row_stats = torch.amax(A.view(rows, cols).abs(), dim=1, keepdim=True).to(device) | ||
out_row = torch.round(A * 127.0 / row_stats).to(torch.int8) | ||
|
||
outlier_cols = None | ||
if threshold > 0.0: | ||
# TODO we could improve perf of this | ||
outliers = A.abs() >= threshold | ||
|
||
if outliers.any(): | ||
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) | ||
|
||
return out_row, row_stats, outlier_cols | ||
|
||
def transform( | ||
self, | ||
|
@@ -69,7 +181,7 @@ def int8_linear_matmul( | |
out: Optional[torch.Tensor] = None, | ||
dtype=torch.int32, | ||
) -> torch.Tensor: | ||
raise NotImplementedError | ||
return Int8AB(A, B) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting and clever! While this does break the expected API as this isn't returning a Tensor (or performing any operations really), I can completely understand why it is done this way. I think this will be OK for right now and we'll make the interface better in this regard later on. |
||
|
||
def int8_mm_dequant( | ||
self, | ||
|
@@ -79,7 +191,15 @@ def int8_mm_dequant( | |
out: Optional[torch.Tensor] = None, | ||
bias: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
raise NotImplementedError | ||
A, B = A.A, A.B | ||
out = torch_npu.npu_quant_matmul( | ||
A, | ||
B.t(), | ||
scale=col_stats.float() / 127.0, | ||
pertoken_scale=row_stats.float().view(-1) / 127.0, | ||
output_dtype=torch.float16 | ||
) | ||
return out | ||
|
||
def extract_outliers( | ||
self, | ||
|
@@ -106,6 +226,10 @@ def quantize_4bit( | |
if blocksize is None: | ||
blocksize = 128 | ||
|
||
total_blocks = A.numel() // blocksize | ||
chunks = 8 if A.numel() > 2048 * 2048 else 1 | ||
chunksize = (total_blocks + chunks - 1) // chunks | ||
|
||
prev_device = torch.npu.current_device() | ||
torch.npu.set_device(A.device) | ||
if A.dtype in [torch.float32, torch.float16, torch.bfloat16]: | ||
|
@@ -128,12 +252,27 @@ def quantize_4bit( | |
1.0, | ||
] | ||
data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1) | ||
absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values | ||
a = A.view(-1, blocksize) / absmax.float() | ||
diff = torch.abs(a.unsqueeze(-1) - data) | ||
out = (torch.argmin(diff, dim=-1) + 8) % 16 | ||
out = out.reshape(-1, 2) | ||
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8) | ||
chunks_absmax = [] | ||
chunks_out = [] | ||
|
||
for i in range(chunks): | ||
start = i * chunksize * blocksize | ||
end = min((i + 1) * chunksize * blocksize, A.numel()) | ||
chunk_data = A.view(-1)[start:end].view(-1, blocksize) | ||
|
||
absmax = chunk_data.abs().max(dim=1, keepdim=True).values | ||
chunks_absmax.append(absmax) | ||
|
||
a = chunk_data / absmax.float() | ||
diff = torch.abs(a.unsqueeze(-1) - data) | ||
out = (torch.argmin(diff, dim=-1) + 8) % 16 | ||
|
||
out = out.reshape(-1, 2) | ||
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8) | ||
chunks_out.append(out) | ||
|
||
absmax = torch.cat(chunks_absmax, dim=0) | ||
out = torch.cat(chunks_out, dim=0) | ||
else: | ||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") | ||
assert_on_npu([A, absmax, out]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import copy | ||
import importlib | ||
from typing import Any, Dict, Optional, TypeVar, Union, overload | ||
import warnings | ||
|
||
|
@@ -320,9 +321,6 @@ def cpu(self, non_blocking: bool = False): | |
return self.to(device="cpu", non_blocking=non_blocking) | ||
|
||
def npu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False): | ||
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). | ||
if isinstance(device, int): | ||
device = f"npu:{device}" | ||
Comment on lines
-323
to
-325
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this require a bump in the minimum torch_npu to support now? |
||
return self.to(device="npu" if device is None else device, non_blocking=non_blocking) | ||
|
||
def xpu(self, non_blocking: bool = False): | ||
|
@@ -345,7 +343,10 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ... | |
def to(self, *args, **kwargs): | ||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) | ||
|
||
if device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized: | ||
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). | ||
if importlib.util.find_spec("torch_npu") and device.type == "cuda" and not self.bnb_quantized: | ||
return self._quantize(f"npu:{device}" if isinstance(device, int) else str(device).replace("cuda", "npu")) | ||
elif device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized: | ||
return self._quantize(device) | ||
else: | ||
if self.quant_state is not None: | ||
|
@@ -677,6 +678,19 @@ def xpu(self, device): | |
self.SCB = SCB | ||
return self | ||
|
||
def npu(self, device): | ||
# we store the 8-bit rows-major weight | ||
B = self.data.contiguous().to(torch.float16).npu(device) | ||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) | ||
if CBt is not None: | ||
del CBt | ||
if SCBt is not None: | ||
del SCBt | ||
self.data = CB | ||
self.CB = CB | ||
self.SCB = SCB | ||
return self | ||
|
||
@overload | ||
def to( | ||
self: T, | ||
|
@@ -695,7 +709,10 @@ def to(self, *args, **kwargs): | |
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) | ||
|
||
if device is not None: | ||
if device.type == "cuda" and self.data.device.type == "cpu": | ||
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). | ||
if importlib.util.find_spec("torch_npu") and device.type == "cuda": | ||
return self.npu(f"npu:{device}" if isinstance(device, int) else str(device).replace("cuda", "npu")) | ||
elif device.type == "cuda" and self.data.device.type == "cpu": | ||
return self.cuda(device) | ||
elif device.type == "cpu": | ||
if self.data.dtype == torch.int8: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As an optimization this can probably be avoided when
threshold==0.0
and moved into the condition below.