diff --git a/src/pyfreeradius.py b/src/pyfreeradius.py index f7281ee..8dcb3ee 100644 --- a/src/pyfreeradius.py +++ b/src/pyfreeradius.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from pydantic import BaseModel, IPvAnyAddress, StringConstraints, Field, model_validator +from pydantic import BaseModel, StringConstraints, Field, model_validator from typing import List from typing_extensions import Annotated @@ -99,7 +99,7 @@ def check_fields_on_init(self): class Nas(BaseModel): - nasname: IPvAnyAddress + nasname: Annotated[str, StringConstraints(min_length=1)] shortname: Annotated[str, StringConstraints(min_length=1)] secret: Annotated[str, StringConstraints(min_length=1)] @@ -348,10 +348,10 @@ class NasRepository(BaseRepository): def __init__(self, db_connection, db_tables: RadTables): super().__init__(db_connection, db_tables) - def exists(self, nasname: IPvAnyAddress) -> bool: + def exists(self, nasname: str) -> bool: with self._db_cursor() as db_cursor: sql = f"SELECT COUNT(DISTINCT nasname) FROM {self.nas} WHERE nasname = %s" - db_cursor.execute(sql, (str(nasname),)) + db_cursor.execute(sql, (nasname,)) (count,) = db_cursor.fetchone() return count > 0 @@ -362,7 +362,7 @@ def find_all_nasnames(self) -> List[str]: nasnames = [nasname for nasname, in db_cursor.fetchall()] return nasnames - def find_nasnames(self, from_nasname: IPvAnyAddress = None) -> List[str]: + def find_nasnames(self, from_nasname: str = None) -> List[str]: if not from_nasname: return self._find_first_nasnames() return self._find_next_nasnames(from_nasname) @@ -375,29 +375,29 @@ def _find_first_nasnames(self) -> List[str]: nasnames = [nasname for nasname, in db_cursor.fetchall()] return nasnames - def _find_next_nasnames(self, from_nasname: IPvAnyAddress) -> List[str]: + def _find_next_nasnames(self, from_nasname: str) -> List[str]: with self._db_cursor() as db_cursor: sql = f"""SELECT DISTINCT nasname FROM {self.nas} WHERE nasname > %s ORDER BY nasname LIMIT {self._PER_PAGE}""" - db_cursor.execute(sql, (str(from_nasname),)) + db_cursor.execute(sql, (from_nasname,)) nasnames = [nasname for nasname, in db_cursor.fetchall()] return nasnames - def find_one(self, nasname: IPvAnyAddress) -> Nas: + def find_one(self, nasname: str) -> Nas: if not self.exists(nasname): return None with self._db_cursor() as db_cursor: sql = f"SELECT nasname, shortname, secret FROM {self.nas} WHERE nasname = %s" - db_cursor.execute(sql, (str(nasname),)) + db_cursor.execute(sql, (nasname,)) n, sh, se = db_cursor.fetchone() return Nas(nasname=n, shortname=sh, secret=se) def add(self, nas: Nas): with self._db_cursor() as db_cursor: sql = f"INSERT INTO {self.nas} (nasname, shortname, secret) VALUES (%s, %s, %s)" - db_cursor.execute(sql, (str(nas.nasname), nas.shortname, nas.secret)) + db_cursor.execute(sql, (nas.nasname, nas.shortname, nas.secret)) - def remove(self, nasname: IPvAnyAddress): + def remove(self, nasname: str): with self._db_cursor() as db_cursor: - db_cursor.execute(f"DELETE FROM {self.nas} WHERE nasname = %s", (str(nasname),)) + db_cursor.execute(f"DELETE FROM {self.nas} WHERE nasname = %s", (nasname,)) diff --git a/src/tests/test_pyfreeradius.py b/src/tests/test_pyfreeradius.py index d043ca4..b24384f 100644 --- a/src/tests/test_pyfreeradius.py +++ b/src/tests/test_pyfreeradius.py @@ -115,9 +115,9 @@ def test_valid_nas(): # Repository: finding assert nas_repo.find_one(n.nasname) == n - assert str(n.nasname) in nas_repo.find_all_nasnames() - assert str(n.nasname) in nas_repo.find_nasnames() - assert str(n.nasname) in nas_repo.find_nasnames(from_nasname="1.1.1.0") + assert n.nasname in nas_repo.find_all_nasnames() + assert n.nasname in nas_repo.find_nasnames() + assert n.nasname in nas_repo.find_nasnames(from_nasname="1.1.1.0") # Repository: removing nas_repo.remove(n.nasname)