diff --git a/TODO.md b/TODO.md index ecfd427..8869e5e 100644 --- a/TODO.md +++ b/TODO.md @@ -1,9 +1,6 @@ TODO for leancheck.py ===================== -* add documentation to module and to all functions - - reorder functions if necessary - * simplify code later diff --git a/src/leancheck.py b/src/leancheck.py index 201160f..8e9dfb4 100644 --- a/src/leancheck.py +++ b/src/leancheck.py @@ -236,6 +236,22 @@ class Enumerator: which can be registered using the `Enumerator.register()` method. This class supports computing sums and products of enumerations: + + >>> print(Enumerator[int] + Enumerator[bool]) + [0, False, True, 1, 2, 3, ...] + + Use `*` to take the product of two enumerations: + + >>> print(Enumerator[int] * Enumerator[bool]) + [(0, False), (0, True), (1, False), (1, True), (2, False), (2, True), ...] + """ + + tiers: typing.Callable[[], typing.Generator] + """ + Generate tiers of values. + + >>> list(Enumerator[bool].tiers()) + [[False, True]] """ def __init__(self, tiers): @@ -308,6 +324,9 @@ def __add__(self, other): >>> print(Enumerator[int] + Enumerator[bool]) [0, False, True, 1, 2, 3, ...] + + >>> Enumerator[int] + Enumerator[bool] + Enumerator(lambda: (xs for xs in [[0, False, True], [1], [2], [3], [4], [5], ...])) """ return Enumerator(lambda: _zippend(self.tiers(), other.tiers())) @@ -317,6 +336,9 @@ def __mul__(self, other): >>> print(Enumerator[int] * Enumerator[bool]) [(0, False), (0, True), (1, False), (1, True), (2, False), (2, True), ...] + + >>> Enumerator[int] * Enumerator[bool] + Enumerator(lambda: (xs for xs in [[(0, False), (0, True)], [(1, False), (1, True)], [(2, False), (2, True)], [(3, False), (3, True)], [(4, False), (4, True)], [(5, False), (5, True)], ...])) """ return Enumerator(lambda: _pproduct(self.tiers(), other.tiers())) @@ -335,6 +357,12 @@ def __str__(self): return "[" + ', '.join(xs) + "]" def map(self, f): + """ + Applies a function to all values in the enumeration. + + >>> Enumerator[int].map(lambda x: x*2) + Enumerator(lambda: (xs for xs in [[0], [2], [4], [6], [8], [10], ...])) + """ return Enumerator(lambda: _mmap(f, self.tiers())) @classmethod @@ -462,7 +490,7 @@ def _intercalate(generator1, generator2): def _zippend(*iiterables): - return itertools.starmap(itertools.chain,itertools.zip_longest(*iiterables, fillvalue=[])) + return map(list,itertools.starmap(itertools.chain,itertools.zip_longest(*iiterables, fillvalue=[]))) def _pproduct(xss, yss, with_f=None):