diff --git a/data_store/db/queries.py b/data_store/db/queries.py index 7604118b0..c50889d16 100644 --- a/data_store/db/queries.py +++ b/data_store/db/queries.py @@ -21,21 +21,17 @@ def query_extend_with_outcome_filter(base_query: Query, outcome_categories: list :return: updated query. """ - outcome_category_condition = ( - ents.OutcomeDim.outcome_category.in_(outcome_categories) if outcome_categories else True - ) - extended_query = ( - base_query.join( - ents.OutcomeData, - or_( - ents.Project.id == ents.OutcomeData.project_id, - ents.ProgrammeJunction.id == ents.OutcomeData.programme_junction_id, - ), - ) - .join(ents.OutcomeDim) - .filter(outcome_category_condition) - ) + extended_query = base_query.join( + ents.OutcomeData, + or_( + ents.Project.id == ents.OutcomeData.project_id, + ents.ProgrammeJunction.id == ents.OutcomeData.programme_junction_id, + ), + ).join(ents.OutcomeDim) + + if outcome_categories: + extended_query = extended_query.filter(ents.OutcomeDim.outcome_category.in_(outcome_categories)) return extended_query @@ -63,13 +59,10 @@ def query_extend_with_region_filter(base_query: Query, itl1_regions: list[str]) :return: updated query. """ - geospatial_region_condition = ents.GeospatialDim.itl1_region_code.in_(itl1_regions) if itl1_regions else True + extended_query = base_query.join(ents.project_geospatial_association).join(ents.GeospatialDim) - extended_query = ( - base_query.join(ents.project_geospatial_association) - .join(ents.GeospatialDim) - .filter(geospatial_region_condition) - ) + if itl1_regions: + extended_query = extended_query.filter(ents.GeospatialDim.itl1_region_code.in_(itl1_regions)) return extended_query diff --git a/data_store/messaging/messaging.py b/data_store/messaging/messaging.py index 77f087b90..97260514e 100644 --- a/data_store/messaging/messaging.py +++ b/data_store/messaging/messaging.py @@ -45,21 +45,30 @@ def remove_errors_already_caught_by_null_failure(error_messages: list[Message]) msgs.BLANK_UNIT_OF_MEASUREMENT, ] - cells_covered_by_null_failures = [ - (message.sheet, cell_index) - for message in error_messages - if message.description in null_descriptions - for cell_index in message.cell_indexes - ] - - filtered_errors = [ - message - for message in error_messages - if not any( - ((message.sheet, cell_index) in cells_covered_by_null_failures) for cell_index in message.cell_indexes - ) - or message.description in null_descriptions - ] + cells_covered_by_null_failures = [] + for message in error_messages: + if message.description not in null_descriptions: + continue + + if message.cell_indexes is None: + continue + + for cell_index in message.cell_indexes: + cells_covered_by_null_failures.append((message.sheet, cell_index)) + + filtered_errors = [] + for message in error_messages: + if message.cell_indexes is None: + continue + + is_covered_by_null_failure = False + for cell_index in message.cell_indexes: + if (message.sheet, cell_index) in cells_covered_by_null_failures: + is_covered_by_null_failure = True + break + + if not is_covered_by_null_failure or message.description in null_descriptions: + filtered_errors.append(message) return filtered_errors diff --git a/data_store/messaging/tf_messaging.py b/data_store/messaging/tf_messaging.py index ff4b432ed..c634616f4 100644 --- a/data_store/messaging/tf_messaging.py +++ b/data_store/messaging/tf_messaging.py @@ -294,10 +294,19 @@ def _wrong_type_failure_message(self, validation_failure: WrongTypeFailure) -> M sheet = self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table] _, section = self.INTERNAL_COLUMN_TO_FORM_COLUMN_AND_SECTION[validation_failure.column] actual_type = self.INTERNAL_TYPE_TO_MESSAGE_FORMAT[validation_failure.actual_type] - cell_index = self._construct_cell_index( - table=validation_failure.table, - column=validation_failure.column, - row_index=validation_failure.row_index, + + # if column is a str make it a list + columns = ( + [validation_failure.column] if isinstance(validation_failure.column, str) else validation_failure.column + ) + + cell_index: tuple[str, ...] | None = tuple( + self._construct_cell_index( + table=validation_failure.table, + column=column, + row_index=validation_failure.row_index, + ) + for column in columns ) if sheet == "Outcomes": @@ -305,8 +314,12 @@ def _wrong_type_failure_message(self, validation_failure: WrongTypeFailure) -> M "Financial Year 2022/21 - Financial Year 2029/30", ("Outcome Indicators (excluding " "footfall) and Footfall Indicator"), ) - cell_index = self._get_cell_indexes_for_outcomes(validation_failure.failed_row) + cell_index = ( + (self._get_cell_indexes_for_outcomes(validation_failure.failed_row),) + if validation_failure.failed_row is not None + else None + ) if validation_failure.expected_type == datetime: message = self.msgs.WRONG_TYPE_DATE.format(wrong_type=actual_type) elif sheet == "PSI": @@ -318,7 +331,7 @@ def _wrong_type_failure_message(self, validation_failure: WrongTypeFailure) -> M else: message = self.msgs.WRONG_TYPE_UNKNOWN - return Message(sheet, section, (cell_index,), message, validation_failure.__class__.__name__) + return Message(sheet, section, cell_index, message, validation_failure.__class__.__name__) def _invalid_enum_value_failure_message(self, validation_failure: InvalidEnumValueFailure) -> Message: sheet = self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table] @@ -331,8 +344,7 @@ def _invalid_enum_value_failure_message(self, validation_failure: InvalidEnumVal # +5 as GeographyIndicator is 5 rows below Footfall Indicator if column == "Geography Indicator": actual_index = validation_failure.row_index + 5 - cell_index = f"C{actual_index}" - return Message(sheet, section, (cell_index,), message, validation_failure.__class__.__name__) + return Message(sheet, section, (f"C{actual_index}",), message, validation_failure.__class__.__name__) # additional logic for risk location if sheet == "Risk Register": @@ -343,13 +355,21 @@ def _invalid_enum_value_failure_message(self, validation_failure: InvalidEnumVal project_number = get_project_number_by_position(validation_failure.row_index, validation_failure.table) section = f"Project Funding Profiles - Project {project_number}" - cell_index = self._construct_cell_index( - table=validation_failure.table, - column=validation_failure.column, - row_index=validation_failure.row_index, + # if column is a str make it a list + columns = ( + [validation_failure.column] if isinstance(validation_failure.column, str) else validation_failure.column + ) + + cell_index = tuple( + self._construct_cell_index( + table=validation_failure.table, + column=column, + row_index=validation_failure.row_index, + ) + for column in columns ) - return Message(sheet, section, (cell_index,), message, validation_failure.__class__.__name__) + return Message(sheet, section, cell_index, message, validation_failure.__class__.__name__) def _non_nullable_constraint_failure_message(self, validation_failure: NonNullableConstraintFailure) -> Message: """Generate error message components for NonNullableConstraintFailure. @@ -363,10 +383,18 @@ def _non_nullable_constraint_failure_message(self, validation_failure: NonNullab sheet = self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table] column, section = self.INTERNAL_COLUMN_TO_FORM_COLUMN_AND_SECTION[validation_failure.column] - cell_index = self._construct_cell_index( - table=validation_failure.table, - column=validation_failure.column, - row_index=validation_failure.row_index, + # if column is a str make it a list + columns = ( + [validation_failure.column] if isinstance(validation_failure.column, str) else validation_failure.column + ) + + cell_index: tuple[str, ...] | None = tuple( + self._construct_cell_index( + table=validation_failure.table, + column=column, + row_index=validation_failure.row_index, + ) + for column in columns ) message = self.msgs.BLANK @@ -381,13 +409,17 @@ def _non_nullable_constraint_failure_message(self, validation_failure: NonNullab if column == "Financial Year 2022/21 - Financial Year 2025/26": section = "Outcome Indicators (excluding footfall) / Footfall Indicator" message = self.msgs.BLANK_ZERO - cell_index = self._get_cell_indexes_for_outcomes(validation_failure.failed_row) + cell_index = ( + (self._get_cell_indexes_for_outcomes(validation_failure.failed_row),) + if validation_failure.failed_row is not None + else None + ) elif sheet == "Funding Profiles": message = self.msgs.BLANK_ZERO elif section == "Programme-Wide Progress Summary": message = self.msgs.BLANK - return Message(sheet, section, (cell_index,), message, validation_failure.__class__.__name__) + return Message(sheet, section, cell_index, message, validation_failure.__class__.__name__) def _unauthorised_submission_failure(self, validation_failure: UnauthorisedSubmissionFailure) -> Message: places_or_funds = join_as_string(validation_failure.expected_values) @@ -397,18 +429,23 @@ def _unauthorised_submission_failure(self, validation_failure: UnauthorisedSubmi return Message(None, None, None, message, validation_failure.__class__.__name__) def _generic_failure(self, validation_failure: GenericFailure) -> Message: - if not validation_failure.cell_index: + if validation_failure.cell_index is not None: + cell_indexes = (validation_failure.cell_index,) + elif validation_failure.column is not None: validation_failure.cell_index = self._construct_cell_index( - validation_failure.table, - validation_failure.column, - validation_failure.row_index, + table=validation_failure.table, + column=validation_failure.column, + row_index=validation_failure.row_index or 0, ) - sheet = self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table] + + cell_indexes = (validation_failure.cell_index,) + else: + cell_indexes = None return Message( - sheet, + self.INTERNAL_TABLE_TO_FORM_SHEET[validation_failure.table], validation_failure.section, - (validation_failure.cell_index,), + cell_indexes, validation_failure.message, validation_failure.__class__.__name__, ) @@ -424,7 +461,7 @@ def _construct_cell_index(self, table: str, column: str, row_index: int) -> str: column_letter = self.TABLE_AND_COLUMN_TO_ORIGINAL_COLUMN_LETTER[table][column] return column_letter.format(i=row_index or "") - def _get_section_for_outcomes_by_row_index(self, index): + def _get_section_for_outcomes_by_row_index(self, index: int) -> str: return "Outcomes Indicators (excluding footfall)" if index < 60 else "Footfall Indicator" def _get_cell_indexes_for_outcomes(self, failed_row: pd.Series) -> str: @@ -447,6 +484,9 @@ def _get_cell_indexes_for_outcomes(self, failed_row: pd.Series) -> str: financial_year = self._get_uk_financial_year_start(start_date) index = failed_row.name + if not isinstance(index, int): + raise TypeError(f"Cell index not int for failed row {failed_row}") + # footfall outcomes starts from row 60 if self._get_section_for_outcomes_by_row_index(index) == "Footfall Indicator": # row for 'Amount' column is end number of start year of financial year * 5 + 'Footfall Indicator' index diff --git a/pyproject.toml b/pyproject.toml index a1ebd072b..512151060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,17 +82,6 @@ module = [ ] disable_error_code = ["import-untyped", "name-defined", "attr-defined", "import-not-found", "index"] -[[tool.mypy.overrides]] -module = "data_store/db/queries" -disable_error_code = ["arg-type"] - -[[tool.mypy.overrides]] -module = "data_store.messaging.messaging" -disable_error_code = ["union-attr"] - -[[tool.mypy.overrides]] -module = "data_store.messaging.tf_messaging" -disable_error_code = ["call-overload", "arg-type"] [[tool.mypy.overrides]] module = "data_store/serialisation/data_serialiser" diff --git a/tests/data_store_tests/messaging_tests/test_tf_messaging.py b/tests/data_store_tests/messaging_tests/test_tf_messaging.py index 5af315032..547df9d43 100644 --- a/tests/data_store_tests/messaging_tests/test_tf_messaging.py +++ b/tests/data_store_tests/messaging_tests/test_tf_messaging.py @@ -433,7 +433,8 @@ def test_authorised_submission(): def test_generic_failure(): test_messeger = TFMessenger() - test_messeger._generic_failure( + + message = test_messeger._generic_failure( GenericFailure( table="Project Details", section="A Section", @@ -442,6 +443,30 @@ def test_generic_failure(): ) ) + assert message == Message( + sheet="Project Admin", + section="A Section", + cell_indexes=("C1",), + description="A message", + error_type="GenericFailure", + ) + + +def test_generic_failure_when_column_is_none(): + test_messeger = TFMessenger() + + with pytest.raises(ValueError): + test_messeger._generic_failure( + GenericFailure( + table="Project Details", + section="A Section", + cell_index=None, + message="A message", + column=None, + row_index=None, + ) + ) + def test_failures_to_messages(): failure1 = InvalidEnumValueFailure( @@ -742,6 +767,15 @@ def test_get_cell_indexes_for_outcomes(): assert cell4 == "G23" +def test_get_cell_indexes_for_outcomes_throws_exception(): + test_messenger = TFMessenger() + + failed_row = pd.Series({"Start_Date": pd.to_datetime("2024-05-01 12:00:00")}, name=None) + + with pytest.raises(TypeError): + test_messenger._get_cell_indexes_for_outcomes(failed_row) + + def test_get_uk_financial_year_start(): # Test case where start_date is in the same financial year start_date_1 = pd.to_datetime("2023-05-01 12:00:00")