Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoGrin committed Jan 14, 2025
1 parent 07075fb commit f9f1680
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 14 deletions.
26 changes: 23 additions & 3 deletions tabpfn_client/tests/unit/test_tabpfn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def test_string_label_predictions(self):
y_pred.dtype.kind in {"U", "O"}, "Predictions should be string type"
)

def test_predict_with_long_and_comma_text(self):
def test_predict_with_uncleaned_text(self):
"""Test predictions with long text (>2500 chars) and text containing commas."""
# Skip initialization
tabpfn = TabPFNClassifier()
Expand All @@ -495,20 +495,26 @@ def test_predict_with_long_and_comma_text(self):
base_text = "very " * 500 # 2500 characters
long_text = base_text + " extra text that should be truncated"
text_with_commas = "very, " * 500 # Same length but with commas
text_with_spaces = (
"text\n\n with\t\tweird spaces\r\nand\n\n\nlinebreaks" * 100
) # Text with various whitespace

# Create variations of the same data with different text
X_normal = [row + [base_text] for row in X]
X_long = [row + [long_text] for row in X]
X_commas = [row + [text_with_commas] for row in X]
X_spaces = [row + [text_with_spaces] for row in X]

# Convert to numpy arrays and make copies for comparison
X_normal_array = np.array(X_normal)
X_long_array = np.array(X_long)
X_commas_array = np.array(X_commas)
X_spaces_array = np.array(X_spaces)

X_normal_copy = X_normal_array.copy()
X_long_copy = X_long_array.copy()
X_commas_copy = X_commas_array.copy()
X_spaces_copy = X_spaces_array.copy()

# Mock predictions
expected_predictions = np.random.randint(0, 2, n_samples)
Expand All @@ -525,6 +531,7 @@ def test_predict_with_long_and_comma_text(self):
pred_normal = tabpfn.predict(X_normal_array)
pred_long = tabpfn.predict(X_long_array)
pred_commas = tabpfn.predict(X_commas_array)
pred_spaces = tabpfn.predict(X_spaces_array)

# Verify input arrays were not modified
np.testing.assert_array_equal(
Expand All @@ -542,11 +549,17 @@ def test_predict_with_long_and_comma_text(self):
X_commas_copy,
"Input array with comma text was modified during prediction",
)
np.testing.assert_array_equal(
X_spaces_array,
X_spaces_copy,
"Input array with special spaces was modified during prediction",
)

# Verify predictions are returned as expected
np.testing.assert_array_equal(pred_normal, expected_predictions)
np.testing.assert_array_equal(pred_long, expected_predictions)
np.testing.assert_array_equal(pred_commas, expected_predictions)
np.testing.assert_array_equal(pred_spaces, expected_predictions)

# Verify that long text (which should be truncated) gives same predictions as normal text
np.testing.assert_array_equal(pred_normal, pred_long)
Expand All @@ -558,6 +571,7 @@ def test_predict_with_long_and_comma_text(self):
proba_normal = tabpfn.predict_proba(X_normal_array)
proba_long = tabpfn.predict_proba(X_long_array)
proba_commas = tabpfn.predict_proba(X_commas_array)
proba_spaces = tabpfn.predict_proba(X_spaces_array)

# Verify input arrays were not modified during predict_proba
np.testing.assert_array_equal(
Expand All @@ -575,17 +589,23 @@ def test_predict_with_long_and_comma_text(self):
X_commas_copy,
"Input array with comma text was modified during probability prediction",
)
np.testing.assert_array_equal(
X_spaces_array,
X_spaces_copy,
"Input array with special spaces was modified during probability prediction",
)

# Verify probability predictions are returned as expected
np.testing.assert_array_equal(proba_normal, expected_probas)
np.testing.assert_array_equal(proba_long, expected_probas)
np.testing.assert_array_equal(proba_commas, expected_probas)
np.testing.assert_array_equal(proba_spaces, expected_probas)

# Verify that long text gives same probability predictions as normal text
np.testing.assert_array_equal(proba_normal, proba_long)

# Verify predict and predict_proba were each called 3 times
self.assertEqual(mock_predict.call_count, 6)
# Verify predict and predict_proba were each called 4 times
self.assertEqual(mock_predict.call_count, 8)

def test_predict_with_pandas_dataframe(self):
"""Test predictions with pandas DataFrame input, including text columns."""
Expand Down
54 changes: 43 additions & 11 deletions tabpfn_client/tests/unit/test_tabpfn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,36 +457,68 @@ def test_predict_with_long_and_comma_text(self):
base_text = "very " * 500 # 2500 characters
long_text = base_text + " extra text that should be truncated"
text_with_commas = "very, " * 500 # Same length but with commas
text_with_spaces = (
"text\n\n with\t\tweird spaces\r\nand\n\n\nlinebreaks" * 100
) # Text with various whitespace

# Create variations of the same data with different text
X_normal = [row + [base_text] for row in X]
X_long = [row + [long_text] for row in X]
X_commas = [row + [text_with_commas] for row in X]
X_spaces = [row + [text_with_spaces] for row in X]

# Convert to numpy arrays and make copies for comparison
X_normal_array = np.array(X_normal)
X_long_array = np.array(X_long)
X_commas_array = np.array(X_commas)
X_spaces_array = np.array(X_spaces)

X_normal_copy = X_normal_array.copy()
X_long_copy = X_long_array.copy()
X_commas_copy = X_commas_array.copy()
X_spaces_copy = X_spaces_array.copy()

# Mock predictions
expected_predictions = np.random.randn(n_samples)
with patch.object(InferenceClient, "predict") as mock_predict:
mock_predict.return_value = expected_predictions

# Test predictions for each variation
pred_normal = tabpfn.predict(np.array(X_normal))
pred_long = tabpfn.predict(np.array(X_long))
pred_commas = tabpfn.predict(np.array(X_commas))
pred_normal = tabpfn.predict(X_normal_array)
pred_long = tabpfn.predict(X_long_array)
pred_commas = tabpfn.predict(X_commas_array)
pred_spaces = tabpfn.predict(X_spaces_array)

# Verify input arrays were not modified
np.testing.assert_array_equal(
X_normal_array,
X_normal_copy,
"Input array with normal text was modified during prediction",
)
np.testing.assert_array_equal(
X_long_array,
X_long_copy,
"Input array with long text was modified during prediction",
)
np.testing.assert_array_equal(
X_commas_array,
X_commas_copy,
"Input array with comma text was modified during prediction",
)
np.testing.assert_array_equal(
X_spaces_array,
X_spaces_copy,
"Input array with special spaces was modified during prediction",
)

# Verify predictions are returned as expected
print("pred_normal", pred_normal)
print("pred_long", pred_long)
print("pred_commas", pred_commas)
print("expected_predictions", expected_predictions)
np.testing.assert_array_equal(pred_normal, expected_predictions)
np.testing.assert_array_equal(pred_long, expected_predictions)
np.testing.assert_array_equal(pred_commas, expected_predictions)

# Verify that long text (which should be truncated) gives same predictions as normal text
np.testing.assert_array_equal(pred_normal, pred_long)
np.testing.assert_array_equal(pred_spaces, expected_predictions)

# Verify predict was called the same way for all variations
self.assertEqual(mock_predict.call_count, 3)
self.assertEqual(mock_predict.call_count, 4)

def test_predict_with_pandas_dataframe(self):
"""Test predictions with pandas DataFrame input, including text columns."""
Expand Down
193 changes: 193 additions & 0 deletions tabpfn_client/tests/unit/test_text_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import unittest
import numpy as np
import pandas as pd
from tabpfn_client.estimator import _clean_text_features
from io import BytesIO
from tabpfn_client.tabpfn_common_utils import utils


class TestCleanTextFeatures(unittest.TestCase):
def test_numeric_numpy_array_unchanged(self):
# Numeric numpy arrays should be returned as-is
X = np.array([[1.0, 2.0], [3.0, 4.0]])
result = _clean_text_features(X)
np.testing.assert_array_equal(X, result)
self.assertIs(type(result), np.ndarray)

def test_object_numpy_array_cleaning(self):
# Object numpy arrays with text should be cleaned
X = np.array(
[
["text1,with,commas and spaces", "short text"],
["a" * 3000, "text2,more,commas here"],
]
)
result = _clean_text_features(X)

self.assertIs(type(result), np.ndarray)
# Check comma removal
self.assertNotIn(",", result[0, 0])
# Check multiple spaces are normalized
self.assertNotIn(" ", result[0, 0])
self.assertNotIn(" ", result[0, 1])
# Check text truncation (2500 char limit)
self.assertEqual(len(result[1, 0]), 2500)

def test_pandas_dataframe_cleaning(self):
# DataFrame with mixed numeric and text columns
df = pd.DataFrame(
{
"numeric": [1.0, 2.0],
"text": [
"text1,with,commas and spaces",
"text2,with,commas\n\nspaces",
],
"long_text": ["a" * 3000, "b " * 750],
}
)

result = _clean_text_features(df)

self.assertIs(type(result), pd.DataFrame)
# Numeric column should be unchanged
np.testing.assert_array_equal(result["numeric"], df["numeric"])
# Text columns should be cleaned
self.assertNotIn(",", result["text"].iloc[0])
self.assertNotIn(" ", result["text"].iloc[0])
self.assertNotIn("\n\n", result["text"].iloc[1])
self.assertEqual(len(result["long_text"].iloc[0]), 2500)

def test_mixed_content_dataframe(self):
# Test handling of mixed content in the same column
df = pd.DataFrame(
{
"mixed": ["text,with,comma", 123, "another,comma"],
"numeric_as_string": ["123", "456", "789"],
}
)

result = _clean_text_features(df)

# Check that numeric strings are preserved
self.assertEqual(result["numeric_as_string"].iloc[0], "123")
# Check that text is cleaned
self.assertNotIn(",", result["mixed"].iloc[0])

def test_null_values_handling(self):
# Test handling of null values
df = pd.DataFrame(
{"text": ["text,with,comma", None, np.nan], "numeric": [1.0, None, np.nan]}
)

result = _clean_text_features(df)

# Verify null values are preserved
self.assertTrue(pd.isna(result["text"].iloc[1]))
self.assertTrue(pd.isna(result["text"].iloc[2]))
self.assertTrue(pd.isna(result["numeric"].iloc[1]))
self.assertTrue(pd.isna(result["numeric"].iloc[2]))

def test_numpy_array_with_missing_values(self):
# Test cleaning of text data with missing values interspersed
X = np.array(
[
["long," * 1000 + "text", None],
[np.nan, "short,text"],
["medium,text", ""],
]
)
result = _clean_text_features(X)

self.assertIs(type(result), np.ndarray)
# Check text cleaning still works with missing values present
self.assertNotIn(",", result[0, 0])
self.assertNotIn(",", result[1, 1])
# Check missing values are preserved
self.assertTrue(pd.isna(result[0, 1]))
self.assertTrue(pd.isna(result[1, 0]))
# Check empty string is preserved
self.assertEqual(result[2, 1], "")
# Check long text truncation still works
self.assertEqual(len(result[0, 0]), 2500)

def test_dataframe_with_text_and_missing_values(self):
# Test DataFrame with different types of missing values in different columns
df = pd.DataFrame(
{
"none_nulls": [
"long," * 1000 + "text",
None,
"text,with,commas",
None,
"",
],
"numpy_nulls": [
"short,text",
np.nan,
"more,commas",
np.nan,
"last,text",
],
"pandas_nulls": ["first,text", pd.NA, "middle,text", pd.NA, "end,text"],
"mixed_nulls": [None, np.nan, pd.NA, "some,text", ""],
}
)

result = _clean_text_features(df)

self.assertIs(type(result), pd.DataFrame)
# Check text cleaning still works for each column
self.assertNotIn(",", result["none_nulls"].iloc[0])
self.assertNotIn(",", result["numpy_nulls"].iloc[0])
self.assertNotIn(",", result["pandas_nulls"].iloc[0])
self.assertNotIn(",", result["mixed_nulls"].iloc[3])

# Check None values are preserved
self.assertTrue(pd.isna(result["none_nulls"].iloc[1]))
self.assertTrue(pd.isna(result["none_nulls"].iloc[3]))

# Check np.nan values are preserved
self.assertTrue(pd.isna(result["numpy_nulls"].iloc[1]))
self.assertTrue(pd.isna(result["numpy_nulls"].iloc[3]))

# Check pd.NA values are preserved
self.assertTrue(pd.isna(result["pandas_nulls"].iloc[1]))
self.assertTrue(pd.isna(result["pandas_nulls"].iloc[3]))

# Check mixed null types are preserved
self.assertTrue(pd.isna(result["mixed_nulls"].iloc[0])) # None
self.assertTrue(pd.isna(result["mixed_nulls"].iloc[1])) # np.nan
self.assertTrue(pd.isna(result["mixed_nulls"].iloc[2])) # pd.NA

# Check empty strings are preserved
self.assertEqual(result["none_nulls"].iloc[4], "")
self.assertEqual(result["mixed_nulls"].iloc[4], "")

# Check long text truncation
self.assertEqual(len(result["none_nulls"].iloc[0]), 2500)

def test_serialization_with_cleaned_text(self):
"""Test serialization of data after text cleaning"""
test_data = pd.DataFrame(
{
"text": [
"text1,with,commas and spaces",
"text2,with,commas\n\nspaces",
],
"numeric": [1.0, 2.0],
}
)

cleaned_data = _clean_text_features(test_data)

serialized = utils.serialize_to_csv_formatted_bytes(cleaned_data)
# TODO: I think this serialization is not exactly what's happening on the server.
deserialized = pd.read_csv(BytesIO(serialized), delimiter=",")

pd.testing.assert_frame_equal(cleaned_data, deserialized)

# Verify text was properly cleaned and remained clean after serialization
for i in range(len(test_data)):
self.assertNotIn(",", deserialized["text"].iloc[i])
self.assertNotIn(" ", deserialized["text"].iloc[i])
self.assertNotIn("\n", deserialized["text"].iloc[i])

0 comments on commit f9f1680

Please sign in to comment.