Skip to content

Commit

Permalink
quantized model
Browse files Browse the repository at this point in the history
  • Loading branch information
andreea-popescu-reef committed Sep 17, 2024
1 parent ddcc0a9 commit 55607cd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/compute_horde_prompt_gen/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ RUN mkdir /output

# Copy your Python script into the container
COPY saved_models/ /app/saved_models/
COPY *.py .
COPY *.py ./

# Set the entrypoint to run your script
ENTRYPOINT ["python3", "run.py"]
18 changes: 18 additions & 0 deletions src/compute_horde_prompt_gen/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,34 @@
default="./saved_models/",
help="Path to save the model and tokenizer to",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Quantize the model",
default=False,
)

args = parser.parse_args()
save_path = os.path.join(args.save_path, args.model_name)
model_name = MODEL_PATHS[args.model_name]
print(f"Saving {model_name} model to {save_path}")

if args.quantize:
import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
llm_int8_enable_fp32_cpu_offload=False,
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
)
print("using quantized model")

model = AutoModelForCausalLM.from_pretrained(
model_name,
# either give token directly or assume logged in with huggingface-cli
token=args.huggingface_token or True,
quantization_config=quantization_config,
)
model.save_pretrained(save_path)

Expand Down
4 changes: 2 additions & 2 deletions src/compute_horde_prompt_gen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,22 @@ def __init__(self):
pass

def generate(self, prompts: list[str], num_return_sequences: int, **_kwargs):
content = f"Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n"
content = "Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n"
return [content for _ in range(len(prompts) * num_return_sequences)]


class GenerativeModel:
def __init__(self, model_path: str, quantize: bool = False):
self.input_prompt_ending = None

import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
)

quantization_config = None
if quantize:
import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
Expand Down

0 comments on commit 55607cd

Please sign in to comment.