Skip to content

Commit

Permalink
Merge pull request #711 from communitiesuk/remove-mypy-exceptions-part-2
Browse files Browse the repository at this point in the history
Remove mypy exceptions part 2
  • Loading branch information
albertkol authored Sep 20, 2024
2 parents 0f6e493 + 79949c0 commit cd01229
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 74 deletions.
33 changes: 13 additions & 20 deletions data_store/db/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,17 @@ def query_extend_with_outcome_filter(base_query: Query, outcome_categories: list
:return: updated query.
"""
outcome_category_condition = (
ents.OutcomeDim.outcome_category.in_(outcome_categories) if outcome_categories else True
)

extended_query = (
base_query.join(
ents.OutcomeData,
or_(
ents.Project.id == ents.OutcomeData.project_id,
ents.ProgrammeJunction.id == ents.OutcomeData.programme_junction_id,
),
)
.join(ents.OutcomeDim)
.filter(outcome_category_condition)
)
extended_query = base_query.join(
ents.OutcomeData,
or_(
ents.Project.id == ents.OutcomeData.project_id,
ents.ProgrammeJunction.id == ents.OutcomeData.programme_junction_id,
),
).join(ents.OutcomeDim)

if outcome_categories:
extended_query = extended_query.filter(ents.OutcomeDim.outcome_category.in_(outcome_categories))

return extended_query

Expand Down Expand Up @@ -63,13 +59,10 @@ def query_extend_with_region_filter(base_query: Query, itl1_regions: list[str])
:return: updated query.
"""

geospatial_region_condition = ents.GeospatialDim.itl1_region_code.in_(itl1_regions) if itl1_regions else True
extended_query = base_query.join(ents.project_geospatial_association).join(ents.GeospatialDim)

extended_query = (
base_query.join(ents.project_geospatial_association)
.join(ents.GeospatialDim)
.filter(geospatial_region_condition)
)
if itl1_regions:
extended_query = extended_query.filter(ents.GeospatialDim.itl1_region_code.in_(itl1_regions))

return extended_query

Expand Down
39 changes: 24 additions & 15 deletions data_store/messaging/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,30 @@ def remove_errors_already_caught_by_null_failure(error_messages: list[Message])
msgs.BLANK_UNIT_OF_MEASUREMENT,
]

cells_covered_by_null_failures = [
(message.sheet, cell_index)
for message in error_messages
if message.description in null_descriptions
for cell_index in message.cell_indexes
]

filtered_errors = [
message
for message in error_messages
if not any(
((message.sheet, cell_index) in cells_covered_by_null_failures) for cell_index in message.cell_indexes
)
or message.description in null_descriptions
]
cells_covered_by_null_failures = []
for message in error_messages:
if message.description not in null_descriptions:
continue

if message.cell_indexes is None:
continue

for cell_index in message.cell_indexes:
cells_covered_by_null_failures.append((message.sheet, cell_index))

filtered_errors = []
for message in error_messages:
if message.cell_indexes is None:
continue

is_covered_by_null_failure = False
for cell_index in message.cell_indexes:
if (message.sheet, cell_index) in cells_covered_by_null_failures:
is_covered_by_null_failure = True
break

if not is_covered_by_null_failure or message.description in null_descriptions:
filtered_errors.append(message)

return filtered_errors

Expand Down
94 changes: 67 additions & 27 deletions data_store/messaging/tf_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,19 +294,32 @@ def _wrong_type_failure_message(self, validation_failure: WrongTypeFailure) -> M
sheet = self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table]
_, section = self.INTERNAL_COLUMN_TO_FORM_COLUMN_AND_SECTION[validation_failure.column]
actual_type = self.INTERNAL_TYPE_TO_MESSAGE_FORMAT[validation_failure.actual_type]
cell_index = self._construct_cell_index(
table=validation_failure.table,
column=validation_failure.column,
row_index=validation_failure.row_index,

# if column is a str make it a list
columns = (
[validation_failure.column] if isinstance(validation_failure.column, str) else validation_failure.column
)

cell_index: tuple[str, ...] | None = tuple(
self._construct_cell_index(
table=validation_failure.table,
column=column,
row_index=validation_failure.row_index,
)
for column in columns
)

if sheet == "Outcomes":
_, section = (
"Financial Year 2022/21 - Financial Year 2029/30",
("Outcome Indicators (excluding " "footfall) and Footfall Indicator"),
)
cell_index = self._get_cell_indexes_for_outcomes(validation_failure.failed_row)

cell_index = (
(self._get_cell_indexes_for_outcomes(validation_failure.failed_row),)
if validation_failure.failed_row is not None
else None
)
if validation_failure.expected_type == datetime:
message = self.msgs.WRONG_TYPE_DATE.format(wrong_type=actual_type)
elif sheet == "PSI":
Expand All @@ -318,7 +331,7 @@ def _wrong_type_failure_message(self, validation_failure: WrongTypeFailure) -> M
else:
message = self.msgs.WRONG_TYPE_UNKNOWN

return Message(sheet, section, (cell_index,), message, validation_failure.__class__.__name__)
return Message(sheet, section, cell_index, message, validation_failure.__class__.__name__)

def _invalid_enum_value_failure_message(self, validation_failure: InvalidEnumValueFailure) -> Message:
sheet = self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table]
Expand All @@ -331,8 +344,7 @@ def _invalid_enum_value_failure_message(self, validation_failure: InvalidEnumVal
# +5 as GeographyIndicator is 5 rows below Footfall Indicator
if column == "Geography Indicator":
actual_index = validation_failure.row_index + 5
cell_index = f"C{actual_index}"
return Message(sheet, section, (cell_index,), message, validation_failure.__class__.__name__)
return Message(sheet, section, (f"C{actual_index}",), message, validation_failure.__class__.__name__)

# additional logic for risk location
if sheet == "Risk Register":
Expand All @@ -343,13 +355,21 @@ def _invalid_enum_value_failure_message(self, validation_failure: InvalidEnumVal
project_number = get_project_number_by_position(validation_failure.row_index, validation_failure.table)
section = f"Project Funding Profiles - Project {project_number}"

cell_index = self._construct_cell_index(
table=validation_failure.table,
column=validation_failure.column,
row_index=validation_failure.row_index,
# if column is a str make it a list
columns = (
[validation_failure.column] if isinstance(validation_failure.column, str) else validation_failure.column
)

cell_index = tuple(
self._construct_cell_index(
table=validation_failure.table,
column=column,
row_index=validation_failure.row_index,
)
for column in columns
)

return Message(sheet, section, (cell_index,), message, validation_failure.__class__.__name__)
return Message(sheet, section, cell_index, message, validation_failure.__class__.__name__)

def _non_nullable_constraint_failure_message(self, validation_failure: NonNullableConstraintFailure) -> Message:
"""Generate error message components for NonNullableConstraintFailure.
Expand All @@ -363,10 +383,18 @@ def _non_nullable_constraint_failure_message(self, validation_failure: NonNullab
sheet = self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table]
column, section = self.INTERNAL_COLUMN_TO_FORM_COLUMN_AND_SECTION[validation_failure.column]

cell_index = self._construct_cell_index(
table=validation_failure.table,
column=validation_failure.column,
row_index=validation_failure.row_index,
# if column is a str make it a list
columns = (
[validation_failure.column] if isinstance(validation_failure.column, str) else validation_failure.column
)

cell_index: tuple[str, ...] | None = tuple(
self._construct_cell_index(
table=validation_failure.table,
column=column,
row_index=validation_failure.row_index,
)
for column in columns
)

message = self.msgs.BLANK
Expand All @@ -381,13 +409,17 @@ def _non_nullable_constraint_failure_message(self, validation_failure: NonNullab
if column == "Financial Year 2022/21 - Financial Year 2025/26":
section = "Outcome Indicators (excluding footfall) / Footfall Indicator"
message = self.msgs.BLANK_ZERO
cell_index = self._get_cell_indexes_for_outcomes(validation_failure.failed_row)
cell_index = (
(self._get_cell_indexes_for_outcomes(validation_failure.failed_row),)
if validation_failure.failed_row is not None
else None
)
elif sheet == "Funding Profiles":
message = self.msgs.BLANK_ZERO
elif section == "Programme-Wide Progress Summary":
message = self.msgs.BLANK

return Message(sheet, section, (cell_index,), message, validation_failure.__class__.__name__)
return Message(sheet, section, cell_index, message, validation_failure.__class__.__name__)

def _unauthorised_submission_failure(self, validation_failure: UnauthorisedSubmissionFailure) -> Message:
places_or_funds = join_as_string(validation_failure.expected_values)
Expand All @@ -397,18 +429,23 @@ def _unauthorised_submission_failure(self, validation_failure: UnauthorisedSubmi
return Message(None, None, None, message, validation_failure.__class__.__name__)

def _generic_failure(self, validation_failure: GenericFailure) -> Message:
if not validation_failure.cell_index:
if validation_failure.cell_index is not None:
cell_indexes = (validation_failure.cell_index,)
elif validation_failure.column is not None:
validation_failure.cell_index = self._construct_cell_index(
validation_failure.table,
validation_failure.column,
validation_failure.row_index,
table=validation_failure.table,
column=validation_failure.column,
row_index=validation_failure.row_index or 0,
)
sheet = self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table]

cell_indexes = (validation_failure.cell_index,)
else:
cell_indexes = None

return Message(
sheet,
self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table],
validation_failure.section,
(validation_failure.cell_index,),
cell_indexes,
validation_failure.message,
validation_failure.__class__.__name__,
)
Expand All @@ -424,7 +461,7 @@ def _construct_cell_index(self, table: str, column: str, row_index: int) -> str:
column_letter = self.TABLE_AND_COLUMN_TO_ORIGINAL_COLUMN_LETTER[table][column]
return column_letter.format(i=row_index or "")

def _get_section_for_outcomes_by_row_index(self, index):
def _get_section_for_outcomes_by_row_index(self, index: int) -> str:
return "Outcomes Indicators (excluding footfall)" if index < 60 else "Footfall Indicator"

def _get_cell_indexes_for_outcomes(self, failed_row: pd.Series) -> str:
Expand All @@ -447,6 +484,9 @@ def _get_cell_indexes_for_outcomes(self, failed_row: pd.Series) -> str:
financial_year = self._get_uk_financial_year_start(start_date)
index = failed_row.name

if not isinstance(index, int):
raise TypeError(f"Cell index not int for failed row {failed_row}")

# footfall outcomes starts from row 60
if self._get_section_for_outcomes_by_row_index(index) == "Footfall Indicator":
# row for 'Amount' column is end number of start year of financial year * 5 + 'Footfall Indicator' index
Expand Down
11 changes: 0 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,6 @@ module = [
]
disable_error_code = ["import-untyped", "name-defined", "attr-defined", "import-not-found", "index"]

[[tool.mypy.overrides]]
module = "data_store/db/queries"
disable_error_code = ["arg-type"]

[[tool.mypy.overrides]]
module = "data_store.messaging.messaging"
disable_error_code = ["union-attr"]

[[tool.mypy.overrides]]
module = "data_store.messaging.tf_messaging"
disable_error_code = ["call-overload", "arg-type"]

[[tool.mypy.overrides]]
module = "data_store/serialisation/data_serialiser"
Expand Down
36 changes: 35 additions & 1 deletion tests/data_store_tests/messaging_tests/test_tf_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ def test_authorised_submission():

def test_generic_failure():
test_messeger = TFMessenger()
test_messeger._generic_failure(

message = test_messeger._generic_failure(
GenericFailure(
table="Project Details",
section="A Section",
Expand All @@ -442,6 +443,30 @@ def test_generic_failure():
)
)

assert message == Message(
sheet="Project Admin",
section="A Section",
cell_indexes=("C1",),
description="A message",
error_type="GenericFailure",
)


def test_generic_failure_when_column_is_none():
test_messeger = TFMessenger()

with pytest.raises(ValueError):
test_messeger._generic_failure(
GenericFailure(
table="Project Details",
section="A Section",
cell_index=None,
message="A message",
column=None,
row_index=None,
)
)


def test_failures_to_messages():
failure1 = InvalidEnumValueFailure(
Expand Down Expand Up @@ -742,6 +767,15 @@ def test_get_cell_indexes_for_outcomes():
assert cell4 == "G23"


def test_get_cell_indexes_for_outcomes_throws_exception():
test_messenger = TFMessenger()

failed_row = pd.Series({"Start_Date": pd.to_datetime("2024-05-01 12:00:00")}, name=None)

with pytest.raises(TypeError):
test_messenger._get_cell_indexes_for_outcomes(failed_row)


def test_get_uk_financial_year_start():
# Test case where start_date is in the same financial year
start_date_1 = pd.to_datetime("2023-05-01 12:00:00")
Expand Down

0 comments on commit cd01229

Please sign in to comment.