diff --git a/objinspect/util.py b/objinspect/util.py index da187f5..89afae9 100644 --- a/objinspect/util.py +++ b/objinspect/util.py @@ -29,40 +29,6 @@ def type_to_str(t: T.Any) -> str: return type_str.split(".")[-1] -def get_enum_choices(e) -> tuple[str, ...]: - """ - Get the options of a Python Enum. - - Args: - e (enum.Enum): A Python Enum. - - Returns: - tuple: A tuple of the names of the Enum options. - - Example: - >>> import enum - >>> class Color(enum.Enum): - ... RED = 1 - ... GREEN = 2 - ... BLUE = 3 - >>> get_enum_choices(Color) - ('RED', 'GREEN', 'BLUE') - """ - return tuple(e.__members__.keys()) - - -def get_literal_choices(literal_t) -> tuple[str, ...]: - """ - Get the options of a Python Literal. - """ - if is_full_literal(literal_t): - return T.get_args(literal_t) - for i in T.get_args(literal_t): - if is_full_literal(i): - return T.get_args(i) - raise ValueError(f"{literal_t} is not a literal") - - def call_method(obj: object, name: str, args: tuple = (), kwargs: dict = {}) -> T.Any: """ Call a method with the given name on the given object. @@ -99,16 +65,38 @@ def is_enum(t: T.Any) -> bool: return isinstance(t, EnumMeta) -def is_full_literal(t: T.Any) -> bool: +def get_enum_choices(e) -> tuple[str, ...]: + """ + Get the options of a Python Enum. + + Args: + e (enum.Enum): A Python Enum. + + Returns: + tuple: A tuple of the names of the Enum options. + + Example: + >>> import enum + >>> class Color(enum.Enum): + ... RED = 1 + ... GREEN = 2 + ... BLUE = 3 + >>> get_enum_choices(Color) + ('RED', 'GREEN', 'BLUE') + """ + return tuple(e.__members__.keys()) + + +def is_pure_literal(t: T.Any) -> bool: if t is typing_extensions.Literal: - return True + return False if hasattr(t, "__origin__") and t.__origin__ is typing_extensions.Literal: return True return False def is_literal(t: T.Any) -> bool: - if is_full_literal(t): + if is_pure_literal(t): return True for i in T.get_args(t): @@ -117,6 +105,31 @@ def is_literal(t: T.Any) -> bool: return False +def get_literal_choices(literal_t) -> tuple[str, ...]: + """ + Get the options of a Python Literal. + """ + if is_pure_literal(literal_t): + return T.get_args(literal_t) + for i in T.get_args(literal_t): + if is_pure_literal(i): + return T.get_args(i) + raise ValueError(f"{literal_t} is not a literal") + + +def literal_contains(literal_t, value: T.Any) -> bool: + """ + Check if a value is in a Python Literal. + """ + if not is_pure_literal(literal_t): + raise ValueError(f"{literal_t} is not a literal") + + values = get_literal_choices(literal_t) + if not len(values): + raise ValueError(f"{literal_t} has no values") + return value in values + + def create_function( name: str, args: dict[str, T.Tuple[T.Any, T.Any]], diff --git a/tests/test_utils.py b/tests/test_utils.py index ecc07cc..c054208 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,13 @@ import pytest -from objinspect.util import create_function, get_literal_choices, is_literal +from objinspect.util import ( + create_function, + get_literal_choices, + is_literal, + is_pure_literal, + literal_contains, +) class TestCreateFunction: @@ -64,53 +70,76 @@ def test_edge_cases(self): assert nop() is None -import typing as T +class TestIsPureLiteral: + def test_literal_type_check(self): + assert is_pure_literal(T.Literal["a", "b"]) -import pytest + def test_nested_literal_type_check(self): + assert not is_pure_literal(T.Union[T.Literal["a"], T.Literal["b"]]) + + def test_literal_or_none(self): + assert not is_pure_literal(T.Literal["b"] | None) + + def test_basic_type_as_literal(self): + assert not is_pure_literal(str) + + def test_t_literal_as_literal(self): + assert not is_pure_literal(T.Literal) class TestIsLiteral: def test_literal_type(self): - assert is_literal(T.Literal["a", "b"]) == True + assert is_literal(T.Literal["a", "b"]) def test_non_literal_type(self): - assert is_literal(int) == False + assert not is_literal(int) def test_nested_literal_type(self): nested_literal = T.Literal[T.Literal["a", "b"]] - assert is_literal(nested_literal) == True + assert is_literal(nested_literal) def test_literal_or_none(self): literal_or_none = T.Literal["a", "b"] | None - assert is_literal(literal_or_none) == True + assert is_literal(literal_or_none) def test_composite_without_literal(self): composite_without_literal = T.Union[int, str] - assert is_literal(composite_without_literal) == False + assert not is_literal(composite_without_literal) -class TestGetLiteralChoices: - def test_literal_type(self): - assert get_literal_choices(T.Literal["a", "b"]) == ("a", "b") +class TestIsInLiteral: + def test_value_matches_literal(self): + assert literal_contains(T.Literal["a", "b", "c"], "a") - def test_non_literal_type(self): + def test_value_does_not_match_literal(self): + assert not literal_contains(T.Literal["a", "b", "c"], "d") + + def test_invalid_literal_type(self): with pytest.raises(ValueError): - get_literal_choices(int) + literal_contains(int, 1) - def test_nested_literal_type(self): - nested_literal = T.Literal[T.Literal["a", "b"]] - assert get_literal_choices(nested_literal) == ("a", "b") + def test_none_value(self): + assert not literal_contains(T.Literal["a", "b", "c"], None) - def test_literal_or_none(self): - literal_or_none = T.Literal["a", "b"] | None - assert get_literal_choices(literal_or_none) == ("a", "b") + def test_complex_value(self): + class CustomClass: + pass - def test_composite_without_literal(self): - composite_without_literal = T.Union[int, str] + assert not literal_contains(T.Literal["a", "b", "c"], CustomClass()) + + def test_empty_literal(self): with pytest.raises(ValueError): - get_literal_choices(composite_without_literal) + literal_contains(T.Literal[()], "a") -# Example of running the tests with pytest -if __name__ == "__main__": - pytest.main() +class TestGetLiteralChoices: + def test_get_choices_from_literal(self): + assert get_literal_choices(T.Literal["a", "b"]) == ("a", "b") + + def test_invalid_literal_type_for_choices(self): + with pytest.raises(ValueError): + get_literal_choices(int) + + def test_empty_literal_for_choices(self): + with pytest.raises(ValueError): + get_literal_choices(T.Literal)