diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 9eddf4f324247..a4a0e96c9653c 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -46,6 +46,7 @@ from langchain_community.chat_models.openai import ChatOpenAI from langchain_community.chat_models.pai_eas_endpoint import PaiEasChatEndpoint from langchain_community.chat_models.promptlayer_openai import PromptLayerChatOpenAI +from langchain_community.chat_models.tongyi import ChatTongyi from langchain_community.chat_models.vertexai import ChatVertexAI from langchain_community.chat_models.volcengine_maas import VolcEngineMaasChat from langchain_community.chat_models.yandex import ChatYandexGPT @@ -76,6 +77,7 @@ "ChatKonko", "PaiEasChatEndpoint", "QianfanChatEndpoint", + "ChatTongyi", "ChatFireworks", "ChatYandexGPT", "ChatBaichuan", diff --git a/libs/community/langchain_community/chat_models/tongyi.py b/libs/community/langchain_community/chat_models/tongyi.py index 70119fbfff6ee..5e4f4c1ab901d 100644 --- a/libs/community/langchain_community/chat_models/tongyi.py +++ b/libs/community/langchain_community/chat_models/tongyi.py @@ -1,23 +1,25 @@ from __future__ import annotations +import asyncio +import functools import logging from typing import ( Any, + AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional, - Tuple, - Type, + Union, ) -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models.chat_models import ( - BaseChatModel, - generate_from_stream, +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, ) +from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -25,8 +27,6 @@ BaseMessageChunk, ChatMessage, ChatMessageChunk, - FunctionMessage, - FunctionMessageChunk, HumanMessage, HumanMessageChunk, SystemMessage, @@ -36,41 +36,63 @@ ChatGeneration, ChatGenerationChunk, ChatResult, - GenerationChunk, ) from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.utils import get_from_dict_or_env from requests.exceptions import HTTPError from tenacity import ( - RetryCallState, + before_sleep_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential, ) -logger = logging.getLogger(__name__) +from langchain_community.llms.tongyi import check_response +logger = logging.getLogger(__name__) -def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: - """Convert a dict to a message.""" +def convert_dict_to_message( + _dict: Mapping[str, Any], is_chunk: bool = False +) -> Union[BaseMessage, BaseMessageChunk]: role = _dict["role"] + content = _dict["content"] if role == "user": - return HumanMessage(content=_dict["content"]) + return ( + HumanMessageChunk(content=content) + if is_chunk + else HumanMessage(content=content) + ) elif role == "assistant": - content = _dict.get("content", "") or "" - if _dict.get("function_call"): - additional_kwargs = {"function_call": dict(_dict["function_call"])} - else: - additional_kwargs = {} - return AIMessage(content=content, additional_kwargs=additional_kwargs) + return ( + AIMessageChunk(content=content) if is_chunk else AIMessage(content=content) + ) elif role == "system": - return SystemMessage(content=_dict["content"]) - elif role == "function": - return FunctionMessage(content=_dict["content"], name=_dict["name"]) + return ( + SystemMessageChunk(content=content) + if is_chunk + else SystemMessage(content=content) + ) + else: + return ( + ChatMessageChunk(role=role, content=content) + if is_chunk + else ChatMessage(role=role, content=content) + ) + + +def convert_message_chunk_to_message(message_chunk: BaseMessageChunk) -> BaseMessage: + if isinstance(message_chunk, HumanMessageChunk): + return HumanMessage(content=message_chunk.content) + elif isinstance(message_chunk, AIMessageChunk): + return AIMessage(content=message_chunk.content) + elif isinstance(message_chunk, SystemMessageChunk): + return SystemMessage(content=message_chunk.content) + elif isinstance(message_chunk, ChatMessageChunk): + return ChatMessage(role=message_chunk.role, content=message_chunk.content) else: - return ChatMessage(content=_dict["content"], role=role) + raise TypeError(f"Got unknown type {message_chunk}") def convert_message_to_dict(message: BaseMessage) -> dict: @@ -83,109 +105,27 @@ def convert_message_to_dict(message: BaseMessage) -> dict: message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} - if "function_call" in message.additional_kwargs: - message_dict["function_call"] = message.additional_kwargs["function_call"] - # If function call only, content is None not empty string - if message_dict["content"] == "": - message_dict["content"] = None elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} - elif isinstance(message, FunctionMessage): - message_dict = { - "role": "function", - "content": message.content, - "name": message.name, - } else: raise TypeError(f"Got unknown type {message}") - if "name" in message.additional_kwargs: - message_dict["name"] = message.additional_kwargs["name"] return message_dict -def _stream_response_to_generation_chunk( - stream_response: Dict[str, Any], - length: int, -) -> GenerationChunk: - """Convert a stream response to a generation chunk. - - As the low level API implement is different from openai and other llm. - Stream response of Tongyi is not split into chunks, but all data generated before. - For example, the answer 'Hi Pickle Rick! How can I assist you today?' - Other llm will stream answer: - 'Hi Pickle', - ' Rick!', - ' How can I assist you today?'. - - Tongyi answer: - 'Hi Pickle', - 'Hi Pickle Rick!', - 'Hi Pickle Rick! How can I assist you today?'. - - As the GenerationChunk is implemented with chunks. Only return full_text[length:] - for new chunk. - """ - full_text = stream_response["output"]["text"] - text = full_text[length:] - finish_reason = stream_response["output"].get("finish_reason", None) - - return GenerationChunk( - text=text, - generation_info=dict( - finish_reason=finish_reason, - ), - ) - - -def _create_retry_decorator( - llm: ChatTongyi, - run_manager: Optional[CallbackManagerForLLMRun] = None, -) -> Callable[[Any], Any]: - def _before_sleep(retry_state: RetryCallState) -> None: - if run_manager: - run_manager.on_retry(retry_state) - return None - +def _create_retry_decorator(llm: ChatTongyi) -> Callable[[Any], Any]: min_seconds = 1 max_seconds = 4 # Wait 2^x * 1 second between each retry starting with - # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + # 4 seconds, then up to 10 seconds, then 10 seconds afterward return retry( reraise=True, stop=stop_after_attempt(llm.max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), retry=(retry_if_exception_type(HTTPError)), - before_sleep=_before_sleep, + before_sleep=before_sleep_log(logger, logging.WARNING), ) -def _convert_delta_to_message_chunk( - _dict: Mapping[str, Any], - default_class: Type[BaseMessageChunk], - length: int, -) -> BaseMessageChunk: - role = _dict.get("role") - full_content = _dict.get("content") or "" - content = full_content[length:] - if _dict.get("function_call"): - additional_kwargs = {"function_call": dict(_dict["function_call"])} - else: - additional_kwargs = {} - - if role == "user" or default_class == HumanMessageChunk: - return HumanMessageChunk(content=content) - elif role == "assistant" or default_class == AIMessageChunk: - return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) - elif role == "system" or default_class == SystemMessageChunk: - return SystemMessageChunk(content=content) - elif role == "function" or default_class == FunctionMessageChunk: - return FunctionMessageChunk(content=content, name=_dict["name"]) - elif role or default_class == ChatMessageChunk: - return ChatMessageChunk(content=content, role=role) - else: - return default_class(content=content) - - class ChatTongyi(BaseChatModel): """Alibaba Tongyi Qwen chat models API. @@ -204,10 +144,6 @@ class ChatTongyi(BaseChatModel): def lc_secrets(self) -> Dict[str, str]: return {"dashscope_api_key": "DASHSCOPE_API_KEY"} - @property - def lc_serializable(self) -> bool: - return True - client: Any #: :meta private: model_name: str = Field(default="qwen-turbo", alias="model") @@ -218,10 +154,7 @@ def lc_serializable(self) -> bool: """Total probability mass of tokens to consider at each step.""" dashscope_api_key: Optional[str] = None - """Dashscope api key provide by alicloud.""" - - n: int = 1 - """How many completions to generate for each prompt.""" + """Dashscope api key provide by Alibaba Cloud.""" streaming: bool = False """Whether to stream the results or not.""" @@ -229,12 +162,6 @@ def lc_serializable(self) -> bool: max_retries: int = 10 """Maximum number of retries to make when generating.""" - prefix_messages: List = Field(default_factory=list) - """Series of messages for Chat input.""" - - result_format: str = Field(default="message") - """Return result format""" - @property def _llm_type(self) -> str: """Return type of llm.""" @@ -243,7 +170,9 @@ def _llm_type(self) -> str: @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - get_from_dict_or_env(values, "dashscope_api_key", "DASHSCOPE_API_KEY") + values["dashscope_api_key"] = get_from_dict_or_env( + values, "dashscope_api_key", "DASHSCOPE_API_KEY" + ) try: import dashscope except ImportError: @@ -264,81 +193,141 @@ def validate_environment(cls, values: Dict) -> Dict: @property def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling OpenAI API.""" + """Get the default parameters for calling Tongyi Qwen API.""" return { "model": self.model_name, "top_p": self.top_p, - "stream": self.streaming, - "n": self.n, - "result_format": self.result_format, + "api_key": self.dashscope_api_key, + "result_format": "message", **self.model_kwargs, } - def completion_with_retry( - self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any - ) -> Any: + def completion_with_retry(self, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + retry_decorator = _create_retry_decorator(self) @retry_decorator def _completion_with_retry(**_kwargs: Any) -> Any: resp = self.client.call(**_kwargs) - if resp.status_code == 200: - return resp - elif resp.status_code in [400, 401]: - raise ValueError( - f"status_code: {resp.status_code} \n " - f"code: {resp.code} \n message: {resp.message}" - ) - else: - raise HTTPError( - f"HTTP error occurred: status_code: {resp.status_code} \n " - f"code: {resp.code} \n message: {resp.message}", - response=resp, - ) + return check_response(resp) return _completion_with_retry(**kwargs) - def stream_completion_with_retry( - self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any - ) -> Any: + def stream_completion_with_retry(self, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" - retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + retry_decorator = _create_retry_decorator(self) @retry_decorator def _stream_completion_with_retry(**_kwargs: Any) -> Any: - return self.client.call(**_kwargs) + responses = self.client.call(**_kwargs) + for resp in responses: + yield check_response(resp) return _stream_completion_with_retry(**kwargs) + async def astream_completion_with_retry(self, **kwargs: Any) -> Any: + """Because the dashscope SDK doesn't provide an async API, + we wrap `stream_generate_with_retry` with an async generator.""" + + class _AioTongyiGenerator: + def __init__(self, generator: Any): + self.generator = generator + + def __aiter__(self) -> AsyncIterator[Any]: + return self + + async def __anext__(self) -> Any: + value = await asyncio.get_running_loop().run_in_executor( + None, self._safe_next + ) + if value is not None: + return value + else: + raise StopAsyncIteration + + def _safe_next(self) -> Any: + try: + return next(self.generator) + except StopIteration: + return None + + async for chunk in _AioTongyiGenerator( + generator=self.stream_completion_with_retry(**kwargs) + ): + yield chunk + def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - should_stream = stream if stream is not None else self.streaming - if should_stream: - stream_iter = self._stream( + generations = [] + if self.streaming: + generation: Optional[ChatGenerationChunk] = None + for chunk in self._stream( messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + generations.append(self._chunk_to_generation(generation)) + else: + params: Dict[str, Any] = self._invocation_params( + messages=messages, stop=stop, **kwargs ) - return generate_from_stream(stream_iter) - - if not messages: - raise ValueError("No messages provided.") - - message_dicts, params = self._create_message_dicts(messages, stop) - - if message_dicts[-1]["role"] != "user": - raise ValueError("Last message should be user message.") + resp = self.completion_with_retry(**params) + generations.append( + ChatGeneration(**self._chat_generation_from_qwen_resp(resp)) + ) + return ChatResult( + generations=generations, + llm_output={ + "model_name": self.model_name, + }, + ) - params = {**params, **kwargs} - response = self.completion_with_retry( - messages=message_dicts, run_manager=run_manager, **params + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + generations = [] + if self.streaming: + generation: Optional[ChatGenerationChunk] = None + async for chunk in self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + generations.append(self._chunk_to_generation(generation)) + else: + params: Dict[str, Any] = self._invocation_params( + messages=messages, stop=stop, **kwargs + ) + resp = await asyncio.get_running_loop().run_in_executor( + None, + functools.partial( + self.completion_with_retry, **{"run_manager": run_manager, **params} + ), + ) + generations.append( + ChatGeneration(**self._chat_generation_from_qwen_resp(resp)) + ) + return ChatResult( + generations=generations, + llm_output={ + "model_name": self.model_name, + }, ) - return self._create_chat_result(response) def _stream( self, @@ -347,62 +336,83 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - message_dicts, params = self._create_message_dicts(messages, stop) - params = {**params, **kwargs, "stream": True} - # Mark current chunk total length - length = 0 - default_chunk_class = AIMessageChunk - for chunk in self.stream_completion_with_retry( - messages=message_dicts, run_manager=run_manager, **params - ): - if len(chunk["output"]["choices"]) == 0: - continue - choice = chunk["output"]["choices"][0] - - chunk = _convert_delta_to_message_chunk( - choice["message"], default_chunk_class, length - ) - finish_reason = choice.get("finish_reason") - generation_info = ( - dict(finish_reason=finish_reason) if finish_reason is not None else None + params: Dict[str, Any] = self._invocation_params( + messages=messages, stop=stop, stream=True, **kwargs + ) + for stream_resp in self.stream_completion_with_retry(**params): + chunk = ChatGenerationChunk( + **self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True) ) - default_chunk_class = chunk.__class__ - chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) yield chunk if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) - length = len(choice["message"]["content"]) - def _create_message_dicts( - self, messages: List[BaseMessage], stop: Optional[List[str]] - ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: - params = self._client_params() + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + params: Dict[str, Any] = self._invocation_params( + messages=messages, stop=stop, stream=True, **kwargs + ) + async for stream_resp in self.astream_completion_with_retry(**params): + chunk = ChatGenerationChunk( + **self._chat_generation_from_qwen_resp(stream_resp, is_chunk=True) + ) + yield chunk + if run_manager: + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) - # Ensure `stop` is a list of strings + def _invocation_params( + self, messages: List[BaseMessage], stop: Any, **kwargs: Any + ) -> Dict[str, Any]: + params = {**self._default_params, **kwargs} if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop + if params.get("stream"): + params["incremental_output"] = True message_dicts = [convert_message_to_dict(m) for m in messages] - return message_dicts, params - def _client_params(self) -> Dict[str, Any]: - """Get the parameters used for the openai client.""" - creds: Dict[str, Any] = { - "api_key": self.dashscope_api_key, - } - return {**self._default_params, **creds} + # According to the docs, the last message should be a `user` message + if message_dicts[-1]["role"] != "user": + raise ValueError("Last message should be user message.") + # And the `system` message should be the first message if present + system_message_indices = [ + i for i, m in enumerate(message_dicts) if m["role"] == "system" + ] + if len(system_message_indices) != 1 or system_message_indices[0] != 0: + raise ValueError("System message can only be the first message.") + + params["messages"] = message_dicts + + return params + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + if llm_outputs[0] is None: + return {} + return llm_outputs[0] + + @staticmethod + def _chat_generation_from_qwen_resp( + resp: Any, is_chunk: bool = False + ) -> Dict[str, Any]: + choice = resp["output"]["choices"][0] + message = convert_dict_to_message(choice["message"], is_chunk=is_chunk) + return dict( + message=message, + generation_info=dict( + finish_reason=choice["finish_reason"], + request_id=resp["request_id"], + token_usage=dict(resp["usage"]), + ), + ) - def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: - generations = [] - for res in response["output"]["choices"]: - message = convert_dict_to_message(res["message"]) - gen = ChatGeneration( - message=message, - generation_info=dict(finish_reason=res.get("finish_reason")), - ) - generations.append(gen) - token_usage = response.get("usage", {}) - llm_output = {"token_usage": token_usage, "model_name": self.model_name} - return ChatResult(generations=generations, llm_output=llm_output) + @staticmethod + def _chunk_to_generation(chunk: ChatGenerationChunk) -> ChatGeneration: + return ChatGeneration( + message=convert_message_chunk_to_message(chunk.message), + generation_info=chunk.generation_info, + ) diff --git a/libs/community/langchain_community/llms/tongyi.py b/libs/community/langchain_community/llms/tongyi.py index 8098612392e8c..69b09b7eb07de 100644 --- a/libs/community/langchain_community/llms/tongyi.py +++ b/libs/community/langchain_community/llms/tongyi.py @@ -1,11 +1,25 @@ from __future__ import annotations +import asyncio +import functools import logging -from typing import Any, Callable, Dict, List, Optional +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, +) -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models.llms import LLM -from langchain_core.outputs import Generation, LLMResult +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.llms import BaseLLM +from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.utils import get_from_dict_or_env from requests.exceptions import HTTPError @@ -24,7 +38,7 @@ def _create_retry_decorator(llm: Tongyi) -> Callable[[Any], Any]: min_seconds = 1 max_seconds = 4 # Wait 2^x * 1 second between each retry starting with - # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + # 4 seconds, then up to 10 seconds, then 10 seconds afterward return retry( reraise=True, stop=stop_after_attempt(llm.max_retries), @@ -34,6 +48,23 @@ def _create_retry_decorator(llm: Tongyi) -> Callable[[Any], Any]: ) +def check_response(resp: Any) -> Any: + """Check the response from the completion call.""" + if resp.status_code == 200: + return resp + elif resp.status_code in [400, 401]: + raise ValueError( + f"status_code: {resp.status_code} \n " + f"code: {resp.code} \n message: {resp.message}" + ) + else: + raise HTTPError( + f"HTTP error occurred: status_code: {resp.status_code} \n " + f"code: {resp.code} \n message: {resp.message}", + response=resp, + ) + + def generate_with_retry(llm: Tongyi, **kwargs: Any) -> Any: """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator(llm) @@ -41,19 +72,7 @@ def generate_with_retry(llm: Tongyi, **kwargs: Any) -> Any: @retry_decorator def _generate_with_retry(**_kwargs: Any) -> Any: resp = llm.client.call(**_kwargs) - if resp.status_code == 200: - return resp - elif resp.status_code in [400, 401]: - raise ValueError( - f"status_code: {resp.status_code} \n " - f"code: {resp.code} \n message: {resp.message}" - ) - else: - raise HTTPError( - f"HTTP error occurred: status_code: {resp.status_code} \n " - f"code: {resp.code} \n message: {resp.message}", - response=resp, - ) + return check_response(resp) return _generate_with_retry(**kwargs) @@ -64,28 +83,44 @@ def stream_generate_with_retry(llm: Tongyi, **kwargs: Any) -> Any: @retry_decorator def _stream_generate_with_retry(**_kwargs: Any) -> Any: - stream_resps = [] - resps = llm.client.call(**_kwargs) - for resp in resps: - if resp.status_code == 200: - stream_resps.append(resp) - elif resp.status_code in [400, 401]: - raise ValueError( - f"status_code: {resp.status_code} \n " - f"code: {resp.code} \n message: {resp.message}" - ) - else: - raise HTTPError( - f"HTTP error occurred: status_code: {resp.status_code} \n " - f"code: {resp.code} \n message: {resp.message}", - response=resp, - ) - return stream_resps + responses = llm.client.call(**_kwargs) + for resp in responses: + yield check_response(resp) return _stream_generate_with_retry(**kwargs) -class Tongyi(LLM): +async def astream_generate_with_retry(llm: Tongyi, **kwargs: Any) -> Any: + """Because the dashscope SDK doesn't provide an async API, + we wrap `stream_generate_with_retry` with an async generator.""" + + class _AioTongyiGenerator: + def __init__(self, _llm: Tongyi, **_kwargs: Any): + self.generator = stream_generate_with_retry(_llm, **_kwargs) + + def __aiter__(self) -> AsyncIterator[Any]: + return self + + async def __anext__(self) -> Any: + value = await asyncio.get_running_loop().run_in_executor( + None, self._safe_next + ) + if value is not None: + return value + else: + raise StopAsyncIteration + + def _safe_next(self) -> Any: + try: + return next(self.generator) + except StopIteration: + return None + + async for chunk in _AioTongyiGenerator(llm, **kwargs): + yield chunk + + +class Tongyi(BaseLLM): """Tongyi Qwen large language models. To use, you should have the ``dashscope`` python package installed, and the @@ -96,17 +131,13 @@ class Tongyi(LLM): .. code-block:: python from langchain_community.llms import Tongyi - Tongyi = tongyi() + tongyi = tongyi() """ @property def lc_secrets(self) -> Dict[str, str]: return {"dashscope_api_key": "DASHSCOPE_API_KEY"} - @classmethod - def is_lc_serializable(cls) -> bool: - return False - client: Any #: :meta private: model_name: str = "qwen-plus" @@ -117,10 +148,7 @@ def is_lc_serializable(cls) -> bool: """Total probability mass of tokens to consider at each step.""" dashscope_api_key: Optional[str] = None - """Dashscope api key provide by alicloud.""" - - n: int = 1 - """How many completions to generate for each prompt.""" + """Dashscope api key provide by Alibaba Cloud.""" streaming: bool = False """Whether to stream the results or not.""" @@ -128,9 +156,6 @@ def is_lc_serializable(cls) -> bool: max_retries: int = 10 """Maximum number of retries to make when generating.""" - prefix_messages: List = Field(default_factory=list) - """Series of messages for Chat input.""" - @property def _llm_type(self) -> str: """Return type of llm.""" @@ -139,7 +164,9 @@ def _llm_type(self) -> str: @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - get_from_dict_or_env(values, "dashscope_api_key", "DASHSCOPE_API_KEY") + values["dashscope_api_key"] = get_from_dict_or_env( + values, "dashscope_api_key", "DASHSCOPE_API_KEY" + ) try: import dashscope except ImportError: @@ -160,118 +187,157 @@ def validate_environment(cls, values: Dict) -> Dict: @property def _default_params(self) -> Dict[str, Any]: - """Get the default parameters for calling OpenAI API.""" + """Get the default parameters for calling Tongyi Qwen API.""" normal_params = { + "model": self.model_name, "top_p": self.top_p, + "api_key": self.dashscope_api_key, } return {**normal_params, **self.model_kwargs} - def _call( + @property + def _identifying_params(self) -> Mapping[str, Any]: + return {"model_name": self.model_name, **super()._identifying_params} + + def _generate( self, - prompt: str, + prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - """Call out to Tongyi's generate endpoint. - - Args: - prompt: The prompt to pass into the model. - - Returns: - The string generated by the model. - - Example: - .. code-block:: python - - response = tongyi("Tell me a joke.") - """ - params: Dict[str, Any] = { - **{"model": self.model_name}, - **self._default_params, - **kwargs, - } - - completion = generate_with_retry( - self, - prompt=prompt, - **params, + ) -> LLMResult: + generations = [] + if self.streaming: + if len(prompts) > 1: + raise ValueError("Cannot stream results with multiple prompts.") + generation: Optional[GenerationChunk] = None + for chunk in self._stream(prompts[0], stop, run_manager, **kwargs): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + generations.append([self._chunk_to_generation(generation)]) + else: + params: Dict[str, Any] = self._invocation_params(stop=stop, **kwargs) + for prompt in prompts: + completion = generate_with_retry(self, prompt=prompt, **params) + generations.append( + [Generation(**self._generation_from_qwen_resp(completion))] + ) + return LLMResult( + generations=generations, + llm_output={ + "model_name": self.model_name, + }, ) - return completion["output"]["text"] - def _generate( + async def _agenerate( self, prompts: List[str], stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: generations = [] - params: Dict[str, Any] = { - **{"model": self.model_name}, - **self._default_params, - **kwargs, - } if self.streaming: if len(prompts) > 1: raise ValueError("Cannot stream results with multiple prompts.") - params["stream"] = True - temp = "" - for stream_resp in stream_generate_with_retry( - self, prompt=prompts[0], **params - ): - if run_manager: - stream_resp_text = stream_resp["output"]["text"] - stream_resp_text = stream_resp_text.replace(temp, "") - # Ali Cloud's streaming transmission interface, each return content - # will contain the output - # of the previous round(as of September 20, 2023, future updates to - # the Alibaba Cloud API may vary) - run_manager.on_llm_new_token(stream_resp_text) - # The implementation of streaming transmission primarily relies on - # the "on_llm_new_token" method - # of the streaming callback. - temp = stream_resp["output"]["text"] - - generations.append( - [ - Generation( - text=stream_resp["output"]["text"], - generation_info=dict( - finish_reason=stream_resp["output"]["finish_reason"], - ), - ) - ] - ) - generations.reverse() - # In the official implementation of the OpenAI API, - # the "generations" parameter passed to LLMResult seems to be a 1*1*1 - # two-dimensional list - # (including in non-streaming mode). - # Considering that Alibaba Cloud's streaming transmission - # (as of September 20, 2023, future updates to the Alibaba Cloud API may - # vary) - # includes the output of the previous round in each return, - # reversing this "generations" list should suffice - # (This is the solution with the least amount of changes to the source code, - # while still allowing for convenient modifications in the future, - # although it may result in slightly more memory consumption). + generation: Optional[GenerationChunk] = None + async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs): + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + generations.append([self._chunk_to_generation(generation)]) else: + params: Dict[str, Any] = self._invocation_params(stop=stop, **kwargs) for prompt in prompts: - completion = generate_with_retry( - self, - prompt=prompt, - **params, + completion = await asyncio.get_running_loop().run_in_executor( + None, + functools.partial( + generate_with_retry, **{"llm": self, "prompt": prompt, **params} + ), ) generations.append( - [ - Generation( - text=completion["output"]["text"], - generation_info=dict( - finish_reason=completion["output"]["finish_reason"], - ), - ) - ] + [Generation(**self._generation_from_qwen_resp(completion))] ) - return LLMResult(generations=generations) + return LLMResult( + generations=generations, + llm_output={ + "model_name": self.model_name, + }, + ) + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + params: Dict[str, Any] = self._invocation_params( + stop=stop, stream=True, **kwargs + ) + for stream_resp in stream_generate_with_retry(self, prompt=prompt, **params): + chunk = GenerationChunk(**self._generation_from_qwen_resp(stream_resp)) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + chunk=chunk, + verbose=self.verbose, + ) + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + params: Dict[str, Any] = self._invocation_params( + stop=stop, stream=True, **kwargs + ) + async for stream_resp in astream_generate_with_retry( + self, prompt=prompt, **params + ): + chunk = GenerationChunk(**self._generation_from_qwen_resp(stream_resp)) + yield chunk + if run_manager: + await run_manager.on_llm_new_token( + chunk.text, + chunk=chunk, + verbose=self.verbose, + ) + + def _invocation_params(self, stop: Any, **kwargs: Any) -> Dict[str, Any]: + params = { + **self._default_params, + **kwargs, + } + if stop is not None: + params["stop"] = stop + if params.get("stream"): + params["incremental_output"] = True + return params + + @staticmethod + def _generation_from_qwen_resp(resp: Any) -> Dict[str, Any]: + return dict( + text=resp["output"]["text"], + generation_info=dict( + finish_reason=resp["output"]["finish_reason"], + request_id=resp["request_id"], + token_usage=dict(resp["usage"]), + ), + ) + + @staticmethod + def _chunk_to_generation(chunk: GenerationChunk) -> Generation: + return Generation( + text=chunk.text, + generation_info=chunk.generation_info, + ) diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index a984a7de6c0a2..0019d552c9ee7 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -26,6 +26,7 @@ "ChatKonko", "PaiEasChatEndpoint", "QianfanChatEndpoint", + "ChatTongyi", "ChatFireworks", "ChatYandexGPT", "ChatBaichuan",