From 9f3bcee30afb9d29f339878634e7fe44fbe2ca2b Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Wed, 12 Feb 2025 04:51:18 +0500 Subject: [PATCH] (Community): Adding Structured Support for ChatPerplexity (#29361) - **Description:** Adding Structured Support for ChatPerplexity - **Issue:** #29357 - This is implemented as per the Perplexity official docs: https://docs.perplexity.ai/guides/structured-outputs --------- Co-authored-by: ccurme --- .../chat_models/perplexity.py | 118 +++++++++++++++++- 1 file changed, 114 insertions(+), 4 deletions(-) diff --git a/libs/community/langchain_community/chat_models/perplexity.py b/libs/community/langchain_community/chat_models/perplexity.py index d2decb38df93b..0e5ea61344787 100644 --- a/libs/community/langchain_community/chat_models/perplexity.py +++ b/libs/community/langchain_community/chat_models/perplexity.py @@ -3,19 +3,23 @@ from __future__ import annotations import logging +from operator import itemgetter from typing import ( Any, Dict, Iterator, List, + Literal, Mapping, Optional, Tuple, Type, + TypeVar, Union, ) from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import ( BaseChatModel, generate_from_stream, @@ -34,17 +38,27 @@ SystemMessageChunk, ToolMessageChunk, ) +from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.utils import ( - from_env, - get_pydantic_field_names, +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough +from langchain_core.utils import from_env, get_pydantic_field_names +from langchain_core.utils.pydantic import ( + is_basemodel_subclass, ) -from pydantic import ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, model_validator from typing_extensions import Self +_BM = TypeVar("_BM", bound=BaseModel) +_DictOrPydanticClass = Union[Dict[str, Any], Type[_BM], Type] +_DictOrPydantic = Union[Dict, _BM] + logger = logging.getLogger(__name__) +def _is_pydantic_class(obj: Any) -> bool: + return isinstance(obj, type) and is_basemodel_subclass(obj) + + class ChatPerplexity(BaseChatModel): """`Perplexity AI` Chat models API. @@ -282,3 +296,99 @@ def _invocation_params(self) -> Mapping[str, Any]: def _llm_type(self) -> str: """Return type of chat model.""" return "perplexitychat" + + def with_structured_output( + self, + schema: Optional[_DictOrPydanticClass] = None, + *, + method: Literal["json_schema"] = "json_schema", + include_raw: bool = False, + strict: Optional[bool] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, _DictOrPydantic]: + """Model wrapper that returns outputs formatted to match the given schema for Preplexity. + Currently, Preplexity only supports "json_schema" method for structured output + as per their official documentation: https://docs.perplexity.ai/guides/structured-outputs + + Args: + schema: + The output schema. Can be passed in as: + + - a JSON Schema, + - a TypedDict class, + - or a Pydantic class + + method: The method for steering model generation, currently only support: + + - "json_schema": Use the JSON Schema to parse the model output + + + include_raw: + If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + kwargs: Additional keyword args aren't supported. + + Returns: + A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`. + + | If ``include_raw`` is False and ``schema`` is a Pydantic class, Runnable outputs an instance of ``schema`` (i.e., a Pydantic object). Otherwise, if ``include_raw`` is False then Runnable outputs a dict. + + | If ``include_raw`` is True, then Runnable outputs a dict with keys: + + - "raw": BaseMessage + - "parsed": None if there was a parsing error, otherwise the type depends on the ``schema`` as described above. + - "parsing_error": Optional[BaseException] + + """ # noqa: E501 + if method == "json_schema": + if schema is None: + raise ValueError( + "schema must be specified when method is not 'json_schema'. " + "Received None." + ) + is_pydantic_schema = _is_pydantic_class(schema) + if is_pydantic_schema and hasattr( + schema, "model_json_schema" + ): # accounting for pydantic v1 and v2 + response_format = schema.model_json_schema() # type: ignore[union-attr] + elif is_pydantic_schema: + response_format = schema.schema() # type: ignore[union-attr] + elif isinstance(schema, dict): + response_format = schema + elif type(schema).__name__ == "_TypedDictMeta": + adapter = TypeAdapter(schema) # if use passes typeddict + response_format = adapter.json_schema() + + llm = self.bind( + response_format={ + "type": "json_schema", + "json_schema": {"schema": response_format}, + } + ) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + else: + raise ValueError( + f"Unrecognized method argument. Expected 'json_schema' Received:\ + '{method}'" + ) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser