Skip to content

Commit

Permalink
Improves Java API signatures maintaining back compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
lordofthejars authored and manyoso committed Oct 12, 2023
1 parent f39df09 commit 3c45a55
Show file tree
Hide file tree
Showing 2 changed files with 265 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;

public class LLModel implements AutoCloseable {

Expand Down Expand Up @@ -306,6 +305,197 @@ static LLModelLibrary.ResponseCallback getResponseCallback(boolean streamToStdOu
};
}

/**
* The array of messages for the conversation.
*/
public static class Messages {

private final List<PromptMessage> messages = new ArrayList<>();

public Messages(PromptMessage...messages) {
this.messages.addAll(Arrays.asList(messages));
}

public Messages(List<PromptMessage> messages) {
this.messages.addAll(messages);
}

public Messages addPromptMessage(PromptMessage promptMessage) {
this.messages.add(promptMessage);
return this;
}

List<PromptMessage> toList() {
return Collections.unmodifiableList(this.messages);
}

List<Map<String, String>> toListMap() {
return messages.stream()
.map(PromptMessage::toMap).collect(Collectors.toList());
}

}

/**
* A message in the conversation, identical to OpenAI's chat message.
*/
public static class PromptMessage {

private static final String ROLE = "role";
private static final String CONTENT = "content";

private final Map<String, String> message = new HashMap<>();

public PromptMessage() {
}

public PromptMessage(Role role, String content) {
addRole(role);
addContent(content);
}

public PromptMessage addRole(Role role) {
return this.addParameter(ROLE, role.type());
}

public PromptMessage addContent(String content) {
return this.addParameter(CONTENT, content);
}

public PromptMessage addParameter(String key, String value) {
this.message.put(key, value);
return this;
}

public String content() {
return this.parameter(CONTENT);
}

public Role role() {
String role = this.parameter(ROLE);
return Role.from(role);
}

public String parameter(String key) {
return this.message.get(key);
}

Map<String, String> toMap() {
return Collections.unmodifiableMap(this.message);
}

}

public enum Role {

SYSTEM("system"), ASSISTANT("assistant"), USER("user");

private final String type;

String type() {
return this.type;
}

static Role from(String type) {

if (type == null) {
return null;
}

switch (type) {
case "system": return SYSTEM;
case "assistant": return ASSISTANT;
case "user": return USER;
default: throw new IllegalArgumentException(
String.format("You passed %s type but only %s are supported",
type, Arrays.toString(Role.values())
)
);
}
}

Role(String type) {
this.type = type;
}

@Override
public String toString() {
return type();
}
}

/**
* The result of the completion, similar to OpenAI's format.
*/
public static class CompletionReturn {
private String model;
private Usage usage;
private Choices choices;

public CompletionReturn(String model, Usage usage, Choices choices) {
this.model = model;
this.usage = usage;
this.choices = choices;
}

public Choices choices() {
return choices;
}

public String model() {
return model;
}

public Usage usage() {
return usage;
}
}

/**
* The generated completions.
*/
public static class Choices {

private final List<CompletionChoice> choices = new ArrayList<>();

public Choices(List<CompletionChoice> choices) {
this.choices.addAll(choices);
}

public Choices(CompletionChoice...completionChoices){
this.choices.addAll(Arrays.asList(completionChoices));
}

public Choices addCompletionChoice(CompletionChoice completionChoice) {
this.choices.add(completionChoice);
return this;
}

public CompletionChoice first() {
return this.choices.get(0);
}

public int totalChoices() {
return this.choices.size();
}

public CompletionChoice get(int index) {
return this.choices.get(index);
}

public List<CompletionChoice> choices() {
return Collections.unmodifiableList(choices);
}
}

/**
* A completion choice, similar to OpenAI's format.
*/
public static class CompletionChoice extends PromptMessage {
public CompletionChoice(Role role, String content) {
super(role, content);
}
}

public static class ChatCompletionResponse {
public String model;
Expand All @@ -323,6 +513,41 @@ public static class Usage {
// Getters and setters
}

public CompletionReturn chatCompletionResponse(Messages messages,
GenerationConfig generationConfig) {
return chatCompletion(messages, generationConfig, false, false);
}

/**
* chatCompletion formats the existing chat conversation into a template to be
* easier to process for chat UIs. It is not absolutely necessary as generate method
* may be directly used to make generations with gpt models.
*
* @param messages object to create theMessages to send to GPT model
* @param generationConfig How to decode/process the generation.
* @param streamToStdOut Send tokens as they are calculated Standard output.
* @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output.
* @return CompletionReturn contains stats and generated Text.
*/
public CompletionReturn chatCompletion(Messages messages,
GenerationConfig generationConfig, boolean streamToStdOut,
boolean outputFullPromptToStdOut) {

String fullPrompt = buildPrompt(messages.toListMap());

if(outputFullPromptToStdOut)
System.out.print(fullPrompt);

String generatedText = generate(fullPrompt, generationConfig, streamToStdOut);

final CompletionChoice promptMessage = new CompletionChoice(Role.ASSISTANT, generatedText);
final Choices choices = new Choices(promptMessage);

final Usage usage = getUsage(fullPrompt, generatedText);
return new CompletionReturn(this.modelName, usage, choices);

}

public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages,
GenerationConfig generationConfig) {
return chatCompletion(messages, generationConfig, false, false);
Expand Down Expand Up @@ -352,19 +577,23 @@ public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages
ChatCompletionResponse response = new ChatCompletionResponse();
response.model = this.modelName;

Usage usage = new Usage();
usage.promptTokens = fullPrompt.length();
usage.completionTokens = generatedText.length();
usage.totalTokens = fullPrompt.length() + generatedText.length();
response.usage = usage;
response.usage = getUsage(fullPrompt, generatedText);

Map<String, String> message = new HashMap<>();
message.put("role", "assistant");
message.put("content", generatedText);

response.choices = List.of(message);

return response;

}

private Usage getUsage(String fullPrompt, String generatedText) {
Usage usage = new Usage();
usage.promptTokens = fullPrompt.length();
usage.completionTokens = generatedText.length();
usage.totalTokens = fullPrompt.length() + generatedText.length();
return usage;
}

protected static String buildPrompt(List<Map<String, String>> messages) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,33 @@
@ExtendWith(MockitoExtension.class)
public class BasicTests {

@Test
public void simplePromptWithObject(){

LLModel model = Mockito.spy(new LLModel());

LLModel.GenerationConfig config =
LLModel.config()
.withNPredict(20)
.build();

// The generate method will return "4"
doReturn("4").when( model ).generate(anyString(), eq(config), eq(true));

LLModel.PromptMessage promptMessage1 = new LLModel.PromptMessage(LLModel.Role.SYSTEM, "You are a helpful assistant");
LLModel.PromptMessage promptMessage2 = new LLModel.PromptMessage(LLModel.Role.USER, "Add 2+2");

LLModel.Messages messages = new LLModel.Messages(promptMessage1, promptMessage2);

LLModel.CompletionReturn response = model.chatCompletion(
messages, config, true, true);

assertTrue( response.choices().first().content().contains("4") );

// Verifies the prompt and response are certain length.
assertEquals( 224 , response.usage().totalTokens );
}

@Test
public void simplePrompt(){

Expand Down

0 comments on commit 3c45a55

Please sign in to comment.