From 7e3b041bb87ec1065060088c98b94629a7fc0e65 Mon Sep 17 00:00:00 2001
From: Khoroshevskyi <sasha99250@gmail.com>
Date: Fri, 21 Jun 2024 14:00:37 -0400
Subject: [PATCH] sample add fix

---
 pepdbagent/db_utils.py        |  2 +-
 pepdbagent/modules/project.py | 11 ++++----
 pepdbagent/modules/sample.py  | 51 +++++++++++++++++++----------------
 pepdbagent/utils.py           |  5 ++++
 tests/test_pepagent.py        |  8 +++---
 5 files changed, 43 insertions(+), 34 deletions(-)

diff --git a/pepdbagent/db_utils.py b/pepdbagent/db_utils.py
index 7dc7a3a..fd5c7cd 100644
--- a/pepdbagent/db_utils.py
+++ b/pepdbagent/db_utils.py
@@ -139,7 +139,7 @@ class Samples(Base):
     sample_name: Mapped[Optional[str]] = mapped_column()
     guid: Mapped[Optional[str]] = mapped_column(nullable=False, unique=True)
 
-    parent_guid: Mapped[Optional[int]] = mapped_column(
+    parent_guid: Mapped[Optional[str]] = mapped_column(
         ForeignKey("samples.guid", ondelete="CASCADE"),
         nullable=True,
         doc="Parent sample id. Used to create a hierarchy of samples.",
diff --git a/pepdbagent/modules/project.py b/pepdbagent/modules/project.py
index 1675068..c9697f8 100644
--- a/pepdbagent/modules/project.py
+++ b/pepdbagent/modules/project.py
@@ -2,7 +2,6 @@
 import json
 import logging
 from typing import Union, List, NoReturn, Dict
-import uuid
 
 import peppy
 from sqlalchemy import and_, delete, select
@@ -36,7 +35,7 @@
     SampleTableUpdateError,
 )
 from pepdbagent.models import UpdateItems, UpdateModel, ProjectDict
-from pepdbagent.utils import create_digest, registry_path_converter, order_samples
+from pepdbagent.utils import create_digest, registry_path_converter, order_samples, generate_guid
 
 
 _LOGGER = logging.getLogger(PKG_NAME)
@@ -547,7 +546,7 @@ def update(
                             f"Please provide it to update samples, or use overwrite method."
                         )
 
-                    self._update_samples_with_ids(
+                    self._update_samples(
                         project_id=found_prj.id,
                         samples_list=update_dict["samples"],
                         sample_name_key=update_dict["config"].get(
@@ -574,7 +573,7 @@ def update(
         else:
             raise ProjectNotFoundError("No items will be updated!")
 
-    def _update_samples_with_ids(
+    def _update_samples(
         self,
         project_id: int,
         samples_list: List[Dict[str, str]],
@@ -600,7 +599,7 @@ def _update_samples_with_ids(
                 new_sample[PEPHUB_SAMPLE_ID_KEY] for new_sample in samples_list
             }
             new_samples_dict: dict = {
-                new_sample[PEPHUB_SAMPLE_ID_KEY] or str(uuid.uuid4()): new_sample
+                new_sample[PEPHUB_SAMPLE_ID_KEY] or generate_guid(): new_sample
                 for new_sample in samples_list
             }
 
@@ -887,7 +886,7 @@ def _add_samples_to_project(
                 row_number=row_number,
                 sample_name=sample.get(sample_table_index),
                 parent_guid=previous_sample_guid,
-                guid=str(uuid.uuid4()),
+                guid=generate_guid(),
             )
             projects_sa.samples_mapping.append(sample)
             previous_sample_guid = sample.guid
diff --git a/pepdbagent/modules/sample.py b/pepdbagent/modules/sample.py
index 5106fed..e2677cc 100644
--- a/pepdbagent/modules/sample.py
+++ b/pepdbagent/modules/sample.py
@@ -4,11 +4,10 @@
 
 import peppy
 from peppy.const import SAMPLE_TABLE_INDEX_KEY
-from sqlalchemy import select, and_, func
+from sqlalchemy import select, and_
 from sqlalchemy.orm import Session
 from sqlalchemy.orm.attributes import flag_modified
 
-
 from pepdbagent.const import (
     DEFAULT_TAG,
     PKG_NAME,
@@ -16,6 +15,7 @@
 from pepdbagent.exceptions import SampleNotFoundError, SampleAlreadyExistsError
 
 from pepdbagent.db_utils import BaseEngine, Samples, Projects
+from pepdbagent.utils import generate_guid, order_samples
 
 _LOGGER = logging.getLogger(PKG_NAME)
 
@@ -216,29 +216,10 @@ def add(
                 raise KeyError(
                     f"Sample index key {project_mapping.config.get(SAMPLE_TABLE_INDEX_KEY, 'sample_name')} not found in sample dict"
                 )
-            project_where_statement = (
-                Samples.project_id
-                == select(Projects.id)
-                .where(
-                    and_(
-                        Projects.namespace == namespace,
-                        Projects.name == name,
-                        Projects.tag == tag,
-                    ),
-                )
-                .scalar_subquery()
-            )
             statement = select(Samples).where(
-                and_(project_where_statement, Samples.sample_name == sample_name)
+                and_(Samples.project_id == project_mapping.id, Samples.sample_name == sample_name)
             )
-
             sample_mapping = session.scalar(statement)
-            row_number = (
-                session.execute(
-                    select(func.max(Samples.row_number)).where(project_where_statement)
-                ).one()[0]
-                or 0
-            )
 
             if sample_mapping and not overwrite:
                 raise SampleAlreadyExistsError(
@@ -257,9 +238,11 @@ def add(
             else:
                 sample_mapping = Samples(
                     sample=sample_dict,
-                    row_number=row_number + 1,
+                    row_number=0,
                     project_id=project_mapping.id,
                     sample_name=sample_name,
+                    guid=generate_guid(),
+                    parent_guid=self._get_last_sample_guid(project_mapping.id),
                 )
                 project_mapping.number_of_samples += 1
                 project_mapping.last_update_date = datetime.datetime.now(datetime.timezone.utc)
@@ -267,6 +250,28 @@ def add(
                 session.add(sample_mapping)
                 session.commit()
 
+    def _get_last_sample_guid(self, project_id: int) -> str:
+        """
+        Get last sample guid from the project
+
+        :param project_id: project_id of the project
+        :return: guid of the last sample
+        """
+        statement = select(Samples).where(Samples.project_id == project_id)
+        with Session(self._sa_engine) as session:
+            samples_results = session.scalars(statement)
+
+            result_dict = {}
+            for sample in samples_results:
+                sample_dict = sample.sample
+
+                result_dict[sample.guid] = {
+                    "sample": sample_dict,
+                    "guid": sample.guid,
+                    "parent_guid": sample.parent_guid,
+                }
+            return order_samples(result_dict)[-1]["guid"]
+
     def delete(
         self,
         namespace: str,
diff --git a/pepdbagent/utils.py b/pepdbagent/utils.py
index 7c1d32b..4bb7103 100644
--- a/pepdbagent/utils.py
+++ b/pepdbagent/utils.py
@@ -3,6 +3,7 @@
 from collections.abc import Iterable
 from hashlib import md5
 from typing import Tuple, Union, List
+import uuid
 
 import ubiquerg
 from peppy.const import SAMPLE_RAW_DICT_KEY
@@ -150,3 +151,7 @@ def order_samples(results: dict) -> List[dict]:
         else:
             current = None
     return ordered_sequence
+
+
+def generate_guid() -> str:
+    return str(uuid.uuid4())
diff --git a/tests/test_pepagent.py b/tests/test_pepagent.py
index 4d4b0b3..f1f9967 100644
--- a/tests/test_pepagent.py
+++ b/tests/test_pepagent.py
@@ -1103,7 +1103,7 @@ def test_update(self, initiate_pepdb_con, namespace, name, sample_name):
             sample_name=sample_name,
             update_dict={"organism": "butterfly"},
         )
-        one_sample = initiate_pepdb_con.sample.get(namespace, name, sample_name)
+        one_sample = initiate_pepdb_con.sample.get(namespace, name, sample_name, raw=False)
         assert one_sample.organism == "butterfly"
 
     @pytest.mark.parametrize(
@@ -1120,7 +1120,7 @@ def test_update_sample_name(self, initiate_pepdb_con, namespace, name, sample_na
             sample_name=sample_name,
             update_dict={"sample_name": "butterfly"},
         )
-        one_sample = initiate_pepdb_con.sample.get(namespace, name, "butterfly")
+        one_sample = initiate_pepdb_con.sample.get(namespace, name, "butterfly", raw=False)
         assert one_sample.sample_name == "butterfly"
 
     @pytest.mark.parametrize(
@@ -1212,10 +1212,10 @@ def test_delete_sample(self, initiate_pepdb_con, namespace, name, sample_name):
         ],
     )
     def test_add_sample(self, initiate_pepdb_con, namespace, name, tag, sample_dict):
-        prj = initiate_pepdb_con.project.get(namespace, name)
+        prj = initiate_pepdb_con.project.get(namespace, name, raw=False)
         initiate_pepdb_con.sample.add(namespace, name, tag, sample_dict)
 
-        prj2 = initiate_pepdb_con.project.get(namespace, name)
+        prj2 = initiate_pepdb_con.project.get(namespace, name, raw=False)
 
         assert len(prj.samples) + 1 == len(prj2.samples)
         assert prj2.samples[-1].sample_name == sample_dict["sample_name"]