diff --git a/nodestream/cli/operations/run_copy.py b/nodestream/cli/operations/run_copy.py index d3331329..0b5629db 100644 --- a/nodestream/cli/operations/run_copy.py +++ b/nodestream/cli/operations/run_copy.py @@ -2,6 +2,7 @@ from ...databases import Copier, GraphDatabaseWriter from ...pipeline import Pipeline +from ...pipeline.object_storage import ObjectStore from ...project import Project, Target from ..commands.nodestream_command import NodestreamCommand from .operation import Operation @@ -29,7 +30,9 @@ async def perform(self, command: NodestreamCommand): def build_pipeline(self) -> Pipeline: copier = self.build_copier() writer = self.build_writer() - return Pipeline([copier, writer], step_outbox_size=10000) + return Pipeline( + [copier, writer], step_outbox_size=10000, object_store=ObjectStore.null() + ) def build_copier(self) -> Copier: return Copier( diff --git a/nodestream/cli/operations/run_pipeline.py b/nodestream/cli/operations/run_pipeline.py index df12034c..724cf6c6 100644 --- a/nodestream/cli/operations/run_pipeline.py +++ b/nodestream/cli/operations/run_pipeline.py @@ -4,6 +4,7 @@ from ...metrics import Metrics from ...pipeline import PipelineInitializationArguments, PipelineProgressReporter +from ...pipeline.object_storage import ObjectStore from ...project import Project, RunRequest from ...project.pipeline_definition import PipelineDefinition from ...utils import StringSuggester @@ -101,6 +102,7 @@ def print_effective_config(config): extra_steps=list( self.get_writer_steps_for_specified_targets(command, pipeline) ), + object_store=ObjectStore.in_current_directory(), ), progress_reporter=self.create_progress_reporter(command, pipeline.name), ) diff --git a/nodestream/pipeline/extractors/extractor.py b/nodestream/pipeline/extractors/extractor.py index 259113da..95002a4e 100644 --- a/nodestream/pipeline/extractors/extractor.py +++ b/nodestream/pipeline/extractors/extractor.py @@ -1,10 +1,15 @@ from abc import abstractmethod -from typing import Any, AsyncGenerator +from typing import Any, AsyncGenerator, Generic, TypeVar -from ..step import Step +from ..step import Step, StepContext +R = TypeVar("R") +T = TypeVar("T") +CHECKPOINT_OBJECT_KEY = "extractor_progress_checkpoint" +CHECKPOINT_INTERVAL = 1000 -class Extractor(Step): + +class Extractor(Step, Generic[R, T]): """Extractors represent the source of a set of records. They are like any other step. However, they ignore the incoming record ' @@ -12,9 +17,35 @@ class Extractor(Step): they generally should only be set at the beginning of a pipeline. """ - def emit_outstanding_records(self): - return self.extract_records() + async def start(self, context: StepContext): + if checkpoint := context.object_store.get_pickled(CHECKPOINT_OBJECT_KEY): + context.info("Found Checkpoint For Extractor. Signaling to resume from it.") + await self.resume_from_checkpoint(checkpoint) + + async def finish(self, context: StepContext): + context.debug("Clearing checkpoint for extractor since extractor is finished.") + context.object_store.delete(CHECKPOINT_OBJECT_KEY) + + async def make_checkpoint(self) -> T: + return None + + async def resume_from_checkpoint(self, checkpoint_object: T): + pass + + async def commit_checkpoint(self, context: StepContext) -> None: + if checkpoint := await self.make_checkpoint(): + context.object_store.put_picklable(CHECKPOINT_OBJECT_KEY, checkpoint) + + async def emit_outstanding_records( + self, context: StepContext + ) -> AsyncGenerator[R, None]: + items_generated = 0 + async for record in self.extract_records(): + yield record + items_generated += 1 + if items_generated % CHECKPOINT_INTERVAL == 0: + await self.commit_checkpoint(context) @abstractmethod - async def extract_records(self) -> AsyncGenerator[Any, Any]: + def extract_records(self) -> AsyncGenerator[R, Any]: raise NotImplementedError diff --git a/nodestream/pipeline/extractors/iterable.py b/nodestream/pipeline/extractors/iterable.py index 6e02710f..5bc8b454 100644 --- a/nodestream/pipeline/extractors/iterable.py +++ b/nodestream/pipeline/extractors/iterable.py @@ -1,3 +1,4 @@ +from logging import getLogger from typing import Any, AsyncGenerator, Iterable from .extractor import Extractor @@ -12,7 +13,20 @@ def range(cls, start=0, stop=100, step=1): def __init__(self, iterable: Iterable[Any]) -> None: self.iterable = iterable + self.index = 0 + self.logger = getLogger(self.__class__.__name__) async def extract_records(self) -> AsyncGenerator[Any, Any]: - for record in self.iterable: + for index, record in enumerate(self.iterable): + if index < self.index: + continue + self.index = index yield record + + async def make_checkpoint(self): + return self.index + + async def resume_from_checkpoint(self, checkpoint): + if isinstance(checkpoint, int): + self.index = checkpoint + self.logger.info(f"Resuming from checkpoint {checkpoint}") diff --git a/nodestream/pipeline/object_storage.py b/nodestream/pipeline/object_storage.py index fc4627a2..89dff2c7 100644 --- a/nodestream/pipeline/object_storage.py +++ b/nodestream/pipeline/object_storage.py @@ -202,6 +202,14 @@ def signed(self, signer: "Signer") -> "ObjectStore": """ return SignedObjectStore(self, signer) + @staticmethod + def null() -> "ObjectStore": + return NullObjectStore() + + @staticmethod + def in_current_directory(): + return DirectoryObjectStore(Path.cwd() / ".nodestream" / "objects") + class DirectoryObjectStore(ObjectStore, alias="directory"): """An object store that stores objects in a directory on a file system.""" @@ -209,10 +217,6 @@ class DirectoryObjectStore(ObjectStore, alias="directory"): def __init__(self, root: Path): self.root = root - @staticmethod - def in_current_directory(): - return DirectoryObjectStore(Path.cwd() / ".nodestream" / "objects") - def get(self, key: str) -> Optional[bytes]: path = self.root / key if not path.exists(): diff --git a/nodestream/pipeline/pipeline.py b/nodestream/pipeline/pipeline.py index 210d0e8c..a4cdb770 100644 --- a/nodestream/pipeline/pipeline.py +++ b/nodestream/pipeline/pipeline.py @@ -5,6 +5,7 @@ from ..metrics import Metric, Metrics from ..schema import ExpandsSchema, ExpandsSchemaFromChildren from .channel import StepInput, StepOutput, channel +from .object_storage import ObjectStore from .progress_reporter import PipelineProgressReporter from .step import Step, StepContext @@ -62,7 +63,7 @@ async def drive_step(self): if not await self.emit_record(record): return - async for record in self.step.emit_outstanding_records(): + async for record in self.step.emit_outstanding_records(self.context): if not await self.emit_record(record): return @@ -129,12 +130,15 @@ class Pipeline(ExpandsSchemaFromChildren): and running the steps in the pipeline. """ - __slots__ = ("steps", "step_outbox_size") + __slots__ = ("steps", "step_outbox_size", "logger", "object_store") - def __init__(self, steps: Tuple[Step, ...], step_outbox_size: int) -> None: + def __init__( + self, steps: Tuple[Step, ...], step_outbox_size: int, object_store: ObjectStore + ) -> None: self.steps = steps self.step_outbox_size = step_outbox_size self.logger = getLogger(self.__class__.__name__) + self.object_store = object_store def get_child_expanders(self) -> Iterable[ExpandsSchema]: return (s for s in self.steps if isinstance(s, ExpandsSchema)) @@ -175,7 +179,8 @@ async def run(self, reporter: PipelineProgressReporter): # input of the next step. for reversed_index, step in reversed(list(enumerate(self.steps))): index = len(self.steps) - reversed_index - 1 - context = StepContext(step.__class__.__name__, index, reporter) + storage = self.object_store.namespaced(str(index)) + context = StepContext(step.__class__.__name__, index, reporter, storage) current_input, next_output = channel(self.step_outbox_size) exec = StepExecutor(step, current_input, current_output, context) current_output = next_output diff --git a/nodestream/pipeline/pipeline_file_loader.py b/nodestream/pipeline/pipeline_file_loader.py index fe48fc76..587640e1 100644 --- a/nodestream/pipeline/pipeline_file_loader.py +++ b/nodestream/pipeline/pipeline_file_loader.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +import hashlib +from dataclasses import dataclass, field from logging import getLogger from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set @@ -12,6 +13,7 @@ from .argument_resolvers import set_config from .class_loader import ClassLoader from .normalizers import Normalizer +from .object_storage import ObjectStore from .pipeline import Pipeline from .scope_config import ScopeConfig from .step import Step @@ -51,6 +53,7 @@ class PipelineInitializationArguments: on_effective_configuration_resolved: Optional[Callable[[List[Dict]], None]] = None extra_steps: Optional[List[Step]] = None effecitve_config_values: Optional[ScopeConfig] = None + object_store: ObjectStore = field(default_factory=ObjectStore.null) @classmethod def for_introspection(cls): @@ -149,7 +152,11 @@ def initialize_with_arguments(self, init_args: PipelineInitializationArguments): if step_definition.should_be_loaded(init_args.annotations) ] steps = steps_defined_in_file + (init_args.extra_steps or []) - return Pipeline(steps, step_outbox_size=init_args.step_outbox_size) + return Pipeline( + steps, + step_outbox_size=init_args.step_outbox_size, + object_store=init_args.object_store, + ) class PipelineFile: @@ -157,11 +164,21 @@ def __init__(self, file_path: Path): self.file_path = file_path self.logger = getLogger(self.__class__.__name__) + def file_sha_256(self) -> str: + sha = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with self.file_path.open("rb", buffering=0) as file: + while n := file.readinto(mv): + sha.update(mv[:n]) + return sha.hexdigest() + def load_pipeline( self, init_args: Optional[PipelineInitializationArguments] = None ) -> Pipeline: self.logger.info("Loading Pipeline") init_args = init_args or PipelineInitializationArguments() + init_args.object_store = init_args.object_store.namespaced(self.file_sha_256()) contents = self.get_contents() return contents.initialize_with_arguments(init_args) diff --git a/nodestream/pipeline/step.py b/nodestream/pipeline/step.py index fcab657e..b18181e0 100644 --- a/nodestream/pipeline/step.py +++ b/nodestream/pipeline/step.py @@ -1,6 +1,7 @@ from typing import AsyncGenerator, Optional from ..metrics import Metric, Metrics +from .object_storage import ObjectStore from .progress_reporter import PipelineProgressReporter @@ -11,14 +12,19 @@ class StepContext: and report and perist information about the state of the pipeline. """ - __slots__ = ("reporter", "index", "name") + __slots__ = ("reporter", "index", "name", "object_store") def __init__( - self, name: str, index: int, reporter: PipelineProgressReporter + self, + name: str, + index: int, + reporter: PipelineProgressReporter, + object_store: ObjectStore, ) -> None: self.name = name self.reporter = reporter self.index = index + self.object_store = object_store def report_error( self, @@ -113,7 +119,7 @@ async def process_record( """ yield record - async def emit_outstanding_records(self): + async def emit_outstanding_records(self, context: StepContext): """Emit any outstanding records. This method is called after all records have been processed. It is diff --git a/tests/integration/test_pipeline_flush_handling.py b/tests/integration/test_pipeline_flush_handling.py index b98423b4..8c48655e 100644 --- a/tests/integration/test_pipeline_flush_handling.py +++ b/tests/integration/test_pipeline_flush_handling.py @@ -39,7 +39,9 @@ def test_extractor_with_flushes(mocker): @pytest.mark.asyncio -async def test_flush_handling(writer, interpreter, test_extractor_with_flushes): - pipeline = Pipeline([test_extractor_with_flushes, interpreter, writer], 1000) +async def test_flush_handling(writer, interpreter, test_extractor_with_flushes, mocker): + pipeline = Pipeline( + [test_extractor_with_flushes, interpreter, writer], 1000, mocker.Mock() + ) await pipeline.run(PipelineProgressReporter()) assert writer.ingest_strategy.flush.call_count == 5 diff --git a/tests/unit/pipeline/extractors/test_iterable_extractor.py b/tests/unit/pipeline/extractors/test_iterable_extractor.py index cfcfa3a6..3f9bf8d0 100644 --- a/tests/unit/pipeline/extractors/test_iterable_extractor.py +++ b/tests/unit/pipeline/extractors/test_iterable_extractor.py @@ -18,3 +18,22 @@ async def test_iterable_extractor_range_factory(): subject = IterableExtractor.range(1, 6, 2) results = [item async for item in subject.extract_records()] assert_that(results, equal_to(expected_results)) + + +@pytest.mark.asyncio +async def test_iterable_extractor_resume(): + expected_results = [{"index": 3}, {"index": 5}] + subject = IterableExtractor.range(1, 6, 2) + await subject.resume_from_checkpoint(1) + results = [item async for item in subject.extract_records()] + assert_that(results, equal_to(expected_results)) + + +@pytest.mark.asyncio +async def test_iterable_extractor_checkpoint(): + subject = IterableExtractor([1, 2, 3]) + generator = subject.extract_records() + await anext(generator) + await anext(generator) + checkpoint = await subject.make_checkpoint() + assert_that(checkpoint, equal_to(1)) diff --git a/tests/unit/pipeline/test_object_storage.py b/tests/unit/pipeline/test_object_storage.py index 24eddba4..0d2cdf1a 100644 --- a/tests/unit/pipeline/test_object_storage.py +++ b/tests/unit/pipeline/test_object_storage.py @@ -10,6 +10,7 @@ InvalidSignatureError, MalformedSignedObjectError, NullObjectStore, + ObjectStore, SignedObject, StaticNamespace, ) @@ -106,7 +107,7 @@ def test_get_pickled_missing_object(directory_object_store): def test_directory_object_store_default_directory(): - store = DirectoryObjectStore.in_current_directory() + store = ObjectStore.in_current_directory() assert_that(store.root, equal_to(Path.cwd() / ".nodestream" / "objects")) diff --git a/tests/unit/pipeline/test_pipeline_progress_reporter.py b/tests/unit/pipeline/test_pipeline_progress_reporter.py index 4c8d2f27..46610961 100644 --- a/tests/unit/pipeline/test_pipeline_progress_reporter.py +++ b/tests/unit/pipeline/test_pipeline_progress_reporter.py @@ -6,7 +6,7 @@ @pytest.mark.asyncio async def test_pipeline_progress_reporter_calls_with_reporting_frequency(mocker): - pipeline = Pipeline([IterableExtractor(range(100))], 10) + pipeline = Pipeline([IterableExtractor(range(100))], 10, mocker.Mock()) reporter = PipelineProgressReporter(reporting_frequency=10, callback=mocker.Mock()) await pipeline.run(reporter) assert_that(reporter.callback.call_count, equal_to(10)) diff --git a/tests/unit/pipeline/test_step.py b/tests/unit/pipeline/test_step.py index 7aa523bb..8a156d78 100644 --- a/tests/unit/pipeline/test_step.py +++ b/tests/unit/pipeline/test_step.py @@ -23,7 +23,7 @@ async def test_default_finish_does_nothing(mocker): @pytest.mark.asyncio async def test_default_emit_outstanding_records_does_nothing(mocker): step = Step() - async for _ in step.emit_outstanding_records(): + async for _ in step.emit_outstanding_records(mocker.Mock(spec=StepContext)): assert False, "Should not emit any records" @@ -40,7 +40,10 @@ async def test_default_process_record_yields_input_record(mocker): async def test_step_context_report_error(mocker): exception = Exception() ctx = StepContext( - "bob", 1, PipelineProgressReporter(on_fatal_error_callback=mocker.Mock()) + "bob", + 1, + PipelineProgressReporter(on_fatal_error_callback=mocker.Mock()), + object_store=mocker.Mock(), ) ctx.report_error("oh no, an error!", exception) ctx.reporter.on_fatal_error_callback.assert_not_called() @@ -50,7 +53,10 @@ async def test_step_context_report_error(mocker): async def test_step_context_report_fatal_error(mocker): exception = Exception() ctx = StepContext( - "bob", 1, PipelineProgressReporter(on_fatal_error_callback=mocker.Mock()) + "bob", + 1, + PipelineProgressReporter(on_fatal_error_callback=mocker.Mock()), + object_store=mocker.Mock(), ) ctx.report_error("oh no, a fatal error!", exception, fatal=True) ctx.reporter.on_fatal_error_callback.assert_called_once_with(exception) @@ -58,7 +64,12 @@ async def test_step_context_report_fatal_error(mocker): @pytest.mark.asyncio async def test_step_context_report_debug_message(mocker): - ctx = StepContext("bob", 1, PipelineProgressReporter(logger=mocker.Mock())) + ctx = StepContext( + "bob", + 1, + PipelineProgressReporter(logger=mocker.Mock()), + object_store=mocker.Mock(), + ) ctx.debug("debug message", x=12) ctx.reporter.logger.debug.assert_called_once_with( "debug message", extra={"index": 1, "x": 12, "step_name": "bob"} @@ -67,7 +78,12 @@ async def test_step_context_report_debug_message(mocker): @pytest.mark.asyncio async def test_step_context_report_info_message(mocker): - ctx = StepContext("bob", 1, PipelineProgressReporter(logger=mocker.Mock())) + ctx = StepContext( + "bob", + 1, + PipelineProgressReporter(logger=mocker.Mock()), + object_store=mocker.Mock(), + ) ctx.info("info message", x=12) ctx.reporter.logger.info.assert_called_once_with( "info message", extra={"index": 1, "x": 12, "step_name": "bob"} @@ -76,7 +92,12 @@ async def test_step_context_report_info_message(mocker): @pytest.mark.asyncio async def test_step_context_report_warning_message(mocker): - ctx = StepContext("bob", 1, PipelineProgressReporter(logger=mocker.Mock())) + ctx = StepContext( + "bob", + 1, + PipelineProgressReporter(logger=mocker.Mock()), + object_store=mocker.Mock(), + ) ctx.warning("warning message", x=12) ctx.reporter.logger.warning.assert_called_once_with( "warning message", extra={"index": 1, "x": 12, "step_name": "bob"} diff --git a/tests/unit/test_extractor.py b/tests/unit/test_extractor.py new file mode 100644 index 00000000..0f093e02 --- /dev/null +++ b/tests/unit/test_extractor.py @@ -0,0 +1,71 @@ +import pytest +from hamcrest import assert_that, is_, none + +from nodestream.pipeline import Extractor +from nodestream.pipeline.object_storage import ObjectStore +from nodestream.pipeline.progress_reporter import PipelineProgressReporter +from nodestream.pipeline.step import StepContext + + +class DummyExtractor(Extractor): + def __init__(self, checkpoint=None): + self.checkpoint = checkpoint + + async def extract_records(self): + for x in []: + yield x + + async def make_checkpoint(self): + return self.checkpoint + + async def resume_from_checkpoint(self, checkpoint_object): + self.checkpoint = checkpoint_object + + +@pytest.fixture +def object_store(mocker): + return mocker.Mock(spec=ObjectStore) + + +@pytest.fixture +def context(object_store): + return StepContext("test", 1, PipelineProgressReporter(), object_store) + + +@pytest.mark.asyncio +async def test_extractor_commit_does_nothing(object_store, context): + extractor = DummyExtractor() + await extractor.commit_checkpoint(context) + object_store.put_picklable.assert_not_called() + + +@pytest.mark.asyncio +async def test_extractor_commit_does_something_with_checkpoint(object_store, context): + extractor = DummyExtractor(checkpoint=1) + await extractor.commit_checkpoint(context) + object_store.put_picklable.assert_called_once_with( + "extractor_progress_checkpoint", 1 + ) + + +@pytest.mark.asyncio +async def test_extractor_finish_cleans(object_store, context): + extractor = DummyExtractor() + await extractor.finish(context) + object_store.delete.assert_called_once_with("extractor_progress_checkpoint") + + +@pytest.mark.asyncio +async def test_extractor_start_does_nothing_when_no_checkpoint(object_store, context): + extractor = DummyExtractor() + object_store.get_pickled.return_value = None + await extractor.start(context) + assert_that(extractor.checkpoint, is_(none())) + + +@pytest.mark.asyncio +async def test_extractor_start_does_something_when_checkpoint(object_store, context): + extractor = DummyExtractor() + object_store.get_pickled.return_value = 1 + await extractor.start(context) + assert_that(extractor.checkpoint, is_(1))