Skip to content

Commit

Permalink
feat: support bind run config and mark config as immutable
Browse files Browse the repository at this point in the history
feat: see attributes started with `_` as non-config
  • Loading branch information
CNSeniorious000 committed Dec 8, 2023
1 parent c282eb7 commit 51d0a83
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/promplate/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, **config):

@property
def _config(self):
return self.__dict__
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}


class Complete(Protocol):
Expand Down
28 changes: 20 additions & 8 deletions python/promplate/llm/openai/v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from copy import copy
from functools import cached_property
from types import MappingProxyType
from typing import Any, Callable, ParamSpec, TypeVar

from openai import AsyncClient, Client # type: ignore
Expand All @@ -21,6 +23,16 @@ class ClientConfig(Configurable):
@same_params_as(Client)
def __init__(self, **config):
super().__init__(**config)
self._run_config = {}

def bind(self, **run_config):
obj = copy(self)
obj._run_config = self._run_config | run_config
return obj

@property
def _config(self):
return MappingProxyType(super()._config)

@cached_property
def client(self):
Expand All @@ -33,29 +45,29 @@ def aclient(self):

class TextComplete(ClientConfig):
def __call__(self, text: str, /, **config):
config = config | {"stream": False, "prompt": text}
config = self._run_config | config | {"stream": False, "prompt": text}
result = self.client.completions.create(**config)
return result.choices[0].text


class AsyncTextComplete(ClientConfig):
async def __call__(self, text: str, /, **config):
config = config | {"stream": False, "prompt": text}
config = self._run_config | config | {"stream": False, "prompt": text}
result = await self.aclient.completions.create(**config)
return result.choices[0].text


class TextGenerate(ClientConfig):
def __call__(self, text: str, /, **config):
config = config | {"stream": True, "prompt": text}
config = self._run_config | config | {"stream": True, "prompt": text}
stream = self.client.completions.create(**config)
for event in stream:
yield event.choices[0].text


class AsyncTextGenerate(ClientConfig):
async def __call__(self, text: str, /, **config):
config = config | {"stream": True, "prompt": text}
config = self._run_config | config | {"stream": True, "prompt": text}
stream = await self.aclient.completions.create(**config)
async for event in stream:
yield event.choices[0].text
Expand All @@ -64,23 +76,23 @@ async def __call__(self, text: str, /, **config):
class ChatComplete(ClientConfig):
def __call__(self, messages: list[Message] | str, /, **config):
messages = ensure(messages)
config = config | {"stream": False, "messages": messages}
config = self._run_config | config | {"stream": False, "messages": messages}
result = self.client.chat.completions.create(**config)
return result.choices[0].message.content


class AsyncChatComplete(ClientConfig):
async def __call__(self, messages: list[Message] | str, /, **config):
messages = ensure(messages)
config = config | {"stream": False, "messages": messages}
config = self._run_config | config | {"stream": False, "messages": messages}
result = await self.aclient.chat.completions.create(**config)
return result.choices[0].message.content


class ChatGenerate(ClientConfig):
def __call__(self, messages: list[Message] | str, /, **config):
messages = ensure(messages)
config = config | {"stream": True, "messages": messages}
config = self._run_config | config | {"stream": True, "messages": messages}
stream = self.client.chat.completions.create(**config)
for event in stream:
yield event.choices[0].delta.content or ""
Expand All @@ -89,7 +101,7 @@ def __call__(self, messages: list[Message] | str, /, **config):
class AsyncChatGenerate(ClientConfig):
async def __call__(self, messages: list[Message] | str, /, **config):
messages = ensure(messages)
config = config | {"stream": True, "messages": messages}
config = self.run_config | config | {"stream": True, "messages": messages}
stream = await self.aclient.chat.completions.create(**config)
async for event in stream:
yield event.choices[0].delta.content or ""
Expand Down

0 comments on commit 51d0a83

Please sign in to comment.