diff --git a/bbconf/config_parser/bedbaseconfig.py b/bbconf/config_parser/bedbaseconfig.py index 3eeb875..66a3c1e 100644 --- a/bbconf/config_parser/bedbaseconfig.py +++ b/bbconf/config_parser/bedbaseconfig.py @@ -10,7 +10,7 @@ import yacman import zarr from botocore.exceptions import BotoCoreError, EndpointConnectionError -from geniml.region2vec import Region2VecExModel +from geniml.region2vec.main import Region2VecExModel from geniml.search import BED2BEDSearchInterface, QdrantBackend, Text2BEDSearchInterface from geniml.search.query2vec import BED2Vec, Text2Vec from pephubclient import PEPHubClient diff --git a/bbconf/modules/bedfiles.py b/bbconf/modules/bedfiles.py index c85d17a..a556d4a 100644 --- a/bbconf/modules/bedfiles.py +++ b/bbconf/modules/bedfiles.py @@ -6,7 +6,7 @@ import numpy as np from geniml.bbclient import BBClient from geniml.io import RegionSet -from genimtools.tokenizers import RegionSet as GRegionSet +from gtars.tokenizers import RegionSet as GRegionSet from pephubclient.exceptions import ResponseError from qdrant_client.models import Distance, PointIdsList, VectorParams from sqlalchemy import and_, delete, func, select @@ -754,15 +754,33 @@ def upload_file_qdrant( :param payload: additional metadata to store alongside vectors :return: None """ + + _LOGGER.debug(f"Adding bed file to qdrant. bed_id: {bed_id}") + bed_embedding = self._embed_file(bed_file) + + self._qdrant_engine.load( + ids=[bed_id], + vectors=bed_embedding, + payloads=[{**payload}], + ) + return None + + def _embed_file(self, bed_file: Union[str, RegionSet]) -> np.ndarray: + """ + Create embeding for bed file + + :param bed_id: bed file id + :param bed_file: path to the bed file, or RegionSet object + + :return np array of embeddings + """ if self._qdrant_engine is None: raise QdrantInstanceNotInitializedError - if not self._config.r2v: raise BedBaseConfError( "Could not add add region to qdrant. Invalid type, or path. " ) - _LOGGER.debug(f"Adding bed file to qdrant. bed_id: {bed_id}") if isinstance(bed_file, str): bed_region_set = GRegionSet(bed_file) elif isinstance(bed_file, RegionSet) or isinstance(bed_file, GRegionSet): @@ -771,19 +789,9 @@ def upload_file_qdrant( raise BedBaseConfError( "Could not add add region to qdrant. Invalid type, or path. " ) - # Not really working - # bed_embedding = np.mean([self._config.r2v.encode(r) for r in bed_region_set], axis=0) - bed_embedding = np.mean(self._config.r2v.encode(bed_region_set), axis=0) - - # Upload bed file vector to the database vec_dim = bed_embedding.shape[0] - self._qdrant_engine.load( - ids=[bed_id], - vectors=bed_embedding.reshape(1, vec_dim), - payloads=[{**payload}], - ) - return None + return bed_embedding.reshape(1, vec_dim) def text_to_bed_search( self, query: str, limit: int = 10, offset: int = 0 @@ -814,7 +822,7 @@ def text_to_bed_search( if result_meta: results_list.append(QdrantSearchResult(**result, metadata=result_meta)) return BedListSearchResult( - count=self.bb_agent.get_stats.bedfiles_number, + count=self.bb_agent.get_stats().bedfiles_number, limit=limit, offset=offset, results=results_list, @@ -842,7 +850,7 @@ def bed_to_bed_search( if result_meta: results_list.append(QdrantSearchResult(**result, metadata=result_meta)) return BedListSearchResult( - count=self.bb_agent.get_stats.bedfiles_number, + count=self.bb_agent.get_stats().bedfiles_number, limit=limit, offset=offset, results=results_list, diff --git a/manual_testing.py b/manual_testing.py index b71e15b..cb24b56 100644 --- a/manual_testing.py +++ b/manual_testing.py @@ -4,10 +4,10 @@ import zarr from dotenv import load_dotenv from geniml.io import RegionSet -from genimtools.tokenizers import TreeTokenizer -from genimtools.utils import read_tokens_from_gtok +from gtars.tokenizers import TreeTokenizer +from gtars.utils import read_tokens_from_gtok -# from genimtools.tokenizers import RegionSet +# from gtars.tokenizers import RegionSet load_dotenv() diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index 538c13e..7851ff1 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -1,6 +1,6 @@ yacman >= 0.9.1 sqlalchemy >= 2.0.0 -geniml >= 0.4.1 +geniml >= 0.4.2 psycopg >= 3.1.15 colorlogs pydantic >= 2.6.4 @@ -11,6 +11,4 @@ sqlalchemy_schemadisplay zarr pyyaml >= 6.0.1 # for s3fs because of the errors s3fs >= 2024.3.1 -# quick fix for gesnim -scipy <= 1.11.0 pandas \ No newline at end of file diff --git a/tests/test_common.py b/tests/test_common.py index 5b9048d..45ad8af 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -9,7 +9,7 @@ @pytest.mark.skipif(SERVICE_UNAVAILABLE, reason="Database is not available") def test_get_stats(bbagent_obj): with ContextManagerDBTesting(config=bbagent_obj.config, add_data=True, bedset=True): - return_result = bbagent_obj.get_stats + return_result = bbagent_obj.get_stats() assert return_result assert return_result.bedfiles_number == 1