Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Issue with search dialect 3 and JSON (resolves #140) #151

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,38 @@
logger = get_logger(__name__)


def _handle_dialect_3(result: Dict[str, Any]) -> Dict[str, Any]:
"""
Handle dialect 3 responses by converting JSON-encoded list values to strings.

Each JSON-encoded string in the result that is a list will be converted:
- If the list has one item, it is unpacked.
- If the list has multiple items, they are joined into a single comma-separated string.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this is the desired behavior from @tylerhutcherson but I feel like I'd rather have a list than a comma-separated string in this situation. I may be missing some deeper context though

Copy link
Collaborator Author

@bsbodden bsbodden May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tylerhutcherson what do you think?


Args:
result (Dict[str, Any]): The dictionary containing the results to process.

Returns:
Dict[str, Any]: The processed dictionary with updated values.
"""
for field, value in result.items():
if isinstance(value, str):
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
continue # Skip processing if value is not valid JSON

if isinstance(parsed_value, list):
# Use a single value if the list contains only one item, else join all items.
result[field] = (
parsed_value[0]
if len(parsed_value) == 1
else ", ".join(map(str, parsed_value))
)

return result


def process_results(
results: "Result", query: BaseQuery, storage_type: StorageType
) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -81,7 +113,13 @@ def _process(doc: "Document") -> Dict[str, Any]:

return doc_dict

return [_process(doc) for doc in results.docs]
processed_results = [_process(doc) for doc in results.docs]

# Handle dialect 3 responses
if query._dialect == 3:
processed_results = [_handle_dialect_3(result) for result in processed_results]

return processed_results


def check_modules_present():
Expand Down
118 changes: 118 additions & 0 deletions tests/integration/test_dialects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import json

import pytest
from redis import Redis
from redis.commands.search.query import Query

from redisvl.index import SearchIndex
from redisvl.query import FilterQuery, VectorQuery
from redisvl.query.filter import Tag
from redisvl.schema.schema import IndexSchema


@pytest.fixture
def sample_data():
return [
{
"name": "Noise-cancelling Bluetooth headphones",
"description": "Wireless Bluetooth headphones with noise-cancelling technology",
"connection": {"wireless": True, "type": "Bluetooth"},
"price": 99.98,
"stock": 25,
"colors": ["black", "silver"],
"embedding": [0.87, -0.15, 0.55, 0.03],
"embeddings": [[0.56, -0.34, 0.69, 0.02], [0.94, -0.23, 0.45, 0.19]],
},
{
"name": "Wireless earbuds",
"description": "Wireless Bluetooth in-ear headphones",
"connection": {"wireless": True, "type": "Bluetooth"},
"price": 64.99,
"stock": 17,
"colors": ["red", "black", "white"],
"embedding": [-0.7, -0.51, 0.88, 0.14],
"embeddings": [[0.54, -0.14, 0.79, 0.92], [0.94, -0.93, 0.45, 0.16]],
},
]


@pytest.fixture
def schema_dict():
return {
"index": {"name": "products", "prefix": "product", "storage_type": "json"},
"fields": [
{"name": "name", "type": "text"},
{"name": "description", "type": "text"},
{"name": "connection_type", "path": "$.connection.type", "type": "tag"},
{"name": "price", "type": "numeric"},
{"name": "stock", "type": "numeric"},
{"name": "color", "path": "$.colors.*", "type": "tag"},
{
"name": "embedding",
"type": "vector",
"attrs": {"dims": 4, "algorithm": "flat", "distance_metric": "cosine"},
},
{
"name": "embeddings",
"path": "$.embeddings[*]",
"type": "vector",
"attrs": {"dims": 4, "algorithm": "hnsw", "distance_metric": "l2"},
},
],
}


@pytest.fixture
def index(sample_data, redis_url, schema_dict):
index_schema = IndexSchema.from_dict(schema_dict)
redis_client = Redis.from_url(redis_url)
index = SearchIndex(index_schema, redis_client)
index.create(overwrite=True, drop=True)
index.load(sample_data)
yield index
index.delete(drop=True)


def test_dialect_3_json(index, sample_data):
# Create a VectorQuery with dialect 3
vector_query = VectorQuery(
vector=[0.23, 0.12, -0.03, 0.98],
vector_field_name="embedding",
return_fields=["name", "description", "price"],
dialect=3,
)

# Execute the query
results = index.query(vector_query)

# Print the results
print("VectorQuery Results:")
print(results)

# Assert the expected format of the results
assert len(results) > 0
for result in results:
assert not isinstance(result["name"], list)
assert not isinstance(result["description"], list)
assert not isinstance(result["price"], (list, str))

# Create a FilterQuery with dialect 3
filter_query = FilterQuery(
filter_expression=Tag("color") == "black",
return_fields=["name", "description", "price"],
dialect=3,
)

# Execute the query
results = index.query(filter_query)

# Print the results
print("FilterQuery Results:")
print(results)

# Assert the expected format of the results
assert len(results) > 0
for result in results:
assert not isinstance(result["name"], list)
assert not isinstance(result["description"], list)
assert not isinstance(result["price"], (list, str))