Skip to content

Commit

Permalink
Lint/format
Browse files Browse the repository at this point in the history
  • Loading branch information
danpalmer committed Mar 1, 2021
1 parent 72c12a5 commit 5ea0b5f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 25 deletions.
10 changes: 7 additions & 3 deletions src/devdata/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ def export_data(django_dbname, only=None, no_update=False):
continue

model = to_model(app_model_label)
bar.set_description('{} ({})'.format(app_model_label, strategy.name))
bar.set_description(
"{} ({})".format(app_model_label, strategy.name)
)

if isinstance(strategy, Exportable):
strategy.export_data(django_dbname, model, exporter, no_update, log=bar.write)
strategy.export_data(
django_dbname, model, exporter, no_update, log=bar.write
)


def import_schema(django_dbname):
Expand Down Expand Up @@ -142,7 +146,7 @@ def import_data(django_dbname):
bar = progress(model_strategies)
for app_model_label, strategy in bar:
model = to_model(app_model_label)
bar.set_description('{} ({})'.format(app_model_label, strategy.name))
bar.set_description("{} ({})".format(app_model_label, strategy.name))
strategy.import_data(django_dbname, model)


Expand Down
35 changes: 20 additions & 15 deletions src/devdata/exporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
from tempfile import NamedTemporaryFile, TemporaryDirectory, TemporaryFile
from typing import Any, Dict, Iterator, NamedTuple

SEPARATOR_BYTE = b'\x00'
SEPARATOR_BYTE = b"\x00"
BLOCK_SIZE = 4096

ExportRequest = NamedTuple('ExportRequest', (
('app_model_label', str),
('strategy_class', str),
('strategy_kwargs', Dict[str, Any]),
('database', str),
))
ExportRequest = NamedTuple(
"ExportRequest",
(
("app_model_label", str),
("strategy_class", str),
("strategy_kwargs", Dict[str, Any]),
("database", str),
),
)


def read_requests(fh: io.BytesIO) -> Iterator[ExportRequest]:
Expand All @@ -25,7 +28,7 @@ def read_requests(fh: io.BytesIO) -> Iterator[ExportRequest]:

left, separator, right = buffer.partition(SEPARATOR_BYTE)
if separator:
command_dict = json.loads(left.decode('utf-8'))
command_dict = json.loads(left.decode("utf-8"))
yield ExportRequest(**command_dict)
buffer = right

Expand All @@ -41,7 +44,7 @@ def response_writer(fh: io.BytesIO) -> Iterator[io.BytesIO]:
class ExportFailed(Exception):
def __init__(self, exitcode: int):
self.exitcode = exitcode

def __repr__(self):
return "ExportFailed(exitcode={})".format(self.exitcode)

Expand All @@ -51,16 +54,16 @@ def __init__(self, command: str):
self.command = command
self._process = None
self.buffer = bytes()
def __enter__(self) -> 'Exporter':

def __enter__(self) -> "Exporter":
self._process = subprocess.Popen(
self.command.split(),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
)
self._process.__enter__()
return self

def __exit__(self, *args, **kwargs):
self._process.__exit__(*args, **kwargs)

Expand All @@ -83,7 +86,7 @@ def export(
written = 0

with TemporaryDirectory() as tempdir:
request_data = json.dumps(export_request._asdict()).encode('utf-8')
request_data = json.dumps(export_request._asdict()).encode("utf-8")
self._process.stdin.write(request_data)
self._process.stdin.write(SEPARATOR_BYTE)
self._process.stdin.flush()
Expand All @@ -95,12 +98,14 @@ def export(
raise ExportFailed(self._process.returncode)

self.buffer += self._process.stdout.read1(BLOCK_SIZE)
data, separator, rest = self.buffer.partition(SEPARATOR_BYTE)
data, separator, rest = self.buffer.partition(
SEPARATOR_BYTE
)
written += output.write(data)
self.buffer = rest
if separator:
break

temp_file_path.rename(output_path)

return written
6 changes: 4 additions & 2 deletions src/devdata/management/commands/devdata_exporter_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class Command(BaseCommand):
help = "Dump the data generated by a given strategy."

def handle(self, **options):
for request in read_requests(sys.stdin.buffer):
with response_writer(sys.stdout.buffer) as output:
Expand All @@ -21,7 +21,9 @@ def handle_request(self, request, output):
app_config = apps.get_app_config(app_name)
model = app_config.get_model(model_name)

strategy = import_string(request.strategy_class)(**request.strategy_kwargs)
strategy = import_string(request.strategy_class)(
**request.strategy_kwargs
)

if not isinstance(strategy, Exportable):
raise CommandError("Strategy not locally exportable")
Expand Down
28 changes: 23 additions & 5 deletions src/devdata/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@ def __init__(self, *args, name, **kwargs):

self.name = name

def export_data(self, django_dbname, model, exporter, no_update=False, log=lambda x: None):
def export_data(
self,
django_dbname,
model,
exporter,
no_update=False,
log=lambda x: None,
):
"""
Export the data to a directory on disk. `no_update` indicates not to
update if there is any data already existing locally.
Expand Down Expand Up @@ -139,7 +146,7 @@ def export_local(self, django_dbname, model, output):
else serializers.get_serializer("json")
)

stream = codecs.getwriter('utf-8')(output)
stream = codecs.getwriter("utf-8")(output)

serializer.serialize(
queryset.iterator(),
Expand All @@ -149,15 +156,24 @@ def export_local(self, django_dbname, model, output):
stream=stream,
)

def export_data(self, django_dbname, model, exporter, no_update=False, log=lambda x: None):
def export_data(
self,
django_dbname,
model,
exporter,
no_update=False,
log=lambda x: None,
):
app_model_label = to_app_model_label(model)
data_file = self.data_file(app_model_label)

if no_update and data_file.exists():
return

kwargs = self.get_kwargs(model)
klass = "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
klass = "{}.{}".format(
self.__class__.__module__, self.__class__.__name__
)

self.ensure_dir_exists(app_model_label)

Expand Down Expand Up @@ -191,7 +207,9 @@ def import_data(self, django_dbname, model):
)
self.import_objects(django_dbname, model, objects)
except Exception:
print("Failed to import {} ({})".format(app_model_label, self.name))
print(
"Failed to import {} ({})".format(app_model_label, self.name)
)
raise

def import_objects(self, django_dbname, model, objects):
Expand Down

0 comments on commit 5ea0b5f

Please sign in to comment.