Skip to content

Commit

Permalink
Merge pull request #699 from OpenIDC/fix-mypy
Browse files Browse the repository at this point in the history
Fix updated mypy
  • Loading branch information
tpazderka authored Oct 25, 2019
2 parents 85f803b + ba81526 commit 2a949cb
Show file tree
Hide file tree
Showing 38 changed files with 235 additions and 128 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_tests(self):
'develop': ["cherrypy==3.2.4", "pyOpenSSL"],
'testing': tests_requires,
'docs': ['Sphinx', 'sphinx-autobuild', 'alabaster'],
'quality': ['pylama', 'isort', 'eradicate', 'mypy==0.730', 'black', 'bandit'],
'quality': ['pylama', 'isort', 'eradicate', 'mypy', 'black', 'bandit'],
'ldap_authn': ['pyldap'],
},
install_requires=[
Expand Down
6 changes: 4 additions & 2 deletions src/oic/extension/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,11 @@ def parse_authz_response(self, query):
sformat="urlencoded",
keyjar=self.keyjar,
)
if aresp.type() == "ErrorResponse":
if isinstance(aresp, ErrorResponse):
logger.info("ErrorResponse: %s" % sanitize(aresp))
raise AuthzError(aresp.error)
raise AuthzError(
aresp.error # type: ignore # Messages have no classical attrs
)

logger.info("Aresp: %s" % sanitize(aresp))

Expand Down
2 changes: 1 addition & 1 deletion src/oic/extension/device_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, host):
# map between device_code and user_code
self.client_id2device = {} # type: Dict[str, str]
self.device2user = {} # type: Dict[str, str]
self.user_auth = {} # type: Dict[str, str]
self.user_auth = {} # type: Dict[str, bool]
self.device_code_expire_at = {} # type: Dict[str, int]
self.device_code_life_time = 900 # 15 minutes

Expand Down
7 changes: 4 additions & 3 deletions src/oic/extension/proof_of_possesion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import json
import time
from typing import Mapping # noqa
from typing import Dict # noqa
from urllib.parse import parse_qs
from urllib.parse import parse_qsl

Expand All @@ -11,6 +11,7 @@

from oic.extension.signed_http_req import SignedHttpRequest
from oic.extension.signed_http_req import ValidationError
from oic.oauth2 import error_response
from oic.oic.message import AccessTokenRequest
from oic.oic.message import AccessTokenResponse
from oic.oic.provider import Provider
Expand All @@ -29,7 +30,7 @@ def __init__(self, *args, **kwargs):
super(PoPProvider, self).__init__(*args, **kwargs)

# mapping from signed pop token to access token in db
self.access_tokens = {} # type: Mapping[JWS, str]
self.access_tokens = {} # type: Dict[JWS, str]

def token_endpoint(self, dtype="urlencoded", **kwargs):
atr = AccessTokenRequest().deserialize(kwargs["request"], dtype)
Expand Down Expand Up @@ -76,7 +77,7 @@ def userinfo_endpoint(self, request, **kwargs):
strict_headers_verification=False,
)
except ValidationError:
return self._error_response(
return error_response(
"access_denied", descr="Could not verify proof of " "possession"
)

Expand Down
4 changes: 2 additions & 2 deletions src/oic/extension/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
self.jwks_uri = jwks_uri
self.verify_ssl = verify_ssl
self.scopes.extend(kwargs.get("scopes", []))
self.keyjar = keyjar
self.keyjar = keyjar # type: KeyJar
if self.keyjar is None:
self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

Expand All @@ -182,7 +182,7 @@ def __init__(
self.token_policy = {
"access_token": {},
"refresh_token": {},
} # type: Dict[str, Dict[str, str]]
} # type: Dict[str, Dict[str, Dict[str, int]]]
if lifetime_policy is None:
self.lifetime_policy = {
"access_token": {
Expand Down
9 changes: 6 additions & 3 deletions src/oic/oauth2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from oic.utils.http_util import Response
from oic.utils.http_util import SeeOther
from oic.utils.keyio import KeyJar
from oic.utils.sdb import SessionBackend # noqa
from oic.utils.sdb import session_update
from oic.utils.time_util import utc_time_sans_frac

Expand Down Expand Up @@ -196,17 +197,19 @@ def __init__(
timeout=timeout,
)

self.sso_db = None
self.sso_db = None # type: Optional[SessionBackend]
self.client_id = client_id
self.client_authn_method = client_authn_method

self.nonce = None
self.nonce = None # type: Optional[str]

self.message_factory = message_factory
self.grant = {} # type: Dict[str, Grant]
self.state2nonce = {} # type: Dict[str, str]
# own endpoint
self.redirect_uris = [] # type: List[str]
# Default behaviour
self.response_type = ["code"]

# service endpoints
self.authorization_endpoint = None # type: Optional[str]
Expand Down Expand Up @@ -651,7 +654,7 @@ def parse_response(
except KeyError:
self.grant[_state] = self.grant_class(resp=resp)

if "id_token" in resp and self.sso_db:
if "id_token" in resp and self.sso_db is not None:
session_update(self.sso_db, _state, "sub", resp["id_token"]["sub"])
session_update(self.sso_db, _state, "issuer", resp["id_token"]["iss"])
if "sid" in resp["id_token"]:
Expand Down
2 changes: 1 addition & 1 deletion src/oic/oauth2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _cookies(self):
"""Turn cookiejar into a dict."""
cookie_dict = {}

for _, a in list(self.cookiejar._cookies.items()):
for _, a in list(self.cookiejar._cookies.items()): # type: ignore
for _, b in list(a.items()):
for cookie in list(b.values()):
cookie_dict[cookie.name] = cookie.value
Expand Down
5 changes: 1 addition & 4 deletions src/oic/oauth2/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,7 @@ def begin(self, baseurl, request, response_type="", **kwargs):
self.sdb["seed:%s" % self.seed] = sid

if not response_type:
if self.response_type:
response_type = self.response_type
else:
self.response_type = response_type = "code"
response_type = self.response_type

location = self.request_info(
AuthorizationRequest,
Expand Down
4 changes: 2 additions & 2 deletions src/oic/oauth2/grant.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def __init__(self, exp_in=600, resp=None, seed=""):
self.grant_expiration_time = 0
self.exp_in = exp_in
self.seed = seed
self.tokens = [] # type: List[Token]
self.tokens = []
self.id_token = None
self.code = None # type: str
self.code = None # type: Optional[str]
if resp:
self.add_code(resp)
self.add_token(resp)
Expand Down
5 changes: 3 additions & 2 deletions src/oic/oauth2/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import namedtuple
from json import JSONDecodeError
from typing import Any # noqa - This is used for MyPy
from typing import Dict # noqa - This is used for MyPy
from typing import List # noqa - This is used for MyPy
from typing import Mapping # noqa - This is used for MyPy
from typing import Optional # noqa - This is used for MyPy
Expand Down Expand Up @@ -135,7 +136,7 @@ def jwt_header(txt):

class Message(MutableMapping):
c_param = {} # type: Mapping[str, ParamDefinition]
c_default = {} # type: Mapping[str, Any]
c_default = {} # type: Dict[str, Any]
c_allowed_values = {} # type: ignore

def __init__(self, **kwargs):
Expand Down Expand Up @@ -266,7 +267,7 @@ def from_urlencoded(self, urlencoded, **kwargs):
cparam = self._extract_cparam(key, _spec)
if cparam is None:
if len(val) == 1:
val = val[0]
val = val[0] # type: ignore

self._dict[key] = val
continue
Expand Down
2 changes: 1 addition & 1 deletion src/oic/oauth2/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(

if capabilities:
self.verify_capabilities(capabilities)
self.capabilities = self.message_factory.get_response_type(
self.capabilities = message_factory.get_response_type(
"configuration_endpoint"
)(**capabilities)
else:
Expand Down
3 changes: 2 additions & 1 deletion src/oic/oic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def __init__(
for endpoint in ENDPOINTS:
setattr(self, endpoint, "")

self.id_token = None
self.id_token = {} # type: Dict[str, Token]
self.log = None

self.request2endpoint = REQUEST2ENDPOINT
Expand All @@ -367,6 +367,7 @@ def __init__(
self.client_prefs = client_prefs or {}

self.behaviour = {} # type: Dict[str, Any]
self.scope = ["openid"]

self.wf = WebFinger(OIC_ISSUER)
self.wf.httpd = self
Expand Down
2 changes: 1 addition & 1 deletion src/oic/oic/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(
"Please use `SessionBackend` to ensure proper API for the database.",
DeprecationWarning,
)
self.sso_db = sso_db
self.sso_db = sso_db # type: SessionBackend
else:
self.sso_db = DictSessionBackend()

Expand Down
5 changes: 4 additions & 1 deletion src/oic/oic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1700,7 +1700,10 @@ def read_registration(self, authn, request, **kwargs):

# Get client_id from request
_info = parse_qs(request)
client_id = _info.get("client_id", [None])[0]
cid = _info.get("client_id")
if cid is None:
return Unauthorized()
client_id = cid[0]

cdb_entry = self.cdb.get(client_id)
if cdb_entry is None:
Expand Down
21 changes: 20 additions & 1 deletion src/oic/utils/aes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@
import os
from base64 import b64decode
from base64 import b64encode
from typing import Union # noqa
from typing import cast

from Cryptodome import Random
from Cryptodome.Cipher import AES
from Cryptodome.Cipher._mode_ccm import CcmMode # noqa
from Cryptodome.Cipher._mode_eax import EaxMode # noqa
from Cryptodome.Cipher._mode_gcm import GcmMode # noqa
from Cryptodome.Cipher._mode_ocb import OcbMode # noqa
from Cryptodome.Cipher._mode_siv import SivMode # noqa

from oic.utils import tobytes

Expand Down Expand Up @@ -139,9 +146,21 @@ def __init__(self, key, iv, mode=AES.MODE_SIV):
assert isinstance(key, bytes)
assert isinstance(iv, bytes)
self.key = key
# The code is written in such a way, that only these modes are actually supported
# The other ones are missing `encrypt_and_digest`, `decrypt_and_verify` and `update` methods
assert mode in (
AES.MODE_CCM,
AES.MODE_EAX,
AES.MODE_GCM,
AES.MODE_SIV,
AES.MODE_OCB,
)
self.mode = mode
self.iv = iv
self.kernel = AES.new(self.key, self.mode, self.iv)
self.kernel = cast(
Union[CcmMode, EaxMode, GcmMode, SivMode, OcbMode],
AES.new(self.key, self.mode, self.iv),
)

def add_associated_data(self, data):
"""
Expand Down
35 changes: 20 additions & 15 deletions src/oic/utils/authn/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import logging
from typing import Dict # noqa
from typing import List # noqa
from typing import Mapping # noqa
from urllib.parse import parse_qs
from urllib.parse import urlencode

Expand Down Expand Up @@ -68,7 +67,7 @@ def __init__(
self.userinfo = userinfo

if cache is None:
self.cache_outstanding_queries = {} # type: Mapping[str, str]
self.cache_outstanding_queries = {} # type: Dict[str, str]
else:
self.cache_outstanding_queries = cache
UserAuthnMethod.__init__(self, srv)
Expand All @@ -86,12 +85,12 @@ def __init__(
self.verification_endpoint = ""
# Configurations for the SP handler.
self.sp_conf = importlib.import_module(spconf)
config = SPConfig().load(self.sp_conf.CONFIG)
config = SPConfig().load(self.sp_conf.CONFIG) # type: ignore
self.sp = Saml2Client(config=config)
mte = lookup.get_template("unauthorized.mako")
argv = {"message": "You are not authorized!"}
self.not_authorized = mte.render(**argv)
self.samlcache = self.sp_conf.SAML_CACHE
self.samlcache = self.sp_conf.SAML_CACHE # type: ignore

def __call__(self, query="", end_point_index=None, *args, **kwargs):

Expand Down Expand Up @@ -173,8 +172,8 @@ def verify(self, request, cookie, path, requrl, end_point_index=None, **kwargs):
logger.error("Other error: %s" % (err,))
return Unauthorized(self.not_authorized), False

if self.sp_conf.VALID_ATTRIBUTE_RESPONSE is not None:
for k, v in self.sp_conf.VALID_ATTRIBUTE_RESPONSE.items():
if self.sp_conf.VALID_ATTRIBUTE_RESPONSE is not None: # type: ignore
for k, v in self.sp_conf.VALID_ATTRIBUTE_RESPONSE.items(): # type: ignore
if k not in response.ava:
return Unauthorized(self.not_authorized), False
else:
Expand Down Expand Up @@ -211,8 +210,11 @@ def verify(self, request, cookie, path, requrl, end_point_index=None, **kwargs):

def setup_userdb(self, uid, samldata):
attributes = {}
if self.sp_conf.ATTRIBUTE_WHITELIST is not None:
for attr, allowed in self.sp_conf.ATTRIBUTE_WHITELIST.items():
if self.sp_conf.ATTRIBUTE_WHITELIST is not None: # type: ignore
for (
attr,
allowed,
) in self.sp_conf.ATTRIBUTE_WHITELIST.items(): # type: ignore
if attr in samldata:
if allowed is not None:
tmp_attr_list = []
Expand All @@ -228,10 +230,10 @@ def setup_userdb(self, uid, samldata):
attributes = samldata
userdb = {} # type: Dict[str, List[str]]

if self.sp_conf.OPENID2SAMLMAP is None:
if self.sp_conf.OPENID2SAMLMAP is None: # type: ignore
userdb = attributes.copy()
else:
for oic, saml in self.sp_conf.OPENID2SAMLMAP.items():
for oic, saml in self.sp_conf.OPENID2SAMLMAP.items(): # type: ignore
if saml in attributes:
userdb[oic] = attributes[saml]
self.userdb[uid] = userdb
Expand Down Expand Up @@ -273,14 +275,14 @@ def _pick_idp(self, query, end_point_index):
'{"'
+ self.CONST_QUERY
+ '": "'
+ base64.b64encode(query)
+ base64.b64encode(query).decode()
+ '" , "'
+ self.CONST_HASIDP
+ '": "False" }',
self.CONST_SAML_COOKIE,
self.CONST_SAML_COOKIE,
)
if self.sp_conf.WAYF:
if self.sp_conf.WAYF: # type: ignore
if query:
try:
wayf_selected = query_dict["wayf_selected"][0]
Expand All @@ -289,7 +291,7 @@ def _pick_idp(self, query, end_point_index):
idp_entity_id = wayf_selected
else:
return self._wayf_redirect(cookie)
elif self.sp_conf.DISCOSRV:
elif self.sp_conf.DISCOSRV: # type: ignore
if query:
idp_entity_id = _cli.parse_discovery_service_response(query=query)
if not idp_entity_id:
Expand All @@ -303,7 +305,7 @@ def _pick_idp(self, query, end_point_index):
][0]
ret += "?sid=%s" % sid_
loc = _cli.create_discovery_service_request(
self.sp_conf.DISCOSRV, eid, **{"return": ret}
self.sp_conf.DISCOSRV, eid, **{"return": ret} # type: ignore
)
return -1, SeeOther(loc, headers=[cookie])
elif not len(idps):
Expand All @@ -320,7 +322,10 @@ def _wayf_redirect(self, cookie):
return (
-1,
SeeOther(
headers=[("Location", "%s?%s" % (self.sp_conf.WAYF, sid_)), cookie]
headers=[
("Location", "%s?%s" % (self.sp_conf.WAYF, sid_)), # type: ignore
cookie,
]
),
)

Expand Down
Loading

0 comments on commit 2a949cb

Please sign in to comment.