Skip to content

Commit

Permalink
feat: upgrade sqlfluff to 3.3.0 (#660)
Browse files Browse the repository at this point in the history
* feat: upgrade sqlfluff to 3.2.x

* clean up

* refactor: pin to sqlfluff 3.3.0 and split a dedicated logic for function handling

* test: remove unneeded logic after upgrading sqlfluff

---------

Co-authored-by: reata <reddevil.hjw@gmail.com>
  • Loading branch information
rubytobi and reata authored Jan 31, 2025
1 parent a3a8ada commit b5b5e10
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 27 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run(self) -> None:
install_requires=[
"sqlparse==0.5.0",
"networkx>=2.4",
"sqlfluff==3.0.5",
"sqlfluff==3.3.0",
"sqlalchemy>=2.0.0",
],
entry_points={"console_scripts": ["sqllineage = sqllineage.cli:main"]},
Expand Down
3 changes: 0 additions & 3 deletions sqllineage/core/holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def add_write_column(self, *tgt_cols: Column) -> None:
tgt_tbl = list(self.write)[0]
for idx, tgt_col in enumerate(tgt_cols):
tgt_col.parent = tgt_tbl
if tgt_col in self.write_columns:
# for DDL with PARTITIONED BY (col) or CLUSTERED BY (col), column can be added multiple times
break
self.graph.add_edge(
tgt_tbl, tgt_col, type=EdgeType.HAS_COLUMN, **{EdgeTag.INDEX: idx}
)
Expand Down
35 changes: 19 additions & 16 deletions sqllineage/core/parser/sqlfluff/extractors/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,25 @@ def _handle_swap_partition(
if self.dialect == "vertica" and segment.type == "select_clause":
if select_clause_element := segment.get_child("select_clause_element"):
if function := select_clause_element.get_child("function"):
if (
function.first_non_whitespace_segment_raw_upper
== "SWAP_PARTITIONS_BETWEEN_TABLES"
):
if bracketed := function.get_child("bracketed"):
expressions = bracketed.get_children("expression")
holder.add_read(
SqlFluffTable(
escape_identifier_name(expressions[0].raw)
)
)
holder.add_write(
SqlFluffTable(
escape_identifier_name(expressions[3].raw)
)
)
if function_name := function.get_child("function_name"):
if function_name.raw_upper == "SWAP_PARTITIONS_BETWEEN_TABLES":
if function_contents := function.get_child(
"function_contents"
):
if bracketed := function_contents.get_child(
"bracketed"
):
expressions = bracketed.get_children("expression")
holder.add_read(
SqlFluffTable(
escape_identifier_name(expressions[0].raw)
)
)
holder.add_write(
SqlFluffTable(
escape_identifier_name(expressions[3].raw)
)
)

def _handle_select_into(self, segment: BaseSegment, holder: SubQueryLineageHolder):
"""
Expand Down
19 changes: 12 additions & 7 deletions sqllineage/core/parser/sqlfluff/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from sqllineage.utils.helpers import escape_identifier_name

NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE = [
"function",
"over_clause",
"partitionby_clause",
"orderby_clause",
"expression",
Expand All @@ -28,10 +26,13 @@
"cast_expression",
]

SOURCE_COLUMN_SEGMENT_TYPE = NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE + [
"identifier",
"column_reference",
]
FUNCTION_SEGMENT_TYPE = ["function"]

COLUMN_SEGMENT_TYPE = ["identifier", "column_reference"]

SOURCE_COLUMN_SEGMENT_TYPE = (
NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE + FUNCTION_SEGMENT_TYPE + COLUMN_SEGMENT_TYPE
)


class SqlFluffTable(Table):
Expand Down Expand Up @@ -151,9 +152,13 @@ def _extract_source_columns(segment: BaseSegment) -> List[ColumnQualifierTuple]:
:return: list of extracted source columns
"""
col_list = []
if segment.type in ("identifier", "column_reference") or is_wildcard(segment):
if segment.type in COLUMN_SEGMENT_TYPE or is_wildcard(segment):
if cqt := extract_column_qualifier(segment):
col_list = [cqt]
elif segment.type in FUNCTION_SEGMENT_TYPE:
for bracketed in segment.recursive_crawl("bracketed"):
# the bracketed could be in function_contents or over_clause in case of window function
col_list += SqlFluffColumn._get_column_from_parenthesis(bracketed)
elif segment.type in NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE:
sub_segments = list_child_segments(segment)
col_list = []
Expand Down

0 comments on commit b5b5e10

Please sign in to comment.