diff --git a/vectordb_bench/backend/clients/memorydb/cli.py b/vectordb_bench/backend/clients/memorydb/cli.py index 53d2a5b24..50b5f89ba 100644 --- a/vectordb_bench/backend/clients/memorydb/cli.py +++ b/vectordb_bench/backend/clients/memorydb/cli.py @@ -44,7 +44,16 @@ class MemoryDBTypedDict(TypedDict): is_flag=True, show_default=True, default=False, - help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports CME mode", + help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports cluster mode (CME)", + ), + ] + insert_batch_size: Annotated[ + int, + click.option( + "--insert-batch-size", + type=int, + default=10, + help="Batch size for inserting data. Adjust this as needed, but don't make it too big", ), ] @@ -73,6 +82,7 @@ def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]): M=parameters["m"], ef_construction=parameters["ef_construction"], ef_runtime=parameters["ef_runtime"], + insert_batch_size=parameters["insert_batch_size"] ), **parameters, ) \ No newline at end of file diff --git a/vectordb_bench/backend/clients/memorydb/config.py b/vectordb_bench/backend/clients/memorydb/config.py index 94f2478b1..1284d3449 100644 --- a/vectordb_bench/backend/clients/memorydb/config.py +++ b/vectordb_bench/backend/clients/memorydb/config.py @@ -24,7 +24,7 @@ def to_dict(self) -> dict: class MemoryDBIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None - insert_batch_size: int | None = 10 # Adjust this as needed, but don't make too big + insert_batch_size: int | None = None def parse_metric(self) -> str: if self.metric_type == MetricType.L2: diff --git a/vectordb_bench/backend/clients/memorydb/memorydb.py b/vectordb_bench/backend/clients/memorydb/memorydb.py index 20f00ccba..c5f80eb2a 100644 --- a/vectordb_bench/backend/clients/memorydb/memorydb.py +++ b/vectordb_bench/backend/clients/memorydb/memorydb.py @@ -29,7 +29,7 @@ def __init__( self.case_config = db_case_config self.collection_name = INDEX_NAME self.target_nodes = RedisCluster.RANDOM if not self.db_config["cmd"] else None - self.insert_batch_size = db_case_config.insert_batch_size or 10 + self.insert_batch_size = db_case_config.insert_batch_size self.dbsize = kwargs.get("num_rows") # Create a MemoryDB connection, if db has password configured, add it to the connection here and in init():