diff --git a/nodestream/pipeline/extractors/files.py b/nodestream/pipeline/extractors/files.py index c143e1315..59191dcbd 100644 --- a/nodestream/pipeline/extractors/files.py +++ b/nodestream/pipeline/extractors/files.py @@ -8,6 +8,7 @@ from csv import DictReader from glob import glob from io import BufferedReader, IOBase, TextIOWrapper +from logging import getLogger from pathlib import Path from typing import ( Any, @@ -176,6 +177,16 @@ def get_files(self) -> AsyncIterator[ReadableFile]: """ raise NotImplementedError + def describe(self) -> str: + """Return a human-readable description of the file source. + + This method should return a human-readable description of the file source. + The description should be a string that describes the file source in a + way that is understandable to the user. The description should be + concise and informative. + """ + return str(self) + @SUPPORTED_FILE_FORMAT_REGISTRY.connect_baseclass class FileCodec(Pluggable, ABC): @@ -451,6 +462,12 @@ async def get_files(self) -> AsyncIterator[ReadableFile]: for path in self.paths: yield LocalFile(path) + def describe(self) -> str: + if len(self.paths) == 1: + return f"{self.paths[0]}" + else: + return f"{len(self.paths)} local files" + class RemoteFileSource(FileSource, alias="http"): """A class that represents a source of remote files to be read. @@ -472,6 +489,12 @@ async def get_files(self) -> AsyncIterator[ReadableFile]: for url in self.urls: yield RemoteFile(url, client, self.memory_spooling_max_size) + def describe(self) -> str: + if len(self.urls) == 1: + return f"{self.urls[0]}" + else: + return f"{len(self.urls)} remote files" + class UnifiedFileExtractor(Extractor): """A class that extracts records from files. @@ -491,6 +514,7 @@ def from_file_data(cls, sources: List[Dict[str, Any]]) -> "UnifiedFileExtractor" def __init__(self, file_sources: Iterable[FileSource]) -> None: self.file_sources = file_sources + self.logger = getLogger(__name__) async def read_file(self, file: ReadableFile) -> Iterable[JsonLikeDocument]: intermediaries: List[AsyncContextManager[ReadableFile]] = [] @@ -536,10 +560,17 @@ async def read_file(self, file: ReadableFile) -> Iterable[JsonLikeDocument]: async def extract_records(self) -> AsyncGenerator[Any, Any]: for file_source in self.file_sources: + total_files_from_source = 0 async for file in file_source.get_files(): + total_files_from_source += 1 async for record in self.read_file(file): yield record + if total_files_from_source == 0: + self.logger.warning( + f"No files found for source: {file_source.describe()}" + ) + # DEPRECATED CODE BELOW ## # diff --git a/tests/unit/pipeline/extractors/test_files.py b/tests/unit/pipeline/extractors/test_files.py index 72d14b32a..556d7117d 100644 --- a/tests/unit/pipeline/extractors/test_files.py +++ b/tests/unit/pipeline/extractors/test_files.py @@ -8,12 +8,14 @@ import pandas as pd import pytest import yaml -from hamcrest import assert_that, equal_to, has_length +from hamcrest import assert_that, contains_string, equal_to, has_length from nodestream.pipeline.extractors.files import ( FileExtractor, + FileSource, LocalFileSource, RemoteFileExtractor, + RemoteFileSource, UnifiedFileExtractor, ) @@ -302,6 +304,64 @@ async def test_remote_file_extractor_extract_records(mocker, httpx_mock): assert_that(results, equal_to([SIMPLE_RECORD, SIMPLE_RECORD])) +@pytest.mark.asyncio +async def test_no_files_found_from_local_source(mocker): + subject = UnifiedFileExtractor([LocalFileSource([])]) + subject.logger = mocker.Mock() + results = [r async for r in subject.extract_records()] + assert_that(results, equal_to([])) + subject.logger.warning.assert_called_once_with( + "No files found for source: 0 local files" + ) + + +@pytest.mark.asyncio +async def test_no_files_found_from_remote_source(mocker): + subject = UnifiedFileExtractor([RemoteFileSource([], 10)]) + subject.logger = mocker.Mock() + results = [r async for r in subject.extract_records()] + assert_that(results, equal_to([])) + subject.logger.warning.assert_called_once_with( + "No files found for source: 0 remote files" + ) + + +def test_remote_file_source_single_file_description(): + url = "https://example.com/file.json" + subject = RemoteFileSource([url], 10) + assert_that(subject.describe(), equal_to(url)) + + +def test_remote_file_source_multiple_file_description(): + urls = ["https://example.com/file.json", "https://example.com/file2.json"] + subject = RemoteFileSource(urls, 10) + assert_that(subject.describe(), equal_to("2 remote files")) + + +def test_local_file_source_single_file_description(fixture_directory): + path = Path(f"{fixture_directory}/file.json") + subject = LocalFileSource([path]) + assert_that(subject.describe(), equal_to(str(path))) + + +def test_local_file_source_multiple_file_description(fixture_directory): + paths = [ + Path(f"{fixture_directory}/file.json"), + Path(f"{fixture_directory}/file2.json"), + ] + subject = LocalFileSource(paths) + assert_that(subject.describe(), equal_to("2 local files")) + + +def test_file_source_default_description(): + class SomeFileSource(FileSource): + def get_files(self): + pass + + subject = SomeFileSource() + assert_that(subject.describe(), contains_string("SomeFileSource")) + + @pytest.mark.asyncio async def test_unified_extractor_multiple_source_types(json_file, csv_file, httpx_mock): urls = ["https://example.com/file.json", "https://example.com/file2.json"]