Skip to content

Commit

Permalink
Update Synth.py
Browse files Browse the repository at this point in the history
  • Loading branch information
win10ogod authored May 2, 2024
1 parent 2b29fe3 commit 0e77573
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions Synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,26 @@
import json
import os
import datetime
import glob

def generate_data(model, prompts, images=None, options=None, system=None, template=None, context=None, stream=False, raw=False, keep_alive="5m"):
def generate_data(model, prompts, options=None, images=None, system=None, template=None, context=None, stream=False, raw=False, done=False, keep_alive="30m", output_json=False):
url = "http://localhost:11434/api/generate"

generated_texts = []
json_data = []


for prompt in prompts:
payload = {
"model": model,
"prompt": prompt,
"stream": stream,
"top_p": 0.9,
"temperature":0.7,
"max_ctx":32768,
"done":done,
"raw": raw,
"keep_alive": keep_alive
payload = {
"model": model,
"prompt": prompt,
"stream": stream,
"top_p": 0.9,
"temperature":0.7,
"max_ctx":32768,
"done":done,
"raw": raw,
"keep_alive": keep_alive
}

if images:
Expand Down Expand Up @@ -48,11 +51,20 @@ def generate_data(model, prompts, images=None, options=None, system=None, templa
if isinstance(data, dict) and "response" in data:
generated_text = data["response"]
generated_texts.append(generated_text)
if output_json:
json_data.append({"instruction": prompt, "input": prompt, "output": generated_text})
else:
print("API response format is incorrect")
else:
print(f"Request failed, status code: {response.status_code}")

if output_json:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
json_file_name = f"generated_texts_{timestamp}.json"
with open(json_file_name, 'w', encoding='utf-8') as f:
json.dump(json_data, f, indent=4, ensure_ascii=False)
print(f"The generated JSON file has been saved to {json_file_name}")

return generated_texts

def load_file(file_path):
Expand All @@ -63,8 +75,10 @@ def load_prompts_from_file(file_path):
with open(file_path, 'r') as file:
prompts = file.readlines()
prompts = [prompt.strip() for prompt in prompts if prompt.strip()]

return prompts


def save_to_file(file_path, texts):
try:
with open(file_path, 'w') as file:
Expand All @@ -78,10 +92,13 @@ def create_output_file():
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"generated_texts_{timestamp}.txt"
return file_name


def main():
model = input("Please enter the model name: ")
prompt_type = input("Please select the prompt type (1-input prompt, 2-load prompt from file): ")
prompt_type = input("Please select the prompt type (1-input prompt, 2-load prompt from file:)")
output_json = input("Whether to output a JSON file (yes/no): ")
output_json = output_json.lower() == 'yes'

if prompt_type == "1":
prompts = []
Expand All @@ -93,11 +110,12 @@ def main():
elif prompt_type == "2":
file_path = input("Please enter the file path: ")
prompts = load_prompts_from_file(file_path)

else:
print("Invalid selection")
return

generated_texts = generate_data(model, prompts)
generated_texts = generate_data(model, prompts, output_json=output_json)

if generated_texts:
output_file = create_output_file()
Expand Down

0 comments on commit 0e77573

Please sign in to comment.