diff --git a/aiopg/extras.py b/aiopg/extras.py new file mode 100644 index 00000000..db7fc51c --- /dev/null +++ b/aiopg/extras.py @@ -0,0 +1,43 @@ +def _paginate(seq, page_size): + """Consume an iterable and return it in chunks. + + Every chunk is at most `page_size`. Never return an empty chunk. + """ + page = [] + count = len(seq) + it = iter(seq) + for s in range(count + 1): + try: + for i in range(page_size): + page.append(next(it)) + yield page + page = [] + except StopIteration: + if page: + yield page + return + + +async def execute_batch(cur, sql, argslist, page_size=100): + r"""Execute groups of statements in fewer server roundtrips. + + Execute *sql* several times, against all parameters set (sequences or + mappings) found in *argslist*. + + The function is semantically similar to + + .. parsed-literal:: + + *cur*\.\ `~cursor.executemany`\ (\ *sql*\ , *argslist*\ ) + + but has a different implementation: Psycopg will join the statements into + fewer multi-statement commands, each one containing at most *page_size* + statements, resulting in a reduced number of server roundtrips. + + After the execution of the function the `cursor.rowcount` property will + **not** contain a total result. + + """ + for page in _paginate(argslist, page_size=page_size): + sqls = [cur.mogrify(sql, args) for args in page] + await cur.execute(b";".join(sqls)) diff --git a/tests/test_extras.py b/tests/test_extras.py new file mode 100644 index 00000000..46e1bec7 --- /dev/null +++ b/tests/test_extras.py @@ -0,0 +1,54 @@ +import pytest + +from aiopg.extras import _paginate, execute_batch + + +@pytest.fixture +def connect(make_connection): + async def go(**kwargs): + conn = await make_connection(**kwargs) + async with conn.cursor() as cur: + await cur.execute("DROP TABLE IF EXISTS tbl_extras") + await cur.execute("CREATE TABLE tbl_extras (id int)") + return conn + + return go + + +@pytest.fixture +def cursor(connect, loop): + async def go(): + return await (await connect()).cursor() + + cur = loop.run_until_complete(go()) + yield cur + cur.close() + + +def test__paginate(): + data = [ + [1, 2, 3], + [4, 5, 6], + [7], + ] + for index, val in enumerate(_paginate((1, 2, 3, 4, 5, 6, 7), page_size=3)): + assert data[index] == list(val) + + +def test__paginate_even(): + data = [ + [1, 2, 3], + [4, 5, 6], + ] + for index, val in enumerate(_paginate((1, 2, 3, 4, 5, 6), page_size=3)): + assert data[index] == list(val) + + +async def test_execute_batch(cursor): + args = [(1,), (2,), (3,), (4,)] + sql = 'insert into tbl_extras values(%s)' + await execute_batch(cursor, sql, argslist=args, page_size=3) + + await cursor.execute('SELECT * from tbl_extras') + ret = await cursor.fetchall() + assert list(ret) == args