From f9d27bbe00f87ef939f40175bbb5686cfdb6e38c Mon Sep 17 00:00:00 2001 From: Mark Greenhalgh Date: Thu, 30 May 2024 20:03:19 +0000 Subject: [PATCH] Introduce a command line interface. --- README.md | 187 +++++++++ pyproject.toml | 9 +- vectordb_bench/__init__.py | 23 +- vectordb_bench/backend/cases.py | 44 ++- vectordb_bench/backend/clients/__init__.py | 1 + vectordb_bench/backend/clients/api.py | 2 +- vectordb_bench/backend/clients/milvus/cli.py | 291 ++++++++++++++ .../backend/clients/milvus/milvus.py | 19 +- .../backend/clients/pgvector/cli.py | 116 ++++++ .../backend/clients/pgvector/config.py | 2 +- .../backend/clients/pgvector/pgvector.py | 8 +- vectordb_bench/backend/clients/redis/cli.py | 74 ++++ vectordb_bench/backend/clients/test/cli.py | 25 ++ vectordb_bench/backend/clients/test/config.py | 18 + vectordb_bench/backend/clients/test/test.py | 62 +++ .../backend/clients/weaviate_cloud/cli.py | 41 ++ .../backend/clients/zilliz_cloud/cli.py | 55 +++ vectordb_bench/backend/task_runner.py | 98 +++-- vectordb_bench/cli/__init__.py | 0 vectordb_bench/cli/cli.py | 362 ++++++++++++++++++ vectordb_bench/cli/vectordbbench.py | 20 + vectordb_bench/config-files/sample_config.yml | 17 + .../components/run_test/dbSelector.py | 8 +- .../components/run_test/submitTask.py | 16 +- .../frontend/const/dbCaseConfigs.py | 3 +- vectordb_bench/interface.py | 46 +-- vectordb_bench/models.py | 50 ++- 27 files changed, 1505 insertions(+), 92 deletions(-) create mode 100644 vectordb_bench/backend/clients/milvus/cli.py create mode 100644 vectordb_bench/backend/clients/pgvector/cli.py create mode 100644 vectordb_bench/backend/clients/redis/cli.py create mode 100644 vectordb_bench/backend/clients/test/cli.py create mode 100644 vectordb_bench/backend/clients/test/config.py create mode 100644 vectordb_bench/backend/clients/test/test.py create mode 100644 vectordb_bench/backend/clients/weaviate_cloud/cli.py create mode 100644 vectordb_bench/backend/clients/zilliz_cloud/cli.py create mode 100644 vectordb_bench/cli/__init__.py create mode 100644 vectordb_bench/cli/cli.py create mode 100644 vectordb_bench/cli/vectordbbench.py create mode 100644 vectordb_bench/config-files/sample_config.yml diff --git a/README.md b/README.md index 4ff2f9121..eeda249ae 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,115 @@ All the database client supported ``` shell init_bench ``` + +OR: + +### Run from the command line. + +``` shell +vectordbbench [OPTIONS] COMMAND [ARGS]... +``` +To list the clients that are runnable via the commandline option, execute: `vectordbbench --help` +``` text +$ vectordbbench --help +Usage: vectordbbench [OPTIONS] COMMAND [ARGS]... + +Options: + --help Show this message and exit. + +Commands: + pgvectorhnsw + pgvectorivfflat + test + weaviate +``` +To list the options for each command, execute `vectordbbench [command] --help` + +```text +$ vectordbbench pgvectorhnsw --help +Usage: vectordbbench pgvectorhnsw [OPTIONS] + +Options: + --config-file PATH Read configuration from yaml file + --drop-old / --skip-drop-old Drop old or skip [default: drop-old] + --load / --skip-load Load or skip [default: load] + --search-serial / --skip-search-serial + Search serial or skip [default: search- + serial] + --search-concurrent / --skip-search-concurrent + Search concurrent or skip [default: search- + concurrent] + --case-type [CapacityDim128|CapacityDim960|Performance768D100M|Performance768D10M|Performance768D1M|Performance768D10M1P|Performance768D1M1P|Performance768D10M99P|Performance768D1M99P|Performance1536D500K|Performance1536D5M|Performance1536D500K1P|Performance1536D5M1P|Performance1536D500K99P|Performance1536D5M99P|Performance1536D50K] + Case type + --db-label TEXT Db label, default: date in ISO format + [default: 2024-05-20T20:26:31.113290] + --dry-run Print just the configuration and exit + without running the tasks + --k INTEGER K value for number of nearest neighbors to + search [default: 100] + --concurrency-duration INTEGER Adjusts the duration in seconds of each + concurrency search [default: 30] + --num-concurrency TEXT Comma-separated list of concurrency values + to test during concurrent search [default: + 1,10,20] + --user-name TEXT Db username [required] + --password TEXT Db password [required] + --host TEXT Db host [required] + --db-name TEXT Db name [required] + --maintenance-work-mem TEXT Sets the maximum memory to be used for + maintenance operations (index creation). Can + be entered as string with unit like '64GB' + or as an integer number of KB.This will set + the parameters: + max_parallel_maintenance_workers, + max_parallel_workers & + table(parallel_workers) + --max-parallel-workers INTEGER Sets the maximum number of parallel + processes per maintenance operation (index + creation) + --m INTEGER hnsw m + --ef-construction INTEGER hnsw ef-construction + --ef-search INTEGER hnsw ef-search + --help Show this message and exit. +``` +#### Using a configuration file. + +The vectordbbench command can optionally read some or all the options from a yaml formatted configuration file. + +By default, configuration files are expected to be in vectordb_bench/config-files/, this can be overridden by setting +the environment variable CONFIG_LOCAL_DIR or by passing the full path to the file. + +The required format is: +```yaml +commandname: + parameter_name: parameter_value + parameter_name: parameter_value +``` +Example: +```yaml +pgvectorhnsw: + db_label: pgConfigTest + user_name: vectordbbench + password: vectordbbench + db_name: vectordbbench + host: localhost + m: 16 + ef_construction: 128 + ef_search: 128 +milvushnsw: + skip_search_serial: True + case_type: Performance1536D50K + uri: http://localhost:19530 + m: 16 + ef_construction: 128 + ef_search: 128 + drop_old: False + load: False +``` +> Notes: +> - Options passed on the command line will override the configuration file* +> - Parameter names use an _ not - + ## What is VectorDBBench VectorDBBench is not just an offering of benchmark results for mainstream vector databases and cloud services, it's your go-to tool for the ultimate performance and cost-effectiveness comparison. Designed with ease-of-use in mind, VectorDBBench is devised to help users, even non-professionals, reproduce results or test new systems, making the hunt for the optimal choice amongst a plethora of cloud services and open-source vector databases a breeze. @@ -220,6 +329,7 @@ class NewDBCaseConfig(DBCaseConfig): # Implement optional case-specific configuration fields # ... ``` + **Step 3: Importing the DB Client and Updating Initialization** In this final step, you will import your DB client into clients/__init__.py and update the initialization process. @@ -258,6 +368,83 @@ class DB(Enum): return NewClientCaseConfig ``` +**Step 4: Implement new_client/cli.py and vectordb_bench/cli/vectordbbench.py** + +In this (optional, but encouraged) step you will enable the test to be run from the command line. +1. Navigate to the vectordb_bench/backend/clients/"client" directory. +2. Inside the "client" folder, create a cli.py file. +Using zilliz as an example cli.py: +```python +from typing import Annotated, Unpack + +import click +import os +from pydantic import SecretStr + +from vectordb_bench.cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from vectordb_bench.backend.clients import DB + + +class ZillizTypedDict(CommonTypedDict): + uri: Annotated[ + str, click.option("--uri", type=str, help="uri connection string", required=True) + ] + user_name: Annotated[ + str, click.option("--user-name", type=str, help="Db username", required=True) + ] + password: Annotated[ + str, + click.option("--password", + type=str, + help="Zilliz password", + default=lambda: os.environ.get("ZILLIZ_PASSWORD", ""), + show_default="$ZILLIZ_PASSWORD", + ), + ] + level: Annotated[ + str, + click.option("--level", type=str, help="Zilliz index level", required=False), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(ZillizTypedDict) +def ZillizAutoIndex(**parameters: Unpack[ZillizTypedDict]): + from .config import ZillizCloudConfig, AutoIndexConfig + + run( + db=DB.ZillizCloud, + db_config=ZillizCloudConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), + ), + db_case_config=AutoIndexConfig( + params={parameters["level"]}, + ), + **parameters, + ) +``` +3. Update cli by adding: + 1. Add database specific options as an Annotated TypedDict, see ZillizTypedDict above. + 2. Add index configuration specific options as an Annotated TypedDict. (example: vectordb_bench/backend/clients/pgvector/cli.py) + 1. May not be needed if there is only one index config. + 2. Repeat for each index configuration, nesting them if possible. + 2. Add a index config specific function for each index type, see Zilliz above. The function name, in lowercase, will be the command name passed to the vectordbbench command. + 3. Update db_config and db_case_config to match client requirements + 4. Continue to add new functions for each index config. + 5. Import the client cli module and command to vectordb_bench/cli/vectordbbench.py (for databases with multiple commands (index configs), this only needs to be done for one command) + +> cli modules with multiple index configs: +> - pgvector: vectordb_bench/backend/clients/pgvector/cli.py +> - milvus: vectordb_bench/backend/clients/milvus/cli.py + That's it! You have successfully added a new DB client to the vectordb_bench project. ## Rules diff --git a/pyproject.toml b/pyproject.toml index 8267f8522..1311cceef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta" [tool.setuptools.packages.find] where = ["."] -include = ["vectordb_bench"] +include = ["vectordb_bench", "vectordb_bench.cli"] [project] name = "vectordb-bench" @@ -24,6 +24,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ + "click", "pytz", "streamlit-autorefresh", "streamlit!=1.34.0", @@ -60,21 +61,23 @@ all = [ "chromadb", "psycopg2", "psycopg", + "psycopg-binary", ] qdrant = [ "qdrant-client" ] pinecone = [ "pinecone-client" ] weaviate = [ "weaviate-client" ] elastic = [ "elasticsearch" ] -pgvector = [ "pgvector", "psycopg" ] +pgvector = [ "psycopg", "psycopg-binary", "pgvector" ] pgvecto_rs = [ "psycopg2" ] redis = [ "redis" ] chromadb = [ "chromadb" ] +zilliz_cloud = [] [project.urls] "repository" = "https://github.com/zilliztech/VectorDBBench" [project.scripts] init_bench = "vectordb_bench.__main__:main" - +vectordbbench = "vectordb_bench.cli.vectordbbench:cli" [tool.setuptools_scm] diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py index eca190832..342059d18 100644 --- a/vectordb_bench/__init__.py +++ b/vectordb_bench/__init__.py @@ -1,11 +1,13 @@ -import environs import inspect import pathlib -from . import log_util +import environs + +from . import log_util env = environs.Env() -env.read_env(".env") +env.read_env(".env", False) + class config: ALIYUN_OSS_URL = "assets.zilliz.com.cn/benchmark/" @@ -19,9 +21,20 @@ class config: DROP_OLD = env.bool("DROP_OLD", True) USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True) - NUM_CONCURRENCY = [1, 5, 10, 15, 20, 25, 30, 35] - RESULTS_LOCAL_DIR = pathlib.Path(__file__).parent.joinpath("results") + NUM_CONCURRENCY = env.list("NUM_CONCURRENCY", [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100], subcast=int ) + + CONCURRENCY_DURATION = 30 + + RESULTS_LOCAL_DIR = env.path( + "RESULTS_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("results") + ) + CONFIG_LOCAL_DIR = env.path( + "CONFIG_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("config-files") + ) + + + K_DEFAULT = 100 # default return top k nearest neighbors during search CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h LOAD_TIMEOUT_DEFAULT = 2.5 * 3600 # 2.5h diff --git a/vectordb_bench/backend/cases.py b/vectordb_bench/backend/cases.py index 7f40ae4f8..6f4b35974 100644 --- a/vectordb_bench/backend/cases.py +++ b/vectordb_bench/backend/cases.py @@ -1,6 +1,7 @@ import typing import logging from enum import Enum, auto +from typing import Type from vectordb_bench import config from vectordb_bench.base import BaseModel @@ -10,8 +11,6 @@ log = logging.getLogger(__name__) -Case = typing.TypeVar("Case") - class CaseType(Enum): """ @@ -42,11 +41,15 @@ class CaseType(Enum): Performance1536D500K99P = 14 Performance1536D5M99P = 15 + Performance1536D50K = 50 + Custom = 100 @property - def case_cls(self, custom_configs: dict | None = None) -> Case: - return type2case.get(self) + def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]: + if self not in type2case: + raise NotImplementedError(f"Case {self} has not implemented. You can add it manually to vectordb_bench.backend.cases.type2case or define a custom_configs['custom_cls']") + return type2case[self] @property def case_name(self) -> str: @@ -69,7 +72,7 @@ class CaseLabel(Enum): class Case(BaseModel): - """Undifined case + """Undefined case Fields: case_id(CaseType): default 9 case type plus one custom cases. @@ -86,9 +89,9 @@ class Case(BaseModel): dataset: DatasetManager load_timeout: float | int - optimize_timeout: float | int | None + optimize_timeout: float | int | None = None - filter_rate: float | None + filter_rate: float | None = None @property def filters(self) -> dict | None: @@ -115,20 +118,23 @@ class PerformanceCase(Case, BaseModel): load_timeout: float | int = config.LOAD_TIMEOUT_DEFAULT optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_DEFAULT + class CapacityDim960(CapacityCase): case_id: CaseType = CaseType.CapacityDim960 dataset: DatasetManager = Dataset.GIST.manager(100_000) name: str = "Capacity Test (960 Dim Repeated)" - description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension vectors (GIST 100K vectors, 960 dimensions) until it is fully loaded. -Number of inserted vectors will be reported.""" + description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension + vectors (GIST 100K vectors, 960 dimensions) until it is fully loaded. Number of inserted vectors will be + reported.""" class CapacityDim128(CapacityCase): case_id: CaseType = CaseType.CapacityDim128 dataset: DatasetManager = Dataset.SIFT.manager(500_000) name: str = "Capacity Test (128 Dim Repeated)" - description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension vectors (SIFT 100K vectors, 128 dimensions) until it is fully loaded. -Number of inserted vectors will be reported.""" + description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension + vectors (SIFT 100K vectors, 128 dimensions) until it is fully loaded. Number of inserted vectors will be + reported.""" class Performance768D10M(PerformanceCase): @@ -238,6 +244,7 @@ class Performance1536D500K1P(PerformanceCase): load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K + class Performance1536D5M1P(PerformanceCase): case_id: CaseType = CaseType.Performance1536D5M1P filter_rate: float | int | None = 0.01 @@ -248,6 +255,7 @@ class Performance1536D5M1P(PerformanceCase): load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M + class Performance1536D500K99P(PerformanceCase): case_id: CaseType = CaseType.Performance1536D500K99P filter_rate: float | int | None = 0.99 @@ -258,6 +266,7 @@ class Performance1536D500K99P(PerformanceCase): load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K + class Performance1536D5M99P(PerformanceCase): case_id: CaseType = CaseType.Performance1536D5M99P filter_rate: float | int | None = 0.99 @@ -269,6 +278,17 @@ class Performance1536D5M99P(PerformanceCase): optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M +class Performance1536D50K(PerformanceCase): + case_id: CaseType = CaseType.Performance1536D50K + filter_rate: float | int | None = None + dataset: DatasetManager = Dataset.OPENAI.manager(50_000) + name: str = "Search Performance Test (50K Dataset, 1536 Dim)" + description: str = """This case tests the search performance of a vector database with a medium 50K dataset (OpenAI 50K vectors, 1536 dimensions), at varying parallel levels. +Results will show index building time, recall, and maximum QPS.""" + load_timeout: float | int = 3600 + optimize_timeout: float | int | None = 15 * 60 + + type2case = { CaseType.CapacityDim960: CapacityDim960, CaseType.CapacityDim128: CapacityDim128, @@ -290,5 +310,5 @@ class Performance1536D5M99P(PerformanceCase): CaseType.Performance1536D500K99P: Performance1536D500K99P, CaseType.Performance1536D5M99P: Performance1536D5M99P, - + CaseType.Performance1536D50K: Performance1536D50K, } diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index aa6641a93..6a99661c2 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -32,6 +32,7 @@ class DB(Enum): PgVectoRS = "PgVectoRS" Redis = "Redis" Chroma = "Chroma" + Test = "test" @property diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index d1325322a..0024bf600 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -47,7 +47,7 @@ def to_dict(self) -> dict: def not_empty_field(cls, v, field): if field.name == "db_label": return v - if isinstance(v, (str, SecretStr)) and len(v) == 0: + if not v and isinstance(v, (str, SecretStr)): raise ValueError("Empty string!") return v diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py new file mode 100644 index 000000000..05ff95418 --- /dev/null +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -0,0 +1,291 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from vectordb_bench.cli.cli import ( + CommonTypedDict, + HNSWFlavor3, + IVFFlatTypedDictN, + cli, + click_parameter_decorators_from_typed_dict, + run, + +) +from vectordb_bench.backend.clients import DB + +DBTYPE = DB.Milvus + + +class MilvusTypedDict(TypedDict): + uri: Annotated[ + str, click.option("--uri", type=str, help="uri connection string", required=True) + ] + + +class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict): + ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusAutoIndexTypedDict) +def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]): + from .config import MilvusConfig, AutoIndexConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + ), + db_case_config=AutoIndexConfig(), + **parameters, + ) + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusAutoIndexTypedDict) +def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]): + from .config import MilvusConfig, FLATConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + ), + db_case_config=FLATConfig(), + **parameters, + ) + + +class MilvusHNSWTypedDict(CommonTypedDict, MilvusTypedDict, HNSWFlavor3): + ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusHNSWTypedDict) +def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]): + from .config import MilvusConfig, HNSWConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + ), + db_case_config=HNSWConfig( + M=parameters["m"], + efConstruction=parameters["ef_construction"], + ef=parameters["ef_search"], + ), + **parameters, + ) + + +class MilvusIVFFlatTypedDict(CommonTypedDict, MilvusTypedDict, IVFFlatTypedDictN): + ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusIVFFlatTypedDict) +def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]): + from .config import MilvusConfig, IVFFlatConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + ), + db_case_config=IVFFlatConfig( + nlist=parameters["nlist"], + nprobe=parameters["nprobe"], + ), + **parameters, + ) + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusIVFFlatTypedDict) +def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]): + from .config import MilvusConfig, IVFSQ8Config + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + ), + db_case_config=IVFSQ8Config( + nlist=parameters["nlist"], + nprobe=parameters["nprobe"], + ), + **parameters, + ) + + +class MilvusDISKANNTypedDict(CommonTypedDict, MilvusTypedDict): + search_list: Annotated[ + str, click.option("--search-list", + type=int, + required=True) + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusDISKANNTypedDict) +def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]): + from .config import MilvusConfig, DISKANNConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + ), + db_case_config=DISKANNConfig( + search_list=parameters["search_list"], + ), + **parameters, + ) + + +class MilvusGPUIVFTypedDict(CommonTypedDict, MilvusTypedDict, MilvusIVFFlatTypedDict): + cache_dataset_on_device: Annotated[ + str, click.option("--cache-dataset-on-device", + type=str, + required=True) + ] + refine_ratio: Annotated[ + str, click.option("--refine-ratio", + type=float, + required=True) + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusGPUIVFTypedDict) +def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]): + from .config import MilvusConfig, GPUIVFFlatConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + ), + db_case_config=GPUIVFFlatConfig( + nlist=parameters["nlist"], + nprobe=parameters["nprobe"], + cache_dataset_on_device=parameters["cache_dataset_on_device"], + refine_ratio=parameters.get("refine_ratio"), + ), + **parameters, + ) + + +class MilvusGPUIVFPQTypedDict(CommonTypedDict, MilvusTypedDict, MilvusIVFFlatTypedDict, MilvusGPUIVFTypedDict): + m: Annotated[ + str, click.option("--m", + type=int, help="hnsw m", + required=True) + ] + nbits: Annotated[ + str, click.option("--nbits", + type=int, + required=True) + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusGPUIVFPQTypedDict) +def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]): + from .config import MilvusConfig, GPUIVFPQConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + ), + db_case_config=GPUIVFPQConfig( + nlist=parameters["nlist"], + nprobe=parameters["nprobe"], + m=parameters["m"], + nbits=parameters["nbits"], + cache_dataset_on_device=parameters["cache_dataset_on_device"], + refine_ratio=parameters["refine_ratio"], + ), + **parameters, + ) + + +class MilvusGPUCAGRATypedDict(CommonTypedDict, MilvusTypedDict, MilvusGPUIVFTypedDict): + intermediate_graph_degree: Annotated[ + str, click.option("--intermediate-graph-degree", + type=int, + required=True) + ] + graph_degree: Annotated[ + str, click.option("--graph-degree", + type=int, + required=True) + ] + build_algo: Annotated[ + str, click.option("--build_algo", + type=str, + required=True) + ] + team_size: Annotated[ + str, click.option("--team-size", + type=int, + required=True) + ] + search_width: Annotated[ + str, click.option("--search-width", + type=int, + required=True) + ] + itopk_size: Annotated[ + str, click.option("--itopk-size", + type=int, + required=True) + ] + min_iterations: Annotated[ + str, click.option("--min-iterations", + type=int, + required=True) + ] + max_iterations: Annotated[ + str, click.option("--max-iterations", + type=int, + required=True) + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(MilvusGPUCAGRATypedDict) +def MilvusGPUCAGRA(**parameters: Unpack[MilvusGPUCAGRATypedDict]): + from .config import MilvusConfig, GPUCAGRAConfig + + run( + db=DBTYPE, + db_config=MilvusConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + ), + db_case_config=GPUCAGRAConfig( + intermediate_graph_degree=parameters["intermediate_graph_degree"], + graph_degree=parameters["graph_degree"], + itopk_size=parameters["itopk_size"], + team_size=parameters["team_size"], + search_width=parameters["search_width"], + min_iterations=parameters["min_iterations"], + max_iterations=parameters["max_iterations"], + build_algo=parameters["build_algo"], + cache_dataset_on_device=parameters["cache_dataset_on_device"], + refine_ratio=parameters["refine_ratio"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index 2436e2680..4590265ae 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -8,7 +8,7 @@ from pymilvus import Collection, utility from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException -from ..api import VectorDB +from ..api import VectorDB, IndexType from .config import MilvusIndexConfig @@ -122,10 +122,18 @@ def wait_index(): if self.case_config.is_gpu_index: log.debug("skip compaction for gpu index type.") else : - self.col.compact() - self.col.wait_for_compaction_completed() + try: + self.col.compact() + self.col.wait_for_compaction_completed() + except Exception as e: + log.warning(f"{self.name} compact error: {e}") + if hasattr(e, 'code'): + if e.code().name == 'PERMISSION_DENIED': + log.warning(f"Skip compact due to permission denied.") + pass + else: + raise e wait_index() - except Exception as e: log.warning(f"{self.name} optimize error: {e}") raise e from None @@ -143,7 +151,6 @@ def _pre_load(self, coll: Collection): self.case_config.index_param(), index_name=self._index_name, ) - coll.load() log.info(f"{self.name} load") except Exception as e: @@ -160,7 +167,7 @@ def need_normalize_cosine(self) -> bool: if self.case_config.is_gpu_index: log.info(f"current gpu_index only supports IP / L2, cosine dataset need normalize.") return True - + return False def insert_embeddings( diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py new file mode 100644 index 000000000..31b268231 --- /dev/null +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -0,0 +1,116 @@ +from typing import Annotated, Optional, TypedDict, Unpack + +import click +import os +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor1, + IVFFlatTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from vectordb_bench.backend.clients import DB + + +class PgVectorTypedDict(CommonTypedDict): + user_name: Annotated[ + str, click.option("--user-name", type=str, help="Db username", required=True) + ] + password: Annotated[ + str, + click.option("--password", + type=str, + help="Postgres database password", + default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), + show_default="$POSTGRES_PASSWORD", + ), + ] + + host: Annotated[ + str, click.option("--host", type=str, help="Db host", required=True) + ] + db_name: Annotated[ + str, click.option("--db-name", type=str, help="Db name", required=True) + ] + maintenance_work_mem: Annotated[ + Optional[str], + click.option( + "--maintenance-work-mem", + type=str, + help="Sets the maximum memory to be used for maintenance operations (index creation). " + "Can be entered as string with unit like '64GB' or as an integer number of KB." + "This will set the parameters: max_parallel_maintenance_workers," + " max_parallel_workers & table(parallel_workers)", + required=False, + ), + ] + max_parallel_workers: Annotated[ + Optional[int], + click.option( + "--max-parallel-workers", + type=int, + help="Sets the maximum number of parallel processes per maintenance operation (index creation)", + required=False, + ), + ] + + +class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict): + ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(PgVectorIVFFlatTypedDict) +def PgVectorIVFFlat( + **parameters: Unpack[PgVectorIVFFlatTypedDict], +): + from .config import PgVectorConfig, PgVectorIVFFlatConfig + + run( + db=DB.PgVector, + db_config=PgVectorConfig( + db_label=parameters["db_label"], + user_name=SecretStr(parameters["user_name"]), + password=SecretStr(parameters["password"]), + host=parameters["host"], + db_name=parameters["db_name"], + ), + db_case_config=PgVectorIVFFlatConfig( + metric_type=None, lists=parameters["lists"], probes=parameters["probes"] + ), + **parameters, + ) + + +class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1): + ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(PgVectorHNSWTypedDict) +def PgVectorHNSW( + **parameters: Unpack[PgVectorHNSWTypedDict], +): + from .config import PgVectorConfig, PgVectorHNSWConfig + + run( + db=DB.PgVector, + db_config=PgVectorConfig( + db_label=parameters["db_label"], + user_name=SecretStr(parameters["user_name"]), + password=SecretStr(parameters["password"]), + host=parameters["host"], + db_name=parameters["db_name"], + ), + db_case_config=PgVectorHNSWConfig( + m=parameters["m"], + ef_construction=parameters["ef_construction"], + ef_search=parameters["ef_search"], + maintenance_work_mem=parameters["maintenance_work_mem"], + max_parallel_workers=parameters["max_parallel_workers"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index a4cc584e8..496a3b440 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -109,7 +109,7 @@ def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[ def _optionally_build_set_options( set_mapping: Mapping[str, Any] ) -> Sequence[dict[str, Any]]: - """Walk through options, creating 'SET 'key1 = "value1";' commands""" + """Walk through options, creating 'SET 'key1 = "value1";' list""" session_options = [] for setting_name, value in set_mapping.items(): if value: diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 8f8244412..5e66fd77e 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -58,14 +58,13 @@ def __init__( self.case_config.create_index_after_load, ) ): - err = f"{self.name} config must create an index using create_index_before_load and/or create_index_after_load" + err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load" log.error(err) raise RuntimeError( f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" ) if drop_old: - # self.pg_table.drop(pg_engine, checkfirst=True) self._drop_index() self._drop_table() self._create_table(dim) @@ -257,7 +256,10 @@ def _create_index(self): with_clause = sql.Composed(()) index_create_sql = sql.SQL( - "CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric})" + """ + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + USING {index_type} (embedding {embedding_metric}) + """ ).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), diff --git a/vectordb_bench/backend/clients/redis/cli.py b/vectordb_bench/backend/clients/redis/cli.py new file mode 100644 index 000000000..172b2b52a --- /dev/null +++ b/vectordb_bench/backend/clients/redis/cli.py @@ -0,0 +1,74 @@ +from typing import Annotated, TypedDict, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + HNSWFlavor2, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB + + +class RedisTypedDict(TypedDict): + host: Annotated[ + str, click.option("--host", type=str, help="Db host", required=True) + ] + password: Annotated[str, click.option("--password", type=str, help="Db password")] + port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")] + ssl: Annotated[ + bool, + click.option( + "--ssl/--no-ssl", + is_flag=True, + show_default=True, + default=True, + help="Enable or disable SSL for Redis", + ), + ] + ssl_ca_certs: Annotated[ + str, + click.option( + "--ssl-ca-certs", + show_default=True, + help="Path to certificate authority file to use for SSL", + ), + ] + cmd: Annotated[ + bool, + click.option( + "--cmd", + is_flag=True, + show_default=True, + default=False, + help="Cluster Mode Disabled (CMD) for Redis doesn't use Cluster conn", + ), + ] + + +class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2): + ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(RedisHNSWTypedDict) +def Redis(**parameters: Unpack[RedisHNSWTypedDict]): + from .config import RedisConfig + run( + db=DB.Redis, + db_config=RedisConfig( + db_label=parameters["db_label"], + password=SecretStr(parameters["password"]) + if parameters["password"] + else None, + host=SecretStr(parameters["host"]), + port=parameters["port"], + ssl=parameters["ssl"], + ssl_ca_certs=parameters["ssl_ca_certs"], + cmd=parameters["cmd"], + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/test/cli.py b/vectordb_bench/backend/clients/test/cli.py new file mode 100644 index 000000000..f06f33492 --- /dev/null +++ b/vectordb_bench/backend/clients/test/cli.py @@ -0,0 +1,25 @@ +from typing import Unpack + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB +from ..test.config import TestConfig, TestIndexConfig + + +class TestTypedDict(CommonTypedDict): + ... + + +@cli.command() +@click_parameter_decorators_from_typed_dict(TestTypedDict) +def Test(**parameters: Unpack[TestTypedDict]): + run( + db=DB.NewClient, + db_config=TestConfig(db_label=parameters["db_label"]), + db_case_config=TestIndexConfig(), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/test/config.py b/vectordb_bench/backend/clients/test/config.py new file mode 100644 index 000000000..01a77e000 --- /dev/null +++ b/vectordb_bench/backend/clients/test/config.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + + +class TestConfig(DBConfig): + def to_dict(self) -> dict: + return {"db_label": self.db_label} + + +class TestIndexConfig(BaseModel, DBCaseConfig): + metric_type: MetricType | None = None + + def index_param(self) -> dict: + return {} + + def search_param(self) -> dict: + return {} diff --git a/vectordb_bench/backend/clients/test/test.py b/vectordb_bench/backend/clients/test/test.py new file mode 100644 index 000000000..78732eb1e --- /dev/null +++ b/vectordb_bench/backend/clients/test/test.py @@ -0,0 +1,62 @@ +import logging +from contextlib import contextmanager +from typing import Any, Generator, Optional, Tuple + +from ..api import DBCaseConfig, VectorDB + +log = logging.getLogger(__name__) + + +class Test(VectorDB): + def __init__( + self, + dim: int, + db_config: dict, + db_case_config: DBCaseConfig, + drop_old: bool = False, + **kwargs, + ): + self.db_config = db_config + self.case_config = db_case_config + + log.info("Starting Test DB") + + @contextmanager + def init(self) -> Generator[None, None, None]: + """create and destroy connections to database. + + Examples: + >>> with self.init(): + >>> self.insert_embeddings() + """ + + yield + + def ready_to_load(self) -> bool: + return True + + def optimize(self) -> None: + pass + + def insert_embeddings( + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs: Any, + ) -> Tuple[int, Optional[Exception]]: + """Insert embeddings into the database. + Should call self.init() first. + """ + raise RuntimeError("Not implemented") + return len(metadata), None + + def search_embedding( + self, + query: list[float], + k: int = 100, + filters: dict | None = None, + timeout: int | None = None, + **kwargs: Any, + ) -> list[int]: + raise NotImplementedError + return [i for i in range(k)] diff --git a/vectordb_bench/backend/clients/weaviate_cloud/cli.py b/vectordb_bench/backend/clients/weaviate_cloud/cli.py new file mode 100644 index 000000000..b6f16011b --- /dev/null +++ b/vectordb_bench/backend/clients/weaviate_cloud/cli.py @@ -0,0 +1,41 @@ +from typing import Annotated, Unpack + +import click +from pydantic import SecretStr + +from ....cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from .. import DB + + +class WeaviateTypedDict(CommonTypedDict): + api_key: Annotated[ + str, click.option("--api-key", type=str, help="Weaviate api key", required=True) + ] + url: Annotated[ + str, + click.option("--url", type=str, help="Weaviate url", required=True), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(WeaviateTypedDict) +def Weaviate(**parameters: Unpack[WeaviateTypedDict]): + from .config import WeaviateConfig, WeaviateIndexConfig + + run( + db=DB.WeaviateCloud, + db_config=WeaviateConfig( + db_label=parameters["db_label"], + api_key=SecretStr(parameters["api_key"]), + url=SecretStr(parameters["url"]), + ), + db_case_config=WeaviateIndexConfig( + ef=256, efConstruction=256, maxConnections=16 + ), + **parameters, + ) diff --git a/vectordb_bench/backend/clients/zilliz_cloud/cli.py b/vectordb_bench/backend/clients/zilliz_cloud/cli.py new file mode 100644 index 000000000..31618f4ec --- /dev/null +++ b/vectordb_bench/backend/clients/zilliz_cloud/cli.py @@ -0,0 +1,55 @@ +from typing import Annotated, Unpack + +import click +import os +from pydantic import SecretStr + +from vectordb_bench.cli.cli import ( + CommonTypedDict, + cli, + click_parameter_decorators_from_typed_dict, + run, +) +from vectordb_bench.backend.clients import DB + + +class ZillizTypedDict(CommonTypedDict): + uri: Annotated[ + str, click.option("--uri", type=str, help="uri connection string", required=True) + ] + user_name: Annotated[ + str, click.option("--user-name", type=str, help="Db username", required=True) + ] + password: Annotated[ + str, + click.option("--password", + type=str, + help="Zilliz password", + default=lambda: os.environ.get("ZILLIZ_PASSWORD", ""), + show_default="$ZILLIZ_PASSWORD", + ), + ] + level: Annotated[ + str, + click.option("--level", type=str, help="Zilliz index level", required=False), + ] + + +@cli.command() +@click_parameter_decorators_from_typed_dict(ZillizTypedDict) +def ZillizAutoIndex(**parameters: Unpack[ZillizTypedDict]): + from .config import ZillizCloudConfig, AutoIndexConfig + + run( + db=DB.ZillizCloud, + db_config=ZillizCloudConfig( + db_label=parameters["db_label"], + uri=SecretStr(parameters["uri"]), + user=parameters["user_name"], + password=SecretStr(parameters["password"]), + ), + db_case_config=AutoIndexConfig( + params={parameters["level"]}, + ), + **parameters, + ) diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index c5c368d02..0680847aa 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -8,7 +8,7 @@ from . import utils from .cases import Case, CaseLabel from ..base import BaseModel -from ..models import TaskConfig, PerformanceTimeoutError +from ..models import TaskConfig, PerformanceTimeoutError, TaskStage from .clients import ( api, @@ -29,7 +29,7 @@ class RunningStatus(Enum): class CaseRunner(BaseModel): - """ DataSet, filter_rate, db_class with db config + """DataSet, filter_rate, db_class with db config Fields: run_id(str): run_id of this case runner, @@ -49,8 +49,9 @@ class CaseRunner(BaseModel): db: api.VectorDB | None = None test_emb: list[list[float]] | None = None - search_runner: MultiProcessingSearchRunner | None = None serial_search_runner: SerialSearchRunner | None = None + search_runner: MultiProcessingSearchRunner | None = None + final_search_runner: MultiProcessingSearchRunner | None = None def __eq__(self, obj): if isinstance(obj, CaseRunner): @@ -58,7 +59,7 @@ def __eq__(self, obj): self.config.db == obj.config.db and \ self.config.db_case_config == obj.config.db_case_config and \ self.ca.dataset == obj.ca.dataset - return False + return False def display(self) -> dict: c_dict = self.ca.dict(include={'label':True, 'filters': True,'dataset':{'data': {'name': True, 'size': True, 'dim': True, 'metric_type': True, 'label': True}} }) @@ -79,20 +80,25 @@ def init_db(self, drop_old: bool = True) -> None: db_config=self.config.db_config.to_dict(), db_case_config=self.config.db_case_config, drop_old=drop_old, - ) + ) # type:ignore + def _pre_run(self, drop_old: bool = True): try: self.init_db(drop_old) self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filter_rate) except ModuleNotFoundError as e: - log.warning(f"pre run case error: please install client for db: {self.config.db}, error={e}") + log.warning( + f"pre run case error: please install client for db: {self.config.db}, error={e}" + ) raise e from None except Exception as e: log.warning(f"pre run case error: {e}") raise e from None def run(self, drop_old: bool = True) -> Metric: + log.info("Starting run") + self._pre_run(drop_old) if self.ca.label == CaseLabel.Load: @@ -105,31 +111,35 @@ def run(self, drop_old: bool = True) -> Metric: raise ValueError(msg) def _run_capacity_case(self) -> Metric: - """ run capacity cases + """run capacity cases Returns: Metric: the max load count """ + assert self.db is not None log.info("Start capacity case") try: - runner = SerialInsertRunner(self.db, self.ca.dataset, self.normalize, self.ca.load_timeout) + runner = SerialInsertRunner( + self.db, self.ca.dataset, self.normalize, self.ca.load_timeout + ) count = runner.run_endlessness() except Exception as e: log.warning(f"Failed to run capacity case, reason = {e}") raise e from None else: - log.info(f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}") + log.info( + f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}" + ) return Metric(max_load_count=count) def _run_perf_case(self, drop_old: bool = True) -> Metric: - """ run performance cases + """run performance cases Returns: Metric: load_duration, recall, serial_latency_p99, and, qps """ - try: - m = Metric() - if drop_old: + ''' + if drop_old: _, load_dur = self._load_train_data() build_dur = self._optimize() m.load_duration = round(load_dur+build_dur, 4) @@ -142,6 +152,39 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric: self._init_search_runner() m.qps = self._conc_search() m.recall, m.serial_latency_p99 = self._serial_search() + ''' + + log.info("Start performance case") + try: + m = Metric() + if drop_old: + if TaskStage.LOAD in self.config.stages: + # self._load_train_data() + _, load_dur = self._load_train_data() + build_dur = self._optimize() + m.load_duration = round(load_dur + build_dur, 4) + log.info( + f"Finish loading the entire dataset into VectorDB," + f" insert_duration={load_dur}, optimize_duration={build_dur}" + f" load_duration(insert + optimize) = {m.load_duration}" + ) + else: + log.info("Data loading skipped") + if ( + TaskStage.SEARCH_SERIAL in self.config.stages + or TaskStage.SEARCH_CONCURRENT in self.config.stages + ): + self._init_search_runner() + if TaskStage.SEARCH_SERIAL in self.config.stages: + search_results = self._serial_search() + ''' + m.recall = search_results.recall + m.serial_latencies = search_results.serial_latencies + ''' + m.recall, m.serial_latency_p99 = search_results + if TaskStage.SEARCH_CONCURRENT in self.config.stages: + search_results = self._conc_search() + m.qps = search_results except Exception as e: log.warning(f"Failed to run performance case, reason = {e}") traceback.print_exc() @@ -217,18 +260,23 @@ def _init_search_runner(self): gt_df = self.ca.dataset.gt_data - self.serial_search_runner = SerialSearchRunner( - db=self.db, - test_data=self.test_emb, - ground_truth=gt_df, - filters=self.ca.filters, - ) - - self.search_runner = MultiProcessingSearchRunner( - db=self.db, - test_data=self.test_emb, - filters=self.ca.filters, - ) + if TaskStage.SEARCH_SERIAL in self.config.stages: + self.serial_search_runner = SerialSearchRunner( + db=self.db, + test_data=self.test_emb, + ground_truth=gt_df, + filters=self.ca.filters, + k=self.config.case_config.k, + ) + if TaskStage.SEARCH_CONCURRENT in self.config.stages: + self.search_runner = MultiProcessingSearchRunner( + db=self.db, + test_data=self.test_emb, + filters=self.ca.filters, + concurrencies=self.config.case_config.concurrency_search_config.num_concurrency, + duration=self.config.case_config.concurrency_search_config.concurrency_duration, + k=self.config.case_config.k, + ) def stop(self): if self.search_runner: diff --git a/vectordb_bench/cli/__init__.py b/vectordb_bench/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py new file mode 100644 index 000000000..00910261b --- /dev/null +++ b/vectordb_bench/cli/cli.py @@ -0,0 +1,362 @@ +import logging +import time +from concurrent.futures import wait +from datetime import datetime +from pprint import pformat +from typing import ( + Annotated, + Callable, + List, + Optional, + Type, + TypedDict, + Unpack, + get_origin, + get_type_hints, + Dict, + Any, +) +import click +from .. import config +from ..backend.clients import DB +from ..interface import benchMarkRunner, global_result_future +from ..models import ( + CaseConfig, + CaseType, + ConcurrencySearchConfig, + DBCaseConfig, + DBConfig, + TaskConfig, + TaskStage, +) +import os +from yaml import load +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader + + +def click_get_defaults_from_file(ctx, param, value): + if value: + if os.path.exists(value): + input_file = value + else: + input_file = os.path.join(config.CONFIG_LOCAL_DIR, value) + try: + with open(input_file, 'r') as f: + _config: Dict[str, Dict[str, Any]] = load(f.read(), Loader=Loader) + ctx.default_map = _config.get(ctx.command.name, {}) + except Exception as e: + raise click.BadParameter(f"Failed to load config file: {e}") + return value + + +def click_parameter_decorators_from_typed_dict( + typed_dict: Type, +) -> Callable[[click.decorators.FC], click.decorators.FC]: + """A convenience method decorator that will read in a TypedDict with parameters defined by Annotated types. + from .models import CaseConfig, CaseType, DBCaseConfig, DBConfig, TaskConfig, TaskStage + The click.options will be collected and re-composed as a single decorator to apply to the click.command. + + Args: + typed_dict (TypedDict) with Annotated[..., click.option()] keys + + Returns: + a fully decorated method + + + For clarity, the key names of the TypedDict will be used to determine the type hints for the input parameters. + The actual function parameters are controlled by the click.option definitions. You must manually ensure these are aligned in a sensible way! + + Example: + ``` + class CommonTypedDict(TypedDict): + z: Annotated[int, click.option("--z/--no-z", is_flag=True, type=bool, help="help z", default=True, show_default=True)] + name: Annotated[str, click.argument("name", required=False, default="Jeff")] + + class FooTypedDict(CommonTypedDict): + x: Annotated[int, click.option("--x", type=int, help="help x", default=1, show_default=True)] + y: Annotated[str, click.option("--y", type=str, help="help y", default="foo", show_default=True)] + + @cli.command() + @click_parameter_decorators_from_typed_dict(FooTypedDict) + def foo(**parameters: Unpack[FooTypedDict]): + "Foo docstring" + print(f"input parameters: {parameters["x"]}") + ``` + """ + decorators = [] + for _, t in get_type_hints(typed_dict, include_extras=True).items(): + assert get_origin(t) is Annotated + if ( + len(t.__metadata__) == 1 + and t.__metadata__[0].__module__ == "click.decorators" + ): + # happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] with no additional metadata defined (len=1) + decorators.append(t.__metadata__[0]) + else: + raise RuntimeError( + "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring" + ) + + def deco(f): + for dec in reversed(decorators): + f = dec(f) + return f + + return deco + + +def click_arg_split(ctx: click.Context, param: click.core.Option, value): + """Will split a comma-separated list input into an actual list. + + Args: + ctx (...): unused click arg + param (...): unused click arg + value (str): input comma-separated list + + Returns: + value (List[str]): list of original + """ + # split columns by ',' and remove whitespace + if value is None: + return [] + return [c.strip() for c in value.split(",") if c.strip()] + + +def parse_task_stages( + drop_old: bool, + load: bool, + search_serial: bool, + search_concurrent: bool, +) -> List[TaskStage]: + stages = [] + if load and not drop_old: + raise RuntimeError("Dropping old data cannot be skipped if loading data") + elif drop_old and not load: + raise RuntimeError("Load cannot be skipped if dropping old data") + if drop_old: + stages.append(TaskStage.DROP_OLD) + if load: + stages.append(TaskStage.LOAD) + if search_serial: + stages.append(TaskStage.SEARCH_SERIAL) + if search_concurrent: + stages.append(TaskStage.SEARCH_CONCURRENT) + return stages + + +log = logging.getLogger(__name__) + + +class CommonTypedDict(TypedDict): + config_file: Annotated[ + bool, + click.option('--config-file', + type=click.Path(), + callback=click_get_defaults_from_file, + is_eager=True, + expose_value=False, + help='Read configuration from yaml file'), + ] + drop_old: Annotated[ + bool, + click.option( + "--drop-old/--skip-drop-old", + type=bool, + default=True, + help="Drop old or skip", + show_default=True, + ), + ] + load: Annotated[ + bool, + click.option( + "--load/--skip-load", + type=bool, + default=True, + help="Load or skip", + show_default=True, + ), + ] + search_serial: Annotated[ + bool, + click.option( + "--search-serial/--skip-search-serial", + type=bool, + default=True, + help="Search serial or skip", + show_default=True, + ), + ] + search_concurrent: Annotated[ + bool, + click.option( + "--search-concurrent/--skip-search-concurrent", + type=bool, + default=True, + help="Search concurrent or skip", + show_default=True, + ), + ] + case_type: Annotated[ + str, + click.option( + "--case-type", + type=click.Choice([ct.name for ct in CaseType if ct.name != "Custom"]), + default="Performance1536D50K", + help="Case type", + ), + ] + db_label: Annotated[ + str, + click.option( + "--db-label", type=str, help="Db label, default: date in ISO format", + show_default=True, + default=datetime.now().isoformat() + ), + ] + dry_run: Annotated[ + bool, + click.option( + "--dry-run", + type=bool, + default=False, + is_flag=True, + help="Print just the configuration and exit without running the tasks", + ), + ] + k: Annotated[ + int, + click.option( + "--k", + type=int, + default=config.K_DEFAULT, + show_default=True, + help="K value for number of nearest neighbors to search", + ), + ] + concurrency_duration: Annotated[ + int, + click.option( + "--concurrency-duration", + type=int, + default=config.CONCURRENCY_DURATION, + show_default=True, + help="Adjusts the duration in seconds of each concurrency search", + ), + ] + num_concurrency: Annotated[ + List[str], + click.option( + "--num-concurrency", + type=str, + help="Comma-separated list of concurrency values to test during concurrent search", + show_default=True, + default=",".join(map(str, config.NUM_CONCURRENCY)), + callback=lambda *args: list(map(int, click_arg_split(*args))), + ), + ] + + +class HNSWBaseTypedDict(TypedDict): + m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m")] + ef_construction: Annotated[ + Optional[int], + click.option("--ef-construction", type=int, help="hnsw ef-construction"), + ] + + +class HNSWBaseRequiredTypedDict(TypedDict): + m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m", required=True)] + ef_construction: Annotated[ + Optional[int], + click.option("--ef-construction", type=int, help="hnsw ef-construction", required=True), + ] + + +class HNSWFlavor1(HNSWBaseTypedDict): + ef_search: Annotated[ + Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search") + ] + + +class HNSWFlavor2(HNSWBaseTypedDict): + ef_runtime: Annotated[ + Optional[int], click.option("--ef-runtime", type=int, help="hnsw ef-runtime") + ] + + +class HNSWFlavor3(HNSWBaseRequiredTypedDict): + ef_search: Annotated[ + Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search", required=True) + ] + + +class IVFFlatTypedDict(TypedDict): + lists: Annotated[ + Optional[int], click.option("--lists", type=int, help="ivfflat lists") + ] + probes: Annotated[ + Optional[int], click.option("--probes", type=int, help="ivfflat probes") + ] + + +class IVFFlatTypedDictN(TypedDict): + nlist: Annotated[ + Optional[int], click.option("--lists", "nlist", type=int, help="ivfflat lists", required=True) + ] + nprobe: Annotated[ + Optional[int], click.option("--probes", "nprobe", type=int, help="ivfflat probes", required=True) + ] + + +@click.group() +def cli(): + ... + + +def run( + db: DB, + db_config: DBConfig, + db_case_config: DBCaseConfig, + **parameters: Unpack[CommonTypedDict], +): + """Builds a single VectorDBBench Task and runs it, awaiting the task until finished. + + Args: + db (DB) + db_config (DBConfig) + db_case_config (DBCaseConfig) + **parameters: expects keys from CommonTypedDict + """ + + task = TaskConfig( + db=db, + db_config=db_config, + db_case_config=db_case_config, + case_config=CaseConfig( + case_id=CaseType[parameters["case_type"]], + k=parameters["k"], + concurrency_search_config=ConcurrencySearchConfig( + concurrency_duration=parameters["concurrency_duration"], + num_concurrency=[int(s) for s in parameters["num_concurrency"]], + ), + ), + stages=parse_task_stages( + ( + False if not parameters["load"] else parameters["drop_old"] + ), # only drop old data if loading new data + parameters["load"], + parameters["search_serial"], + parameters["search_concurrent"], + ), + ) + + log.info(f"Task:\n{pformat(task)}\n") + if not parameters["dry_run"]: + benchMarkRunner.run([task]) + time.sleep(5) + if global_result_future: + wait([global_result_future]) diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py new file mode 100644 index 000000000..396909cd5 --- /dev/null +++ b/vectordb_bench/cli/vectordbbench.py @@ -0,0 +1,20 @@ +from ..backend.clients.pgvector.cli import PgVectorHNSW +from ..backend.clients.redis.cli import Redis +from ..backend.clients.test.cli import Test +from ..backend.clients.weaviate_cloud.cli import Weaviate +from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex +from ..backend.clients.milvus.cli import MilvusAutoIndex + + +from .cli import cli + +cli.add_command(PgVectorHNSW) +cli.add_command(Redis) +cli.add_command(Weaviate) +cli.add_command(Test) +cli.add_command(ZillizAutoIndex) +cli.add_command(MilvusAutoIndex) + + +if __name__ == "__main__": + cli() diff --git a/vectordb_bench/config-files/sample_config.yml b/vectordb_bench/config-files/sample_config.yml new file mode 100644 index 000000000..5a7d2cfd5 --- /dev/null +++ b/vectordb_bench/config-files/sample_config.yml @@ -0,0 +1,17 @@ +pgvectorhnsw: + db_label: pgConfigTest + user_name: vectordbbench + db_name: vectordbbench + host: localhost + m: 16 + ef_construction: 128 + ef_search: 128 +milvushnsw: + skip_search_serial: True + case_type: Performance1536D50K + uri: http://localhost:19530 + m: 16 + ef_construction: 128 + ef_search: 128 + drop_old: False + load: False diff --git a/vectordb_bench/frontend/components/run_test/dbSelector.py b/vectordb_bench/frontend/components/run_test/dbSelector.py index 61db843f3..5fcbd8c08 100644 --- a/vectordb_bench/frontend/components/run_test/dbSelector.py +++ b/vectordb_bench/frontend/components/run_test/dbSelector.py @@ -1,3 +1,5 @@ +from streamlit.runtime.media_file_storage import MediaFileStorageError + from vectordb_bench.frontend.const.styles import * from vectordb_bench.frontend.const.dbCaseConfigs import DB_LIST @@ -30,7 +32,11 @@ def dbSelector(st): for i, db in enumerate(DB_LIST): column = dbContainerColumns[i % DB_SELECTOR_COLUMNS] dbIsActived[db] = column.checkbox(db.name) - column.image(DB_TO_ICON.get(db, "")) + try: + column.image(DB_TO_ICON.get(db, "")) + except MediaFileStorageError as e: + column.warning(f"{db.name} image not available") + pass activedDbList = [db for db in DB_LIST if dbIsActived[db]] return activedDbList diff --git a/vectordb_bench/frontend/components/run_test/submitTask.py b/vectordb_bench/frontend/components/run_test/submitTask.py index 22d34b4f5..26cb1ef70 100644 --- a/vectordb_bench/frontend/components/run_test/submitTask.py +++ b/vectordb_bench/frontend/components/run_test/submitTask.py @@ -37,22 +37,30 @@ def taskLabelInput(st): def advancedSettings(st): container = st.columns([1, 2]) index_already_exists = container[0].checkbox("Index already exists", value=False) - container[1].caption("if actived, inserting and building will be skipped.") + container[1].caption("if selected, inserting and building will be skipped.") container = st.columns([1, 2]) use_aliyun = container[0].checkbox("Dataset from Aliyun (Shanghai)", value=False) container[1].caption( - "if actived, the dataset will be downloaded from Aliyun OSS shanghai, default AWS S3 aws-us-west." + "if selected, the dataset will be downloaded from Aliyun OSS shanghai, default AWS S3 aws-us-west." ) - return index_already_exists, use_aliyun + container = st.columns([1, 2]) + k = container[0].number_input("k",min_value=1, value=100, label_visibility="collapsed") + container[1].caption( + "K value for number of nearest neighbors to search" + ) + + return index_already_exists, use_aliyun, k def controlPanel(st, tasks, taskLabel, isAllValid): - index_already_exists, use_aliyun = advancedSettings(st) + index_already_exists, use_aliyun, k = advancedSettings(st) def runHandler(): benchMarkRunner.set_drop_old(not index_already_exists) + for task in tasks: + task.case_config.k = k benchMarkRunner.set_download_address(use_aliyun) benchMarkRunner.run(tasks, taskLabel) diff --git a/vectordb_bench/frontend/const/dbCaseConfigs.py b/vectordb_bench/frontend/const/dbCaseConfigs.py index 1e69c57aa..ed101ac69 100644 --- a/vectordb_bench/frontend/const/dbCaseConfigs.py +++ b/vectordb_bench/frontend/const/dbCaseConfigs.py @@ -9,7 +9,7 @@ MAX_STREAMLIT_INT = (1 << 53) - 1 -DB_LIST = [d for d in DB] +DB_LIST = [d for d in DB if d != DB.Test] DIVIDER = "DIVIDER" CASE_LIST_WITH_DIVIDER = [ @@ -19,6 +19,7 @@ DIVIDER, CaseType.Performance1536D5M, CaseType.Performance1536D500K, + CaseType.Performance1536D50K, DIVIDER, CaseType.Performance768D10M1P, CaseType.Performance768D1M1P, diff --git a/vectordb_bench/interface.py b/vectordb_bench/interface.py index c170c67dc..c765d0d63 100644 --- a/vectordb_bench/interface.py +++ b/vectordb_bench/interface.py @@ -1,38 +1,33 @@ -import traceback +import concurrent.futures +import logging +import multiprocessing as mp import pathlib import signal -import logging +import traceback import uuid -import concurrent -import multiprocessing as mp +from enum import Enum from multiprocessing.connection import Connection import psutil -from enum import Enum from . import config -from .metric import Metric -from .models import ( - TaskConfig, - TestResult, - CaseResult, - LoadTimeoutError, - PerformanceTimeoutError, - ResultLabel, -) -from .backend.result_collector import ResultCollector from .backend.assembler import Assembler -from .backend.task_runner import TaskRunner from .backend.data_source import DatasetSource +from .backend.result_collector import ResultCollector +from .backend.task_runner import TaskRunner +from .metric import Metric +from .models import (CaseResult, LoadTimeoutError, PerformanceTimeoutError, + ResultLabel, TaskConfig, TaskStage, TestResult) log = logging.getLogger(__name__) global_result_future: concurrent.futures.Future | None = None + class SIGNAL(Enum): - SUCCESS=0 - ERROR=1 - WIP=2 + SUCCESS = 0 + ERROR = 1 + WIP = 2 class BenchMarkRunner: @@ -42,9 +37,11 @@ def __init__(self): self.drop_old: bool = True self.dataset_source: DatasetSource = DatasetSource.S3 + def set_drop_old(self, drop_old: bool): self.drop_old = drop_old + def set_download_address(self, use_aliyun: bool): if use_aliyun: self.dataset_source = DatasetSource.AliyunOSS @@ -152,13 +149,13 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non latest_runner, cached_load_duration = None, None for idx, runner in enumerate(running_task.case_runners): case_res = CaseResult( - result_id=idx, metrics=Metric(), task_config=runner.config, ) # drop_old = False if latest_runner and runner == latest_runner else config.DROP_OLD - drop_old = config.DROP_OLD + # drop_old = config.DROP_OLD + drop_old = TaskStage.DROP_OLD in runner.config.stages if latest_runner and runner == latest_runner: drop_old = False elif not self.drop_old: @@ -167,7 +164,7 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non log.info(f"[{idx+1}/{running_task.num_cases()}] start case: {runner.display()}, drop_old={drop_old}") case_res.metrics = runner.run(drop_old) log.info(f"[{idx+1}/{running_task.num_cases()}] finish case: {runner.display()}, " - f"result={case_res.metrics}, label={case_res.label}") + f"result={case_res.metrics}, label={case_res.label}") # cache the latest succeeded runner latest_runner = runner @@ -193,7 +190,6 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non c_results.append(case_res) send_conn.send((SIGNAL.WIP, idx)) - test_result = TestResult( run_id=running_task.run_id, task_label=running_task.task_label, @@ -204,7 +200,7 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non send_conn.send((SIGNAL.SUCCESS, None)) send_conn.close() - log.info(f"Succes to finish task: label={running_task.task_label}, run_id={running_task.run_id}") + log.info(f"Success to finish task: label={running_task.task_label}, run_id={running_task.run_id}") except Exception as e: err_msg = f"An error occurs when running task={running_task.task_label}, run_id={running_task.run_id}, err={e}" @@ -246,7 +242,7 @@ def kill_proc_tree(self, sig=signal.SIGTERM, timeout=None, on_terminate=None): called as soon as a child terminates. """ children = psutil.Process().children(recursive=True) - for p in children: + for p in children: try: log.warning(f"sending SIGTERM to child process: {p}") p.send_signal(sig) diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index ec1b610e1..aa9c930ea 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -17,7 +17,6 @@ from . import config from .metric import Metric - log = logging.getLogger(__name__) @@ -65,15 +64,55 @@ class CaseConfigParamType(Enum): maintenance_work_mem = "maintenance_work_mem" max_parallel_workers = "max_parallel_workers" + class CustomizedCase(BaseModel): pass +class ConcurrencySearchConfig(BaseModel): + num_concurrency: List[int] = config.NUM_CONCURRENCY + concurrency_duration: int = config.CONCURRENCY_DURATION + + class CaseConfig(BaseModel): """cases, dataset, test cases, filter rate, params""" case_id: CaseType custom_case: dict | None = None + k: int | None = config.K_DEFAULT + concurrency_search_config: ConcurrencySearchConfig = ConcurrencySearchConfig() + + ''' + @property + def k(self): + """K search parameter, default is config.K_DEFAULT""" + return self._k + + # + @k.setter + def k(self, value): + self._k = value + ''' + +class TaskStage(StrEnum): + """Enumerations of various stages of the task""" + + DROP_OLD = auto() + LOAD = auto() + SEARCH_SERIAL = auto() + SEARCH_CONCURRENT = auto() + + def __repr__(self) -> str: + return str.__repr__(self.value) + + +# TODO: Add CapacityCase enums and adjust TaskRunner to utilize +ALL_TASK_STAGES = [ + TaskStage.DROP_OLD, + TaskStage.LOAD, + TaskStage.SEARCH_SERIAL, + TaskStage.SEARCH_CONCURRENT, +] class TaskConfig(BaseModel): @@ -81,6 +120,7 @@ class TaskConfig(BaseModel): db_config: DBConfig db_case_config: DBCaseConfig case_config: CaseConfig + stages: List[TaskStage] = ALL_TASK_STAGES @property def db_name(self): @@ -210,18 +250,18 @@ def append_return(x, y): max_db = max(map(len, [f.task_config.db.name for f in filtered_results])) max_db_labels = ( - max(map(len, [f.task_config.db_config.db_label for f in filtered_results])) - + 3 + max(map(len, [f.task_config.db_config.db_label for f in filtered_results])) + + 3 ) max_case = max( map(len, [f.task_config.case_config.case_id.name for f in filtered_results]) ) max_load_dur = ( - max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3 + max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3 ) max_qps = max(map(len, [str(f.metrics.qps) for f in filtered_results])) + 3 max_recall = ( - max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3 + max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3 ) max_db_labels = 8 if max_db_labels < 8 else max_db_labels