Skip to content

Commit

Permalink
- let dolt sql write to file
Browse files Browse the repository at this point in the history
- expose function parameter to parse resulting file

Goal is to expose the ability to use C path in Doltpy
  • Loading branch information
Oscar Batori committed Mar 25, 2021
1 parent 8371a0f commit 1311516
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 129 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
dist/
zip/
.pycache/
*__pycache__*
.idea/
79 changes: 47 additions & 32 deletions doltcli/dolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
from collections import OrderedDict
import datetime
from subprocess import PIPE, Popen
from typing import List, Dict, Tuple, Union, Optional
from typing import List, Dict, Tuple, Union, Optional, Callable, Any

logger = logging.getLogger(__name__)

SQL_OUTPUT_PARSERS = {
'csv': lambda fh: list(csv.DictReader(fh)),
'json': lambda fh: json.load(fh)
}

from .types import (
BranchT,
CommitT,
Expand All @@ -31,7 +36,6 @@
write_rows,
)


class DoltException(Exception):

"""
Expand Down Expand Up @@ -67,19 +71,23 @@ def __init__(self, message):
self.message = message


def _execute(args: List[str], cwd: Optional[str] = None):
def _execute(args: List[str], cwd: Optional[str] = None, outfile: Optional[str] = None):
_args = ["dolt"] + args
str_args = " ".join(" ".join(args).split())
logger.info(str_args)
proc = Popen(args=_args, cwd=cwd, stdout=PIPE, stderr=PIPE)
_outfile = open(outfile, 'w') if outfile else PIPE
proc = Popen(args=_args, cwd=cwd, stdout=_outfile, stderr=PIPE)
out, err = proc.communicate()
exitcode = proc.returncode

if exitcode != 0:
logger.error(err)
raise DoltException(str_args, out, err, exitcode)

return out.decode("utf-8")
if outfile:
return outfile
else:
return out.decode("utf-8")


class Status(StatusT):
Expand Down Expand Up @@ -280,22 +288,33 @@ def head(self):
return head_commit

def execute(
self, args: List[str], print_output: Optional[bool] = None
) -> List[str]:
self, args: List[str], print_output: Optional[bool] = None, stdout_to_file: bool = False
) -> str:
"""
Manages executing a dolt command, pass all commands, sub-commands, and arguments as they would appear on the
command line.
:param args:
:param print_output:
:param stdout_to_file:
:return:
"""
output = _execute(args, self.repo_dir)
if print_output and stdout_to_file:
raise ValueError('Cannot print output and send it to a file')

outfile = None
if stdout_to_file:
_, outfile = tempfile.mkstemp()

output = _execute(args, self.repo_dir, outfile=outfile)

print_output = print_output or self._print_output
if print_output:
logger.info(output)

return output.split("\n")
if outfile:
return outfile
else:
return output

@staticmethod
def init(repo_dir: Optional[str] = None) -> "Dolt":
Expand Down Expand Up @@ -337,7 +356,7 @@ def status(self) -> Status:
new_tables: Dict[str, bool] = {}
changes: Dict[str, bool] = {}

output = self.execute(["status"], print_output=False)
output = self.execute(["status"], print_output=False).split("\n")

if "clean" in str("\n".join(output)):
return Status(True, changes, new_tables)
Expand Down Expand Up @@ -449,7 +468,7 @@ def merge(self, branch: str, message: Optional[str] = None, squash: bool = False
args.append("--squash")

args.append(branch)
output = self.execute(args)
output = self.execute(args).split('\n')
merge_conflict_pos = 2

if len(output) == 3 and "Fast-forward" in output[1]:
Expand Down Expand Up @@ -492,6 +511,7 @@ def sql(
list_saved: bool = False,
batch: bool = False,
multi_db_dir: Optional[str] = None,
result_parser: Callable[[io.StringIO], Any] = None
):
"""
Execute a SQL query, using the options to dictate how it is executed, and where the output goes.
Expand All @@ -503,6 +523,7 @@ def sql(
:param list_saved: print out a list of saved queries
:param batch: execute in batch mode, one statement after the other delimited by ;
:param multi_db_dir: use a directory of Dolt repos, each one treated as a database
:param result_parser:
:return:
"""
args = ["sql"]
Expand Down Expand Up @@ -536,31 +557,22 @@ def sql(
"Must provide a query in order to specify a result format"
)
args.extend(["--query", query])
if result_format in ["csv", "tabular"]:
args.extend(["--result-format", "csv"])
output = self.execute(args)
dict_reader = csv.DictReader(io.StringIO("\n".join(output)))
return list(dict_reader)
elif result_format == "json":
args.extend(["--result-format", "json"])
output = self.execute(args)
return json.load(io.StringIO("".join(output)))

if result_format in ["csv", "json"]:
args.extend(["--result-format", result_format])
output_file = self.execute(args, stdout_to_file=True)
return SQL_OUTPUT_PARSERS[result_format](open(output_file))

else:
raise ValueError(
f"{result_format} is not a valid value for result_format"
)
args.extend(["--result-format", "csv"])
output_file = self.execute(args, stdout_to_file=True)
return result_parser(open(output_file))

logger.warning("Must provide a value for result_format to get output back")
if query:
args.extend(["--query", query])
self.execute(args)

def _parse_tabluar_output_to_dict(self, args: List[str]):
args.extend(["--result-format", "csv"])
output = self.execute(args)
dict_reader = csv.DictReader(io.StringIO("\n".join(output)))
return list(dict_reader)

def log(self, number: Optional[int] = None, commit: Optional[str] = None) -> Dict:
"""
Parses the log created by running the log command into instances of `Commit` that provide detail of the
Expand Down Expand Up @@ -734,7 +746,10 @@ def _get_branches(self) -> Tuple[Branch, List[Branch]]:
ab_dicts = read_rows_sql(
self, f"select * from dolt_branches where name = (select active_branch())"
)
assert len(ab_dicts) == 1

if len(ab_dicts) != 1:
raise ValueError('Ensure you have the latest version of Dolt installed, this is fixed as of 0.24.2')

active_branch = Branch(**ab_dicts[0])

if not active_branch:
Expand Down Expand Up @@ -795,7 +810,7 @@ def remote(
args = ["remote", "--verbose"]

if not (add or remove):
output = self.execute(args, print_output=False)
output = self.execute(args, print_output=False).split('\n')

remotes = []
for line in output:
Expand Down Expand Up @@ -1200,7 +1215,7 @@ def ls(self, system: bool = False, all: bool = False) -> List[TableT]:
if system:
args.append("--system")

output = self.execute(args, print_output=False)
output = self.execute(args, print_output=False).split('\n')
tables: List[TableT] = []
system_pos = None

Expand Down
4 changes: 3 additions & 1 deletion doltcli/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import asdict, dataclass
import datetime
import json
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Callable, Any
import io


class Encoder(json.JSONEncoder):
Expand Down Expand Up @@ -136,6 +137,7 @@ def sql(
list_saved: bool = False,
batch: bool = False,
multi_db_dir: Optional[str] = None,
result_parser: Callable[[io.StringIO], Any] = None
) -> List:
...

Expand Down
14 changes: 7 additions & 7 deletions doltcli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,30 @@
def read_columns(
dolt: DoltT, table: str, as_of: Optional[str] = None
) -> Dict[str, list]:
return read_columns_sql(dolt, _get_read_table_asof_query(table, as_of))
return read_columns_sql(dolt, get_read_table_asof_query(table, as_of))


def read_rows(dolt: DoltT, table: str, as_of: Optional[str] = None) -> List[dict]:
return read_rows_sql(dolt, _get_read_table_asof_query(table, as_of))
return read_rows_sql(dolt, get_read_table_asof_query(table, as_of))


def _get_read_table_asof_query(table: str, as_of: Optional[str] = None) -> str:
def get_read_table_asof_query(table: str, as_of: Optional[str] = None) -> str:
base_query = f"SELECT * FROM `{table}`"
return f'{base_query} AS OF "{as_of}"' if as_of else base_query


def read_columns_sql(dolt: DoltT, sql: str) -> Dict[str, list]:
rows = _read_table_sql(dolt, sql)
rows = read_table_sql(dolt, sql)
columns = rows_to_columns(rows)
return columns


def read_rows_sql(dolt: DoltT, sql: str) -> List[dict]:
return _read_table_sql(dolt, sql)
return read_table_sql(dolt, sql)


def _read_table_sql(dolt: DoltT, sql: str) -> List[dict]:
return dolt.sql(sql, result_format="csv")
def read_table_sql(dolt: DoltT, sql: str, result_parser: Callable[[io.StringIO], Any] = None) -> List[dict]:
return dolt.sql(sql, result_format="csv", result_parser=result_parser)


CREATE, FORCE_CREATE, REPLACE, UPDATE = "create", "force_create", "replace", "update"
Expand Down
Loading

0 comments on commit 1311516

Please sign in to comment.