From 70480585c29d23c7eeb3158a4ffeea6fca60a3ec Mon Sep 17 00:00:00 2001 From: "Chris (Someguy123)" Date: Mon, 9 Dec 2019 02:27:54 +0000 Subject: [PATCH] 0.9.0 - AsyncIO support for wrapper + builder **Updates and Improvements** - Refactored `SqliteWrapper._should_zip` into a standalone function for sharing between other classes and methods. **New additions** - New dependencies: `aiosqlite`, `async-property` and `nest_asyncio` - New class `db.base.GenericAsyncDBWrapper` - an abstract base class for AsyncIO database wrappers - New class `db.query.asyncx.base.BaseAsyncQueryBuilder` - an abstract base class for AsyncIO query builders - New protocol types in `db.types`: - `GenericAsyncCursor` - a protocol which covers any class which implements the Python DB API (PEP 249) for cursors, with all methods being coroutines (async def). - `GenericAsyncConnection` - a protocol which covers any class which implements the Python DB API (PEP 249) for DB connections, with all methods being coroutines (async def). - AsyncIO Sqlite3 support: - `db.sqlite.SqliteAsyncWrapper` uses `aiosqlite` to provide asynchronous wrapper methods for interacting with Sqlite3 databases asynchronously. - `db.query.asyncx.sqlite.SqliteAsyncQueryBuilder` is a query builder for Sqlite3 which uses `aiosqlite` to implement async methods for executing the built queries and fetching the results **Unit testing** - Added `pytest-asyncio` as a development dependency, allowing for asynchronous unit tests - Changes to `tests/base.py`: - Created class `_TestAsyncWrapperMixin` - an async version of `_TestWrapperMixin` - Created class `ExampleAsyncWrapper` - an async version of `ExampleWrapper` - Refactored `_TestWrapperMixin.example_users` into a module variable so that it can be shared by other classes. - Created `tests/test_async.py` - Uses `pytest-asyncio` for running async unit test functions - Unit tests to cover the async class `SqliteAsyncWrapper` (based on the unit tests in `test_sqlite_wrapper.py`) - Unit tests to cover the async class `SqliteAsyncQueryBuilder` (based on the unit tests in `test_sqlite_builder.py`) **This covers all notable changes and additions. There may possibly be other small changes and improvements too :)** --- Pipfile | 8 +- Pipfile.lock | 222 ++++++-- privex/db/__init__.py | 18 +- privex/db/base.py | 859 ++++++++++++++++++++++++++++- privex/db/query/__init__.py | 8 +- privex/db/query/asyncx/__init__.py | 2 + privex/db/query/asyncx/base.py | 367 ++++++++++++ privex/db/query/asyncx/sqlite.py | 149 +++++ privex/db/query/base.py | 2 +- privex/db/query/sqlite.py | 20 +- privex/db/sqlite.py | 295 +++++++++- privex/db/types.py | 46 +- requirements.txt | 4 + setup.py | 5 +- tests/base.py | 99 +++- tests/test_async.py | 241 ++++++++ tests/test_sqlite_builder.py | 99 ++++ 17 files changed, 2346 insertions(+), 98 deletions(-) create mode 100644 privex/db/query/asyncx/__init__.py create mode 100644 privex/db/query/asyncx/base.py create mode 100644 privex/db/query/asyncx/sqlite.py create mode 100644 tests/test_async.py diff --git a/Pipfile b/Pipfile index 81ace53..c3d7293 100644 --- a/Pipfile +++ b/Pipfile @@ -8,7 +8,7 @@ coverage = "*" codecov = "*" pytest = "*" pytest-cov = "*" -privex-helpers = {extras = ["setuppy"],version = ">=2.3.0"} +privex-helpers = {extras = ["setuppy"],version = ">=2.6.0"} sphinx-autobuild = ">=0.7.1" restructuredtext-lint = ">=1.3.0" sphinx-rtd-theme = ">=0.4.3" @@ -20,11 +20,15 @@ python-dotenv = "*" [packages] privex-loghelper = "*" python-dateutil = "*" -privex-helpers = ">=2.3.0" +privex-helpers = ">=2.6.0" pytz = "*" mysqlclient = "*" psycopg2 = "*" typing-extensions = "*" +aiosqlite = "*" +pytest-asyncio = "*" +async-property = "*" +nest-asyncio = "*" [requires] python_version = "3.8" diff --git a/Pipfile.lock b/Pipfile.lock index 413aab7..e44aaeb 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "72fd668c7ee110f31545ed15ff63b6581b6452bc9fed1196496ad7eeb9340bc7" + "sha256": "5373f0aeec83c9c50ef23ac3a9f9ba200ee4ae8eae93228906c322061c6d7965" }, "pipfile-spec": 6, "requires": { @@ -16,6 +16,43 @@ ] }, "default": { + "aiosqlite": { + "hashes": [ + "sha256:ad84fbd7516ca7065d799504fc41d6845c938e5306d1b7dd960caaeda12e22a9" + ], + "index": "pypi", + "version": "==0.10.0" + }, + "async-property": { + "hashes": [ + "sha256:53826fd45a67d7d6cca3d7abbc0e8ba951f7c7618c830021fbd3675979b0b67d", + "sha256:f1f105009a6216ed9a13031aa13632754ed8a5c2e329fb8f9f2082d83529eacd" + ], + "index": "pypi", + "version": "==0.2.1" + }, + "attrs": { + "hashes": [ + "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", + "sha256:f7b7ce16570fe9965acd6d30101a28f62fb4a7f9e926b3bbc9b61f8b04247e72" + ], + "version": "==19.3.0" + }, + "importlib-metadata": { + "hashes": [ + "sha256:3a8b2dfd0a2c6a3636e7c016a7e54ae04b997d30e69d5eacdca7a6c2221a1402", + "sha256:41e688146d000891f32b1669e8573c57e39e5060e7f5f647aa617cd9a9568278" + ], + "markers": "python_version < '3.8'", + "version": "==1.2.0" + }, + "more-itertools": { + "hashes": [ + "sha256:b84b238cce0d9adad5ed87e745778d20a3f8487d0f0cb8b8a586816c7496458d", + "sha256:c833ef592a0324bcc6a60e48440da07645063c453880c9477ceb22490aec1564" + ], + "version": "==8.0.2" + }, "mysqlclient": { "hashes": [ "sha256:4c82187dd6ab3607150fbb1fa5ef4643118f3da122b8ba31c3149ddd9cf0cb39", @@ -26,13 +63,35 @@ "index": "pypi", "version": "==1.4.6" }, + "nest-asyncio": { + "hashes": [ + "sha256:7d4d7c1ca2aad0e5c2706d0222c8ff006805abfd05caa97e6127c8811d0f6adc", + "sha256:c5e710ef96b1f490f2facb47780314810c7131d695cddf829516c3ffd54beb83" + ], + "index": "pypi", + "version": "==1.2.1" + }, + "packaging": { + "hashes": [ + "sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47", + "sha256:d9551545c6d761f3def1677baf08ab2a3ca17c56879e70fecba2fc4dde4ed108" + ], + "version": "==19.2" + }, + "pluggy": { + "hashes": [ + "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0", + "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d" + ], + "version": "==0.13.1" + }, "privex-helpers": { "hashes": [ - "sha256:009df1c637e7002faf7d5866eb63ed2f0bf919b117c172ff59914995d9146d60", - "sha256:bec2bf1a2f7753f162c127251541ebc4e2f3ff70224e906eaa776dc43822e4a1" + "sha256:2707f2904e92c9c86886b75d5a5e5b16308d942d26de8db4f1b2c8ef19a5c645", + "sha256:3aa73691c14c01b7d0ab395916bd256722d2704532affb5716b6063b83f3a384" ], "index": "pypi", - "version": "==2.4.0" + "version": "==2.6.0" }, "privex-loghelper": { "hashes": [ @@ -61,6 +120,35 @@ "index": "pypi", "version": "==2.8.4" }, + "py": { + "hashes": [ + "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa", + "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53" + ], + "version": "==1.8.0" + }, + "pyparsing": { + "hashes": [ + "sha256:20f995ecd72f2a1f4bf6b072b63b22e2eb457836601e76d6e5dfcd75436acc1f", + "sha256:4ca62001be367f01bd3e92ecbb79070272a9d4964dce6a48a82ff0b8bc7e683a" + ], + "version": "==2.4.5" + }, + "pytest": { + "hashes": [ + "sha256:63344a2e3bce2e4d522fd62b4fdebb647c019f1f9e4ca075debbd13219db4418", + "sha256:f67403f33b2b1d25a6756184077394167fe5e2f9d8bdaab30707d19ccec35427" + ], + "version": "==5.3.1" + }, + "pytest-asyncio": { + "hashes": [ + "sha256:9fac5100fd716cbecf6ef89233e8590a4ad61d729d1732e0a96b84182df1daaf", + "sha256:d734718e25cfc32d2bf78d346e99d33724deeba774cc4afdf491530c6184b63b" + ], + "index": "pypi", + "version": "==0.10.0" + }, "python-dateutil": { "hashes": [ "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", @@ -84,6 +172,13 @@ ], "version": "==1.13.0" }, + "sniffio": { + "hashes": [ + "sha256:20ed6d5b46f8ae136d00b9dcb807615d83ed82ceea6b2058cecb696765246da5", + "sha256:8e3810100f69fe0edd463d02ad407112542a11ffdc29f67db2bf3771afb87a21" + ], + "version": "==1.1.0" + }, "typing-extensions": { "hashes": [ "sha256:091ecc894d5e908ac75209f10d5b4f118fbdb2eb1ede6a63544054bb1edb41f2", @@ -92,6 +187,20 @@ ], "index": "pypi", "version": "==3.7.4.1" + }, + "wcwidth": { + "hashes": [ + "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", + "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" + ], + "version": "==0.1.7" + }, + "zipp": { + "hashes": [ + "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e", + "sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335" + ], + "version": "==0.6.0" } }, "develop": { @@ -109,13 +218,6 @@ ], "version": "==0.26.2" }, - "asgiref": { - "hashes": [ - "sha256:7e06d934a7718bf3975acbf87780ba678957b87c7adc056f13b6215d610695a0", - "sha256:ea448f92fc35a0ef4b1508f53a04c4670255a3f33d22a81c8fc9c872036adbe5" - ], - "version": "==3.2.3" - }, "attrs": { "hashes": [ "sha256:08a96c641c3a74e44eb59afb61a24f2cb9f4d7188748e76ba4bb5edfa3cb7d1c", @@ -197,14 +299,6 @@ "index": "pypi", "version": "==4.5.4" }, - "django": { - "hashes": [ - "sha256:6f857bd4e574442ba35a7172f1397b303167dae964cf18e53db5e85fe248d000", - "sha256:d98c9b6e5eed147bc51f47c014ff6826bd1ab50b166956776ee13db5a58804ae" - ], - "index": "pypi", - "version": "==3.0" - }, "docutils": { "hashes": [ "sha256:6c4f696463b79f1fb8ba0c594b63840ebd41f059e92b31957c46b74a4599b6d0", @@ -228,6 +322,14 @@ ], "version": "==1.1.0" }, + "importlib-metadata": { + "hashes": [ + "sha256:3a8b2dfd0a2c6a3636e7c016a7e54ae04b997d30e69d5eacdca7a6c2221a1402", + "sha256:41e688146d000891f32b1669e8573c57e39e5060e7f5f647aa617cd9a9568278" + ], + "markers": "python_version < '3.8'", + "version": "==1.2.0" + }, "jinja2": { "hashes": [ "sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f", @@ -237,10 +339,10 @@ }, "keyring": { "hashes": [ - "sha256:9b80469783d3f6106bce1d389c6b8b20c8d4d739943b1b8cd0ddc2a45d065f9d", - "sha256:ee3d35b7f1ac3cb69e9a1e4323534649d3ab2fea402738a77e4250c152970fed" + "sha256:a3f71fc0cf6b74e201e70532879ba1d15db25cb2c7407dce52fe52a6d5fc7b66", + "sha256:fc9cadedae35b77141f670f84c10a657147d2e526348698c93dd77f039979729" ], - "version": "==19.3.0" + "version": "==20.0.0" }, "livereload": { "hashes": [ @@ -285,10 +387,10 @@ }, "more-itertools": { "hashes": [ - "sha256:53ff73f186307d9c8ef17a9600309154a6ae27f25579e80af4db8f047ba14bc2", - "sha256:a0ea684c39bc4315ba7aae406596ef191fd84f873d2d2751f84d64e81a7a2d45" + "sha256:b84b238cce0d9adad5ed87e745778d20a3f8487d0f0cb8b8a586816c7496458d", + "sha256:c833ef592a0324bcc6a60e48440da07645063c453880c9477ceb22490aec1564" ], - "version": "==8.0.0" + "version": "==8.0.2" }, "packaging": { "hashes": [ @@ -325,11 +427,11 @@ }, "privex-helpers": { "hashes": [ - "sha256:009df1c637e7002faf7d5866eb63ed2f0bf919b117c172ff59914995d9146d60", - "sha256:bec2bf1a2f7753f162c127251541ebc4e2f3ff70224e906eaa776dc43822e4a1" + "sha256:2707f2904e92c9c86886b75d5a5e5b16308d942d26de8db4f1b2c8ef19a5c645", + "sha256:3aa73691c14c01b7d0ab395916bd256722d2704532affb5716b6063b83f3a384" ], "index": "pypi", - "version": "==2.4.0" + "version": "==2.6.0" }, "privex-loghelper": { "hashes": [ @@ -365,7 +467,6 @@ "sha256:63344a2e3bce2e4d522fd62b4fdebb647c019f1f9e4ca075debbd13219db4418", "sha256:f67403f33b2b1d25a6756184077394167fe5e2f9d8bdaab30707d19ccec35427" ], - "index": "pypi", "version": "==5.3.1" }, "pytest-cov": { @@ -376,6 +477,14 @@ "index": "pypi", "version": "==2.8.1" }, + "python-dateutil": { + "hashes": [ + "sha256:73ebfe9dbf22e832286dafa60473e4cd239f8592f699aa5adaf10050e6e1823c", + "sha256:75bb3f31ea686f1197762692a9ee6a7550b59fc6ca3a1f4b5d7e32fb98e2da2a" + ], + "index": "pypi", + "version": "==2.8.1" + }, "python-dotenv": { "hashes": [ "sha256:debd928b49dbc2bf68040566f55cdb3252458036464806f4094487244e2a4093", @@ -394,21 +503,19 @@ }, "pyyaml": { "hashes": [ - "sha256:0113bc0ec2ad727182326b61326afa3d1d8280ae1122493553fd6f4397f33df9", - "sha256:01adf0b6c6f61bd11af6e10ca52b7d4057dd0be0343eb9283c878cf3af56aee4", - "sha256:5124373960b0b3f4aa7df1707e63e9f109b5263eca5976c66e08b1c552d4eaf8", - "sha256:5ca4f10adbddae56d824b2c09668e91219bb178a1eee1faa56af6f99f11bf696", - "sha256:7907be34ffa3c5a32b60b95f4d95ea25361c951383a894fec31be7252b2b6f34", - "sha256:7ec9b2a4ed5cad025c2278a1e6a19c011c80a3caaac804fd2d329e9cc2c287c9", - "sha256:87ae4c829bb25b9fe99cf71fbb2140c448f534e24c998cc60f39ae4f94396a73", - "sha256:9de9919becc9cc2ff03637872a440195ac4241c80536632fffeb6a1e25a74299", - "sha256:a5a85b10e450c66b49f98846937e8cfca1db3127a9d5d1e31ca45c3d0bef4c5b", - "sha256:b0997827b4f6a7c286c01c5f60384d218dca4ed7d9efa945c3e1aa623d5709ae", - "sha256:b631ef96d3222e62861443cc89d6563ba3eeb816eeb96b2629345ab795e53681", - "sha256:bf47c0607522fdbca6c9e817a6e81b08491de50f3766a7a0e6a5be7905961b41", - "sha256:f81025eddd0327c7d4cfe9b62cf33190e1e736cc6e97502b3ec425f574b3e7a8" - ], - "version": "==5.1.2" + "sha256:0e7f69397d53155e55d10ff68fdfb2cf630a35e6daf65cf0bdeaf04f127c09dc", + "sha256:2e9f0b7c5914367b0916c3c104a024bb68f269a486b9d04a2e8ac6f6597b7803", + "sha256:35ace9b4147848cafac3db142795ee42deebe9d0dad885ce643928e88daebdcc", + "sha256:38a4f0d114101c58c0f3a88aeaa44d63efd588845c5a2df5290b73db8f246d15", + "sha256:483eb6a33b671408c8529106df3707270bfacb2447bf8ad856a4b4f57f6e3075", + "sha256:4b6be5edb9f6bb73680f5bf4ee08ff25416d1400fbd4535fe0069b2994da07cd", + "sha256:7f38e35c00e160db592091751d385cd7b3046d6d51f578b29943225178257b31", + "sha256:8100c896ecb361794d8bfdb9c11fce618c7cf83d624d73d5ab38aef3bc82d43f", + "sha256:c0ee8eca2c582d29c3c2ec6e2c4f703d1b7f1fb10bc72317355a746057e7346c", + "sha256:e4c015484ff0ff197564917b4b4246ca03f411b9bd7f16e02a2f586eb48b6d04", + "sha256:ebc4ed52dcc93eeebeae5cf5deb2ae4347b3a81c3fa12b0b8c976544829396a4" + ], + "version": "==5.2" }, "readme-renderer": { "hashes": [ @@ -452,6 +559,13 @@ ], "version": "==1.13.0" }, + "sniffio": { + "hashes": [ + "sha256:20ed6d5b46f8ae136d00b9dcb807615d83ed82ceea6b2058cecb696765246da5", + "sha256:8e3810100f69fe0edd463d02ad407112542a11ffdc29f67db2bf3771afb87a21" + ], + "version": "==1.1.0" + }, "snowballstemmer": { "hashes": [ "sha256:209f257d7533fdb3cb73bdbd24f436239ca3b2fa67d56f6ff88e86be08cc5ef0", @@ -525,13 +639,6 @@ ], "version": "==1.1.3" }, - "sqlparse": { - "hashes": [ - "sha256:40afe6b8d4b1117e7dff5504d7a8ce07d9a1b15aeeade8a2d10f130a834f8177", - "sha256:7c3dca29c022744e95b547e867cee89f4fce4373f3549ccd8797d8eb52cdb873" - ], - "version": "==0.3.0" - }, "tornado": { "hashes": [ "sha256:349884248c36801afa19e342a77cc4458caca694b0eda633f5878e458a44cb2c", @@ -546,10 +653,10 @@ }, "tqdm": { "hashes": [ - "sha256:156a0565f09d1f0ef8242932a0e1302462c93827a87ba7b4423d90f01befe94c", - "sha256:c0ffb55959ea5f3eaeece8d2db0651ba9ced9c72f40a6cce3419330256234289" + "sha256:895796ea8df435b6f502bf122f2b2034a3d48e6d8ff52175606ac1051b0e3e12", + "sha256:e405d16c98fcf30725d0c9d493ed07302a18846b5452de5253030ccd18996f87" ], - "version": "==4.40.0" + "version": "==4.40.1" }, "twine": { "hashes": [ @@ -591,6 +698,13 @@ "sha256:f4da1763d3becf2e2cd92a14a7c920f0f00eca30fdde9ea992c836685b9faf28" ], "version": "==0.33.6" + }, + "zipp": { + "hashes": [ + "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e", + "sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335" + ], + "version": "==0.6.0" } } } diff --git a/privex/db/__init__.py b/privex/db/__init__.py index 3f6ee32..c89ea26 100644 --- a/privex/db/__init__.py +++ b/privex/db/__init__.py @@ -43,10 +43,13 @@ import logging import warnings +import nest_asyncio from privex.db.base import GenericDBWrapper, CursorManager from privex.db.types import GenericCursor, GenericConnection from privex.db.query.base import BaseQueryBuilder, QueryMode +nest_asyncio.apply() + __all__ = [ 'GenericDBWrapper', 'CursorManager', 'GenericCursor', 'GenericConnection', 'BaseQueryBuilder', 'QueryMode', 'name', 'VERSION', @@ -76,7 +79,18 @@ from privex.db.sqlite import SqliteWrapper __all__ += ['SqliteWrapper'] except ImportError: - log.warning("Failed to import privex.db.sqlite (missing Python SQLite API?)") + log.warning("Failed to import privex.db.sqlite.SqliteWrapper (missing Python SQLite API?)") + +try: + from privex.db.sqlite import SqliteAsyncWrapper + from privex.db.query.asyncx import SqliteAsyncQueryBuilder + + __all__ += ['SqliteAsyncWrapper', 'SqliteAsyncQueryBuilder'] +except ImportError: + log.warning( + "Failed to import privex.db.sqlite.SqliteAsyncWrapper and/or privex.db.query.asyncx.SqliteAsyncQueryBuilder " + "(missing aiosqlite library?)" + ) def _setup_logging(level=logging.WARNING): @@ -99,7 +113,7 @@ def _setup_logging(level=logging.WARNING): log = _setup_logging() name = 'db' -VERSION = '0.8.0' +VERSION = '0.9.0' diff --git a/privex/db/base.py b/privex/db/base.py index 9f54bb1..fed38e7 100644 --- a/privex/db/base.py +++ b/privex/db/base.py @@ -23,12 +23,16 @@ """ import abc +import asyncio from abc import ABC -from typing import List, Tuple, Set, Union, Any, Optional, TypeVar, Generic, Dict -from privex.helpers import dictable_namedtuple, is_namedtuple, empty, DictObject +from typing import List, Tuple, Set, Union, Any, Optional, TypeVar, Generic, Dict, Iterable, Coroutine + +from async_property import async_property +from privex.helpers import dictable_namedtuple, is_namedtuple, empty, DictObject, awaitable, run_sync, aobject import logging -from privex.db.types import GenericCursor, GenericConnection +from privex.db.types import GenericCursor, GenericConnection, GenericAsyncCursor, CoroNone, GenericAsyncConnection, \ + BOOL_CORO, INT_CORO, TUPD_OPT_CORO, TUPD_CORO, DICT_CORO log = logging.getLogger(__name__) DBExecution = dictable_namedtuple('DBExecution', 'query result execute_result cursor') @@ -36,6 +40,13 @@ CUR = TypeVar('CUR') +def _should_zip(_res, query_mode='dict', auto_zip=True): + # auto_zip = self.AUTO_ZIP_COLS + is_dict_tuple = not isinstance(_res, dict) and not is_namedtuple(_res) + return auto_zip and (is_dict_tuple and query_mode == 'dict') + + + class CursorManager(GenericCursor, Generic[CUR], object): """ Not all database API's support context management with their cursors, so this class wraps a given database @@ -157,6 +168,133 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pass +# noinspection DuplicatedCode +class AsyncCursorManager(GenericAsyncCursor, Generic[CUR], object): + """ + Async version of :class:`.CursorManager` + + Not all database API's support context management with their cursors, so this class wraps a given database + cursor objects, and provides context management methods :meth:`.__enter__` and :meth:`.__exit__` + """ + + _cursor: Union[CUR, GenericAsyncCursor, Any] + """The actual cursor object this class is wrapping""" + _cursor_id: int + """The object ID of the cursor instance - for context manager nesting tracking""" + _close_callback: Optional[CoroNone] + """The function/method to callback to when the cursor is closed""" + + _active_cursor_ids = set() + """Object IDs of cursors which have a responsible context managing CursorManager""" + can_cleanup: bool + """This becomes True if this is the **first** context manager instance for a cursor""" + is_context_manager: bool + """``True`` if this class is being used in a ``with`` statement, otherwise ``False``.""" + + def __init__(self, cursor: CUR, close_callback: Optional[CoroNone] = None): + """ + Initialise the cursor manager. + + :param CUR|GenericCursor cursor: A database cursor object to wrap + :param Coroutine close_callback: If specified, this awaitable callable (function/method) will be called BEFORE + and AFTER the cursor is closed, with the kwargs ``state='BEFORE_CLOSE'`` and + ``state='AFTER_CLOSE'`` respectively. + + """ + self._cursor = cursor + self._cursor_id = id(cursor) + self.can_cleanup = False + self.is_context_manager = False + self._close_callback = close_callback + + @property + def rowcount(self) -> int: + return self._cursor.rowcount + + @property + def description(self): + return self._cursor.description + + @property + def lastrowid(self): + return self._cursor.lastrowid + + async def execute(self, query: str, params: Iterable = None, *args, **kwargs) -> Any: + args = [query, params] + list(args) if params is not None else [query] + list(args) + return await self._cursor.execute(query, *args, **kwargs) + + async def fetchone(self, *args, **kwargs) -> Union[tuple, list, dict, set]: + return await self._cursor.fetchone(*args, **kwargs) + + async def fetchall(self, *args, **kwargs) -> Iterable: + return await self._cursor.fetchall(*args, **kwargs) + + async def fetchmany(self, *args, **kwargs) -> Iterable: + return await self._cursor.fetchmany(*args, **kwargs) + + async def close(self, *args, **kwargs): + if self._close_callback is not None: + await self._close_callback(state='BEFORE_CLOSE') + _closed = await self._cursor.close(*args, **kwargs) + if self._close_callback is not None: + await self._close_callback(state='AFTER_CLOSE') + self._close_callback = None + return _closed + + def _cleanup(self, *args, **kwargs): + if self._cursor is None: return + try: + _closed = self.close(*args, **kwargs) + self._cursor = None + del self._cursor + return _closed + except (Exception, BaseException): + pass + try: + del self._cursor + except AttributeError: + pass + return None + + def __getattr__(self, item): + try: + return object.__getattribute__(self, item) + except AttributeError: + return getattr(self._cursor, item) + + def __setattr__(self, key, value): + # if hasattr(self, key): + try: + object.__setattr__(self, key, value) + except AttributeError: + setattr(self._cursor, key, value) + + def __next__(self): + return self._cursor.__next__() + + def __iter__(self): + return self._cursor.__iter__() + + def __getitem__(self, item): + return self._cursor.__getitem__(item) + + def __setitem__(self, item, value): + return self._cursor.__setitem__(item, value) + + def __enter__(self): + self.is_context_manager = True + self.can_cleanup = self._cursor_id not in self._active_cursor_ids + if self.can_cleanup: + self._active_cursor_ids.add(self._cursor_id) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.can_cleanup: + self._cleanup() + self._active_cursor_ids.remove(self._cursor_id) + + + def cursor_to_dict(cur: Union[GenericCursor, Any]) -> DictObject: """ Convert a cursor object into a dictionary (:class:`.DictObject`), allowing the original cursor to be safely closed @@ -449,10 +587,10 @@ def query(self, sql: str, *params, fetch='all', **kwparams) -> Tuple[Optional[it handling the returned results. """ - def _should_zip(_res): - auto_zip = self.AUTO_ZIP_COLS - is_dict_tuple = not isinstance(_res, dict) and not is_namedtuple(_res) - return auto_zip and (is_dict_tuple and query_mode == 'dict') + # def _should_zip(_res): + # auto_zip = self.AUTO_ZIP_COLS + # is_dict_tuple = not isinstance(_res, dict) and not is_namedtuple(_res) + # return auto_zip and (is_dict_tuple and query_mode == 'dict') cursor_name = kwparams.pop('cursor_name', None) query_mode = kwparams.pop('query_mode', self.query_mode) @@ -466,7 +604,7 @@ def _should_zip(_res): res = c.fetchone() if res is None: return None, res_exec, c - if _should_zip(res): + if _should_zip(res, query_mode=query_mode, auto_zip=self.AUTO_ZIP_COLS): res = self._zip_cols(c, tuple(res)) elif fetch == 'no': res = None @@ -857,3 +995,708 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._conn.close() del self._conn self._conn = None + + +class GenericAsyncDBWrapper(GenericDBWrapper): + + def __init__(self, db=None, connector_func: callable = None, **kwargs): + """ + Initialise the database wrapper class. + + This constructor sets ``_conn`` to None, and sets up various instance variables such as ``connector_func`` + and ``query_mode``. + + While you can set various instance variables such as ``query_mode`` via this constructor, if you're inheriting + this class, it's recommended that you override the ``DEFAULT_`` static class attributes to your preference. + + + :param str db: The name / path of the database to connect to + :param callable connector_func: A function / method / lambda which returns a database connection object + which acts like :class:`.GenericConnection` + :key bool auto_create_schema: (Default: ``True``) If True, call :meth:`.create_schemas` during constructor. + :key list connector_args: A list of arguments to be passed as positional args to ``connector_func`` + :key dict connector_kwargs: A dict of arguments to passed as keyword args to ``connector_func`` + :key str query_mode: Either ``flat`` (return tuples) or ``dict`` (return dicts of column:value) + Controls how results are returned from query functions, + e.g. :py:meth:`.query` and :py:meth:`.fetch` + :key str table_query: The query used to check for existence of a table. The query should take one prepared + statement argument (the table name to check for), and the first column returned + must be named ``table_count`` - an integer containing how many tables were found under + the given name (usually just 0 if not found, 1 if found). + :key str table_list_query: The query used to obtain a list of tables in the database. + The query should take no arguments, and return rows containing one column each, + ``name`` - the name of the table. + """ + self.db = db + self._conn = None + self._cursor = None + self._execution_log = [] + self._cursors = [] + self.connector_func = connector_func + # auto_create_schema = kwargs.pop('auto_create_schema', True) + self.enable_execution_log = kwargs.pop('enable_execution_log', self.DEFAULT_ENABLE_EXECUTION_LOG) + self.connector_args = kwargs.pop('connector_args', []) + self.connector_kwargs = kwargs.pop('connector_kwargs', {}) + self.query_mode = kwargs.pop('query_mode', self.DEFAULT_QUERY_MODE) + self.table_query = kwargs.pop('table_query', self.DEFAULT_TABLE_QUERY) + self.table_list_query = kwargs.pop('table_list_query', self.DEFAULT_TABLE_LIST_QUERY) + + # if auto_create_schema: + # # res = asyncio.run(self.create_schemas) + # res = self.create_schemas() + # log.debug('Create schema result: "%s"', res) + + @awaitable + def make_connection(self, *args, **kwargs) \ + -> Union[GenericAsyncConnection, Coroutine[Any, Any, GenericAsyncConnection]]: + """ + Creates a database connection using :py:attr:`.connector_func`, passing all arguments/kwargs directly + to the connector function. + + :return GenericConnection conn: A database connection object, which should implement at least the basic + connection object methods as specified in the Python DB API (PEP 249), + and in the Protocol type class :class:`.GenericConnection` + """ + return self._make_connection(*args, **kwargs) + + async def _make_connection(self, *args, **kwargs) -> GenericAsyncConnection: + await_conn = kwargs.pop('await_conn', True) + conn = self.connector_func(*args, **kwargs) + if asyncio.iscoroutine(conn) and await_conn: + return await conn + return conn + + @async_property + async def conn(self) -> GenericAsyncConnection: + """Get or create a database connection""" + # c = run_sync(self._get_connection) + # return run_sync(c) if asyncio.iscoroutine(c) else c + return await self._get_connection() + + @awaitable + def get_connection(self, new=False) -> Union[GenericAsyncConnection, Coroutine[Any, Any, GenericAsyncConnection]]: + """Get or create a database connection""" + return self._get_connection(new=new) + + async def _get_connection(self, new=False, await_conn=True) -> GenericAsyncConnection: + """Get or create a database connection""" + + if self._conn is None or new: + conn = await self._make_connection(*self.connector_args, **self.connector_kwargs, await_conn=await_conn) + if new: + return conn + self._conn = conn + + return self._conn + + @awaitable + def table_exists(self, table: str) -> BOOL_CORO: + return self._table_exists(table=table) + + async def _table_exists(self, table: str) -> bool: + """ + Returns ``True`` if the table ``table`` exists in the database, otherwise ``False``. + + + >>> GenericDBWrapper().table_exists('some_table') + True + >>> GenericDBWrapper().table_exists('other_table') + False + + + :param str table: The table to check for existence. + :return bool exists: ``True`` if the table ``table`` exists in the database, otherwise ``False``. + """ + + res = await self.fetchone(self.table_query, [table]) + if isinstance(res, dict): + return res['table_count'] == 1 + else: + return res[0] == 1 + + @awaitable + def list_tables(self) -> Union[List[str], Coroutine[Any, Any, List[str]]]: + return self._list_tables() + + async def _list_tables(self) -> List[str]: + """ + Get a list of tables present in the current database. + + Example:: + + >>> GenericDBWrapper().list_tables() + ['sqlite_sequence', 'nodes', 'node_api', 'node_failures'] + + + :return List[str] tables: A list of tables in the database + """ + res = await self.fetchall(self.table_list_query) + if len(res) < 1: + return [] + if isinstance(res[0], dict): + return [r['name'] for r in res] + else: + return [r[0] for r in res] + + _Q_OUT_TYPE = Tuple[Optional[iter], Any, GenericAsyncCursor] + + @awaitable + def query(self, sql: str, *params, fetch='all', **kwparams) -> Union[_Q_OUT_TYPE, Coroutine[Any, Any, _Q_OUT_TYPE]]: + return self._query(sql, *params, fetch=fetch, **kwparams) + + async def _query(self, sql: str, *params, fetch='all', **kwparams) -> _Q_OUT_TYPE: + """ + + Create an instance of your database wrapper: + + >>> db = GenericDBWrapper() + + **Querying with prepared SQL queries and returning a single row**:: + + >>> res, res_exec, cur = db.query("SELECT * FROM users WHERE first_name = ?;", ['John'], fetch='one') + >>> res + (12, 'John', 'Doe', '123 Example Road',) + >>> cur.close() + + **Querying with plain SQL queries, using query_mode, handling an iterator result, and using the cursor** + + If your database API returns rows as ``tuple``s or ``list``s, you can use ``query_mode='dict'`` (or set + :py:attr:`.query_mode` in the constructor) to convert any row results into dictionaries which map + each column to their values. + + >>> res, _, cur = db.query("SELECT * FROM users;", fetch='all', query_mode='dict') + + When querying with ``fetch='all'``, depending on your database API, ``res`` may be an iterator, and cannot + be accessed via an index like ``res[0]``. + + You should make sure to iterate the rows using a ``for`` loop:: + + >>> for row in res: + ... print(row['first_name'], ':', row) + John : {'first_name': 'John', 'last_name': 'Doe', 'id': 12} + Dave : {'first_name': 'Dave', 'last_name': 'Johnson', 'id': 13} + Aaron : {'first_name': 'Aaron', 'last_name': 'Swartz', 'id': 14} + + Or, if the result object is a generator, then you can auto-iterate the results into a list + using ``x = list(res)``:: + + >>> rows = list(res) + >>> rows[0] + {'first_name': 'John', 'last_name': 'Doe', 'id': 12} + + Using the returned cursor (third return item), we can access various metadata about our query. Note that + cursor objects vary between database APIs, and not all methods/attributes may be available, or may + return different data than shown below:: + + >>> cur.description # cursor.description often contains the column names matching the query columns + (('id', None, None, None, None, None, None), ('first_name', None, None, None, None, None, None), + ('last_name', None, None, None, None, None, None)) + + >>> _, _, cur = db.query("INSERT INTO users (first_name, last_name) VALUES ('a', 'b')", fetch='no') + >>> cur.rowcount # cursor.rowcount tells us how many rows were affected by a query + 1 + >>> cur.lastrowid # cursor.lastrowid tells us the ID of the last row we inserted with this cursor + 3 + + + + + :param str sql: An SQL query to execute + :param params: Any positional arguments other than ``sql`` will be passed to ``cursor.execute``. + :param str fetch: Fetch mode. Either ``all`` (return ``cursor.fetchall()`` as first return arg), + ``one`` (return ``cursor.fetchone()``), or ``no`` (do not fetch. first return arg is None). + :param kwparams: Any keyword arguments that aren't specified as parameters / keyword args for this method + will be forwarded to ``cursor.execute`` + + :key GenericCursor cursor: Use this specific cursor instead of automatically obtaining one + :key cursor_name: If your database API supports named cursors (e.g. PostgreSQL), then you may + specify ``cursor_name`` as a keyword argument to use a named cursor for this query. + :key query_mode: Either ``flat`` (fetch results as they were originally returned from the DB), or + ``dict`` (use :meth:`._zip_cols` to convert tuple/list rows into dicts mapping col:value). + + :return iter results: (tuple item 1) An iterable such as a generator, or storage type e.g. ``list`` or ``dict``. + **NOTE:** If you've set ``fetch='all'``, depending on your database adapter, this + may be a generator or other form of iterator that cannot be directly accessed via index + (i.e. ``res[123]``). Instead you must iterate it with a ``for`` loop, or cast it into + a list/tuple to automatically iterate it into an indexed object, e.g. ``list(res)` + + :return Any res_exec: (tuple item 2) The object returned from running ``cur.execute(sql, *params, **kwparams)``. + This may be a cursor, but may also vary based on database API. + + :return GenericCursor cur: (tuple item 3) The cursor that was used to execute and fetch your query. To allow + for use with server side cursors, the cursor is NOT closed automatically. + To avoid stale cursors, it's best to run ``cur.close()`` when you're done with + handling the returned results. + + """ + # cursor_name = kwparams.pop('cursor_name', None) + query_mode = kwparams.pop('query_mode', self.query_mode) + # c = kwparams.pop('cursor', await self.get_cursor(cursor_name)) + # res_exec = await c.execute(sql, *params, **kwparams) + # res, c = None, None + res, c = await self.execute(sql, *params, fetch=fetch, cleanup_cursor=False, **kwparams) + + if fetch == 'all': + if _should_zip(res, query_mode, auto_zip=self.AUTO_ZIP_COLS): + res = [self._zip_cols(c, r) for r in res] + elif fetch == 'one': + if res is None: + return None, c, c + if _should_zip(res, query_mode, auto_zip=self.AUTO_ZIP_COLS): + res = self._zip_cols(c, tuple(res)) + elif fetch == 'no': + res = None + else: + raise AttributeError("The parameter 'fetch' must be either 'all', 'one' or 'no'.") + if self.enable_execution_log: + self._execution_log += [DBExecution(sql, res, c, c)] + return res, c, c + + async def execute(self, query: str, *params: Iterable, fetch='all', **kwargs) \ + -> Tuple[Iterable, Union[GenericAsyncCursor, DictObject]]: + + cursor_name = kwargs.pop('cursor_name', None) + cleanup_cursor = kwargs.pop('cleanup_cursor', True) + _cur: GenericAsyncCursor = kwargs.pop('cursor', await self.get_cursor(cursor_name=cursor_name)) + res = None + if not cleanup_cursor: + await _cur.execute(query, *params) + if fetch == 'all': res = await _cur.fetchall() + if fetch == 'one': res = await _cur.fetchone() + return res, _cur + + async with _cur as cur: # type: GenericAsyncCursor + await cur.execute(query, *params) + if fetch == 'all': res = await cur.fetchall() + if fetch == 'one': res = await cur.fetchone() + + return res, cursor_to_dict(cur) + + @awaitable + def action(self, sql: str, *params, **kwparams) -> INT_CORO: + return self._action(sql, *params, **kwparams) + + async def _action(self, sql: str, *params, **kwparams) -> int: + """ + Use :meth:`.action` as a simple alias method for running "action" queries which don't return results, only + affected row counts. + + For example ``INSERT``, ``UPDATE``, ``CREATE`` etc. queries. + + This method calls :meth:`.query` with ``fetch='no'``, saves the row count into a variable, closes the cursor, + then returns the affected row count as an integer. + + >>> db = GenericDBWrapper('SomeDB') + >>> rows_affected = db.action("DELETE FROM users WHERE first_name LIKE 'John%';") + >>> rows_affected + 4 + + :param str sql: An SQL query to execute on the current DB, as a string. + :param params: Extra arguments will be passed through to ``cursor.execute(sql, *params, **kwparams)`` + :param kwparams: Extra arguments will be passed through to ``cursor.execute(sql, *params, **kwparams)`` + :return int row_count: Number of rows affected + """ + # _, _, cur = self.query(sql, *params, fetch='no', **kwparams) + # async with await self.get_cursor() as cur: + res, cur = await self.execute(sql, *params, fetch='no', **kwparams) + row_count = int(cur.rowcount) + # cur.close() + return row_count + + @awaitable + def fetch(self, sql: str, *params, fetch='all', **kwparams) -> Union[TUPD_CORO, TUPD_OPT_CORO]: + return self._fetch(sql, *params, fetch=fetch, **kwparams) + + async def _fetch(self, sql: str, *params, fetch='all', **kwparams) \ + -> Optional[Union[dict, tuple, List[dict], List[tuple]]]: + """ + Similar to :py:meth:`.query` but only returns the fetch results, not the execution object nor cursor. + + Example Usage (default query mode):: + >>> s = GenericDBWrapper() + >>> user = s.fetch("SELECT * FROM users WHERE id = ?;", [123], fetch='one') + >>> user + (123, 'john', 'doe',) + + Example Usage (dict query mode):: + + >>> s.query_mode = 'dict' # Or s = SqliteWrapper(query_mode='dict') + >>> res = s.fetch("SELECT * FROM users WHERE id = ?;", [123], fetch='one') + >>> res + {'id': 123, 'first_name': 'john', 'last_name': 'doe'} + + + :param str fetch: Either ``'all'`` or ``'one'`` - controls whether the result is fetched with + :meth:`GenericCursor.fetchall` or :meth:`GenericCursor.fetchone` + :param str sql: An SQL query to execute on the current DB, as a string. + :param params: Extra arguments will be passed through to ``cursor.execute(sql, *params, **kwparams)`` + :param kwparams: Extra arguments will be passed through to ``cursor.execute(sql, *params, **kwparams)`` + :return: + """ + res, _, cur = await self.query(sql, *params, fetch=fetch, **kwparams) + return res + + @awaitable + def fetchone(self, sql: str, *params, **kwparams) -> Optional[Union[dict, tuple]]: + """Alias for :meth:`.fetch` with ``fetch='one'``""" + return self.fetch(sql, *params, fetch='one', **kwparams) + + @awaitable + def fetchall(self, sql: str, *params, **kwparams) -> Optional[Union[List[dict], List[tuple]]]: + """Alias for :meth:`.fetch` with ``fetch='all'``""" + return self.fetch(sql, *params, fetch='all', **kwparams) + + @awaitable + def insert(self, _table: str, _cursor: GenericAsyncCursor = None, **fields): + return self._insert(_table, _cursor=_cursor, **fields) + + async def _insert(self, _table: str, _cursor: GenericAsyncCursor = None, **fields) \ + -> Union[DictObject, GenericAsyncCursor]: + """ + Builds and executes an insert query into the table ``_table`` using the keyword arguments for column names + and values. + + >>> db = GenericDBWrapper(db='SomeDB') + >>> cur = db.insert('users', first_name='John', last_name='Doe', phone='+1-800-123-4567') + >>> cur.lastrowid + 15 + + :param str _table: The table to insert into + :param GenericCursor _cursor: Optionally, specify a cursor to use, instead of the default :attr:`.cursor` + :param fields: Keyword args mapping column names to values + :return DictObject cur: If no custom cursor was specified, the cursor used to execute the query is converted + into a :class:`.DictObject` before closing it, then the dict is returned. + :return GenericAsyncCursor cur: If a custom cursor (``_cursor``) was specified, then that cursor will NOT be + auto-closed, and the original cursor instance will be returned. + """ + query = self._build_insert_query(_table, list(fields.keys())) + + if _cursor is not None: # If a custom cursor is passed, execute with that cursor and return the cursor. + await _cursor.execute(query, list(fields.values())) + return _cursor + + # If no custom cursor was passed, use self.cursor, and use cursor_to_dict to preserve the cursor data after + # the cursor is closed. + # async with await self.cursor as _cursor: + # await _cursor.execute(query, list(fields.values())) + # res = cursor_to_dict(_cursor) + res = await self.execute(query, list(fields.values()), fetch='no') + return res[1] + + def _build_insert_query(self, table, fields: list, placeholder: str = None) -> str: + """ + Builds an SQL ``INSERT`` query based on the passed table name ``table``, and a list of column names to + insert into (``fields``). + + :param str table: The table to insert into. + :param list fields: A ``dict`` mapping columns to values, e.g. ``dict(username='john', age=39)`` + :param str placeholder: The value placeholder. If None, defaults to :attr:`.DEFAULT_PLACEHOLDER`. + :return str query: The built SQL query. + """ + placeholder = self.DEFAULT_PLACEHOLDER if placeholder is None else placeholder + sql_placeholders = ', '.join([placeholder for _ in fields]) + return f"INSERT INTO {table} ({', '.join(fields)}) VALUES ({sql_placeholders});" + + @staticmethod + def _zip_cols(cursor: GenericCursor, row: iter) -> DictObject: + """Combine column names from ``cursor`` with the row values ``row`` into a singular dict ``res``""" + # If the passed row is already a dictionary, just make sure the row is casted to a real dict and return it. + if isinstance(row, dict): + return DictObject(row) + # combine the column names with the row data to create a dictionary of column_name:value + col_names = list(map(lambda x: x[0], cursor.description)) + res = DictObject(zip(col_names, row)) + return res + + async def get_cursor(self, cursor_name=None, cursor_class=None, *args, **kwargs) \ + -> Union[AsyncCursorManager, GenericAsyncCursor]: + """ + Create and return a new database cursor object, by default the cursor will be wrapped with + :class:`.CursorManager` to ensure context management (``with`` statements) works regardless of whether + the database API supports context managing cursors (e.g. :py:mod:`sqlite` does not support cursor contexts). + + For sub-classes, you should override :meth:`._get_cursor`, which returns an actual native DB cursor. + + :param str cursor_name: (If DB API supports it) The name for this cursor + :param type cursor_class: (If DB API supports it) The cursor class to use + :key bool cursor_mgr: (Default: ``True``) If True, wrap the returned cursor with :class:`.CursorManager` + :key callable close_callback: (Default: ``None``) Passed onto :class:`.CursorManager` + :return GenericCursor cursor: A cursor object which should implement at least the basic Python DB API Cursor + functionality as specified in :class:`.GenericCursor` ((PEP 249) + """ + cursor_mgr = kwargs.pop('cursor_mgr', True) + close_callback = kwargs.pop('close_callback', None) + c = await self._get_cursor(cursor_name=cursor_name, cursor_class=cursor_class, *args, **kwargs) + self._cursors += [c] + return AsyncCursorManager(c, close_callback=close_callback) if cursor_mgr else c + + @abc.abstractmethod + async def _get_cursor(self, cursor_name=None, cursor_class=None, *args, **kwargs): + """ + Create and return a new database cursor object. + + It's recommended to override this method if you're inheriting from this class, as this Generic version of + ``_get_cursor`` does not make use of ``cursor_name`` nor ``cursor_class``. + + :param str cursor_name: (If DB API supports it) The name for this cursor + :param type cursor_class: (If DB API supports it) The cursor class to use + :return GenericCursor cursor: A cursor object which should implement at least the basic Python DB API Cursor + functionality as specified in :class:`.GenericCursor` ((PEP 249) + """ + c = await self.conn + return await c.cursor(*args, **kwargs) + + def _close_callback(self, state=None): + log.debug("%s._close_callback was called with state: %s", self.__class__.__name__, state) + if state == 'AFTER_CLOSE': + self._cursor = None + + # noinspection PyTypeChecker + @async_property + async def cursor(self) -> Union[AsyncCursorManager, GenericAsyncCursor]: + if self._cursor is None: + # self._cursor = CursorManager(self.get_cursor(), close_callback=self._close_callback) + self._cursor = self.get_cursor(close_callback=self._close_callback) + if asyncio.iscoroutine(self._cursor): + self._cursor = await self._cursor + return self._cursor + + @awaitable + def close_cursor(self): + return self._close_cursor() + + async def _close_cursor(self): + if self._cursor is None: + return + try: + await self._cursor.close() + except (BaseException, Exception): + log.exception("close_cursor was called but exception was raised...") + self._cursor = None + + @awaitable + def create_schemas(self, *tables) -> DICT_CORO: + return self._create_schemas(*tables) + + async def _create_schemas(self, *tables) -> dict: + """ + Create all tables listed in :py:attr:`.SCHEMAS` if they don't already exist. + + :param str tables: If one or more table names are specified, then only these tables will be affected. + :return dict result: ``dict_keys(['executions', 'created', 'skipped', 'tables_created', 'tables_skipped'])`` + """ + results = dict(executions=[], created=0, skipped=0, tables_created=[], tables_skipped=[]) + tb_count = len(self.SCHEMAS) + tb_exists = 0 + for table, schema in self.SCHEMAS: + tb_key = f"{self.db}:{table}" + if tb_key in self.tables_created: + log.debug('Found key %s in table_created.', tb_key) + tb_exists += 1 + results['tables_skipped'] += [table] + + if tb_exists >= tb_count: + log.debug('According to %s.tables_created, all tables already exist. Not creating schemas.', + self.__class__.__name__) + results['skipped'] = len(results['tables_skipped']) + return results + + # c = self.get_cursor() + for table, schema in self.SCHEMAS: + if len(tables) > 0 and table not in tables: + log.debug("Table '%s' not specified in argument '*tables'. Skipping.", table) + continue + r = await self.create_schema(table=table, schema=schema) + results['executions'] += r.get('executions', []) + results['tables_created'] += r.get('tables_created', []) + results['tables_skipped'] += r.get('tables_skipped', []) + # c.close() + results['created'] = len(results['tables_created']) + results['skipped'] = len(results['tables_skipped']) + return results + + @awaitable + def create_schema(self, table: str, schema: str = None) -> DICT_CORO: + return self._create_schema(table=table, schema=schema) + + async def _create_schema(self, table: str, schema: str = None): + """ + Create the individual table ``table``, either uses the create statement ``schema``, or if ``schema`` is empty, + then checks for a pre-existing CREATE statement for ``table`` in :py:attr:`.SCHEMAS`. + + >>> db = GenericDBWrapper('SomeDBName') + >>> db.create_schema('users', 'CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(200));') + + :param table: + :param schema: + :return: + """ + tb_key = f"{self.db}:{table}" + results = dict(executions=[], tables_created=[], tables_skipped=[]) + if tb_key in self.tables_created: + log.debug('Skipping check for table %s as table is in table_created.', table) + return None + cls_name = self.__class__.__name__ + if empty(schema): + _schemas = dict(self.SCHEMAS) + if table not in _schemas: + raise AttributeError( + f"Missing schema - cannot create table {table}. " + f"Table does not exist in {cls_name}.SCHEMAS and schema param was empty." + ) + schema = _schemas[table] + + if not await self.table_exists(table): + log.debug("Table %s didn't exist. Creating it.", table) + # cur = self.get_cursor() + # async with await self.get_cursor() as cur: + results['executions'] += [await self.execute(schema)] + results['tables_created'] += [table] + # cur.close() + else: + log.debug("Table %s already exists. Not creating it.", table) + results['tables_skipped'] += [table] + + self.tables_created.add(tb_key) + return results + + @awaitable + def recreate_schemas(self, *tables) -> DICT_CORO: + return self._recreate_schemas(*tables) + + async def _recreate_schemas(self, *tables) -> dict: + """ + Drop all tables then re-create them. + + Shortcut for running :meth:`.drop_schemas` followed by :meth:`.create_schemas`. + + :param str tables: If one or more table names are specified, then only these tables will be affected. + :return dict result: ``dict_keys(['tables_created', 'skipped_create', 'skipped_drop', 'tables_dropped'])`` + """ + res = dict(tables_created=[], tables_dropped=[], skipped_create=[], skipped_drop=[]) + # conn = await self.conn + log.debug("\n-------------------------\n") + dr = await self.drop_schemas(*tables) + # await conn.commit() + log.debug("\n-------------------------\n") + cr = await self.create_schemas(*tables) + # await conn.commit() + log.debug("\n-------------------------\n") + + res['tables_created'] += cr['tables_created'] + res['skipped_create'] += cr['tables_skipped'] + + res['tables_dropped'] += dr['tables_dropped'] + res['skipped_drop'] += dr['tables_skipped'] + + return res + + @awaitable + def drop_schemas(self, *tables) -> DICT_CORO: + return self._drop_schemas(*tables) + + async def _drop_schemas(self, *tables) -> dict: + """ + Drop all tables listed in :py:attr:`.SCHEMAS` if they exist. + + :param str tables: If one or more table names are specified, then only these tables will be affected. + :return dict result: ``dict_keys(['executions', 'created', 'skipped', 'tables_dropped', 'tables_skipped'])`` + """ + reversed_schemas = list(self.SCHEMAS) + reversed_schemas.reverse() + tables_drop = [t for t, _ in reversed_schemas] + cls_name = self.__class__.__name__ + + if len(tables) > 0: + _tables = list(tables) + log.debug("Tables specified to drop_schemas. Dropping tables: %s", _tables) + tables_drop = [t for t, _ in reversed_schemas if t in _tables] + tables_drop += [t for t in _tables if t not in reversed_schemas] + + results = dict(executions=[], created=0, skipped=0, tables_dropped=[], tables_skipped=[]) + for table in tables_drop: + tb_key = f"{self.db}:{table}" + if await self.table_exists(table): + log.debug("Table %s exists. Dropping it.", table) + was_dropped = await self.drop_table(table) + + if was_dropped: + if self.enable_execution_log: + results['executions'] += [self._execution_log[-1]] + else: + log.debug('%s.enable_execution_log is False. Not logging execution of drop_table', cls_name) + results['tables_dropped'] += [table] + else: + log.debug("%s.drop_table('%s') returned False... Table wasn't dropped?", + self.__class__.__name__, table) + results['tables_skipped'] += [table] + else: + log.debug("Table %s doesn't exist. Not dropping it.", table) + results['tables_skipped'] += [table] + if tb_key in self.tables_created: + log.debug('Removing key "%s" from tables_created, as table %s no longer exists.', tb_key, table) + self.tables_created.remove(tb_key) + return results + + def drop_table(self, table: str) -> BOOL_CORO: + return self._drop_table(table) + + async def _drop_table(self, table: str) -> bool: + """ + Drop the table ``table`` if it exists. If the table exists, it will be dropped and ``True`` will be returned. + + If the table doesn't exist (thus can't be dropped), ``False`` will be returned. + """ + if not await self.table_exists(table): + return False + q = await self.action(f'DROP TABLE {table};') + return True + + @awaitable + def drop_tables(self, *tables) -> Union[List[Tuple[str, bool]], Coroutine[Any, Any, List[tuple]]]: + return self._drop_tables(*tables) + + async def _drop_tables(self, *tables) -> List[Tuple[str, bool]]: + """ + Drop one or more tables as positional arguments. + + Returns a list of tuples containing ``(table_name:str, was_dropped:bool,)`` + :param str tables: One or more table names to drop as positional args + :return list drop_results: List of tuples containing ``(table_name:str, was_dropped:bool,)`` + """ + tables = list(tables) + coros = [self.drop_table(table) for table in tables] + coro_res = await asyncio.gather(*coros) + + # return [(table, self.drop_table(table),) for table in tables] + return list(zip(tables, coro_res)) + + @abc.abstractmethod + def __enter__(self): + """ + Create the database connection at the start of a *with* statement, e.g.:: + + >>> with GenericDBWrapper('SomeDB') as db: + ... db.query("INSERT INTO users (name) VALUES ('Bob');", fetch='no') + ... + + """ + self._conn = self.conn + return self + + @abc.abstractmethod + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Close / delete the database connection at the end of a *with* statement, e.g.:: + + >>> with GenericDBWrapper('SomeDB') as db: + ... db.query("INSERT INTO users (name) VALUES ('Bob');", fetch='no') + ... + + """ + self.close_cursor() + if self._conn is not None: + self._conn.close() + del self._conn + self._conn = None + diff --git a/privex/db/query/__init__.py b/privex/db/query/__init__.py index faf8717..f3a34a4 100644 --- a/privex/db/query/__init__.py +++ b/privex/db/query/__init__.py @@ -1,5 +1,8 @@ import logging from privex.db.query.base import BaseQueryBuilder, QueryMode +import nest_asyncio + +nest_asyncio.apply() log = logging.getLogger(__name__) @@ -13,7 +16,10 @@ try: from privex.db.query.sqlite import SqliteQueryBuilder - __all__ += ['SqliteQueryBuilder'] + from privex.db.query.asyncx.sqlite import SqliteAsyncQueryBuilder + + __all__ += ['SqliteQueryBuilder', 'SqliteAsyncQueryBuilder'] except ImportError: log.warning("Failed to import privex.db.query.sqlite (missing Python SQLite API?)") + diff --git a/privex/db/query/asyncx/__init__.py b/privex/db/query/asyncx/__init__.py new file mode 100644 index 0000000..be3cc34 --- /dev/null +++ b/privex/db/query/asyncx/__init__.py @@ -0,0 +1,2 @@ +from privex.db.query.asyncx.sqlite import SqliteAsyncQueryBuilder +from privex.db.query.asyncx.base import BaseAsyncQueryBuilder diff --git a/privex/db/query/asyncx/base.py b/privex/db/query/asyncx/base.py new file mode 100644 index 0000000..dc2fd14 --- /dev/null +++ b/privex/db/query/asyncx/base.py @@ -0,0 +1,367 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Union, Coroutine, Any + +from async_property import async_property +from privex.helpers import awaitable, run_sync, empty_if + +from privex.db.base import AsyncCursorManager +from privex.db.query.base import QueryMode +from privex.db.types import GenericAsyncCursor, GenericAsyncConnection, STR_CORO, ANY_CORO, TUPD_CORO, TUP_DICT, \ + TUPD_OPT_CORO, TUPDICT_OPT, GenericCursor + +import logging + +log = logging.getLogger(__name__) + + +async def async_iterate_list(obj): + x = [] + async for l in obj: + x += [l] + return x + + +class BaseAsyncQueryBuilder(ABC): + """ + This is an asynchronous version of :class:`.BaseQueryBuilder` - an abstract base class for Async query builders + to inherit from. + + To allow chaining to function correctly, the methods ``select``, ``where``, ``limit`` and similar methods + are still synchronous. + + The following methods are wrapped with ``@awaitable`` which allows them to work both synchronously, and + asynchronously, depending on whether they're called from a sync / async function: + + * ``get_cursor`` + * ``close_cursor`` + * ``build_query`` + * ``execute`` + * ``all`` + * ``fetch`` + * ``fetch_next`` + + When inheriting from this class, for each of the above methods, you should implement the method with an underline + in front (e.g. ``_execute``), as to avoid breaking the ``@awaitable`` "hybrid" wrapping. + """ + query: str + table: str + select_cols: List[str] + group_cols: List[str] + where_clauses: List[str] + where_clauses_values: List[str] + order_cols: List[str] + order_dir: str + order_dir: str + + _is_executed: bool + _co_str = Union[str, Coroutine[Any, Any, str]] + _cursor: Optional[GenericAsyncCursor] + _connection: GenericAsyncConnection + connection_kwargs: dict + connection_args: list + + # noinspection SqlNoDataSourceInspection + Q_SELECT_CLAUSE = ' SELECT {cols} FROM {table}' + Q_WHERE_CLAUSE = ' WHERE {w_clauses}' + Q_LIMIT_CLAUSE = ' LIMIT {limit}' + Q_OFFSET_CLAUSE = ' OFFSET {offset}' + Q_ORDER_CLAUSE = ' ORDER BY {order_cols} {order_dir}' + Q_GROUP_BY_CLAUSE = ' GROUP BY {group_cols}' + Q_PRE_QUERY = "" + Q_POST_QUERY = "" + Q_DEFAULT_PLACEHOLDER = "%s" + + def __init__(self, table: str, connection_args: list = None, connection_kwargs: dict = None, **kwargs): + self.query = "" + self.connection_kwargs = empty_if(connection_kwargs, {}) + self.connection_args = empty_if(connection_args, []) + self.table = table + self.select_cols = [] + self.group_cols = [] + self.where_clauses = [] + self.where_clauses_values = [] + self.order_cols = [] + self.order_dir = '' + self.limit_num = None + self.limit_offset = None + self._cursor = None + self._is_executed = False + + @abstractmethod + async def get_connection(self) -> GenericAsyncConnection: + raise NotImplementedError(f"{self.__class__.__name__} must implement .get_connection()!") + + async def get_cursor(self, cursor_name=None, cursor_class=None, *args, **kwargs) -> GenericAsyncCursor: + """ + Create and return a new database cursor object. + + It's recommended to override this method if you're inheriting from this class, as this Generic version of + ``_get_cursor`` does not make use of ``cursor_name`` nor ``cursor_class``. + + :param str cursor_name: (If DB API supports it) The name for this cursor + :param type cursor_class: (If DB API supports it) The cursor class to use + :return GenericAsyncCursor cursor: An async cursor object which should implement at least the basic Python + DB API Cursor functionality as specified in :class:`.GenericAsyncCursor` + (PEP 249) + + """ + self._connection = await self.get_connection() + self._cursor = await self._connection.cursor(*args, **kwargs) + return self._cursor + + @async_property + async def cursor(self) -> GenericAsyncCursor: + if self._cursor is None: + self._cursor = await self.get_cursor() + return self._cursor + + async def close_cursor(self): + if not hasattr(self, '_cursor') or self._cursor is None: + return + try: + c = await self._cursor.close() + except (BaseException, Exception): + log.exception("close_cursor was called but exception was raised...") + try: + self._cursor = None + except: + pass + + async def build_query(self) -> str: + """ + Used internally by :py:meth:`.all` and :py:meth:`.fetch` - builds and returns a string SQL query using the + various class attributes such as :py:attr:`.where_clauses` + :return str query: The SQL query that will be sent to the database as a string + """ + # raise NotImplementedError(f"{self.__class__.__name__} must implement .build_query()!") + return await self._build_query() + + @abstractmethod + async def _build_query(self) -> str: + """ + This is a stock :meth:`.build_query` method which can be used by sub-classes if their DBMS is compatible + with the ANSI SQL outputted by this method. + + Example:: + + >>> class SomeDBQueryBuilder(BaseAsyncQueryBuilder): + >>> async def _build_query(self) -> str: + ... return await super()._build_query() + + + :return str query: The SQL query that will be sent to the database as a string + """ + s_cols = ', '.join(self.select_cols) if len(self.select_cols) > 0 else '*' + # q = f"{self.Q_PRE_QUERY} " + q = self.Q_PRE_QUERY + # SELECT {s_cols} FROM {self.table} + q += self.Q_SELECT_CLAUSE.format(cols=s_cols, table=self.table) + if len(self.where_clauses) > 0: + w_clauses = ' '.join(self.where_clauses) + # q += f" WHERE {w_clauses}" + q += self.Q_WHERE_CLAUSE.format(w_clauses=w_clauses) + if len(self.group_cols) > 0: + g_cols = ', '.join(self.group_cols) + # q += f" GROUP BY {g_cols}" + q += self.Q_GROUP_BY_CLAUSE.format(group_cols=g_cols) + if len(self.order_cols) > 0: + # q += f" ORDER BY {', '.join(self.order_cols)} {self.order_dir}" + q += self.Q_ORDER_CLAUSE.format(order_cols=', '.join(self.order_cols), order_dir=self.order_dir) + if self.limit_num is not None: + # q += f" LIMIT {self.limit_num}" + q += self.Q_LIMIT_CLAUSE.format(limit=self.limit_num) + if self.limit_offset is not None: + q += self.Q_OFFSET_CLAUSE.format(offset=self.limit_offset) + # q += f" OFFSET {self.limit_offset}" + + q += ';' + + log.debug(f"Built query: '''\n{q}\n'''") + + return q + + async def execute(self, *args, **kwargs) -> Any: + _exec = await self.cursor.execute(self.build_query(), self.where_clauses_values, *args, **kwargs) + self._is_executed = True + return _exec + + @abstractmethod + async def all(self, query_mode=QueryMode.ROW_DICT) -> TUP_DICT: + """ + Executes the current query, and returns an iterable cursor (results are loaded as you iterate the cursor) + + Usage: + + >>> results = BaseQueryBuilder('people').all() # Equivalent to ``SELECT * FROM people;`` + >>> for r in results: + >>> print(r['first_name'], r['last_name'], r['phone']) + + :return Iterable: A cursor which can be iterated using a ``for`` loop. + Ideally, should load rows as you iterate, saving RAM. + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement .all()!") + + @abstractmethod + async def fetch(self, query_mode=QueryMode.ROW_DICT) -> TUPDICT_OPT: + """ + Executes the current query, and fetches the first result as a ``dict``. + + If there are no results, will return None + + :return dict: The query result as a dictionary: {column: value, } + :return None: If no results are found + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement .fetch()!") + + @abstractmethod + async def fetch_next(self, query_mode=QueryMode.ROW_DICT) -> TUPDICT_OPT: + """ + Similar to :meth:`.fetch`, but doesn't close the cursor after the query, so can be ran more than once + to iterate row-by-row over the results. + + :param QueryMode query_mode: + :return: + """ + raise NotImplementedError(f"{self.__class__.__name__} must implement .fetch_next()!") + + def select(self, *args): + """ + Add columns to select clause, specify as individual args. NOTE: no escaping! + + example: + + q.select('mycol', 'othercol', 'somecol as thiscol') + + can also chain: q.select('mycol').select('othercol') + + :param args: columns to select as individual arguments + :return: QueryBuilder object (for chaining) + """ + self.select_cols += list(args) + return self + + def order(self, *args, direction='DESC'): + """ + example: order('mycol', 'othercol') == ORDER BY mycol, othercol DESC + + :param args: One or more order columns as individual args + :param direction: Direction to sort + :return: QueryBuilder object (for chaining) + """ + self.order_cols = list(args) + self.order_dir = direction + return self + + def order_by(self, *args, **kwargs): + """Alias of :meth:`.order`""" + return self.order(*args, **kwargs) + + def where(self, col, val, compare='=', placeholder=None): + """ + For adding a simple col=value clause with "AND" before it (if at least 1 other clause). val is escaped properly + + example: where('x','test').where('y','thing') produces prepared sql "WHERE x = %s AND y = %s" + + :param col: the column, function etc. to query + :param val: the value it should be equal to. most python objects will be converted and escaped properly + :param compare: instead of '=', compare using this comparator, e.g. '>', '<=' etc. + :param placeholder: Set the value placeholder, e.g. placeholder='HOST(%s)' + :return: QueryBuilder object (for chaining) + """ + placeholder = self.Q_DEFAULT_PLACEHOLDER if placeholder is None else placeholder + + # Convert .where('name', None) into "name IS NULL" + if val is None: + placeholder = 'NULL' + if compare == '=': + compare = 'IS' + elif compare == '!=': + compare = 'IS NOT' + else: + self.where_clauses_values += [val] + + if len(self.where_clauses) > 0: + self.where_clauses += ['AND {} {} {}'.format(col, compare, placeholder)] + return self + self.where_clauses += ['{} {} {}'.format(col, compare, placeholder)] + return self + + def where_or(self, col, val, compare='=', placeholder=None): + """ + For adding simple col=value clause with "OR" before it (if at least 1 other clause). val is escaped properly + + example: where('x','test').where_or('y','thing') produces prepared sql "WHERE x = %s OR y = %s" + + :param col: the column, function etc. to query + :param val: the value it should be equal to. most python objects will be converted and escaped properly + :param compare: instead of '=', compare using this comparator, e.g. '>', '<=' etc. + :param placeholder: Set the value placeholder, e.g. placeholder='HOST(%s)' + :return: QueryBuilder object (for chaining) + """ + placeholder = self.Q_DEFAULT_PLACEHOLDER if placeholder is None else placeholder + + self.where_clauses_values += [val] + + if len(self.where_clauses) > 0: + self.where_clauses += ['OR {} {} {}'.format(col, compare, placeholder)] + return self + self.where_clauses += ['{} {} {}'.format(col, compare, placeholder)] + return self + + def limit(self, limit_num, offset=None): + """ + Add a limit/offset. When using offset you should use an ORDER BY to avoid issues. + :param limit_num: Amount of rows to limit to + :param offset: Offset by this many rows (optional) + :return: QueryBuilder object (for chaining) + """ + self.limit_num = limit_num + if offset is not None: + self.limit_offset = offset + + return self + + def group_by(self, *args): + """ + Add one or more columns to group by clause. + + example: group_by('name', 'date') == GROUP BY name, date + + :param args: One or more columns to group by + :return: QueryBuilder object (for chaining) + """ + self.group_cols += list(args) + return self + + async def __aiter__(self): + """ + Allow the query object to be iterated over to get results. + + Iterating over the query builder object is equivalent to iterating over :meth:`.all` + + >>> q = BaseQueryBuilder('users') + >>> for r in q.select('username', 'first_name').where('id', 10, '>='): + ... print(r['username'], r['first_name']) + + """ + # res = await self.all() + async for r in self.all(): + yield r + + async def __anext__(self): + return await self.fetch_next() + + def __getitem__(self, item): + if type(item) is int: + if item == 0: + # async with await self.fetch(query_mode=QueryMode.ROW_DICT) as r: + return run_sync(self.fetch, query_mode=QueryMode.ROW_DICT) + # async with await self.all() as rows: + return list(run_sync(async_iterate_list, self.all()))[item] + if type(item) is str: + # async with await self.fetch(query_mode=QueryMode.ROW_DICT) as r: + # return r[item] + return run_sync(self.fetch, query_mode=QueryMode.ROW_DICT)[item] + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close_cursor() diff --git a/privex/db/query/asyncx/sqlite.py b/privex/db/query/asyncx/sqlite.py new file mode 100644 index 0000000..e06f899 --- /dev/null +++ b/privex/db/query/asyncx/sqlite.py @@ -0,0 +1,149 @@ +import sqlite3 +from typing import Union, Coroutine, Any, AsyncIterable + +import aiosqlite +from aiosqlite import Cursor +from async_property import async_property +from privex.helpers import DictObject + +from privex.db.query.asyncx.base import BaseAsyncQueryBuilder +from privex.db.query.base import QueryMode +from privex.db.types import GenericCursor, GenericAsyncCursor, TUP_DICT, TUPDICT_OPT, GenericAsyncConnection + + +def _zip_cols(cursor: Union[sqlite3.Cursor, GenericCursor, GenericAsyncCursor], row: iter): + # combine the column names with the row data + # so it can be used like a dict + col_names = list(map(lambda x: x[0], cursor.description)) + res = DictObject(zip(col_names, row)) + return res + + +class SqliteAsyncQueryBuilder(BaseAsyncQueryBuilder): + Q_DEFAULT_PLACEHOLDER = '?' + Q_PRE_QUERY = '' + _connection: aiosqlite.Connection + _cursor = Coroutine[Any, Any, Cursor] + + def __init__(self, table: str, connection_args: list = None, connection_kwargs: dict = None, **kwargs): + super().__init__(table, connection_args=connection_args, connection_kwargs=connection_kwargs, **kwargs) + self._connection = None + + @async_property + async def connection(self) -> aiosqlite.Connection: + if self._connection is None: + self._connection = await self.get_connection() + return self._connection + + async def get_connection(self) -> aiosqlite.Connection: + return aiosqlite.connect(*self.connection_args, **self.connection_kwargs) + + # @async_property + # async def cursor(self) -> Cursor: + # if self._cursor is None: + # # conn = self.connection + # # if asyncio.iscoroutine(conn): + # conn = await self.connection + # await conn._connect() + # self._cursor = await conn.cursor() + # return self._cursor + + # async def _get_cursor(self, cursor_name=None, cursor_class=None, *args, **kwargs) -> aiosqlite.Connection: + # return self.conn + + async def _build_query(self) -> str: + return await super()._build_query() + + async def execute(self, *args, **kwargs) -> Any: + _cur = kwargs.pop('cursor', None) + if _cur is None: + if self._connection is None: + _cur = self._connection = await self.get_connection() + else: + _cur = self._connection + + # conn = self._connection = await self.get_connection() + # conn = await self.connection + + self._cursor = await _cur.execute(await self.build_query(), self.where_clauses_values) + self._is_executed = True + return self._cursor + + async def all(self, query_mode=QueryMode.ROW_DICT) -> AsyncIterable[Union[tuple, dict]]: + # await self.execute() + async def _all_body(connection): + cur = await self.execute(cursor=connection) + for res in await cur.fetchall(): + if query_mode == QueryMode.ROW_DICT: + yield _zip_cols(cur, res) + else: + yield res + + if self._connection is None: + _conn = await self.get_connection() + async with _conn as conn: + async for row in _all_body(conn): + yield row + else: + _conn = self._connection + async for row in _all_body(_conn): + yield row + + # async with _conn as conn: + # cur = await self.execute(cursor=conn) + # for res in await cur.fetchall(): + # if query_mode == QueryMode.ROW_DICT: + # yield _zip_cols(cur, res) + # else: + # yield res + + self._cursor = None + + async def fetch(self, query_mode=QueryMode.ROW_DICT) -> TUPDICT_OPT: + async def _fetch_body(connection): + cur = await self.execute(cursor=connection) + res = await cur.fetchone() + if len(res) > 0 and query_mode == QueryMode.ROW_DICT: + res = _zip_cols(cur, tuple(res)) + return res + + # with self.cursor as cur: + if self._connection is None: + _conn = await self.get_connection() + async with _conn as conn: + return await _fetch_body(conn) + # cur = await self.execute(cursor=conn) + # res = await cur.fetchone() + # if len(res) > 0 and query_mode == QueryMode.ROW_DICT: + # res = _zip_cols(cur, tuple(res)) + # return res + + return await _fetch_body(self._connection) + # self._cursor = None + # return res + + async def fetch_next(self, query_mode=QueryMode.ROW_DICT) -> TUPDICT_OPT: + if self._connection is None: + name = self.__class__.__name__ + raise ConnectionError( + f"To use {name}.fetch_next() you MUST use this class in an async context manager, e.g: \n" + f"\tasync with {name}('users', connection_args=['example.db']) as b:\n" + f"\t\tuser = await b.__anext__()\n" + ) + if not self._is_executed: + await self.execute() + res = await self._cursor.fetchone() + if len(res) > 0 and query_mode == QueryMode.ROW_DICT: + res = _zip_cols(self._cursor, tuple(res)) + if res is None or len(res) == 0: + self._cursor = None + return res + + async def __aenter__(self): + self._connection = await self.get_connection() + await self._connection.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._connection.__aexit__(None, None, None) + diff --git a/privex/db/query/base.py b/privex/db/query/base.py index 4a1ed8e..4a95897 100644 --- a/privex/db/query/base.py +++ b/privex/db/query/base.py @@ -354,5 +354,5 @@ def __exit__(self, exc_type, exc_val, exc_tb): def __del__(self): self.close_cursor() - + diff --git a/privex/db/query/sqlite.py b/privex/db/query/sqlite.py index abde4c3..50bd187 100644 --- a/privex/db/query/sqlite.py +++ b/privex/db/query/sqlite.py @@ -1,9 +1,8 @@ import sqlite3 +import warnings from typing import Iterable, Union -from privex.helpers import DictObject - -from privex.db.types import GenericCursor +from privex.db.query.asyncx.sqlite import _zip_cols from privex.db.query.base import BaseQueryBuilder, QueryMode @@ -13,7 +12,7 @@ def fetch_next(self, query_mode=QueryMode.ROW_DICT) -> Union[dict, tuple, None]: self.execute() res = self.cursor.fetchone() if len(res) > 0 and query_mode == QueryMode.ROW_DICT: - res = self._zip_cols(self.cursor, tuple(res)) + res = _zip_cols(self.cursor, tuple(res)) return res def fetch(self, query_mode=QueryMode.ROW_DICT) -> Union[dict, tuple, None]: @@ -23,7 +22,7 @@ def fetch(self, query_mode=QueryMode.ROW_DICT) -> Union[dict, tuple, None]: self.execute() res = cur.fetchone() if len(res) > 0 and query_mode == QueryMode.ROW_DICT: - res = self._zip_cols(cur, tuple(res)) + res = _zip_cols(cur, tuple(res)) # cur.close() return res @@ -38,14 +37,6 @@ def conn(self) -> sqlite3.Connection: def build_query(self) -> str: return self._build_query() - @staticmethod - def _zip_cols(cursor: Union[sqlite3.Cursor, GenericCursor], row: iter): - # combine the column names with the row data - # so it can be used like a dict - col_names = list(map(lambda x: x[0], cursor.description)) - res = DictObject(zip(col_names, row)) - return res - def all(self, query_mode=QueryMode.ROW_DICT) -> Union[Iterable[dict], Iterable[tuple]]: if self.conn is None: raise Exception('Please set SqliteQueryBuilder.connection to an sqlite3 connection') @@ -54,7 +45,7 @@ def all(self, query_mode=QueryMode.ROW_DICT) -> Union[Iterable[dict], Iterable[t with self.cursor as cur: for res in self.execute(): if query_mode == QueryMode.ROW_DICT: - yield self._zip_cols(cur, res) + yield _zip_cols(cur, res) else: yield res # res = cur.fetchall() @@ -63,3 +54,4 @@ def all(self, query_mode=QueryMode.ROW_DICT) -> Union[Iterable[dict], Iterable[t # res = [self._zip_cols(cur, r) for r in orig_res] # cur.close() # return res + diff --git a/privex/db/sqlite.py b/privex/db/sqlite.py index 5e8707c..d8308ce 100644 --- a/privex/db/sqlite.py +++ b/privex/db/sqlite.py @@ -22,15 +22,21 @@ """ +import asyncio import os import sqlite3 import logging +import warnings from os.path import expanduser, join, dirname, isabs -from typing import List, Tuple, Optional, Any, Union, Set -from privex.helpers import empty, DictObject +from typing import List, Tuple, Optional, Any, Union, Set, Iterable -from privex.db.base import GenericDBWrapper +from async_property import async_property +from privex.helpers import empty, DictObject, is_namedtuple + +from privex.db.base import GenericDBWrapper, GenericAsyncDBWrapper, _should_zip, cursor_to_dict, DBExecution from privex.db.query.sqlite import SqliteQueryBuilder +from privex.db.query import SqliteAsyncQueryBuilder +from privex.db.types import GenericAsyncCursor log = logging.getLogger(__name__) @@ -127,7 +133,7 @@ def __init__(self, db: str = None, isolation_level=None, **kwargs): """ db = self.DEFAULT_DB if db is None else db - if db != ':memory:': + if ':memory:' not in db: db_folder = dirname(db) if not isabs(db): log.debug("Passed 'db' argument isn't absolute: %s", db) @@ -180,3 +186,284 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self._conn is not None: self._conn.close() self._conn = None + + +try: + import aiosqlite + + class SqliteAsyncWrapper(GenericAsyncDBWrapper): + """ + + **Usage** + + Creating an instance:: + + >>> from privex.db import SqliteAsyncWrapper + >>> db = SqliteAsyncWrapper('my_app.db') + + Inserting rows:: + + >>> db.insert('users', first_name='John', last_name='Doe') + >>> db.insert('users', first_name='Dave', last_name='Johnson') + + Running raw queries:: + + >>> # fetchone() allows you to run a raw query, and a dict is returned with the first row result + >>> row = await db.fetchone("SELECT * FROM users WHERE first_name = ?;", ['John']) + >>> row['first_name'] + John + >>> row['last_name'] + Doe + + >>> # fetchall() runs a query and returns an iterator of the returned rows + >>> rows = await db.fetchall("SELECT * FROM users;") + >>> for user in rows: + ... print(f"First Name: {row['first_name']} || Last Name: {row['last_name']}") + ... + First Name: John || Last Name: Doe + First Name: Dave || Last Name: Johnson + + >>> # action() is for running queries where you don't want to fetch any results. It simply returns the + >>> # affected row count as an integer. + >>> row_count = await db.action('UPDATE users SET first_name = ? WHERE id = ?;', ['David', 2]) + >>> print(row_count) + 1 + + Creating tables if they don't already exist:: + + >>> # If the table 'users' doesn't exist, the CREATE TABLE query will be executed. + >>> await db.create_schema( + ... 'users', + ... "CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, first_name TEXT, last_name TEXT);" + ... ) + >>> + + Using the query builder:: + + >>> # You can either use it directly + >>> q = db.builder('users') + >>> q.select('first_name', 'last_name').where('first_name', 'John').where_or('last_name', 'Doe') + >>> results = q.all() + >>> async for row in results: + ... print(f"First Name: {row['first_name']} || Last Name: {row['last_name']}") + ... + First Name: John || Last Name: Doe + + >>> # Or, you can use it in a ``with`` statement to maintain a singular connection, which means you + >>> # can use .fetch_next to fetch a singular row at a time (you can still use .all() and .fetch()) + >>> async with db.builder('users') as q: + ... q.select('first_name', 'last_name') + ... row = q.fetch_next() + ... print('Name:', row['first_name'], row['last_name']) # John Doe + ... row = q.fetch_next() + ... print('Name:', row['first_name'], row['last_name']) # Dave Johnson + ... + Name: John Doe + Name: Dave Johnson + + + + Creating a wrapper sub-class of SqliteAsyncWrapper: + + + .. code-block:: python + + class MyManager(SqliteAsyncWrapper): + ### + # If a database path isn't specified, then the class attribute DEFAULT_DB will be used. + ### + DEFAULT_DB_FOLDER: str = expanduser('~/.my_app') + DEFAULT_DB_NAME: str = 'my_app.db' + DEFAULT_DB: str = join(DEFAULT_DB_FOLDER, DEFAULT_DB_NAME) + + ### + # The SCHEMAS class attribute contains a list of tuples, with each tuple containing the name of a + # table, as well as the SQL query required to create the table if it doesn't exist. + ### + SCHEMAS: List[Tuple[str, str]] = [ + ('users', "CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT);"), + ('items', "CREATE TABLE items (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT);"), + ] + + async def get_items(self): + # This is an example of a helper method you might want to define, which simply calls + # self.fetchall with a pre-defined SQL query + return await self.fetchall("SELECT * FROM items;") + + async def find_item(self, id: int): + # This is an example of a helper method you might want to define, which simply calls + # self.fetchone with a pre-defined SQL query, and interpolates the 'id' parameter into + # the prepared statement. + return await self.fetchone("SELECT * FROM items WHERE id = ?;", [id]); + + """ + AIO_CUR = aiosqlite.Cursor + + DEFAULT_DB_FOLDER: str = expanduser('~/.privex_sqlite') + """If an absolute path isn't given, store the sqlite3 database file in this folder""" + DEFAULT_DB_NAME: str = 'privex_sqlite.db' + """If no database is specified to :meth:`.__init__`, then use this (appended to :py:attr:`.DEFAULT_DB_FOLDER`)""" + DEFAULT_DB: str = join(DEFAULT_DB_FOLDER, DEFAULT_DB_NAME) + """ + Combined :py:attr:`.DEFAULT_DB_FOLDER` and :py:attr:`.DEFAULT_DB_NAME` used as default absolute path for + the sqlite3 database + """ + + DEFAULT_TABLE_QUERY = "SELECT count(name) as table_count FROM sqlite_master WHERE type = 'table' AND name = ?" + DEFAULT_TABLE_LIST_QUERY = "SELECT name FROM sqlite_master WHERE type = 'table'" + + db: str + """Path to the SQLite3 database for this class instance""" + + _conn: Optional[aiosqlite.Connection] + """Instance variable which holds the current SQLite3 connection object""" + + _builder: Optional[SqliteAsyncQueryBuilder] + + def __init__(self, db: str = None, isolation_level=None, **kwargs): + """ + + :param str db: Relative / absolute path to SQLite3 database file to use. + :param isolation_level: Isolation level for SQLite3 connection. Defaults to ``None`` (autocommit). + See the `Python SQLite3 Docs`_ for more information. + + :key int db_timeout: Amount of time to wait for any SQLite3 locks to expire before giving up + :key str query_mode: Either ``'flat'`` (query returns tuples) or ``'dict'`` (query returns dicts). + More details in PyDoc block under :py:attr:`.query_mode` + + .. _Python SQLite3 Docs: https://docs.python.org/3.8/library/sqlite3.html#sqlite3.Connection.isolation_level + + """ + db = self.DEFAULT_DB if db is None else db + if ':memory:' not in db: + db_folder = dirname(db) + if not isabs(db): + log.debug("Passed 'db' argument isn't absolute: %s", db) + db = join(self.DEFAULT_DB_FOLDER, db) + log.debug("Prepended DEFAULT_DB_FOLDER to 'db' argument: %s", db) + db_folder = dirname(db) + + if not os.path.exists(db_folder): + log.debug("Database folder '%s' doesn't exist. Creating it + any missing parent folders", db_folder) + os.makedirs(db_folder) + else: + log.debug("Passed 'db' argument is :memory: - using in-memory sqlite3 database.") + self.db = db + self.isolation_level = isolation_level + self.db_timeout = int(kwargs.pop('db_timeout', 30)) + self.query_mode = kwargs.pop('query_mode', 'dict') + self._conn = None + self._builder = None + + super().__init__( + db=db, connector_func=aiosqlite.connect, connector_args=[db], query_mode=self.query_mode, + connector_kwargs=dict(isolation_level=self.isolation_level, timeout=self.db_timeout), + **kwargs + ) + + async def get_cursor(self, cursor_name=None, cursor_class=None, *args, **kwargs) -> aiosqlite.Cursor: + # Disable cursor_mgr by default, as aiosqlite already has a context manager. + kwargs = dict(kwargs) + kwargs['cursor_mgr'] = kwargs.pop('cursor_mgr', False) + # noinspection PyTypeChecker + return await super().get_cursor(cursor_name=cursor_name, cursor_class=cursor_class, *args, **kwargs) + + @async_property + async def cursor(self) -> aiosqlite.Cursor: + # if self._cursor is None: + # self._cursor = self.get_cursor(cursor_mgr=False, close_callback=self._close_callback) + # if asyncio.iscoroutine(self._cursor): + # self._cursor = await self._cursor + return await self.get_cursor(cursor_mgr=False, close_callback=self._close_callback) + + # noinspection PyTypeChecker + @async_property + async def conn(self) -> aiosqlite.Connection: + """Get or create an SQLite3 connection using DB file :py:attr:`.db` and return it""" + return await super().conn # type: aiosqlite.Connection + + # noinspection PyTypeChecker + def builder(self, table: str) -> SqliteAsyncQueryBuilder: + return SqliteAsyncQueryBuilder( + table=table, connection_args=self.connector_args, connection_kwargs=self.connector_kwargs + ) + + async def _get_cursor(self, cursor_name=None, cursor_class=None, *args, **kwargs): + # conn = await self.conn + # return conn.cursor() + return await self._get_connection(new=True, await_conn=False) + + # noinspection PyTypeChecker + async def insert(self, _table: str, _cursor: AIO_CUR = None, **fields) -> Union[DictObject, AIO_CUR]: + return await super().insert(_table, _cursor, **fields) + + _Q_OUT_TYPE = GenericAsyncDBWrapper._Q_OUT_TYPE + + # async def _query(self, sql: str, *params, fetch='all', **kwparams) -> _Q_OUT_TYPE: + # conn = await self.conn + # async with conn as db: + # query_mode = kwparams.pop('query_mode', self.query_mode) + # async with db.execute(sql, *params, **kwparams) as cur: + # if fetch == 'all': + # if self.AUTO_ZIP_COLS and query_mode == 'dict': + # res = [self._zip_cols(cur, r) for r in cur] + # elif fetch == 'one': + # res = cur[0] + # if res is None: + # return None, cur, cursor_to_dict(cur) + # if _should_zip(res, query_mode=query_mode, auto_zip=self.AUTO_ZIP_COLS): + # res = self._zip_cols(cur, tuple(res)) + # elif fetch == 'no': + # res = None + # else: + # raise AttributeError("The parameter 'fetch' must be either 'all', 'one' or 'no'.") + # if self.enable_execution_log: + # self._execution_log += [DBExecution(sql, res, cur, cursor_to_dict(cur))] + # return res, cur, cursor_to_dict(cur) + + async def execute(self, query: str, *params: Iterable, fetch='all', **kwargs) \ + -> Tuple[Iterable, DictObject]: + + # cursor_name = kwargs.pop('cursor_name', None) + cleanup_cursor = kwargs.pop('cleanup_cursor', True) + _cur: aiosqlite.Connection = kwargs.pop('cursor', None) + res = None + + # cur = _cur + if _cur is None: + # noinspection PyTypeChecker + _cur: aiosqlite.Connection = await self._get_connection(new=True, await_conn=False) + + # if not cleanup_cursor: + # # _cur. + # cur = await _cur.execute(query, *params) + # if fetch == 'all': res = await cur.fetchall() + # if fetch == 'one': res = await cur.fetchone() + # return res, cur + + async with _cur as conn: + async with conn.execute(query, *params) as cur: + if fetch == 'all': res = await cur.fetchall() + if fetch == 'one': res = await cur.fetchone() + cur_dict = cursor_to_dict(cur) + await conn.commit() + return res, cur_dict + + def __enter__(self): + self._conn = self.conn + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close_cursor() + if self._conn is not None: + conn = await self._conn + await conn.close() + del self._conn + self._conn = None + + def __exit__(self, exc_type, exc_val, exc_tb): + self._conn = None + +except ImportError: + warnings.warn("Could not import 'aiosqlite'. SqliteAsyncWrapper will not be available.") + diff --git a/privex/db/types.py b/privex/db/types.py index 944d76a..5d36fca 100644 --- a/privex/db/types.py +++ b/privex/db/types.py @@ -23,9 +23,25 @@ """ -from typing import Any, Iterable, Union +from typing import Any, Iterable, Union, Coroutine, Type, Optional from typing_extensions import Protocol +CoroNone = Type[Coroutine[Any, Any, None]] + +TUP_DICT = Union[Iterable[dict], Iterable[tuple]] +TUPDICT_OPT = Optional[Union[dict, tuple]] + +# The below types are for @awaitable functions, which might either synchronously return their value, +# or return a coroutine which asynchronously returns the value. +STR_CORO = Union[str, Coroutine[Any, Any, str]] +INT_CORO = Union[int, Coroutine[Any, Any, int]] +BOOL_CORO = Union[bool, Coroutine[Any, Any, bool]] +DICT_CORO = Union[dict, Coroutine[Any, Any, dict]] +ITER_CORO = Union[Iterable, Coroutine[Any, Any, Iterable]] +TUPD_CORO = Union[TUP_DICT, Coroutine[Any, Any, TUP_DICT]] +TUPD_OPT_CORO = Union[TUPDICT_OPT, Coroutine[Any, Any, TUPDICT_OPT]] +ANY_CORO = Union[Any, Coroutine[Any, Any, Any]] + class GenericCursor(Protocol): """ @@ -50,6 +66,21 @@ def fetchall(self, *args, **kwargs) -> Iterable: pass def fetchmany(self, *args, **kwargs) -> Iterable: pass +class GenericAsyncCursor(Protocol): + + async def close(self, *args, **kwargs): pass + + async def execute(self, query: str, params: Iterable = None, *args, **kwargs) -> Any: pass + + async def executemany(self, query: str, params: Iterable = None, *args, **kwargs) -> Any: pass + + async def fetchone(self, *args, **kwargs) -> Union[tuple, list, dict, set]: pass + + async def fetchall(self, *args, **kwargs) -> Iterable: pass + + async def fetchmany(self, *args, **kwargs) -> Iterable: pass + + class GenericConnection(Protocol): """ This is a :class:`typing_extensions.Protocol` which represents any database Connection object which follows @@ -64,3 +95,16 @@ def commit(self, *args, **kwargs): pass def rollback(self, *args, **kwargs): pass def close(self, *args, **kwargs): pass + + +class GenericAsyncConnection(Protocol): + def __init__(self, *args, **kwargs): pass + + async def cursor(self, *args, **kwargs) -> GenericAsyncCursor: pass + + async def commit(self, *args, **kwargs): pass + + async def rollback(self, *args, **kwargs): pass + + async def close(self, *args, **kwargs): pass + diff --git a/requirements.txt b/requirements.txt index 6cf9dea..cba1af3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,10 @@ coverage codecov pytest pytest-cov +pytest-asyncio +nest_asyncio +async-property +aiosqlite privex-helpers[setuppy]>=2.3.0 sphinx-autobuild>=0.7.1 restructuredtext-lint>=1.3.0 diff --git a/setup.py b/setup.py index fbef19e..efcbf63 100755 --- a/setup.py +++ b/setup.py @@ -76,7 +76,10 @@ license='MIT', install_requires=[ - 'privex-helpers>=2.3.0', + 'privex-helpers>=2.6.0', + 'aiosqlite', + 'async-property', + 'nest_asyncio', 'python-dateutil', 'pytz', 'typing-extensions', diff --git a/tests/base.py b/tests/base.py index bbe03e2..b6dec37 100644 --- a/tests/base.py +++ b/tests/base.py @@ -21,14 +21,17 @@ """ import logging import sqlite3 +import warnings + import dotenv from os import getenv as env from typing import List, Tuple from unittest import TestCase from privex.loghelper import LogHelper -from privex.helpers import dictable_namedtuple +from privex.helpers import dictable_namedtuple, Mocker from privex.db import SqliteWrapper, BaseQueryBuilder, SqliteQueryBuilder, QueryMode from privex.db import _setup_logging +from privex.db.sqlite import SqliteAsyncWrapper try: dotenv.read_dotenv() @@ -46,7 +49,11 @@ log = _lh.get_logger() -class PrivexDBTestBase(TestCase): +class PrivexTestBase(TestCase): + pass + + +class PrivexDBTestBase(PrivexTestBase): """ Base class for all privex-db test classes. Includes :meth:`.tearDown` to reset database after each test. """ @@ -60,7 +67,7 @@ def tearDown(self) -> None: __all__ = [ 'PrivexDBTestBase', 'SqliteWrapper', 'BaseQueryBuilder', 'SqliteQueryBuilder', 'QueryMode', - 'ExampleWrapper', 'LOG_LEVEL', 'LOG_FORMATTER' + 'ExampleWrapper', 'LOG_LEVEL', 'LOG_FORMATTER', 'User', 'example_users' ] """ We manually specify __all__ so that we can safely use ``from tests.base import *`` within each test file. @@ -68,15 +75,17 @@ def tearDown(self) -> None: User = dictable_namedtuple('User', 'first_name last_name') +example_users = [ + User('John', 'Doe'), + User('Jane', 'Smith'), + User('John', 'Johnson'), + User('Dave', 'Johnson'), + User('John', 'Smith'), +] + class _TestWrapperMixin: - example_users = [ - User('John', 'Doe'), - User('Jane', 'Smith'), - User('John', 'Johnson'), - User('Dave', 'Johnson'), - User('John', 'Smith'), - ] + example_users = example_users def __init__(self, *args, **kwargs): super(_TestWrapperMixin, self).__init__(*args, **kwargs) @@ -124,4 +133,74 @@ class ExampleWrapper(SqliteWrapper, _TestWrapperMixin): def __init__(self, *args, **kwargs): super(ExampleWrapper, self).__init__(*args, **kwargs) +try: + import aiosqlite + HAS_ASYNC = True + + class _TestAsyncWrapperMixin: + def __init__(self, *args, **kwargs): + super(_TestWrapperMixin, self).__init__(*args, **kwargs) + + async def get_items(self): + return await self.fetchall("SELECT * FROM items;") + + async def find_item(self, id: int): + return await self.fetchone("SELECT * FROM items WHERE id = ?;", [id]) + + async def get_users(self): + return await self.fetchall("SELECT * FROM users;") + + async def insert_user(self, first_name, last_name) -> aiosqlite.Cursor: + # c = await self.conn.cursor() + res = await self.execute( + "INSERT INTO users (first_name, last_name) " + "VALUES (?, ?);", + [first_name, last_name], fetch='no' + ) + return res[1] + + async def insert_item(self, name) -> sqlite3.Cursor: + # c = await self.conn.cursor() + res = await self.execute( + "INSERT INTO items (name) VALUES (?);", + [name] + ) + return res[1] + + async def find_user(self, id: int): + return await self.fetchone("SELECT * FROM users WHERE id = ?;", [id]) + + + class ExampleAsyncWrapper(SqliteAsyncWrapper, _TestAsyncWrapperMixin): + example_users = example_users + + DEFAULT_DB: str = 'file::memory:?cache=privexdbtests' + SCHEMAS: List[Tuple[str, str]] = [ + ( + 'users', + "CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, first_name TEXT, last_name TEXT);" + ), + ( + 'items', "CREATE TABLE items (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT);" + ), + ] + + def __init__(self, *args, **kwargs): + super(ExampleAsyncWrapper, self).__init__(*args, **kwargs) + + + # class PrivexAsyncTestBase(PrivexTestBase): + # def setUp(self) -> None: + # self.wrp = ExampleAsyncWrapper() + # + # def tearDown(self) -> None: + # self.wrp.drop_schemas() + + # __all__ += ['PrivexAsyncTestBase', 'ExampleAsyncWrapper'] + __all__ += ['ExampleAsyncWrapper'] +except ImportError: + HAS_ASYNC = False + # PrivexAsyncTestBase = Mocker() + ExampleAsyncWrapper = Mocker() + warnings.warn("Could not import 'aiosqlite'. ExampleAsyncWrapper will not be available.") diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000..c53418a --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,241 @@ +import pytest + +from privex.db import QueryMode +from tests.base import ExampleAsyncWrapper + + +wrp: ExampleAsyncWrapper = ExampleAsyncWrapper() + + +@pytest.fixture() +async def setup_teardown() -> None: + global wrp + await wrp.close_cursor() + wrp = ExampleAsyncWrapper() + # wrp = ExampleAsyncWrapper() + await wrp.recreate_schemas() + yield "" + await wrp.drop_schemas() + + +# @pytest.fixture() +# def tearDown() -> None: +# wrp.drop_schemas() + + +def assertEqual(param, param1): + assert param == param1 + + +def assertIn(param, container): + assert param in container + + +def assertNotIn(param, container): + assert param not in container + + +@pytest.mark.asyncio +async def test_all_call(setup_teardown): + b = wrp.builder('users') + await wrp.insert_user('John', 'Doe') + await wrp.insert_user('Dave', 'Johnson') + + # res = list(await b.all()) + res = [row async for row in b.all()] + assertEqual(res[0]['first_name'], 'John') + assertEqual(res[0]['last_name'], 'Doe') + assertEqual(res[1]['first_name'], 'Dave') + assertEqual(res[1]['last_name'], 'Johnson') + + +@pytest.mark.asyncio +async def test_where_call(setup_teardown): + b = wrp.builder('users') + await wrp.insert_user('John', 'Doe') + await wrp.insert_user('Dave', 'Johnson') + await wrp.insert_user('Jane', 'Smith') + + res = await b.where('first_name', 'Dave').fetch() + assertEqual(res['first_name'], 'Dave') + assertEqual(res['last_name'], 'Johnson') + + +@pytest.mark.asyncio +async def test_group_call(setup_teardown): + b = wrp.builder('users') + await wrp.insert_user('John', 'Doe') + await wrp.insert_user('John', 'Johnson') + await wrp.insert_user('John', 'Smith') + await wrp.insert_user('Dave', 'Johnson') + await wrp.insert_user('Jane', 'Smith') + + b.select('first_name', 'COUNT(first_name)').where('first_name', 'John').group_by('first_name') + + res = await b.fetch(query_mode=QueryMode.ROW_TUPLE) + assertEqual(res[0], 'John') + assertEqual(res[1], 3) + + +@pytest.mark.asyncio +async def test_iterate_builder(setup_teardown): + """ + Test obtaining SqliteQueryBuilder results by iterating over the builder object itself with a for loop + """ + b = wrp.builder('users') + ex_users = wrp.example_users + for u in ex_users: + await wrp.insert_user(u.first_name, u.last_name) + + for i, row in enumerate(b): + assertEqual(row['first_name'], ex_users[i].first_name) + assertEqual(row['last_name'], ex_users[i].last_name) + + +@pytest.mark.asyncio +async def test_index_builder(setup_teardown): + """ + Test obtaining SqliteQueryBuilder results by accessing an index of the builder object + """ + b = wrp.builder('users') + ex_users = wrp.example_users + for u in ex_users: + await wrp.insert_user(u.first_name, u.last_name) + + for i in range(0, 3): + assertEqual(b[i]['first_name'], ex_users[i].first_name) + assertEqual(b[i]['last_name'], ex_users[i].last_name) + + +@pytest.mark.asyncio +async def test_generator_builder(setup_teardown): + """ + Test obtaining SqliteQueryBuilder results by calling :func:`next` on the builder object (like a generator) + """ + + ex_users = wrp.example_users + for u in ex_users: + await wrp.insert_user(u.first_name, u.last_name) + + async with wrp.builder('users') as b: + for i in range(0, len(ex_users)): + user = await b.__anext__() + assertEqual(user['first_name'], ex_users[i].first_name) + assertEqual(user['last_name'], ex_users[i].last_name) + + +@pytest.mark.asyncio +async def test_tables_created(setup_teardown): + w = wrp + assertEqual(w.db, ExampleAsyncWrapper.DEFAULT_DB) + tables = await w.list_tables() + assertIn('users', tables) + assertIn('items', tables) + + +@pytest.mark.asyncio +async def test_tables_drop(setup_teardown): + w = wrp + tables = await w.list_tables() + assertIn('users', tables) + assertIn('items', tables) + + await w.drop_schemas() + tables = await w.list_tables() + assertNotIn('users', tables) + assertNotIn('items', tables) + + +@pytest.mark.asyncio +async def test_insert_find_user(setup_teardown): + w = wrp + w.query_mode = 'flat' + res = await w.insert_user('John', 'Doe') + assertEqual(res.rowcount, 1) + user = await w.find_user(res.lastrowid) + print('User is:', user) + assertEqual(user[1], 'John') + assertEqual(user[2], 'Doe') + + +@pytest.mark.asyncio +async def test_action_update(setup_teardown): + w = wrp + w.query_mode = 'dict' + res = await w.insert_user('John', 'Doe') + last_id = res.lastrowid + rows = await w.action("UPDATE users SET last_name = ? WHERE first_name = ?", ['Smith', 'John']) + assertEqual(rows, 1) + john = await w.find_user(last_id) + assertEqual(john['last_name'], 'Smith') + + +@pytest.mark.asyncio +async def test_find_user_dict_mode(setup_teardown): + w = wrp + w.query_mode = 'dict' + res = await w.insert_user('John', 'Doe') + assertEqual(res.rowcount, 1) + user = await w.find_user(res.lastrowid) + assertEqual(user['first_name'], 'John') + assertEqual(user['last_name'], 'Doe') + + +def assertIsNone(param): + assert param is None + + +@pytest.mark.asyncio +async def test_find_user_nonexistent(setup_teardown): + w = wrp + user = await w.find_user(99) + assertIsNone(user) + + +@pytest.mark.asyncio +async def test_get_users_tuple(setup_teardown): + w = wrp + w.query_mode = 'flat' + await w.insert_user('John', 'Doe') + await w.insert_user('Jane', 'Doe') + await w.insert_user('Dave', 'Johnson') + + users = list(await w.get_users()) + assertEqual(len(users), 3) + assertEqual(users[0][1], 'John') + + assertEqual(users[1][1], 'Jane') + assertEqual(users[1][2], 'Doe') + + assertEqual(users[2][2], 'Johnson') + + +@pytest.mark.asyncio +async def test_get_users_dict(setup_teardown): + w = wrp + w.query_mode = 'dict' + + await w.insert_user('John', 'Doe') + await w.insert_user('Jane', 'Doe') + await w.insert_user('Dave', 'Johnson') + + users = list(await w.get_users()) + assertEqual(len(users), 3) + assertEqual(users[0]['first_name'], 'John') + + assertEqual(users[1]['first_name'], 'Jane') + assertEqual(users[1]['last_name'], 'Doe') + + assertEqual(users[2]['last_name'], 'Johnson') + + +@pytest.mark.asyncio +async def test_insert_helper(setup_teardown): + w = wrp + w.query_mode = 'dict' + res = await w.insert('users', first_name='Dave', last_name='Johnson') + assertEqual(res.lastrowid, 1) + + user = await w.find_user(res.lastrowid) + assertEqual(user['first_name'], 'Dave') + assertEqual(user['last_name'], 'Johnson') diff --git a/tests/test_sqlite_builder.py b/tests/test_sqlite_builder.py index 4cdf844..f26dbad 100644 --- a/tests/test_sqlite_builder.py +++ b/tests/test_sqlite_builder.py @@ -1,8 +1,15 @@ """ Tests related to :class:`.SqliteQueryBuilder` and :class:`.ExampleWrapper` """ +from unittest import TestCase + +import pytest + +from privex.db.sqlite import SqliteAsyncWrapper from tests.base import * import logging +import nest_asyncio +nest_asyncio.apply() log = logging.getLogger(__name__) @@ -122,3 +129,95 @@ def test_generator_builder(self): user = next(b) self.assertEqual(user['first_name'], ex_users[i].first_name) self.assertEqual(user['last_name'], ex_users[i].last_name) + +# +# class TestAsyncSQLiteBuilder(TestCase): +# wrp: ExampleAsyncWrapper +# +# def setUp(self) -> None: +# self.wrp = ExampleAsyncWrapper() +# +# def tearDown(self) -> None: +# self.wrp.drop_schemas() +# +# @pytest.mark.asyncio +# async def test_all_call(self): +# b = self.wrp.builder('users') +# await self.wrp.insert_user('John', 'Doe') +# await self.wrp.insert_user('Dave', 'Johnson') +# +# res = list(await b.all()) +# self.assertEqual(res[0]['first_name'], 'John') +# self.assertEqual(res[0]['last_name'], 'Doe') +# self.assertEqual(res[1]['first_name'], 'Dave') +# self.assertEqual(res[1]['last_name'], 'Johnson') +# +# @pytest.mark.asyncio +# async def test_where_call(self): +# b = self.wrp.builder('users') +# await self.wrp.insert_user('John', 'Doe') +# await self.wrp.insert_user('Dave', 'Johnson') +# await self.wrp.insert_user('Jane', 'Smith') +# +# res = await b.where('first_name', 'Dave').fetch() +# self.assertEqual(res['first_name'], 'Dave') +# self.assertEqual(res['last_name'], 'Johnson') +# +# @pytest.mark.asyncio +# async def test_group_call(self): +# b = self.wrp.builder('users') +# await self.wrp.insert_user('John', 'Doe') +# await self.wrp.insert_user('John', 'Johnson') +# await self.wrp.insert_user('John', 'Smith') +# await self.wrp.insert_user('Dave', 'Johnson') +# await self.wrp.insert_user('Jane', 'Smith') +# +# b.select('first_name', 'COUNT(first_name)').where('first_name', 'John').group_by('first_name') +# +# res = await b.fetch(query_mode=QueryMode.ROW_TUPLE) +# self.assertEqual(res[0], 'John') +# self.assertEqual(res[1], 3) +# +# @pytest.mark.asyncio +# async def test_iterate_builder(self): +# """ +# Test obtaining SqliteQueryBuilder results by iterating over the builder object itself with a for loop +# """ +# b = self.wrp.builder('users') +# ex_users = self.wrp.example_users +# for u in ex_users: +# await self.wrp.insert_user(u.first_name, u.last_name) +# +# for i, row in enumerate(b): +# self.assertEqual(row['first_name'], ex_users[i].first_name) +# self.assertEqual(row['last_name'], ex_users[i].last_name) +# +# @pytest.mark.asyncio +# async def test_index_builder(self): +# """ +# Test obtaining SqliteQueryBuilder results by accessing an index of the builder object +# """ +# b = self.wrp.builder('users') +# ex_users = self.wrp.example_users +# for u in ex_users: +# await self.wrp.insert_user(u.first_name, u.last_name) +# +# for i in range(0, 3): +# self.assertEqual(b[i]['first_name'], ex_users[i].first_name) +# self.assertEqual(b[i]['last_name'], ex_users[i].last_name) +# +# @pytest.mark.asyncio +# async def test_generator_builder(self): +# """ +# Test obtaining SqliteQueryBuilder results by calling :func:`next` on the builder object (like a generator) +# """ +# b = self.wrp.builder('users') +# +# ex_users = self.wrp.example_users +# for u in ex_users: +# await self.wrp.insert_user(u.first_name, u.last_name) +# +# for i in range(0, len(ex_users)): +# user = next(b) +# self.assertEqual(user['first_name'], ex_users[i].first_name) +# self.assertEqual(user['last_name'], ex_users[i].last_name)