Skip to content

Commit

Permalink
Add comments and improve code for error raising
Browse files Browse the repository at this point in the history
  • Loading branch information
davidotte committed Mar 11, 2024
1 parent 68e6d5f commit c3e53e1
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 50 deletions.
64 changes: 36 additions & 28 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def get_client_version() -> str:
try:
return version('tabpfn_client')
except PackageNotFoundError:
# Package not found, should only happen during development. Simply return a version number that works
# for development.
# Package not found, should only happen during development. Execute 'pip install -e .' to use the actual
# version number during development. Otherwise, simply return a version number that is large enough.
return '5.5.5'


Expand Down Expand Up @@ -90,7 +90,7 @@ def upload_train_set(self, X, y) -> str:
])
)

self.error_raising(response, "upload_train_set")
self._validate_response(response, "upload_train_set")

train_set_uid = response.json()["train_set_uid"]
return train_set_uid
Expand Down Expand Up @@ -122,22 +122,30 @@ def predict(self, train_set_uid: str, x_test):
])
)

self.error_raising(response, "predict")
self._validate_response(response, "predict")

return np.array(response.json()["y_pred"])

def error_raising(self, response, method_name, only_version_check=False):
if response.status_code != 200:
load = None
try:
load = response.json()
except Exception:
pass
if (response.status_code == 403 and load and load.get("detail").startswith("Client version too old")
or response.status_code == 400 and load and load.get("detail").startswith("Client version")):
raise RuntimeError(load.get("detail"))
if only_version_check:
return
@staticmethod
def _validate_response(response, method_name, only_version_check=False):
# If status code is 200, no errors occurred on the server side.
if response.status_code == 200:
return

# Read response.
load = None
try:
load = response.json()
except Exception:
pass

# Check if the server requires a newer client version.
if response.status_code == 426:
logger.error(f"Fail to call {method_name}, response status: {response.status_code}")
raise RuntimeError(load.get("detail"))

# If we not only want to check the version compatibility, also raise other errors.
if not only_version_check:
if load is not None:
raise RuntimeError(f"Fail to call {method_name} with error: {load}")
logger.error(f"Fail to call {method_name}, response status: {response.status_code}")
Expand Down Expand Up @@ -172,7 +180,7 @@ def predict_proba(self, train_set_uid: str, x_test):
])
)

self.error_raising(response, "predict_proba")
self._validate_response(response, "predict_proba")

return np.array(response.json()["y_pred_proba"])

Expand All @@ -183,7 +191,7 @@ def try_connection(self) -> bool:
found_valid_connection = False
try:
response = self.httpx_client.get(self.server_endpoints.root.path)
self.error_raising(response, "try_connection", only_version_check=True)
self._validate_response(response, "try_connection", only_version_check=True)
if response.status_code == 200:
found_valid_connection = True

Expand All @@ -202,7 +210,7 @@ def try_authenticate(self, access_token) -> bool:
headers={"Authorization": f"Bearer {access_token}"},
)

self.error_raising(response, "try_authenticate", only_version_check=True)
self._validate_response(response, "try_authenticate", only_version_check=True)

if response.status_code == 200:
is_authenticated = True
Expand Down Expand Up @@ -239,7 +247,7 @@ def register(
params={"email": email, "password": password, "password_confirm": password_confirm, "validation_link": validation_link}
)

self.error_raising(response, "register", only_version_check=True)
self._validate_response(response, "register", only_version_check=True)
if response.status_code == 200:
is_created = True
message = response.json()["message"]
Expand Down Expand Up @@ -270,7 +278,7 @@ def login(self, email: str, password: str) -> str | None:
data=common_utils.to_oauth_request_form(email, password)
)

self.error_raising(response, "login", only_version_check=True)
self._validate_response(response, "login", only_version_check=True)
if response.status_code == 200:
access_token = response.json()["access_token"]

Expand All @@ -289,7 +297,7 @@ def get_password_policy(self) -> {}:
response = self.httpx_client.get(
self.server_endpoints.password_policy.path,
)
self.error_raising(response, "get_password_policy", only_version_check=True)
self._validate_response(response, "get_password_policy", only_version_check=True)

return response.json()["requirements"]

Expand All @@ -299,7 +307,7 @@ def retrieve_greeting_messages(self) -> list[str]:
"""
response = self.httpx_client.get(self.server_endpoints.retrieve_greeting_messages.path)

self.error_raising(response, "retrieve_greeting_messages", only_version_check=True)
self._validate_response(response, "retrieve_greeting_messages", only_version_check=True)
if response.status_code != 200:
return []

Expand All @@ -318,7 +326,7 @@ def get_data_summary(self) -> {}:
response = self.httpx_client.get(
self.server_endpoints.get_data_summary.path,
)
self.error_raising(response, "get_data_summary")
self._validate_response(response, "get_data_summary")

return response.json()

Expand All @@ -337,7 +345,7 @@ def download_all_data(self, save_dir: Path) -> Path | None:

full_url = self.base_url + self.server_endpoints.download_all_data.path
with httpx.stream("GET", full_url, headers={"Authorization": f"Bearer {self.access_token}"}) as response:
self.error_raising(response, "download_all_data")
self._validate_response(response, "download_all_data")

filename = response.headers["Content-Disposition"].split("filename=")[1]
save_path = Path(save_dir) / filename
Expand Down Expand Up @@ -368,7 +376,7 @@ def delete_dataset(self, dataset_uid: str) -> [str]:
params={"dataset_uid": dataset_uid}
)

self.error_raising(response, "delete_dataset")
self._validate_response(response, "delete_dataset")

return response.json()["deleted_dataset_uids"]

Expand All @@ -385,7 +393,7 @@ def delete_all_datasets(self) -> [str]:
self.server_endpoints.delete_all_datasets.path,
)

self.error_raising(response, "delete_all_datasets")
self._validate_response(response, "delete_all_datasets")

return response.json()["deleted_dataset_uids"]

Expand All @@ -395,4 +403,4 @@ def delete_user_account(self, confirm_pass: str) -> None:
params={"confirm_password": confirm_pass}
)

self.error_raising(response, "delete_user_account")
self._validate_response(response, "delete_user_account")
1 change: 1 addition & 0 deletions tabpfn_client/tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def init(use_server=True):
# prompt for login / register
PromptAgent.prompt_and_set_token(user_auth_handler)

# Print new greeting messages. If there are no new messages, nothing will be printed.
PromptAgent.prompt_retrieved_greeting_messages(user_auth_handler.retrieve_greeting_messages())

g_tabpfn_config.use_server = True
Expand Down
33 changes: 11 additions & 22 deletions tabpfn_client/tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@ def test_try_connection_with_invalid_server(self, mock_server):
@with_mock_server()
def test_try_connection_with_outdated_client(self, mock_server):
mock_server.router.get(mock_server.endpoints.root.path).respond(
400, json={"detail": "Client version header missing. Please make sure to use the ..."})
with self.assertRaises(RuntimeError) as cm:
self.client.try_connection()
self.assertTrue(str(cm.exception).startswith("Client version header missing"))
mock_server.router.get(mock_server.endpoints.root.path).respond(
403, json={"detail": "Client version too old. ..."})
426, json={"detail": "Client version too old. ..."})
with self.assertRaises(RuntimeError) as cm:
self.client.try_connection()
self.assertTrue(str(cm.exception).startswith("Client version too old."))
Expand Down Expand Up @@ -95,46 +90,40 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server):
)
self.assertTrue(np.array_equal(pred, dummy_result["y_pred"]))

def test_error_raising_no_error(self):
def test_validate_response_no_error(self):
response = Mock()
response.status_code = 200
r = self.client.error_raising(response, "test")
r = self.client._validate_response(response, "test")
self.assertIsNone(r)

def test_error_raising(self):
def test_validate_response(self):
response = Mock()
# Test for "Client version too old." error
response.status_code = 403
response.status_code = 426
response.json.return_value = {"detail": "Client version too old."}
with self.assertRaises(RuntimeError) as cm:
self.client.error_raising(response, "test")
self.client._validate_response(response, "test")
self.assertEqual(str(cm.exception), "Client version too old.")

# Test for "Some other error" which is translated to a generic failure message
response.status_code = 400
response.json.return_value = {"detail": "Some other error"}
with self.assertRaises(RuntimeError) as cm:
self.client.error_raising(response, "test")
self.client._validate_response(response, "test")
self.assertTrue(str(cm.exception).startswith("Fail to call test"))

def test_error_raising_only_version_check(self):
def test_validate_response_only_version_check(self):
response = Mock()
response.status_code = 403
response.status_code = 426
response.json.return_value = {"detail": "Client version too old."}
with self.assertRaises(RuntimeError) as cm:
self.client.error_raising(response, "test", only_version_check=True)
self.client._validate_response(response, "test", only_version_check=True)
self.assertEqual(str(cm.exception), "Client version too old.")

response.status_code = 400
response.json.return_value = {"detail": "Client version header missing."}
with self.assertRaises(RuntimeError) as cm:
self.client.error_raising(response, "test", only_version_check=True)
self.assertEqual(str(cm.exception), "Client version header missing.")

# Errors that have nothing to do with client version should be skipped.
response = Mock()
response.status_code = 400
response.json.return_value = {"detail": "Some other error"}
r = self.client.error_raising(response, "test", only_version_check=True)
r = self.client._validate_response(response, "test", only_version_check=True)
self.assertIsNone(r)

0 comments on commit c3e53e1

Please sign in to comment.