Skip to content

Commit

Permalink
Create MVP AI console (#934)
Browse files Browse the repository at this point in the history
  • Loading branch information
wwwillchen authored Sep 10, 2024
1 parent 4d9f803 commit 6f021e8
Show file tree
Hide file tree
Showing 37 changed files with 2,342 additions and 11 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ __pycache__

*.log

# This is a git submodule
/ai/data/

# Do not save generated files
/ai/ft/outputs/
/ai/outputs/
Expand Down
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[submodule "third_party/angular_components"]
path = third_party/angular_components
url = https://github.com/angular/components.git

[submodule "ai/data"]
path = ai/data
url = git@hf.co:datasets/wwwillchen/mesop-data
18 changes: 18 additions & 0 deletions ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,24 @@ All the commands should be run from the `ai/` directory.
- All entry-points are in `src/*.py` - this includes the AI service and scripts.
- `src/common` contains code that's shared between offline scripts and the online service.

## AI Console

**Setup**:

```sh
git clone git@hf.co:datasets/wwwillchen/mesop-data data
```

**Running**:

Inside `ai/src/`, run the following command:

```sh
mesop console.py --port=32124
```

> Note: you can run this on a separate port to avoid conflicting with the main Mesop development app.
## Scripts

These are scripts used to generate and process data for offline evaluation.
Expand Down
36 changes: 36 additions & 0 deletions ai/src/ai/common/diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import re
from typing import NamedTuple

EDIT_HERE_MARKER = " # <--- EDIT HERE"


class ApplyPatchResult(NamedTuple):
has_error: bool
result: str


def apply_patch(original_code: str, patch: str) -> ApplyPatchResult:
# Extract the diff content
diff_pattern = r"<<<<<<< ORIGINAL(.*?)=======\n(.*?)>>>>>>> UPDATED"
matches = re.findall(diff_pattern, patch, re.DOTALL)
patched_code = original_code
if len(matches) == 0:
print("[WARN] No diff found:", patch)
return ApplyPatchResult(
True,
"[AI-001] Sorry! AI output was mis-formatted. Please try again.",
)
for original, updated in matches:
original = original.strip().replace(EDIT_HERE_MARKER, "")
updated = updated.strip().replace(EDIT_HERE_MARKER, "")

# Replace the original part with the updated part
new_patched_code = patched_code.replace(original, updated, 1)
if new_patched_code == patched_code:
return ApplyPatchResult(
True,
"[AI-002] Sorry! AI output could not be used. Please try again.",
)
patched_code = new_patched_code

return ApplyPatchResult(False, patched_code)
49 changes: 49 additions & 0 deletions ai/src/ai/common/entity_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
from typing import Generic, TypeVar

from pydantic import BaseModel

T = TypeVar("T", bound=BaseModel)


def get_data_path(dirname: str) -> str:
return os.path.join(
os.path.dirname(__file__), "..", "..", "..", "data", dirname
)


class EntityStore(Generic[T]):
def __init__(self, entity_type: type[T], *, dirname: str):
self.entity_type = entity_type
self.directory_path = get_data_path(dirname)

def get(self, id: str) -> T:
file_path = os.path.join(self.directory_path, f"{id}.json")
with open(file_path) as f:
entity_json = f.read()
entity = self.entity_type.model_validate_json(entity_json)
return entity

def get_all(self) -> list[T]:
entities: list[T] = []
for filename in os.listdir(self.directory_path):
if filename.endswith(".json"):
file_path = os.path.join(self.directory_path, filename)
with open(file_path) as f:
entity_json = f.read()
entities.append(self.entity_type.model_validate_json(entity_json))
entities.sort(key=lambda x: x.id, reverse=True)
return entities

def save(self, entity: T, overwrite: bool = False):
id = entity.id # type: ignore
entity_path = os.path.join(self.directory_path, f"{id}.json")
if not overwrite and os.path.exists(entity_path):
raise ValueError(
f"{self.entity_type.__name__} with id {id} already exists"
)
with open(entity_path, "w") as f:
f.write(entity.model_dump_json(indent=4))

def delete(self, entity_id: str):
os.remove(os.path.join(self.directory_path, f"{entity_id}.json"))
137 changes: 137 additions & 0 deletions ai/src/ai/common/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
An example is a single input/output pair.
- Examples are used for fine-tuning a model (i.e. golden example) or running an eval (i.e. expected example).
- There are two types of examples:
- **Golden Example**: A golden example is an example that is used to create a golden dataset.
- **Expected Example**: An expected example is an example that is used to evaluate a producer.
Internally, once an expected example has been run through an eval, we create an **evaluated example**, but you don't need to create this manually in the UI.
"""

import os
import shutil
from typing import Generic, Literal, TypeVar

from pydantic import BaseModel


class ExampleInput(BaseModel):
prompt: str
input_code: str | None = None
line_number_target: int | None = None


class BaseExample(BaseModel):
id: str
input: ExampleInput


class ExampleOutput(BaseModel):
output_code: str | None = None
raw_output: str | None = None
output_type: Literal["full", "diff"] = "diff"


class ExpectedExample(BaseExample):
expect_executable: bool = True
expect_type_checkable: bool = True


class ExpectResult(BaseModel):
name: Literal["executable", "type_checkable", "patchable"]
score: int # 0 or 1
message: str | None = None


class EvaluatedExampleOutput(BaseModel):
time_spent_secs: float
tokens: int
output: ExampleOutput
expect_results: list[ExpectResult]


class EvaluatedExample(BaseModel):
expected: ExpectedExample
outputs: list[EvaluatedExampleOutput]


class GoldenExample(BaseExample):
output: ExampleOutput


T = TypeVar("T", bound=BaseExample)


class ExampleStore(Generic[T]):
def __init__(self, entity_type: type[T], *, dirname: str):
self.entity_type = entity_type
self.directory_path = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "data", dirname
)

def get(self, id: str) -> T:
dir_path = os.path.join(self.directory_path, id)
json_path = os.path.join(dir_path, "example_input.json")
with open(json_path) as f:
entity_json = f.read()
entity = self.entity_type.model_validate_json(entity_json)
input = entity.input
input_py_path = os.path.join(dir_path, "input.py")
if os.path.exists(input_py_path):
with open(input_py_path) as f:
input.input_code = f.read()
if isinstance(entity, GoldenExample):
output_py_path = os.path.join(dir_path, "output.py")
if os.path.exists(output_py_path):
with open(output_py_path) as f:
entity.output.output_code = f.read()
raw_output_path = os.path.join(dir_path, "raw_output.txt")
if os.path.exists(raw_output_path):
with open(raw_output_path) as f:
entity.output.raw_output = f.read()
return entity

def get_all(self) -> list[T]:
entities: list[T] = []
for filename in os.listdir(self.directory_path):
entities.append(self.get(filename))
return entities

def save(self, entity: T, overwrite: bool = False):
id = entity.id
dir_path = os.path.join(self.directory_path, id)

if not overwrite:
if os.path.exists(dir_path):
raise ValueError(
f"{self.entity_type.__name__} with id {id} already exists"
)
else:
os.mkdir(dir_path)
json_path = os.path.join(dir_path, "example_input.json")
input_code = entity.input.input_code
if input_code:
input_py_path = os.path.join(dir_path, "input.py")
with open(input_py_path, "w") as f:
f.write(input_code)
entity.input.input_code = None

if isinstance(entity, GoldenExample):
output_py_path = os.path.join(dir_path, "output.py")
with open(output_py_path, "w") as f:
f.write(entity.output.output_code)
raw_output_path = os.path.join(dir_path, "raw_output.txt")
with open(raw_output_path, "w") as f:
f.write(entity.output.raw_output)
entity.output.output_code = None
entity.output.raw_output = None
with open(json_path, "w") as f:
f.write(entity.model_dump_json(indent=4))

def delete(self, entity_id: str):
shutil.rmtree(os.path.join(self.directory_path, entity_id))


expected_example_store = ExampleStore(
ExpectedExample, dirname="expected_examples"
)
golden_example_store = ExampleStore(GoldenExample, dirname="golden_examples")
128 changes: 128 additions & 0 deletions ai/src/ai/common/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from os import getenv
from typing import Iterator

from openai import OpenAI
from openai.types.chat import (
ChatCompletionMessageParam,
)

from ai.common.diff import EDIT_HERE_MARKER, ApplyPatchResult, apply_patch
from ai.common.entity_store import get_data_path
from ai.common.example import ExampleInput
from ai.common.model import model_store
from ai.common.producer import producer_store
from ai.common.prompt_context import prompt_context_store
from ai.common.prompt_fragment import PromptFragment, prompt_fragment_store


class ProviderExecutor:
def __init__(self, model_name: str, prompt_fragments: list[PromptFragment]):
self.model_name = model_name

self.prompt_fragments = [
PromptFragment(
id=pf.id,
role=pf.role,
chain_of_thought=pf.chain_of_thought,
content_value=get_content_value(pf),
content_path=None,
)
for pf in prompt_fragments
]

def format_messages(
self, input: ExampleInput
) -> list[ChatCompletionMessageParam]:
code = input.input_code or ""
# Add sentinel token based on line_number (1-indexed)
if input.line_number_target is not None:
code_lines = code.splitlines()
if 1 <= input.line_number_target <= len(code_lines):
code_lines[input.line_number_target - 1] += EDIT_HERE_MARKER
code = "\n".join(code_lines)

return [
{
"role": pf.role,
"content": pf.content_value.replace("<APP_CODE>", code).replace( # type: ignore
"<APP_CHANGES>", input.prompt
),
}
for pf in self.prompt_fragments
]

def execute(self, input: ExampleInput) -> str: ...

def execute_stream(self, input: ExampleInput) -> Iterator[str]: ...


class OpenaiExecutor(ProviderExecutor):
def __init__(self, model_name: str, prompt_fragments: list[PromptFragment]):
super().__init__(model_name, prompt_fragments)
self.client = OpenAI(
api_key=getenv("OPENAI_API_KEY"),
)

def execute(self, input: ExampleInput) -> str:
response = self.client.chat.completions.create(
model=self.model_name,
max_tokens=10_000,
messages=self.format_messages(input),
)
return response.choices[0].message.content or ""

def execute_stream(self, input: ExampleInput) -> Iterator[str]:
stream = self.client.chat.completions.create(
model=self.model_name,
max_tokens=10_000,
messages=self.format_messages(input),
stream=True,
)
for chunk in stream:
content = chunk.choices[0].delta.content
yield content or ""


provider_executors: dict[str, type[ProviderExecutor]] = {
"openai": OpenaiExecutor,
}


class ProducerExecutor:
def __init__(self, producer_id: str):
self.producer = producer_store.get(producer_id)

def get_provider_executor(self) -> ProviderExecutor:
prompt_context = prompt_context_store.get(self.producer.prompt_context_id)
prompt_fragments = [
prompt_fragment_store.get(pfid) for pfid in prompt_context.fragment_ids
]
model = model_store.get(self.producer.mesop_model_id)
provider_executor_type = provider_executors.get(model.provider)
if provider_executor_type is None:
raise ValueError(f"Provider {model.provider} not supported")
provider_executor = provider_executor_type(model.name, prompt_fragments)
return provider_executor

def execute(self, input: ExampleInput):
return self.get_provider_executor().execute(input)

def execute_stream(self, input: ExampleInput):
return self.get_provider_executor().execute_stream(input)

def transform_output(self, input_code: str, output: str):
if self.producer.output_format == "diff":
return apply_patch(input_code, output)
elif self.producer.output_format == "full":
return ApplyPatchResult(True, output)
else:
raise ValueError(f"Unknown output format: {self.producer.output_format}")


def get_content_value(pf: PromptFragment) -> str | None:
if pf.content_value is not None:
return pf.content_value
if pf.content_path is not None:
with open(get_data_path(pf.content_path.replace("//", ""))) as f:
return f.read()
return None
Loading

0 comments on commit 6f021e8

Please sign in to comment.