Skip to content

Commit

Permalink
refactor: simplify code logic
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed May 4, 2024
1 parent d336046 commit 5725d61
Showing 1 changed file with 12 additions and 29 deletions.
41 changes: 12 additions & 29 deletions sqllineage/core/parser/sqlfluff/extractors/create_insert.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from sqlfluff.core.parser import BaseSegment

from sqllineage.core.holders import SubQueryLineageHolder
from sqllineage.core.models import Column, Path, Table
from sqllineage.core.models import Path, Table
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
from sqllineage.core.parser.sqlfluff.extractors.select import SelectExtractor
from sqllineage.core.parser.sqlfluff.models import SqlFluffColumn, SqlFluffTable
from sqllineage.core.parser.sqlfluff.utils import (
list_child_segments,
)
from sqllineage.utils.entities import AnalyzerContext, ColumnQualifierTuple
from sqllineage.utils.entities import AnalyzerContext
from sqllineage.utils.helpers import escape_identifier_name


Expand All @@ -32,8 +32,7 @@ def extract(
) -> SubQueryLineageHolder:
holder = self._init_holder(context)
src_flag = tgt_flag = False
statement_child = list_child_segments(statement)
for seg_idx, segment in enumerate(list_child_segments(statement)):
for segment in list_child_segments(statement):
if segment.type == "with_compound_statement":
holder |= self.delegate_to_cte(segment, holder)
elif segment.type == "bracketed" and any(
Expand Down Expand Up @@ -105,38 +104,22 @@ def extract(
if segment.type in ["table_reference", "object_reference"]:
write_obj = SqlFluffTable.of(segment)
holder.add_write(write_obj)
# get target table columns from metadata if available
if (
isinstance(write_obj, Table)
and self.metadata_provider
and statement.type == "insert_statement"
):
holder.add_write_column(
*self.metadata_provider.get_table_columns(table=write_obj)
)
elif segment.type == "literal":
if segment.raw.isnumeric():
# Special Handling for Spark Bucket Table DDL
pass
else:
holder.add_write(Path(escape_identifier_name(segment.raw)))
tgt_flag = False

if statement.type == "insert_statement" and holder.write:
sub_segments = list_child_segments(
segment=statement_child[seg_idx + 1]
)
if not all(
sub_segment.type in ["column_reference", "column_definition"]
for sub_segment in sub_segments
):
tgt_tab = list(holder.write)[0]
if isinstance(tgt_tab, Table) and tgt_tab.schema:
col_list = [
Column(
name=col.raw_name,
source_columns=[
ColumnQualifierTuple(
column=col.raw_name, qualifier=None
)
],
)
for col in self.metadata_provider.get_table_columns(
table=tgt_tab
)
]
holder.add_write_column(*col_list)
if src_flag:
if segment.type in ["table_reference", "object_reference"]:
holder.add_read(SqlFluffTable.of(segment))
Expand Down

0 comments on commit 5725d61

Please sign in to comment.