From 326f8f1c07d0b862a196e9f6a3915a75bc42a019 Mon Sep 17 00:00:00 2001 From: kanduric <chakri.co@gmail.com> Date: Wed, 7 Feb 2024 14:12:38 +0100 Subject: [PATCH] -changes how positive labels are assigned -changes how repertoire components are concatenated - filters user-supplied signal for legal gene pairs --- .../RepComponentConcatenation.py | 18 ++++-- .../SignalComponentGeneration.py | 4 +- simAIRR/util/utilities.py | 37 +++++++++++ simAIRR/workflows/Workflows.py | 9 ++- tests/util/test_utilities.py | 62 ++++++++++++++++++- tests/workflows/test_Workflows.py | 15 ++++- 6 files changed, 133 insertions(+), 12 deletions(-) diff --git a/simAIRR/concatenate_repertoire_components/RepComponentConcatenation.py b/simAIRR/concatenate_repertoire_components/RepComponentConcatenation.py index 44de477..56f9f09 100644 --- a/simAIRR/concatenate_repertoire_components/RepComponentConcatenation.py +++ b/simAIRR/concatenate_repertoire_components/RepComponentConcatenation.py @@ -4,12 +4,12 @@ from multiprocessing import Pool import pandas as pd import numpy as np -from simAIRR.util.utilities import makedir_if_not_exists +from simAIRR.util.utilities import makedir_if_not_exists, concatenate_dataframes_with_replacement class RepComponentConcatenation: def __init__(self, components_type, super_path, n_threads, export_nt=None, n_sequences=None, annotate_signal=None, - export_cdr3_aa=None): + export_cdr3_aa=None, n_pos_repertoires=None): self.components_type = components_type self.super_path = str(super_path).rstrip('/') self.n_threads = n_threads @@ -18,6 +18,7 @@ def __init__(self, components_type, super_path, n_threads, export_nt=None, n_seq self.n_sequences = n_sequences self.annotate_signal = annotate_signal self.export_cdr3_aa = export_cdr3_aa + self.n_pos_repertoires = n_pos_repertoires def _set_component_specific_paths(self): # super_path in case of "public_private" concatenation is baseline_repertoires_path and @@ -54,7 +55,10 @@ def concatenate_repertoire_components(self, file_number): except (pd.errors.EmptyDataError, FileNotFoundError) as e: continue try: - concatenated_df = pd.concat(dfs_list) + if self.components_type == "public_private": + concatenated_df = pd.concat(dfs_list) + else: + concatenated_df = concatenate_dataframes_with_replacement(dfs_list) if self.export_cdr3_aa is True: concatenated_df['cdr3_aa'] = concatenated_df['junction_aa'].str[1:-1] concatenated_df = concatenated_df.drop('junction_aa', axis=1) @@ -83,11 +87,13 @@ def multi_concatenate_repertoire_components(self): proxy_subject_ids = [secrets.token_hex(16) for i in range(len(found_primary_reps))] proxy_primary_fns = [subject_id + ".tsv" for subject_id in proxy_subject_ids] self.proxy_primary_fns = dict(zip(primary_rep_fns, proxy_primary_fns)) - secondary_rep_fns = [os.path.basename(rep) for rep in found_secondary_reps] - lab_pos = [True if rep in secondary_rep_fns else False for rep in primary_rep_fns] + rep_indices = [int(rep.split("_")[1].split(".")[0]) for rep in primary_rep_fns] + lab_pos = [True if i < self.n_pos_repertoires else False for i in rep_indices] + labels_mapping = dict(zip(primary_rep_fns, lab_pos)) file_names = [self.proxy_primary_fns[rep] for rep in primary_rep_fns] + labels = [labels_mapping[rep] for rep in primary_rep_fns] subject_ids = [fn.split(".")[0] for fn in file_names] - metadata_dict = {'subject_id': subject_ids, 'filename': file_names, 'label_positive': lab_pos} + metadata_dict = {'subject_id': subject_ids, 'filename': file_names, 'label_positive': labels} metadata_df = pd.DataFrame.from_dict(metadata_dict) metadata_df.to_csv(os.path.join(self.super_path, "metadata.csv")) metadata_df.to_csv(os.path.join(self.concatenated_reps_path, "metadata.csv")) diff --git a/simAIRR/expand_repertoire_components/SignalComponentGeneration.py b/simAIRR/expand_repertoire_components/SignalComponentGeneration.py index f8fc96f..09dccc7 100644 --- a/simAIRR/expand_repertoire_components/SignalComponentGeneration.py +++ b/simAIRR/expand_repertoire_components/SignalComponentGeneration.py @@ -85,7 +85,7 @@ def _write_signal_components(self, implantable_seq_subset_indices, pgen_interval np.savetxt(os.path.join(self.signal_components_path, "implanted_sequences_frequencies_neg_reps.txt"), neg_rep_num, fmt="%s") neg_rep_seq_presence_indices = ImplantationHelper.get_repertoire_sequence_presence_indices( - desired_num_repertoires=self.desired_num_repertoires, abs_num_of_reps_list=neg_rep_num) + desired_num_repertoires=self.n_neg_repertoires, abs_num_of_reps_list=neg_rep_num) else: neg_rep_seq_presence_indices = [] seq_presence_indices = pos_rep_seq_presence_indices + neg_rep_seq_presence_indices @@ -97,7 +97,7 @@ def _write_signal_components(self, implantable_seq_subset_indices, pgen_interval def _determine_signal_sequence_combination(self, pgen_intervals_array, pool_size): if len(pgen_intervals_array) > pool_size: valid_seq_proportions = [1 - seq_proportion for seq_proportion in - [0, 0.05, 0.10, 0.20, 0.40, 0.60, 0.80, 0.90, 0.95] if + [0, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95] if seq_proportion * len(pgen_intervals_array) > pool_size] subset_seqs_total_implant_counts = {} subset_seq_indices = {} diff --git a/simAIRR/util/utilities.py b/simAIRR/util/utilities.py index 1e92b26..6f6f2f4 100644 --- a/simAIRR/util/utilities.py +++ b/simAIRR/util/utilities.py @@ -1,3 +1,4 @@ +import logging import os.path import subprocess @@ -66,3 +67,39 @@ def count_lines(file_path): return line_count except (subprocess.CalledProcessError, FileNotFoundError): return 0 + + +def concatenate_dataframes_with_replacement(dfs_list): + df_a, df_b = dfs_list + for idx, row in df_b.iterrows(): + match_idx = df_a[(df_a['v_call'] == row['v_call']) & (df_a['j_call'] == row['j_call'])].index + if not match_idx.empty: + df_a = df_a.drop(match_idx[0]) + df = pd.concat([df_a, df_b], ignore_index=True) + return df + + +def filter_legal_pairs(df, legal_pairs): + df['pairs'] = list(zip(df['v_gene'], df['j_gene'])) + initial_n_rows = df.shape[0] + df = df[df['pairs'].isin(legal_pairs)] + df = df.drop(columns=['pairs']) + final_n_rows = df.shape[0] + logging.info('Number of sequences removed from the user-supplied signal because of lack of legal gene ' + 'combinations: ' + str(initial_n_rows - final_n_rows)) + return df + +def get_legal_vj_pairs(background_sequences_path): + df = pd.read_csv(background_sequences_path, sep="\t", header=0) + if df.shape[1] == 3: + df.insert(0, 'nt_seq', "NA") + df.columns = ['junction', 'junction_aa', 'v_call', 'j_call'] + df['v_j_call'] = list(zip(df['v_call'], df['j_call'])) + unique_combinations = df['v_j_call'].value_counts() + unique_combinations = unique_combinations.reset_index() + unique_combinations.columns = ['v_j_call', 'count'] + unique_combinations['count'] = unique_combinations['count'].astype(int) + unique_combinations['percentage'] = unique_combinations['count'] / unique_combinations['count'].sum() * 100 + filtered_combinations = unique_combinations[unique_combinations['percentage'] > 0.015] + legal_combinations = filtered_combinations['v_j_call'].to_list() + return legal_combinations diff --git a/simAIRR/workflows/Workflows.py b/simAIRR/workflows/Workflows.py index d532bde..54c4397 100644 --- a/simAIRR/workflows/Workflows.py +++ b/simAIRR/workflows/Workflows.py @@ -11,7 +11,7 @@ from simAIRR.olga_compute_pgen.OlgaPgenComputation import OlgaPgenComputation from simAIRR.olga_compute_pgen.UniqueSequenceFilter import UniqueSequenceFilter from simAIRR.pgen_count_map.PgenCountMap import PgenCountMap -from simAIRR.util.utilities import makedir_if_not_exists, sort_olga_seq_by_pgen +from simAIRR.util.utilities import makedir_if_not_exists, sort_olga_seq_by_pgen, filter_legal_pairs, get_legal_vj_pairs class Workflows: @@ -139,12 +139,17 @@ def _parse_and_validate_user_signal(self): self.export_nt = False if user_signal.iloc[:, 0].isnull().all(): self.export_nt = False + user_signal.columns = ['nt_seq', 'aa_seq', 'v_gene', 'j_gene'] + if self.background_sequences_path is not None: + legal_pairs = get_legal_vj_pairs(self.background_sequences_path) + user_signal = filter_legal_pairs(user_signal, legal_pairs) return user_signal def _simulated_repertoire_generation(self): rep_concat = RepComponentConcatenation(components_type="baseline_and_signal", super_path=self.output_path, n_threads=self.n_threads, export_nt=self.export_nt, - export_cdr3_aa=self.export_cdr3_aa, annotate_signal=self.annotate_signal) + export_cdr3_aa=self.export_cdr3_aa, annotate_signal=self.annotate_signal, + n_pos_repertoires=self.n_pos_repertoires) logging.info('Concatenating the signal component and baseline repertoire component') rep_concat.multi_concatenate_repertoire_components() diff --git a/tests/util/test_utilities.py b/tests/util/test_utilities.py index 2b044bc..3da9015 100644 --- a/tests/util/test_utilities.py +++ b/tests/util/test_utilities.py @@ -1,5 +1,9 @@ +import glob import os -from simAIRR.util.utilities import count_lines +import pandas as pd +from simAIRR.util.utilities import count_lines, concatenate_dataframes_with_replacement, get_legal_vj_pairs, \ + filter_legal_pairs +from simAIRR.workflows.Workflows import Workflows def test_count_lines(tmp_path): @@ -8,3 +12,59 @@ def test_count_lines(tmp_path): exit_code = os.system(command) assert count_lines(out_filename) == 1000 + +def test_concatenate_dataframes_with_replacement(): + df_a = pd.DataFrame({ + 'junction': ['j1', 'j2', 'j3', 'j4', 'j5'], + 'junction_aa': ['aa1', 'aa2', 'aa3', 'aa4', 'aa5'], + 'v_call': ['TRBV20-1', 'TRBV20-2', 'TRBV20-3', 'TRBV20-1', 'TRBV20-2'], + 'j_call': ['TRBJ2-1', 'TRBJ2-2', 'TRBJ2-3', 'TRBJ2-1', 'TRBJ2-2'] + }) + + df_b = pd.DataFrame({ + 'junction': ['j6', 'j7', 'j8'], + 'junction_aa': ['aa5', 'aa6', 'aa7'], + 'v_call': ['TRBV20-1', 'TRBV20-2', 'TRBV20-2'], + 'j_call': ['TRBJ2-1', 'TRBJ2-2', 'TRBJ2-5'] + }) + + df = concatenate_dataframes_with_replacement([df_a, df_b]) + assert df.shape == (6, 4) + assert df['junction'].tolist() == ['j3', 'j4', 'j5', 'j6', 'j7', 'j8'] + +def test_get_legal_vj_pairs(tmp_path): + user_config_dict = {'mode': 'baseline_repertoire_generation', + 'olga_model': 'humanTRB', + 'output_path': None, + 'n_repertoires': 1, + 'seed': 1234, + 'n_sequences': 10000, + 'n_threads': 1, + 'store_intermediate_files': True, + 'depth_variation': True} + out_path = tmp_path / "workflow_output" + user_config_dict['output_path'] = out_path + desired_workflow = Workflows(**user_config_dict) + desired_workflow.execute() + background_sequences = glob.glob(os.path.join(out_path, "simulated_repertoires", "*.tsv"))[0] + legal_pairs = get_legal_vj_pairs(background_sequences) + df = pd.read_csv(background_sequences, sep="\t", header=0) + if df.shape[1] == 3: + df.insert(0, 'nt_seq', "NA") + df.columns = ['junction', 'junction_aa', 'v_call', 'j_call'] + df['v_j_call'] = list(zip(df['v_call'], df['j_call'])) + unique_combinations = df['v_j_call'].value_counts() + assert len(legal_pairs) <= unique_combinations.shape[0] + +def test_filter_legal_pairs(): + df = pd.DataFrame({ + 'junction': ['j1', 'j2', 'j3', 'j4', 'j5'], + 'junction_aa': ['aa1', 'aa2', 'aa3', 'aa4', 'aa5'], + 'v_gene': ['TRBV20-1', 'TRBV20-2', 'TRBV20-3', 'TRBV20-4', 'TRBV20-2'], + 'j_gene': ['TRBJ2-1', 'TRBJ2-2', 'TRBJ2-3', 'TRBJ2-4', 'TRBJ2-2'] + }) + + legal_pairs = [('TRBV20-1', 'TRBJ2-1'), ('TRBV20-2', 'TRBJ2-2')] + df = filter_legal_pairs(df, legal_pairs) + assert df.shape == (3, 4) + assert df['junction'].tolist() == ['j1', 'j2', 'j5'] \ No newline at end of file diff --git a/tests/workflows/test_Workflows.py b/tests/workflows/test_Workflows.py index d305db3..6f39355 100644 --- a/tests/workflows/test_Workflows.py +++ b/tests/workflows/test_Workflows.py @@ -69,6 +69,7 @@ def prepare_test_data_signal_implantation_workflow(): 'signal_sequences_file': None, 'positive_label_rate': 0.5, 'phenotype_burden': 2, + 'noise_rate': 0.5, 'phenotype_pool_size': None, 'allow_closer_phenotype_burden': True, 'store_intermediate_files': True, @@ -111,4 +112,16 @@ def test_workflow_generate_baseline_repertoires(tmp_path): user_config_dict['output_path'] = out_path desired_workflow = Workflows(**user_config_dict) desired_workflow.execute() - print(out_path) \ No newline at end of file + +def test__parse_and_validate_user_signal(tmp_path): + out_path = tmp_path / "workflow_output" + print(out_path) + user_config_dict, signal_sequences = prepare_test_data_signal_implantation_workflow() + signal_file_path = os.path.join(tmp_path, 'signal_sequences.tsv') + signal_sequences = signal_sequences.drop(signal_sequences.columns[0], axis=1) + signal_sequences.to_csv(signal_file_path, index=None, header=None, sep='\t') + user_config_dict['signal_sequences_file'] = signal_file_path + user_config_dict['output_path'] = out_path + desired_workflow = Workflows(**user_config_dict) + user_signal = desired_workflow._parse_and_validate_user_signal() + print(user_signal) \ No newline at end of file