diff --git a/src/compute_horde_prompt_gen/Dockerfile b/src/compute_horde_prompt_gen/Dockerfile index ca7b698..a504b03 100644 --- a/src/compute_horde_prompt_gen/Dockerfile +++ b/src/compute_horde_prompt_gen/Dockerfile @@ -18,7 +18,7 @@ RUN mkdir /output # Copy your Python script into the container COPY saved_models/ /app/saved_models/ -COPY *.py . +COPY *.py ./ # Set the entrypoint to run your script ENTRYPOINT ["python3", "run.py"] diff --git a/src/compute_horde_prompt_gen/download_model.py b/src/compute_horde_prompt_gen/download_model.py index 7516d76..0975e23 100644 --- a/src/compute_horde_prompt_gen/download_model.py +++ b/src/compute_horde_prompt_gen/download_model.py @@ -31,16 +31,34 @@ default="./saved_models/", help="Path to save the model and tokenizer to", ) + parser.add_argument( + "--quantize", + action="store_true", + help="Quantize the model", + default=False, + ) args = parser.parse_args() save_path = os.path.join(args.save_path, args.model_name) model_name = MODEL_PATHS[args.model_name] print(f"Saving {model_name} model to {save_path}") + if args.quantize: + import torch + from transformers import BitsAndBytesConfig + + quantization_config = BitsAndBytesConfig( + llm_int8_enable_fp32_cpu_offload=False, + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + ) + print("using quantized model") + model = AutoModelForCausalLM.from_pretrained( model_name, # either give token directly or assume logged in with huggingface-cli token=args.huggingface_token or True, + quantization_config=quantization_config, ) model.save_pretrained(save_path) diff --git a/src/compute_horde_prompt_gen/model.py b/src/compute_horde_prompt_gen/model.py index d68b513..43218f3 100644 --- a/src/compute_horde_prompt_gen/model.py +++ b/src/compute_horde_prompt_gen/model.py @@ -15,7 +15,7 @@ def __init__(self): pass def generate(self, prompts: list[str], num_return_sequences: int, **_kwargs): - content = f"Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n" + content = "Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n" return [content for _ in range(len(prompts) * num_return_sequences)] @@ -23,7 +23,6 @@ class GenerativeModel: def __init__(self, model_path: str, quantize: bool = False): self.input_prompt_ending = None - import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, @@ -31,6 +30,7 @@ def __init__(self, model_path: str, quantize: bool = False): quantization_config = None if quantize: + import torch from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig(