From 326f8f1c07d0b862a196e9f6a3915a75bc42a019 Mon Sep 17 00:00:00 2001
From: kanduric <chakri.co@gmail.com>
Date: Wed, 7 Feb 2024 14:12:38 +0100
Subject: [PATCH] -changes how positive labels are assigned -changes how
 repertoire components are concatenated - filters user-supplied signal for
 legal gene pairs

---
 .../RepComponentConcatenation.py              | 18 ++++--
 .../SignalComponentGeneration.py              |  4 +-
 simAIRR/util/utilities.py                     | 37 +++++++++++
 simAIRR/workflows/Workflows.py                |  9 ++-
 tests/util/test_utilities.py                  | 62 ++++++++++++++++++-
 tests/workflows/test_Workflows.py             | 15 ++++-
 6 files changed, 133 insertions(+), 12 deletions(-)

diff --git a/simAIRR/concatenate_repertoire_components/RepComponentConcatenation.py b/simAIRR/concatenate_repertoire_components/RepComponentConcatenation.py
index 44de477..56f9f09 100644
--- a/simAIRR/concatenate_repertoire_components/RepComponentConcatenation.py
+++ b/simAIRR/concatenate_repertoire_components/RepComponentConcatenation.py
@@ -4,12 +4,12 @@
 from multiprocessing import Pool
 import pandas as pd
 import numpy as np
-from simAIRR.util.utilities import makedir_if_not_exists
+from simAIRR.util.utilities import makedir_if_not_exists, concatenate_dataframes_with_replacement
 
 
 class RepComponentConcatenation:
     def __init__(self, components_type, super_path, n_threads, export_nt=None, n_sequences=None, annotate_signal=None,
-                 export_cdr3_aa=None):
+                 export_cdr3_aa=None, n_pos_repertoires=None):
         self.components_type = components_type
         self.super_path = str(super_path).rstrip('/')
         self.n_threads = n_threads
@@ -18,6 +18,7 @@ def __init__(self, components_type, super_path, n_threads, export_nt=None, n_seq
         self.n_sequences = n_sequences
         self.annotate_signal = annotate_signal
         self.export_cdr3_aa = export_cdr3_aa
+        self.n_pos_repertoires = n_pos_repertoires
 
     def _set_component_specific_paths(self):
         # super_path in case of "public_private" concatenation is baseline_repertoires_path and
@@ -54,7 +55,10 @@ def concatenate_repertoire_components(self, file_number):
             except (pd.errors.EmptyDataError, FileNotFoundError) as e:
                 continue
         try:
-            concatenated_df = pd.concat(dfs_list)
+            if self.components_type == "public_private":
+                concatenated_df = pd.concat(dfs_list)
+            else:
+                concatenated_df = concatenate_dataframes_with_replacement(dfs_list)
             if self.export_cdr3_aa is True:
                 concatenated_df['cdr3_aa'] = concatenated_df['junction_aa'].str[1:-1]
                 concatenated_df = concatenated_df.drop('junction_aa', axis=1)
@@ -83,11 +87,13 @@ def multi_concatenate_repertoire_components(self):
             proxy_subject_ids = [secrets.token_hex(16) for i in range(len(found_primary_reps))]
             proxy_primary_fns = [subject_id + ".tsv" for subject_id in proxy_subject_ids]
             self.proxy_primary_fns = dict(zip(primary_rep_fns, proxy_primary_fns))
-            secondary_rep_fns = [os.path.basename(rep) for rep in found_secondary_reps]
-            lab_pos = [True if rep in secondary_rep_fns else False for rep in primary_rep_fns]
+            rep_indices = [int(rep.split("_")[1].split(".")[0]) for rep in primary_rep_fns]
+            lab_pos = [True if i < self.n_pos_repertoires else False for i in rep_indices]
+            labels_mapping = dict(zip(primary_rep_fns, lab_pos))
             file_names = [self.proxy_primary_fns[rep] for rep in primary_rep_fns]
+            labels = [labels_mapping[rep] for rep in primary_rep_fns]
             subject_ids = [fn.split(".")[0] for fn in file_names]
-            metadata_dict = {'subject_id': subject_ids, 'filename': file_names, 'label_positive': lab_pos}
+            metadata_dict = {'subject_id': subject_ids, 'filename': file_names, 'label_positive': labels}
             metadata_df = pd.DataFrame.from_dict(metadata_dict)
             metadata_df.to_csv(os.path.join(self.super_path, "metadata.csv"))
             metadata_df.to_csv(os.path.join(self.concatenated_reps_path, "metadata.csv"))
diff --git a/simAIRR/expand_repertoire_components/SignalComponentGeneration.py b/simAIRR/expand_repertoire_components/SignalComponentGeneration.py
index f8fc96f..09dccc7 100644
--- a/simAIRR/expand_repertoire_components/SignalComponentGeneration.py
+++ b/simAIRR/expand_repertoire_components/SignalComponentGeneration.py
@@ -85,7 +85,7 @@ def _write_signal_components(self, implantable_seq_subset_indices, pgen_interval
             np.savetxt(os.path.join(self.signal_components_path, "implanted_sequences_frequencies_neg_reps.txt"),
                        neg_rep_num, fmt="%s")
             neg_rep_seq_presence_indices = ImplantationHelper.get_repertoire_sequence_presence_indices(
-                desired_num_repertoires=self.desired_num_repertoires, abs_num_of_reps_list=neg_rep_num)
+                desired_num_repertoires=self.n_neg_repertoires, abs_num_of_reps_list=neg_rep_num)
         else:
             neg_rep_seq_presence_indices = []
         seq_presence_indices = pos_rep_seq_presence_indices + neg_rep_seq_presence_indices
@@ -97,7 +97,7 @@ def _write_signal_components(self, implantable_seq_subset_indices, pgen_interval
     def _determine_signal_sequence_combination(self, pgen_intervals_array, pool_size):
         if len(pgen_intervals_array) > pool_size:
             valid_seq_proportions = [1 - seq_proportion for seq_proportion in
-                                     [0, 0.05, 0.10, 0.20, 0.40, 0.60, 0.80, 0.90, 0.95] if
+                                     [0, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95] if
                                      seq_proportion * len(pgen_intervals_array) > pool_size]
             subset_seqs_total_implant_counts = {}
             subset_seq_indices = {}
diff --git a/simAIRR/util/utilities.py b/simAIRR/util/utilities.py
index 1e92b26..6f6f2f4 100644
--- a/simAIRR/util/utilities.py
+++ b/simAIRR/util/utilities.py
@@ -1,3 +1,4 @@
+import logging
 import os.path
 import subprocess
 
@@ -66,3 +67,39 @@ def count_lines(file_path):
         return line_count
     except (subprocess.CalledProcessError, FileNotFoundError):
         return 0
+
+
+def concatenate_dataframes_with_replacement(dfs_list):
+    df_a, df_b = dfs_list
+    for idx, row in df_b.iterrows():
+        match_idx = df_a[(df_a['v_call'] == row['v_call']) & (df_a['j_call'] == row['j_call'])].index
+        if not match_idx.empty:
+            df_a = df_a.drop(match_idx[0])
+    df = pd.concat([df_a, df_b], ignore_index=True)
+    return df
+
+
+def filter_legal_pairs(df, legal_pairs):
+    df['pairs'] = list(zip(df['v_gene'], df['j_gene']))
+    initial_n_rows = df.shape[0]
+    df = df[df['pairs'].isin(legal_pairs)]
+    df = df.drop(columns=['pairs'])
+    final_n_rows = df.shape[0]
+    logging.info('Number of sequences removed from the user-supplied signal because of lack of legal gene '
+                 'combinations: ' + str(initial_n_rows - final_n_rows))
+    return df
+
+def get_legal_vj_pairs(background_sequences_path):
+    df = pd.read_csv(background_sequences_path, sep="\t", header=0)
+    if df.shape[1] == 3:
+        df.insert(0, 'nt_seq', "NA")
+    df.columns = ['junction', 'junction_aa', 'v_call', 'j_call']
+    df['v_j_call'] = list(zip(df['v_call'], df['j_call']))
+    unique_combinations = df['v_j_call'].value_counts()
+    unique_combinations = unique_combinations.reset_index()
+    unique_combinations.columns = ['v_j_call', 'count']
+    unique_combinations['count'] = unique_combinations['count'].astype(int)
+    unique_combinations['percentage'] = unique_combinations['count'] / unique_combinations['count'].sum() * 100
+    filtered_combinations = unique_combinations[unique_combinations['percentage'] > 0.015]
+    legal_combinations = filtered_combinations['v_j_call'].to_list()
+    return legal_combinations
diff --git a/simAIRR/workflows/Workflows.py b/simAIRR/workflows/Workflows.py
index d532bde..54c4397 100644
--- a/simAIRR/workflows/Workflows.py
+++ b/simAIRR/workflows/Workflows.py
@@ -11,7 +11,7 @@
 from simAIRR.olga_compute_pgen.OlgaPgenComputation import OlgaPgenComputation
 from simAIRR.olga_compute_pgen.UniqueSequenceFilter import UniqueSequenceFilter
 from simAIRR.pgen_count_map.PgenCountMap import PgenCountMap
-from simAIRR.util.utilities import makedir_if_not_exists, sort_olga_seq_by_pgen
+from simAIRR.util.utilities import makedir_if_not_exists, sort_olga_seq_by_pgen, filter_legal_pairs, get_legal_vj_pairs
 
 
 class Workflows:
@@ -139,12 +139,17 @@ def _parse_and_validate_user_signal(self):
                     self.export_nt = False
             if user_signal.iloc[:, 0].isnull().all():
                 self.export_nt = False
+        user_signal.columns = ['nt_seq', 'aa_seq', 'v_gene', 'j_gene']
+        if self.background_sequences_path is not None:
+            legal_pairs = get_legal_vj_pairs(self.background_sequences_path)
+            user_signal = filter_legal_pairs(user_signal, legal_pairs)
         return user_signal
 
     def _simulated_repertoire_generation(self):
         rep_concat = RepComponentConcatenation(components_type="baseline_and_signal", super_path=self.output_path,
                                                n_threads=self.n_threads, export_nt=self.export_nt,
-                                               export_cdr3_aa=self.export_cdr3_aa, annotate_signal=self.annotate_signal)
+                                               export_cdr3_aa=self.export_cdr3_aa, annotate_signal=self.annotate_signal,
+                                               n_pos_repertoires=self.n_pos_repertoires)
         logging.info('Concatenating the signal component and baseline repertoire component')
         rep_concat.multi_concatenate_repertoire_components()
 
diff --git a/tests/util/test_utilities.py b/tests/util/test_utilities.py
index 2b044bc..3da9015 100644
--- a/tests/util/test_utilities.py
+++ b/tests/util/test_utilities.py
@@ -1,5 +1,9 @@
+import glob
 import os
-from simAIRR.util.utilities import count_lines
+import pandas as pd
+from simAIRR.util.utilities import count_lines, concatenate_dataframes_with_replacement, get_legal_vj_pairs, \
+    filter_legal_pairs
+from simAIRR.workflows.Workflows import Workflows
 
 
 def test_count_lines(tmp_path):
@@ -8,3 +12,59 @@ def test_count_lines(tmp_path):
     exit_code = os.system(command)
     assert count_lines(out_filename) == 1000
 
+
+def test_concatenate_dataframes_with_replacement():
+    df_a = pd.DataFrame({
+        'junction': ['j1', 'j2', 'j3', 'j4', 'j5'],
+        'junction_aa': ['aa1', 'aa2', 'aa3', 'aa4', 'aa5'],
+        'v_call': ['TRBV20-1', 'TRBV20-2', 'TRBV20-3', 'TRBV20-1', 'TRBV20-2'],
+        'j_call': ['TRBJ2-1', 'TRBJ2-2', 'TRBJ2-3', 'TRBJ2-1', 'TRBJ2-2']
+    })
+
+    df_b = pd.DataFrame({
+        'junction': ['j6', 'j7', 'j8'],
+        'junction_aa': ['aa5', 'aa6', 'aa7'],
+        'v_call': ['TRBV20-1', 'TRBV20-2', 'TRBV20-2'],
+        'j_call': ['TRBJ2-1', 'TRBJ2-2', 'TRBJ2-5']
+    })
+
+    df = concatenate_dataframes_with_replacement([df_a, df_b])
+    assert df.shape == (6, 4)
+    assert df['junction'].tolist() == ['j3', 'j4', 'j5', 'j6', 'j7', 'j8']
+
+def test_get_legal_vj_pairs(tmp_path):
+    user_config_dict = {'mode': 'baseline_repertoire_generation',
+                        'olga_model': 'humanTRB',
+                        'output_path': None,
+                        'n_repertoires': 1,
+                        'seed': 1234,
+                        'n_sequences': 10000,
+                        'n_threads': 1,
+                        'store_intermediate_files': True,
+                        'depth_variation': True}
+    out_path = tmp_path / "workflow_output"
+    user_config_dict['output_path'] = out_path
+    desired_workflow = Workflows(**user_config_dict)
+    desired_workflow.execute()
+    background_sequences = glob.glob(os.path.join(out_path, "simulated_repertoires", "*.tsv"))[0]
+    legal_pairs = get_legal_vj_pairs(background_sequences)
+    df = pd.read_csv(background_sequences, sep="\t", header=0)
+    if df.shape[1] == 3:
+        df.insert(0, 'nt_seq', "NA")
+    df.columns = ['junction', 'junction_aa', 'v_call', 'j_call']
+    df['v_j_call'] = list(zip(df['v_call'], df['j_call']))
+    unique_combinations = df['v_j_call'].value_counts()
+    assert len(legal_pairs) <= unique_combinations.shape[0]
+
+def test_filter_legal_pairs():
+    df = pd.DataFrame({
+        'junction': ['j1', 'j2', 'j3', 'j4', 'j5'],
+        'junction_aa': ['aa1', 'aa2', 'aa3', 'aa4', 'aa5'],
+        'v_gene': ['TRBV20-1', 'TRBV20-2', 'TRBV20-3', 'TRBV20-4', 'TRBV20-2'],
+        'j_gene': ['TRBJ2-1', 'TRBJ2-2', 'TRBJ2-3', 'TRBJ2-4', 'TRBJ2-2']
+    })
+
+    legal_pairs = [('TRBV20-1', 'TRBJ2-1'), ('TRBV20-2', 'TRBJ2-2')]
+    df = filter_legal_pairs(df, legal_pairs)
+    assert df.shape == (3, 4)
+    assert df['junction'].tolist() == ['j1', 'j2', 'j5']
\ No newline at end of file
diff --git a/tests/workflows/test_Workflows.py b/tests/workflows/test_Workflows.py
index d305db3..6f39355 100644
--- a/tests/workflows/test_Workflows.py
+++ b/tests/workflows/test_Workflows.py
@@ -69,6 +69,7 @@ def prepare_test_data_signal_implantation_workflow():
                         'signal_sequences_file': None,
                         'positive_label_rate': 0.5,
                         'phenotype_burden': 2,
+                        'noise_rate': 0.5,
                         'phenotype_pool_size': None,
                         'allow_closer_phenotype_burden': True,
                         'store_intermediate_files': True,
@@ -111,4 +112,16 @@ def test_workflow_generate_baseline_repertoires(tmp_path):
     user_config_dict['output_path'] = out_path
     desired_workflow = Workflows(**user_config_dict)
     desired_workflow.execute()
-    print(out_path)
\ No newline at end of file
+
+def test__parse_and_validate_user_signal(tmp_path):
+    out_path = tmp_path / "workflow_output"
+    print(out_path)
+    user_config_dict, signal_sequences = prepare_test_data_signal_implantation_workflow()
+    signal_file_path = os.path.join(tmp_path, 'signal_sequences.tsv')
+    signal_sequences = signal_sequences.drop(signal_sequences.columns[0], axis=1)
+    signal_sequences.to_csv(signal_file_path, index=None, header=None, sep='\t')
+    user_config_dict['signal_sequences_file'] = signal_file_path
+    user_config_dict['output_path'] = out_path
+    desired_workflow = Workflows(**user_config_dict)
+    user_signal = desired_workflow._parse_and_validate_user_signal()
+    print(user_signal)
\ No newline at end of file