Skip to content

Commit

Permalink
Add client aws_opensearch
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>
  • Loading branch information
cydrain committed Jul 15, 2024
1 parent 09306a0 commit 676e68c
Show file tree
Hide file tree
Showing 5 changed files with 359 additions and 0 deletions.
13 changes: 13 additions & 0 deletions vectordb_bench/backend/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DB(Enum):
PgVectoRS = "PgVectoRS"
Redis = "Redis"
Chroma = "Chroma"
AWSOpenSearch = "OpenSearch"
Test = "test"


Expand Down Expand Up @@ -78,6 +79,10 @@ def init_cls(self) -> Type[VectorDB]:
from .chroma.chroma import ChromaClient
return ChromaClient

if self == DB.AWSOpenSearch:
from .aws_opensearch.aws_opensearch import AWSOpenSearch
return AWSOpenSearch

@property
def config_cls(self) -> Type[DBConfig]:
"""Import while in use"""
Expand Down Expand Up @@ -121,6 +126,10 @@ def config_cls(self) -> Type[DBConfig]:
from .chroma.config import ChromaConfig
return ChromaConfig

if self == DB.AWSOpenSearch:
from .aws_opensearch.config import AWSOpenSearchConfig
return AWSOpenSearchConfig

def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
if self == DB.Milvus:
from .milvus.config import _milvus_case_config
Expand Down Expand Up @@ -150,6 +159,10 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon
from .pgvecto_rs.config import _pgvecto_rs_case_config
return _pgvecto_rs_case_config.get(index_type)

if self == DB.AWSOpenSearch:
from .aws_opensearch.config import AWSOpenSearchIndexConfig
return AWSOpenSearchIndexConfig

# DB.Pinecone, DB.Chroma, DB.Redis
return EmptyDBCaseConfig

Expand Down
159 changes: 159 additions & 0 deletions vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import logging
from contextlib import contextmanager
import time
from typing import Iterable, Type
from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk

log = logging.getLogger(__name__)


class AWSOpenSearch(VectorDB):
def __init__(
self,
dim: int,
db_config: dict,
db_case_config: AWSOpenSearchIndexConfig,
index_name: str = "vdb_bench_index", # must be lowercase
id_col_name: str = "id",
vector_col_name: str = "embedding",
drop_old: bool = False,
**kwargs,
):
self.dim = dim
self.db_config = db_config
self.case_config = db_case_config
self.index_name = index_name
self.id_col_name = id_col_name
self.category_col_names = [
f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
]
self.vector_col_name = vector_col_name

log.info(f"AWS_OpenSearch client config: {self.db_config}")
client = OpenSearch(**self.db_config)
if drop_old:
log.info(f"AWS_OpenSearch client drop old index: {self.index_name}")
is_existed = client.indices.exists(index=self.index_name)
if is_existed:
client.indices.delete(index=self.index_name)
self._create_index(client)

@classmethod
def config_cls(cls) -> AWSOpenSearchConfig:
return AWSOpenSearchConfig

@classmethod
def case_config_cls(
cls, index_type: IndexType | None = None
) -> AWSOpenSearchIndexConfig:
return AWSOpenSearchIndexConfig

def _create_index(self, client: OpenSearch):
settings = {
"index": {
"knn": True,
# "number_of_shards": 5,
# "refresh_interval": "600s",
}
}
mappings = {
"properties": {
self.id_col_name: {"type": "integer"},
**{
categoryCol: {"type": "keyword"}
for categoryCol in self.category_col_names
},
self.vector_col_name: {
"type": "knn_vector",
"dimension": self.dim,
"method": self.case_config.index_param(),
},
}
}
try:
client.indices.create(
index=self.index_name, body=dict(settings=settings, mappings=mappings)
)
except Exception as e:
log.warning(f"Failed to create index: {self.index_name} error: {str(e)}")
raise e from None

@contextmanager
def init(self) -> None:
"""connect to elasticsearch"""
self.client = OpenSearch(**self.db_config)

yield
# self.client.transport.close()
self.client = None
del self.client

def insert_embeddings(
self,
embeddings: Iterable[list[float]],
metadata: list[int],
**kwargs,
) -> tuple[int, Exception]:
"""Insert the embeddings to the elasticsearch."""
assert self.client is not None, "should self.init() first"

insert_data = []
for i in range(len(embeddings)):
insert_data.append({"index": {"_index": self.index_name, "_id": metadata[i]}})
insert_data.append({self.vector_col_name: embeddings[i]})
try:
resp = self.client.bulk(insert_data)
log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
resp = self.client.indices.stats(self.index_name)
log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}")
return (len(embeddings), None)
except Exception as e:
log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}")
time.sleep(10)
return self.insert_embeddings(embeddings, metadata)

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
) -> list[int]:
"""Get k most similar embeddings to query vector.
Args:
query(list[float]): query embedding to look up documents similar to.
k(int): Number of most similar embeddings to return. Defaults to 100.
filters(dict, optional): filtering expression to filter the data while searching.
Returns:
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
"""
assert self.client is not None, "should self.init() first"

body = {
"size": k,
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
}
try:
resp = self.client.search(index=self.index_name, body=body)
log.info(f'Search took: {resp["took"]}')
log.info(f'Search shards: {resp["_shards"]}')
log.info(f'Search hits total: {resp["hits"]["total"]}')
result = [int(d["_id"]) for d in resp["hits"]["hits"]]
# log.info(f'success! length={len(res)}')

return result
except Exception as e:
log.warning(f"Failed to search: {self.index_name} error: {str(e)}")
raise e from None

def optimize(self):
"""optimize will be called between insertion and search in performance cases."""
pass

def ready_to_load(self):
"""ready_to_load will be called before load in load cases."""
pass
60 changes: 60 additions & 0 deletions vectordb_bench/backend/clients/aws_opensearch/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from enum import Enum
from pydantic import SecretStr, BaseModel

from ..api import DBConfig, DBCaseConfig, MetricType, IndexType


class AWSOpenSearchConfig(DBConfig, BaseModel):
host: str = (
"xxxxxx.us-west-2.es.amazonaws.com"
)
port: int = 443
user: str = "admin"
password: SecretStr = "xxxxxx"

def to_dict(self) -> dict:
return {
"hosts": [{'host': self.host, 'port': self.port}],
"http_auth": (self.user, self.password.get_secret_value()),
"use_ssl": True,
"http_compress": True,
"verify_certs": True,
"ssl_assert_hostname": False,
"ssl_show_warn": False,
"timeout": 600,
}


class AWSOS_Engine(Enum):
nmslib = "nmslib"
faiss = "faiss"
lucene = "Lucene"


class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType = MetricType.L2
engine: AWSOS_Engine = AWSOS_Engine.nmslib
efConstruction: int = 360
M: int = 30

def parse_metric(self) -> str:
if self.metric_type == MetricType.IP:
return "innerproduct" # only support faiss / nmslib, not for Lucene.
elif self.metric_type == MetricType.COSINE:
return "cosinesimil"
return "l2"

def index_param(self) -> dict:
params = {
"name": "hnsw",
"space_type": self.parse_metric(),
"engine": self.engine.value,
"parameters": {
"ef_construction": self.efConstruction,
"m": self.M
}
}
return params

def search_param(self) -> dict:
return {}
125 changes: 125 additions & 0 deletions vectordb_bench/backend/clients/aws_opensearch/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import time, random
from opensearchpy import OpenSearch
from opensearch_dsl import Search, Document, Text, Keyword

_HOST = 'xxxxxx.us-west-2.es.amazonaws.com'
_PORT = 443
_AUTH = ('admin', 'xxxxxx') # For testing only. Don't store credentials in code.

_INDEX_NAME = 'my-dsl-index'
_BATCH = 100
_ROWS = 100
_DIM = 128
_TOPK = 10


def create_client():
client = OpenSearch(
hosts=[{'host': _HOST, 'port': _PORT}],
http_compress=True, # enables gzip compression for request bodies
http_auth=_AUTH,
use_ssl=True,
verify_certs=True,
ssl_assert_hostname=False,
ssl_show_warn=False,
)
return client


def create_index(client, index_name):
settings = {
"index": {
"knn": True,
"number_of_shards": 1,
"refresh_interval": "5s",
}
}
mappings = {
"properties": {
"embedding": {
"type": "knn_vector",
"dimension": _DIM,
"method": {
"engine": "nmslib",
"name": "hnsw",
"space_type": "l2",
"parameters": {
"ef_construction": 128,
"m": 24,
}
}
}
}
}

response = client.indices.create(index=index_name, body=dict(settings=settings, mappings=mappings))
print('\nCreating index:')
print(response)


def delete_index(client, index_name):
response = client.indices.delete(index=index_name)
print('\nDeleting index:')
print(response)


def bulk_insert(client, index_name):
# Perform bulk operations
ids = [i for i in range(_ROWS)]
vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]

docs = []
for i in range(0, _ROWS, _BATCH):
docs.clear()
for j in range(0, _BATCH):
docs.append({"index": {"_index": index_name, "_id": ids[i+j]}})
docs.append({"embedding": vec[i+j]})
response = client.bulk(docs)
print('\nAdding documents:', len(response['items']), response['errors'])
response = client.indices.stats(index_name)
print('\nTotal document count in index:', response['_all']['primaries']['indexing']['index_total'])


def search(client, index_name):
# Search for the document.
search_body = {
"size": _TOPK,
"query": {
"knn": {
"embedding": {
"vector": [random.random() for _ in range(_DIM)],
"k": _TOPK,
}
}
}
}
while True:
response = client.search(index=index_name, body=search_body)
print(f'\nSearch took: {response["took"]}')
print(f'\nSearch shards: {response["_shards"]}')
print(f'\nSearch hits total: {response["hits"]["total"]}')
result = response["hits"]["hits"]
if len(result) != 0:
print('\nSearch results:')
for hit in response["hits"]["hits"]:
print(hit["_id"], hit["_score"])
break
else:
print('\nSearch not ready, sleep 1s')
time.sleep(1)


def main():
client = create_client()
try:
create_index(client, _INDEX_NAME)
bulk_insert(client, _INDEX_NAME)
search(client, _INDEX_NAME)
delete_index(client, _INDEX_NAME)
except Exception as e:
print(e)
delete_index(client, _INDEX_NAME)


if __name__ == '__main__':
main()
Loading

0 comments on commit 676e68c

Please sign in to comment.