Skip to content

Commit

Permalink
Support for MemoryDB client (#339)
Browse files Browse the repository at this point in the history
* Add support for MemoryDB

* Add support for MemoryDB

* Update vectordb_bench/backend/clients/memorydb/cli.py

Co-authored-by: Jonathan S. Katz <jkatz@users.noreply.github.com>

* Update vectordb_bench/backend/clients/memorydb/cli.py

Co-authored-by: Jonathan S. Katz <jkatz@users.noreply.github.com>

---------

Co-authored-by: Baswanth Vegunta <baswanth@amazon.com>
Co-authored-by: Jonathan S. Katz <jkatz@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 31, 2024
1 parent f4ab02a commit e546a42
Show file tree
Hide file tree
Showing 8 changed files with 410 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ __MACOSX
build/
venv/
.idea/
results/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ All the database client supported
| pgvector | `pip install vectordb-bench[pgvector]` |
| pgvecto.rs | `pip install vectordb-bench[pgvecto_rs]` |
| redis | `pip install vectordb-bench[redis]` |
| memorydb | `pip install vectordb-bench[memorydb]` |
| chromadb | `pip install vectordb-bench[chromadb]` |
| awsopensearch | `pip install vectordb-bench[awsopensearch]` |

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ elastic = [ "elasticsearch" ]
pgvector = [ "psycopg", "psycopg-binary", "pgvector" ]
pgvecto_rs = [ "psycopg2" ]
redis = [ "redis" ]
memorydb = [ "memorydb" ]
chromadb = [ "chromadb" ]
awsopensearch = [ "awsopensearch" ]
zilliz_cloud = []
Expand Down
9 changes: 9 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class DB(Enum):
PgVector = "PgVector"
PgVectoRS = "PgVectoRS"
Redis = "Redis"
MemoryDB = "MemoryDB"
Chroma = "Chroma"
AWSOpenSearch = "OpenSearch"
Test = "test"
Expand Down Expand Up @@ -74,6 +75,10 @@ def init_cls(self) -> Type[VectorDB]:
if self == DB.Redis:
from .redis.redis import Redis
return Redis

if self == DB.MemoryDB:
from .memorydb.memorydb import MemoryDB
return MemoryDB

if self == DB.Chroma:
from .chroma.chroma import ChromaClient
Expand Down Expand Up @@ -121,6 +126,10 @@ def config_cls(self) -> Type[DBConfig]:
if self == DB.Redis:
from .redis.config import RedisConfig
return RedisConfig

if self == DB.MemoryDB:
from .memorydb.config import MemoryDBConfig
return MemoryDBConfig

if self == DB.Chroma:
from .chroma.config import ChromaConfig
Expand Down
88 changes: 88 additions & 0 deletions vectordb_bench/backend/clients/memorydb/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Annotated, TypedDict, Unpack

import click
from pydantic import SecretStr

from ....cli.cli import (
CommonTypedDict,
HNSWFlavor2,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
from .. import DB


class MemoryDBTypedDict(TypedDict):
host: Annotated[
str, click.option("--host", type=str, help="Db host", required=True)
]
password: Annotated[str, click.option("--password", type=str, help="Db password")]
port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
ssl: Annotated[
bool,
click.option(
"--ssl/--no-ssl",
is_flag=True,
show_default=True,
default=True,
help="Enable or disable SSL for MemoryDB",
),
]
ssl_ca_certs: Annotated[
str,
click.option(
"--ssl-ca-certs",
show_default=True,
help="Path to certificate authority file to use for SSL",
),
]
cmd: Annotated[
bool,
click.option(
"--cmd",
is_flag=True,
show_default=True,
default=False,
help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports cluster mode (CME)",
),
]
insert_batch_size: Annotated[
int,
click.option(
"--insert-batch-size",
type=int,
default=10,
help="Batch size for inserting data. Adjust this as needed, but don't make it too big",
),
]


class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2):
...


@cli.command()
@click_parameter_decorators_from_typed_dict(MemoryDBHNSWTypedDict)
def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]):
from .config import MemoryDBConfig, MemoryDBHNSWConfig

run(
db=DB.MemoryDB,
db_config=MemoryDBConfig(
db_label=parameters["db_label"],
password=SecretStr(parameters["password"]) if parameters["password"] else None,
host=SecretStr(parameters["host"]),
port=parameters["port"],
ssl=parameters["ssl"],
ssl_ca_certs=parameters["ssl_ca_certs"],
cmd=parameters["cmd"],
),
db_case_config=MemoryDBHNSWConfig(
M=parameters["m"],
ef_construction=parameters["ef_construction"],
ef_runtime=parameters["ef_runtime"],
insert_batch_size=parameters["insert_batch_size"]
),
**parameters,
)
54 changes: 54 additions & 0 deletions vectordb_bench/backend/clients/memorydb/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from pydantic import BaseModel, SecretStr

from ..api import DBCaseConfig, DBConfig, IndexType, MetricType


class MemoryDBConfig(DBConfig):
host: SecretStr
password: SecretStr | None = None
port: int | None = None
ssl: bool | None = None
cmd: bool | None = None
ssl_ca_certs: str | None = None

def to_dict(self) -> dict:
return {
"host": self.host.get_secret_value(),
"port": self.port,
"password": self.password.get_secret_value() if self.password else None,
"ssl": self.ssl,
"cmd": self.cmd,
"ssl_ca_certs": self.ssl_ca_certs,
}


class MemoryDBIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
insert_batch_size: int | None = None

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2"
elif self.metric_type == MetricType.IP:
return "ip"
return "cosine"


class MemoryDBHNSWConfig(MemoryDBIndexConfig):
M: int | None = 16
ef_construction: int | None = 64
ef_runtime: int | None = 10
index: IndexType = IndexType.HNSW

def index_param(self) -> dict:
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
"m": self.M,
"ef_construction": self.ef_construction,
}

def search_param(self) -> dict:
return {
"ef_runtime": self.ef_runtime,
}
Loading

0 comments on commit e546a42

Please sign in to comment.