diff --git a/sqllineage/config.py b/sqllineage/config.py index 6e75ac64..b4d83e01 100644 --- a/sqllineage/config.py +++ b/sqllineage/config.py @@ -1,4 +1,8 @@ import os +import threading +from typing import Any, Dict, Set + +from sqllineage.exceptions import ConfigException class _SQLLineageConfigLoader: @@ -17,10 +21,18 @@ class _SQLLineageConfigLoader: # lateral column alias reference supported by some dialect (redshift, spark 3.4+, etc) "LATERAL_COLUMN_ALIAS_REFERENCE": (bool, False), } - BOOLEAN_TRUE_STRINGS = ("true", "on", "ok", "y", "yes", "1") - def __getattr__(self, item): - if item in self.config: + def __init__(self) -> None: + self._thread_config: Dict[int, Dict[str, Any]] = {} + self._thread_in_context_manager: Set[int] = set() + + def __getattr__(self, item: str): + if item in self.config.keys(): + if ( + value := self._thread_config.get(self.get_ident(), {}).get(item) + ) is not None: + return value + type_, default = self.config[item] # require SQLLINEAGE_ prefix from environment variable return self.parse_value( @@ -29,23 +41,57 @@ def __getattr__(self, item): else: return super().__getattribute__(item) - @classmethod - def parse_value(cls, value, cast): - """Parse and cast provided value + def __setattr__(self, key, value) -> None: + if key in self.config: + raise ConfigException( + "SQLLineageConfig is read-only. Use context manager to update thread level config." + ) + else: + super().__setattr__(key, value) + + def __call__(self, *args, **kwargs): + if self.get_ident() not in self._thread_config.keys(): + self._thread_config[self.get_ident()] = {} + for key, value in kwargs.items(): + if key in self.config.keys(): + self._thread_config[self.get_ident()][key] = self.parse_value( + value, self.config[key][0] + ) + else: + raise ConfigException(f"Invalid config key: {key}") + return self + def __enter__(self): + if (thread_id := self.get_ident()) not in self._thread_in_context_manager: + self._thread_in_context_manager.add(thread_id) + else: + raise ConfigException("SQLLineageConfig context manager is not reentrant") + + def __exit__(self, exc_type, exc_val, exc_tb): + thread_id = self.get_ident() + if thread_id in self._thread_config: + self._thread_config.pop(self.get_ident()) + if thread_id in self._thread_in_context_manager: + self._thread_in_context_manager.remove(thread_id) + + @staticmethod + def get_ident() -> int: + return threading.get_ident() + + @staticmethod + def parse_value(value, cast) -> Any: + """Parse and cast provided value :param value: Stringed value. :param cast: Type to cast return value as. - - :returns: Casted value + :returns: cast value """ if cast is bool: try: value = int(value) != 0 except ValueError: - value = value.lower().strip() in cls.BOOLEAN_TRUE_STRINGS + value = value.lower().strip() in ("true", "on", "ok", "y", "yes", "1") else: value = cast(value) - return value diff --git a/sqllineage/exceptions.py b/sqllineage/exceptions.py index 99d42ca9..3b87d9ff 100644 --- a/sqllineage/exceptions.py +++ b/sqllineage/exceptions.py @@ -12,3 +12,7 @@ class InvalidSyntaxException(SQLLineageException): class MetaDataProviderException(SQLLineageException): """Raised for MetaDataProvider errors""" + + +class ConfigException(SQLLineageException): + """Raised for configuration errors""" diff --git a/tests/core/test_config.py b/tests/core/test_config.py index faf09d0d..9d7ed46d 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -1,7 +1,13 @@ +import concurrent.futures import os +import random +import time from unittest.mock import patch +import pytest + from sqllineage.config import SQLLineageConfig +from sqllineage.exceptions import ConfigException @patch( @@ -21,3 +27,53 @@ def test_config(): assert type(SQLLineageConfig.TSQL_NO_SEMICOLON) is bool assert SQLLineageConfig.TSQL_NO_SEMICOLON is True + + +def test_disable_direct_update_config(): + with pytest.raises(ConfigException): + SQLLineageConfig.DEFAULT_SCHEMA = "ods" + + +def test_update_config_using_context_manager(): + with SQLLineageConfig(LATERAL_COLUMN_ALIAS_REFERENCE=True): + assert SQLLineageConfig.LATERAL_COLUMN_ALIAS_REFERENCE is True + assert SQLLineageConfig.LATERAL_COLUMN_ALIAS_REFERENCE is False + + with SQLLineageConfig(DEFAULT_SCHEMA="ods"): + assert SQLLineageConfig.DEFAULT_SCHEMA == "ods" + assert SQLLineageConfig.DEFAULT_SCHEMA == "" + + with SQLLineageConfig(DIRECTORY=""): + assert SQLLineageConfig.DIRECTORY == "" + assert SQLLineageConfig.DIRECTORY != "" + + +def test_update_config_context_manager_non_reentrant(): + with pytest.raises(ConfigException): + with SQLLineageConfig(DEFAULT_SCHEMA="ods"): + with SQLLineageConfig(DEFAULT_SCHEMA="dwd"): + pass + + +def test_disable_update_unknown_config(): + with pytest.raises(ConfigException): + with SQLLineageConfig(UNKNOWN_KEY="value"): + pass + + +def _check_schema(schema: str): + # used by test_config_parallel, must be a global function so that it can be pickled between processes + with SQLLineageConfig(DEFAULT_SCHEMA=schema): + # randomly sleep [0, 0.1) second to simulate real parsing scenario + time.sleep(random.random() * 0.1) + return SQLLineageConfig.DEFAULT_SCHEMA + + +@pytest.mark.parametrize("pool", ["ThreadPoolExecutor", "ProcessPoolExecutor"]) +def test_config_parallel(pool: str): + executor_class = getattr(concurrent.futures, pool) + schemas = [f"db{i}" for i in range(100)] + with executor_class() as executor: + futures = [executor.submit(_check_schema, schema) for schema in schemas] + for i, future in enumerate(futures): + assert future.result() == schemas[i] diff --git a/tox.ini b/tox.ini index d4163831..f8f73461 100644 --- a/tox.ini +++ b/tox.ini @@ -14,7 +14,7 @@ commands = [flake8] exclude = .tox,.git,__pycache__,build,sqllineagejs,venv,env max-line-length = 120 -# ignore = D100,D101 +ignore = A005,W503 show-source = true enable-extensions=G application-import-names = sqllineage