Skip to content

Commit

Permalink
Write_file accepts file argument (#24)
Browse files Browse the repository at this point in the history
* Write_file accepts file argument

* Remote prints
  • Loading branch information
max-hoffman authored Aug 4, 2021
1 parent bb9ab3c commit 31608a3
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 9 deletions.
1 change: 1 addition & 0 deletions doltcli/dolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
write_rows,
)

global logger
logger = logging.getLogger(__name__)


Expand Down
33 changes: 26 additions & 7 deletions doltcli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
from collections import defaultdict
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union

from .types import DoltT
Expand Down Expand Up @@ -61,7 +62,8 @@ def read_table_sql(
def write_file(
dolt: DoltT,
table: str,
file_handle: io.StringIO,
file_handle: Optional[io.TextIOBase] = None,
file: Union[str, Path, None] = None,
# TODO what to do about this?
filetype: str = "csv",
import_mode: Optional[str] = None,
Expand All @@ -71,9 +73,25 @@ def write_file(
commit_date: Optional[datetime.datetime] = None,
do_continue: Optional[bool] = False,
):
def writer(filepath: str):
with open(filepath, "w", newline="") as f:
f.writelines(file_handle.readlines())
if file_handle is not None and file is not None:
raise ValueError("Specify one of: file, file_handle")
elif file_handle is None and file is None:
raise ValueError("Specify one of: file, file_handle")
elif file_handle is not None:

def writer(filepath: str):
if not isinstance(file_handle, io.TextIOBase):
raise ValueError(
f"file_handle expected type io.StringIO; found: {type(file_handle)}"
)
with open(filepath, "w", newline="") as f:
f.writelines(file_handle.readlines())
return filepath

elif file is not None:

def writer(filepath: str):
return str(file)

_import_helper(
dolt=dolt,
Expand Down Expand Up @@ -121,6 +139,7 @@ def writer(filepath: str):
rows = columns_to_rows(columns)
csv_writer.writeheader()
csv_writer.writerows(rows)
return filepath

_import_helper(
dolt=dolt,
Expand Down Expand Up @@ -163,12 +182,12 @@ def writer(filepath: str):
with open(filepath, "w", newline="") as f:
fieldnames: Set[str] = set()
for row in rows:
print(row)
fieldnames = fieldnames.union(set(row.keys()))

csv_writer = csv.DictWriter(f, fieldnames)
csv_writer.writeheader()
csv_writer.writerows(rows)
return filepath

_import_helper(
dolt=dolt,
Expand All @@ -186,7 +205,7 @@ def writer(filepath: str):
def _import_helper(
dolt: DoltT,
table: str,
write_import_file: Callable[[str], None],
write_import_file: Callable[[str], str],
import_mode: Optional[str] = None,
primary_key: Optional[List[str]] = None,
do_continue: Optional[bool] = False,
Expand All @@ -202,7 +221,7 @@ def _import_helper(
fname = tempfile.mktemp(suffix=".csv")
import_flags = IMPORT_MODES_TO_FLAGS[import_mode]
try:
write_import_file(fname)
fname = write_import_file(fname)
args = ["table", "import", table] + import_flags
if primary_key:
args += ["--pk={}".format(",".join(primary_key))]
Expand Down
75 changes: 73 additions & 2 deletions tests/test_write.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import pytest

from doltcli import CREATE, read_rows, write_columns, write_rows, DoltException, write_file
from doltcli import (
CREATE,
DoltException,
read_rows,
write_columns,
write_file,
write_rows,
)
from tests.helpers import compare_rows_helper, write_dict_to_csv

# Note that we use string values here as serializing via CSV does preserve type information in any meaningful way
Expand Down Expand Up @@ -47,7 +54,7 @@ def test_write_columns_uneven(init_empty_test_repo):
write_columns(repo, "players", DICT_OF_LISTS_UNEVEN_LENGTHS, CREATE, ["name"])


def test_write_file(init_empty_test_repo, tmp_path):
def test_write_file_handle(init_empty_test_repo, tmp_path):
tempfile = tmp_path / "test.csv"
TEST_ROWS = [
{"name": "Anna", "adjective": "tragic", "id": "1", "date_of_death": "1877-01-01"},
Expand All @@ -74,3 +81,67 @@ def test_write_file(init_empty_test_repo, tmp_path):
)
actual = read_rows(dolt, "characters")
compare_rows_helper(TEST_ROWS[:2], actual)


def test_write_file(init_empty_test_repo, tmp_path):
tempfile = tmp_path / "test.csv"
TEST_ROWS = [
{"name": "Anna", "adjective": "tragic", "id": "1", "date_of_death": "1877-01-01"},
{"name": "Vronksy", "adjective": "honorable", "id": "2", "date_of_death": ""},
{"name": "Vronksy", "adjective": "honorable", "id": "2", "date_of_death": ""},
]
write_dict_to_csv(TEST_ROWS, tempfile)
dolt = init_empty_test_repo
write_file(
dolt=dolt,
table="characters",
file=tempfile,
import_mode=CREATE,
primary_key=["id"],
do_continue=True,
)
actual = read_rows(dolt, "characters")
compare_rows_helper(TEST_ROWS[:2], actual)


def test_write_file_errors(init_empty_test_repo, tmp_path):
tempfile = tmp_path / "test.csv"
TEST_ROWS = [
{"name": "Anna", "adjective": "tragic", "id": "1", "date_of_death": "1877-01-01"},
{"name": "Vronksy", "adjective": "honorable", "id": "2", "date_of_death": ""},
{"name": "Vronksy", "adjective": "honorable", "id": "2", "date_of_death": ""},
]
write_dict_to_csv(TEST_ROWS, tempfile)
dolt = init_empty_test_repo
with pytest.raises(DoltException):
write_file(
dolt=dolt,
table="characters",
file_handle=open(tempfile),
import_mode=CREATE,
primary_key=["id"],
)
with pytest.raises(ValueError):
write_file(
dolt=dolt,
table="characters",
file_handle=open(tempfile),
file=tempfile,
import_mode=CREATE,
primary_key=["id"],
)
with pytest.raises(ValueError):
write_file(
dolt=dolt,
table="characters",
import_mode=CREATE,
primary_key=["id"],
)
with pytest.raises(ValueError):
write_file(
dolt=dolt,
file_handle=tempfile,
table="characters",
import_mode=CREATE,
primary_key=["id"],
)

0 comments on commit 31608a3

Please sign in to comment.