diff --git a/.gitignore b/.gitignore index 57142fc..cc851b9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ dist/ zip/ .pycache/ +*__pycache__* +.idea/ \ No newline at end of file diff --git a/doltcli/dolt.py b/doltcli/dolt.py index 454fe9b..4b07b9f 100644 --- a/doltcli/dolt.py +++ b/doltcli/dolt.py @@ -7,10 +7,15 @@ from collections import OrderedDict import datetime from subprocess import PIPE, Popen -from typing import List, Dict, Tuple, Union, Optional +from typing import List, Dict, Tuple, Union, Optional, Callable, Any logger = logging.getLogger(__name__) +SQL_OUTPUT_PARSERS = { + 'csv': lambda fh: list(csv.DictReader(fh)), + 'json': lambda fh: json.load(fh) +} + from .types import ( BranchT, CommitT, @@ -31,7 +36,6 @@ write_rows, ) - class DoltException(Exception): """ @@ -67,11 +71,12 @@ def __init__(self, message): self.message = message -def _execute(args: List[str], cwd: Optional[str] = None): +def _execute(args: List[str], cwd: Optional[str] = None, outfile: Optional[str] = None): _args = ["dolt"] + args str_args = " ".join(" ".join(args).split()) logger.info(str_args) - proc = Popen(args=_args, cwd=cwd, stdout=PIPE, stderr=PIPE) + _outfile = open(outfile, 'w') if outfile else PIPE + proc = Popen(args=_args, cwd=cwd, stdout=_outfile, stderr=PIPE) out, err = proc.communicate() exitcode = proc.returncode @@ -79,7 +84,10 @@ def _execute(args: List[str], cwd: Optional[str] = None): logger.error(err) raise DoltException(str_args, out, err, exitcode) - return out.decode("utf-8") + if outfile: + return outfile + else: + return out.decode("utf-8") class Status(StatusT): @@ -280,22 +288,33 @@ def head(self): return head_commit def execute( - self, args: List[str], print_output: Optional[bool] = None - ) -> List[str]: + self, args: List[str], print_output: Optional[bool] = None, stdout_to_file: bool = False + ) -> str: """ Manages executing a dolt command, pass all commands, sub-commands, and arguments as they would appear on the command line. :param args: :param print_output: + :param stdout_to_file: :return: """ - output = _execute(args, self.repo_dir) + if print_output and stdout_to_file: + raise ValueError('Cannot print output and send it to a file') + + outfile = None + if stdout_to_file: + _, outfile = tempfile.mkstemp() + + output = _execute(args, self.repo_dir, outfile=outfile) print_output = print_output or self._print_output if print_output: logger.info(output) - return output.split("\n") + if outfile: + return outfile + else: + return output @staticmethod def init(repo_dir: Optional[str] = None) -> "Dolt": @@ -337,7 +356,7 @@ def status(self) -> Status: new_tables: Dict[str, bool] = {} changes: Dict[str, bool] = {} - output = self.execute(["status"], print_output=False) + output = self.execute(["status"], print_output=False).split("\n") if "clean" in str("\n".join(output)): return Status(True, changes, new_tables) @@ -449,7 +468,7 @@ def merge(self, branch: str, message: Optional[str] = None, squash: bool = False args.append("--squash") args.append(branch) - output = self.execute(args) + output = self.execute(args).split('\n') merge_conflict_pos = 2 if len(output) == 3 and "Fast-forward" in output[1]: @@ -492,6 +511,7 @@ def sql( list_saved: bool = False, batch: bool = False, multi_db_dir: Optional[str] = None, + result_parser: Callable[[io.StringIO], Any] = None ): """ Execute a SQL query, using the options to dictate how it is executed, and where the output goes. @@ -503,6 +523,7 @@ def sql( :param list_saved: print out a list of saved queries :param batch: execute in batch mode, one statement after the other delimited by ; :param multi_db_dir: use a directory of Dolt repos, each one treated as a database + :param result_parser: :return: """ args = ["sql"] @@ -536,31 +557,22 @@ def sql( "Must provide a query in order to specify a result format" ) args.extend(["--query", query]) - if result_format in ["csv", "tabular"]: - args.extend(["--result-format", "csv"]) - output = self.execute(args) - dict_reader = csv.DictReader(io.StringIO("\n".join(output))) - return list(dict_reader) - elif result_format == "json": - args.extend(["--result-format", "json"]) - output = self.execute(args) - return json.load(io.StringIO("".join(output))) + + if result_format in ["csv", "json"]: + args.extend(["--result-format", result_format]) + output_file = self.execute(args, stdout_to_file=True) + return SQL_OUTPUT_PARSERS[result_format](open(output_file)) + else: - raise ValueError( - f"{result_format} is not a valid value for result_format" - ) + args.extend(["--result-format", "csv"]) + output_file = self.execute(args, stdout_to_file=True) + return result_parser(open(output_file)) logger.warning("Must provide a value for result_format to get output back") if query: args.extend(["--query", query]) self.execute(args) - def _parse_tabluar_output_to_dict(self, args: List[str]): - args.extend(["--result-format", "csv"]) - output = self.execute(args) - dict_reader = csv.DictReader(io.StringIO("\n".join(output))) - return list(dict_reader) - def log(self, number: Optional[int] = None, commit: Optional[str] = None) -> Dict: """ Parses the log created by running the log command into instances of `Commit` that provide detail of the @@ -734,7 +746,10 @@ def _get_branches(self) -> Tuple[Branch, List[Branch]]: ab_dicts = read_rows_sql( self, f"select * from dolt_branches where name = (select active_branch())" ) - assert len(ab_dicts) == 1 + + if len(ab_dicts) != 1: + raise ValueError('Ensure you have the latest version of Dolt installed, this is fixed as of 0.24.2') + active_branch = Branch(**ab_dicts[0]) if not active_branch: @@ -795,7 +810,7 @@ def remote( args = ["remote", "--verbose"] if not (add or remove): - output = self.execute(args, print_output=False) + output = self.execute(args, print_output=False).split('\n') remotes = [] for line in output: @@ -1200,7 +1215,7 @@ def ls(self, system: bool = False, all: bool = False) -> List[TableT]: if system: args.append("--system") - output = self.execute(args, print_output=False) + output = self.execute(args, print_output=False).split('\n') tables: List[TableT] = [] system_pos = None diff --git a/doltcli/types.py b/doltcli/types.py index e8119a3..8e76d64 100644 --- a/doltcli/types.py +++ b/doltcli/types.py @@ -1,7 +1,8 @@ from dataclasses import asdict, dataclass import datetime import json -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Callable, Any +import io class Encoder(json.JSONEncoder): @@ -136,6 +137,7 @@ def sql( list_saved: bool = False, batch: bool = False, multi_db_dir: Optional[str] = None, + result_parser: Callable[[io.StringIO], Any] = None ) -> List: ... diff --git a/doltcli/utils.py b/doltcli/utils.py index f222aae..49d85b5 100644 --- a/doltcli/utils.py +++ b/doltcli/utils.py @@ -17,30 +17,30 @@ def read_columns( dolt: DoltT, table: str, as_of: Optional[str] = None ) -> Dict[str, list]: - return read_columns_sql(dolt, _get_read_table_asof_query(table, as_of)) + return read_columns_sql(dolt, get_read_table_asof_query(table, as_of)) def read_rows(dolt: DoltT, table: str, as_of: Optional[str] = None) -> List[dict]: - return read_rows_sql(dolt, _get_read_table_asof_query(table, as_of)) + return read_rows_sql(dolt, get_read_table_asof_query(table, as_of)) -def _get_read_table_asof_query(table: str, as_of: Optional[str] = None) -> str: +def get_read_table_asof_query(table: str, as_of: Optional[str] = None) -> str: base_query = f"SELECT * FROM `{table}`" return f'{base_query} AS OF "{as_of}"' if as_of else base_query def read_columns_sql(dolt: DoltT, sql: str) -> Dict[str, list]: - rows = _read_table_sql(dolt, sql) + rows = read_table_sql(dolt, sql) columns = rows_to_columns(rows) return columns def read_rows_sql(dolt: DoltT, sql: str) -> List[dict]: - return _read_table_sql(dolt, sql) + return read_table_sql(dolt, sql) -def _read_table_sql(dolt: DoltT, sql: str) -> List[dict]: - return dolt.sql(sql, result_format="csv") +def read_table_sql(dolt: DoltT, sql: str, result_parser: Callable[[io.StringIO], Any] = None) -> List[dict]: + return dolt.sql(sql, result_format="csv", result_parser=result_parser) CREATE, FORCE_CREATE, REPLACE, UPDATE = "create", "force_create", "replace", "update" diff --git a/poetry.lock b/poetry.lock index b23b473..80edfc8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,122 +1,122 @@ [[package]] -category = "dev" -description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." name = "appdirs" +version = "1.4.4" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +category = "dev" optional = false python-versions = "*" -version = "1.4.4" [[package]] -category = "dev" -description = "Atomic file writes." name = "atomicwrites" +version = "1.4.0" +description = "Atomic file writes." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "1.4.0" [[package]] -category = "dev" -description = "Classes Without Boilerplate" name = "attrs" +version = "20.3.0" +description = "Classes Without Boilerplate" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "20.3.0" [package.extras] -dev = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"] +dev = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "furo", "sphinx", "pre-commit"] docs = ["furo", "sphinx", "zope.interface"] -tests = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] -tests_no_zope = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"] +tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] +tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"] [[package]] -category = "dev" -description = "The uncompromising code formatter." name = "black" +version = "20.8b1" +description = "The uncompromising code formatter." +category = "dev" optional = false python-versions = ">=3.6" -version = "20.8b1" [package.dependencies] appdirs = "*" click = ">=7.1.2" +dataclasses = {version = ">=0.6", markers = "python_version < \"3.7\""} mypy-extensions = ">=0.4.3" pathspec = ">=0.6,<1" regex = ">=2020.1.8" toml = ">=0.10.1" typed-ast = ">=1.4.0" typing-extensions = ">=3.7.4" -dataclasses = {version = ">=0.6", markers = "python_version < \"3.7\""} [package.extras] colorama = ["colorama (>=0.4.3)"] d = ["aiohttp (>=3.3.2)", "aiohttp-cors"] [[package]] -category = "dev" -description = "Composable command line interface toolkit" name = "click" +version = "7.1.2" +description = "Composable command line interface toolkit" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "7.1.2" [[package]] -category = "dev" -description = "Cross-platform colored terminal text." name = "colorama" +version = "0.4.4" +description = "Cross-platform colored terminal text." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "0.4.4" [[package]] -category = "dev" -description = "Code coverage measurement for Python" name = "coverage" +version = "5.5" +description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, <4" -version = "5.5" [package.extras] toml = ["toml"] [[package]] -category = "dev" -description = "A backport of the dataclasses module for Python 3.6" name = "dataclasses" +version = "0.8" +description = "A backport of the dataclasses module for Python 3.6" +category = "dev" optional = false python-versions = ">=3.6, <3.7" -version = "0.8" [[package]] -category = "dev" -description = "Read metadata from Python packages" name = "importlib-metadata" +version = "3.7.3" +description = "Read metadata from Python packages" +category = "dev" optional = false python-versions = ">=3.6" -version = "3.7.3" [package.dependencies] -zipp = ">=0.5" typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} +zipp = ">=0.5" [package.extras] docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] -testing = ["pytest (>=3.5,<3.7.3 || >3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "pytest-enabler", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"] +testing = ["pytest (>=3.5,!=3.7.3)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "pytest-enabler", "packaging", "pep517", "pyfakefs", "flufl.flake8", "pytest-black (>=0.3.7)", "pytest-mypy", "importlib-resources (>=1.3)"] [[package]] -category = "dev" -description = "iniconfig: brain-dead simple config-ini parsing" name = "iniconfig" +version = "1.1.1" +description = "iniconfig: brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = "*" -version = "1.1.1" [[package]] -category = "dev" -description = "Optional static typing for Python" name = "mypy" +version = "0.800" +description = "Optional static typing for Python" +category = "dev" optional = false python-versions = ">=3.5" -version = "0.800" [package.dependencies] mypy-extensions = ">=0.4.3,<0.5.0" @@ -127,39 +127,39 @@ typing-extensions = ">=3.7.4" dmypy = ["psutil (>=4.0)"] [[package]] -category = "dev" -description = "Experimental type system extensions for programs checked with the mypy typechecker." name = "mypy-extensions" +version = "0.4.3" +description = "Experimental type system extensions for programs checked with the mypy typechecker." +category = "dev" optional = false python-versions = "*" -version = "0.4.3" [[package]] -category = "dev" -description = "Core utilities for Python packages" name = "packaging" +version = "20.9" +description = "Core utilities for Python packages" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "20.9" [package.dependencies] pyparsing = ">=2.0.2" [[package]] -category = "dev" -description = "Utility library for gitignore style pattern matching of file paths." name = "pathspec" +version = "0.8.1" +description = "Utility library for gitignore style pattern matching of file paths." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "0.8.1" [[package]] -category = "dev" -description = "plugin and hook calling mechanisms for python" name = "pluggy" +version = "0.13.1" +description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "0.13.1" [package.dependencies] importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} @@ -168,106 +168,106 @@ importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} dev = ["pre-commit", "tox"] [[package]] -category = "dev" -description = "library with cross-python path, ini-parsing, io, code, log facilities" name = "py" +version = "1.10.0" +description = "library with cross-python path, ini-parsing, io, code, log facilities" +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "1.10.0" [[package]] -category = "dev" -description = "Python parsing module" name = "pyparsing" +version = "2.4.7" +description = "Python parsing module" +category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -version = "2.4.7" [[package]] -category = "dev" -description = "pytest: simple powerful testing with Python" name = "pytest" +version = "6.2.2" +description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.6" -version = "6.2.2" [package.dependencies] +atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} attrs = ">=19.2.0" +colorama = {version = "*", markers = "sys_platform == \"win32\""} +importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} iniconfig = "*" packaging = "*" pluggy = ">=0.12,<1.0.0a1" py = ">=1.8.2" toml = "*" -atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} -colorama = {version = "*", markers = "sys_platform == \"win32\""} -importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} [package.extras] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] [[package]] -category = "dev" -description = "Pytest plugin for measuring coverage." name = "pytest-cov" +version = "2.11.1" +description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "2.11.1" [package.dependencies] coverage = ">=5.2.1" pytest = ">=4.6" [package.extras] -testing = ["fields", "hunter", "process-tests (2.0.2)", "six", "pytest-xdist", "virtualenv"] +testing = ["fields", "hunter", "process-tests (==2.0.2)", "six", "pytest-xdist", "virtualenv"] [[package]] -category = "dev" -description = "Alternative regular expression module, to replace re." name = "regex" +version = "2021.3.17" +description = "Alternative regular expression module, to replace re." +category = "dev" optional = false python-versions = "*" -version = "2021.3.17" [[package]] -category = "dev" -description = "Python Library for Tom's Obvious, Minimal Language" name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +category = "dev" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -version = "0.10.2" [[package]] -category = "dev" -description = "a fork of Python 2 and 3 ast modules with type comment support" name = "typed-ast" +version = "1.4.2" +description = "a fork of Python 2 and 3 ast modules with type comment support" +category = "dev" optional = false python-versions = "*" -version = "1.4.2" [[package]] -category = "dev" -description = "Backported and Experimental Type Hints for Python 3.5+" name = "typing-extensions" +version = "3.7.4.3" +description = "Backported and Experimental Type Hints for Python 3.5+" +category = "dev" optional = false python-versions = "*" -version = "3.7.4.3" [[package]] -category = "dev" -description = "Backport of pathlib-compatible object wrapper for zip files" name = "zipp" +version = "3.4.1" +description = "Backport of pathlib-compatible object wrapper for zip files" +category = "dev" optional = false python-versions = ">=3.6" -version = "3.4.1" [package.extras] docs = ["sphinx", "jaraco.packaging (>=8.2)", "rst.linker (>=1.9)"] testing = ["pytest (>=4.6)", "pytest-checkdocs (>=1.2.3)", "pytest-flake8", "pytest-cov", "pytest-enabler", "jaraco.itertools", "func-timeout", "pytest-black (>=0.3.7)", "pytest-mypy"] [metadata] -content-hash = "69bd5d5a73c6a65e339a0c999def8af2b5be98bf7c2c74192bc26830e4ea2ac8" lock-version = "1.1" python-versions = "^3.6" +content-hash = "69bd5d5a73c6a65e339a0c999def8af2b5be98bf7c2c74192bc26830e4ea2ac8" [metadata.files] appdirs = [ diff --git a/tests/test_dolt.py b/tests/test_dolt.py index 3ff7232..9a0f5e1 100644 --- a/tests/test_dolt.py +++ b/tests/test_dolt.py @@ -393,12 +393,6 @@ def test_sql_csv(create_test_table): _verify_against_base_rows(result) -def test_sql_tabular(create_test_table): - repo, test_table = create_test_table - result = repo.sql(query='SELECT * FROM `{table}`'.format(table=test_table), result_format='tabular') - _verify_against_base_rows(result) - - def _verify_against_base_rows(result: List[dict]): assert len(result) == len(BASE_TEST_ROWS)