From 9b489ce823f3a6ae7cd2ba9505553af3c4018049 Mon Sep 17 00:00:00 2001 From: Khoroshevskyi Date: Mon, 5 Feb 2024 23:42:49 +0100 Subject: [PATCH] Fixed view update and errors --- pepdbagent/exceptions.py | 5 ++ pepdbagent/modules/view.py | 121 ++++++++++++++++++------------------- tests/test_pepagent.py | 4 ++ 3 files changed, 67 insertions(+), 63 deletions(-) diff --git a/pepdbagent/exceptions.py b/pepdbagent/exceptions.py index 9f81d65..4b90fac 100644 --- a/pepdbagent/exceptions.py +++ b/pepdbagent/exceptions.py @@ -71,6 +71,11 @@ def __init__(self, msg=""): super().__init__(f"""View does not exist. {msg}""") +class SampleNotInViewError(PEPDatabaseAgentError): + def __init__(self, msg=""): + super().__init__(f"""Sample is not in the view. {msg}""") + + class SampleAlreadyInView(PEPDatabaseAgentError): """ Sample is already in the view exception diff --git a/pepdbagent/modules/view.py b/pepdbagent/modules/view.py index a43bddd..ef4fc30 100644 --- a/pepdbagent/modules/view.py +++ b/pepdbagent/modules/view.py @@ -19,6 +19,7 @@ ProjectNotFoundError, SampleNotFoundError, ViewAlreadyExistsError, + SampleNotInViewError, ) from pepdbagent.db_utils import BaseEngine, Samples, Projects, Views, ViewSampleAssociation @@ -163,41 +164,40 @@ def create( Projects.tag == view_dict.project_tag, ) ) - - with Session(self._sa_engine) as sa_session: - project = sa_session.scalar(project_statement) - if not project: - raise ProjectNotFoundError( - f"Project {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag} does not exist" - ) - view = Views( - name=view_name, - description=description, - project_mapping=project, - ) - sa_session.add(view) - - for sample_name in view_dict.sample_list: - sample_statement = select(Samples.id).where( - and_( - Samples.project_id == project.id, - Samples.sample_name == sample_name, + try: + with Session(self._sa_engine) as sa_session: + project = sa_session.scalar(project_statement) + if not project: + raise ProjectNotFoundError( + f"Project {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag} does not exist" ) + view = Views( + name=view_name, + description=description, + project_mapping=project, ) - sample_id = sa_session.execute(sample_statement).one()[0] - if not sample_id: - raise SampleNotFoundError( - f"Sample {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag}:{sample_name} does not exist" + sa_session.add(view) + + for sample_name in view_dict.sample_list: + sample_statement = select(Samples.id).where( + and_( + Samples.project_id == project.id, + Samples.sample_name == sample_name, + ) ) - try: - sa_session.add(ViewSampleAssociation(sample_id=sample_id, view=view)) + sample_id = sa_session.execute(sample_statement).one()[0] + if not sample_id: + raise SampleNotFoundError( + f"Sample {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag}:{sample_name} does not exist" + ) - except IntegrityError: - raise ViewAlreadyExistsError( - f"View {view_name} of the project {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag} already exists" - ) + sa_session.add(ViewSampleAssociation(sample_id=sample_id, view=view)) - sa_session.commit() + sa_session.commit() + except IntegrityError: + raise ViewAlreadyExistsError( + f"View {view_name} of the project {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag} already exists" + ) def delete( self, @@ -265,34 +265,32 @@ def add_sample( Views.name == view_name, ) ) - - with Session(self._sa_engine) as sa_session: - view = sa_session.scalar(view_statement) - if not view: - raise ViewNotFoundError( - f"View {view_name} of the project {namespace}/{name}:{tag} does not exist" - ) - for sample_name_one in sample_name: - sample_statement = select(Samples).where( - and_( - Samples.project_id == view.project_mapping.id, - Samples.sample_name == sample_name_one, + try: + with Session(self._sa_engine) as sa_session: + view = sa_session.scalar(view_statement) + if not view: + raise ViewNotFoundError( + f"View {view_name} of the project {namespace}/{name}:{tag} does not exist" ) - ) - sample = sa_session.scalar(sample_statement) - if not sample: - raise SampleNotFoundError( - f"Sample {namespace}/{name}:{tag}:{sample_name} does not exist" + for sample_name_one in sample_name: + sample_statement = select(Samples).where( + and_( + Samples.project_id == view.project_mapping.id, + Samples.sample_name == sample_name_one, + ) ) - try: + sample = sa_session.scalar(sample_statement) + if not sample: + raise SampleNotFoundError( + f"Sample {namespace}/{name}:{tag}:{sample_name} does not exist" + ) + sa_session.add(ViewSampleAssociation(sample=sample, view=view)) sa_session.commit() - except IntegrityError: - raise SampleAlreadyInView( - f"Sample {namespace}/{name}:{tag}:{sample_name} already in view {view_name}" - ) - - return None + except IntegrityError: + raise SampleAlreadyInView( + f"Sample {namespace}/{name}:{tag}:{sample_name} already in view {view_name}" + ) def remove_sample( self, @@ -335,21 +333,18 @@ def remove_sample( ) ) sample = sa_session.scalar(sample_statement) + if sample.id not in [view_sample.sample_id for view_sample in view.samples]: + raise SampleNotInViewError( + f"Sample {namespace}/{name}:{tag}:{sample_name} does not exist in view {view_name}" + ) delete_statement = delete(ViewSampleAssociation).where( and_( ViewSampleAssociation.sample_id == sample.id, ViewSampleAssociation.view_id == view.id, ) ) - try: - sa_session.execute(delete_statement) - sa_session.commit() - except IntegrityError: - raise SampleNotFoundError( - f"Sample {namespace}/{name}:{tag}:{sample_name} does not exist in view {view_name}" - ) - - return None + sa_session.execute(delete_statement) + sa_session.commit() def get_snap_view( self, namespace: str, name: str, tag: str, sample_name_list: List[str], raw: bool = False diff --git a/tests/test_pepagent.py b/tests/test_pepagent.py index a7dedf6..d03ea14 100644 --- a/tests/test_pepagent.py +++ b/tests/test_pepagent.py @@ -15,6 +15,7 @@ SampleNotFoundError, ViewNotFoundError, SampleAlreadyInView, + SampleNotInViewError, ) from .conftest import DNS @@ -1189,6 +1190,9 @@ def test_remove_sample_from_view(self, initiate_pepdb_con, namespace, name, samp assert len(initiate_pepdb_con.view.get(namespace, name, "default", "view1").samples) == 1 assert len(initiate_pepdb_con.project.get(namespace, name).samples) == 4 + with pytest.raises(SampleNotInViewError): + initiate_pepdb_con.view.remove_sample(namespace, name, "default", "view1", sample_name) + @pytest.mark.parametrize( "namespace, name, sample_name", [