Skip to content

Commit

Permalink
Merge pull request #343 from nodestream-proj/feature/improve-error-me…
Browse files Browse the repository at this point in the history
…ssages-file-extractor

Add log message for when a source yields no files
  • Loading branch information
zprobst authored Aug 7, 2024
2 parents 008868f + 9b5e95a commit 46483ba
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
31 changes: 31 additions & 0 deletions nodestream/pipeline/extractors/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from csv import DictReader
from glob import glob
from io import BufferedReader, IOBase, TextIOWrapper
from logging import getLogger
from pathlib import Path
from typing import (
Any,
Expand Down Expand Up @@ -176,6 +177,16 @@ def get_files(self) -> AsyncIterator[ReadableFile]:
"""
raise NotImplementedError

def describe(self) -> str:
"""Return a human-readable description of the file source.
This method should return a human-readable description of the file source.
The description should be a string that describes the file source in a
way that is understandable to the user. The description should be
concise and informative.
"""
return str(self)


@SUPPORTED_FILE_FORMAT_REGISTRY.connect_baseclass
class FileCodec(Pluggable, ABC):
Expand Down Expand Up @@ -451,6 +462,12 @@ async def get_files(self) -> AsyncIterator[ReadableFile]:
for path in self.paths:
yield LocalFile(path)

def describe(self) -> str:
if len(self.paths) == 1:
return f"{self.paths[0]}"
else:
return f"{len(self.paths)} local files"


class RemoteFileSource(FileSource, alias="http"):
"""A class that represents a source of remote files to be read.
Expand All @@ -472,6 +489,12 @@ async def get_files(self) -> AsyncIterator[ReadableFile]:
for url in self.urls:
yield RemoteFile(url, client, self.memory_spooling_max_size)

def describe(self) -> str:
if len(self.urls) == 1:
return f"{self.urls[0]}"
else:
return f"{len(self.urls)} remote files"


class UnifiedFileExtractor(Extractor):
"""A class that extracts records from files.
Expand All @@ -491,6 +514,7 @@ def from_file_data(cls, sources: List[Dict[str, Any]]) -> "UnifiedFileExtractor"

def __init__(self, file_sources: Iterable[FileSource]) -> None:
self.file_sources = file_sources
self.logger = getLogger(__name__)

async def read_file(self, file: ReadableFile) -> Iterable[JsonLikeDocument]:
intermediaries: List[AsyncContextManager[ReadableFile]] = []
Expand Down Expand Up @@ -536,10 +560,17 @@ async def read_file(self, file: ReadableFile) -> Iterable[JsonLikeDocument]:

async def extract_records(self) -> AsyncGenerator[Any, Any]:
for file_source in self.file_sources:
total_files_from_source = 0
async for file in file_source.get_files():
total_files_from_source += 1
async for record in self.read_file(file):
yield record

if total_files_from_source == 0:
self.logger.warning(
f"No files found for source: {file_source.describe()}"
)


# DEPRECATED CODE BELOW ##
#
Expand Down
62 changes: 61 additions & 1 deletion tests/unit/pipeline/extractors/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import pandas as pd
import pytest
import yaml
from hamcrest import assert_that, equal_to, has_length
from hamcrest import assert_that, contains_string, equal_to, has_length

from nodestream.pipeline.extractors.files import (
FileExtractor,
FileSource,
LocalFileSource,
RemoteFileExtractor,
RemoteFileSource,
UnifiedFileExtractor,
)

Expand Down Expand Up @@ -302,6 +304,64 @@ async def test_remote_file_extractor_extract_records(mocker, httpx_mock):
assert_that(results, equal_to([SIMPLE_RECORD, SIMPLE_RECORD]))


@pytest.mark.asyncio
async def test_no_files_found_from_local_source(mocker):
subject = UnifiedFileExtractor([LocalFileSource([])])
subject.logger = mocker.Mock()
results = [r async for r in subject.extract_records()]
assert_that(results, equal_to([]))
subject.logger.warning.assert_called_once_with(
"No files found for source: 0 local files"
)


@pytest.mark.asyncio
async def test_no_files_found_from_remote_source(mocker):
subject = UnifiedFileExtractor([RemoteFileSource([], 10)])
subject.logger = mocker.Mock()
results = [r async for r in subject.extract_records()]
assert_that(results, equal_to([]))
subject.logger.warning.assert_called_once_with(
"No files found for source: 0 remote files"
)


def test_remote_file_source_single_file_description():
url = "https://example.com/file.json"
subject = RemoteFileSource([url], 10)
assert_that(subject.describe(), equal_to(url))


def test_remote_file_source_multiple_file_description():
urls = ["https://example.com/file.json", "https://example.com/file2.json"]
subject = RemoteFileSource(urls, 10)
assert_that(subject.describe(), equal_to("2 remote files"))


def test_local_file_source_single_file_description(fixture_directory):
path = Path(f"{fixture_directory}/file.json")
subject = LocalFileSource([path])
assert_that(subject.describe(), equal_to(str(path)))


def test_local_file_source_multiple_file_description(fixture_directory):
paths = [
Path(f"{fixture_directory}/file.json"),
Path(f"{fixture_directory}/file2.json"),
]
subject = LocalFileSource(paths)
assert_that(subject.describe(), equal_to("2 local files"))


def test_file_source_default_description():
class SomeFileSource(FileSource):
def get_files(self):
pass

subject = SomeFileSource()
assert_that(subject.describe(), contains_string("SomeFileSource"))


@pytest.mark.asyncio
async def test_unified_extractor_multiple_source_types(json_file, csv_file, httpx_mock):
urls = ["https://example.com/file.json", "https://example.com/file2.json"]
Expand Down

0 comments on commit 46483ba

Please sign in to comment.