Skip to content

Commit

Permalink
Add more heuristics for backwards JWT compat
Browse files Browse the repository at this point in the history
These additional heuristics help in case the calling application was
correctly calssifying key usage, as this is another valid hitn of what
the application intended.
Invalid key usage would already cause failure, so this does not affect
the countermeasures introduced but can avoid issues in older
applications.

Signed-off-by: Simo Sorce <simo@redhat.com>
  • Loading branch information
simo5 committed Sep 14, 2022
1 parent 84f121f commit c4e0bee
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 12 deletions.
88 changes: 76 additions & 12 deletions jwcrypto/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from jwcrypto.common import json_decode, json_encode
from jwcrypto.jwe import JWE
from jwcrypto.jwe import default_allowed_algs as jwe_algs
from jwcrypto.jwk import JWK, JWKSet
from jwcrypto.jws import JWS
from jwcrypto.jws import default_allowed_algs as jws_algs


# RFC 7519 - 4.1
Expand Down Expand Up @@ -296,26 +298,89 @@ def validity(self):
def validity(self, v):
self._validity = int(v)

@property
def expected_type(self):
if self._expected_type is not None:
return self._expected_type

# If no expected type is set we default to accept only JWSs,
# however to improve backwards compatibility we try some
# heuristic to see if there has been strong indication of
# what the expected token type is.
def _expected_type_heuristics(self, key=None):
if self._expected_type is None and self._algs:
if set(self._algs).issubset(jwe_algs + ['RSA1_5']):
self._expected_type = "JWE"
elif set(self._algs).issubset(jws_algs):
self._expected_type = "JES"
if self._expected_type is None and self._header:
if "enc" in json_decode(self._header):
self._expected_type = "JWE"
if self._expected_type is None and key is not None:
if isinstance(key, JWK):
use = key.get('use')
if use == 'sig':
self._expected_type = "JWS"
elif use == 'enc':
self._expected_type = "JWE"
elif isinstance(key, JWKSet):
all_use = None
# we can infer only if all keys are of the same type
for k in key:
use = k.get('use')
if all_use is None:
all_use = use
elif use != all_use:
all_use = None
break
if all_use == 'sig':
self._expected_type = "JWS"
elif all_use == 'enc':
self._expected_type = "JWE"
if self._expected_type is None and key is not None:
if isinstance(key, JWK):
ops = key.get('key_ops')
if ops:
if not isinstance(ops, list):
ops = [ops]
if set(ops).issubset(['sign', 'verify']):
self._expected_type = "JWS"
elif set(ops).issubset(['encrypt', 'decrypt']):
self._expected_type = "JWE"
elif isinstance(key, JWKSet):
all_ops = None
ttype = None
# we can infer only if all keys are of the same type
for k in key:
ops = k.get('key_ops')
if ops:
if not isinstance(ops, list):
ops = [ops]
if all_ops is None:
if set(ops).issubset(['sign', 'verify']):
all_ops = set(['sign', 'verify'])
ttype = "JWS"
elif set(ops).issubset(['encrypt', 'decrypt']):
all_ops = set(['encrypt', 'decrypt'])
ttype = "JWE"
else:
ttype = None
break
else:
if not set(ops).issubset(all_ops):
ttype = None
break
elif all_ops:
ttype = None
break
if ttype:
self._expected_type = ttype
if self._expected_type is None:
self._expected_type = "JWS"

return self._expected_type

@property
def expected_type(self):
if self._expected_type is not None:
return self._expected_type

# If no expected type is set we default to accept only JWSs,
# however to improve backwards compatibility we try some
# heuristic to see if there has been strong indication of
# what the expected token type is.
return self._expected_type_heuristics()

@expected_type.setter
def expected_type(self, v):
if v in ["JWS", "JWE"]:
Expand Down Expand Up @@ -549,7 +614,7 @@ def validate(self, key):
if self.token is None:
raise ValueError("Token empty")

et = self.expected_type
et = self._expected_type_heuristics(key)
validate_fn = None

if isinstance(self.token, JWS):
Expand All @@ -558,7 +623,6 @@ def validate(self, key):
validate_fn = self.token.verify
elif isinstance(self.token, JWE):
if et != "JWE" and JWT_expect_type:
print("algs: {}".format(self._algs))
raise TypeError("Expected {}, got JWE".format(et))
validate_fn = self.token.decrypt
else:
Expand Down
36 changes: 36 additions & 0 deletions jwcrypto/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,24 @@ def test_unexpected(self):
with self.assertRaises(TypeError):
jwt.JWT(jwt=sertok, key=key, expected_type='JWE')

key.use = 'sig'
jwt.JWT(jwt=sertok, key=key)
key.use = 'enc'
with self.assertRaises(TypeError):
jwt.JWT(jwt=sertok, key=key)
key.use = None
key.key_ops = 'verify'
jwt.JWT(jwt=sertok, key=key)
key.key_ops = ['sign', 'verify']
jwt.JWT(jwt=sertok, key=key)
key.key_ops = 'decrypt'
with self.assertRaises(TypeError):
jwt.JWT(jwt=sertok, key=key)
key.key_ops = ['encrypt', 'decrypt']
with self.assertRaises(TypeError):
jwt.JWT(jwt=sertok, key=key)
key.key_ops = None

token = jwt.JWT(header={"alg": "A256KW", "enc": "A256GCM"},
claims=claims)
token.make_encrypted_token(key)
Expand All @@ -1781,6 +1799,24 @@ def test_unexpected(self):
with self.assertRaises(TypeError):
jwt.JWT(jwt=enctok, key=key, expected_type='JWS')

key.use = 'enc'
jwt.JWT(jwt=enctok, key=key)
key.use = 'sig'
with self.assertRaises(TypeError):
jwt.JWT(jwt=enctok, key=key)
key.use = None
key.key_ops = 'verify'
with self.assertRaises(TypeError):
jwt.JWT(jwt=enctok, key=key)
key.key_ops = ['sign', 'verify']
with self.assertRaises(TypeError):
jwt.JWT(jwt=enctok, key=key)
key.key_ops = 'decrypt'
jwt.JWT(jwt=enctok, key=key)
key.key_ops = ['encrypt', 'decrypt']
jwt.JWT(jwt=enctok, key=key)
key.key_ops = None


class ConformanceTests(unittest.TestCase):

Expand Down

0 comments on commit c4e0bee

Please sign in to comment.