Commit cce3648b authored by Mihai Patrascoiu's avatar Mihai Patrascoiu
Browse files

FTS-1674: Rework the offline validation decoding. A list of filtered keys is...

FTS-1674: Rework the offline validation decoding. A list of filtered keys is returned which will be tried until the first key that works is found
parent c4600296
Pipeline #2026355 passed with stage
in 2 minutes and 5 seconds
...@@ -271,38 +271,55 @@ class FTS3OAuth2ResourceProvider(ResourceProvider): ...@@ -271,38 +271,55 @@ class FTS3OAuth2ResourceProvider(ResourceProvider):
:param access_token: :param access_token:
:return: tuple(valid, credential) or tuple(False, None) :return: tuple(valid, credential) or tuple(False, None)
""" """
def decode(key):
log.debug('Attempt decoding using key={}'.format(key.export()))
try:
if 'wlcg' in issuer:
audience = 'https://wlcg.cern.ch/jwt/v1/any'
credential = jwt.decode(access_token,
key.export_to_pem(),
algorithms=[algorithm],
audience=audience
)
else:
# We don't check audience for non-WLCG token
credential = jwt.decode(access_token,
key.export_to_pem(),
algorithms=[algorithm],
options={'verify_aud': False}
)
return credential
except Exception:
return None
log.debug('entered validate_token_offline') log.debug('entered validate_token_offline')
credential = None
try: try:
unverified_payload = jwt.decode(access_token, verify=False) unverified_payload = jwt.decode(access_token, verify=False)
unverified_header = jwt.get_unverified_header(access_token) unverified_header = jwt.get_unverified_header(access_token)
issuer = unverified_payload['iss'] issuer = unverified_payload['iss']
key_id = unverified_header['kid'] key_id = unverified_header.get('kid')
log.debug('issuer={}, key_id={}'.format(issuer, key_id)) algorithm = unverified_header.get('alg')
log.debug('issuer={}, key_id={}, alg={}'.format(issuer, key_id, algorithm))
algorithm = unverified_header.get('alg', 'RS256') # Retrieval of keys
log.debug('alg={}'.format(algorithm)) keys = oidc_manager.filter_provider_keys(issuer, key_id, algorithm)
jwkeys = [JWK.from_json(json.dumps(key.to_dict())) for key in keys]
pub_key = oidc_manager.get_provider_key(issuer, key_id)
log.debug('key={}'.format(pub_key)) # Find the first key which decodes the token
pub_key = JWK.from_json(json.dumps(pub_key.to_dict())) for jwkey in jwkeys:
log.debug('pubkey={}'.format(pub_key)) credential = decode(jwkey)
# Verify & Validate if credential is not None:
if 'wlcg' in issuer: log.debug('offline_response::: {}'.format(credential))
audience = 'https://wlcg.cern.ch/jwt/v1/any' break
credential = jwt.decode(access_token, pub_key.export_to_pem(), algorithms=[algorithm], audience=audience)
else:
credential = jwt.decode(access_token,
pub_key.export_to_pem(),
algorithms=[algorithm],
options={'verify_aud': False} # We don't check audience for non-WLCG token)
)
log.debug('offline_response::: {}'.format(credential))
except Exception as ex: except Exception as ex:
log.debug(ex) log.debug('return False, Exception: {}'.format(ex))
log.debug('return False (exception)')
return False, None return False, None
log.debug('return True, credential')
return True, credential if credential is None:
log.debug('No key managed to decode the token')
log.debug('return {}, credential'.format(credential is not None))
return (credential is not None), credential
def _validate_token_online(self, access_token): def _validate_token_online(self, access_token):
""" """
......
...@@ -55,20 +55,25 @@ class OIDCmanager: ...@@ -55,20 +55,25 @@ class OIDCmanager:
for keybundle in keybundles: for keybundle in keybundles:
keybundle.cache_time = cache_time keybundle.cache_time = cache_time
def get_provider_key(self, issuer, kid): def filter_provider_keys(self, issuer, kid=None, alg=None):
""" """
Get a Provider Key by ID Return Provider Keys after applying Key ID and Algorithm filter.
If no filters match, return the full set.
:param issuer: provider :param issuer: provider
:param kid: Key ID :param kid: Key ID
:return: key :param alg: Algorithm
:raise ValueError: if key not found :return: keys
:raise ValueError: client could not be retrieved
""" """
client = self.clients[issuer] client = self.clients.get(issuer)
keys = client.keyjar.get_issuer_keys(issuer) # List of Keys (from pyjwkest) if client is None:
for key in keys: raise ValueError('Could not retrieve client for issuer={}'.format(issuer))
if key.kid == kid: # List of Keys (from pyjwkest)
return key keys = client.keyjar.get_issuer_keys(issuer)
raise ValueError("Key with kid {} not found".format(kid)) filtered_keys = [key for key in keys if key.kid == kid or key.alg == alg]
if len(filtered_keys) is 0:
return keys
return filtered_keys
def introspect(self, issuer, access_token): def introspect(self, issuer, access_token):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment