Skip to content

Commit

Permalink
improve T.Literal utils
Browse files Browse the repository at this point in the history
  • Loading branch information
zigai committed Dec 4, 2023
1 parent 2a0733d commit f93277d
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 62 deletions.
87 changes: 50 additions & 37 deletions objinspect/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,6 @@ def type_to_str(t: T.Any) -> str:
return type_str.split(".")[-1]


def get_enum_choices(e) -> tuple[str, ...]:
"""
Get the options of a Python Enum.
Args:
e (enum.Enum): A Python Enum.
Returns:
tuple: A tuple of the names of the Enum options.
Example:
>>> import enum
>>> class Color(enum.Enum):
... RED = 1
... GREEN = 2
... BLUE = 3
>>> get_enum_choices(Color)
('RED', 'GREEN', 'BLUE')
"""
return tuple(e.__members__.keys())


def get_literal_choices(literal_t) -> tuple[str, ...]:
"""
Get the options of a Python Literal.
"""
if is_full_literal(literal_t):
return T.get_args(literal_t)
for i in T.get_args(literal_t):
if is_full_literal(i):
return T.get_args(i)
raise ValueError(f"{literal_t} is not a literal")


def call_method(obj: object, name: str, args: tuple = (), kwargs: dict = {}) -> T.Any:
"""
Call a method with the given name on the given object.
Expand Down Expand Up @@ -99,16 +65,38 @@ def is_enum(t: T.Any) -> bool:
return isinstance(t, EnumMeta)


def is_full_literal(t: T.Any) -> bool:
def get_enum_choices(e) -> tuple[str, ...]:
"""
Get the options of a Python Enum.
Args:
e (enum.Enum): A Python Enum.
Returns:
tuple: A tuple of the names of the Enum options.
Example:
>>> import enum
>>> class Color(enum.Enum):
... RED = 1
... GREEN = 2
... BLUE = 3
>>> get_enum_choices(Color)
('RED', 'GREEN', 'BLUE')
"""
return tuple(e.__members__.keys())


def is_pure_literal(t: T.Any) -> bool:
if t is typing_extensions.Literal:
return True
return False
if hasattr(t, "__origin__") and t.__origin__ is typing_extensions.Literal:
return True
return False


def is_literal(t: T.Any) -> bool:
if is_full_literal(t):
if is_pure_literal(t):
return True

for i in T.get_args(t):
Expand All @@ -117,6 +105,31 @@ def is_literal(t: T.Any) -> bool:
return False


def get_literal_choices(literal_t) -> tuple[str, ...]:
"""
Get the options of a Python Literal.
"""
if is_pure_literal(literal_t):
return T.get_args(literal_t)
for i in T.get_args(literal_t):
if is_pure_literal(i):
return T.get_args(i)
raise ValueError(f"{literal_t} is not a literal")


def literal_contains(literal_t, value: T.Any) -> bool:
"""
Check if a value is in a Python Literal.
"""
if not is_pure_literal(literal_t):
raise ValueError(f"{literal_t} is not a literal")

values = get_literal_choices(literal_t)
if not len(values):
raise ValueError(f"{literal_t} has no values")
return value in values


def create_function(
name: str,
args: dict[str, T.Tuple[T.Any, T.Any]],
Expand Down
79 changes: 54 additions & 25 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import pytest

from objinspect.util import create_function, get_literal_choices, is_literal
from objinspect.util import (
create_function,
get_literal_choices,
is_literal,
is_pure_literal,
literal_contains,
)


class TestCreateFunction:
Expand Down Expand Up @@ -64,53 +70,76 @@ def test_edge_cases(self):
assert nop() is None


import typing as T
class TestIsPureLiteral:
def test_literal_type_check(self):
assert is_pure_literal(T.Literal["a", "b"])

import pytest
def test_nested_literal_type_check(self):
assert not is_pure_literal(T.Union[T.Literal["a"], T.Literal["b"]])

def test_literal_or_none(self):
assert not is_pure_literal(T.Literal["b"] | None)

def test_basic_type_as_literal(self):
assert not is_pure_literal(str)

def test_t_literal_as_literal(self):
assert not is_pure_literal(T.Literal)


class TestIsLiteral:
def test_literal_type(self):
assert is_literal(T.Literal["a", "b"]) == True
assert is_literal(T.Literal["a", "b"])

def test_non_literal_type(self):
assert is_literal(int) == False
assert not is_literal(int)

def test_nested_literal_type(self):
nested_literal = T.Literal[T.Literal["a", "b"]]
assert is_literal(nested_literal) == True
assert is_literal(nested_literal)

def test_literal_or_none(self):
literal_or_none = T.Literal["a", "b"] | None
assert is_literal(literal_or_none) == True
assert is_literal(literal_or_none)

def test_composite_without_literal(self):
composite_without_literal = T.Union[int, str]
assert is_literal(composite_without_literal) == False
assert not is_literal(composite_without_literal)


class TestGetLiteralChoices:
def test_literal_type(self):
assert get_literal_choices(T.Literal["a", "b"]) == ("a", "b")
class TestIsInLiteral:
def test_value_matches_literal(self):
assert literal_contains(T.Literal["a", "b", "c"], "a")

def test_non_literal_type(self):
def test_value_does_not_match_literal(self):
assert not literal_contains(T.Literal["a", "b", "c"], "d")

def test_invalid_literal_type(self):
with pytest.raises(ValueError):
get_literal_choices(int)
literal_contains(int, 1)

def test_nested_literal_type(self):
nested_literal = T.Literal[T.Literal["a", "b"]]
assert get_literal_choices(nested_literal) == ("a", "b")
def test_none_value(self):
assert not literal_contains(T.Literal["a", "b", "c"], None)

def test_literal_or_none(self):
literal_or_none = T.Literal["a", "b"] | None
assert get_literal_choices(literal_or_none) == ("a", "b")
def test_complex_value(self):
class CustomClass:
pass

def test_composite_without_literal(self):
composite_without_literal = T.Union[int, str]
assert not literal_contains(T.Literal["a", "b", "c"], CustomClass())

def test_empty_literal(self):
with pytest.raises(ValueError):
get_literal_choices(composite_without_literal)
literal_contains(T.Literal[()], "a")


# Example of running the tests with pytest
if __name__ == "__main__":
pytest.main()
class TestGetLiteralChoices:
def test_get_choices_from_literal(self):
assert get_literal_choices(T.Literal["a", "b"]) == ("a", "b")

def test_invalid_literal_type_for_choices(self):
with pytest.raises(ValueError):
get_literal_choices(int)

def test_empty_literal_for_choices(self):
with pytest.raises(ValueError):
get_literal_choices(T.Literal)

0 comments on commit f93277d

Please sign in to comment.