Skip to content

Commit

Permalink
-changes how positive labels are assigned
Browse files Browse the repository at this point in the history
-changes how repertoire components are concatenated
- filters user-supplied signal for legal gene pairs
  • Loading branch information
KanduriC committed Feb 7, 2024
1 parent 2817a52 commit 326f8f1
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from multiprocessing import Pool
import pandas as pd
import numpy as np
from simAIRR.util.utilities import makedir_if_not_exists
from simAIRR.util.utilities import makedir_if_not_exists, concatenate_dataframes_with_replacement


class RepComponentConcatenation:
def __init__(self, components_type, super_path, n_threads, export_nt=None, n_sequences=None, annotate_signal=None,
export_cdr3_aa=None):
export_cdr3_aa=None, n_pos_repertoires=None):
self.components_type = components_type
self.super_path = str(super_path).rstrip('/')
self.n_threads = n_threads
Expand All @@ -18,6 +18,7 @@ def __init__(self, components_type, super_path, n_threads, export_nt=None, n_seq
self.n_sequences = n_sequences
self.annotate_signal = annotate_signal
self.export_cdr3_aa = export_cdr3_aa
self.n_pos_repertoires = n_pos_repertoires

def _set_component_specific_paths(self):
# super_path in case of "public_private" concatenation is baseline_repertoires_path and
Expand Down Expand Up @@ -54,7 +55,10 @@ def concatenate_repertoire_components(self, file_number):
except (pd.errors.EmptyDataError, FileNotFoundError) as e:
continue
try:
concatenated_df = pd.concat(dfs_list)
if self.components_type == "public_private":
concatenated_df = pd.concat(dfs_list)
else:
concatenated_df = concatenate_dataframes_with_replacement(dfs_list)
if self.export_cdr3_aa is True:
concatenated_df['cdr3_aa'] = concatenated_df['junction_aa'].str[1:-1]
concatenated_df = concatenated_df.drop('junction_aa', axis=1)
Expand Down Expand Up @@ -83,11 +87,13 @@ def multi_concatenate_repertoire_components(self):
proxy_subject_ids = [secrets.token_hex(16) for i in range(len(found_primary_reps))]
proxy_primary_fns = [subject_id + ".tsv" for subject_id in proxy_subject_ids]
self.proxy_primary_fns = dict(zip(primary_rep_fns, proxy_primary_fns))
secondary_rep_fns = [os.path.basename(rep) for rep in found_secondary_reps]
lab_pos = [True if rep in secondary_rep_fns else False for rep in primary_rep_fns]
rep_indices = [int(rep.split("_")[1].split(".")[0]) for rep in primary_rep_fns]
lab_pos = [True if i < self.n_pos_repertoires else False for i in rep_indices]
labels_mapping = dict(zip(primary_rep_fns, lab_pos))
file_names = [self.proxy_primary_fns[rep] for rep in primary_rep_fns]
labels = [labels_mapping[rep] for rep in primary_rep_fns]
subject_ids = [fn.split(".")[0] for fn in file_names]
metadata_dict = {'subject_id': subject_ids, 'filename': file_names, 'label_positive': lab_pos}
metadata_dict = {'subject_id': subject_ids, 'filename': file_names, 'label_positive': labels}
metadata_df = pd.DataFrame.from_dict(metadata_dict)
metadata_df.to_csv(os.path.join(self.super_path, "metadata.csv"))
metadata_df.to_csv(os.path.join(self.concatenated_reps_path, "metadata.csv"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _write_signal_components(self, implantable_seq_subset_indices, pgen_interval
np.savetxt(os.path.join(self.signal_components_path, "implanted_sequences_frequencies_neg_reps.txt"),
neg_rep_num, fmt="%s")
neg_rep_seq_presence_indices = ImplantationHelper.get_repertoire_sequence_presence_indices(
desired_num_repertoires=self.desired_num_repertoires, abs_num_of_reps_list=neg_rep_num)
desired_num_repertoires=self.n_neg_repertoires, abs_num_of_reps_list=neg_rep_num)
else:
neg_rep_seq_presence_indices = []
seq_presence_indices = pos_rep_seq_presence_indices + neg_rep_seq_presence_indices
Expand All @@ -97,7 +97,7 @@ def _write_signal_components(self, implantable_seq_subset_indices, pgen_interval
def _determine_signal_sequence_combination(self, pgen_intervals_array, pool_size):
if len(pgen_intervals_array) > pool_size:
valid_seq_proportions = [1 - seq_proportion for seq_proportion in
[0, 0.05, 0.10, 0.20, 0.40, 0.60, 0.80, 0.90, 0.95] if
[0, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95] if
seq_proportion * len(pgen_intervals_array) > pool_size]
subset_seqs_total_implant_counts = {}
subset_seq_indices = {}
Expand Down
37 changes: 37 additions & 0 deletions simAIRR/util/utilities.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os.path
import subprocess

Expand Down Expand Up @@ -66,3 +67,39 @@ def count_lines(file_path):
return line_count
except (subprocess.CalledProcessError, FileNotFoundError):
return 0


def concatenate_dataframes_with_replacement(dfs_list):
df_a, df_b = dfs_list
for idx, row in df_b.iterrows():
match_idx = df_a[(df_a['v_call'] == row['v_call']) & (df_a['j_call'] == row['j_call'])].index
if not match_idx.empty:
df_a = df_a.drop(match_idx[0])
df = pd.concat([df_a, df_b], ignore_index=True)
return df


def filter_legal_pairs(df, legal_pairs):
df['pairs'] = list(zip(df['v_gene'], df['j_gene']))
initial_n_rows = df.shape[0]
df = df[df['pairs'].isin(legal_pairs)]
df = df.drop(columns=['pairs'])
final_n_rows = df.shape[0]
logging.info('Number of sequences removed from the user-supplied signal because of lack of legal gene '
'combinations: ' + str(initial_n_rows - final_n_rows))
return df

def get_legal_vj_pairs(background_sequences_path):
df = pd.read_csv(background_sequences_path, sep="\t", header=0)
if df.shape[1] == 3:
df.insert(0, 'nt_seq', "NA")
df.columns = ['junction', 'junction_aa', 'v_call', 'j_call']
df['v_j_call'] = list(zip(df['v_call'], df['j_call']))
unique_combinations = df['v_j_call'].value_counts()
unique_combinations = unique_combinations.reset_index()
unique_combinations.columns = ['v_j_call', 'count']
unique_combinations['count'] = unique_combinations['count'].astype(int)
unique_combinations['percentage'] = unique_combinations['count'] / unique_combinations['count'].sum() * 100
filtered_combinations = unique_combinations[unique_combinations['percentage'] > 0.015]
legal_combinations = filtered_combinations['v_j_call'].to_list()
return legal_combinations
9 changes: 7 additions & 2 deletions simAIRR/workflows/Workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from simAIRR.olga_compute_pgen.OlgaPgenComputation import OlgaPgenComputation
from simAIRR.olga_compute_pgen.UniqueSequenceFilter import UniqueSequenceFilter
from simAIRR.pgen_count_map.PgenCountMap import PgenCountMap
from simAIRR.util.utilities import makedir_if_not_exists, sort_olga_seq_by_pgen
from simAIRR.util.utilities import makedir_if_not_exists, sort_olga_seq_by_pgen, filter_legal_pairs, get_legal_vj_pairs


class Workflows:
Expand Down Expand Up @@ -139,12 +139,17 @@ def _parse_and_validate_user_signal(self):
self.export_nt = False
if user_signal.iloc[:, 0].isnull().all():
self.export_nt = False
user_signal.columns = ['nt_seq', 'aa_seq', 'v_gene', 'j_gene']
if self.background_sequences_path is not None:
legal_pairs = get_legal_vj_pairs(self.background_sequences_path)
user_signal = filter_legal_pairs(user_signal, legal_pairs)
return user_signal

def _simulated_repertoire_generation(self):
rep_concat = RepComponentConcatenation(components_type="baseline_and_signal", super_path=self.output_path,
n_threads=self.n_threads, export_nt=self.export_nt,
export_cdr3_aa=self.export_cdr3_aa, annotate_signal=self.annotate_signal)
export_cdr3_aa=self.export_cdr3_aa, annotate_signal=self.annotate_signal,
n_pos_repertoires=self.n_pos_repertoires)
logging.info('Concatenating the signal component and baseline repertoire component')
rep_concat.multi_concatenate_repertoire_components()

Expand Down
62 changes: 61 additions & 1 deletion tests/util/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import glob
import os
from simAIRR.util.utilities import count_lines
import pandas as pd
from simAIRR.util.utilities import count_lines, concatenate_dataframes_with_replacement, get_legal_vj_pairs, \
filter_legal_pairs
from simAIRR.workflows.Workflows import Workflows


def test_count_lines(tmp_path):
Expand All @@ -8,3 +12,59 @@ def test_count_lines(tmp_path):
exit_code = os.system(command)
assert count_lines(out_filename) == 1000


def test_concatenate_dataframes_with_replacement():
df_a = pd.DataFrame({
'junction': ['j1', 'j2', 'j3', 'j4', 'j5'],
'junction_aa': ['aa1', 'aa2', 'aa3', 'aa4', 'aa5'],
'v_call': ['TRBV20-1', 'TRBV20-2', 'TRBV20-3', 'TRBV20-1', 'TRBV20-2'],
'j_call': ['TRBJ2-1', 'TRBJ2-2', 'TRBJ2-3', 'TRBJ2-1', 'TRBJ2-2']
})

df_b = pd.DataFrame({
'junction': ['j6', 'j7', 'j8'],
'junction_aa': ['aa5', 'aa6', 'aa7'],
'v_call': ['TRBV20-1', 'TRBV20-2', 'TRBV20-2'],
'j_call': ['TRBJ2-1', 'TRBJ2-2', 'TRBJ2-5']
})

df = concatenate_dataframes_with_replacement([df_a, df_b])
assert df.shape == (6, 4)
assert df['junction'].tolist() == ['j3', 'j4', 'j5', 'j6', 'j7', 'j8']

def test_get_legal_vj_pairs(tmp_path):
user_config_dict = {'mode': 'baseline_repertoire_generation',
'olga_model': 'humanTRB',
'output_path': None,
'n_repertoires': 1,
'seed': 1234,
'n_sequences': 10000,
'n_threads': 1,
'store_intermediate_files': True,
'depth_variation': True}
out_path = tmp_path / "workflow_output"
user_config_dict['output_path'] = out_path
desired_workflow = Workflows(**user_config_dict)
desired_workflow.execute()
background_sequences = glob.glob(os.path.join(out_path, "simulated_repertoires", "*.tsv"))[0]
legal_pairs = get_legal_vj_pairs(background_sequences)
df = pd.read_csv(background_sequences, sep="\t", header=0)
if df.shape[1] == 3:
df.insert(0, 'nt_seq', "NA")
df.columns = ['junction', 'junction_aa', 'v_call', 'j_call']
df['v_j_call'] = list(zip(df['v_call'], df['j_call']))
unique_combinations = df['v_j_call'].value_counts()
assert len(legal_pairs) <= unique_combinations.shape[0]

def test_filter_legal_pairs():
df = pd.DataFrame({
'junction': ['j1', 'j2', 'j3', 'j4', 'j5'],
'junction_aa': ['aa1', 'aa2', 'aa3', 'aa4', 'aa5'],
'v_gene': ['TRBV20-1', 'TRBV20-2', 'TRBV20-3', 'TRBV20-4', 'TRBV20-2'],
'j_gene': ['TRBJ2-1', 'TRBJ2-2', 'TRBJ2-3', 'TRBJ2-4', 'TRBJ2-2']
})

legal_pairs = [('TRBV20-1', 'TRBJ2-1'), ('TRBV20-2', 'TRBJ2-2')]
df = filter_legal_pairs(df, legal_pairs)
assert df.shape == (3, 4)
assert df['junction'].tolist() == ['j1', 'j2', 'j5']
15 changes: 14 additions & 1 deletion tests/workflows/test_Workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def prepare_test_data_signal_implantation_workflow():
'signal_sequences_file': None,
'positive_label_rate': 0.5,
'phenotype_burden': 2,
'noise_rate': 0.5,
'phenotype_pool_size': None,
'allow_closer_phenotype_burden': True,
'store_intermediate_files': True,
Expand Down Expand Up @@ -111,4 +112,16 @@ def test_workflow_generate_baseline_repertoires(tmp_path):
user_config_dict['output_path'] = out_path
desired_workflow = Workflows(**user_config_dict)
desired_workflow.execute()
print(out_path)

def test__parse_and_validate_user_signal(tmp_path):
out_path = tmp_path / "workflow_output"
print(out_path)
user_config_dict, signal_sequences = prepare_test_data_signal_implantation_workflow()
signal_file_path = os.path.join(tmp_path, 'signal_sequences.tsv')
signal_sequences = signal_sequences.drop(signal_sequences.columns[0], axis=1)
signal_sequences.to_csv(signal_file_path, index=None, header=None, sep='\t')
user_config_dict['signal_sequences_file'] = signal_file_path
user_config_dict['output_path'] = out_path
desired_workflow = Workflows(**user_config_dict)
user_signal = desired_workflow._parse_and_validate_user_signal()
print(user_signal)

0 comments on commit 326f8f1

Please sign in to comment.