Skip to content

Commit

Permalink
refactor: pin to sqlfluff 3.3.0 and split a dedicated logic for funct…
Browse files Browse the repository at this point in the history
…ion handling
  • Loading branch information
reata committed Jan 31, 2025
1 parent 8c52ec2 commit 52d3034
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 28 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run(self) -> None:
install_requires=[
"sqlparse==0.5.0",
"networkx>=2.4",
"sqlfluff~=3.2.5",
"sqlfluff==3.3.0",
"sqlalchemy>=2.0.0",
],
entry_points={"console_scripts": ["sqllineage = sqllineage.cli:main"]},
Expand Down
36 changes: 19 additions & 17 deletions sqllineage/core/parser/sqlfluff/extractors/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,25 @@ def _handle_swap_partition(
if self.dialect == "vertica" and segment.type == "select_clause":
if select_clause_element := segment.get_child("select_clause_element"):
if function := select_clause_element.get_child("function"):
if (
function.first_non_whitespace_segment_raw_upper
== "SWAP_PARTITIONS_BETWEEN_TABLES"
):
if function_contents := function.get_child("function_contents"):
bracketed = function_contents.segments[0]
expressions = bracketed.get_children("expression")
holder.add_read(
SqlFluffTable(
escape_identifier_name(expressions[0].raw)
)
)
holder.add_write(
SqlFluffTable(
escape_identifier_name(expressions[3].raw)
)
)
if function_name := function.get_child("function_name"):
if function_name.raw_upper == "SWAP_PARTITIONS_BETWEEN_TABLES":
if function_contents := function.get_child(
"function_contents"
):
if bracketed := function_contents.get_child(
"bracketed"
):
expressions = bracketed.get_children("expression")
holder.add_read(
SqlFluffTable(
escape_identifier_name(expressions[0].raw)
)
)
holder.add_write(
SqlFluffTable(
escape_identifier_name(expressions[3].raw)
)
)

def _handle_select_into(self, segment: BaseSegment, holder: SubQueryLineageHolder):
"""
Expand Down
23 changes: 13 additions & 10 deletions sqllineage/core/parser/sqlfluff/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from sqllineage.utils.helpers import escape_identifier_name

NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE = [
"function",
"over_clause",
"partitionby_clause",
"orderby_clause",
"expression",
Expand All @@ -28,10 +26,13 @@
"cast_expression",
]

SOURCE_COLUMN_SEGMENT_TYPE = NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE + [
"identifier",
"column_reference",
]
FUNCTION_SEGMENT_TYPE = ["function"]

COLUMN_SEGMENT_TYPE = ["identifier", "column_reference"]

SOURCE_COLUMN_SEGMENT_TYPE = (
NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE + FUNCTION_SEGMENT_TYPE + COLUMN_SEGMENT_TYPE
)


class SqlFluffTable(Table):
Expand Down Expand Up @@ -151,14 +152,18 @@ def _extract_source_columns(segment: BaseSegment) -> List[ColumnQualifierTuple]:
:return: list of extracted source columns
"""
col_list = []
if segment.type in ("identifier", "column_reference") or is_wildcard(segment):
if segment.type in COLUMN_SEGMENT_TYPE or is_wildcard(segment):
if cqt := extract_column_qualifier(segment):
col_list = [cqt]
elif segment.type in FUNCTION_SEGMENT_TYPE:
for bracketed in segment.recursive_crawl("bracketed"):
# the bracketed could be in function_contents or over_clause in case of window function
col_list += SqlFluffColumn._get_column_from_parenthesis(bracketed)
elif segment.type in NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE:
sub_segments = list_child_segments(segment)
col_list = []
for sub_segment in sub_segments:
if sub_segment.type in ("bracketed", "function_contents"):
if sub_segment.type == "bracketed":
if is_subquery(sub_segment):
col_list += SqlFluffColumn._get_column_from_subquery(
sub_segment
Expand Down Expand Up @@ -206,8 +211,6 @@ def _get_column_from_parenthesis(
# windows function has an extra layer, get rid of it so that it can be handled as regular functions
if window_specification := sub_segment.get_child("window_specification"):
sub_segment = window_specification
elif sub_segment.type == "function_contents":
sub_segment = sub_segment.segments[0]
col, _ = SqlFluffColumn._get_column_and_alias(sub_segment, False)
return col if col else []

Expand Down

0 comments on commit 52d3034

Please sign in to comment.