Skip to content

Commit

Permalink
fixed geniml updates
Browse files Browse the repository at this point in the history
  • Loading branch information
khoroshevskyi committed Sep 27, 2024
1 parent db4f3ad commit 1a1865b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 24 deletions.
2 changes: 1 addition & 1 deletion bbconf/config_parser/bedbaseconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import yacman
import zarr
from botocore.exceptions import BotoCoreError, EndpointConnectionError
from geniml.region2vec import Region2VecExModel
from geniml.region2vec.main import Region2VecExModel
from geniml.search import BED2BEDSearchInterface, QdrantBackend, Text2BEDSearchInterface
from geniml.search.query2vec import BED2Vec, Text2Vec
from pephubclient import PEPHubClient
Expand Down
40 changes: 24 additions & 16 deletions bbconf/modules/bedfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from geniml.bbclient import BBClient
from geniml.io import RegionSet
from genimtools.tokenizers import RegionSet as GRegionSet
from gtars.tokenizers import RegionSet as GRegionSet
from pephubclient.exceptions import ResponseError
from qdrant_client.models import Distance, PointIdsList, VectorParams
from sqlalchemy import and_, delete, func, select
Expand Down Expand Up @@ -754,15 +754,33 @@ def upload_file_qdrant(
:param payload: additional metadata to store alongside vectors
:return: None
"""

_LOGGER.debug(f"Adding bed file to qdrant. bed_id: {bed_id}")
bed_embedding = self._embed_file(bed_file)

self._qdrant_engine.load(
ids=[bed_id],
vectors=bed_embedding,
payloads=[{**payload}],
)
return None

def _embed_file(self, bed_file: Union[str, RegionSet]) -> np.ndarray:
"""
Create embeding for bed file
:param bed_id: bed file id
:param bed_file: path to the bed file, or RegionSet object
:return np array of embeddings
"""
if self._qdrant_engine is None:
raise QdrantInstanceNotInitializedError

if not self._config.r2v:
raise BedBaseConfError(
"Could not add add region to qdrant. Invalid type, or path. "
)

_LOGGER.debug(f"Adding bed file to qdrant. bed_id: {bed_id}")
if isinstance(bed_file, str):
bed_region_set = GRegionSet(bed_file)
elif isinstance(bed_file, RegionSet) or isinstance(bed_file, GRegionSet):
Expand All @@ -771,19 +789,9 @@ def upload_file_qdrant(
raise BedBaseConfError(
"Could not add add region to qdrant. Invalid type, or path. "
)
# Not really working
# bed_embedding = np.mean([self._config.r2v.encode(r) for r in bed_region_set], axis=0)

bed_embedding = np.mean(self._config.r2v.encode(bed_region_set), axis=0)

# Upload bed file vector to the database
vec_dim = bed_embedding.shape[0]
self._qdrant_engine.load(
ids=[bed_id],
vectors=bed_embedding.reshape(1, vec_dim),
payloads=[{**payload}],
)
return None
return bed_embedding.reshape(1, vec_dim)

def text_to_bed_search(
self, query: str, limit: int = 10, offset: int = 0
Expand Down Expand Up @@ -814,7 +822,7 @@ def text_to_bed_search(
if result_meta:
results_list.append(QdrantSearchResult(**result, metadata=result_meta))
return BedListSearchResult(
count=self.bb_agent.get_stats.bedfiles_number,
count=self.bb_agent.get_stats().bedfiles_number,
limit=limit,
offset=offset,
results=results_list,
Expand Down Expand Up @@ -842,7 +850,7 @@ def bed_to_bed_search(
if result_meta:
results_list.append(QdrantSearchResult(**result, metadata=result_meta))
return BedListSearchResult(
count=self.bb_agent.get_stats.bedfiles_number,
count=self.bb_agent.get_stats().bedfiles_number,
limit=limit,
offset=offset,
results=results_list,
Expand Down
6 changes: 3 additions & 3 deletions manual_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import zarr
from dotenv import load_dotenv
from geniml.io import RegionSet
from genimtools.tokenizers import TreeTokenizer
from genimtools.utils import read_tokens_from_gtok
from gtars.tokenizers import TreeTokenizer
from gtars.utils import read_tokens_from_gtok

# from genimtools.tokenizers import RegionSet
# from gtars.tokenizers import RegionSet


load_dotenv()
Expand Down
4 changes: 1 addition & 3 deletions requirements/requirements-all.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
yacman >= 0.9.1
sqlalchemy >= 2.0.0
geniml >= 0.4.1
geniml >= 0.4.2
psycopg >= 3.1.15
colorlogs
pydantic >= 2.6.4
Expand All @@ -11,6 +11,4 @@ sqlalchemy_schemadisplay
zarr
pyyaml >= 6.0.1 # for s3fs because of the errors
s3fs >= 2024.3.1
# quick fix for gesnim
scipy <= 1.11.0
pandas
2 changes: 1 addition & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@pytest.mark.skipif(SERVICE_UNAVAILABLE, reason="Database is not available")
def test_get_stats(bbagent_obj):
with ContextManagerDBTesting(config=bbagent_obj.config, add_data=True, bedset=True):
return_result = bbagent_obj.get_stats
return_result = bbagent_obj.get_stats()

assert return_result
assert return_result.bedfiles_number == 1
Expand Down

0 comments on commit 1a1865b

Please sign in to comment.