Skip to content
This repository has been archived by the owner on Oct 15, 2024. It is now read-only.

Commit

Permalink
core[patch]: simple prompt pretty printing (langchain-ai#15968)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Jan 13, 2024
1 parent 3f75fd4 commit bccb07f
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 3 deletions.
20 changes: 20 additions & 0 deletions libs/core/langchain_core/messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.utils import get_bolded_text
from langchain_core.utils.interactive_env import is_interactive_env

if TYPE_CHECKING:
from langchain_core.prompts.chat import ChatPromptTemplate
Expand Down Expand Up @@ -42,6 +44,14 @@ def __add__(self, other: Any) -> ChatPromptTemplate:
prompt = ChatPromptTemplate(messages=[self])
return prompt + other

def pretty_repr(self, html: bool = False) -> str:
title = get_msg_title_repr(self.type.title() + " Message", bold=html)
# TODO: handle non-string content.
return f"{title}\n\n{self.content}"

def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))


def merge_content(
first_content: Union[str, List[Union[str, Dict]]],
Expand Down Expand Up @@ -176,3 +186,13 @@ def messages_to_dict(messages: Sequence[BaseMessage]) -> List[dict]:
List of messages as dicts.
"""
return [message_to_dict(m) for m in messages]


def get_msg_title_repr(title: str, *, bold: bool = False) -> str:
padded = " " + title + " "
sep_len = (80 - len(padded)) // 2
sep = "=" * sep_len
second_sep = sep + "=" if len(padded) % 2 else sep
if bold:
padded = get_bolded_text(padded)
return f"{sep}{padded}{second_sep}"
40 changes: 37 additions & 3 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
HumanMessage,
SystemMessage,
)
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env


class BaseMessagePromptTemplate(Serializable, ABC):
Expand Down Expand Up @@ -68,6 +71,13 @@ def input_variables(self) -> List[str]:
List of input variables.
"""

def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation."""
raise NotImplementedError

def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))

def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Expand Down Expand Up @@ -95,9 +105,7 @@ def get_lc_namespace(cls) -> List[str]:
return ["langchain", "prompts", "chat"]

def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
return super().__init__(
variable_name=variable_name, optional=optional, **kwargs
)
super().__init__(variable_name=variable_name, optional=optional, **kwargs)

def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Expand Down Expand Up @@ -135,6 +143,15 @@ def input_variables(self) -> List[str]:
"""
return [self.variable_name] if not self.optional else []

def pretty_repr(self, html: bool = False) -> str:
var = "{" + self.variable_name + "}"
if html:
title = get_msg_title_repr("Messages Placeholder", bold=True)
var = get_colored_text(var, "yellow")
else:
title = get_msg_title_repr("Messages Placeholder")
return f"{title}\n\n{var}"


MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
Expand Down Expand Up @@ -237,6 +254,12 @@ def input_variables(self) -> List[str]:
"""
return self.prompt.input_variables

def pretty_repr(self, html: bool = False) -> str:
# TODO: Handle partials
title = self.__class__.__name__.replace("MessagePromptTemplate", " Message")
title = get_msg_title_repr(title, bold=html)
return f"{title}\n\n{self.prompt.pretty_repr(html=html)}"


class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Chat message prompt template."""
Expand Down Expand Up @@ -369,6 +392,13 @@ def format_prompt(self, **kwargs: Any) -> PromptValue:
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages."""

def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation."""
raise NotImplementedError

def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))


MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]

Expand Down Expand Up @@ -701,6 +731,10 @@ def save(self, file_path: Union[Path, str]) -> None:
"""
raise NotImplementedError()

def pretty_repr(self, html: bool = False) -> str:
# TODO: handle partials
return "\n\n".join(msg.pretty_repr(html=html) for msg in self.messages)


def _create_template_from_message_type(
message_type: str, template: str
Expand Down
3 changes: 3 additions & 0 deletions libs/core/langchain_core/prompts/few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,6 @@ def format(self, **kwargs: Any) -> str:
"""
messages = self.format_messages(**kwargs)
return get_buffer_string(messages)

def pretty_repr(self, html: bool = False) -> str:
raise NotImplementedError()
16 changes: 16 additions & 0 deletions libs/core/langchain_core/prompts/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env


def jinja2_formatter(template: str, **kwargs: Any) -> str:
Expand Down Expand Up @@ -159,3 +161,17 @@ def get_lc_namespace(cls) -> List[str]:
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))

def pretty_repr(self, html: bool = False) -> str:
# TODO: handle partials
dummy_vars = {
input_var: "{" + f"{input_var}" + "}" for input_var in self.input_variables
}
if html:
dummy_vars = {
k: get_colored_text(v, "yellow") for k, v in dummy_vars.items()
}
return self.format(**dummy_vars)

def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))
5 changes: 5 additions & 0 deletions libs/core/langchain_core/utils/interactive_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def is_interactive_env() -> bool:
"""Determine if running within IPython or Jupyter."""
import sys

return hasattr(sys, "ps2")
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,6 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
messages += historical_messages
messages.append(input_message)
return messages

def pretty_repr(self, html: bool = False) -> str:
raise NotImplementedError

0 comments on commit bccb07f

Please sign in to comment.