Skip to content

Commit

Permalink
(Community): Adding Structured Support for ChatPerplexity (#29361)
Browse files Browse the repository at this point in the history
- **Description:** Adding Structured Support for ChatPerplexity
- **Issue:** #29357
- This is implemented as per the Perplexity official docs:
https://docs.perplexity.ai/guides/structured-outputs

---------

Co-authored-by: ccurme <chester.curme@gmail.com>
  • Loading branch information
keenborder786 and ccurme authored Feb 11, 2025
1 parent 994c546 commit 9f3bcee
Showing 1 changed file with 114 additions and 4 deletions.
118 changes: 114 additions & 4 deletions libs/community/langchain_community/chat_models/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
from __future__ import annotations

import logging
from operator import itemgetter
from typing import (
Any,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
Expand All @@ -34,17 +38,27 @@
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import (
from_env,
get_pydantic_field_names,
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.utils import from_env, get_pydantic_field_names
from langchain_core.utils.pydantic import (
is_basemodel_subclass,
)
from pydantic import ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator
from typing_extensions import Self

_BM = TypeVar("_BM", bound=BaseModel)
_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type]
_DictOrPydantic = Union[Dict, _BM]

logger = logging.getLogger(__name__)


def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and is_basemodel_subclass(obj)


class ChatPerplexity(BaseChatModel):
"""`Perplexity AI` Chat models API.
Expand Down Expand Up @@ -282,3 +296,99 @@ def _invocation_params(self) -> Mapping[str, Any]:
def _llm_type(self) -> str:
"""Return type of chat model."""
return "perplexitychat"

def with_structured_output(
self,
schema: Optional[_DictOrPydanticClass] = None,
*,
method: Literal["json_schema"] = "json_schema",
include_raw: bool = False,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
"""Model wrapper that returns outputs formatted to match the given schema for Preplexity.
Currently, Preplexity only supports "json_schema" method for structured output
as per their official documentation: https://docs.perplexity.ai/guides/structured-outputs
Args:
schema:
The output schema. Can be passed in as:
- a JSON Schema,
- a TypedDict class,
- or a Pydantic class
method: The method for steering model generation, currently only support:
- "json_schema": Use the JSON Schema to parse the model output
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
kwargs: Additional keyword args aren't supported.
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
| If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict.
| If ``include_raw`` is True, then Runnable outputs a dict with keys:
- "raw": BaseMessage
- "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above.
- "parsing_error": Optional[BaseException]
""" # noqa: E501
if method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is not 'json_schema'. "
"Received None."
)
is_pydantic_schema = _is_pydantic_class(schema)
if is_pydantic_schema and hasattr(
schema, "model_json_schema"
): # accounting for pydantic v1 and v2
response_format = schema.model_json_schema() # type: ignore[union-attr]
elif is_pydantic_schema:
response_format = schema.schema() # type: ignore[union-attr]
elif isinstance(schema, dict):
response_format = schema
elif type(schema).__name__ == "_TypedDictMeta":
adapter = TypeAdapter(schema) # if use passes typeddict
response_format = adapter.json_schema()

llm = self.bind(
response_format={
"type": "json_schema",
"json_schema": {"schema": response_format},
}
)
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. Expected 'json_schema' Received:\
'{method}'"
)

if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser

0 comments on commit 9f3bcee

Please sign in to comment.