Skip to content

Commit

Permalink
Merge pull request #13 from joefernandez/gemini-refactor
Browse files Browse the repository at this point in the history
Gemini refactor
  • Loading branch information
cannoneyed authored Mar 20, 2024
2 parents 7d012e6 + ec6c5b8 commit 4ea64b0
Show file tree
Hide file tree
Showing 21 changed files with 209 additions and 194 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ Wordcraft is an LLM-powered text editor with an emphasis on short story writing.

Wordcraft is a tool built by researchers at Google
[PAIR](https://pair.withgoogle.com/) for writing stories with AI. The
application is powered by LLMs such as
[PaLM](https://developers.generativeai.google/), one of the latest generation of
large language models. At its core, LLMs are simple machines — it's trained to
application is powered by generative models such as
[Gemini](https://ai.google.dev/docs/).
At its core, generative models are simple machines — it's trained to
predict the most likely next word given a textual prompt. But because the model
is so large and has been trained on a massive amount of text, it's able to learn
higher-level concepts. It also demonstrates a fascinating emergent capability
Expand All @@ -32,23 +32,23 @@ npm run dev

# ☁️ API

In order to run Wordcraft, you'll need a PaLM API key. Please follow the
In order to run Wordcraft, you'll need a Gemini API key. Please follow the
instructions at
[developers.generativeai.google/tutorials/setup](https://developers.generativeai.google/tutorials/setup).
[ai.google.dev/tutorials/setup](https://ai.google.dev/tutorials/setup).
Once you have your API key, create a .env file and add the key!

```bash
touch .env
echo "PALM_API_KEY=\"<INSERT_PALM_API_KEY>\"" > .env
echo "API_KEY=\"<INSERT_API_KEY>\"" > .env
```

Remember, use your API keys securely. Do not share them with others, or embed
them directly in code that's exposed to the public! This application
stores/loads API keys on the client for ease of development, but these should be
removed in all production apps!

You can find more information about the PaLM 2 API at
[developers.generativeai.google](https://developers.generativeai.google/)
You can find more information about the Gemini API at
[ai.google.dev/docs/](https://ai.google.dev/docs/)

# 🤖 App

Expand Down Expand Up @@ -88,9 +88,9 @@ To add a new custom control (e.g. a button that translates into pig latin):
- Create a new `pig_latin_examples.json` in `/app/context/json/`
- Register the examples int the `WordCraftContext` constructor
(`/app/context/index.ts`)
- Create a corresponding prompt handler in `/app/models/palm/prompts`
- Create a corresponding prompt handler in `/app/models/gemini/prompts`
- Register that prompt handler with the underlying `Model` class in
`/app/models/palm/index.ts`
`/app/models/gemini/index.ts`
- Create a new `PigLatinOperation` in `/app/core/operations`
- Register the operation in `main.ts`

Expand Down
8 changes: 4 additions & 4 deletions app/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import {OperationsService} from '@services/operations_service';
import {WordcraftContext} from './context';
import {makeServiceProvider} from './service_provider';
import {InitializationService} from '@services/initialization_service';
import {PalmModel} from '@models/palm';
import {PalmDialogModel} from '@models/palm/dialog';
import {GeminiModel} from '@models/gemini';
import {GeminiDialogModel} from '@models/gemini/dialog';

wordcraftCore.initialize(makeServiceProvider);

Expand All @@ -55,8 +55,8 @@ operationsService.registerOperations(

// Register prompts with models
const modelService = wordcraftCore.getService(ModelService);
modelService.useModel(PalmModel);
modelService.useDialogModel(PalmDialogModel);
modelService.useModel(GeminiModel);
modelService.useDialogModel(GeminiDialogModel);

// Initialize the app after page load, so that all of the javascript is present
// before we build the app.
Expand Down
123 changes: 123 additions & 0 deletions app/models/gemini/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/**
* @license
*
* Copyright 2023 Google LLC.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ==============================================================================
*/

// set up Gemini generative AI library
const { GoogleGenerativeAI } = require("@google/generative-ai");
const genAI = new GoogleGenerativeAI(process.env.API_KEY);
// Remember to set an environment variable for API_KEY in .env

import { HarmBlockThreshold, HarmCategory } from "@google/generative-ai";
import { DialogParams } from '@core/shared/interfaces';

// Default safety settings
const safetySettings = [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
];

export interface ModelParams {
generationConfig?: {
topK?: number;
topP?: number;
candidateCount?: number;
maxOutputTokens?: number;
temperature?: number;
}
}

const DEFAULT_GENERATION_PARAMS: ModelParams = {
generationConfig: {
temperature: 0.8,
topK: 40,
topP: 0.95,
candidateCount: 8
}
};

const TEXT_MODEL_ID = 'gemini-pro';
const DIALOG_MODEL_ID = 'gemini-pro';

export async function callTextModel(
textPrompt: string,
genConfig: ModelParams) {
// set any passed parameters
genConfig = Object.assign({}, DEFAULT_GENERATION_PARAMS, genConfig);
genConfig.generationConfig.maxOutputTokens = 1024;

const model = genAI.getGenerativeModel({
model: TEXT_MODEL_ID, genConfig, safetySettings
});
const result = await model.generateContent(textPrompt);
const response = await result.response;
return response.text();
}

export async function callDialogModel(
chatParams: DialogParams,
genConfig: ModelParams) {
// set any passed parameters
genConfig = Object.assign({}, DEFAULT_GENERATION_PARAMS, genConfig);
// set dialog-specific model parameters
genConfig.generationConfig.temperature = 0.7;
genConfig.generationConfig.candidateCount = 1;

const model = genAI.getGenerativeModel({
model: DIALOG_MODEL_ID, genConfig, safetySettings
});

// get lastest chat request (last message)
const lastMsgIndex = chatParams.messages.length - 1;
const message = chatParams.messages[lastMsgIndex].content;

// set chat history
const history = remapHistory(chatParams);
console.log("history (object):\n", history);
const chat = model.startChat( history );

const result = await chat.sendMessage(message);
const response = await result.response;
return response.text();
}

export function remapHistory(chatParams: DialogParams) {
const remappedMessageHistory = [];

// skip the first and last messages
for (let i = 1; i < chatParams.messages.length - 1; i++) {
remappedMessageHistory.push({
role: chatParams.messages[i].author,
parts: chatParams.messages[i].content
});
}
return remappedMessageHistory;
}
30 changes: 9 additions & 21 deletions app/models/palm/dialog.ts → app/models/gemini/dialog.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import {DialogParams} from '@core/shared/interfaces';
import {DialogModel} from '../dialog_model';
import {callDialogModel, ModelParams} from './api';
import { callDialogModel, ModelParams } from './api';
import {createModelResults} from '../utils';

import {ContextService, StatusService} from '@services/services';
Expand All @@ -29,35 +29,23 @@ interface ServiceProvider {
statusService: StatusService;
}
/**
* A Model representing PaLM Dialog API.
* A Model representing Gemini API for chat.
*/
export class PalmDialogModel extends DialogModel {
export class GeminiDialogModel extends DialogModel {
constructor(serviceProvider: ServiceProvider) {
super(serviceProvider);
}

override async query(
params: DialogParams,
chatParams: DialogParams,
modelParams: Partial<ModelParams> = {}
) {
let temperature = (params as any).temperature;
temperature = temperature === undefined ? 0.7 : temperature;
console.log('🚀 DialogParams: ', JSON.stringify(chatParams));

const queryParams = {
...modelParams,
candidateCount: 1,
prompt: {
messages: params.messages,
},
temperature: temperature,
};

const res = await callDialogModel(queryParams);
const response = await res.json();
console.log('🚀 model results: ', response);

const responseText = response.candidates?.length
? response.candidates.map((candidate) => candidate.content)
const singleResponse = await callDialogModel(chatParams, modelParams);
console.log('🚀 model results: ', singleResponse);
const responseText = singleResponse.length
? [singleResponse]
: [];

const results = createModelResults(responseText);
Expand Down
41 changes: 22 additions & 19 deletions app/models/palm/index.ts → app/models/gemini/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ interface ServiceProvider {
}

/**
* A Model representing PaLM API.
* A Model representing Gemini API.
*/
export class PalmModel extends Model {
export class GeminiModel extends Model {
constructor(serviceProvider: ServiceProvider) {
super(serviceProvider);
}
Expand Down Expand Up @@ -125,26 +125,15 @@ export class PalmModel extends Model {
params: Partial<ModelParams> = {},
shouldParse = true
) {
let temperature = (params as any).temperature;
temperature = temperature === undefined ? 1 : temperature;

const modelParams = {
...params,
prompt: {
text: promptText,
},
temperature: temperature,
};

// candidateCount: setting is being rejected or ignored. workaround:
promptText = promptText + "\nGenerate 8 responses. " +
"Each response must start with: " + D0 + " and end with: " + D1;
console.log('🚀 prompt text: ', promptText);

const res = await callTextModel(modelParams);
const response = await res.json();
console.log('🚀 model results: ', response);
const res = await callTextModel(promptText, params);
console.log('🚀 model results: ', res);

const responseText = response.candidates?.length
? response.candidates.map((candidate) => candidate.output)
: [];
const responseText = getListOfReponses(res, D0, D1);

const results = createModelResults(responseText);
const output = shouldParse
Expand All @@ -169,3 +158,17 @@ export class PalmModel extends Model {
override rewriteSentence = this.makePromptHandler(rewriteSentence);
override suggestRewrite = this.makePromptHandler(suggestRewrite);
}

/** Get text between two delimiters */
export function getListOfReponses(txt: string, d0: string, d1: string) {
// Note: s flag indicates a "single line", which counts newlines as characters
// allowing the regex to capture multi-line output
const re = new RegExp(`(?<=${d0})(.*?)(?=${d1})`, 'gms');
const matches = txt.match(re);
const responseList = [];
for (const match of matches) {
// re-add the curly brackets
responseList.push("{" + match + "}");
}
return responseList;
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import {ContinuePromptParams} from '@core/shared/interfaces';
import {ContinueExample, WordcraftContext} from '../../../context';
import {OperationType} from '@core/shared/types';
import {PalmModel} from '..';
import { GeminiModel } from '..';

export function makePromptHandler(model: PalmModel, context: WordcraftContext) {
export function makePromptHandler(model: GeminiModel, context: WordcraftContext) {
function generatePrompt(text: string) {
const prefix = model.getStoryPrefix();
const suffix = 'Continue the story: ';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import {ElaboratePromptParams} from '@core/shared/interfaces';
import {ElaborateExample, WordcraftContext} from '../../../context';
import {OperationType} from '@core/shared/types';
import {PalmModel} from '..';
import { GeminiModel } from '..';

export function makePromptHandler(model: PalmModel, context: WordcraftContext) {
export function makePromptHandler(model: GeminiModel, context: WordcraftContext) {
function generatePrompt(text: string, subject: string) {
const prefix = model.getStoryPrefix();
const suffix = `Describe "${subject}" in more detail.`;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import {FirstSentencePromptParams} from '@core/shared/interfaces';
import {FirstSentenceExample, WordcraftContext} from '../../../context';
import {OperationType} from '@core/shared/types';
import {PalmModel} from '..';
import { GeminiModel } from '..';
import {parseSentences} from '@lib/parse_sentences';

export function makePromptHandler(model: PalmModel, context: WordcraftContext) {
export function makePromptHandler(model: GeminiModel, context: WordcraftContext) {
function generatePrompt(textAfterBlank: string) {
const prefix = model.getStoryPrefix();
const suffix = 'Tell me the first sentence that fills in the blank: ';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import {shuffle} from '@lib/utils';
import {FreeformPromptParams} from '@core/shared/interfaces';
import {FreeformExample, WordcraftContext} from '../../../context';
import {OperationType} from '@core/shared/types';
import {PalmModel} from '..';
import { GeminiModel } from '..';

export function makePromptHandler(model: PalmModel, context: WordcraftContext) {
export function makePromptHandler(model: GeminiModel, context: WordcraftContext) {
function getPromptContext() {
const examples = context.getExampleData<FreeformExample>(
OperationType.FREEFORM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ import {
WordcraftContext,
} from '../../../context';
import {OperationType} from '@core/shared/types';
import {PalmModel} from '..';
import { GeminiModel } from '..';

export function makePromptHandler(model: PalmModel, context: WordcraftContext) {
export function makePromptHandler(model: GeminiModel, context: WordcraftContext) {
function generatePrompt(
textBeforeBlank: string,
textAfterBlank: string,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import {MetaPromptPromptParams} from '@core/shared/interfaces';
import {MetaPromptExample, WordcraftContext} from '../../../context';
import {OperationType} from '@core/shared/types';
import {endsWithPunctuation} from '@lib/parse_sentences/utils';
import {PalmModel} from '..';
import { GeminiModel } from '..';

export function makePromptHandler(model: PalmModel, context: WordcraftContext) {
export function makePromptHandler(model: GeminiModel, context: WordcraftContext) {
function generatePrompt(text: string) {
const prefix = model.getStoryPrefix();
const suffix = 'Next prompt:';
Expand Down
Loading

0 comments on commit 4ea64b0

Please sign in to comment.