diff --git a/src/api.py b/src/api.py index 3e56ba1..a31825a 100644 --- a/src/api.py +++ b/src/api.py @@ -3,7 +3,6 @@ from pydantic import BaseModel from pyfreeradius import User, Group, Nas from pyfreeradius import UserRepository, GroupRepository, NasRepository -from typing import List # # We want our REST API endpoints to be KISS! @@ -39,7 +38,7 @@ def read_root(): return {"Welcome!": f"API docs is available at {API_URL}/docs"} -@router.get("/nas", tags=["nas"], status_code=200, response_model=List[str]) +@router.get("/nas", tags=["nas"], status_code=200, response_model=list[str]) def get_nases(response: Response, from_nasname: str | None = None): nasnames = nas_repo.find_nasnames(from_nasname) if nasnames: @@ -48,7 +47,7 @@ def get_nases(response: Response, from_nasname: str | None = None): return nasnames -@router.get("/users", tags=["users"], status_code=200, response_model=List[str]) +@router.get("/users", tags=["users"], status_code=200, response_model=list[str]) def get_users(response: Response, from_username: str | None = None): usernames = user_repo.find_usernames(from_username) if usernames: @@ -57,7 +56,7 @@ def get_users(response: Response, from_username: str | None = None): return usernames -@router.get("/groups", tags=["groups"], status_code=200, response_model=List[str]) +@router.get("/groups", tags=["groups"], status_code=200, response_model=list[str]) def get_groups(response: Response, from_groupname: str | None = None): groupnames = group_repo.find_groupnames(from_groupname) if groupnames: diff --git a/src/pyfreeradius.py b/src/pyfreeradius.py index ebbb7d3..566ce86 100644 --- a/src/pyfreeradius.py +++ b/src/pyfreeradius.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from pydantic import BaseModel, StringConstraints, Field, model_validator -from typing import List from typing_extensions import Annotated # @@ -31,9 +30,9 @@ class GroupUser(BaseModel): class User(BaseModel): username: Annotated[str, StringConstraints(min_length=1)] - checks: List[AttributeOpValue] = [] - replies: List[AttributeOpValue] = [] - groups: List[UserGroup] = [] + checks: list[AttributeOpValue] = [] + replies: list[AttributeOpValue] = [] + groups: list[UserGroup] = [] @model_validator(mode="after") def check_fields_on_init(self): @@ -71,9 +70,9 @@ def check_fields_on_init(self): class Group(BaseModel): groupname: Annotated[str, StringConstraints(min_length=1)] - checks: List[AttributeOpValue] = [] - replies: List[AttributeOpValue] = [] - users: List[GroupUser] = [] + checks: list[AttributeOpValue] = [] + replies: list[AttributeOpValue] = [] + users: list[GroupUser] = [] @model_validator(mode="after") def check_fields_on_init(self): @@ -164,7 +163,7 @@ def exists(self, username: str) -> bool: counts = [count for count, in db_cursor.fetchall()] return sum(counts) > 0 - def find_all_usernames(self) -> List[str]: + def find_all_usernames(self) -> list[str]: with self._db_cursor() as db_cursor: sql = f"""SELECT DISTINCT username FROM {self.radcheck} UNION SELECT DISTINCT username FROM {self.radreply} @@ -173,12 +172,12 @@ def find_all_usernames(self) -> List[str]: usernames = [username for username, in db_cursor.fetchall()] return usernames - def find_usernames(self, from_username: str | None = None) -> List[str]: + def find_usernames(self, from_username: str | None = None) -> list[str]: if not from_username: return self._find_first_usernames() return self._find_next_usernames(from_username) - def _find_first_usernames(self) -> List[str]: + def _find_first_usernames(self) -> list[str]: with self._db_cursor() as db_cursor: sql = f""" SELECT username FROM ( @@ -191,7 +190,7 @@ def _find_first_usernames(self) -> List[str]: usernames = [username for username, in db_cursor.fetchall()] return usernames - def _find_next_usernames(self, from_username: str) -> List[str]: + def _find_next_usernames(self, from_username: str) -> list[str]: with self._db_cursor() as db_cursor: sql = f""" SELECT username FROM ( @@ -257,7 +256,7 @@ def exists(self, groupname: str) -> bool: counts = [count for count, in db_cursor.fetchall()] return sum(counts) > 0 - def find_all_groupnames(self) -> List[str]: + def find_all_groupnames(self) -> list[str]: with self._db_cursor() as db_cursor: sql = f"""SELECT DISTINCT groupname FROM {self.radgroupcheck} UNION SELECT DISTINCT groupname FROM {self.radgroupreply} @@ -266,12 +265,12 @@ def find_all_groupnames(self) -> List[str]: groupnames = [groupname for groupname, in db_cursor.fetchall()] return groupnames - def find_groupnames(self, from_groupname: str | None = None) -> List[str]: + def find_groupnames(self, from_groupname: str | None = None) -> list[str]: if not from_groupname: return self._find_first_groupnames() return self._find_next_groupnames(from_groupname) - def _find_first_groupnames(self) -> List[str]: + def _find_first_groupnames(self) -> list[str]: with self._db_cursor() as db_cursor: sql = f""" SELECT groupname FROM ( @@ -284,7 +283,7 @@ def _find_first_groupnames(self) -> List[str]: groupnames = [groupname for groupname, in db_cursor.fetchall()] return groupnames - def _find_next_groupnames(self, from_groupname: str) -> List[str]: + def _find_next_groupnames(self, from_groupname: str) -> list[str]: with self._db_cursor() as db_cursor: sql = f""" SELECT groupname FROM ( @@ -355,19 +354,19 @@ def exists(self, nasname: str) -> bool: (count,) = db_cursor.fetchone() return count > 0 - def find_all_nasnames(self) -> List[str]: + def find_all_nasnames(self) -> list[str]: with self._db_cursor() as db_cursor: sql = f"SELECT DISTINCT nasname FROM {self.nas}" db_cursor.execute(sql) nasnames = [nasname for nasname, in db_cursor.fetchall()] return nasnames - def find_nasnames(self, from_nasname: str | None = None) -> List[str]: + def find_nasnames(self, from_nasname: str | None = None) -> list[str]: if not from_nasname: return self._find_first_nasnames() return self._find_next_nasnames(from_nasname) - def _find_first_nasnames(self) -> List[str]: + def _find_first_nasnames(self) -> list[str]: with self._db_cursor() as db_cursor: sql = f"""SELECT DISTINCT nasname FROM {self.nas} ORDER BY nasname LIMIT {self._PER_PAGE}""" @@ -375,7 +374,7 @@ def _find_first_nasnames(self) -> List[str]: nasnames = [nasname for nasname, in db_cursor.fetchall()] return nasnames - def _find_next_nasnames(self, from_nasname: str) -> List[str]: + def _find_next_nasnames(self, from_nasname: str) -> list[str]: with self._db_cursor() as db_cursor: sql = f"""SELECT DISTINCT nasname FROM {self.nas} WHERE nasname > %s ORDER BY nasname LIMIT {self._PER_PAGE}"""