Skip to content

Commit

Permalink
Increase redisvl test coverage (#109)
Browse files Browse the repository at this point in the history
Upped the coverage to 81% and fix a minor bug in the utils
  • Loading branch information
bsbodden authored Feb 5, 2024
1 parent 16c11e2 commit 8d38801
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 14 deletions.
9 changes: 0 additions & 9 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,6 @@ def sample_data():
},
]

@pytest.fixture(scope="session")
def event_loop():
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
yield loop
loop.close()

@pytest.fixture
def clear_db(redis):
redis.flushall()
Expand Down
6 changes: 3 additions & 3 deletions redisvl/redis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def convert_bytes(data: Any) -> Any:
except:
return data
if isinstance(data, dict):
return dict(map(convert_bytes, data.items()))
return {convert_bytes(key): convert_bytes(value) for key, value in data.items()}
if isinstance(data, list):
return list(map(convert_bytes, data))
return [convert_bytes(item) for item in data]
if isinstance(data, tuple):
return map(convert_bytes, data)
return tuple(convert_bytes(item) for item in data)
return data


Expand Down
55 changes: 54 additions & 1 deletion tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.vectorize import HFTextVectorizer

from redisvl.index.index import SearchIndex
from collections import namedtuple

@pytest.fixture
def vectorizer():
Expand All @@ -18,6 +19,10 @@ def cache(vectorizer):
cache_instance.clear() # Clear cache after each test
cache_instance._index.delete(True) # Clean up index

@pytest.fixture
def cache_no_cleanup(vectorizer):
cache_instance = SemanticCache(vectorizer=vectorizer, distance_threshold=0.2)
yield cache_instance

@pytest.fixture
def cache_with_ttl(vectorizer):
Expand All @@ -26,6 +31,12 @@ def cache_with_ttl(vectorizer):
cache_instance.clear() # Clear cache after each test
cache_instance._index.delete(True) # Clean up index

@pytest.fixture
def cache_with_redis_client(vectorizer, client):
cache_instance = SemanticCache(vectorizer=vectorizer, redis_client=client, distance_threshold=0.2)
yield cache_instance
cache_instance.clear() # Clear cache after each test
cache_instance._index.delete(True) # Clean up index

# Test basic store and check functionality
def test_store_and_check(cache, vectorizer):
Expand Down Expand Up @@ -83,6 +94,10 @@ def test_check_invalid_input(cache):
with pytest.raises(TypeError):
cache.check(prompt="test", return_fields="bad value")

# Test handling invalid input for check method
def test_bad_ttl(cache):
with pytest.raises(ValueError):
cache.set_ttl(2.5)

# Test storing with metadata
def test_store_with_metadata(cache, vectorizer):
Expand All @@ -100,6 +115,16 @@ def test_store_with_metadata(cache, vectorizer):
assert check_result[0]["metadata"] == metadata
assert check_result[0]["prompt"] == prompt

# Test storing with invalid metadata
def test_store_with_invalid_metadata(cache, vectorizer):
prompt = "This is another test prompt."
response = "This is another test response."
metadata = namedtuple('metadata', 'source')(**{'source': 'test'})

vector = vectorizer.embed(prompt)

with pytest.raises(TypeError, match=r"If specified, cached metadata must be a dictionary."):
cache.store(prompt, response, vector=vector, metadata=metadata)

# Test setting and getting the distance threshold
def test_distance_threshold(cache):
Expand All @@ -110,6 +135,11 @@ def test_distance_threshold(cache):
assert cache.distance_threshold == new_threshold
assert cache.distance_threshold != initial_threshold

# Test out of range distance threshold
def test_distance_threshold_out_of_range(cache):
out_of_range_threshold = -1
with pytest.raises(ValueError):
cache.set_threshold(out_of_range_threshold)

# Test storing and retrieving multiple items
def test_multiple_items(cache, vectorizer):
Expand All @@ -130,3 +160,26 @@ def test_multiple_items(cache, vectorizer):
print(check_result, flush=True)
assert check_result[0]["response"] == expected_response
assert "metadata" not in check_result[0]

# Test retrieving underlying SearchIndex for the cache.
def test_get_index(cache):
assert isinstance(cache.index, SearchIndex)

# Test basic functionality with cache created with user-provided Redis client
def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
vector = vectorizer.embed(prompt)

cache_with_redis_client.store(prompt, response, vector=vector)
check_result = cache_with_redis_client.check(vector=vector)

assert len(check_result) == 1
print(check_result, flush=True)
assert response == check_result[0]["response"]
assert "metadata" not in check_result[0]

# Test deleting the cache
def test_delete(cache_no_cleanup, vectorizer):
cache_no_cleanup.delete()
assert not cache_no_cleanup.index.exists()
62 changes: 62 additions & 0 deletions tests/integration/test_search_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest

from redisvl.index import SearchIndex
from redisvl.query import FilterQuery
from redisvl.query.filter import Tag

@pytest.fixture
def filter_query():
return FilterQuery(
return_fields=None,
filter_expression=Tag("credit_score") == "high",
)

@pytest.fixture
def index(sample_data):
fields_spec = [
{"name": "credit_score", "type": "tag"},
{"name": "user", "type": "tag"},
{"name": "job", "type": "text"},
{"name": "age", "type": "numeric"},
{
"name": "user_embedding",
"type": "vector",
"attrs": {
"dims": 3,
"distance_metric": "cosine",
"algorithm": "flat",
"datatype": "float32",
},
},
]

json_schema = {
"index": {
"name": "user_index_json",
"prefix": "users_json",
"storage_type": "json",
},
"fields": fields_spec,
}

# construct a search index from the schema
index = SearchIndex.from_dict(json_schema)

# connect to local redis instance
index.connect("redis://localhost:6379")

# create the index (no data yet)
index.create(overwrite=True)

# Prepare and load the data
index.load(sample_data)

# run the test
yield index

# clean up
index.delete(drop=True)

def test_process_results_unpacks_json_properly(index, filter_query):
results = index.query(filter_query)
assert len(results) == 4
34 changes: 34 additions & 0 deletions tests/unit/test_async_search_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from redisvl.index import AsyncSearchIndex
from redisvl.redis.utils import convert_bytes
from redisvl.schema import IndexSchema, StorageType
from redisvl.query import VectorQuery

fields = [{"name": "test", "type": "tag"}]

Expand Down Expand Up @@ -137,3 +138,36 @@ async def test_no_id_field(async_client, async_index):
# catch missing / invalid id_field
with pytest.raises(ValueError):
await async_index.load(bad_data, id_field="key")


@pytest.mark.asyncio
async def test_check_index_exists_before_delete(async_client, async_index):
async_index.set_client(async_client)
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)
with pytest.raises(ValueError):
await async_index.delete()

@pytest.mark.asyncio
async def test_check_index_exists_before_search(async_client, async_index):
async_index.set_client(async_client)
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)

query = VectorQuery(
[0.1, 0.1, 0.5],
"user_embedding",
return_fields=["user", "credit_score", "age", "job", "location"],
num_results=7,
)
with pytest.raises(ValueError):
await async_index.search(query.query, query_params=query.params)

@pytest.mark.asyncio
async def test_check_index_exists_before_info(async_client, async_index):
async_index.set_client(async_client)
await async_index.create(overwrite=True, drop=True)
await async_index.delete(drop=True)

with pytest.raises(ValueError):
await async_index.info()
45 changes: 44 additions & 1 deletion tests/unit/test_search_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from redisvl.index import SearchIndex
from redisvl.redis.utils import convert_bytes
from redisvl.schema import IndexSchema, StorageType
from redisvl.query import VectorQuery

fields = [{"name": "test", "type": "tag"}]

Expand All @@ -11,11 +12,13 @@
def index_schema():
return IndexSchema.from_dict({"index": {"name": "my_index"}, "fields": fields})


@pytest.fixture
def index(index_schema):
return SearchIndex(schema=index_schema)

@pytest.fixture
def index_from_yaml():
return SearchIndex.from_yaml("schemas/test_json_schema.yaml")

def test_search_index_properties(index_schema, index):
assert index.schema == index_schema
Expand All @@ -28,6 +31,13 @@ def test_search_index_properties(index_schema, index):
assert index.storage_type == index_schema.index.storage_type == StorageType.HASH
assert index.key("foo").startswith(index.prefix)

def test_search_index_from_yaml(index_from_yaml):
assert index_from_yaml.name == "json-test"
assert index_from_yaml.client == None
assert index_from_yaml.prefix == "json"
assert index_from_yaml.key_separator == ":"
assert index_from_yaml.storage_type == StorageType.JSON
assert index_from_yaml.key("foo").startswith(index_from_yaml.prefix)

def test_search_index_no_prefix(index_schema):
# specify an explicitly empty prefix...
Expand Down Expand Up @@ -118,3 +128,36 @@ def test_no_id_field(client, index):
# catch missing / invalid id_field
with pytest.raises(ValueError):
index.load(bad_data, id_field="key")

def test_check_index_exists_before_delete(client, index):
index.set_client(client)
index.create(overwrite=True, drop=True)
index.delete(drop=True)
with pytest.raises(ValueError):
index.delete()

def test_check_index_exists_before_search(client, index):
index.set_client(client)
index.create(overwrite=True, drop=True)
index.delete(drop=True)

query = VectorQuery(
[0.1, 0.1, 0.5],
"user_embedding",
return_fields=["user", "credit_score", "age", "job", "location"],
num_results=7,
)
with pytest.raises(ValueError):
index.search(query.query, query_params=query.params)

def test_check_index_exists_before_info(client, index):
index.set_client(client)
index.create(overwrite=True, drop=True)
index.delete(drop=True)

with pytest.raises(ValueError):
index.info()

def test_index_needs_valid_schema():
with pytest.raises(ValueError, match=r"Must provide a valid IndexSchema object"):
index = SearchIndex(schema="Not A Valid Schema")
Loading

0 comments on commit 8d38801

Please sign in to comment.