Skip to content

Commit

Permalink
fix types: fit with stricter pyright rules
Browse files Browse the repository at this point in the history
  • Loading branch information
CNSeniorious000 committed Dec 6, 2023
1 parent dc186fe commit 69dd3b1
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions python/promplate/chain/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import ChainMap
from typing import Callable, Mapping, MutableMapping, TypeVar, overload
from typing import TYPE_CHECKING, Callable, Mapping, MutableMapping, Self, TypeVar, overload

from ..llm.base import *
from ..prompt.template import Context, Loader, Template
Expand Down Expand Up @@ -38,6 +38,9 @@ def result(self, result):
def result(self):
self.__delitem__("__result__")

if TYPE_CHECKING: # fix type from `collections.ChainMap`
copy: Callable[[Self], Self]


CTX = TypeVar("CTX", Context, ChainContext)

Expand All @@ -64,8 +67,6 @@ async def arun(
) -> ChainContext:
...

context: Context

complete: Complete | AsyncComplete | None


Expand All @@ -75,15 +76,15 @@ def _run(
context: ChainContext,
/,
complete: Complete | None = None,
) -> ChainContext:
):
...

async def _arun(
self,
context: ChainContext,
/,
complete: Complete | AsyncComplete | None = None,
) -> ChainContext:
):
...

def run(self, context=None, /, complete=None) -> ChainContext:
Expand Down Expand Up @@ -242,15 +243,11 @@ def __iter__(self):

def _run(self, context, /, complete=None):
for node in self.nodes:
context = node.run(context, self.complete or complete) # type: ignore

return context
node.run(context, self.complete or complete) # type: ignore

async def _arun(self, context, /, complete=None):
for node in self.nodes:
context = await node.arun(context, self.complete or complete)

return context
await node.arun(context, self.complete or complete)

def __repr__(self):
return " + ".join(map(str, self.nodes))
Expand Down

0 comments on commit 69dd3b1

Please sign in to comment.