Skip to content

Commit

Permalink
refactor: use built-in collection types as generic types
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Feb 15, 2025
1 parent 3562f12 commit e733914
Show file tree
Hide file tree
Showing 21 changed files with 117 additions and 120 deletions.
6 changes: 3 additions & 3 deletions sqllineage/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import threading
from typing import Any, Dict, Set
from typing import Any

from sqllineage.exceptions import ConfigException

Expand All @@ -23,8 +23,8 @@ class _SQLLineageConfigLoader:
}

def __init__(self) -> None:
self._thread_config: Dict[int, Dict[str, Any]] = {}
self._thread_in_context_manager: Set[int] = set()
self._thread_config: dict[int, dict[str, Any]] = {}
self._thread_in_context_manager: set[int] = set()

def __getattr__(self, item: str):
if item in self.config.keys():
Expand Down
3 changes: 1 addition & 2 deletions sqllineage/core/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from typing import List

from sqllineage.core.holders import StatementLineageHolder
from sqllineage.core.metadata_provider import MetaDataProvider
Expand All @@ -11,7 +10,7 @@ class LineageAnalyzer:
"""

PARSER_NAME: str = ""
SUPPORTED_DIALECTS: List[str] = []
SUPPORTED_DIALECTS: list[str] = []

@abstractmethod
def analyze(
Expand Down
58 changes: 29 additions & 29 deletions sqllineage/core/holders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Optional, Union

import networkx as nx
from networkx import DiGraph
Expand All @@ -8,13 +8,13 @@
from sqllineage.core.models import Column, Path, Schema, SubQuery, Table
from sqllineage.utils.constant import EdgeTag, EdgeType, NodeTag

DATASET_CLASSES = (Path, Table)
DATAset_CLASSES = (Path, Table)


class ColumnLineageMixin:
def get_column_lineage(
self, exclude_path_ending_in_subquery=True, exclude_subquery_columns=False
) -> Set[Tuple[Column, ...]]:
) -> set[tuple[Column, ...]]:
"""
:param exclude_path_ending_in_subquery: exclude_subquery rename to exclude_path_ending_in_subquery
exclude column from SubQuery in the ending path
Expand Down Expand Up @@ -58,7 +58,7 @@ class SubQueryLineageHolder(ColumnLineageMixin):
SubQueryLineageHolder will hold attributes like read, write, cte.
Each of them is a Set[:class:`sqllineage.core.models.Table`].
Each of them is a set[:class:`sqllineage.core.models.Table`].
This is the most atomic representation of lineage result.
"""
Expand All @@ -70,14 +70,14 @@ def __or__(self, other):
self.graph = nx.compose(self.graph, other.graph)
return self

def _property_getter(self, prop) -> Set[Union[SubQuery, Table]]:
def _property_getter(self, prop) -> set[Union[SubQuery, Table]]:
return {t for t, attr in self.graph.nodes(data=True) if attr.get(prop) is True}

def _property_setter(self, value, prop) -> None:
self.graph.add_node(value, **{prop: True})

@property
def read(self) -> Set[Union[SubQuery, Table]]:
def read(self) -> set[Union[SubQuery, Table]]:
return self._property_getter(NodeTag.READ)

def add_read(self, value) -> None:
Expand All @@ -87,7 +87,7 @@ def add_read(self, value) -> None:
self.graph.add_edge(value, value.alias, type=EdgeType.HAS_ALIAS)

@property
def write(self) -> Set[Union[SubQuery, Table]]:
def write(self) -> set[Union[SubQuery, Table]]:
# SubQueryLineageHolder.write can return a single SubQuery or Table, or both when __or__ together.
# This is different from StatementLineageHolder.write, where Table is the only possibility.
return self._property_getter(NodeTag.WRITE)
Expand All @@ -96,22 +96,22 @@ def add_write(self, value) -> None:
self._property_setter(value, NodeTag.WRITE)

@property
def cte(self) -> Set[SubQuery]:
def cte(self) -> set[SubQuery]:
return self._property_getter(NodeTag.CTE) # type: ignore

def add_cte(self, value) -> None:
self._property_setter(value, NodeTag.CTE)

@property
def write_columns(self) -> List[Column]:
def write_columns(self) -> list[Column]:
"""
return a list of columns that write table contains.
It's either manually added via `add_write_column` if specified in DML
or automatic added via `add_column_lineage` after parsing from SELECT
"""
tgt_cols = []
if tgt_tbl := self._get_target_table():
tgt_col_with_idx: List[Tuple[Column, int]] = sorted(
tgt_col_with_idx: list[tuple[Column, int]] = sorted(
[
(col, attr.get(EdgeTag.INDEX, 0))
for tbl, col, attr in self.graph.out_edges(tgt_tbl, data=True)
Expand Down Expand Up @@ -151,7 +151,7 @@ def add_column_lineage(self, src: Column, tgt: Column) -> None:
# starting NetworkX v2.6, None is not allowed as node, see https://github.com/networkx/networkx/pull/4892
self.graph.add_edge(src.parent, src, type=EdgeType.HAS_COLUMN)

def get_table_columns(self, table: Union[Table, SubQuery]) -> List[Column]:
def get_table_columns(self, table: Union[Table, SubQuery]) -> list[Column]:
return [
tgt
for (src, tgt, edge_type) in self.graph.out_edges(nbunch=table, data="type")
Expand Down Expand Up @@ -185,8 +185,8 @@ def expand_wildcard(self, metadata_provider: MetaDataProvider) -> None:
)

def get_alias_mapping_from_table_group(
self, table_group: List[Union[Path, Table, SubQuery]]
) -> Dict[str, Union[Path, Table, SubQuery]]:
self, table_group: list[Union[Path, Table, SubQuery]]
) -> dict[str, Union[Path, Table, SubQuery]]:
"""
A table can be referred to as alias, table name, or database_name.table_name, create the mapping here.
For SubQuery, it's only alias then.
Expand All @@ -210,7 +210,7 @@ def _get_target_table(self) -> Optional[Union[SubQuery, Table]]:
table = next(iter(write_only))
return table

def get_source_columns(self, node: Column) -> List[Column]:
def get_source_columns(self, node: Column) -> list[Column]:
return [
src
for (src, tgt, edge_type) in self.graph.in_edges(nbunch=node, data="type")
Expand All @@ -220,7 +220,7 @@ def get_source_columns(self, node: Column) -> List[Column]:
def _replace_wildcard(
self,
tgt_table: Union[Table, SubQuery],
src_table_columns: List[Column],
src_table_columns: list[Column],
tgt_wildcard: Column,
src_wildcard: Column,
) -> None:
Expand All @@ -246,9 +246,9 @@ class StatementLineageHolder(SubQueryLineageHolder, ColumnLineageMixin):
Based on SubQueryLineageHolder, StatementLineageHolder holds extra attributes like drop and rename
For drop, it is a Set[:class:`sqllineage.core.models.Table`].
For drop, it is a set[:class:`sqllineage.core.models.Table`].
For rename, it a Set[Tuple[:class:`sqllineage.core.models.Table`, :class:`sqllineage.core.models.Table`]],
For rename, it a set[tuple[:class:`sqllineage.core.models.Table`, :class:`sqllineage.core.models.Table`]],
with the first table being original table before renaming and the latter after renaming.
"""

Expand All @@ -262,22 +262,22 @@ def __repr__(self):
return str(self)

@property
def read(self) -> Set[Table]: # type: ignore
return {t for t in super().read if isinstance(t, DATASET_CLASSES)}
def read(self) -> set[Table]: # type: ignore
return {t for t in super().read if isinstance(t, DATAset_CLASSES)}

@property
def write(self) -> Set[Table]: # type: ignore
return {t for t in super().write if isinstance(t, DATASET_CLASSES)}
def write(self) -> set[Table]: # type: ignore
return {t for t in super().write if isinstance(t, DATAset_CLASSES)}

@property
def drop(self) -> Set[Table]:
def drop(self) -> set[Table]:
return self._property_getter(NodeTag.DROP) # type: ignore

def add_drop(self, value) -> None:
self._property_setter(value, NodeTag.DROP)

@property
def rename(self) -> Set[Tuple[Table, Table]]:
def rename(self) -> set[tuple[Table, Table]]:
return {
(src, tgt)
for src, tgt, attr in self.graph.edges(data=True)
Expand Down Expand Up @@ -311,7 +311,7 @@ def table_lineage_graph(self) -> DiGraph:
"""
The table level DiGraph held by SQLLineageHolder
"""
table_nodes = [n for n in self.graph.nodes if isinstance(n, DATASET_CLASSES)]
table_nodes = [n for n in self.graph.nodes if isinstance(n, DATAset_CLASSES)]
return self.graph.subgraph(table_nodes)

@property
Expand All @@ -323,7 +323,7 @@ def column_lineage_graph(self) -> DiGraph:
return self.graph.subgraph(column_nodes)

@property
def source_tables(self) -> Set[Table]:
def source_tables(self) -> set[Table]:
"""
a list of source :class:`sqllineage.core.models.Table`
"""
Expand All @@ -337,7 +337,7 @@ def source_tables(self) -> Set[Table]:
return source_tables

@property
def target_tables(self) -> Set[Table]:
def target_tables(self) -> set[Table]:
"""
a list of target :class:`sqllineage.core.models.Table`
"""
Expand All @@ -351,7 +351,7 @@ def target_tables(self) -> Set[Table]:
return target_tables

@property
def intermediate_tables(self) -> Set[Table]:
def intermediate_tables(self) -> set[Table]:
"""
a list of intermediate :class:`sqllineage.core.models.Table`
"""
Expand All @@ -363,11 +363,11 @@ def intermediate_tables(self) -> Set[Table]:
intermediate_tables -= self.__retrieve_tag_tables(NodeTag.SELFLOOP)
return intermediate_tables

def __retrieve_tag_tables(self, tag) -> Set[Union[Path, Table]]:
def __retrieve_tag_tables(self, tag) -> set[Union[Path, Table]]:
return {
table
for table, attr in self.graph.nodes(data=True)
if attr.get(tag) is True and isinstance(table, DATASET_CLASSES)
if attr.get(tag) is True and isinstance(table, DATAset_CLASSES)
}

@staticmethod
Expand Down
6 changes: 3 additions & 3 deletions sqllineage/core/metadata/dummy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Optional

from sqllineage.core.metadata_provider import MetaDataProvider

Expand All @@ -8,14 +8,14 @@ class DummyMetaDataProvider(MetaDataProvider):
A Dummy MetaDataProvider that accept metadata as a dict
"""

def __init__(self, metadata: Optional[Dict[str, List[str]]] = None):
def __init__(self, metadata: Optional[dict[str, list[str]]] = None):
"""
:param metadata: a dict with schema.table name as key and a list of unqualified column name as value
"""
super().__init__()
self.metadata = metadata if metadata is not None else {}

def _get_table_columns(self, schema: str, table: str, **kwargs) -> List[str]:
def _get_table_columns(self, schema: str, table: str, **kwargs) -> list[str]:
return self.metadata.get(f"{schema}.{table}", [])

def __bool__(self):
Expand Down
6 changes: 3 additions & 3 deletions sqllineage/core/metadata/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional

from sqlalchemy import MetaData, Table, create_engine, make_url
from sqlalchemy.exc import NoSuchModuleError, NoSuchTableError, OperationalError
Expand All @@ -16,7 +16,7 @@ class SQLAlchemyMetaDataProvider(MetaDataProvider):
SQLAlchemyMetaDataProvider queries metadata from database using SQLAlchemy
"""

def __init__(self, url: str, engine_kwargs: Optional[Dict[str, Any]] = None):
def __init__(self, url: str, engine_kwargs: Optional[dict[str, Any]] = None):
"""
:param url: sqlalchemy url
:param engine_kwargs: a dictionary of keyword arguments that will be passed to sqlalchemy create_engine
Expand All @@ -37,7 +37,7 @@ def __init__(self, url: str, engine_kwargs: Optional[Dict[str, Any]] = None):
except OperationalError as e:
raise MetaDataProviderException(f"Could not connect to {url}") from e

def _get_table_columns(self, schema: str, table: str, **kwargs) -> List[str]:
def _get_table_columns(self, schema: str, table: str, **kwargs) -> list[str]:
columns = []
try:
sqlalchemy_table = Table(
Expand Down
11 changes: 5 additions & 6 deletions sqllineage/core/metadata_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from typing import Dict, List

from sqllineage.core.models import Column, Table

Expand All @@ -23,9 +22,9 @@ class MetaDataProvider:
"""

def __init__(self) -> None:
self._session_metadata: Dict[str, List[str]] = {}
self._session_metadata: dict[str, list[str]] = {}

def get_table_columns(self, table: Table, **kwargs) -> List[Column]:
def get_table_columns(self, table: Table, **kwargs) -> list[Column]:
"""
return columns of given table.
"""
Expand All @@ -41,10 +40,10 @@ def get_table_columns(self, table: Table, **kwargs) -> List[Column]:
return columns

@abstractmethod
def _get_table_columns(self, schema: str, table: str, **kwargs) -> List[str]:
def _get_table_columns(self, schema: str, table: str, **kwargs) -> list[str]:
"""To be implemented by subclasses."""

def register_session_metadata(self, table: Table, columns: List[Column]) -> None:
def register_session_metadata(self, table: Table, columns: list[Column]) -> None:
"""Register session-level metadata, like temporary table or view created."""
self._session_metadata[str(table)] = [c.raw_name for c in columns]

Expand Down Expand Up @@ -78,5 +77,5 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.metadata_provider.deregister_session_metadata()

def register_session_metadata(self, table: Table, columns: List[Column]) -> None:
def register_session_metadata(self, table: Table, columns: list[Column]) -> None:
self.metadata_provider.register_session_metadata(table, columns)
8 changes: 4 additions & 4 deletions sqllineage/core/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Optional, Union

from sqllineage.config import SQLLineageConfig
from sqllineage.exceptions import SQLLineageException
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(self, name: str, **kwargs):
:param parent: :class:`Table` or :class:`SubQuery`
:param kwargs:
"""
self._parent: Set[Union[Path, Table, SubQuery]] = set()
self._parent: set[Union[Path, Table, SubQuery]] = set()
self.raw_name = escape_identifier_name(name)
self.source_columns = [
(
Expand Down Expand Up @@ -193,7 +193,7 @@ def parent(self, value: Union[Path, Table, SubQuery]):
self._parent.add(value)

@property
def parent_candidates(self) -> List[Union[Path, Table, SubQuery]]:
def parent_candidates(self) -> list[Union[Path, Table, SubQuery]]:
return sorted(self._parent, key=lambda p: str(p))

@staticmethod
Expand All @@ -205,7 +205,7 @@ def of(column: Any, **kwargs) -> "Column":
"""
raise NotImplementedError

def to_source_columns(self, alias_mapping: Dict[str, Union[Path, Table, SubQuery]]):
def to_source_columns(self, alias_mapping: dict[str, Union[Path, Table, SubQuery]]):
"""
Best guess for source table given all the possible table/subquery and their alias.
"""
Expand Down
10 changes: 5 additions & 5 deletions sqllineage/core/parser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple, Union
from typing import Union

from sqllineage.config import SQLLineageConfig
from sqllineage.core.holders import SubQueryLineageHolder
Expand All @@ -7,9 +7,9 @@


class SourceHandlerMixin:
tables: List[Union[Path, SubQuery, Table]]
columns: List[Column]
union_barriers: List[Tuple[int, int]]
tables: list[Union[Path, SubQuery, Table]]
columns: list[Column]
union_barriers: list[tuple[int, int]]

def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None:
for i, tbl in enumerate(self.tables):
Expand All @@ -25,7 +25,7 @@ def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None:
if len(holder.write) > 1:
raise SQLLineageException
tgt_tbl = next(iter(holder.write))
lateral_column_aliases: Dict[str, List[Column]] = {}
lateral_column_aliases: dict[str, list[Column]] = {}
for idx, tgt_col_from_query in enumerate(col_grp):
tgt_col_from_query.parent = tgt_tbl
tgt_col_resolved = tgt_col_from_query
Expand Down
Loading

0 comments on commit e733914

Please sign in to comment.