From a5afcc15d80cf283ff4cd5b2d971f6c85ff4c77a Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Tue, 31 Oct 2023 00:44:17 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- python/promplate/chain/node.py | 12 ++++-------- python/promplate/prompt/chat.py | 16 +++++----------- python/promplate/prompt/template.py | 4 ++-- 3 files changed, 11 insertions(+), 21 deletions(-) diff --git a/python/promplate/chain/node.py b/python/promplate/chain/node.py index 58ea18c..5f99c9a 100644 --- a/python/promplate/chain/node.py +++ b/python/promplate/chain/node.py @@ -59,11 +59,10 @@ def __ror__(self, other: Mapping | None): def __ior__(self, other: MutableMapping | None): if other is not None and other is not self: - if isinstance(other, ChainContext): - self.primary_map.maps[0:0] = other.primary_map - self.fallback_map.maps[0:0] = other.fallback_map - else: + if not isinstance(other, ChainContext): return super().__ior__(other) + self.primary_map.maps[:0] = other.primary_map + self.fallback_map.maps[:0] = other.fallback_map return self def __repr__(self): @@ -218,10 +217,7 @@ async def _arun(self, context, /, complete=None): await self._apply_async_post_processes(context) def next(self, chain: AbstractChain): - if isinstance(chain, Chain): - return Chain(self, *chain) - else: - return Chain(self, chain) + return Chain(self, *chain) if isinstance(chain, Chain) else Chain(self, chain) def __add__(self, chain: AbstractChain): return self.next(chain) diff --git a/python/promplate/prompt/chat.py b/python/promplate/prompt/chat.py index 39830d4..2848967 100644 --- a/python/promplate/prompt/chat.py +++ b/python/promplate/prompt/chat.py @@ -8,18 +8,13 @@ if version_info >= (3, 11): from typing import NotRequired, TypedDict - class Message(TypedDict): # type: ignore - role: Role - content: str - name: NotRequired[str] - else: from typing_extensions import NotRequired, TypedDict - class Message(TypedDict): - role: Role - content: str - name: NotRequired[str] +class Message(TypedDict): # type: ignore + role: Role + content: str + name: NotRequired[str] class MessageBuilder: @@ -84,8 +79,7 @@ def parse_chat_markup(text: str) -> list[Message]: buffer = [] for line in text.splitlines(): - match = is_message_start.match(line) - if match: + if match := is_message_start.match(line): role, name = match.group(1), match.group(2) if current_message: diff --git a/python/promplate/prompt/template.py b/python/promplate/prompt/template.py index 55036a2..7c53d7e 100644 --- a/python/promplate/prompt/template.py +++ b/python/promplate/prompt/template.py @@ -59,13 +59,13 @@ def _on_special_token(self, token, sync: bool): else: op = inner.split(" ", 1)[0] - if op == "if" or op == "for" or op == "while": + if op in ["if", "for", "while"]: self._ops_stack.append(op) self._flush() self._builder.add_line(f"{inner}:") self._builder.indent() - elif op == "else" or op == "elif": + elif op in ["else", "elif"]: self._flush() self._builder.dedent() self._builder.add_line(f"{inner}:")