-
Notifications
You must be signed in to change notification settings - Fork 290
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4d9f803
commit 6f021e8
Showing
37 changed files
with
2,342 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
[submodule "third_party/angular_components"] | ||
path = third_party/angular_components | ||
url = https://github.com/angular/components.git | ||
|
||
[submodule "ai/data"] | ||
path = ai/data | ||
url = git@hf.co:datasets/wwwillchen/mesop-data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import re | ||
from typing import NamedTuple | ||
|
||
EDIT_HERE_MARKER = " # <--- EDIT HERE" | ||
|
||
|
||
class ApplyPatchResult(NamedTuple): | ||
has_error: bool | ||
result: str | ||
|
||
|
||
def apply_patch(original_code: str, patch: str) -> ApplyPatchResult: | ||
# Extract the diff content | ||
diff_pattern = r"<<<<<<< ORIGINAL(.*?)=======\n(.*?)>>>>>>> UPDATED" | ||
matches = re.findall(diff_pattern, patch, re.DOTALL) | ||
patched_code = original_code | ||
if len(matches) == 0: | ||
print("[WARN] No diff found:", patch) | ||
return ApplyPatchResult( | ||
True, | ||
"[AI-001] Sorry! AI output was mis-formatted. Please try again.", | ||
) | ||
for original, updated in matches: | ||
original = original.strip().replace(EDIT_HERE_MARKER, "") | ||
updated = updated.strip().replace(EDIT_HERE_MARKER, "") | ||
|
||
# Replace the original part with the updated part | ||
new_patched_code = patched_code.replace(original, updated, 1) | ||
if new_patched_code == patched_code: | ||
return ApplyPatchResult( | ||
True, | ||
"[AI-002] Sorry! AI output could not be used. Please try again.", | ||
) | ||
patched_code = new_patched_code | ||
|
||
return ApplyPatchResult(False, patched_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import os | ||
from typing import Generic, TypeVar | ||
|
||
from pydantic import BaseModel | ||
|
||
T = TypeVar("T", bound=BaseModel) | ||
|
||
|
||
def get_data_path(dirname: str) -> str: | ||
return os.path.join( | ||
os.path.dirname(__file__), "..", "..", "..", "data", dirname | ||
) | ||
|
||
|
||
class EntityStore(Generic[T]): | ||
def __init__(self, entity_type: type[T], *, dirname: str): | ||
self.entity_type = entity_type | ||
self.directory_path = get_data_path(dirname) | ||
|
||
def get(self, id: str) -> T: | ||
file_path = os.path.join(self.directory_path, f"{id}.json") | ||
with open(file_path) as f: | ||
entity_json = f.read() | ||
entity = self.entity_type.model_validate_json(entity_json) | ||
return entity | ||
|
||
def get_all(self) -> list[T]: | ||
entities: list[T] = [] | ||
for filename in os.listdir(self.directory_path): | ||
if filename.endswith(".json"): | ||
file_path = os.path.join(self.directory_path, filename) | ||
with open(file_path) as f: | ||
entity_json = f.read() | ||
entities.append(self.entity_type.model_validate_json(entity_json)) | ||
entities.sort(key=lambda x: x.id, reverse=True) | ||
return entities | ||
|
||
def save(self, entity: T, overwrite: bool = False): | ||
id = entity.id # type: ignore | ||
entity_path = os.path.join(self.directory_path, f"{id}.json") | ||
if not overwrite and os.path.exists(entity_path): | ||
raise ValueError( | ||
f"{self.entity_type.__name__} with id {id} already exists" | ||
) | ||
with open(entity_path, "w") as f: | ||
f.write(entity.model_dump_json(indent=4)) | ||
|
||
def delete(self, entity_id: str): | ||
os.remove(os.path.join(self.directory_path, f"{entity_id}.json")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
""" | ||
An example is a single input/output pair. | ||
- Examples are used for fine-tuning a model (i.e. golden example) or running an eval (i.e. expected example). | ||
- There are two types of examples: | ||
- **Golden Example**: A golden example is an example that is used to create a golden dataset. | ||
- **Expected Example**: An expected example is an example that is used to evaluate a producer. | ||
Internally, once an expected example has been run through an eval, we create an **evaluated example**, but you don't need to create this manually in the UI. | ||
""" | ||
|
||
import os | ||
import shutil | ||
from typing import Generic, Literal, TypeVar | ||
|
||
from pydantic import BaseModel | ||
|
||
|
||
class ExampleInput(BaseModel): | ||
prompt: str | ||
input_code: str | None = None | ||
line_number_target: int | None = None | ||
|
||
|
||
class BaseExample(BaseModel): | ||
id: str | ||
input: ExampleInput | ||
|
||
|
||
class ExampleOutput(BaseModel): | ||
output_code: str | None = None | ||
raw_output: str | None = None | ||
output_type: Literal["full", "diff"] = "diff" | ||
|
||
|
||
class ExpectedExample(BaseExample): | ||
expect_executable: bool = True | ||
expect_type_checkable: bool = True | ||
|
||
|
||
class ExpectResult(BaseModel): | ||
name: Literal["executable", "type_checkable", "patchable"] | ||
score: int # 0 or 1 | ||
message: str | None = None | ||
|
||
|
||
class EvaluatedExampleOutput(BaseModel): | ||
time_spent_secs: float | ||
tokens: int | ||
output: ExampleOutput | ||
expect_results: list[ExpectResult] | ||
|
||
|
||
class EvaluatedExample(BaseModel): | ||
expected: ExpectedExample | ||
outputs: list[EvaluatedExampleOutput] | ||
|
||
|
||
class GoldenExample(BaseExample): | ||
output: ExampleOutput | ||
|
||
|
||
T = TypeVar("T", bound=BaseExample) | ||
|
||
|
||
class ExampleStore(Generic[T]): | ||
def __init__(self, entity_type: type[T], *, dirname: str): | ||
self.entity_type = entity_type | ||
self.directory_path = os.path.join( | ||
os.path.dirname(__file__), "..", "..", "..", "data", dirname | ||
) | ||
|
||
def get(self, id: str) -> T: | ||
dir_path = os.path.join(self.directory_path, id) | ||
json_path = os.path.join(dir_path, "example_input.json") | ||
with open(json_path) as f: | ||
entity_json = f.read() | ||
entity = self.entity_type.model_validate_json(entity_json) | ||
input = entity.input | ||
input_py_path = os.path.join(dir_path, "input.py") | ||
if os.path.exists(input_py_path): | ||
with open(input_py_path) as f: | ||
input.input_code = f.read() | ||
if isinstance(entity, GoldenExample): | ||
output_py_path = os.path.join(dir_path, "output.py") | ||
if os.path.exists(output_py_path): | ||
with open(output_py_path) as f: | ||
entity.output.output_code = f.read() | ||
raw_output_path = os.path.join(dir_path, "raw_output.txt") | ||
if os.path.exists(raw_output_path): | ||
with open(raw_output_path) as f: | ||
entity.output.raw_output = f.read() | ||
return entity | ||
|
||
def get_all(self) -> list[T]: | ||
entities: list[T] = [] | ||
for filename in os.listdir(self.directory_path): | ||
entities.append(self.get(filename)) | ||
return entities | ||
|
||
def save(self, entity: T, overwrite: bool = False): | ||
id = entity.id | ||
dir_path = os.path.join(self.directory_path, id) | ||
|
||
if not overwrite: | ||
if os.path.exists(dir_path): | ||
raise ValueError( | ||
f"{self.entity_type.__name__} with id {id} already exists" | ||
) | ||
else: | ||
os.mkdir(dir_path) | ||
json_path = os.path.join(dir_path, "example_input.json") | ||
input_code = entity.input.input_code | ||
if input_code: | ||
input_py_path = os.path.join(dir_path, "input.py") | ||
with open(input_py_path, "w") as f: | ||
f.write(input_code) | ||
entity.input.input_code = None | ||
|
||
if isinstance(entity, GoldenExample): | ||
output_py_path = os.path.join(dir_path, "output.py") | ||
with open(output_py_path, "w") as f: | ||
f.write(entity.output.output_code) | ||
raw_output_path = os.path.join(dir_path, "raw_output.txt") | ||
with open(raw_output_path, "w") as f: | ||
f.write(entity.output.raw_output) | ||
entity.output.output_code = None | ||
entity.output.raw_output = None | ||
with open(json_path, "w") as f: | ||
f.write(entity.model_dump_json(indent=4)) | ||
|
||
def delete(self, entity_id: str): | ||
shutil.rmtree(os.path.join(self.directory_path, entity_id)) | ||
|
||
|
||
expected_example_store = ExampleStore( | ||
ExpectedExample, dirname="expected_examples" | ||
) | ||
golden_example_store = ExampleStore(GoldenExample, dirname="golden_examples") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
from os import getenv | ||
from typing import Iterator | ||
|
||
from openai import OpenAI | ||
from openai.types.chat import ( | ||
ChatCompletionMessageParam, | ||
) | ||
|
||
from ai.common.diff import EDIT_HERE_MARKER, ApplyPatchResult, apply_patch | ||
from ai.common.entity_store import get_data_path | ||
from ai.common.example import ExampleInput | ||
from ai.common.model import model_store | ||
from ai.common.producer import producer_store | ||
from ai.common.prompt_context import prompt_context_store | ||
from ai.common.prompt_fragment import PromptFragment, prompt_fragment_store | ||
|
||
|
||
class ProviderExecutor: | ||
def __init__(self, model_name: str, prompt_fragments: list[PromptFragment]): | ||
self.model_name = model_name | ||
|
||
self.prompt_fragments = [ | ||
PromptFragment( | ||
id=pf.id, | ||
role=pf.role, | ||
chain_of_thought=pf.chain_of_thought, | ||
content_value=get_content_value(pf), | ||
content_path=None, | ||
) | ||
for pf in prompt_fragments | ||
] | ||
|
||
def format_messages( | ||
self, input: ExampleInput | ||
) -> list[ChatCompletionMessageParam]: | ||
code = input.input_code or "" | ||
# Add sentinel token based on line_number (1-indexed) | ||
if input.line_number_target is not None: | ||
code_lines = code.splitlines() | ||
if 1 <= input.line_number_target <= len(code_lines): | ||
code_lines[input.line_number_target - 1] += EDIT_HERE_MARKER | ||
code = "\n".join(code_lines) | ||
|
||
return [ | ||
{ | ||
"role": pf.role, | ||
"content": pf.content_value.replace("<APP_CODE>", code).replace( # type: ignore | ||
"<APP_CHANGES>", input.prompt | ||
), | ||
} | ||
for pf in self.prompt_fragments | ||
] | ||
|
||
def execute(self, input: ExampleInput) -> str: ... | ||
|
||
def execute_stream(self, input: ExampleInput) -> Iterator[str]: ... | ||
|
||
|
||
class OpenaiExecutor(ProviderExecutor): | ||
def __init__(self, model_name: str, prompt_fragments: list[PromptFragment]): | ||
super().__init__(model_name, prompt_fragments) | ||
self.client = OpenAI( | ||
api_key=getenv("OPENAI_API_KEY"), | ||
) | ||
|
||
def execute(self, input: ExampleInput) -> str: | ||
response = self.client.chat.completions.create( | ||
model=self.model_name, | ||
max_tokens=10_000, | ||
messages=self.format_messages(input), | ||
) | ||
return response.choices[0].message.content or "" | ||
|
||
def execute_stream(self, input: ExampleInput) -> Iterator[str]: | ||
stream = self.client.chat.completions.create( | ||
model=self.model_name, | ||
max_tokens=10_000, | ||
messages=self.format_messages(input), | ||
stream=True, | ||
) | ||
for chunk in stream: | ||
content = chunk.choices[0].delta.content | ||
yield content or "" | ||
|
||
|
||
provider_executors: dict[str, type[ProviderExecutor]] = { | ||
"openai": OpenaiExecutor, | ||
} | ||
|
||
|
||
class ProducerExecutor: | ||
def __init__(self, producer_id: str): | ||
self.producer = producer_store.get(producer_id) | ||
|
||
def get_provider_executor(self) -> ProviderExecutor: | ||
prompt_context = prompt_context_store.get(self.producer.prompt_context_id) | ||
prompt_fragments = [ | ||
prompt_fragment_store.get(pfid) for pfid in prompt_context.fragment_ids | ||
] | ||
model = model_store.get(self.producer.mesop_model_id) | ||
provider_executor_type = provider_executors.get(model.provider) | ||
if provider_executor_type is None: | ||
raise ValueError(f"Provider {model.provider} not supported") | ||
provider_executor = provider_executor_type(model.name, prompt_fragments) | ||
return provider_executor | ||
|
||
def execute(self, input: ExampleInput): | ||
return self.get_provider_executor().execute(input) | ||
|
||
def execute_stream(self, input: ExampleInput): | ||
return self.get_provider_executor().execute_stream(input) | ||
|
||
def transform_output(self, input_code: str, output: str): | ||
if self.producer.output_format == "diff": | ||
return apply_patch(input_code, output) | ||
elif self.producer.output_format == "full": | ||
return ApplyPatchResult(True, output) | ||
else: | ||
raise ValueError(f"Unknown output format: {self.producer.output_format}") | ||
|
||
|
||
def get_content_value(pf: PromptFragment) -> str | None: | ||
if pf.content_value is not None: | ||
return pf.content_value | ||
if pf.content_path is not None: | ||
with open(get_data_path(pf.content_path.replace("//", ""))) as f: | ||
return f.read() | ||
return None |
Oops, something went wrong.