From a3f2cfb0cbdb259ab7e31b6534c539844ca72f10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20M=C3=BCller?= Date: Sat, 4 May 2024 14:32:13 +0200 Subject: [PATCH 1/6] Pass hyper params from client --- tabpfn_client/client.py | 13 +- tabpfn_client/remote_tabpfn_classifier.py | 70 ------ tabpfn_client/service_wrapper.py | 13 +- tabpfn_client/tabpfn_classifier.py | 234 +++++++++++------- .../integration/test_tabpfn_classifier.py | 11 +- tabpfn_client/tests/unit/test_prompt_agent.py | 3 - .../unit/test_remote_tabpfn_classifier.py | 64 ----- .../tests/unit/test_tabpfn_classifier.py | 31 +-- 8 files changed, 182 insertions(+), 257 deletions(-) delete mode 100644 tabpfn_client/remote_tabpfn_classifier.py delete mode 100644 tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 155f8ad..de20dec 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import traceback from pathlib import Path import httpx import logging @@ -96,7 +99,7 @@ def upload_train_set(self, X, y) -> str: train_set_uid = response.json()["train_set_uid"] return train_set_uid - def predict(self, train_set_uid: str, x_test): + def predict(self, train_set_uid: str, x_test, tabpfn_config: dict | None=None): """ Predict the class labels for the provided data (test set). @@ -115,9 +118,14 @@ def predict(self, train_set_uid: str, x_test): x_test = common_utils.serialize_to_csv_formatted_bytes(x_test) + params = {"train_set_uid": train_set_uid} + + if tabpfn_config is not None: + params["tabpfn_config"] = json.dumps(tabpfn_config, default=lambda x: x.to_dict()) + response = self.httpx_client.post( url=self.server_endpoints.predict.path, - params={"train_set_uid": train_set_uid}, + params=params, files=common_utils.to_httpx_post_file_format([ ("x_file", "x_test_filename", x_test) ]) @@ -198,7 +206,6 @@ def try_connection(self) -> bool: except httpx.ConnectError as e: logger.error(f"Failed to connect to the server with error: {e}") - import traceback traceback.print_exc() found_valid_connection = False diff --git a/tabpfn_client/remote_tabpfn_classifier.py b/tabpfn_client/remote_tabpfn_classifier.py deleted file mode 100644 index c6024fb..0000000 --- a/tabpfn_client/remote_tabpfn_classifier.py +++ /dev/null @@ -1,70 +0,0 @@ -from sklearn.utils.validation import check_is_fitted -from sklearn.base import BaseEstimator, ClassifierMixin - -from tabpfn_client.service_wrapper import InferenceClient - - -class RemoteTabPFNClassifier(BaseEstimator, ClassifierMixin): - - def __init__( - self, - model=None, - device="cpu", - base_path=".", - model_string="", - batch_size_inference=4, - fp16_inference=False, - inference_mode=True, - c=None, - N_ensemble_configurations=10, - preprocess_transforms=("none", "power_all"), - feature_shift_decoder=False, - normalize_with_test=False, - average_logits=False, - categorical_features=tuple(), - optimize_metric=None, - seed=None, - transformer_predict_kwargs_init=None, - multiclass_decoder="permutation", - - # dependency injection (for testing) - inference_handler=InferenceClient() - ): - # TODO: - # These configs are ignored at the moment -> all clients share the same (default) on-server TabPFNClassifier. - # In the future version, these configs will be used to create per-user TabPFNClassifier, - # allowing the user to setup the desired TabPFNClassifier on the server. - # config for tabpfn - self.model = model - self.device = device - self.base_path = base_path - self.model_string = model_string - self.batch_size_inference = batch_size_inference - self.fp16_inference = fp16_inference - self.inference_mode = inference_mode - self.c = c - self.N_ensemble_configurations = N_ensemble_configurations - self.preprocess_transforms = preprocess_transforms - self.feature_shift_decoder = feature_shift_decoder - self.normalize_with_test = normalize_with_test - self.average_logits = average_logits - self.categorical_features = categorical_features - self.optimize_metric = optimize_metric - self.seed = seed - self.transformer_predict_kwargs_init = transformer_predict_kwargs_init - self.multiclass_decoder = multiclass_decoder - - self.inference_handler = inference_handler - - def fit(self, X, y): - self.inference_handler.fit(X, y) - self.fitted_ = True - return self - - def predict(self, X): - check_is_fitted(self) - return self.inference_handler.predict(X) - - def predict_proba(self, X): - check_is_fitted(self) - return self.inference_handler.predict_proba(X) diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index 3773d17..77bb322 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -177,16 +177,9 @@ def fit(self, X, y) -> None: self.last_train_set_uid = self.service_client.upload_train_set(X, y) - def predict(self, X): + def predict(self, X, config=None): return self.service_client.predict( train_set_uid=self.last_train_set_uid, - x_test=X + x_test=X, + tabpfn_config=config ) - - def predict_proba(self, X): - return self.service_client.predict_proba( - train_set_uid=self.last_train_set_uid, - x_test=X - ) - - diff --git a/tabpfn_client/tabpfn_classifier.py b/tabpfn_client/tabpfn_classifier.py index 362e297..be0b161 100644 --- a/tabpfn_client/tabpfn_classifier.py +++ b/tabpfn_client/tabpfn_classifier.py @@ -1,12 +1,13 @@ +from typing import Optional, Tuple, Literal import logging -from pathlib import Path import shutil +from dataclasses import dataclass, asdict +import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.utils.validation import check_is_fitted from tabpfn import TabPFNClassifier as LocalTabPFNClassifier -from tabpfn_client.remote_tabpfn_classifier import RemoteTabPFNClassifier from tabpfn_client.service_wrapper import UserAuthenticationClient, InferenceClient from tabpfn_client.client import ServiceClient from tabpfn_client.constants import CACHE_DIR @@ -78,108 +79,175 @@ def reset(): shutil.rmtree(CACHE_DIR, ignore_errors=True) +@dataclass(eq=True, frozen=True) +class PreprocessorConfig: + """ + Configuration for data preprocessors. + + Attributes: + name (Literal): Name of the preprocessor. + categorical_name (Literal): Name of the categorical encoding method. Valid options are "none", "numeric", + "onehot", "ordinal", "ordinal_shuffled". Default is "none". + append_original (bool): Whether to append the original features to the transformed features. Default is False. + subsample_features (float): Fraction of features to subsample. -1 means no subsampling. Default is -1. + global_transformer_name (str): Name of the global transformer to use. Default is None. + """ + + name: Literal[ + "per_feature", # a different transformation for each feature + "power", # a standard sklearn power transformer + "safepower", # a power transformer that prevents some numerical issues + "power_box", + "safepower_box", + "quantile_uni_coarse", # different quantile transformations with few quantiles up to a lot + "quantile_norm_coarse", + "quantile_uni", + "quantile_norm", + "quantile_uni_fine", + "quantile_norm_fine", + "robust", # a standard sklearn robust scaler + "kdi", + "none", # no transformation (inside the transformer we anyways do a standardization) + "kdi_random_alpha", + "kdi_uni", + "kdi_random_alpha_uni", + "adaptive", + "norm_and_kdi", + # KDI with alpha collection + "kdi_alpha_0.3_uni", + "kdi_alpha_0.5_uni", + "kdi_alpha_0.8_uni", + "kdi_alpha_1.0_uni", + "kdi_alpha_1.2_uni", + "kdi_alpha_1.5_uni", + "kdi_alpha_2.0_uni", + "kdi_alpha_3.0_uni", + "kdi_alpha_5.0_uni", + "kdi_alpha_0.3", + "kdi_alpha_0.5", + "kdi_alpha_0.8", + "kdi_alpha_1.0", + "kdi_alpha_1.2", + "kdi_alpha_1.5", + "kdi_alpha_2.0", + "kdi_alpha_3.0", + "kdi_alpha_5.0", + ] + categorical_name: Literal[ + "none", + "numeric", + "onehot", + "ordinal", + "ordinal_shuffled", + "ordinal_very_common_categories_shuffled", + ] = "none" + # categorical_name meanings: + # "none": categorical features are pretty much treated as ordinal, just not resorted + # "numeric": categorical features are treated as numeric, that means they are also power transformed for example + # "onehot": categorical features are onehot encoded + # "ordinal": categorical features are sorted and encoded as integers from 0 to n_categories - 1 + # "ordinal_shuffled": categorical features are encoded as integers from 0 to n_categories - 1 in a random order + append_original: bool = False + subsample_features: Optional[float] = -1 + global_transformer_name: Optional[str] = None + # if True, the transformed features (e.g. power transformed) are appended to the original features + + def __str__(self): + return ( + f"{self.name}_cat:{self.categorical_name}" + + ("_and_none" if self.append_original else "") + + ( + "_subsample_feats_" + str(self.subsample_features) + if self.subsample_features > 0 + else "" + ) + + ( + f"_global_transformer_{self.global_transformer_name}" + if self.global_transformer_name is not None + else "" + ) + ) + + def can_be_cached(self): + return not self.subsample_features > 0 + + def to_dict(self): + return {k: str(v) if not isinstance(v, (str, int, float, list, dict)) else v for k, v in asdict(self).items()} + + +ClassificationOptimizationMetricType = Literal[ + "auroc", "roc", "auroc_ovo", "balanced_acc", "acc", "log_loss", None +] + + class TabPFNClassifier(BaseEstimator, ClassifierMixin): - # def __init__(self): - # Configuration for TabPFNClassifier is still under development. - # pass - def __init__( - self, - model="latest_tabpfn_hosted", - device="cpu", - base_path=Path(__file__).parent.parent.resolve(), - model_string="", - batch_size_inference=4, - fp16_inference=False, - inference_mode=True, - c=None, - N_ensemble_configurations=10, - preprocess_transforms=("none", "power_all"), - feature_shift_decoder=False, - normalize_with_test=False, - average_logits=False, - categorical_features=tuple(), - optimize_metric=None, - seed=None, - transformer_predict_kwargs_init=None, - multiclass_decoder="permutation", + self, + model="latest_tabpfn_hosted", + n_estimators: int = 4, + preprocess_transforms: Tuple[PreprocessorConfig, ...] = ( + PreprocessorConfig( + "quantile_uni_coarse", + append_original=True, + categorical_name="ordinal_very_common_categories_shuffled", + global_transformer_name="svd", + subsample_features=-1, + ), + PreprocessorConfig( + "none", categorical_name="numeric", subsample_features=-1 + ), + ), + feature_shift_decoder: str = "shuffle", + normalize_with_test: bool = False, + average_logits: bool = False, + optimize_metric: ClassificationOptimizationMetricType = "roc", + transformer_predict_kwargs: Optional[dict] = None, + multiclass_decoder="shuffle", + softmax_temperature: Optional[float] = -0.1, + use_poly_features=False, + max_poly_features=50, + remove_outliers=12.0, + add_fingerprint_features=True, + subsample_samples=-1, ): - # config for tabpfn self.model = model - self.device = device - self.base_path = base_path - self.model_string = model_string - self.batch_size_inference = batch_size_inference - self.fp16_inference = fp16_inference - self.inference_mode = inference_mode - self.c = c - self.N_ensemble_configurations = N_ensemble_configurations + self.n_estimators = n_estimators self.preprocess_transforms = preprocess_transforms self.feature_shift_decoder = feature_shift_decoder self.normalize_with_test = normalize_with_test self.average_logits = average_logits - self.categorical_features = categorical_features self.optimize_metric = optimize_metric - self.seed = seed - self.transformer_predict_kwargs_init = transformer_predict_kwargs_init + self.transformer_predict_kwargs = transformer_predict_kwargs self.multiclass_decoder = multiclass_decoder + self.softmax_temperature = softmax_temperature + self.use_poly_features = use_poly_features + self.max_poly_features = max_poly_features + self.remove_outliers = remove_outliers + self.add_fingerprint_features = add_fingerprint_features + self.subsample_samples = subsample_samples def fit(self, X, y): # assert init() is called if not g_tabpfn_config.is_initialized: raise RuntimeError("TabPFNClassifier.init() must be called before using TabPFNClassifier") - # create classifier if not created yet - if not hasattr(self, "classifier"): - # arguments that are commented out are not used at the moment - # (not supported until new TabPFN interface is released) - classifier_cfg = { - # "model": self.model, - "device": self.device, - "base_path": self.base_path, - "model_string": self.model_string, - "batch_size_inference": self.batch_size_inference, - # "fp16_inference": self.fp16_inference, - # "inference_mode": self.inference_mode, - # "c": self.c, - "N_ensemble_configurations": self.N_ensemble_configurations, - # "preprocess_transforms": self.preprocess_transforms, - "feature_shift_decoder": self.feature_shift_decoder, - # "normalize_with_test": self.normalize_with_test, - # "average_logits": self.average_logits, - # "categorical_features": self.categorical_features, - # "optimize_metric": self.optimize_metric, - "seed": self.seed, - # "transformer_predict_kwargs_init": self.transformer_predict_kwargs_init, - "multiclass_decoder": self.multiclass_decoder - } - #classifier_cfg = {} - - if g_tabpfn_config.use_server: - try: - assert self.model == "latest_tabpfn_hosted", "Only 'latest_tabpfn_hosted' model is supported at the moment for tabpfn_classifier.init(use_server=True)" - except AssertionError as e: - print(e) - self.classifier_ = RemoteTabPFNClassifier( - **classifier_cfg, - inference_handler=g_tabpfn_config.inference_handler - ) - else: - try: - assert self.model == "tabpfn_1_local", "Only 'tabpfn_1_local' model is supported at the moment for tabpfn_classifier.init(use_server=False)" - except AssertionError as e: - print(e) - self.classifier_ = LocalTabPFNClassifier(**classifier_cfg) - - self.classifier_.fit(X, y) + if g_tabpfn_config.use_server: + try: + assert self.model == "latest_tabpfn_hosted", "Only 'latest_tabpfn_hosted' model is supported at the moment for tabpfn_classifier.init(use_server=True)" + except AssertionError as e: + print(e) + g_tabpfn_config.inference_handler.fit(X, y) + self.fitted_ = True + else: + raise NotImplementedError("Only server mode is supported at the moment for tabpfn_classifier.init(use_server=False)") return self def predict(self, X): - check_is_fitted(self) - return self.classifier_.predict(X) + probas = self.predict_proba(X) + return np.argmax(probas, axis=1) def predict_proba(self, X): check_is_fitted(self) - return self.classifier_.predict_proba(X) + return g_tabpfn_config.inference_handler.predict(X, config=self.get_params()) diff --git a/tabpfn_client/tests/integration/test_tabpfn_classifier.py b/tabpfn_client/tests/integration/test_tabpfn_classifier.py index 1fe4dbf..09adb72 100644 --- a/tabpfn_client/tests/integration/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/integration/test_tabpfn_classifier.py @@ -19,15 +19,6 @@ def tearDown(self): tabpfn_classifier.reset() ServiceClient().delete_instance() - def test_use_local_tabpfn_classifier(self): - tabpfn_classifier.init(use_server=False) - tabpfn = TabPFNClassifier(device="cpu", model="tabpfn_1_local") - tabpfn.fit(self.X_train, self.y_train) - - self.assertTrue(isinstance(tabpfn.classifier_, LocalTabPFNClassifier)) - pred = tabpfn.predict(self.X_test) - self.assertEqual(pred.shape[0], self.X_test.shape[0]) - @with_mock_server() def test_use_remote_tabpfn_classifier(self, mock_server): # create dummy token file @@ -52,7 +43,7 @@ def test_use_remote_tabpfn_classifier(self, mock_server): # mock prediction mock_server.router.post(mock_server.endpoints.predict.path).respond( 200, - json={"y_pred": LocalTabPFNClassifier().fit(self.X_train, self.y_train).predict(self.X_test).tolist()} + json={"y_pred": LocalTabPFNClassifier().fit(self.X_train, self.y_train).predict_proba(self.X_test).tolist()} ) pred = tabpfn.predict(self.X_test) self.assertEqual(pred.shape[0], self.X_test.shape[0]) diff --git a/tabpfn_client/tests/unit/test_prompt_agent.py b/tabpfn_client/tests/unit/test_prompt_agent.py index 14b2b7b..7c17370 100644 --- a/tabpfn_client/tests/unit/test_prompt_agent.py +++ b/tabpfn_client/tests/unit/test_prompt_agent.py @@ -1,10 +1,7 @@ import unittest from unittest.mock import patch, MagicMock -import respx -from httpx import Response from tabpfn_client.prompt_agent import PromptAgent from tabpfn_client.tests.mock_tabpfn_server import with_mock_server -from tabpfn_client.service_wrapper import UserAuthenticationClient, ServiceClient class TestPromptAgent(unittest.TestCase): diff --git a/tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py deleted file mode 100644 index 5c76b55..0000000 --- a/tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py +++ /dev/null @@ -1,64 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch -import shutil - -from sklearn.datasets import load_breast_cancer -from sklearn.model_selection import train_test_split -from sklearn.exceptions import NotFittedError - -from tabpfn_client.remote_tabpfn_classifier import RemoteTabPFNClassifier -from tabpfn_client.client import ServiceClient -from tabpfn_client.service_wrapper import InferenceClient -from tabpfn_client.constants import CACHE_DIR - - -class TestRemoteTabPFNClassifier(unittest.TestCase): - - def setUp(self): - self.dummy_token = "dummy_token" - X, y = load_breast_cancer(return_X_y=True) - self.X_train, self.X_test, self.y_train, self.y_test = \ - train_test_split(X, y, test_size=0.33) - - # mock service client - self.mock_client = MagicMock(spec=ServiceClient) - self.mock_client.is_initialized.return_value = True - inference_handler = InferenceClient(service_client=self.mock_client) - - self.remote_tabpfn = RemoteTabPFNClassifier(inference_handler=inference_handler) - - def tearDown(self): - patch.stopall() - shutil.rmtree(CACHE_DIR, ignore_errors=True) - - def test_fit_and_predict_with_valid_datasets(self): - # mock responses - self.mock_client.upload_train_set.return_value = "dummy_train_set_uid" - - mock_predict_response = [1, 1, 0] - self.mock_client.predict.return_value = mock_predict_response - - self.remote_tabpfn.fit(self.X_train, self.y_train) - y_pred = self.remote_tabpfn.predict(self.X_test) - - self.assertEqual(mock_predict_response, y_pred) - self.mock_client.upload_train_set.called_once_with(self.X_train, self.y_train) - self.mock_client.predict.called_once_with(self.X_test) - - def test_call_predict_without_calling_fit_before(self): - self.assertRaises( - NotFittedError, - self.remote_tabpfn.predict, - self.X_test - ) - - def test_call_predict_proba_without_calling_fit_before(self): - self.assertRaises( - NotFittedError, - self.remote_tabpfn.predict_proba, - self.X_test - ) - - def test_predict_with_conflicting_test_set(self): - # TODO: implement this - pass diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index ccf8a88..45e5b2f 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -2,13 +2,13 @@ from unittest.mock import patch import shutil +import numpy as np from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split -from tabpfn import TabPFNClassifier as LocalTabPFNClassifier +from sklearn.exceptions import NotFittedError from tabpfn_client import tabpfn_classifier from tabpfn_client.tabpfn_classifier import TabPFNClassifier -from tabpfn_client.remote_tabpfn_classifier import RemoteTabPFNClassifier from tabpfn_client.service_wrapper import UserAuthenticationClient from tabpfn_client.client import ServiceClient from tabpfn_client.tests.mock_tabpfn_server import with_mock_server @@ -34,11 +34,6 @@ def tearDown(self): # remove cache dir shutil.rmtree(CACHE_DIR, ignore_errors=True) - def test_init_local_classifier(self): - tabpfn_classifier.init(use_server=False) - tabpfn = TabPFNClassifier(model="tabpfn_1_local").fit(self.X_train, self.y_train) - self.assertTrue(isinstance(tabpfn.classifier_, LocalTabPFNClassifier)) - @with_mock_server() @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token") @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond", @@ -54,13 +49,26 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con ) mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( 200, json={"messages": []}) + mock_predict_response = [[1, 0.],[.9, .1],[0.01, 0.99]] + mock_server.router.post(mock_server.endpoints.predict.path).respond( + 200, json={"y_pred": mock_predict_response} + ) tabpfn_classifier.init(use_server=True) - tabpfn = TabPFNClassifier().fit(self.X_train, self.y_train) - self.assertTrue(isinstance(tabpfn.classifier_, RemoteTabPFNClassifier)) + + tabpfn = TabPFNClassifier() + self.assertRaises( + NotFittedError, + tabpfn.predict, + self.X_test + ) + tabpfn.fit(self.X_train, self.y_train) self.assertTrue(mock_prompt_and_set_token.called) self.assertTrue(mock_prompt_for_terms_and_cond.called) + y_pred = tabpfn.predict(self.X_test) + self.assertTrue(np.all(np.argmax(mock_predict_response, axis=1) == y_pred)) + @with_mock_server() def test_reuse_saved_access_token(self, mock_server): # mock connection and authentication @@ -99,11 +107,6 @@ def test_invalid_saved_access_token(self, mock_server, mock_prompt_for_terms_and self.assertRaises(RuntimeError, tabpfn_classifier.init, use_server=True) self.assertTrue(mock_prompt_and_set_token.called) - def test_reset_on_local_classifier(self): - tabpfn_classifier.init(use_server=False) - tabpfn_classifier.reset() - self.assertFalse(tabpfn_classifier.g_tabpfn_config.is_initialized) - @with_mock_server() def test_reset_on_remote_classifier(self, mock_server): # create dummy token file From 70dd739aff5c32879df2dc07c6123df92d312ddc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20M=C3=BCller?= Date: Mon, 6 May 2024 14:58:32 +0200 Subject: [PATCH 2/6] updated test to include a test that an example param is passed to the server --- tabpfn_client/tests/unit/test_tabpfn_classifier.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index 45e5b2f..8acb318 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -50,13 +50,14 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond( 200, json={"messages": []}) mock_predict_response = [[1, 0.],[.9, .1],[0.01, 0.99]] - mock_server.router.post(mock_server.endpoints.predict.path).respond( + predict_route = mock_server.router.post(mock_server.endpoints.predict.path) + predict_route.respond( 200, json={"y_pred": mock_predict_response} ) tabpfn_classifier.init(use_server=True) - tabpfn = TabPFNClassifier() + tabpfn = TabPFNClassifier(n_estimators=10) self.assertRaises( NotFittedError, tabpfn.predict, @@ -69,6 +70,8 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con y_pred = tabpfn.predict(self.X_test) self.assertTrue(np.all(np.argmax(mock_predict_response, axis=1) == y_pred)) + self.assertIn('n_estimators%22%3A%2010', str(predict_route.calls.last.request.url), "check that n_estimators is passed to the server") + @with_mock_server() def test_reuse_saved_access_token(self, mock_server): # mock connection and authentication From f5795095d02b3390e0f6c2408cd22faf24d84941 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20M=C3=BCller?= Date: Mon, 6 May 2024 16:26:30 +0200 Subject: [PATCH 3/6] updated doc string --- tabpfn_client/tabpfn_classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tabpfn_client/tabpfn_classifier.py b/tabpfn_client/tabpfn_classifier.py index be0b161..cf30ff2 100644 --- a/tabpfn_client/tabpfn_classifier.py +++ b/tabpfn_client/tabpfn_classifier.py @@ -229,7 +229,7 @@ def __init__( def fit(self, X, y): # assert init() is called if not g_tabpfn_config.is_initialized: - raise RuntimeError("TabPFNClassifier.init() must be called before using TabPFNClassifier") + raise RuntimeError("tabpfn_client.init() must be called before using TabPFNClassifier") if g_tabpfn_config.use_server: try: From c474a9d1cc156b0f9237ee1265a9395b34c94ba9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20M=C3=BCller?= Date: Wed, 8 May 2024 15:49:33 +0200 Subject: [PATCH 4/6] updated api --- quick_test.py | 2 +- tabpfn_client/client.py | 2 +- tabpfn_client/prompt_agent.py | 2 +- tabpfn_client/server_config.yaml | 13 +++++++------ tabpfn_client/tabpfn_common_utils | 2 +- .../tests/integration/test_tabpfn_classifier.py | 2 +- tabpfn_client/tests/unit/test_client.py | 4 ++-- tabpfn_client/tests/unit/test_tabpfn_classifier.py | 2 +- 8 files changed, 15 insertions(+), 14 deletions(-) diff --git a/quick_test.py b/quick_test.py index 3f0caac..128b1af 100644 --- a/quick_test.py +++ b/quick_test.py @@ -32,7 +32,7 @@ else: tabpfn_classifier.init() - tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted") + tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted", n_estimators=3) # print("checking estimator", check_estimator(tabpfn)) tabpfn.fit(X_train[:99], y_train[:99]) print("predicting") diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index de20dec..f78551c 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -133,7 +133,7 @@ def predict(self, train_set_uid: str, x_test, tabpfn_config: dict | None=None): self._validate_response(response, "predict") - return np.array(response.json()["y_pred"]) + return np.array(response.json()["y_pred_proba"]) @staticmethod def _validate_response(response, method_name, only_version_check=False): diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index 432fd34..66f88d7 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -51,7 +51,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): # Registration if choice == "1": # validation_link = input(cls.indent("Please enter your secret code: ")) - validation_link = "tabpfn-2023" + validation_link = "tabpfn-test" while True: email = input(cls.indent("Please enter your email: ")) # Send request to server to check if email is valid and not already taken. diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index ee05c1e..4923bf2 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -1,12 +1,13 @@ ## testing -#protocol: "http" -#host: "0.0.0.0" -#port: "8000" +protocol: "http" +host: "localhost" +port: "8080" # production -protocol: "https" -host: "tabpfn-server-wjedmz7r5a-ez.a.run.app" -port: "443" +#protocol: "https" +#host: "tabpfn-server-wjedmz7r5a-ez.a.run.app" +#host: tabpfn-server-preprod-wjedmz7r5a-ez.a.run.app # preprod +#port: "443" endpoints: root: path: "/" diff --git a/tabpfn_client/tabpfn_common_utils b/tabpfn_client/tabpfn_common_utils index a2df122..cb44694 160000 --- a/tabpfn_client/tabpfn_common_utils +++ b/tabpfn_client/tabpfn_common_utils @@ -1 +1 @@ -Subproject commit a2df122f2894369a444eb2335776d7dd5eade5d9 +Subproject commit cb4469425eba995b4cefad1357c020878e1a6d02 diff --git a/tabpfn_client/tests/integration/test_tabpfn_classifier.py b/tabpfn_client/tests/integration/test_tabpfn_classifier.py index 09adb72..52eb80d 100644 --- a/tabpfn_client/tests/integration/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/integration/test_tabpfn_classifier.py @@ -43,7 +43,7 @@ def test_use_remote_tabpfn_classifier(self, mock_server): # mock prediction mock_server.router.post(mock_server.endpoints.predict.path).respond( 200, - json={"y_pred": LocalTabPFNClassifier().fit(self.X_train, self.y_train).predict_proba(self.X_test).tolist()} + json={"y_pred_proba": LocalTabPFNClassifier().fit(self.X_train, self.y_train).predict_proba(self.X_test).tolist()} ) pred = tabpfn.predict(self.X_test) self.assertEqual(pred.shape[0], self.X_test.shape[0]) diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py index ab2a850..a913896 100644 --- a/tabpfn_client/tests/unit/test_client.py +++ b/tabpfn_client/tests/unit/test_client.py @@ -91,7 +91,7 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server): self.client.upload_train_set(self.X_train, self.y_train) - dummy_result = {"y_pred": [1, 2, 3]} + dummy_result = {"y_pred_proba": [1, 2, 3]} mock_server.router.post(mock_server.endpoints.predict.path).respond( 200, json=dummy_result) @@ -99,7 +99,7 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server): train_set_uid=dummy_json["train_set_uid"], x_test=self.X_test ) - self.assertTrue(np.array_equal(pred, dummy_result["y_pred"])) + self.assertTrue(np.array_equal(pred, dummy_result["y_pred_proba"])) @with_mock_server() def test_add_user_information(self, mock_server): diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py index 8acb318..5653a25 100644 --- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py +++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py @@ -52,7 +52,7 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con mock_predict_response = [[1, 0.],[.9, .1],[0.01, 0.99]] predict_route = mock_server.router.post(mock_server.endpoints.predict.path) predict_route.respond( - 200, json={"y_pred": mock_predict_response} + 200, json={"y_pred_proba": mock_predict_response} ) tabpfn_classifier.init(use_server=True) From 70d9cfeee827fce4f4578a61e417b9af73b58b9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20M=C3=BCller?= Date: Wed, 8 May 2024 15:51:42 +0200 Subject: [PATCH 5/6] updated version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7d0f872..71254f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "tabpfn-client" -version = "0.0.11" +version = "0.0.12" requires-python = ">=3.10" dependencies = [ "httpx>=0.24.1", From 79822bdebced4d65b0c24547214694d44b9bd4c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20M=C3=BCller?= Date: Mon, 13 May 2024 14:56:48 +0200 Subject: [PATCH 6/6] udpated validation link --- tabpfn_client/prompt_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index 66f88d7..432fd34 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -51,7 +51,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): # Registration if choice == "1": # validation_link = input(cls.indent("Please enter your secret code: ")) - validation_link = "tabpfn-test" + validation_link = "tabpfn-2023" while True: email = input(cls.indent("Please enter your email: ")) # Send request to server to check if email is valid and not already taken.