From 04b0c301978bf2381c8d7103ab041ac54c393914 Mon Sep 17 00:00:00 2001 From: Jermiah Date: Sat, 18 Nov 2023 22:07:53 +0000 Subject: [PATCH] modify auth --- src/auth.py | 81 ++++++++++++++++++++----------------- src/nbia.py | 3 +- src/tests/test_auth.py | 13 ++++-- src/utils/nbia_endpoints.py | 2 +- 4 files changed, 55 insertions(+), 44 deletions(-) diff --git a/src/auth.py b/src/auth.py index 3898053..a7b8447 100644 --- a/src/auth.py +++ b/src/auth.py @@ -9,48 +9,53 @@ def __init__(self, self.username = username self.password = password self.access_token = None - self.api_headers = None #username=nbia_guest&password=&client_id=NBIA&grant_type=password def getToken(self): # Check if the access token is valid and not expired - if self.access_token is not None and self.access_token != 401: - return self.access_token - elif self.access_token == 401: - print("Failed to get access token. Status code: 401") - else: - # Prepare the request data4 - data = { - 'username': self.username, - 'password': self.password, - 'client_id': self.client_id, - 'grant_type': 'password' - } - token_url = 'https://services.cancerimagingarchive.net/nbia-api/oauth/token' - # Make a POST request to the token endpoint - response = requests.post(token_url, data=data) - # response.raise_for_status() + if self.access_token is not None: + return 401 if self.access_token == 401 else self.access_token + + # Prepare the request data4 + data = { + 'username': self.username, + 'password': self.password, + 'client_id': self.client_id, + 'grant_type': 'password' + } + token_url = 'https://services.cancerimagingarchive.net/nbia-api/oauth/token' + + response = requests.post(token_url, data=data) + + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + print(f"HTTP Error occurred: {e}") + print(f"Failed to get access token. Status code: \ + {response.status_code}") - # Check if the request was successful - if response.status_code == 200: - token_data = response.json() - self.access_token = token_data.get('access_token') - # save access token to self for later use - # self.api_headers = { - # 'Authorization': f'Bearer {self.access_token}', - # 'Accept': 'application/json' - # } - # TODO::implement refresh token functionality - self.expiry_time = time.ctime( - time.time() + token_data.get('expires_in') ) - self.refresh_token = token_data.get('refresh_token') - self.refresh_expiry = token_data.get('refresh_expires_in') - self.scope = token_data.get('scope') - return self.access_token - else: - print(f"Failed to get access token. Status code: \ - {response.status_code}") + self.access_token = response.status_code + return response.status_code + + token_data = response.json() + self.access_token = token_data.get('access_token') + + self.api_headers = { + 'Authorization': f'Bearer {self.access_token}', + 'Accept': 'application/json' + } + + # TODO::implement refresh token functionality + self.expiry_time = time.ctime( + time.time() + token_data.get('expires_in') ) + self.refresh_token = token_data.get('refresh_token') + self.refresh_expiry = token_data.get('refresh_expires_in') + self.scope = token_data.get('scope') + return self.api_headers - self.access_token = response.status_code - return response.status_code + + # def logout(self): + # # Request for logout + # # curl -X -v -d "Authorization:Bearer YOUR_ACCESS_TOKEN" -k "https://services.cancerimagingarchive.net/nbia-api/logout" + \ No newline at end of file diff --git a/src/nbia.py b/src/nbia.py index 0c32ed8..33ebe6b 100644 --- a/src/nbia.py +++ b/src/nbia.py @@ -17,8 +17,7 @@ def __init__(self, # Setup OAuth2 client self.logger.info("Setting up OAuth2 client... with username %s", username) self._oauth2_client = OAuth2(username=username, password=password) - self.api_headers = {'Authorization': f'Bearer {self._oauth2_client.getToken()}'} - + self.api_headers = self._oauth2_client.getToken() def query_api(self, endpoint: NBIA_ENDPOINTS, params: dict = {}) -> dict: base_url = "https://services.cancerimagingarchive.net/nbia-api/services/" diff --git a/src/tests/test_auth.py b/src/tests/test_auth.py index 2ac3041..a5c60c7 100644 --- a/src/tests/test_auth.py +++ b/src/tests/test_auth.py @@ -21,7 +21,7 @@ def failed_oauth2(): return oauth def test_getToken(oauth2): - assert (oauth2.access_token is not None and oauth2.access_token != 401) + assert oauth2.access_token is not None def test_expiry(oauth2): # expiry should be in the form of :'Tue Jun 29 13:58:57 2077' @@ -39,6 +39,13 @@ def test_failed_oauth(failed_oauth2,capsys): def test_failed_oauth_retried(failed_oauth2,capsys): failed_oauth2.getToken() captured = capsys.readouterr() - assert captured.out == "Failed to get access token. Status code: 401\n" assert failed_oauth2.access_token == 401 - \ No newline at end of file + +def test_getToken_valid_token(oauth2): + # Test if the access token is valid and not expired + assert oauth2.getToken() == oauth2.access_token + +def test_getToken_failed_token(failed_oauth2, capsys): + # Test if the access token retrieval fails with incorrect credentials + assert failed_oauth2.getToken() == 401 + captured = capsys.readouterr() diff --git a/src/utils/nbia_endpoints.py b/src/utils/nbia_endpoints.py index 71e08df..c401fe9 100644 --- a/src/utils/nbia_endpoints.py +++ b/src/utils/nbia_endpoints.py @@ -10,7 +10,7 @@ class NBIA_ENDPOINTS(Enum): GET_COLLECTIONS = 'v2/getCollectionValues' GET_BODY_PART_PATIENT_COUNT = 'getBodyPartValuesAndCounts' GET_PATIENT_BY_COLLECTION_AND_MODALITY = 'v2/getPatientByCollectionAndModality' - + # Helper functions def __str__(self):