From 51d0a83a720bbae6e580207fdcef26fd8f6d2dab Mon Sep 17 00:00:00 2001 From: Muspi Merol Date: Sat, 9 Dec 2023 05:59:31 +0800 Subject: [PATCH] feat: support bind run config and mark config as immutable feat: see attributes started with `_` as non-config --- python/promplate/llm/base.py | 2 +- python/promplate/llm/openai/v1.py | 28 ++++++++++++++++++++-------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/python/promplate/llm/base.py b/python/promplate/llm/base.py index 2a54d1b..1037353 100644 --- a/python/promplate/llm/base.py +++ b/python/promplate/llm/base.py @@ -8,7 +8,7 @@ def __init__(self, **config): @property def _config(self): - return self.__dict__ + return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} class Complete(Protocol): diff --git a/python/promplate/llm/openai/v1.py b/python/promplate/llm/openai/v1.py index b75f020..f285189 100644 --- a/python/promplate/llm/openai/v1.py +++ b/python/promplate/llm/openai/v1.py @@ -1,4 +1,6 @@ +from copy import copy from functools import cached_property +from types import MappingProxyType from typing import Any, Callable, ParamSpec, TypeVar from openai import AsyncClient, Client # type: ignore @@ -21,6 +23,16 @@ class ClientConfig(Configurable): @same_params_as(Client) def __init__(self, **config): super().__init__(**config) + self._run_config = {} + + def bind(self, **run_config): + obj = copy(self) + obj._run_config = self._run_config | run_config + return obj + + @property + def _config(self): + return MappingProxyType(super()._config) @cached_property def client(self): @@ -33,21 +45,21 @@ def aclient(self): class TextComplete(ClientConfig): def __call__(self, text: str, /, **config): - config = config | {"stream": False, "prompt": text} + config = self._run_config | config | {"stream": False, "prompt": text} result = self.client.completions.create(**config) return result.choices[0].text class AsyncTextComplete(ClientConfig): async def __call__(self, text: str, /, **config): - config = config | {"stream": False, "prompt": text} + config = self._run_config | config | {"stream": False, "prompt": text} result = await self.aclient.completions.create(**config) return result.choices[0].text class TextGenerate(ClientConfig): def __call__(self, text: str, /, **config): - config = config | {"stream": True, "prompt": text} + config = self._run_config | config | {"stream": True, "prompt": text} stream = self.client.completions.create(**config) for event in stream: yield event.choices[0].text @@ -55,7 +67,7 @@ def __call__(self, text: str, /, **config): class AsyncTextGenerate(ClientConfig): async def __call__(self, text: str, /, **config): - config = config | {"stream": True, "prompt": text} + config = self._run_config | config | {"stream": True, "prompt": text} stream = await self.aclient.completions.create(**config) async for event in stream: yield event.choices[0].text @@ -64,7 +76,7 @@ async def __call__(self, text: str, /, **config): class ChatComplete(ClientConfig): def __call__(self, messages: list[Message] | str, /, **config): messages = ensure(messages) - config = config | {"stream": False, "messages": messages} + config = self._run_config | config | {"stream": False, "messages": messages} result = self.client.chat.completions.create(**config) return result.choices[0].message.content @@ -72,7 +84,7 @@ def __call__(self, messages: list[Message] | str, /, **config): class AsyncChatComplete(ClientConfig): async def __call__(self, messages: list[Message] | str, /, **config): messages = ensure(messages) - config = config | {"stream": False, "messages": messages} + config = self._run_config | config | {"stream": False, "messages": messages} result = await self.aclient.chat.completions.create(**config) return result.choices[0].message.content @@ -80,7 +92,7 @@ async def __call__(self, messages: list[Message] | str, /, **config): class ChatGenerate(ClientConfig): def __call__(self, messages: list[Message] | str, /, **config): messages = ensure(messages) - config = config | {"stream": True, "messages": messages} + config = self._run_config | config | {"stream": True, "messages": messages} stream = self.client.chat.completions.create(**config) for event in stream: yield event.choices[0].delta.content or "" @@ -89,7 +101,7 @@ def __call__(self, messages: list[Message] | str, /, **config): class AsyncChatGenerate(ClientConfig): async def __call__(self, messages: list[Message] | str, /, **config): messages = ensure(messages) - config = config | {"stream": True, "messages": messages} + config = self.run_config | config | {"stream": True, "messages": messages} stream = await self.aclient.chat.completions.create(**config) async for event in stream: yield event.choices[0].delta.content or ""