diff --git a/README.md b/README.md index f4c1fee..87612b8 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,9 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) Script to generate batches of random unique prompts to be used in the Compute Horde project synthetic jobs. + The prompt that generates prompts is inspired from [Bittensor Subnet 18 (Cortex. t)] (https://github.com/Datura-ai/cortex.t/blob/276cfcf742e8b442500435a1c1862ac4dffa9e20/cortext/utils.py#L193) (licensed under the MIT License.) + The generated prompts will be saved in `/prompts_.txt`, each line of the text file containing a prompt. diff --git a/pdm.lock b/pdm.lock index 3838b6f..03789e1 100644 --- a/pdm.lock +++ b/pdm.lock @@ -611,13 +611,13 @@ files = [ [[package]] name = "setuptools" -version = "74.1.0" +version = "74.1.1" requires_python = ">=3.8" summary = "Easily download, build, install, upgrade, and uninstall Python packages" groups = ["default"] files = [ - {file = "setuptools-74.1.0-py3-none-any.whl", hash = "sha256:cee604bd76cc092355a4e43ec17aee5369095974f41f088676724dc6bc2c9ef8"}, - {file = "setuptools-74.1.0.tar.gz", hash = "sha256:bea195a800f510ba3a2bc65645c88b7e016fe36709fefc58a880c4ae8a0138d7"}, + {file = "setuptools-74.1.1-py3-none-any.whl", hash = "sha256:fc91b5f89e392ef5b77fe143b17e32f65d3024744fba66dc3afe07201684d766"}, + {file = "setuptools-74.1.1.tar.gz", hash = "sha256:2353af060c06388be1cecbf5953dcdb1f38362f87a2356c480b6b4d5fcfc8847"}, ] [[package]] diff --git a/src/compute_horde_prompt_gen/model.py b/src/compute_horde_prompt_gen/model.py index 4998646..65f1c4f 100644 --- a/src/compute_horde_prompt_gen/model.py +++ b/src/compute_horde_prompt_gen/model.py @@ -1,9 +1,4 @@ -import torch import logging -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, -) from prompt import PROMPT_ENDING @@ -15,7 +10,7 @@ def __init__(self): pass def generate(self, prompts: list[str], num_return_sequences: int, **_kwargs): - return torch.rand(len(prompts) * num_return_sequences) + return [1 for _ in range(len(prompts) * num_return_sequences)] def decode(self, _output): return f"COPY PASTE INPUT PROMPT {PROMPT_ENDING} Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n" @@ -23,6 +18,12 @@ def decode(self, _output): class GenerativeModel: def __init__(self, model_path: str, quantize: bool = False): + import torch + from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + ) + quantization_config = None if quantize: from transformers import BitsAndBytesConfig