From b5b5e10df55f0d153053f5e6fcb419dac89cd42d Mon Sep 17 00:00:00 2001 From: Tobias Ruby Date: Fri, 31 Jan 2025 20:57:27 +1100 Subject: [PATCH] feat: upgrade sqlfluff to 3.3.0 (#660) * feat: upgrade sqlfluff to 3.2.x * clean up * refactor: pin to sqlfluff 3.3.0 and split a dedicated logic for function handling * test: remove unneeded logic after upgrading sqlfluff --------- Co-authored-by: reata --- setup.py | 2 +- sqllineage/core/holders.py | 3 -- .../core/parser/sqlfluff/extractors/select.py | 35 ++++++++++--------- sqllineage/core/parser/sqlfluff/models.py | 19 ++++++---- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index b6603d82..75fb559e 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ def run(self) -> None: install_requires=[ "sqlparse==0.5.0", "networkx>=2.4", - "sqlfluff==3.0.5", + "sqlfluff==3.3.0", "sqlalchemy>=2.0.0", ], entry_points={"console_scripts": ["sqllineage = sqllineage.cli:main"]}, diff --git a/sqllineage/core/holders.py b/sqllineage/core/holders.py index 8a7086d9..d8938161 100644 --- a/sqllineage/core/holders.py +++ b/sqllineage/core/holders.py @@ -137,9 +137,6 @@ def add_write_column(self, *tgt_cols: Column) -> None: tgt_tbl = list(self.write)[0] for idx, tgt_col in enumerate(tgt_cols): tgt_col.parent = tgt_tbl - if tgt_col in self.write_columns: - # for DDL with PARTITIONED BY (col) or CLUSTERED BY (col), column can be added multiple times - break self.graph.add_edge( tgt_tbl, tgt_col, type=EdgeType.HAS_COLUMN, **{EdgeTag.INDEX: idx} ) diff --git a/sqllineage/core/parser/sqlfluff/extractors/select.py b/sqllineage/core/parser/sqlfluff/extractors/select.py index ce4490b3..a61ad067 100644 --- a/sqllineage/core/parser/sqlfluff/extractors/select.py +++ b/sqllineage/core/parser/sqlfluff/extractors/select.py @@ -97,22 +97,25 @@ def _handle_swap_partition( if self.dialect == "vertica" and segment.type == "select_clause": if select_clause_element := segment.get_child("select_clause_element"): if function := select_clause_element.get_child("function"): - if ( - function.first_non_whitespace_segment_raw_upper - == "SWAP_PARTITIONS_BETWEEN_TABLES" - ): - if bracketed := function.get_child("bracketed"): - expressions = bracketed.get_children("expression") - holder.add_read( - SqlFluffTable( - escape_identifier_name(expressions[0].raw) - ) - ) - holder.add_write( - SqlFluffTable( - escape_identifier_name(expressions[3].raw) - ) - ) + if function_name := function.get_child("function_name"): + if function_name.raw_upper == "SWAP_PARTITIONS_BETWEEN_TABLES": + if function_contents := function.get_child( + "function_contents" + ): + if bracketed := function_contents.get_child( + "bracketed" + ): + expressions = bracketed.get_children("expression") + holder.add_read( + SqlFluffTable( + escape_identifier_name(expressions[0].raw) + ) + ) + holder.add_write( + SqlFluffTable( + escape_identifier_name(expressions[3].raw) + ) + ) def _handle_select_into(self, segment: BaseSegment, holder: SubQueryLineageHolder): """ diff --git a/sqllineage/core/parser/sqlfluff/models.py b/sqllineage/core/parser/sqlfluff/models.py index da498d0a..836e318f 100644 --- a/sqllineage/core/parser/sqlfluff/models.py +++ b/sqllineage/core/parser/sqlfluff/models.py @@ -16,8 +16,6 @@ from sqllineage.utils.helpers import escape_identifier_name NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE = [ - "function", - "over_clause", "partitionby_clause", "orderby_clause", "expression", @@ -28,10 +26,13 @@ "cast_expression", ] -SOURCE_COLUMN_SEGMENT_TYPE = NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE + [ - "identifier", - "column_reference", -] +FUNCTION_SEGMENT_TYPE = ["function"] + +COLUMN_SEGMENT_TYPE = ["identifier", "column_reference"] + +SOURCE_COLUMN_SEGMENT_TYPE = ( + NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE + FUNCTION_SEGMENT_TYPE + COLUMN_SEGMENT_TYPE +) class SqlFluffTable(Table): @@ -151,9 +152,13 @@ def _extract_source_columns(segment: BaseSegment) -> List[ColumnQualifierTuple]: :return: list of extracted source columns """ col_list = [] - if segment.type in ("identifier", "column_reference") or is_wildcard(segment): + if segment.type in COLUMN_SEGMENT_TYPE or is_wildcard(segment): if cqt := extract_column_qualifier(segment): col_list = [cqt] + elif segment.type in FUNCTION_SEGMENT_TYPE: + for bracketed in segment.recursive_crawl("bracketed"): + # the bracketed could be in function_contents or over_clause in case of window function + col_list += SqlFluffColumn._get_column_from_parenthesis(bracketed) elif segment.type in NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE: sub_segments = list_child_segments(segment) col_list = []