Skip to content

Commit

Permalink
Merge pull request #10 from NatLibFi/feature/simplye-202/enable-mypy
Browse files Browse the repository at this point in the history
Fix mypy type checking errors
  • Loading branch information
attemoi authored Feb 7, 2024
2 parents fcca932 + 12005a6 commit 4cd2607
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 73 deletions.
17 changes: 10 additions & 7 deletions api/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import sys
from abc import ABC
from typing import Dict, Iterable, List, Optional, Tuple, Type
from typing import Dict, Iterable, List, Optional, Tuple, Type, cast

import flask
import jwt
Expand Down Expand Up @@ -431,7 +431,7 @@ def register_ekirjasto_provider(
self.ekirjasto_provider = provider

# Finland
def get_ekirjasto_provider(self) -> EkirjastoAuthenticationAPI:
def get_ekirjasto_provider(self) -> EkirjastoAuthenticationAPI | None:
return self.ekirjasto_provider

@property
Expand Down Expand Up @@ -495,9 +495,10 @@ def authenticated_patron(
return INVALID_EKIRJASTO_DELEGATE_TOKEN
provider = self.ekirjasto_provider
# Get decoded payload from the delegate token.
provider_token = provider.validate_ekirjasto_delegate_token(auth.token)
if isinstance(provider_token, ProblemDetail):
return provider_token
validate_result = provider.validate_ekirjasto_delegate_token(auth.token)
if isinstance(validate_result, ProblemDetail):
return validate_result
provider_token = validate_result
elif auth.type.lower() == "bearer":
# The patron wants to use an
# SAMLAuthenticationProvider. Figure out which one.
Expand Down Expand Up @@ -593,12 +594,14 @@ def create_bearer_token(
# Maybe we should use something custom instead.
iss=provider_name,
)
return jwt.encode(payload, self.bearer_token_signing_secret, algorithm="HS256")
return jwt.encode(
payload, cast(str, self.bearer_token_signing_secret), algorithm="HS256"
)

def decode_bearer_token(self, token: str) -> Tuple[str, str]:
"""Extract auth provider name and access token from JSON web token."""
decoded = jwt.decode(
token, self.bearer_token_signing_secret, algorithms=["HS256"]
token, cast(str, self.bearer_token_signing_secret), algorithms=["HS256"]
)
provider_name = decoded["iss"]
token = decoded["token"]
Expand Down
68 changes: 35 additions & 33 deletions api/ekirjasto_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC
from base64 import b64decode, b64encode
from enum import Enum
from typing import Any
from typing import Any, Tuple

import jwt
import requests
Expand Down Expand Up @@ -110,8 +110,8 @@ def __init__(
self.ekirjasto_environment = settings.ekirjasto_environment
self.delegate_expire_timemestamp = settings.delegate_expire_time

self.delegate_token_signing_secret = None
self.delegate_token_encrypting_secret = None
self.delegate_token_signing_secret: str | None = None
self.delegate_token_encrypting_secret: bytes | None = None

self.analytics = analytics

Expand Down Expand Up @@ -284,7 +284,7 @@ def get_credential_from_header(self, auth: Authorization) -> str | None:
# circulation API needs additional authentication.
return None

def get_patron_delegate_id(self, _db: Session, patron: Patron) -> str:
def get_patron_delegate_id(self, _db: Session, patron: Patron) -> str | None:
"""Find or randomly create an identifier to use when identifying
this patron from delegate token.
"""
Expand Down Expand Up @@ -343,17 +343,6 @@ def set_secrets(self, _db):
_db.commit()
self.delegate_token_encrypting_secret = secret.value.encode()

def _check_secrets_or_throw(self):
if (
self.delegate_token_signing_secret == None
or len(self.delegate_token_signing_secret) == 0
or self.delegate_token_encrypting_secret == None
or len(self.delegate_token_encrypting_secret) == 0
):
raise InternalServerError(
"Ekirjasto authenticator not fully setup, secrets are missing."
)

def create_ekirjasto_delegate_token(
self, provider_token: str, patron_delegate_id: str, expires: int
) -> str:
Expand All @@ -362,7 +351,15 @@ def create_ekirjasto_delegate_token(
The patron will use this as the authentication toekn to authentiacte againsy circulation backend.
"""
self._check_secrets_or_throw()
if not self.delegate_token_encrypting_secret:
raise InternalServerError(
"Error creating delegate token, encryption secret missing"
)

if not self.delegate_token_signing_secret:
raise InternalServerError(
"Error creating delegate token, signing secret missing"
)

# Encrypt the ekirjasto token with a128cbc-hs256 algorithm.
fernet = Fernet(self.delegate_token_encrypting_secret)
Expand All @@ -377,6 +374,7 @@ def create_ekirjasto_delegate_token(
iat=int(utc_now().timestamp()),
exp=expires,
)

return jwt.encode(
payload, self.delegate_token_signing_secret, algorithm="HS256"
)
Expand All @@ -392,7 +390,10 @@ def decode_ekirjasto_delegate_token(
return decoded payload
"""
self._check_secrets_or_throw()
if not self.delegate_token_signing_secret:
raise InternalServerError(
"Error decoding delegate token, signing secret missing"
)

options = dict(
verify_signature=True,
Expand All @@ -418,6 +419,10 @@ def decode_ekirjasto_delegate_token(
return decoded_payload

def _decrypt_ekirjasto_token(self, token: str):
if not self.delegate_token_encrypting_secret:
raise InternalServerError(
"Error decrypting ekirjasto token, signing secret missing"
)
fernet = Fernet(self.delegate_token_encrypting_secret)
encrypted_token = b64decode(token.encode("ascii"))
return fernet.decrypt(encrypted_token).decode()
Expand Down Expand Up @@ -445,7 +450,9 @@ def validate_ekirjasto_delegate_token(
return INVALID_EKIRJASTO_DELEGATE_TOKEN
return decoded_payload

def remote_refresh_token(self, token: str) -> (str, int):
def remote_refresh_token(
self, token: str
) -> Tuple[ProblemDetail, None] | Tuple[str, int]:
"""Refresh ekirjasto token with ekirjasto API call.
We assume that the token is valid, API call fails if not.
Expand All @@ -454,9 +461,9 @@ def remote_refresh_token(self, token: str) -> (str, int):
"""

if self.ekirjasto_environment == EkirjastoEnvironment.FAKE:
token = self.fake_ekirjasto_token
expires = utc_now() + datetime.timedelta(days=1)
return token, expires.timestamp()
fake_token = self.fake_ekirjasto_token
expire_date = utc_now() + datetime.timedelta(days=1)
return fake_token, int(expire_date.timestamp())

url = self._ekirjasto_api_url + "/v1/auth/refresh"

Expand All @@ -469,10 +476,6 @@ def remote_refresh_token(self, token: str) -> (str, int):
# Do nothing if authentication fails, e.g. token expired.
return INVALID_EKIRJASTO_TOKEN, None
elif response.status_code != 200:
msg = "Got unexpected response code %d. Content: %s" % (
response.status_code,
response.content,
)
return EKIRJASTO_REMOTE_AUTHENTICATION_FAILED, None
else:
try:
Expand All @@ -482,6 +485,7 @@ def remote_refresh_token(self, token: str) -> (str, int):

token = content["token"]
expires = content["exp"]

return token, expires

def remote_patron_lookup(
Expand Down Expand Up @@ -532,8 +536,6 @@ def remote_patron_lookup(

return self._userinfo_to_patrondata(content)

return EKIRJASTO_REMOTE_AUTHENTICATION_FAILED

def remote_authenticate(
self, ekirjasto_token: str | None
) -> PatronData | ProblemDetail | None:
Expand Down Expand Up @@ -608,7 +610,7 @@ def local_patron_lookup(

def ekirjasto_authenticate(
self, _db: Session, ekirjasto_token: str
) -> (Patron, bool):
) -> Tuple[PatronData | Patron | ProblemDetail | None, bool]:
"""Authenticate patron with remote ekirjasto API and if necessary,
create authenticated patron if not in database.
Expand All @@ -620,17 +622,17 @@ def ekirjasto_authenticate(
log_method=self.logger().info,
message_prefix="authenticated_patron - ekirjasto_authenticate",
):
patron = self.authenticate_and_update_patron(_db, ekirjasto_token)
auth_result = self.authenticate_and_update_patron(_db, ekirjasto_token)

if isinstance(patron, PatronData):
if isinstance(auth_result, PatronData):
# We didn't find the patron, but authentication to external truth was
# succesfull, so we create a new patron with the information we have.
patron, is_new = patron.get_or_create_patron(
# successful, so we create a new patron with the information we have.
patron, is_new = auth_result.get_or_create_patron(
_db, self.library_id, analytics=self.analytics
)
patron.last_external_sync = utc_now()

return patron, is_new
return auth_result, is_new

def authenticated_patron(
self, _db: Session, authorization: dict | str
Expand Down
23 changes: 9 additions & 14 deletions api/ekirjasto_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,30 @@ def __init__(self, circulation_manager, authenticator):

self._logger = logging.getLogger(__name__)

def _get_delegate_expire_timestamp(self, ekirjasto_token_expires: int) -> int:
def _get_delegate_expire_timestamp(self, ekirjasto_expire_millis: int) -> int:
"""Get the expire time to use for delegate token, it is calculated based on
expire time of the ekirjasto token.
:param ekirjasto_token_expires: Ekirjasto token expiration timestamp in milliseconds.
:param ekirjasto_expire_millis: Ekirjasto token expiration timestamp in milliseconds.
:return: Timestamp for the delegate token expiration.
:return: Timestamp for the delegate token expiration in seconds.
"""

# Ekirjasto expire is in milliseconds but JWT uses seconds.
ekirjasto_token_expires = int(ekirjasto_token_expires / 1000)

delegate_token_expires = (
delegate_expire_seconds = (
utc_now()
+ datetime.timedelta(
seconds=self._authenticator.ekirjasto_provider.delegate_expire_timemestamp
)
).timestamp()

# Use ekirjasto expire time at 70 % of the remaining duration, so we have some time to refresh it.
now_seconds = utc_now().timestamp()
ekirjasto_token_expires = (
ekirjasto_token_expires - now_seconds
) * 0.7 + now_seconds

if ekirjasto_token_expires < delegate_token_expires:
return int(ekirjasto_token_expires)
# Use ekirjasto expire time at 70 % of the remaining duration, so we have some time to refresh it.
ekirjasto_expire_seconds = (
(ekirjasto_expire_millis / 1000) - now_seconds
) * 0.7 + now_seconds

return int(delegate_token_expires)
return int(min(ekirjasto_expire_seconds, delegate_expire_seconds))

def get_tokens(self, authorization, validate_expire=False):
"""Extract possible delegate and ekirjasto tokens from the authorization header."""
Expand Down
4 changes: 2 additions & 2 deletions core/model/patron.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class Patron(Base):
)

loan_checkouts: Mapped[List[LoanCheckout]] = relationship(
"LoanCheckout", backref="patron", cascade="delete", uselist=True
"LoanCheckout", back_populates="patron", cascade="delete", uselist=True
)

holds: Mapped[List[Hold]] = relationship(
Expand Down Expand Up @@ -586,7 +586,7 @@ class LoanCheckout(Base, LoanAndHoldMixin):
patron_id = Column(
Integer, ForeignKey("patrons.id", ondelete="CASCADE"), index=True
)
patron: Patron
patron: Mapped[Patron] = relationship("Patron", back_populates="loan_checkouts")

license_pool_id = Column(Integer, ForeignKey("licensepools.id"), index=True)
license_pool: Mapped[LicensePool] = relationship("LicensePool")
Expand Down
Loading

0 comments on commit 4cd2607

Please sign in to comment.