Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/memorymind' into memorymind
Browse files Browse the repository at this point in the history
  • Loading branch information
Baswanth Vegunta committed Jul 8, 2024
2 parents 6da877d + d8b1407 commit 5faae4f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ All the database client supported
|pgvector|`pip install vectordb-bench[pgvector]`|
|pgvecto.rs|`pip install vectordb-bench[pgvecto_rs]`|
|redis|`pip install vectordb-bench[redis]`|
|memorydb| `pip install vectordb-bench[memorydb]`|
|chromadb|`pip install vectordb-bench[chromadb]`|

### Run
Expand Down
4 changes: 2 additions & 2 deletions vectordb_bench/backend/clients/memorydb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MemoryDBTypedDict(TypedDict):
is_flag=True,
show_default=True,
default=True,
help="Enable or disable SSL for Redis",
help="Enable or disable SSL for MemoryDB",
),
]
ssl_ca_certs: Annotated[
Expand All @@ -44,7 +44,7 @@ class MemoryDBTypedDict(TypedDict):
is_flag=True,
show_default=True,
default=False,
help="Cluster Mode Disabled (CMD) for Redis doesn't use Cluster conn",
help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports CME mode",
),
]

Expand Down
32 changes: 15 additions & 17 deletions vectordb_bench/backend/clients/memorydb/memorydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,24 @@ def __init__(
self.case_config = db_case_config
self.collection_name = INDEX_NAME
self.target_nodes = RedisCluster.RANDOM if not self.db_config["cmd"] else None
self.insert_batch_size = db_case_config.insert_batch_size or 1
self.insert_batch_size = db_case_config.insert_batch_size or 10
self.dbsize = kwargs.get("num_rows")

# Create a redis connection, if db has password configured, add it to the connection here and in init():
log.info(f"Redis establishing connection to: {self.db_config}")
# Create a MemoryDB connection, if db has password configured, add it to the connection here and in init():
log.info(f"Establishing connection to: {self.db_config}")
conn = self.get_client(primary=True)
log.info(f"Connection established: {conn}")
log.info(conn.execute_command("INFO server"))

if drop_old:
try:
log.info(f"Redis client getting info for: {INDEX_NAME}")
log.info(f"MemoryDB client getting info for: {INDEX_NAME}")
info = conn.ft(INDEX_NAME).info()
log.info(f"Index info: {info}")
except redis.exceptions.ResponseError as e:
log.error(e)
drop_old = False
log.info(f"Redis client drop_old collection: {self.collection_name}")
log.info(f"MemoryDB client drop_old collection: {self.collection_name}")

log.info("Executing FLUSHALL")
conn.flushall()
Expand All @@ -73,11 +73,9 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis):
index_param = self.case_config.index_param()
search_param = self.case_config.search_param()
vector_parameters = { # Vector Index Type: FLAT or HNSW
"TYPE": "FLOAT32", # FLOAT32 or FLOAT64
"TYPE": "FLOAT32",
"DIM": vector_dimensions, # Number of Vector Dimensions
"DISTANCE_METRIC": index_param[
"metric"
], # Vector Search Distance Metric
"DISTANCE_METRIC": index_param["metric"], # Vector Search Distance Metric
}
if index_param["m"]:
vector_parameters["M"] = index_param["m"]
Expand All @@ -89,7 +87,7 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis):
schema = (
TagField("id"),
NumericField("metadata"),
VectorField("vector", # Vector Field Name
VectorField("vector", # Vector Field Name
"HNSW", vector_parameters
),
)
Expand All @@ -100,8 +98,8 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis):

def get_client(self, **kwargs):
"""
Gets either cluster connection or normal redis connection based on `cmd` flag.
CMD stands for Cluster Mode Disabled and is a "mode" for Redis.
Gets either cluster connection or normal connection based on `cmd` flag.
CMD stands for Cluster Mode Disabled and is a "mode".
"""
if not self.db_config["cmd"]:
# Cluster mode enabled
Expand Down Expand Up @@ -228,15 +226,15 @@ def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis):
def search_embedding(
self,
query: list[float],
k: int = 100,
k: int = 10,
filters: dict | None = None,
timeout: int | None = None,
**kwargs: Any,
) -> (list[int]):
assert self.conn is not None

query_vector = np.array(query).astype(np.float32).tobytes()
query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k).dialect(2)
query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
query_params = {"vec": query_vector}

if filters:
Expand All @@ -246,11 +244,11 @@ def search_embedding(
# Removing '>=' from the id_value: '>=10000'
metadata_value = filters.get("metadata")[2:]
if id_value and metadata_value:
query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k).dialect(2)
query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
elif id_value:
#gets exact match for id
query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k).dialect(2)
query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
else: #metadata only case, greater than or equal to metadata value
query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k).dialect(2)
query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
res = self.conn.ft(INDEX_NAME).search(query_obj, query_params)
return [int(doc["id"]) for doc in res.docs]

0 comments on commit 5faae4f

Please sign in to comment.