diff --git a/telegram_upload/utils.py b/telegram_upload/utils.py index 5ea33d72..e5ecf495 100644 --- a/telegram_upload/utils.py +++ b/telegram_upload/utils.py @@ -52,10 +52,10 @@ async def aislice(iterator, limit): items = [] i = 0 async for value in iterator: - if i > limit: - break - i += 1 items.append(value) + i += 1 + if i >= limit: + break return items diff --git a/tests/test_utils.py b/tests/test_utils.py index 8c85e192..9b3e2051 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,9 @@ import unittest from unittest.mock import patch, Mock -from telegram_upload.utils import sizeof_fmt, scantree +from telegram_upload.utils import sizeof_fmt, scantree, aislice + +import asyncio # for TestAISlice's test_general class TestSizeOfFmt(unittest.TestCase): @@ -36,3 +38,30 @@ def test_directory(self, m): side_effect = [[directory], [file] * 3] m.side_effect = side_effect self.assertEqual(list(scantree('foo')), side_effect[-1]) + +class TestAISlice(unittest.TestCase): + def test_general(self): + class asyncFromList: + def __init__(self, rawlist): + self.rawidx = -1 + self.rawlist = rawlist.copy() + def __aiter__(self): + return self + async def __anext__(self): + await asyncio.sleep(0) + self.rawidx += 1 + if self.rawidx >= len(self.rawlist): + raise StopAsyncIteration + return self.rawlist[self.rawidx] + async def async_collect_all_files(it, pagesize): + ret = [] + while True: + aislice_res = await aislice(it, pagesize) + if len(aislice_res) == 0: break + self.assertLessEqual(len(aislice_res), pagesize, "limit returns more than asked") + ret.extend(aislice_res) + return ret + pagesize = 10 + files = [ str(i) for i in range(51) ] + gotfiles = asyncio.get_event_loop().run_until_complete(async_collect_all_files(asyncFromList(files), pagesize)) + self.assertEqual(gotfiles, files, "got different set of values")