Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

25 make code base more maintainablereadable #31

Merged
merged 5 commits into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 29 additions & 83 deletions lib/src/main/java/com/frazik/instructgpt/Agent.java
Original file line number Diff line number Diff line change
@@ -1,49 +1,38 @@
package com.frazik.instructgpt;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.frazik.instructgpt.auto.Cli;
import com.frazik.instructgpt.embedding.OpenAIEmbeddingProvider;
import com.frazik.instructgpt.memory.LocalMemory;
import com.frazik.instructgpt.models.OpenAIModel;
import com.frazik.instructgpt.prompts.Prompt;
import com.frazik.instructgpt.prompts.PromptHistory;
import com.frazik.instructgpt.response.Response;
import com.frazik.instructgpt.response.Thought;
import com.frazik.instructgpt.tools.Browser;
import com.frazik.instructgpt.tools.GoogleSearch;
import com.frazik.instructgpt.tools.Tool;
import com.google.gson.Gson;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONObject;

import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
@Slf4j
public class Agent {
private final String name;
private final String description;
private final List<String> goals;
private final Map<String, Object> subAgents;
private final LocalMemory memory;
private final PromptHistory history;
private final List<Tool> tools;
private Map<String, Object> stagingTool;

private JsonNode stagingResponse;
private final OpenAIModel openAIModel;

private final String responseFormat;

public Agent(String name, String description, List<String> goals, String model) {
this.history = new PromptHistory();
this.name = name;
this.description = description != null ? description : "A personal assistant that responds exclusively in JSON";
this.goals = goals != null ? goals : new ArrayList<>();
this.subAgents = new HashMap<>();
this.description = description;
this.goals = goals;
this.memory = new LocalMemory(new OpenAIEmbeddingProvider());
this.responseFormat = Constants.getDefaultResponseFormat();
this.tools = Arrays.asList(new Browser(), new GoogleSearch());
this.openAIModel = new OpenAIModel(model);
}
Expand All @@ -65,7 +54,7 @@ private List<Map<String, String>> getFullPrompt(String userInput) {
// Build current date and time prompt
Prompt currentTimePrompt = new Prompt.Builder("current_time")
.withRole("system")
.formatted(0, ZonedDateTime.now().format(DateTimeFormatter.ofPattern("EEE MMM dd HH:mm:ss yyyy")))
.formattedWithCurrentTime(0)
.build();
prompt.add(currentTimePrompt.getPrompt());

Expand Down Expand Up @@ -195,13 +184,9 @@ public Response chat(String message, boolean runTool) {
this.history.addNewPrompt("user", message);
this.history.addNewPrompt("assistant", resp);

try {
JsonNode parsedResp = this.loadJson(resp);

if (parsedResp.has("name")) {
parsedResp = mapper.createObjectNode().set("command", parsedResp);
}

Response response = Response.getResponseFromRaw(resp);
if (response != null) {
JsonNode parsedResp = response.getParsedResp();
String commandArgs = parsedResp.get("command").get("args").asText();
String commandName = parsedResp.get("command").get("name").asText();

Expand All @@ -210,57 +195,12 @@ public Response chat(String message, boolean runTool) {
this.stagingTool.put("name", commandName);

this.stagingResponse = parsedResp;
// Parse the 'thoughts' and 'command' parts of the response into objects
if (parsedResp.has("thoughts") && parsedResp.has("command")) {
JsonNode thoughtsNode = parsedResp.get("thoughts");
Thought thoughts = new Thought(
thoughtsNode.get("text").asText(),
thoughtsNode.get("reasoning").asText(),
thoughtsNode.get("plan").asText(),
thoughtsNode.get("criticism").asText(),
thoughtsNode.get("speak").asText()
);
JsonNode commandNode = parsedResp.get("command");
return new Response(thoughts, commandNode.get("name").asText());
}

} catch (Exception e) {
log.error("Error parsing response: " + resp, e);
return response;
}

return null;
}

private static final ObjectMapper mapper = new ObjectMapper();

private JsonNode loadJson(String s) throws Exception {
if (!s.contains("{") || !s.contains("}")) {
throw new Exception();
}
try {
return mapper.readTree(s);
} catch (Exception e1) {
int startIndex = s.indexOf("{");
int endIndex = s.indexOf("}") + 1;
String subString = s.substring(startIndex, endIndex);
try {
return mapper.readTree(subString);
} catch (Exception e2) {
subString += "}";
try {
return mapper.readTree(subString);
} catch (Exception e3) {
subString = subString.replace("'", "\"");
try {
return mapper.readTree(subString);
} catch (Exception e4) {
throw new Exception();
}
}
}
}
}

public String lastUserInput() {
for (int i = history.getSize() - 1; i >= 0; i--) {
Map<String, String> msg = history.getValue(i);
Expand Down Expand Up @@ -330,7 +270,6 @@ public Object runStagingTool() {

public void clearState() {
history.clear();
subAgents.clear();
memory.clear();
}

Expand All @@ -347,13 +286,20 @@ public String headerPrompt() {
}
prompt.add(resourcesPrompt());
prompt.add(evaluationPrompt());
prompt.add(this.responseFormat);
prompt.add(defaultResponsePrompt());
return newLineDelimited(prompt);
}

public String defaultResponsePrompt() {
String defaultResponse = Prompt.getDefaultResponse();
Prompt defaultResponsePrompt = new Prompt.Builder("use_only_defined_format")
.formatted(0, defaultResponse)
.build();
return defaultResponsePrompt.getContent();
}
public String personaPrompt() {
Prompt personaPrompt = new Prompt.Builder("persona")
.formatted(0, "name", "description")
.formatted(0, name, description)
.build();
return personaPrompt.getContent();
}
Expand All @@ -373,6 +319,19 @@ public String constraintsPrompt() {
return constraintsPrompt.getContent();
}

public String resourcesPrompt() {
Prompt resourcesPrompt = new Prompt.Builder("resources")
.build();
return resourcesPrompt.getContent();
}

public String evaluationPrompt() {
Prompt evaluationPrompt = new Prompt.Builder("evaluation")
.delimited()
.build();
return evaluationPrompt.getContent();
}

/**
* The given code snippet represents a method called tools_prompt() that generates a prompt for a list of tools in a specific format. Here's a breakdown of what the code does:
* It initializes an empty list called prompt to store the lines of the prompt.
Expand Down Expand Up @@ -411,19 +370,6 @@ public String toolsPrompt() {
return newLineDelimited(prompt);
}

public String resourcesPrompt() {
Prompt resourcesPrompt = new Prompt.Builder("resources")
.build();
return resourcesPrompt.getContent();
}

public String evaluationPrompt() {
Prompt evaluationPrompt = new Prompt.Builder("evaluation")
.delimited()
.build();
return evaluationPrompt.getContent();
}

private static String newLineDelimited(List<String> prompt) {
return String.join("\n", prompt) + "\n";
}
Expand Down
44 changes: 0 additions & 44 deletions lib/src/main/java/com/frazik/instructgpt/Constants.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import java.util.Collections;
import java.util.List;
import java.util.Map;

public class OpenAIEmbeddingProvider extends EmbeddingProvider {
public static final String OPENAI_API_KEY = "OPENAI_API_KEY";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


@Slf4j
Expand Down
51 changes: 41 additions & 10 deletions lib/src/main/java/com/frazik/instructgpt/prompts/Prompt.java
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
package com.frazik.instructgpt.prompts;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.google.gson.Gson;
import org.openqa.selenium.json.TypeToken;

import java.io.*;
import java.util.*;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Prompt {
private final String role;
private final String content;
private static Map<String, List<String>> promptsBundle = new HashMap<>();
private static final Map<String, List<String>> promptsBundle;
private static final Map<String, Map<String, String>> defaultResponsesJson;
private static final ObjectMapper objectMapper = new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT);

static {
readPromptJson();
// Use generic types to prepare for future expansion
TypeToken<Map<String, List<String>>> promptsToken = new TypeToken<Map<String, List<String>>>() {};
promptsBundle = readPromptJson(promptsToken, "prompts_en.json");
// Use generic types to prepare for future expansion
TypeToken<Map<String, Map<String, String>>> defaultResponsesToken =
new TypeToken<Map<String, Map<String, String>>>() {};
defaultResponsesJson = readPromptJson(defaultResponsesToken, "default_response_en.json");
}

public Prompt(String role, String content) {
Expand All @@ -23,6 +42,14 @@ public Prompt(String role, String content) {
public String getContent() {
return content;
}

public static String getDefaultResponse() {
try {
return objectMapper.writeValueAsString(defaultResponsesJson);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
public Map<String, String> getPrompt() {
Map<String, String> prompt = new HashMap<>();
prompt.put("role", role);
Expand Down Expand Up @@ -58,6 +85,11 @@ public Builder formatted(int i, Object... args) {
prompts.set(i, String.format(prompts.get(i), args));
return this;
}

public Builder formattedWithCurrentTime(int i) {
String currentTime = ZonedDateTime.now().format(DateTimeFormatter.ofPattern("EEE MMM dd HH:mm:ss yyyy"));
return formatted(i, currentTime);
}
public Builder withRole(String role) {
this.role = role;
return this;
Expand All @@ -69,19 +101,18 @@ public Prompt build() {
}
}

private static void readPromptJson() {
private static <T> T readPromptJson(TypeToken<T> token, String jsonFileName) {
try {
InputStream inputStream = Prompt.class.getClassLoader().getResourceAsStream("prompts_en.json");
InputStream inputStream = Prompt.class.getClassLoader().getResourceAsStream(jsonFileName);

if (inputStream == null) {
throw new FileNotFoundException("prompts_en.json file not found.");
throw new FileNotFoundException(jsonFileName + " file not found.");
}

InputStreamReader reader = new InputStreamReader(inputStream);
TypeToken<Map<String, List<String>>> token = new TypeToken<Map<String, List<String>>>() {};
promptsBundle = new Gson().fromJson(reader, token.getType());
return new Gson().fromJson(reader, token.getType());
} catch (IOException e) {
throw new RuntimeException("Error reading prompts_en.json file.", e);
throw new RuntimeException("Error reading " + jsonFileName, e);
}
}
}
Loading
Loading