From ffa070dec5e1af1c4b073e2ad41e5d4d0f2b179c Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Wed, 10 Apr 2024 13:25:16 -0700 Subject: [PATCH] embedding on cuda --- .github/workflows/compile_t4.yml | 4 ++-- quantize.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/compile_t4.yml b/.github/workflows/compile_t4.yml index 3c9c33570..a3f6da5b5 100644 --- a/.github/workflows/compile_t4.yml +++ b/.github/workflows/compile_t4.yml @@ -52,8 +52,8 @@ jobs: echo "******************************************" echo "******* Emb: channel-wise quantized ******" echo "******************************************" - # python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager - # cat ./output_eager + python generate.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_eager + cat ./output_eager # python generate.py --device cuda --compile --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --temperature 0 > ./output_compiled # cat ./output_compiled # python export.py --device cuda --quant '{"embedding" : {"bitwidth": 8, "group_size": 0}}' --checkpoint-path ${MODEL_PATH} --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so diff --git a/quantize.py b/quantize.py index c2093cfba..120c14ee5 100644 --- a/quantize.py +++ b/quantize.py @@ -572,17 +572,18 @@ def __init__( group_size = embedding_dim self.group_size = group_size self.dtype = dtype + if device is None: device = "cpu" self.register_buffer( - "weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8) + "weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8, device=device) ) groups_per_row = (embedding_dim + group_size - 1) // group_size if groups_per_row > 1: self.register_buffer( - "scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16) + "scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16, device=device) ) else: self.register_buffer( - "scales", torch.ones((vocab_size,), dtype=torch.float16) + "scales", torch.ones((vocab_size,), dtype=torch.float16, device=device) ) @torch.no_grad()