-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualization.py
169 lines (151 loc) · 6.96 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from tqdm import tqdm
# Set global matplotlib parameters for larger text
plt.rcParams.update({
'font.size': 28,
'axes.titlesize': 32,
'axes.labelsize': 28,
'xtick.labelsize': 24,
'ytick.labelsize': 24,
'legend.fontsize': 24,
'figure.titlesize': 36
})
def plot_movement(ax, session, label):
"""
Plot the movement pattern of a worm session.
Args:
ax (matplotlib.axes.Axes): The axis to plot on.
session (pd.DataFrame): Session data.
label (str): Label for the plot (e.g., "Animal 1, Session 1, Drugged").
"""
with tqdm(total=2, desc="Plotting movement", leave=False) as pbar:
scatter = ax.scatter(
session['X'], session['Y'], c=session['Frame'], cmap='viridis', s=20
)
pbar.update(1)
ax.plot(session['X'], session['Y'], color='gray', alpha=0.5, linewidth=1.0)
ax.set_title(label, pad=30)
ax.set_xlabel("X Coordinate", labelpad=20)
ax.set_ylabel("Y Coordinate", labelpad=20)
cbar = plt.colorbar(scatter, ax=ax, label="Frame Number")
cbar.ax.tick_params(labelsize=24)
cbar.set_label("Frame Number", size=24, labelpad=20)
pbar.update(1)
def plot_absolute_diff(ax, session, label):
"""
Plot the absolute coordinate differences between frames.
Args:
ax (matplotlib.axes.Axes): The axis to plot on.
session (pd.DataFrame): Session data.
label (str): Label for the plot (e.g., "Animal 1, Session 1, Drugged").
"""
with tqdm(total=1, desc="Plotting differences", leave=False) as pbar:
ax.plot(
session['Frame'], session['Delta_Distance'], marker='o', markersize=3, linestyle='-', alpha=0.8
)
ax.set_title(label, pad=15)
ax.set_xlabel("Frame Number", labelpad=10)
ax.set_ylabel("Absolute Distance Change", labelpad=10)
pbar.update(1)
def plot_clustered_movement(ax, session, clusters, label):
"""
Plot the movement pattern of a worm session with clusters color-coded.
Lines between points in the same cluster are gray; lines connecting different clusters are red.
Colors are assigned sequentially while preserving original cluster IDs.
Args:
ax (matplotlib.axes.Axes): The axis to plot on.
session (pd.DataFrame): Session data.
clusters (pd.Series): Cluster labels for each point.
label (str): Label for the plot.
"""
unique_clusters = sorted(clusters.unique()) # Sort to ensure consistent color assignment
color_map = cm.get_cmap('tab10', len(unique_clusters))
# Create mapping from cluster IDs to color indices
color_indices = {cluster_id: idx for idx, cluster_id in enumerate(unique_clusters)}
# Plot points with progress bar
with tqdm(total=len(unique_clusters) + len(session) - 1,
desc="Plotting clustered movement", leave=False) as pbar:
# Plot points for each cluster
for cluster_id in unique_clusters:
cluster_points = session[clusters == cluster_id]
color_idx = color_indices[cluster_id] # Get sequential color index while preserving ID
ax.scatter(
cluster_points['X'], cluster_points['Y'],
label=f"Cluster {cluster_id}", s=20,
c=np.array([color_map(color_idx)] * len(cluster_points)), alpha=0.6
)
pbar.update(1)
# Draw connections with progress tracking
for i in range(1, len(session)):
if clusters.iloc[i] == clusters.iloc[i - 1]: # Same cluster
ax.plot(
[session.iloc[i - 1]['X'], session.iloc[i]['X']],
[session.iloc[i - 1]['Y'], session.iloc[i]['Y']],
color='red', linewidth=1.0
)
else: # Different clusters
ax.plot(
[session.iloc[i - 1]['X'], session.iloc[i]['X']],
[session.iloc[i - 1]['Y'], session.iloc[i]['Y']],
color='lightgray', linewidth=1.6
)
pbar.update(1)
ax.set_title(label, pad=30)
ax.set_xlabel("X Coordinate", labelpad=20)
ax.set_ylabel("Y Coordinate", labelpad=20)
# Add legend with larger font and better spacing
if len(unique_clusters) <= 10: # Only show legend if not too many clusters
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=24)
def plot_clustered_movement_with_original_colors(ax, session, clusters, original_clusters, label):
"""
Plot the movement pattern with clusters color-coded, preserving original colors.
Lines between points in the same cluster are gray; lines connecting different clusters are red.
Args:
ax (matplotlib.axes.Axes): The axis to plot on.
session (pd.DataFrame): Filtered session data.
clusters (pd.Series): Filtered cluster labels.
original_clusters (pd.Series): Original cluster labels for color reference.
label (str): Label for the plot.
"""
unique_clusters = original_clusters.unique()
color_map = cm.get_cmap('tab10', len(unique_clusters))
# Calculate total operations for progress bar
remaining_clusters = set(clusters.values)
total_ops = sum(1 for c in unique_clusters if c in remaining_clusters) + len(session) - 1
# Plot with progress tracking
with tqdm(total=total_ops, desc="Plotting with original colors", leave=False) as pbar:
# Plot points
for cluster_id in unique_clusters:
if cluster_id in clusters.values: # Plot only remaining clusters
cluster_points = session[clusters == cluster_id]
ax.scatter(
cluster_points['X'], cluster_points['Y'],
label=f"Cluster {cluster_id}",
s=10,
c=np.array([color_map(cluster_id)] * len(cluster_points)),
alpha=0.6
)
pbar.update(1)
# Draw connections
for i in range(1, len(session)):
if clusters.iloc[i] == clusters.iloc[i - 1]: # Same cluster
ax.plot(
[session.iloc[i - 1]['X'], session.iloc[i]['X']],
[session.iloc[i - 1]['Y'], session.iloc[i]['Y']],
color='red', linewidth=0.5
)
else: # Different clusters
ax.plot(
[session.iloc[i - 1]['X'], session.iloc[i]['X']],
[session.iloc[i - 1]['Y'], session.iloc[i]['Y']],
color='lightgray', linewidth=0.8
)
pbar.update(1)
ax.set_title(label, pad=15)
ax.set_xlabel("X Coordinate", labelpad=10)
ax.set_ylabel("Y Coordinate", labelpad=10)
# Add legend with larger font and better spacing
if len(unique_clusters) <= 10: # Only show legend if not too many clusters
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)