Skip to content

Add npu support for LLM.int8 forward #1534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: multi-backend-refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 149 additions & 10 deletions bitsandbytes/backends/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from bitsandbytes.cextension import lib
from bitsandbytes.functional import (
COOSparseTensor,
get_4bit_type,
get_ptr,
)
Expand All @@ -28,6 +29,43 @@ def assert_on_npu(tensors):
return True


def coo_zeros(rows, cols, rowidx, colidx, values, nnz, device, dtype=torch.half):
rowidx = rowidx.to(torch.int32)
colidx = colidx.to(torch.int32)
values = values.to(device).to(dtype)
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)


def row_col_stats(A, threshold):
cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]

row_max = torch.zeros(rows, dtype=torch.float32, device="npu")
col_max = torch.zeros(cols, dtype=torch.float32, device="npu")
outlier_num = torch.zeros(1, dtype=torch.int32, device="npu")
lib.cget_col_row_stats(
get_ptr(A),
get_ptr(row_max),
get_ptr(col_max),
get_ptr(outlier_num),
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
torch.npu.current_stream()
)
return row_max, col_max, outlier_num


class Int8AB:
def __init__(self, A: torch.Tensor, B: torch.Tensor):
self.A = A
self.B = B
self.device = A.device


class NPUBackend(Backend):
def int8_double_quant(
self,
Expand All @@ -38,7 +76,53 @@ def int8_double_quant(
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError
past_device = None
device = A.device
assert A.dtype == torch.half
assert device.type == "npu"

cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]

if past_device != str(A.device):
torch.npu.set_device(A.device) # reset context
past_device = str(A.device)

row_stats, col_stats, cnt_npu = row_col_stats(A, threshold)

quant_row = torch.empty((rows, cols), dtype=torch.int8, device=device)
quant_col = torch.empty((rows, cols), dtype=torch.int8, device=device)
outliers_row_idx = torch.zeros(rows, dtype=torch.int32, device=device)
outliers_col_idx = torch.zeros(40 * cols, dtype=torch.int32, device=device) - 1
outliers_value = torch.empty(0, dtype=torch.float16, device=device)

lib.cdouble_rowcol_quant(
get_ptr(A),
get_ptr(row_stats),
get_ptr(col_stats),
get_ptr(quant_row),
get_ptr(quant_col),
get_ptr(outliers_row_idx),
get_ptr(outliers_col_idx),
get_ptr(outliers_value),
ct.c_int(cols),
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
torch.npu.current_stream()
)

colidx_tmp = torch.unique(outliers_col_idx)
colidx = colidx_tmp[colidx_tmp != -1]
Comment on lines +118 to +119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an optimization this can probably be avoided when threshold==0.0 and moved into the condition below.


coo_tensor = None
if threshold != 0.0:
coo_tensor = coo_zeros(rows, cols, outliers_row_idx, colidx, outliers_value, cnt_npu, device, dtype=torch.half)

return quant_row, quant_col, row_stats, col_stats, coo_tensor

def int8_vectorwise_dequant(self, A, stats):
return super().int8_vectorwise_dequant(A, stats)
Expand All @@ -48,7 +132,35 @@ def int8_vectorwise_quant(
A: torch.Tensor,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError
device = A.device
assert A.dtype == torch.half
assert device.type == "npu"

cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]

A_no_threshold = None
if threshold > 0.0:
zero = torch.tensor(0.0, dtype=torch.half, device=device)
A_no_threshold = torch.where(A.view(rows, cols).abs() < threshold, A.view(rows, cols), zero)
row_stats = torch.amax(A_no_threshold.abs(), dim=1, keepdim=True).to(device)
out_row = torch.round(A_no_threshold * 127.0 / row_stats).to(torch.int8)
else:
row_stats = torch.amax(A.view(rows, cols).abs(), dim=1, keepdim=True).to(device)
out_row = torch.round(A * 127.0 / row_stats).to(torch.int8)

outlier_cols = None
if threshold > 0.0:
# TODO we could improve perf of this
outliers = A.abs() >= threshold

if outliers.any():
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)

return out_row, row_stats, outlier_cols

def transform(
self,
Expand All @@ -69,7 +181,7 @@ def int8_linear_matmul(
out: Optional[torch.Tensor] = None,
dtype=torch.int32,
) -> torch.Tensor:
raise NotImplementedError
return Int8AB(A, B)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting and clever! While this does break the expected API as this isn't returning a Tensor (or performing any operations really), I can completely understand why it is done this way. I think this will be OK for right now and we'll make the interface better in this regard later on.


def int8_mm_dequant(
self,
Expand All @@ -79,7 +191,15 @@ def int8_mm_dequant(
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
A, B = A.A, A.B
out = torch_npu.npu_quant_matmul(
A,
B.t(),
scale=col_stats.float() / 127.0,
pertoken_scale=row_stats.float().view(-1) / 127.0,
output_dtype=torch.float16
)
return out

def extract_outliers(
self,
Expand All @@ -106,6 +226,10 @@ def quantize_4bit(
if blocksize is None:
blocksize = 128

total_blocks = A.numel() // blocksize
chunks = 8 if A.numel() > 2048 * 2048 else 1
chunksize = (total_blocks + chunks - 1) // chunks

prev_device = torch.npu.current_device()
torch.npu.set_device(A.device)
if A.dtype in [torch.float32, torch.float16, torch.bfloat16]:
Expand All @@ -128,12 +252,27 @@ def quantize_4bit(
1.0,
]
data = torch.tensor(data, device="npu", dtype=torch.float32).view(1, -1)
absmax = A.view(-1, blocksize).abs().max(dim=1, keepdim=True).values
a = A.view(-1, blocksize) / absmax.float()
diff = torch.abs(a.unsqueeze(-1) - data)
out = (torch.argmin(diff, dim=-1) + 8) % 16
out = out.reshape(-1, 2)
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8)
chunks_absmax = []
chunks_out = []

for i in range(chunks):
start = i * chunksize * blocksize
end = min((i + 1) * chunksize * blocksize, A.numel())
chunk_data = A.view(-1)[start:end].view(-1, blocksize)

absmax = chunk_data.abs().max(dim=1, keepdim=True).values
chunks_absmax.append(absmax)

a = chunk_data / absmax.float()
diff = torch.abs(a.unsqueeze(-1) - data)
out = (torch.argmin(diff, dim=-1) + 8) % 16

out = out.reshape(-1, 2)
out = (out[:, 0] + out[:, 1] * 16).to(torch.uint8)
chunks_out.append(out)

absmax = torch.cat(chunks_absmax, dim=0)
out = torch.cat(chunks_out, dim=0)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
assert_on_npu([A, absmax, out])
Expand Down
27 changes: 22 additions & 5 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import importlib
from typing import Any, Dict, Optional, TypeVar, Union, overload
import warnings

Expand Down Expand Up @@ -320,9 +321,6 @@ def cpu(self, non_blocking: bool = False):
return self.to(device="cpu", non_blocking=non_blocking)

def npu(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if isinstance(device, int):
device = f"npu:{device}"
Comment on lines -323 to -325
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this require a bump in the minimum torch_npu to support now?

return self.to(device="npu" if device is None else device, non_blocking=non_blocking)

def xpu(self, non_blocking: bool = False):
Expand All @@ -345,7 +343,10 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized:
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if importlib.util.find_spec("torch_npu") and device.type == "cuda" and not self.bnb_quantized:
return self._quantize(f"npu:{device}" if isinstance(device, int) else str(device).replace("cuda", "npu"))
elif device is not None and device.type in ["cuda", "cpu", "npu", "xpu"] and not self.bnb_quantized:
return self._quantize(device)
else:
if self.quant_state is not None:
Expand Down Expand Up @@ -677,6 +678,19 @@ def xpu(self, device):
self.SCB = SCB
return self

def npu(self, device):
# we store the 8-bit rows-major weight
B = self.data.contiguous().to(torch.float16).npu(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
if CBt is not None:
del CBt
if SCBt is not None:
del SCBt
self.data = CB
self.CB = CB
self.SCB = SCB
return self

@overload
def to(
self: T,
Expand All @@ -695,7 +709,10 @@ def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if device is not None:
if device.type == "cuda" and self.data.device.type == "cpu":
# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if importlib.util.find_spec("torch_npu") and device.type == "cuda":
return self.npu(f"npu:{device}" if isinstance(device, int) else str(device).replace("cuda", "npu"))
elif device.type == "cuda" and self.data.device.type == "cpu":
return self.cuda(device)
elif device.type == "cpu":
if self.data.dtype == torch.int8:
Expand Down
Loading
Loading