Skip to content

Commit

Permalink
Fix mypy type checking errors
Browse files Browse the repository at this point in the history
  • Loading branch information
attemoi committed Feb 5, 2024
1 parent fcca932 commit f47fdd4
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 67 deletions.
17 changes: 10 additions & 7 deletions api/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import sys
from abc import ABC
from typing import Dict, Iterable, List, Optional, Tuple, Type
from typing import Dict, Iterable, List, Optional, Tuple, Type, cast

import flask
import jwt
Expand Down Expand Up @@ -431,7 +431,7 @@ def register_ekirjasto_provider(
self.ekirjasto_provider = provider

# Finland
def get_ekirjasto_provider(self) -> EkirjastoAuthenticationAPI:
def get_ekirjasto_provider(self) -> EkirjastoAuthenticationAPI | None:
return self.ekirjasto_provider

@property
Expand Down Expand Up @@ -495,9 +495,10 @@ def authenticated_patron(
return INVALID_EKIRJASTO_DELEGATE_TOKEN
provider = self.ekirjasto_provider
# Get decoded payload from the delegate token.
provider_token = provider.validate_ekirjasto_delegate_token(auth.token)
if isinstance(provider_token, ProblemDetail):
return provider_token
validate_result = provider.validate_ekirjasto_delegate_token(auth.token)
if isinstance(validate_result, ProblemDetail):
return validate_result
provider_token = validate_result
elif auth.type.lower() == "bearer":
# The patron wants to use an
# SAMLAuthenticationProvider. Figure out which one.
Expand Down Expand Up @@ -593,12 +594,14 @@ def create_bearer_token(
# Maybe we should use something custom instead.
iss=provider_name,
)
return jwt.encode(payload, self.bearer_token_signing_secret, algorithm="HS256")
return jwt.encode(
payload, cast(str, self.bearer_token_signing_secret), algorithm="HS256"
)

def decode_bearer_token(self, token: str) -> Tuple[str, str]:
"""Extract auth provider name and access token from JSON web token."""
decoded = jwt.decode(
token, self.bearer_token_signing_secret, algorithms=["HS256"]
token, cast(str, self.bearer_token_signing_secret), algorithms=["HS256"]
)
provider_name = decoded["iss"]
token = decoded["token"]
Expand Down
68 changes: 35 additions & 33 deletions api/ekirjasto_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC
from base64 import b64decode, b64encode
from enum import Enum
from typing import Any
from typing import Any, Tuple

import jwt
import requests
Expand Down Expand Up @@ -110,8 +110,8 @@ def __init__(
self.ekirjasto_environment = settings.ekirjasto_environment
self.delegate_expire_timemestamp = settings.delegate_expire_time

self.delegate_token_signing_secret = None
self.delegate_token_encrypting_secret = None
self.delegate_token_signing_secret: str | None = None
self.delegate_token_encrypting_secret: bytes | None = None

self.analytics = analytics

Expand Down Expand Up @@ -284,7 +284,7 @@ def get_credential_from_header(self, auth: Authorization) -> str | None:
# circulation API needs additional authentication.
return None

def get_patron_delegate_id(self, _db: Session, patron: Patron) -> str:
def get_patron_delegate_id(self, _db: Session, patron: Patron) -> str | None:
"""Find or randomly create an identifier to use when identifying
this patron from delegate token.
"""
Expand Down Expand Up @@ -343,17 +343,6 @@ def set_secrets(self, _db):
_db.commit()
self.delegate_token_encrypting_secret = secret.value.encode()

def _check_secrets_or_throw(self):
if (
self.delegate_token_signing_secret == None
or len(self.delegate_token_signing_secret) == 0
or self.delegate_token_encrypting_secret == None
or len(self.delegate_token_encrypting_secret) == 0
):
raise InternalServerError(
"Ekirjasto authenticator not fully setup, secrets are missing."
)

def create_ekirjasto_delegate_token(
self, provider_token: str, patron_delegate_id: str, expires: int
) -> str:
Expand All @@ -362,7 +351,15 @@ def create_ekirjasto_delegate_token(
The patron will use this as the authentication toekn to authentiacte againsy circulation backend.
"""
self._check_secrets_or_throw()
if not self.delegate_token_encrypting_secret:
raise InternalServerError(
"Error creating delegate token, encryption secret missing"
)

if not self.delegate_token_signing_secret:
raise InternalServerError(
"Error creating delegate token, signing secret missing"
)

# Encrypt the ekirjasto token with a128cbc-hs256 algorithm.
fernet = Fernet(self.delegate_token_encrypting_secret)
Expand All @@ -377,6 +374,7 @@ def create_ekirjasto_delegate_token(
iat=int(utc_now().timestamp()),
exp=expires,
)

return jwt.encode(
payload, self.delegate_token_signing_secret, algorithm="HS256"
)
Expand All @@ -392,7 +390,10 @@ def decode_ekirjasto_delegate_token(
return decoded payload
"""
self._check_secrets_or_throw()
if not self.delegate_token_signing_secret:
raise InternalServerError(
"Error decoding delegate token, signing secret missing"
)

options = dict(
verify_signature=True,
Expand All @@ -418,6 +419,10 @@ def decode_ekirjasto_delegate_token(
return decoded_payload

def _decrypt_ekirjasto_token(self, token: str):
if not self.delegate_token_encrypting_secret:
raise InternalServerError(
"Error decrypting ekirjasto token, signing secret missing"
)
fernet = Fernet(self.delegate_token_encrypting_secret)
encrypted_token = b64decode(token.encode("ascii"))
return fernet.decrypt(encrypted_token).decode()
Expand Down Expand Up @@ -445,7 +450,9 @@ def validate_ekirjasto_delegate_token(
return INVALID_EKIRJASTO_DELEGATE_TOKEN
return decoded_payload

def remote_refresh_token(self, token: str) -> (str, int):
def remote_refresh_token(
self, token: str
) -> Tuple[ProblemDetail, None] | Tuple[str, int]:
"""Refresh ekirjasto token with ekirjasto API call.
We assume that the token is valid, API call fails if not.
Expand All @@ -454,9 +461,9 @@ def remote_refresh_token(self, token: str) -> (str, int):
"""

if self.ekirjasto_environment == EkirjastoEnvironment.FAKE:
token = self.fake_ekirjasto_token
expires = utc_now() + datetime.timedelta(days=1)
return token, expires.timestamp()
fake_token = self.fake_ekirjasto_token
expire_date = utc_now() + datetime.timedelta(days=1)
return fake_token, int(expire_date.timestamp())

url = self._ekirjasto_api_url + "/v1/auth/refresh"

Expand All @@ -469,10 +476,6 @@ def remote_refresh_token(self, token: str) -> (str, int):
# Do nothing if authentication fails, e.g. token expired.
return INVALID_EKIRJASTO_TOKEN, None
elif response.status_code != 200:
msg = "Got unexpected response code %d. Content: %s" % (
response.status_code,
response.content,
)
return EKIRJASTO_REMOTE_AUTHENTICATION_FAILED, None
else:
try:
Expand All @@ -482,6 +485,7 @@ def remote_refresh_token(self, token: str) -> (str, int):

token = content["token"]
expires = content["exp"]

return token, expires

def remote_patron_lookup(
Expand Down Expand Up @@ -532,8 +536,6 @@ def remote_patron_lookup(

return self._userinfo_to_patrondata(content)

return EKIRJASTO_REMOTE_AUTHENTICATION_FAILED

def remote_authenticate(
self, ekirjasto_token: str | None
) -> PatronData | ProblemDetail | None:
Expand Down Expand Up @@ -608,7 +610,7 @@ def local_patron_lookup(

def ekirjasto_authenticate(
self, _db: Session, ekirjasto_token: str
) -> (Patron, bool):
) -> Tuple[PatronData | Patron | ProblemDetail | None, bool]:
"""Authenticate patron with remote ekirjasto API and if necessary,
create authenticated patron if not in database.
Expand All @@ -620,17 +622,17 @@ def ekirjasto_authenticate(
log_method=self.logger().info,
message_prefix="authenticated_patron - ekirjasto_authenticate",
):
patron = self.authenticate_and_update_patron(_db, ekirjasto_token)
auth_result = self.authenticate_and_update_patron(_db, ekirjasto_token)

if isinstance(patron, PatronData):
if isinstance(auth_result, PatronData):
# We didn't find the patron, but authentication to external truth was
# succesfull, so we create a new patron with the information we have.
patron, is_new = patron.get_or_create_patron(
# successful, so we create a new patron with the information we have.
patron, is_new = auth_result.get_or_create_patron(
_db, self.library_id, analytics=self.analytics
)
patron.last_external_sync = utc_now()

return patron, is_new
return auth_result, is_new

def authenticated_patron(
self, _db: Session, authorization: dict | str
Expand Down
21 changes: 8 additions & 13 deletions api/ekirjasto_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,30 @@ def __init__(self, circulation_manager, authenticator):

self._logger = logging.getLogger(__name__)

def _get_delegate_expire_timestamp(self, ekirjasto_token_expires: int) -> int:
def _get_delegate_expire_timestamp(self, ekirjasto_expire_millis: int) -> int:
"""Get the expire time to use for delegate token, it is calculated based on
expire time of the ekirjasto token.
:param ekirjasto_token_expires: Ekirjasto token expiration timestamp in milliseconds.
:return: Timestamp for the delegate token expiration.
:return: Timestamp for the delegate token expiration in seconds.
"""

# Ekirjasto expire is in milliseconds but JWT uses seconds.
ekirjasto_token_expires = int(ekirjasto_token_expires / 1000)

delegate_token_expires = (
delegate_expire_seconds = (
utc_now()
+ datetime.timedelta(
seconds=self._authenticator.ekirjasto_provider.delegate_expire_timemestamp
)
).timestamp()

# Use ekirjasto expire time at 70 % of the remaining duration, so we have some time to refresh it.
now_seconds = utc_now().timestamp()
ekirjasto_token_expires = (
ekirjasto_token_expires - now_seconds
) * 0.7 + now_seconds

if ekirjasto_token_expires < delegate_token_expires:
return int(ekirjasto_token_expires)
# Use ekirjasto expire time at 70 % of the remaining duration, so we have some time to refresh it.
ekirjasto_expire_seconds = (
(ekirjasto_expire_millis / 1000) - now_seconds
) * 0.7 + now_seconds

return int(delegate_token_expires)
return int(min(ekirjasto_expire_seconds, delegate_expire_seconds))

def get_tokens(self, authorization, validate_expire=False):
"""Extract possible delegate and ekirjasto tokens from the authorization header."""
Expand Down
4 changes: 2 additions & 2 deletions core/model/patron.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class Patron(Base):
)

loan_checkouts: Mapped[List[LoanCheckout]] = relationship(
"LoanCheckout", backref="patron", cascade="delete", uselist=True
"LoanCheckout", back_populates="patron", cascade="delete", uselist=True
)

holds: Mapped[List[Hold]] = relationship(
Expand Down Expand Up @@ -586,7 +586,7 @@ class LoanCheckout(Base, LoanAndHoldMixin):
patron_id = Column(
Integer, ForeignKey("patrons.id", ondelete="CASCADE"), index=True
)
patron: Patron
patron: Mapped[Patron] = relationship("Patron", back_populates="loan_checkouts")

license_pool_id = Column(Integer, ForeignKey("licensepools.id"), index=True)
license_pool: Mapped[LicensePool] = relationship("LicensePool")
Expand Down
20 changes: 9 additions & 11 deletions tests/finland/test_ekirjasto.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_authentication_flow_document(
controller_fixture.app.config["SERVER_NAME"] = "localhost"

with controller_fixture.app.test_request_context("/"):
doc = provider._authentication_flow_document(None)
doc = provider._authentication_flow_document(controller_fixture.db.session)
assert provider.label() == doc["description"]
assert provider.flow_type == doc["type"]

Expand Down Expand Up @@ -311,19 +311,12 @@ def test_secrets(
):
provider = create_provider()

# Secrets are not set, so this will fail.
pytest.raises(InternalServerError, provider._check_secrets_or_throw)
assert provider.delegate_token_signing_secret == None
assert provider.delegate_token_encrypting_secret == None

provider.set_secrets(controller_fixture.db.session)

provider._check_secrets_or_throw()

assert provider.delegate_token_signing_secret != None
assert provider.delegate_token_encrypting_secret != None
assert provider.delegate_token_signing_secret is not None
assert provider.delegate_token_encrypting_secret is not None

# Screts should be strong enough.
# Secrets should be strong enough.
assert len(provider.delegate_token_signing_secret) > 30
assert len(provider.delegate_token_encrypting_secret) > 30

Expand Down Expand Up @@ -685,6 +678,9 @@ def test_authenticated_patron_delegate_token_expired(
decoded_payload = provider.validate_ekirjasto_delegate_token(
delegate_token, validate_expire=False
)

assert type(decoded_payload) is dict

patron = provider.authenticated_patron(
controller_fixture.db.session, decoded_payload
)
Expand Down Expand Up @@ -718,6 +714,8 @@ def test_authenticated_patron_ekirjasto_token_invald(
controller_fixture.db.session, patron
)

assert patron_delegate_id is not None

# Delegate token with the ekirjasto token.
delegate_token = provider.create_ekirjasto_delegate_token(
ekirjasto_token, patron_delegate_id, expires_at
Expand Down
2 changes: 1 addition & 1 deletion tests/finland/test_opensearch_analytics_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class MockLibrary:
def short_name():
def short_name(self):
return "testlib"


Expand Down

0 comments on commit f47fdd4

Please sign in to comment.