Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseChatInterface: get_canonical_template_keys as a default template keys list #162

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions esbmc_ai/chats/base_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
"""Contains code for the base class for interacting with the LLMs in a
conversation-based way."""

from abc import abstractmethod
from typing import Optional
from typing import Any, Optional

from langchain.schema import (
BaseMessage,
Expand Down Expand Up @@ -33,10 +32,9 @@ def __init__(
self.messages: list[BaseMessage] = []
self.llm: BaseChatModel = llm

@abstractmethod
def compress_message_stack(self) -> None:
"""Compress the message stack, is abstract and needs to be implemented."""
raise NotImplementedError()
self.messages = []

def push_to_message_stack(
self,
Expand All @@ -48,6 +46,17 @@ def push_to_message_stack(
else:
self.messages.append(message)

def get_canonical_template_keys(
self, source_code: str, esbmc_output: str, error_line: str, error_type: str
) -> dict[str, Any]:
"""Gets the canonical template keys for applying in template values."""
return {
source_code: source_code,
esbmc_output: esbmc_output,
error_line: error_line,
error_type: error_type,
}

def apply_template_value(self, **kwargs: str) -> None:
"""Will substitute an f-string in the message stack and system messages to
the provided value. The new substituted messages will become the new
Expand Down
10 changes: 6 additions & 4 deletions esbmc_ai/chats/solution_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,12 @@ def generate_solution(

# Apply template substitution to message stack
self.apply_template_value(
source_code=self.source_code_formatted,
esbmc_output=self.esbmc_output,
error_line=str(self.verifier.get_error_line(self.esbmc_output)),
error_type=error_type if error_type else "unknown error",
*self.get_canonical_template_keys(
source_code=self.source_code_formatted,
esbmc_output=self.esbmc_output,
error_line=str(self.verifier.get_error_line(self.esbmc_output)),
error_type=error_type if error_type else "unknown error",
)
)

# Generate the solution
Expand Down
18 changes: 9 additions & 9 deletions esbmc_ai/chats/user_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


from esbmc_ai.ai_models import AIModel
from esbmc_ai.solution import Solution
from esbmc_ai.verifiers.base_source_verifier import BaseSourceVerifier

from .base_chat_interface import BaseChatInterface
Expand All @@ -21,14 +22,12 @@ class UserChat(BaseChatInterface):
"""Simple interface that talks to the LLM and stores the result. The class
also stores the fixed results from fix code command."""

solution: str = ""

def __init__(
self,
ai_model: AIModel,
llm: BaseChatModel,
verifier: BaseSourceVerifier,
source_code: str,
solution: Solution,
esbmc_output: str,
system_messages: list[BaseMessage],
set_solution_messages: list[BaseMessage],
Expand All @@ -38,20 +37,21 @@ def __init__(
ai_model=ai_model,
llm=llm,
)

# Store source code and esbmc output in order to substitute it into the message stack.
self.source_code: str = source_code
self.solution: Solution = solution
self.esbmc_output: str = esbmc_output
# The messsages for setting a new solution to the source code.
self.set_solution_messages = set_solution_messages

error_type: Optional[str] = verifier.get_error_type(self.esbmc_output)

self.apply_template_value(
source_code=self.source_code,
esbmc_output=self.esbmc_output,
error_line=str(verifier.get_error_line(self.esbmc_output)),
error_type=error_type if error_type else "unknown error",
*self.get_canonical_template_keys(
source_code=self.solution.files[0].content,
esbmc_output=self.esbmc_output,
error_line=str(verifier.get_error_line(self.esbmc_output)),
error_type=error_type if error_type else "unknown error",
)
)

def set_solution(self, source_code: str) -> None:
Expand Down
Loading