From bce902d6095cbe3d182adb452a01120bb4b298ba Mon Sep 17 00:00:00 2001 From: David Otte Date: Mon, 11 Mar 2024 14:29:32 +0100 Subject: [PATCH] Fix version check and quick_test.py --- quick_test.py | 3 +-- tabpfn_client/client.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/quick_test.py b/quick_test.py index 31b3a29..3f0caac 100644 --- a/quick_test.py +++ b/quick_test.py @@ -34,8 +34,7 @@ tabpfn_classifier.init() tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted") # print("checking estimator", check_estimator(tabpfn)) - print(X_train.shape[0]*100) - tabpfn.fit(np.repeat(X_train, 100, axis=0), np.repeat(y_train, 100, axis=0)) + tabpfn.fit(X_train[:99], y_train[:99]) print("predicting") print(tabpfn.predict(X_test)) print("predicting_proba") diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index a84511b..1360393 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -278,7 +278,7 @@ def login(self, email: str, password: str) -> str | None: data=common_utils.to_oauth_request_form(email, password) ) - self._validate_response(response, "login", only_version_check=True) + self._validate_response(response, "login", only_version_check=False) if response.status_code == 200: access_token = response.json()["access_token"]