Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working Basic Checkpointing #385

Merged
merged 6 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion nodestream/cli/operations/run_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ...databases import Copier, GraphDatabaseWriter
from ...pipeline import Pipeline
from ...pipeline.object_storage import ObjectStore
from ...project import Project, Target
from ..commands.nodestream_command import NodestreamCommand
from .operation import Operation
Expand Down Expand Up @@ -29,7 +30,9 @@ async def perform(self, command: NodestreamCommand):
def build_pipeline(self) -> Pipeline:
copier = self.build_copier()
writer = self.build_writer()
return Pipeline([copier, writer], step_outbox_size=10000)
return Pipeline(
[copier, writer], step_outbox_size=10000, object_store=ObjectStore.null()
)

def build_copier(self) -> Copier:
return Copier(
Expand Down
2 changes: 2 additions & 0 deletions nodestream/cli/operations/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ...metrics import Metrics
from ...pipeline import PipelineInitializationArguments, PipelineProgressReporter
from ...pipeline.object_storage import ObjectStore
from ...project import Project, RunRequest
from ...project.pipeline_definition import PipelineDefinition
from ...utils import StringSuggester
Expand Down Expand Up @@ -101,6 +102,7 @@ def print_effective_config(config):
extra_steps=list(
self.get_writer_steps_for_specified_targets(command, pipeline)
),
object_store=ObjectStore.in_current_directory(),
),
progress_reporter=self.create_progress_reporter(command, pipeline.name),
)
Expand Down
43 changes: 37 additions & 6 deletions nodestream/pipeline/extractors/extractor.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,51 @@
from abc import abstractmethod
from typing import Any, AsyncGenerator
from typing import Any, AsyncGenerator, Generic, TypeVar

from ..step import Step
from ..step import Step, StepContext

R = TypeVar("R")
T = TypeVar("T")
CHECKPOINT_OBJECT_KEY = "extractor_progress_checkpoint"
CHECKPOINT_INTERVAL = 1000

class Extractor(Step):

class Extractor(Step, Generic[R, T]):
"""Extractors represent the source of a set of records.

They are like any other step. However, they ignore the incoming record '
stream and instead produce their own stream of records. For this reason
they generally should only be set at the beginning of a pipeline.
"""

def emit_outstanding_records(self):
return self.extract_records()
async def start(self, context: StepContext):
if checkpoint := context.object_store.get_pickled(CHECKPOINT_OBJECT_KEY):
context.info("Found Checkpoint For Extractor. Signaling to resume from it.")
await self.resume_from_checkpoint(checkpoint)

async def finish(self, context: StepContext):
context.debug("Clearing checkpoint for extractor since extractor is finished.")
context.object_store.delete(CHECKPOINT_OBJECT_KEY)

async def make_checkpoint(self) -> T:
return None

async def resume_from_checkpoint(self, checkpoint_object: T):
pass

async def commit_checkpoint(self, context: StepContext) -> None:
if checkpoint := await self.make_checkpoint():
context.object_store.put_picklable(CHECKPOINT_OBJECT_KEY, checkpoint)

async def emit_outstanding_records(
self, context: StepContext
) -> AsyncGenerator[R, None]:
items_generated = 0
async for record in self.extract_records():
yield record
items_generated += 1
if items_generated % CHECKPOINT_INTERVAL == 0:
await self.commit_checkpoint(context)

@abstractmethod
async def extract_records(self) -> AsyncGenerator[Any, Any]:
def extract_records(self) -> AsyncGenerator[R, Any]:
raise NotImplementedError
16 changes: 15 additions & 1 deletion nodestream/pipeline/extractors/iterable.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from logging import getLogger
from typing import Any, AsyncGenerator, Iterable

from .extractor import Extractor
Expand All @@ -12,7 +13,20 @@ def range(cls, start=0, stop=100, step=1):

def __init__(self, iterable: Iterable[Any]) -> None:
self.iterable = iterable
self.index = 0
self.logger = getLogger(self.__class__.__name__)

async def extract_records(self) -> AsyncGenerator[Any, Any]:
for record in self.iterable:
for index, record in enumerate(self.iterable):
if index < self.index:
continue
self.index = index
yield record

async def make_checkpoint(self):
return self.index

async def resume_from_checkpoint(self, checkpoint):
if isinstance(checkpoint, int):
self.index = checkpoint
self.logger.info(f"Resuming from checkpoint {checkpoint}")
12 changes: 8 additions & 4 deletions nodestream/pipeline/object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,21 @@ def signed(self, signer: "Signer") -> "ObjectStore":
"""
return SignedObjectStore(self, signer)

@staticmethod
def null() -> "ObjectStore":
return NullObjectStore()

@staticmethod
def in_current_directory():
return DirectoryObjectStore(Path.cwd() / ".nodestream" / "objects")


class DirectoryObjectStore(ObjectStore, alias="directory"):
"""An object store that stores objects in a directory on a file system."""

def __init__(self, root: Path):
self.root = root

@staticmethod
def in_current_directory():
return DirectoryObjectStore(Path.cwd() / ".nodestream" / "objects")

def get(self, key: str) -> Optional[bytes]:
path = self.root / key
if not path.exists():
Expand Down
13 changes: 9 additions & 4 deletions nodestream/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..metrics import Metric, Metrics
from ..schema import ExpandsSchema, ExpandsSchemaFromChildren
from .channel import StepInput, StepOutput, channel
from .object_storage import ObjectStore
from .progress_reporter import PipelineProgressReporter
from .step import Step, StepContext

Expand Down Expand Up @@ -62,7 +63,7 @@ async def drive_step(self):
if not await self.emit_record(record):
return

async for record in self.step.emit_outstanding_records():
async for record in self.step.emit_outstanding_records(self.context):
if not await self.emit_record(record):
return

Expand Down Expand Up @@ -129,12 +130,15 @@ class Pipeline(ExpandsSchemaFromChildren):
and running the steps in the pipeline.
"""

__slots__ = ("steps", "step_outbox_size")
__slots__ = ("steps", "step_outbox_size", "logger", "object_store")

def __init__(self, steps: Tuple[Step, ...], step_outbox_size: int) -> None:
def __init__(
self, steps: Tuple[Step, ...], step_outbox_size: int, object_store: ObjectStore
) -> None:
self.steps = steps
self.step_outbox_size = step_outbox_size
self.logger = getLogger(self.__class__.__name__)
self.object_store = object_store

def get_child_expanders(self) -> Iterable[ExpandsSchema]:
return (s for s in self.steps if isinstance(s, ExpandsSchema))
Expand Down Expand Up @@ -175,7 +179,8 @@ async def run(self, reporter: PipelineProgressReporter):
# input of the next step.
for reversed_index, step in reversed(list(enumerate(self.steps))):
index = len(self.steps) - reversed_index - 1
context = StepContext(step.__class__.__name__, index, reporter)
storage = self.object_store.namespaced(str(index))
context = StepContext(step.__class__.__name__, index, reporter, storage)
current_input, next_output = channel(self.step_outbox_size)
exec = StepExecutor(step, current_input, current_output, context)
current_output = next_output
Expand Down
21 changes: 19 additions & 2 deletions nodestream/pipeline/pipeline_file_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import hashlib
from dataclasses import dataclass, field
from logging import getLogger
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set
Expand All @@ -12,6 +13,7 @@
from .argument_resolvers import set_config
from .class_loader import ClassLoader
from .normalizers import Normalizer
from .object_storage import ObjectStore
from .pipeline import Pipeline
from .scope_config import ScopeConfig
from .step import Step
Expand Down Expand Up @@ -51,6 +53,7 @@ class PipelineInitializationArguments:
on_effective_configuration_resolved: Optional[Callable[[List[Dict]], None]] = None
extra_steps: Optional[List[Step]] = None
effecitve_config_values: Optional[ScopeConfig] = None
object_store: ObjectStore = field(default_factory=ObjectStore.null)
angelosantos4 marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def for_introspection(cls):
Expand Down Expand Up @@ -149,19 +152,33 @@ def initialize_with_arguments(self, init_args: PipelineInitializationArguments):
if step_definition.should_be_loaded(init_args.annotations)
]
steps = steps_defined_in_file + (init_args.extra_steps or [])
return Pipeline(steps, step_outbox_size=init_args.step_outbox_size)
return Pipeline(
steps,
step_outbox_size=init_args.step_outbox_size,
object_store=init_args.object_store,
)


class PipelineFile:
def __init__(self, file_path: Path):
self.file_path = file_path
self.logger = getLogger(self.__class__.__name__)

def file_sha_256(self) -> str:
angelosantos4 marked this conversation as resolved.
Show resolved Hide resolved
sha = hashlib.sha256()
b = bytearray(128 * 1024)
mv = memoryview(b)
with self.file_path.open("rb", buffering=0) as file:
while n := file.readinto(mv):
sha.update(mv[:n])
return sha.hexdigest()

def load_pipeline(
self, init_args: Optional[PipelineInitializationArguments] = None
) -> Pipeline:
self.logger.info("Loading Pipeline")
init_args = init_args or PipelineInitializationArguments()
init_args.object_store = init_args.object_store.namespaced(self.file_sha_256())
contents = self.get_contents()
return contents.initialize_with_arguments(init_args)

Expand Down
12 changes: 9 additions & 3 deletions nodestream/pipeline/step.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import AsyncGenerator, Optional

from ..metrics import Metric, Metrics
from .object_storage import ObjectStore
from .progress_reporter import PipelineProgressReporter


Expand All @@ -11,14 +12,19 @@ class StepContext:
and report and perist information about the state of the pipeline.
"""

__slots__ = ("reporter", "index", "name")
__slots__ = ("reporter", "index", "name", "object_store")

def __init__(
self, name: str, index: int, reporter: PipelineProgressReporter
self,
name: str,
index: int,
reporter: PipelineProgressReporter,
object_store: ObjectStore,
) -> None:
self.name = name
self.reporter = reporter
self.index = index
self.object_store = object_store

def report_error(
self,
Expand Down Expand Up @@ -113,7 +119,7 @@ async def process_record(
"""
yield record

async def emit_outstanding_records(self):
async def emit_outstanding_records(self, context: StepContext):
"""Emit any outstanding records.

This method is called after all records have been processed. It is
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/test_pipeline_flush_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def test_extractor_with_flushes(mocker):


@pytest.mark.asyncio
async def test_flush_handling(writer, interpreter, test_extractor_with_flushes):
pipeline = Pipeline([test_extractor_with_flushes, interpreter, writer], 1000)
async def test_flush_handling(writer, interpreter, test_extractor_with_flushes, mocker):
pipeline = Pipeline(
[test_extractor_with_flushes, interpreter, writer], 1000, mocker.Mock()
)
await pipeline.run(PipelineProgressReporter())
assert writer.ingest_strategy.flush.call_count == 5
19 changes: 19 additions & 0 deletions tests/unit/pipeline/extractors/test_iterable_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,22 @@ async def test_iterable_extractor_range_factory():
subject = IterableExtractor.range(1, 6, 2)
results = [item async for item in subject.extract_records()]
assert_that(results, equal_to(expected_results))


@pytest.mark.asyncio
async def test_iterable_extractor_resume():
expected_results = [{"index": 3}, {"index": 5}]
subject = IterableExtractor.range(1, 6, 2)
await subject.resume_from_checkpoint(1)
results = [item async for item in subject.extract_records()]
assert_that(results, equal_to(expected_results))


@pytest.mark.asyncio
async def test_iterable_extractor_checkpoint():
subject = IterableExtractor([1, 2, 3])
generator = subject.extract_records()
await anext(generator)
await anext(generator)
checkpoint = await subject.make_checkpoint()
assert_that(checkpoint, equal_to(1))
3 changes: 2 additions & 1 deletion tests/unit/pipeline/test_object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
InvalidSignatureError,
MalformedSignedObjectError,
NullObjectStore,
ObjectStore,
SignedObject,
StaticNamespace,
)
Expand Down Expand Up @@ -106,7 +107,7 @@ def test_get_pickled_missing_object(directory_object_store):


def test_directory_object_store_default_directory():
store = DirectoryObjectStore.in_current_directory()
store = ObjectStore.in_current_directory()
assert_that(store.root, equal_to(Path.cwd() / ".nodestream" / "objects"))


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/pipeline/test_pipeline_progress_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@pytest.mark.asyncio
async def test_pipeline_progress_reporter_calls_with_reporting_frequency(mocker):
pipeline = Pipeline([IterableExtractor(range(100))], 10)
pipeline = Pipeline([IterableExtractor(range(100))], 10, mocker.Mock())
reporter = PipelineProgressReporter(reporting_frequency=10, callback=mocker.Mock())
await pipeline.run(reporter)
assert_that(reporter.callback.call_count, equal_to(10))
Expand Down
Loading