Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MemoryDB client #339

Merged
merged 7 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ __MACOSX
build/
venv/
.idea/
results/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ All the database client supported
| pgvector | `pip install vectordb-bench[pgvector]` |
| pgvecto.rs | `pip install vectordb-bench[pgvecto_rs]` |
| redis | `pip install vectordb-bench[redis]` |
| memorydb | `pip install vectordb-bench[memorydb]` |
| chromadb | `pip install vectordb-bench[chromadb]` |
| awsopensearch | `pip install vectordb-bench[awsopensearch]` |

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ elastic = [ "elasticsearch" ]
pgvector = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvecto_rs = [ "psycopg2" ]
redis = [ "redis" ]
memorydb = [ "memorydb" ]
chromadb = [ "chromadb" ]
awsopensearch = [ "awsopensearch" ]
zilliz_cloud = []
Expand Down
9 changes: 9 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class DB(Enum):
PgVector = "PgVector"
PgVectoRS = "PgVectoRS"
Redis = "Redis"
MemoryDB = "MemoryDB"
Chroma = "Chroma"
AWSOpenSearch = "OpenSearch"
Test = "test"
Expand Down Expand Up @@ -74,6 +75,10 @@ def init_cls(self) -> Type[VectorDB]:
if self == DB.Redis:
from .redis.redis import Redis
return Redis

if self == DB.MemoryDB:
from .memorydb.memorydb import MemoryDB
return MemoryDB

if self == DB.Chroma:
from .chroma.chroma import ChromaClient
Expand Down Expand Up @@ -121,6 +126,10 @@ def config_cls(self) -> Type[DBConfig]:
if self == DB.Redis:
from .redis.config import RedisConfig
return RedisConfig

if self == DB.MemoryDB:
from .memorydb.config import MemoryDBConfig
return MemoryDBConfig

if self == DB.Chroma:
from .chroma.config import ChromaConfig
Expand Down
88 changes: 88 additions & 0 deletions vectordb_bench/backend/clients/memorydb/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Annotated, TypedDict, Unpack

import click
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
HNSWFlavor2,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from .. import DB


class MemoryDBTypedDict(TypedDict):
host: Annotated[
str, click.option("--host", type=str, help="Db host", required=True)
]
password: Annotated[str, click.option("--password", type=str, help="Db password")]
port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
ssl: Annotated[
bool,
click.option(
"--ssl/--no-ssl",
is_flag=True,
show_default=True,
default=True,
help="Enable or disable SSL for MemoryDB",
),
]
ssl_ca_certs: Annotated[
str,
click.option(
"--ssl-ca-certs",
show_default=True,
help="Path to certificate authority file to use for SSL",
),
]
cmd: Annotated[
bool,
click.option(
"--cmd",
is_flag=True,
show_default=True,
default=False,
help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports cluster mode (CME)",
),
]
insert_batch_size: Annotated[
int,
click.option(
"--insert-batch-size",
type=int,
default=10,
help="Batch size for inserting data. Adjust this as needed, but don't make it too big",
),
]


class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2):
...


@cli.command()
@click_parameter_decorators_from_typed_dict(MemoryDBHNSWTypedDict)
def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]):
from .config import MemoryDBConfig, MemoryDBHNSWConfig

run(
db=DB.MemoryDB,
db_config=MemoryDBConfig(
db_label=parameters["db_label"],
password=SecretStr(parameters["password"]) if parameters["password"] else None,
host=SecretStr(parameters["host"]),
port=parameters["port"],
ssl=parameters["ssl"],
ssl_ca_certs=parameters["ssl_ca_certs"],
cmd=parameters["cmd"],
),
db_case_config=MemoryDBHNSWConfig(
M=parameters["m"],
ef_construction=parameters["ef_construction"],
ef_runtime=parameters["ef_runtime"],
insert_batch_size=parameters["insert_batch_size"]
),
**parameters,
)
54 changes: 54 additions & 0 deletions vectordb_bench/backend/clients/memorydb/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from pydantic import BaseModel, SecretStr

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType


class MemoryDBConfig(DBConfig):
host: SecretStr
password: SecretStr | None = None
port: int | None = None
ssl: bool | None = None
cmd: bool | None = None
ssl_ca_certs: str | None = None

def to_dict(self) -> dict:
return {
"host": self.host.get_secret_value(),
"port": self.port,
"password": self.password.get_secret_value() if self.password else None,
"ssl": self.ssl,
"cmd": self.cmd,
"ssl_ca_certs": self.ssl_ca_certs,
}


class MemoryDBIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
insert_batch_size: int | None = None

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2"
elif self.metric_type == MetricType.IP:
return "ip"
return "cosine"


class MemoryDBHNSWConfig(MemoryDBIndexConfig):
M: int | None = 16
ef_construction: int | None = 64
ef_runtime: int | None = 10
index: IndexType = IndexType.HNSW

def index_param(self) -> dict:
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
"m": self.M,
"ef_construction": self.ef_construction,
}

def search_param(self) -> dict:
return {
"ef_runtime": self.ef_runtime,
}
Loading