Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joegasewicz committed Nov 9, 2024
1 parent e7a15cd commit 4ae187e
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 23 deletions.
21 changes: 2 additions & 19 deletions flask_jwt_router/_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,8 @@
import jwt
from flask import g

from ._config import Config


class BaseAuthentication(ABC):
# pylint:disable=missing-class-docstring
@abstractmethod
def create_token(self, config: Config, exp: int, **kwargs):
# pylint:disable=missing-function-docstring
pass

@abstractmethod
def update_token(self, config: Config, exp: int, table_name, **kwarg):
# pylint:disable=missing-function-docstring
pass

@abstractmethod
def encode_token(self, config: Config, entity_id: Any, exp: int, table_name: str):
# pylint:disable=missing-function-docstring
pass
from flask_jwt_router._config import Config
from flask_jwt_router._base import BaseAuthentication


class Authentication(BaseAuthentication):
Expand Down
27 changes: 27 additions & 0 deletions flask_jwt_router/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from typing import Any
from datetime import datetime
# pylint:disable=wildcard-import,unused-wildcard-import
from dateutil.relativedelta import *
import jwt
from flask import g

from ._config import Config

class BaseAuthentication(ABC):
# pylint:disable=missing-class-docstring
@abstractmethod
def create_token(self, config: Config, exp: int, **kwargs):
# pylint:disable=missing-function-docstring
pass

@abstractmethod
def update_token(self, config: Config, exp: int, table_name, **kwarg):
# pylint:disable=missing-function-docstring
pass

@abstractmethod
def encode_token(self, config: Config, entity_id: Any, exp: int, table_name: str):
# pylint:disable=missing-function-docstring
pass

10 changes: 8 additions & 2 deletions flask_jwt_router/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class SecretKeyError(Exception):
message = "You must define a secret key. " \
message = "You must define a secret key OR public & private keys. " \
"See https://flask-jwt-router.readthedocs.io/en/latest/extensions.html"

def __init__(self):
Expand Down Expand Up @@ -37,6 +37,8 @@ class Config(BaseConfig):
:param entity_models: Multiple entities to be authenticated
:param expire_days: Expire time for the token in days
:param oauth_entity: If google_oauth options are declared then this will indicate the entity key in flight
:param public_key: TODO
:param private_key: TODO
:kwargs:
:param entity_models: Multiple entities to be authenticated
:param google_oauth: Options if the type or auth is Google's OAuth 2.0
Expand All @@ -50,6 +52,8 @@ class Config(BaseConfig):
expire_days: int
google_oauth: Dict
oauth_entity: str = None
public_key: str = None
private_key: str = None

def init_config(self, app_config: Dict[str, Any], **kwargs) -> None:
"""
Expand All @@ -64,8 +68,10 @@ def init_config(self, app_config: Dict[str, Any], **kwargs) -> None:
self.entity_models = app_config.get("ENTITY_MODELS") or kwargs.get("entity_models") or []
self.expire_days = app_config.get("JWT_EXPIRE_DAYS")
self.google_oauth = kwargs.get("google_oauth")
self.public_key = app_config.get("JWT_PUBLIC_KEY")
self.private_key = app_config.get("JWT_PRIVATE_KEY")

if not self.secret_key:
if not self.secret_key and (not self.private_key or not self.public_key):
raise SecretKeyError

if self.google_oauth:
Expand Down
5 changes: 5 additions & 0 deletions flask_jwt_router/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class AuthenticationError(Exception):
message = "algorithm kwarg must be set to either HS256 or RS256"

def __init__(self, err=""):
super(AuthenticationError, self).__init__(f"{err}\n{self.message}")
17 changes: 15 additions & 2 deletions flask_jwt_router/_jwt_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,15 @@ def login():
from ._config import Config
from ._entity import BaseEntity, Entity, _ORMType
from ._routing import BaseRouting, RoutingMixin
from ._authentication import BaseAuthentication, Authentication
from ._base import BaseAuthentication
from ._authentication import Authentication
from ._rsa_authentication import RSAAuthentication
from .oauth2.google import Google
from .oauth2._base import BaseOAuth, TestBaseOAuth
from .oauth2.http_requests import HttpRequests
from .oauth2._urls import GOOGLE_OAUTH_URL
from ._exceptions import AuthenticationError


# pylint:disable=invalid-name
logger = logging.getLogger()
Expand Down Expand Up @@ -233,6 +237,9 @@ class BaseJwtRoutes:
#: Optional. See :class:`~flask_jwt_router.oauth2.google`
google_oauth: Dict # TODO needs to be a list

#: Optional. TODO
algorithm: str = "HS256"

#: Optional. A Lust of strategies to be implement in the routing
strategies: List[BaseOAuth]

Expand All @@ -243,8 +250,14 @@ def __init__(self, app=None, **kwargs):
self.entity_models = kwargs.get("entity_models")
self.google_oauth = kwargs.get("google_oauth")
self.strategies = kwargs.get("strategies")
self.algorithm = kwargs.get("algorithm")
self.config = Config()
self.auth = Authentication()
if self.algorithm == "HS256":
self.auth = Authentication()
elif self.algorithm == "RS256":
self.auth = RSAAuthentication()
else:
raise AuthenticationError()
self.app = app
if app:
self.init_app(app, entity_models=self.entity_models)
Expand Down
4 changes: 4 additions & 0 deletions flask_jwt_router/_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def handle_token(self) -> None:
def init(self, app, config: Config, entity: BaseEntity, strategy_dict: Dict[str, BaseOAuth] = None) -> None:
pass

@abstractmethod
def before_middleware(self) -> None:
pass


class Routing(BaseRouting):
"""
Expand Down
85 changes: 85 additions & 0 deletions flask_jwt_router/_rsa_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from abc import ABC, abstractmethod
from typing import Any
from datetime import datetime
# pylint:disable=wildcard-import,unused-wildcard-import
from dateutil.relativedelta import *
import jwt
from flask import g

from flask_jwt_router._config import Config
from flask_jwt_router._base import BaseAuthentication


class RSAAuthentication(BaseAuthentication):
"""
Uses RSA-256 hash algorithm
"""
#: The reference to the entity key. Defaulted to `id`.
# See :class:`~flask_jwt_router._config` for more information.
entity_key: str = "id"

#: The reference to the entity key.
#: See :class:`~flask_jwt_router._config` for more information.
public_key: str = None

#: The reference to the entity key.
#: See :class:`~flask_jwt_router._config` for more information.
private_key: str = None

#: The reference to the entity ID.
entity_id: str = None

def __init__(self):
# pylint:disable=useless-super-delegation
super(RSAAuthentication, self).__init__()

def create_token(self, config: Config, exp: int, **kwargs):
"""
"""
self.entity_id = kwargs.get("entity_id", None)
table_name = kwargs.get("table_name", None)
return self.encode_token(config, self.entity_id, exp, table_name)

def encode_token(self, config: Config, entity_id: Any, exp: int, table_name: str):
"""
"""
self.entity_key = config.entity_key
self.secret_key = config.secret_key

encoded = jwt.encode({
"table_name": table_name,
self.entity_key: entity_id,
# pylint: disable=no-member
"exp": datetime.utcnow() + relativedelta(days=+exp)
}, self.private_key, algorithm="RS256")
try:
# Handle < pyJWT==2.0
encoded = encoded.decode("utf-8", self.public_key, algorithms=["RS256"])
except AttributeError:
pass
return encoded

def update_token(self,
config: Config,
exp: int,
table_name: str,
**kwargs,
) -> str:
"""
kwargs:
- entity_id: Represents the entity's primary key
:param config:
:param exp:
:param table_name:
:return: Union[str, None]
"""
self.entity_id = kwargs.get("entity_id", None)
return self.encode_token(config, self.entity_id, exp, table_name)

def get_oauth_token(self) -> str:
"""
:return: A Google OAuth 2.0 token
"""
return g.get("access_token")

0 comments on commit 4ae187e

Please sign in to comment.