-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanalyze_perturbation_patterns.py
181 lines (142 loc) · 6.17 KB
/
analyze_perturbation_patterns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#!/usr/bin/env python3
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
def load_data(data_path):
"""Load and preprocess the perturbation data."""
print(f"Loading data from: {data_path}")
df = pd.read_csv(data_path)
print(f"Loaded {len(df)} rows of data")
# Extract topic number from response_id
df['topic'] = df['response_id'].str.extract('X(\d+)').astype(int)
return df
def create_perturbation_heatmap(df, output_dir):
"""Create a heatmap of perturbation effects across topics."""
print("Creating perturbation heatmap...")
# Pivot data to create matrix of genes x topics
pivot_df = df.pivot(index='grna_target',
columns='topic',
values='log_2_fold_change')
# Filter for genes that are significant in at least one topic
sig_mask = df.pivot(index='grna_target',
columns='topic',
values='significant')
sig_genes = sig_mask.index[sig_mask.any(axis=1)]
pivot_df = pivot_df.loc[sig_genes]
print(f"Found {len(sig_genes)} genes with significant effects")
# Create figure
plt.figure(figsize=(20, 12))
# Create heatmap
g = sns.clustermap(pivot_df,
cmap='RdBu_r',
center=0,
vmin=-2, vmax=2,
xticklabels=True,
yticklabels=True,
dendrogram_ratio=(.1, .2),
cbar_pos=(.02, .32, .03, .2),
figsize=(20, 12))
# Rotate x-axis labels
plt.setp(g.ax_heatmap.get_xticklabels(), rotation=45, ha='right')
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0)
g.fig.suptitle('Perturbation Effects Across Topics', y=1.02, fontsize=16)
# Save plot
output_path = os.path.join(output_dir, 'perturbation_heatmap.pdf')
print(f"Saving heatmap to: {output_path}")
g.savefig(output_path, bbox_inches='tight', dpi=300)
plt.close()
def analyze_gene_patterns(df):
"""Analyze patterns in gene perturbation effects."""
print("Analyzing gene patterns...")
# Count number of topics each gene affects
gene_effects = df[df['significant']].groupby('grna_target').agg({
'topic': 'count',
'log_2_fold_change': ['mean', 'std']
}).round(3)
gene_effects.columns = ['n_topics_affected', 'mean_effect', 'effect_std']
gene_effects = gene_effects.sort_values('n_topics_affected', ascending=False)
return gene_effects
def create_topic_similarity_matrix(df, output_dir):
"""Create a similarity matrix between topics based on their perturbation profiles."""
print("Creating topic similarity matrix...")
# Create matrix of perturbation effects
pivot_df = df.pivot(index='topic',
columns='grna_target',
values='log_2_fold_change').fillna(0)
# Calculate correlation matrix
corr_matrix = pivot_df.corr()
# Plot correlation matrix
plt.figure(figsize=(12, 10))
g = sns.clustermap(corr_matrix,
cmap='RdBu_r',
center=0,
vmin=-1, vmax=1,
xticklabels=True,
yticklabels=True)
# Rotate labels
plt.setp(g.ax_heatmap.get_xticklabels(), rotation=45, ha='right')
plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0)
g.fig.suptitle('Topic Similarity Based on Perturbation Profiles', y=1.02, fontsize=16)
# Save plot
output_path = os.path.join(output_dir, 'topic_similarity_matrix.pdf')
print(f"Saving similarity matrix to: {output_path}")
g.savefig(output_path, bbox_inches='tight', dpi=300)
plt.close()
return corr_matrix
def analyze_topic_clusters(corr_matrix, threshold=0.7):
"""Identify clusters of related topics."""
print("Analyzing topic clusters...")
# Find highly correlated topic pairs
high_corr = np.where(np.triu(corr_matrix > threshold, k=1))
topic_clusters = []
for i, j in zip(*high_corr):
topic_clusters.append({
'topic1': corr_matrix.index[i],
'topic2': corr_matrix.index[j],
'correlation': corr_matrix.iloc[i, j]
})
return pd.DataFrame(topic_clusters)
def create_top_genes_barplot(gene_effects, output_dir, top_n=20):
"""Create a bar plot of the top genes affecting multiple topics."""
print(f"Creating bar plot of top {top_n} genes...")
plt.figure(figsize=(15, 8))
top_genes = gene_effects.head(top_n)
# Create bar plot
sns.barplot(data=top_genes.reset_index(),
x='grna_target',
y='n_topics_affected',
color='skyblue')
# Customize plot
plt.xticks(rotation=45, ha='right')
plt.xlabel('Gene')
plt.ylabel('Number of Topics Affected')
plt.title(f'Top {top_n} Genes by Number of Topics Affected')
# Save plot
output_path = os.path.join(output_dir, 'top_genes_barplot.pdf')
print(f"Saving bar plot to: {output_path}")
plt.savefig(output_path, bbox_inches='tight', dpi=300)
plt.close()
def main():
# Setup paths
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(script_dir)
data_path = os.path.join(parent_dir, 'FP_moi15_thresho20_60k_-celltype_default.csv')
output_dir = os.path.join(script_dir, 'analysis_results')
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")
# Load data
df = load_data(data_path)
# Generate visualizations and analyses
create_perturbation_heatmap(df, output_dir)
gene_patterns = analyze_gene_patterns(df)
gene_patterns.to_csv(os.path.join(output_dir, 'gene_patterns.csv'))
create_top_genes_barplot(gene_patterns, output_dir)
corr_matrix = create_topic_similarity_matrix(df, output_dir)
topic_clusters = analyze_topic_clusters(corr_matrix)
topic_clusters.to_csv(os.path.join(output_dir, 'topic_clusters.csv'), index=False)
print("Analysis complete! Results saved in:", output_dir)
if __name__ == '__main__':
main()