From 7d5c66c323b6f40588862b20e9b01705a2aedc9f Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 8 Jan 2025 10:33:02 +0200 Subject: [PATCH] made function more robust per comments --- src/sempy_labs/_helper_functions.py | 14 +++++++++++--- src/sempy_labs/_model_bpa.py | 4 ++-- src/sempy_labs/_model_bpa_bulk.py | 4 ++-- src/sempy_labs/_vertipaq.py | 6 ++++-- src/sempy_labs/report/_report_bpa.py | 4 ++-- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/sempy_labs/_helper_functions.py b/src/sempy_labs/_helper_functions.py index 5cf1b3aa..1066bfee 100644 --- a/src/sempy_labs/_helper_functions.py +++ b/src/sempy_labs/_helper_functions.py @@ -1188,16 +1188,24 @@ def generate_guid(): return str(uuid.uuid4()) -def _get_column_value(lakehouse: str, table_name: str, column_name: str = 'RunId', function: str = 'MAX') -> int: +def _get_column_aggregate( + lakehouse: str, + table_name: str, + column_name: str = "RunId", + function: str = "max", + default_value: int = 0, +) -> int: from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() + function = function.upper() query = f"SELECT {function}({column_name}) FROM {lakehouse}.{table_name}" + if "COUNT" in function and "DISTINCT" in function: + query = f"SELECT COUNT(DISTINCT({column_name})) FROM {lakehouse}.{table_name}" dfSpark = spark.sql(query) - max_run_id = dfSpark.collect()[0][0] or 0 - return max_run_id + return dfSpark.collect()[0][0] or default_value def _make_list_unique(my_list): diff --git a/src/sempy_labs/_model_bpa.py b/src/sempy_labs/_model_bpa.py index 6dc29666..4b7f6094 100644 --- a/src/sempy_labs/_model_bpa.py +++ b/src/sempy_labs/_model_bpa.py @@ -12,7 +12,7 @@ resolve_workspace_capacity, resolve_dataset_name_and_id, get_language_codes, - _get_column_value, + _get_column_aggregate, resolve_workspace_name_and_id, ) from sempy_labs.lakehouse import get_lakehouse_tables, lakehouse_attached @@ -383,7 +383,7 @@ def translate_using_spark(rule_file): if len(lakeT_filt) == 0: runId = 1 else: - max_run_id = _get_column_value( + max_run_id = _get_column_aggregate( lakehouse=lakehouse, table_name=delta_table_name ) runId = max_run_id + 1 diff --git a/src/sempy_labs/_model_bpa_bulk.py b/src/sempy_labs/_model_bpa_bulk.py index 74438727..5d2e8a4c 100644 --- a/src/sempy_labs/_model_bpa_bulk.py +++ b/src/sempy_labs/_model_bpa_bulk.py @@ -6,7 +6,7 @@ save_as_delta_table, resolve_workspace_capacity, retry, - _get_column_value, + _get_column_aggregate, ) from sempy_labs.lakehouse import ( get_lakehouse_tables, @@ -76,7 +76,7 @@ def run_model_bpa_bulk( if len(lakeT_filt) == 0: runId = 1 else: - max_run_id = _get_column_value(lakehouse=lakehouse, table_name=output_table) + max_run_id = _get_column_aggregate(lakehouse=lakehouse, table_name=output_table) runId = max_run_id + 1 if isinstance(workspace, str): diff --git a/src/sempy_labs/_vertipaq.py b/src/sempy_labs/_vertipaq.py index 1443b4a4..4802fe4d 100644 --- a/src/sempy_labs/_vertipaq.py +++ b/src/sempy_labs/_vertipaq.py @@ -12,7 +12,7 @@ resolve_lakehouse_name, save_as_delta_table, resolve_workspace_capacity, - _get_column_value, + _get_column_aggregate, resolve_workspace_name_and_id, resolve_dataset_name_and_id, ) @@ -519,7 +519,9 @@ def _style_columns_based_on_types(dataframe: pd.DataFrame, column_type_mapping): if len(lakeT_filt) == 0: runId = 1 else: - max_run_id = _get_column_value(lakehouse=lakehouse, table_name=lakeTName) + max_run_id = _get_column_aggregate( + lakehouse=lakehouse, table_name=lakeTName + ) runId = max_run_id + 1 dfMap = { diff --git a/src/sempy_labs/report/_report_bpa.py b/src/sempy_labs/report/_report_bpa.py index edab37f1..6219dd7e 100644 --- a/src/sempy_labs/report/_report_bpa.py +++ b/src/sempy_labs/report/_report_bpa.py @@ -10,7 +10,7 @@ resolve_report_id, resolve_lakehouse_name, resolve_workspace_capacity, - _get_column_value, + _get_column_aggregate, resolve_workspace_name_and_id, ) from sempy_labs.lakehouse import get_lakehouse_tables, lakehouse_attached @@ -217,7 +217,7 @@ def execute_rule(row): if len(lakeT_filt) == 0: runId = 1 else: - max_run_id = _get_column_value( + max_run_id = _get_column_aggregate( lakehouse=lakehouse, table_name=delta_table_name ) runId = max_run_id + 1