Skip to content

Commit

Permalink
fix prefix noquant bug
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang committed Feb 12, 2025
1 parent 865eff9 commit 180030d
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 33 deletions.
13 changes: 9 additions & 4 deletions lightllm/common/int8kv_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@ def get_cell_size(self):
) + 2 * self.head_num * self.layer_num * torch._utils._element_size(self.dtype)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = [
torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)
]
self.scale_buffer = [torch.empty((size, 2 * head_num, 1), dtype=dtype, device="cuda") for _ in range(layer_num)]
self.kv_buffer = torch.empty((layer_num, size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda")
self.scale_buffer = torch.empty((layer_num, size, 2 * head_num, 1), dtype=dtype, device="cuda")

def _free_buffers(self):
self.kv_buffer = None
self.scale_buffer = None

def get_index_kv_buffer(self, index):
return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]}

def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"])
6 changes: 6 additions & 0 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,12 @@ def resize_mem(self, new_size):
self._init_buffers(size, dtype, head_num, head_dim, layer_num)
return

def get_index_kv_buffer(self, index):
return {"kv_buffer": self.kv_buffer[:, index]}

def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])


class ReadOnlyStaticsMemoryManager:
"""
Expand Down
18 changes: 11 additions & 7 deletions lightllm/common/ppl_int4kv_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ def get_cell_size(self):
)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = [
torch.empty((size, 2 * head_num, head_dim // 2), dtype=torch.int8, device="cuda") for _ in range(layer_num)
]
self.scale_buffer = [
torch.empty((size, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
self.kv_buffer = torch.empty((layer_num, size, 2 * head_num, head_dim // 2), dtype=torch.int8, device="cuda")
self.scale_buffer = torch.empty(
(layer_num, size, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda"
)

def _free_buffers(self):
self.kv_buffer = None
self.scale_buffer = None

def get_index_kv_buffer(self, index):
return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]}

def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"])
18 changes: 11 additions & 7 deletions lightllm/common/ppl_int8kv_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@ def get_cell_size(self):
)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = [
torch.empty((size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)
]
self.scale_buffer = [
torch.empty((size, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
self.kv_buffer = torch.empty((layer_num, size, 2 * head_num, head_dim), dtype=torch.int8, device="cuda")
self.scale_buffer = torch.empty(
(layer_num, size, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda"
)

def _free_buffers(self):
self.kv_buffer = None
self.scale_buffer = None

def get_index_kv_buffer(self, index):
return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]}

def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])
self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"])
13 changes: 3 additions & 10 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,9 @@ def _save_promptcache_kvbuffer(self):
"""
prompt_cache_token_id = list(self.radix_cache.root_node.children.values())[0].token_id_key
print(f"prompt_cache_token_id : {prompt_cache_token_id}")
if isinstance(self.radix_cache.mem_manager.kv_buffer, list):
kv_buffer_list = []
for i in range(len(self.radix_cache.mem_manager.kv_buffer)):
kv_buffer_list.append(self.radix_cache.mem_manager.kv_buffer[i][: len(prompt_cache_token_id)])
torch.save(kv_buffer_list, f"prompt_cache_rank_{dist.get_rank()}.pt")
else:
torch.save(
self.radix_cache.mem_manager.kv_buffer[:, : len(prompt_cache_token_id)],
f"prompt_cache_rank_{dist.get_rank()}.pt",
)
index = range(len(prompt_cache_token_id))
prompt_cache_kv_buffer = self.radix_cache.mem_manager.get_index_kv_buffer(index)
torch.save(prompt_cache_kv_buffer, f"prompt_cache_rank_{dist.get_rank()}.pt")

@torch.no_grad()
def filter(self, finished_request_ids: List[int]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,7 @@ def preload_prompt_cache_kv_buffer(self, model_cfg):
prompt_cache_kv_buffer = torch.load(prompt_cache_kv_buffer_path, weights_only=True, map_location="cpu")
intact_kv_len = len(model_cfg["prompt_cache_token_ids"])
intact_kv_index = self.radix_cache.mem_manager.alloc(intact_kv_len)
if isinstance(self.radix_cache.mem_manager.kv_buffer, list):
for i in range(len(self.radix_cache.mem_manager.kv_buffer)):
self.radix_cache.mem_manager.kv_buffer[i][intact_kv_index].copy_(prompt_cache_kv_buffer[i])
else:
self.radix_cache.mem_manager.kv_buffer[:, intact_kv_index].copy_(prompt_cache_kv_buffer)
self.radix_cache.mem_manager.load_index_kv_buffer(intact_kv_index, prompt_cache_kv_buffer)
self.radix_cache.insert(
torch.tensor(model_cfg["prompt_cache_token_ids"], dtype=torch.int64, device="cpu"),
intact_kv_index,
Expand Down

0 comments on commit 180030d

Please sign in to comment.