diff --git a/conftest.py b/conftest.py index 622ae345..2db79ec0 100644 --- a/conftest.py +++ b/conftest.py @@ -95,15 +95,6 @@ def sample_data(): }, ] -@pytest.fixture(scope="session") -def event_loop(): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - yield loop - loop.close() - @pytest.fixture def clear_db(redis): redis.flushall() diff --git a/redisvl/redis/utils.py b/redisvl/redis/utils.py index e8108bbb..29ad5f78 100644 --- a/redisvl/redis/utils.py +++ b/redisvl/redis/utils.py @@ -21,11 +21,11 @@ def convert_bytes(data: Any) -> Any: except: return data if isinstance(data, dict): - return dict(map(convert_bytes, data.items())) + return {convert_bytes(key): convert_bytes(value) for key, value in data.items()} if isinstance(data, list): - return list(map(convert_bytes, data)) + return [convert_bytes(item) for item in data] if isinstance(data, tuple): - return map(convert_bytes, data) + return tuple(convert_bytes(item) for item in data) return data diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index f406498d..de7495b0 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -4,7 +4,8 @@ from redisvl.extensions.llmcache import SemanticCache from redisvl.utils.vectorize import HFTextVectorizer - +from redisvl.index.index import SearchIndex +from collections import namedtuple @pytest.fixture def vectorizer(): @@ -18,6 +19,10 @@ def cache(vectorizer): cache_instance.clear() # Clear cache after each test cache_instance._index.delete(True) # Clean up index +@pytest.fixture +def cache_no_cleanup(vectorizer): + cache_instance = SemanticCache(vectorizer=vectorizer, distance_threshold=0.2) + yield cache_instance @pytest.fixture def cache_with_ttl(vectorizer): @@ -26,6 +31,12 @@ def cache_with_ttl(vectorizer): cache_instance.clear() # Clear cache after each test cache_instance._index.delete(True) # Clean up index +@pytest.fixture +def cache_with_redis_client(vectorizer, client): + cache_instance = SemanticCache(vectorizer=vectorizer, redis_client=client, distance_threshold=0.2) + yield cache_instance + cache_instance.clear() # Clear cache after each test + cache_instance._index.delete(True) # Clean up index # Test basic store and check functionality def test_store_and_check(cache, vectorizer): @@ -83,6 +94,10 @@ def test_check_invalid_input(cache): with pytest.raises(TypeError): cache.check(prompt="test", return_fields="bad value") +# Test handling invalid input for check method +def test_bad_ttl(cache): + with pytest.raises(ValueError): + cache.set_ttl(2.5) # Test storing with metadata def test_store_with_metadata(cache, vectorizer): @@ -100,6 +115,16 @@ def test_store_with_metadata(cache, vectorizer): assert check_result[0]["metadata"] == metadata assert check_result[0]["prompt"] == prompt +# Test storing with invalid metadata +def test_store_with_invalid_metadata(cache, vectorizer): + prompt = "This is another test prompt." + response = "This is another test response." + metadata = namedtuple('metadata', 'source')(**{'source': 'test'}) + + vector = vectorizer.embed(prompt) + + with pytest.raises(TypeError, match=r"If specified, cached metadata must be a dictionary."): + cache.store(prompt, response, vector=vector, metadata=metadata) # Test setting and getting the distance threshold def test_distance_threshold(cache): @@ -110,6 +135,11 @@ def test_distance_threshold(cache): assert cache.distance_threshold == new_threshold assert cache.distance_threshold != initial_threshold +# Test out of range distance threshold +def test_distance_threshold_out_of_range(cache): + out_of_range_threshold = -1 + with pytest.raises(ValueError): + cache.set_threshold(out_of_range_threshold) # Test storing and retrieving multiple items def test_multiple_items(cache, vectorizer): @@ -130,3 +160,26 @@ def test_multiple_items(cache, vectorizer): print(check_result, flush=True) assert check_result[0]["response"] == expected_response assert "metadata" not in check_result[0] + +# Test retrieving underlying SearchIndex for the cache. +def test_get_index(cache): + assert isinstance(cache.index, SearchIndex) + +# Test basic functionality with cache created with user-provided Redis client +def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer): + prompt = "This is a test prompt." + response = "This is a test response." + vector = vectorizer.embed(prompt) + + cache_with_redis_client.store(prompt, response, vector=vector) + check_result = cache_with_redis_client.check(vector=vector) + + assert len(check_result) == 1 + print(check_result, flush=True) + assert response == check_result[0]["response"] + assert "metadata" not in check_result[0] + +# Test deleting the cache +def test_delete(cache_no_cleanup, vectorizer): + cache_no_cleanup.delete() + assert not cache_no_cleanup.index.exists() \ No newline at end of file diff --git a/tests/integration/test_search_results.py b/tests/integration/test_search_results.py new file mode 100644 index 00000000..15b2048a --- /dev/null +++ b/tests/integration/test_search_results.py @@ -0,0 +1,62 @@ +import pytest + +from redisvl.index import SearchIndex +from redisvl.query import FilterQuery +from redisvl.query.filter import Tag + +@pytest.fixture +def filter_query(): + return FilterQuery( + return_fields=None, + filter_expression=Tag("credit_score") == "high", + ) + +@pytest.fixture +def index(sample_data): + fields_spec = [ + {"name": "credit_score", "type": "tag"}, + {"name": "user", "type": "tag"}, + {"name": "job", "type": "text"}, + {"name": "age", "type": "numeric"}, + { + "name": "user_embedding", + "type": "vector", + "attrs": { + "dims": 3, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32", + }, + }, + ] + + json_schema = { + "index": { + "name": "user_index_json", + "prefix": "users_json", + "storage_type": "json", + }, + "fields": fields_spec, + } + + # construct a search index from the schema + index = SearchIndex.from_dict(json_schema) + + # connect to local redis instance + index.connect("redis://localhost:6379") + + # create the index (no data yet) + index.create(overwrite=True) + + # Prepare and load the data + index.load(sample_data) + + # run the test + yield index + + # clean up + index.delete(drop=True) + +def test_process_results_unpacks_json_properly(index, filter_query): + results = index.query(filter_query) + assert len(results) == 4 \ No newline at end of file diff --git a/tests/unit/test_async_search_index.py b/tests/unit/test_async_search_index.py index c2b4c359..1676ff7d 100644 --- a/tests/unit/test_async_search_index.py +++ b/tests/unit/test_async_search_index.py @@ -3,6 +3,7 @@ from redisvl.index import AsyncSearchIndex from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType +from redisvl.query import VectorQuery fields = [{"name": "test", "type": "tag"}] @@ -137,3 +138,36 @@ async def test_no_id_field(async_client, async_index): # catch missing / invalid id_field with pytest.raises(ValueError): await async_index.load(bad_data, id_field="key") + + +@pytest.mark.asyncio +async def test_check_index_exists_before_delete(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True, drop=True) + await async_index.delete(drop=True) + with pytest.raises(ValueError): + await async_index.delete() + +@pytest.mark.asyncio +async def test_check_index_exists_before_search(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True, drop=True) + await async_index.delete(drop=True) + + query = VectorQuery( + [0.1, 0.1, 0.5], + "user_embedding", + return_fields=["user", "credit_score", "age", "job", "location"], + num_results=7, + ) + with pytest.raises(ValueError): + await async_index.search(query.query, query_params=query.params) + +@pytest.mark.asyncio +async def test_check_index_exists_before_info(async_client, async_index): + async_index.set_client(async_client) + await async_index.create(overwrite=True, drop=True) + await async_index.delete(drop=True) + + with pytest.raises(ValueError): + await async_index.info() diff --git a/tests/unit/test_search_index.py b/tests/unit/test_search_index.py index 5b98164d..4e83c9ab 100644 --- a/tests/unit/test_search_index.py +++ b/tests/unit/test_search_index.py @@ -3,6 +3,7 @@ from redisvl.index import SearchIndex from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType +from redisvl.query import VectorQuery fields = [{"name": "test", "type": "tag"}] @@ -11,11 +12,13 @@ def index_schema(): return IndexSchema.from_dict({"index": {"name": "my_index"}, "fields": fields}) - @pytest.fixture def index(index_schema): return SearchIndex(schema=index_schema) +@pytest.fixture +def index_from_yaml(): + return SearchIndex.from_yaml("schemas/test_json_schema.yaml") def test_search_index_properties(index_schema, index): assert index.schema == index_schema @@ -28,6 +31,13 @@ def test_search_index_properties(index_schema, index): assert index.storage_type == index_schema.index.storage_type == StorageType.HASH assert index.key("foo").startswith(index.prefix) +def test_search_index_from_yaml(index_from_yaml): + assert index_from_yaml.name == "json-test" + assert index_from_yaml.client == None + assert index_from_yaml.prefix == "json" + assert index_from_yaml.key_separator == ":" + assert index_from_yaml.storage_type == StorageType.JSON + assert index_from_yaml.key("foo").startswith(index_from_yaml.prefix) def test_search_index_no_prefix(index_schema): # specify an explicitly empty prefix... @@ -118,3 +128,36 @@ def test_no_id_field(client, index): # catch missing / invalid id_field with pytest.raises(ValueError): index.load(bad_data, id_field="key") + +def test_check_index_exists_before_delete(client, index): + index.set_client(client) + index.create(overwrite=True, drop=True) + index.delete(drop=True) + with pytest.raises(ValueError): + index.delete() + +def test_check_index_exists_before_search(client, index): + index.set_client(client) + index.create(overwrite=True, drop=True) + index.delete(drop=True) + + query = VectorQuery( + [0.1, 0.1, 0.5], + "user_embedding", + return_fields=["user", "credit_score", "age", "job", "location"], + num_results=7, + ) + with pytest.raises(ValueError): + index.search(query.query, query_params=query.params) + +def test_check_index_exists_before_info(client, index): + index.set_client(client) + index.create(overwrite=True, drop=True) + index.delete(drop=True) + + with pytest.raises(ValueError): + index.info() + +def test_index_needs_valid_schema(): + with pytest.raises(ValueError, match=r"Must provide a valid IndexSchema object"): + index = SearchIndex(schema="Not A Valid Schema") \ No newline at end of file diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 00000000..9a3c45db --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,127 @@ +import pytest +import numpy as np +from redisvl.redis.utils import make_dict, buffer_to_array, convert_bytes, array_to_buffer + +def test_even_number_of_elements(): + """Test with an even number of elements""" + values = ['key1', 'value1', 'key2', 'value2'] + expected = {'key1': 'value1', 'key2': 'value2'} + assert make_dict(values) == expected + +def test_odd_number_of_elements(): + """Test with an odd number of elements - expecting the last element to be ignored""" + values = ['key1', 'value1', 'key2'] + expected = {'key1': 'value1'} # 'key2' has no pair, so it's ignored + assert make_dict(values) == expected + +def test_different_data_types(): + """Test with different data types as keys and values""" + values = [1, 'one', 2.0, 'two'] + expected = {1: 'one', 2.0: 'two'} + assert make_dict(values) == expected + +def test_empty_list(): + """Test with an empty list""" + values = [] + expected = {} + assert make_dict(values) == expected + +def test_with_complex_objects(): + """Test with complex objects like lists and dicts as values""" + key = 'a list' + value = [1, 2, 3] + values = [key, value] + expected = {key: value} + assert make_dict(values) == expected + +def test_simple_byte_buffer_to_floats(): + """Test conversion of a simple byte buffer into floats""" + buffer = np.array([1.0, 2.0, 3.0], dtype=np.float32).tobytes() + expected = [1.0, 2.0, 3.0] + assert buffer_to_array(buffer, dtype=np.float32) == expected + +def test_different_data_types(): + """Test conversion with different data types""" + # Integer test + buffer = np.array([1, 2, 3], dtype=np.int32).tobytes() + expected = [1, 2, 3] + assert buffer_to_array(buffer, dtype=np.int32) == expected + + # Float64 test + buffer = np.array([1.0, 2.0, 3.0], dtype=np.float64).tobytes() + expected = [1.0, 2.0, 3.0] + assert buffer_to_array(buffer, dtype=np.float64) == expected + +def test_empty_byte_buffer(): + """Test conversion of an empty byte buffer""" + buffer = b'' + expected = [] + assert buffer_to_array(buffer, dtype=np.float32) == expected + +def test_plain_bytes_to_string(): + """Test conversion of plain bytes to string""" + data = b'hello world' + expected = 'hello world' + assert convert_bytes(data) == expected + +def test_bytes_in_dict(): + """Test conversion of bytes in a dictionary, including nested dictionaries""" + data = {'key': b'value', 'nested': {'nkey': b'nvalue'}} + expected = {'key': 'value', 'nested': {'nkey': 'nvalue'}} + assert convert_bytes(data) == expected + +def test_bytes_in_list(): + """Test conversion of bytes in a list, including nested lists""" + data = [b'item1', b'item2', ['nested', b'nested item']] + expected = ['item1', 'item2', ['nested', 'nested item']] + assert convert_bytes(data) == expected + +def test_bytes_in_tuple(): + """Test conversion of bytes in a tuple, including nested tuples""" + data = (b'item1', b'item2', ('nested', b'nested item')) + expected = ('item1', 'item2', ('nested', 'nested item')) + assert convert_bytes(data) == expected + +def test_non_bytes_data(): + """Test handling of non-bytes data types""" + data = 'already a string' + expected = 'already a string' + assert convert_bytes(data) == expected + +def test_bytes_with_invalid_utf8(): + """Test handling bytes that cannot be decoded with UTF-8""" + data = b'\xff\xff' # Invalid in UTF-8 + expected = data + assert convert_bytes(data) == expected + +def test_simple_list_to_bytes_default_dtype(): + """Test conversion of a simple list of floats to bytes using the default dtype""" + array = [1.0, 2.0, 3.0] + expected = np.array(array, dtype=np.float32).tobytes() + assert array_to_buffer(array) == expected + +def test_list_to_bytes_non_default_dtype(): + """Test conversion with a non-default dtype""" + array = [1.0, 2.0, 3.0] + dtype = np.float64 + expected = np.array(array, dtype=dtype).tobytes() + assert array_to_buffer(array, dtype=dtype) == expected + +def test_empty_list_to_bytes(): + """Test conversion of an empty list""" + array = [] + expected = np.array(array, dtype=np.float32).tobytes() + assert array_to_buffer(array) == expected + +@pytest.mark.parametrize("dtype", [np.int32, np.float64]) +def test_conversion_with_various_dtypes(dtype): + """Test conversion of a list of floats to bytes with various dtypes""" + array = [1.0, -2.0, 3.5] + expected = np.array(array, dtype=dtype).tobytes() + assert array_to_buffer(array, dtype=dtype) == expected + +def test_conversion_with_invalid_floats(): + """Test conversion with invalid float values (numpy should handle them)""" + array = [float('inf'), float('-inf'), float('nan')] + result = array_to_buffer(array) + assert len(result) > 0 # Simple check to ensure it returns anything