From c3e53e1745cc01419118b048c2e706f151a0ca5e Mon Sep 17 00:00:00 2001 From: David Otte Date: Mon, 11 Mar 2024 12:32:45 +0100 Subject: [PATCH] Add comments and improve code for error raising --- tabpfn_client/client.py | 64 ++++++++++++++----------- tabpfn_client/tabpfn_classifier.py | 1 + tabpfn_client/tests/unit/test_client.py | 33 +++++-------- 3 files changed, 48 insertions(+), 50 deletions(-) diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 4800f8b..a84511b 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -18,8 +18,8 @@ def get_client_version() -> str: try: return version('tabpfn_client') except PackageNotFoundError: - # Package not found, should only happen during development. Simply return a version number that works - # for development. + # Package not found, should only happen during development. Execute 'pip install -e .' to use the actual + # version number during development. Otherwise, simply return a version number that is large enough. return '5.5.5' @@ -90,7 +90,7 @@ def upload_train_set(self, X, y) -> str: ]) ) - self.error_raising(response, "upload_train_set") + self._validate_response(response, "upload_train_set") train_set_uid = response.json()["train_set_uid"] return train_set_uid @@ -122,22 +122,30 @@ def predict(self, train_set_uid: str, x_test): ]) ) - self.error_raising(response, "predict") + self._validate_response(response, "predict") return np.array(response.json()["y_pred"]) - def error_raising(self, response, method_name, only_version_check=False): - if response.status_code != 200: - load = None - try: - load = response.json() - except Exception: - pass - if (response.status_code == 403 and load and load.get("detail").startswith("Client version too old") - or response.status_code == 400 and load and load.get("detail").startswith("Client version")): - raise RuntimeError(load.get("detail")) - if only_version_check: - return + @staticmethod + def _validate_response(response, method_name, only_version_check=False): + # If status code is 200, no errors occurred on the server side. + if response.status_code == 200: + return + + # Read response. + load = None + try: + load = response.json() + except Exception: + pass + + # Check if the server requires a newer client version. + if response.status_code == 426: + logger.error(f"Fail to call {method_name}, response status: {response.status_code}") + raise RuntimeError(load.get("detail")) + + # If we not only want to check the version compatibility, also raise other errors. + if not only_version_check: if load is not None: raise RuntimeError(f"Fail to call {method_name} with error: {load}") logger.error(f"Fail to call {method_name}, response status: {response.status_code}") @@ -172,7 +180,7 @@ def predict_proba(self, train_set_uid: str, x_test): ]) ) - self.error_raising(response, "predict_proba") + self._validate_response(response, "predict_proba") return np.array(response.json()["y_pred_proba"]) @@ -183,7 +191,7 @@ def try_connection(self) -> bool: found_valid_connection = False try: response = self.httpx_client.get(self.server_endpoints.root.path) - self.error_raising(response, "try_connection", only_version_check=True) + self._validate_response(response, "try_connection", only_version_check=True) if response.status_code == 200: found_valid_connection = True @@ -202,7 +210,7 @@ def try_authenticate(self, access_token) -> bool: headers={"Authorization": f"Bearer {access_token}"}, ) - self.error_raising(response, "try_authenticate", only_version_check=True) + self._validate_response(response, "try_authenticate", only_version_check=True) if response.status_code == 200: is_authenticated = True @@ -239,7 +247,7 @@ def register( params={"email": email, "password": password, "password_confirm": password_confirm, "validation_link": validation_link} ) - self.error_raising(response, "register", only_version_check=True) + self._validate_response(response, "register", only_version_check=True) if response.status_code == 200: is_created = True message = response.json()["message"] @@ -270,7 +278,7 @@ def login(self, email: str, password: str) -> str | None: data=common_utils.to_oauth_request_form(email, password) ) - self.error_raising(response, "login", only_version_check=True) + self._validate_response(response, "login", only_version_check=True) if response.status_code == 200: access_token = response.json()["access_token"] @@ -289,7 +297,7 @@ def get_password_policy(self) -> {}: response = self.httpx_client.get( self.server_endpoints.password_policy.path, ) - self.error_raising(response, "get_password_policy", only_version_check=True) + self._validate_response(response, "get_password_policy", only_version_check=True) return response.json()["requirements"] @@ -299,7 +307,7 @@ def retrieve_greeting_messages(self) -> list[str]: """ response = self.httpx_client.get(self.server_endpoints.retrieve_greeting_messages.path) - self.error_raising(response, "retrieve_greeting_messages", only_version_check=True) + self._validate_response(response, "retrieve_greeting_messages", only_version_check=True) if response.status_code != 200: return [] @@ -318,7 +326,7 @@ def get_data_summary(self) -> {}: response = self.httpx_client.get( self.server_endpoints.get_data_summary.path, ) - self.error_raising(response, "get_data_summary") + self._validate_response(response, "get_data_summary") return response.json() @@ -337,7 +345,7 @@ def download_all_data(self, save_dir: Path) -> Path | None: full_url = self.base_url + self.server_endpoints.download_all_data.path with httpx.stream("GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"}) as response: - self.error_raising(response, "download_all_data") + self._validate_response(response, "download_all_data") filename = response.headers["Content-Disposition"].split("filename=")[1] save_path = Path(save_dir) / filename @@ -368,7 +376,7 @@ def delete_dataset(self, dataset_uid: str) -> [str]: params={"dataset_uid": dataset_uid} ) - self.error_raising(response, "delete_dataset") + self._validate_response(response, "delete_dataset") return response.json()["deleted_dataset_uids"] @@ -385,7 +393,7 @@ def delete_all_datasets(self) -> [str]: self.server_endpoints.delete_all_datasets.path, ) - self.error_raising(response, "delete_all_datasets") + self._validate_response(response, "delete_all_datasets") return response.json()["deleted_dataset_uids"] @@ -395,4 +403,4 @@ def delete_user_account(self, confirm_pass: str) -> None: params={"confirm_password": confirm_pass} ) - self.error_raising(response, "delete_user_account") + self._validate_response(response, "delete_user_account") diff --git a/tabpfn_client/tabpfn_classifier.py b/tabpfn_client/tabpfn_classifier.py index beb3db9..362e297 100644 --- a/tabpfn_client/tabpfn_classifier.py +++ b/tabpfn_client/tabpfn_classifier.py @@ -52,6 +52,7 @@ def init(use_server=True): # prompt for login / register PromptAgent.prompt_and_set_token(user_auth_handler) + # Print new greeting messages. If there are no new messages, nothing will be printed. PromptAgent.prompt_retrieved_greeting_messages(user_auth_handler.retrieve_greeting_messages()) g_tabpfn_config.use_server = True diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py index 2386d04..9c14318 100644 --- a/tabpfn_client/tests/unit/test_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -31,12 +31,7 @@ def test_try_connection_with_invalid_server(self, mock_server): @with_mock_server() def test_try_connection_with_outdated_client(self, mock_server): mock_server.router.get(mock_server.endpoints.root.path).respond( - 400, json={"detail": "Client version header missing. Please make sure to use the ..."}) - with self.assertRaises(RuntimeError) as cm: - self.client.try_connection() - self.assertTrue(str(cm.exception).startswith("Client version header missing")) - mock_server.router.get(mock_server.endpoints.root.path).respond( - 403, json={"detail": "Client version too old. ..."}) + 426, json={"detail": "Client version too old. ..."}) with self.assertRaises(RuntimeError) as cm: self.client.try_connection() self.assertTrue(str(cm.exception).startswith("Client version too old.")) @@ -95,46 +90,40 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server): ) self.assertTrue(np.array_equal(pred, dummy_result["y_pred"])) - def test_error_raising_no_error(self): + def test_validate_response_no_error(self): response = Mock() response.status_code = 200 - r = self.client.error_raising(response, "test") + r = self.client._validate_response(response, "test") self.assertIsNone(r) - def test_error_raising(self): + def test_validate_response(self): response = Mock() # Test for "Client version too old." error - response.status_code = 403 + response.status_code = 426 response.json.return_value = {"detail": "Client version too old."} with self.assertRaises(RuntimeError) as cm: - self.client.error_raising(response, "test") + self.client._validate_response(response, "test") self.assertEqual(str(cm.exception), "Client version too old.") # Test for "Some other error" which is translated to a generic failure message response.status_code = 400 response.json.return_value = {"detail": "Some other error"} with self.assertRaises(RuntimeError) as cm: - self.client.error_raising(response, "test") + self.client._validate_response(response, "test") self.assertTrue(str(cm.exception).startswith("Fail to call test")) - def test_error_raising_only_version_check(self): + def test_validate_response_only_version_check(self): response = Mock() - response.status_code = 403 + response.status_code = 426 response.json.return_value = {"detail": "Client version too old."} with self.assertRaises(RuntimeError) as cm: - self.client.error_raising(response, "test", only_version_check=True) + self.client._validate_response(response, "test", only_version_check=True) self.assertEqual(str(cm.exception), "Client version too old.") - response.status_code = 400 - response.json.return_value = {"detail": "Client version header missing."} - with self.assertRaises(RuntimeError) as cm: - self.client.error_raising(response, "test", only_version_check=True) - self.assertEqual(str(cm.exception), "Client version header missing.") - # Errors that have nothing to do with client version should be skipped. response = Mock() response.status_code = 400 response.json.return_value = {"detail": "Some other error"} - r = self.client.error_raising(response, "test", only_version_check=True) + r = self.client._validate_response(response, "test", only_version_check=True) self.assertIsNone(r)