Skip to content

Commit

Permalink
Merge pull request #194 from moj-analytical-services/fix-s3-read-schema
Browse files Browse the repository at this point in the history
Fix reading parquet schema from S3
  • Loading branch information
oliver-critchfield authored May 10, 2023
2 parents 363095a + 25b7eef commit 87203ae
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 26 deletions.
26 changes: 15 additions & 11 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## 6.2.1 2023-05-10

- Updated parquet validator to read schemas from S3

## 6.2.0 2023-01-10

- Added parquet validator for validating parquet schemas
Expand Down Expand Up @@ -90,20 +94,20 @@ allow there to be some misallignment between the meta and the data

```
[ALL CLOSED ISSUES]
- issue #140
- issue #140
- issue #139
- issue #133
- issue #133
- issue #132
- issue #131
- issue #130
- issue #129
- issue #128
- issue #131
- issue #130
- issue #129
- issue #128
- issue #125
- issue #122
- issue #121
- issue #120
- issue #110
- issue #100
- issue #122
- issue #121
- issue #120
- issue #110
- issue #100
- issue #98
- issue #87
- issue #70
Expand Down
2 changes: 1 addition & 1 deletion data_linter/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "6.2.0"
__version__ = "6.2.1"
33 changes: 25 additions & 8 deletions data_linter/validators/parquet_validator.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,53 @@
import logging
import os
from typing import Union


from data_linter.validators.base import (
BaseTableValidator,
)
import pyarrow.parquet as pq
from dataengineeringutils3.s3 import s3_path_to_bucket_key
from mojap_metadata import Metadata
from mojap_metadata.converters.arrow_converter import ArrowConverter
from typing import Union
import pyarrow.parquet as pq
from pyarrow import Schema
from pyarrow.fs import S3FileSystem

from data_linter.validators.base import BaseTableValidator

log = logging.getLogger("root")
default_date_format = "%Y-%m-%d"
default_datetime_format = "%Y-%m-%d %H:%M:%S"
aws_default_region = os.getenv(
"AWS_DEFAULT_REGION", os.getenv("AWS_REGION", "eu-west-1")
)


class ParquetValidator(BaseTableValidator):
"""
Validator for checking that a parquet file's schema matches a given Metadata.
For validating the data itself, use the Pandas validator.
"""

def __init__(
self,
filepath: str,
table_params: dict,
metadata: Union[dict, str, Metadata],
**kwargs
**kwargs,
):
super().__init__(filepath, table_params, metadata)

@staticmethod
def _read_schema(filepath: str) -> Schema:
if filepath.startswith("s3://"):
s3fs = S3FileSystem(region=aws_default_region)
b, k = s3_path_to_bucket_key(filepath)
pa_pth = os.path.join(b, k)
with s3fs.open_input_file(pa_pth) as file:
schema = pq.read_schema(file).remove_metadata()
else:
schema = pq.read_schema(filepath).remove_metadata()
return schema

def read_data_and_validate(self):
table_arrow_schema = pq.read_schema(self.filepath).remove_metadata()
table_arrow_schema = self._read_schema(self.filepath)
ac = ArrowConverter()
metadata_arrow_schema = ac.generate_from_meta(self.metadata).remove_metadata()
metas_match = table_arrow_schema.equals(metadata_arrow_schema)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "data_linter"
version = "6.2.0"
version = "6.2.1"
description = "data linter"
authors = ["Thomas Hirsch <thomas.hirsch@digital.justice.gov.uk>",
"George Kelly <george.kelly@digital.justice.gov.uk>",
Expand Down
16 changes: 14 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import io
import boto3
import os
import pathlib
from contextlib import contextmanager
from tempfile import NamedTemporaryFile

import boto3
from dataengineeringutils3.s3 import s3_path_to_bucket_key


Expand Down Expand Up @@ -68,6 +71,15 @@ def open_input_stream(s3_file_path_in: str) -> io.BytesIO:
finally:
obj_io_bytes.close()

@staticmethod
@contextmanager
def open_input_file(s3_file_path_in: str):
s3_client = boto3.client("s3")
bucket, key = s3_path_to_bucket_key(s3_file_path_in)
tmp_file = NamedTemporaryFile(suffix=pathlib.Path(key).suffix)
s3_client.download_file(bucket, key, tmp_file.name)
yield tmp_file.name


def mock_get_file(*args, **kwargs):
return MockS3FilesystemReadInputStream()
61 changes: 58 additions & 3 deletions tests/test_parquet_validator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,62 @@
import pytest
import os
from data_linter.validators.parquet_validator import ParquetValidator

import awswrangler as wr
import boto3
import pytest
from mojap_metadata import Metadata
from mojap_metadata.converters.arrow_converter import ArrowConverter
from moto import mock_s3

import data_linter.validators.parquet_validator as pqv
from tests.helpers import mock_get_file

bucket = "dummy-bucket"


@mock_s3
@pytest.mark.parametrize(
"filepath, mock_s3, meta_path",
[
(
"tests/data/parquet_validator/table1.parquet",
True,
"tests/data/parquet_validator/meta_data/table1_pass.json",
),
(
"tests/data/parquet_validator/table1.parquet",
False,
"tests/data/parquet_validator/meta_data/table1_pass.json",
),
],
)
def test_parquet_schema_reader(filepath, mock_s3, meta_path, monkeypatch):
if mock_s3:

s3_client = boto3.client("s3")
_ = s3_client.create_bucket(
Bucket=bucket,
CreateBucketConfiguration={"LocationConstraint": "eu-west-1"},
)

full_path = f"s3://{bucket}/{filepath}"

wr.s3.upload(filepath, full_path)

else:
full_path = filepath

_ = monkeypatch.setattr(pqv, "S3FileSystem", mock_get_file)

schema = pqv.ParquetValidator._read_schema(full_path)
ac = ArrowConverter()

meta = ac.generate_to_meta(schema)
column_names = sorted(meta.column_names)

expected_meta = Metadata.from_json(meta_path)
expected_column_names = sorted(expected_meta.column_names)

assert column_names == expected_column_names


@pytest.mark.parametrize(
Expand All @@ -17,6 +72,6 @@ def test_parquet_validator(meta_file, expected_pass):
os.path.join("tests/data/parquet_validator/meta_data/", meta_file)
)
file_path = "tests/data/parquet_validator/table1.parquet"
pv = ParquetValidator(filepath=file_path, table_params={}, metadata=meta)
pv = pqv.ParquetValidator(filepath=file_path, table_params={}, metadata=meta)
pv.read_data_and_validate()
assert pv.response.result["valid"] == expected_pass

0 comments on commit 87203ae

Please sign in to comment.