Skip to content

Commit

Permalink
polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
khoroshevskyi committed Oct 14, 2024
1 parent 525aece commit 169002a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 35 deletions.
53 changes: 31 additions & 22 deletions bbconf/config_parser/bedbaseconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import zarr
from botocore.exceptions import BotoCoreError, EndpointConnectionError
from geniml.region2vec.main import Region2VecExModel
from geniml.search import BED2BEDSearchInterface, QdrantBackend, Text2BEDSearchInterface
from geniml.search import BED2BEDSearchInterface, Text2BEDSearchInterface
from geniml.search.query2vec import BED2Vec, Text2Vec
from geniml.search.backends import BiVectorBackend, QdrantBackend
from geniml.search.interfaces import BiVectorSearchInterface
Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(self, config: Union[Path, str]):

self._db_engine = self._init_db_engine()
self._qdrant_engine = self._init_qdrant_backend()
self._qdrant_text_engine = self._init_qdrant_text_backend()
self._t2bsi = self._init_t2bsi_object()
self._b2bsi = self._init_b2bsi_object()
self._r2v = self._init_r2v_object()
Expand Down Expand Up @@ -97,14 +98,14 @@ def db_engine(self) -> BaseEngine:
"""
return self._db_engine

@property
def t2bsi(self) -> Union[Text2BEDSearchInterface, None]:
"""
Get text2bednn object
:return: text2bednn object
"""
return self._t2bsi
# @property
# def t2bsi(self) -> Union[Text2BEDSearchInterface, None]:
# """
# Get text2bednn object
#
# :return: text2bednn object
# """
# return self._t2bsi

@property
def b2bsi(self) -> Union[BED2BEDSearchInterface, None]:
Expand Down Expand Up @@ -134,7 +135,6 @@ def bivec(self) -> BiVectorSearchInterface:

return self._bivec


@property
def qdrant_engine(self) -> QdrantBackend:
"""
Expand Down Expand Up @@ -208,7 +208,7 @@ def _init_qdrant_backend(self) -> QdrantBackend:
"""
try:
return QdrantBackend(
collection=self._config.qdrant.collection,
collection=self._config.qdrant.file_collection,
qdrant_host=self._config.qdrant.host,
qdrant_port=self._config.qdrant.port,
qdrant_api_key=self._config.qdrant.api_key,
Expand All @@ -219,23 +219,30 @@ def _init_qdrant_backend(self) -> QdrantBackend:
f"error in Connection to qdrant! skipping... Error: {err}", UserWarning
)

def _init_qdrant_text_backend(self) -> QdrantBackend:
"""
Create qdrant client text embedding object using credentials provided in config file
:return: QdrantClient
"""

return QdrantBackend(
dim=384,
collection=self.config.qdrant.text_collection,
qdrant_host=self.config.qdrant.host,
qdrant_api_key=self.config.qdrant.api_key,
)

def _init_bivec_object(self) -> Union[BiVectorSearchInterface, None]:
"""
Create BiVectorSearchInterface object using credentials provided in config file
:return: BiVectorSearchInterface
"""
text_backend = QdrantBackend(dim=384,
collection=self.config.qdrant.label_collection,
qdrant_host=self.config.qdrant.host,
qdrant_api_key=self.config.qdrant.api_key,
)
bed_backend = QdrantBackend(
collection=self.config.qdrant.file_collection,
qdrant_host=self.config.qdrant.host,
qdrant_api_key=self.config.qdrant.api_key,

search_backend = BiVectorBackend(
metadata_backend=self._qdrant_text_engine, bed_backend=self._qdrant_engine
)
search_backend = BiVectorBackend(text_backend, bed_backend)
search_interface = BiVectorSearchInterface(
backend=search_backend, query2vec="sentence-transformers/all-MiniLM-L6-v2"
)
Expand Down Expand Up @@ -325,7 +332,9 @@ def _init_r2v_object(self) -> Union[Region2VecExModel, None]:
return Region2VecExModel(self.config.path.region2vec)
except Exception as e:
_LOGGER.error(f"Error in creating Region2VecExModel object: {e}")
warnings.warn(f"Error in creating Region2VecExModel object: {e}", UserWarning)
warnings.warn(
f"Error in creating Region2VecExModel object: {e}", UserWarning
)
return None

def upload_s3(self, file_path: str, s3_path: Union[Path, str]) -> None:
Expand Down
1 change: 1 addition & 0 deletions bbconf/config_parser/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DEFAULT_QDRANT_HOST = "localhost"
DEFAULT_QDRANT_PORT = 6333
DEFAULT_QDRANT_COLLECTION_NAME = "bedbase"
DEFAULT_QDRANT_TEXT_COLLECTION_NAME = "bed_text"
DEFAULT_QDRANT_API_KEY = None

DEFAULT_SERVER_PORT = 80
Expand Down
6 changes: 3 additions & 3 deletions bbconf/config_parser/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DEFAULT_PEPHUB_NAMESPACE,
DEFAULT_PEPHUB_TAG,
DEFAULT_QDRANT_COLLECTION_NAME,
DEFAULT_QDRANT_TEXT_COLLECTION_NAME,
DEFAULT_QDRANT_PORT,
DEFAULT_REGION2_VEC_MODEL,
DEFAULT_S3_BUCKET,
Expand Down Expand Up @@ -52,9 +53,8 @@ class ConfigQdrant(BaseModel):
host: str
port: int = DEFAULT_QDRANT_PORT
api_key: Optional[str] = None
collection: str = DEFAULT_QDRANT_COLLECTION_NAME
label_collection: Optional[str] = "bed_text"
file_collection: Optional[str] = "bedbase2"
file_collection: str = DEFAULT_QDRANT_COLLECTION_NAME
text_collection: Optional[str] = DEFAULT_QDRANT_TEXT_COLLECTION_NAME


class ConfigServer(BaseModel):
Expand Down
19 changes: 9 additions & 10 deletions bbconf/modules/bedfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def get_embedding(self, identifier: str) -> BedEmbeddingResult:
if not self.exists(identifier):
raise BEDFileNotFoundError(f"Bed file with id: {identifier} not found.")
result = self._qdrant_engine.qd_client.retrieve(
collection_name=self._config.config.qdrant.collection,
collection_name=self._config.config.qdrant.file_collection,
ids=[identifier],
with_vectors=True,
with_payload=True,
Expand All @@ -362,14 +362,13 @@ def get_ids_list(
:param offset: offset to start from
:param genome: filter by genome
:param bed_type: filter by bed type. e.g. 'bed6+4'
:param full: if True, return full metadata, including statistics, files, and raw metadata from pephub
:return: list of bed file identifiers
"""
statement = select(Bed)
count_statement = select(func.count(Bed.id))

# TODO: make it generic, like in pephub
# TODO: make it generic, like in PEPhub
if genome:
statement = statement.where(and_(Bed.genome_alias == genome))
count_statement = count_statement.where(and_(Bed.genome_alias == genome))
Expand Down Expand Up @@ -769,7 +768,7 @@ def _embed_file(self, bed_file: Union[str, RegionSet]) -> np.ndarray:
"""
Create embeding for bed file
:param bed_id: bed file id
:param bed_file: bed file id
:param bed_file: path to the bed file, or RegionSet object
:return np array of embeddings
Expand Down Expand Up @@ -806,9 +805,9 @@ def text_to_bed_search(
:return: list of bed file metadata
"""
_LOGGER.info(f"Looking for: {query}")
_LOGGER.info(f"Using backend: {self._config.t2bsi}")

# _LOGGER.info(f"Using backend: {self._config.t2bsi}")
# results = self._config.t2bsi.query_search(query, limit=limit, offset=offset)

results = self._config.bivec.query_search(query, limit=limit, offset=offset)
results_list = []
for result in results:
Expand Down Expand Up @@ -909,7 +908,7 @@ def delete_qdrant_point(self, identifier: str) -> None:
"""

result = self._config.qdrant_engine.qd_client.delete(
collection_name=self._config.config.qdrant.collection,
collection_name=self._config.config.qdrant.file_collection,
points_selector=PointIdsList(
points=[identifier],
),
Expand All @@ -921,7 +920,7 @@ def create_qdrant_collection(self) -> bool:
Create qdrant collection for bed files.
"""
return self._config.qdrant_engine.qd_client.create_collection(
collection_name=self._config.config.qdrant.collection,
collection_name=self._config.config.qdrant.file_collection,
vectors_config=VectorParams(size=100, distance=Distance.DOT),
)

Expand Down Expand Up @@ -1072,7 +1071,7 @@ def _add_zarr_s3(
"Set overwrite to True to overwrite it."
)

return os.path.join(ZARR_TOKENIZED_FOLDER, path)
return str(os.path.join(ZARR_TOKENIZED_FOLDER, path))

def get_tokenized(self, bed_id: str, universe_id: str) -> TokenizedBedResponse:
"""
Expand Down Expand Up @@ -1140,7 +1139,7 @@ def _get_tokenized_path(self, bed_id: str, universe_id: str) -> str:
),
)
tokenized_object = session.scalar(statement)
return tokenized_object.path
return str(tokenized_object.path)

def exist_tokenized(self, bed_id: str, universe_id: str) -> bool:
"""
Expand Down

0 comments on commit 169002a

Please sign in to comment.