Skip to content

Commit

Permalink
adapt decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
timonmerk committed Oct 11, 2024
1 parent 5f83915 commit 670a7dd
Show file tree
Hide file tree
Showing 4 changed files with 634 additions and 619 deletions.
248 changes: 129 additions & 119 deletions figure_27_get_ch_ind_per_all_ch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import sys

def run_channel(sub, ch, ):
def run_channel(sub, ch, ch_idx):

per_ind = []
df_sub = df[df["sub"] == sub].copy()
Expand All @@ -23,7 +23,7 @@ def run_channel(sub, ch, ):
if label == "pkg_tremor":
y = df_sub[label].copy() > 1
elif label == "pkg_bk":
y = df_sub[label].copy() > 80
y = df_sub[label].copy() > 50
elif label == "pkg_dk":
y = (df_sub[label].copy() / df_sub[label].max()) > 0.02
else:
Expand Down Expand Up @@ -58,7 +58,7 @@ def run_channel(sub, ch, ):
per_ind.append({
"sub": sub,
"ch": ch,
"ch_orig": ch_names[ch_idx],
"ch_orig": ch_names_orig[ch_idx],
"label": label,
"classification": CLASSIFICATION,
"per": per
Expand All @@ -68,7 +68,7 @@ def run_channel(sub, ch, ):

if __name__ == "__main__":

RUN_DECODING = False
RUN_DECODING = True
if RUN_DECODING:
RUN_ON_CLUSTER = False
if RUN_ON_CLUSTER is False:
Expand All @@ -84,20 +84,29 @@ def run_channel(sub, ch, ):
df = pd.read_csv(PATH_, index_col=0)
df_ch_used = pd.read_csv(ch_used, index_col=0)

run_idx = int(sys.argv[1])
sub_idx = run_idx // 4
ch_idx = run_idx % 4
subs = df_ch_used["sub"].unique()
for sub in subs:
print(f"sub: {sub}")
ch_names_orig = df_ch_used[df_ch_used["sub"] == sub].iloc[0, :4].values
ch_names = df_ch_used.columns[:4]
for ch_idx, ch in enumerate(ch_names):
print(f"ch: {ch}")
run_channel(sub, ch, ch_idx)

#run_idx = int(sys.argv[1])
#sub_idx = run_idx // 4
#ch_idx = run_idx % 4


sub = df_ch_used["sub"].unique()[sub_idx]
ch_names_orig = df_ch_used[df_ch_used["sub"] == sub].iloc[0, :4].values
ch_names = df_ch_used.columns[:4]
#sub = df_ch_used["sub"].unique()[sub_idx]
#ch_names_orig = df_ch_used[df_ch_used["sub"] == sub].iloc[0, :4].values
#ch_names = df_ch_used.columns[:4]

ch = ch_names[ch_idx]
#ch = ch_names[ch_idx]

run_channel(sub, ch)

MERGE_FILES = False
MERGE_FILES = True
if MERGE_FILES:
PATH_PER = r"/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/out_per/out_dir"
l_ = []
Expand All @@ -110,7 +119,7 @@ def run_channel(sub, ch, ):
new_df = pd.concat(l_, axis=0)
new_df.to_csv(os.path.join(PATH_PER, "df_per_ind_all.csv"))

MERGE_WITH_COORDS = False
MERGE_WITH_COORDS = True
if MERGE_WITH_COORDS:
PATH_PER = r"/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/out_per/out_dir"
df = pd.read_csv(os.path.join(PATH_PER, "df_per_ind_all.csv"), index_col=0)
Expand Down Expand Up @@ -156,7 +165,7 @@ def run_channel(sub, ch, ):
df_per_ind_all_coords = pd.DataFrame(l_)
df_per_ind_all_coords.to_csv(os.path.join(PATH_PER, "df_per_ind_all_coords.csv"))

PLOT = True
PLOT = False
if PLOT:
PATH_PER = r"/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/out_per/out_dir"
PATH_FIGURES = r"/Users/Timon/Library/CloudStorage/OneDrive-Charité-UniversitätsmedizinBerlin/Shared Documents - ICN Data World/General/Data/UCSF_OLARU/figures_ucsf"
Expand Down Expand Up @@ -208,127 +217,128 @@ def run_channel(sub, ch, ):
plt.savefig(os.path.join(PATH_FIGURES, "ECoG_performances_regress_small.pdf"))
plt.show(block=True)

# PLOT STN coordinates
df_query = df.query("loc == 'STN' and label== 'pkg_dk' and classification == False")
stn_strip_xyz = df_query[["x", "y", "z"]].values
strip_color = df_query["per"]
PLOT_STN = False
if PLOT_STN:
# PLOT STN coordinates
df_query = df.query("loc == 'STN' and label== 'pkg_dk' and classification == False")
stn_strip_xyz = df_query[["x", "y", "z"]].values
strip_color = df_query["per"]

fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
axes.scatter(x_stn, z_stn, c="gray", s=0.025)
axes.axes.set_aspect("equal", anchor="C")
pos_ecog = axes.scatter(
stn_strip_xyz[:, 0],
stn_strip_xyz[:, 2],
c=np.clip(strip_color, 0, 1),
s=100,
alpha=0.8,
cmap="viridis",
marker="o",
label="ecog electrode",
)
axes.axis("off")
plt.show(block=True)

fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
axes.scatter(x_stn, z_stn, c="gray", s=0.025)
axes.axes.set_aspect("equal", anchor="C")
pos_ecog = axes.scatter(
# load the GPi coordinates
PATH_GPI = r"/Users/Timon/Downloads/v_1_1/DISTAL (Ewert 2017)/lh/STN.nii"
PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/AHEAD Atlas (Alkemade 2020)/lh/GPi_mask.nii"

#PATH_GPI = r"/Users/Timon/Downloads/v_1_1/DISTAL Minimal (Ewert 2017)/lh/GPi.nii"
PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/DISTAL Nano (Ewert 2017)/lh/GPi.nii"

PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/DISTAL Nano (Ewert 2017)/lh/GPi_mask.nii"
PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/AHEAD Atlas (Alkemade 2020)/lh/GPi_mask.nii"

PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/DISTAL Nano (Ewert 2017)/lh/GPi.nii"
import nibabel as nib
img = nib.load(PATH_GPI)
data = img.get_fdata()
affine = img.affine
x_gpi, y_gpi, z_gpi = np.where(data > 0)

voxel_indices = np.array([x_gpi, y_gpi, z_gpi]).T
mni_coords = nib.affines.apply_affine(affine, voxel_indices)
x_gpi = mni_coords[:, 0]
y_gpi = mni_coords[:,1]
z_gpi = mni_coords[:,2]

df_query_GP = df.query("loc == 'GP' and label== 'pkg_dk' and classification == False")
gp_strip_xyz = df_query_GP[["x", "y", "z"]].values
strip_color_GP = df_query_GP["per"]

fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
axes.scatter(x_gpi, y_gpi, c="gray", s=0.025)
axes.scatter(x_stn, y_stn, c="gray", s=0.025)
axes.axes.set_aspect("equal", anchor="C")
pos_gpi = axes.scatter(
gp_strip_xyz[:, 0],
gp_strip_xyz[:, 1],
c=np.clip(strip_color_GP, 0, 1),
s=100,
alpha=0.8,
cmap="viridis",
marker="o",
label="ecog electrode",
)
df_query_STN = df.query("loc == 'STN' and label== 'pkg_dk' and classification == False")
stn_strip_xyz = df_query_STN[["x", "y", "z"]].values
strip_color_stn = df_query_STN["per"]
pos_stn = axes.scatter(
stn_strip_xyz[:, 0],
stn_strip_xyz[:, 2],
c=np.clip(strip_color, 0, 1),
stn_strip_xyz[:, 1],
c=np.clip(strip_color_stn, 0, 1),
s=100,
alpha=0.8,
cmap="viridis",
marker="o",
label="ecog electrode",
)
axes.axis("off")
plt.show(block=True)
axes.axis("off")
plt.show(block=True)

# load the GPi coordinates
PATH_GPI = r"/Users/Timon/Downloads/v_1_1/DISTAL (Ewert 2017)/lh/STN.nii"
PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/AHEAD Atlas (Alkemade 2020)/lh/GPi_mask.nii"

#PATH_GPI = r"/Users/Timon/Downloads/v_1_1/DISTAL Minimal (Ewert 2017)/lh/GPi.nii"
PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/DISTAL Nano (Ewert 2017)/lh/GPi.nii"

PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/DISTAL Nano (Ewert 2017)/lh/GPi_mask.nii"
PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/AHEAD Atlas (Alkemade 2020)/lh/GPi_mask.nii"

PATH_GPI = r"/Users/Timon/Documents/MATLAB/leaddbs/templates/space/MNI152NLin2009bAsym/atlases/DISTAL Nano (Ewert 2017)/lh/GPi.nii"
import nibabel as nib
img = nib.load(PATH_GPI)
data = img.get_fdata()
affine = img.affine
x_gpi, y_gpi, z_gpi = np.where(data > 0)

voxel_indices = np.array([x_gpi, y_gpi, z_gpi]).T
mni_coords = nib.affines.apply_affine(affine, voxel_indices)
x_gpi = mni_coords[:, 0]
y_gpi = mni_coords[:,1]
z_gpi = mni_coords[:,2]

df_query_GP = df.query("loc == 'GP' and label== 'pkg_dk' and classification == False")
gp_strip_xyz = df_query_GP[["x", "y", "z"]].values
strip_color_GP = df_query_GP["per"]

fig, axes = plt.subplots(1, 1, facecolor=(1, 1, 1), figsize=(14, 9))
axes.scatter(x_gpi, y_gpi, c="gray", s=0.025)
axes.scatter(x_stn, y_stn, c="gray", s=0.025)
axes.axes.set_aspect("equal", anchor="C")
pos_gpi = axes.scatter(
gp_strip_xyz[:, 0],
gp_strip_xyz[:, 1],
c=np.clip(strip_color_GP, 0, 1),
s=100,
alpha=0.8,
cmap="viridis",
marker="o",
label="ecog electrode",
)
df_query_STN = df.query("loc == 'STN' and label== 'pkg_dk' and classification == False")
stn_strip_xyz = df_query_STN[["x", "y", "z"]].values
strip_color_stn = df_query_STN["per"]
pos_stn = axes.scatter(
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Assuming x_gpi, y_gpi, z_gpi, x_stn, y_stn, z_stn, gp_strip_xyz, strip_color_GP, df, and strip_color_stn are defined

fig = plt.figure(facecolor=(1, 1, 1), figsize=(14, 9))
axes = fig.add_subplot(111, projection='3d')

# Scatter plot for GPI
axes.scatter(x_gpi, y_gpi, z_gpi, c="gray", s=0.025)
axes.scatter(x_stn, y_stn, z_stn, c="gray", s=0.025)

# Scatter plot for GP strip
pos_gpi = axes.scatter(
gp_strip_xyz[:, 0],
gp_strip_xyz[:, 1],
gp_strip_xyz[:, 2],
c=np.clip(strip_color_GP, 0, 1),
s=100,
alpha=0.8,
cmap="viridis",
marker="o",
label="ecog electrode",
)

# Query and scatter plot for STN strip
df_query_STN = df.query("loc == 'STN' and label== 'pkg_dk' and classification == False")
stn_strip_xyz = df_query_STN[["x", "y", "z"]].values
strip_color_stn = df_query_STN["per"]
pos_stn = axes.scatter(
stn_strip_xyz[:, 0],
stn_strip_xyz[:, 1],
stn_strip_xyz[:, 2],
c=np.clip(strip_color_stn, 0, 1),
s=100,
alpha=0.8,
cmap="viridis",
marker="o",
label="ecog electrode",
)
axes.axis("off")
plt.show(block=True)


import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

# Assuming x_gpi, y_gpi, z_gpi, x_stn, y_stn, z_stn, gp_strip_xyz, strip_color_GP, df, and strip_color_stn are defined

fig = plt.figure(facecolor=(1, 1, 1), figsize=(14, 9))
axes = fig.add_subplot(111, projection='3d')

# Scatter plot for GPI
axes.scatter(x_gpi, y_gpi, z_gpi, c="gray", s=0.025)
axes.scatter(x_stn, y_stn, z_stn, c="gray", s=0.025)

# Scatter plot for GP strip
pos_gpi = axes.scatter(
gp_strip_xyz[:, 0],
gp_strip_xyz[:, 1],
gp_strip_xyz[:, 2],
c=np.clip(strip_color_GP, 0, 1),
s=100,
alpha=0.8,
cmap="viridis",
marker="o",
label="ecog electrode",
)

# Query and scatter plot for STN strip
df_query_STN = df.query("loc == 'STN' and label== 'pkg_dk' and classification == False")
stn_strip_xyz = df_query_STN[["x", "y", "z"]].values
strip_color_stn = df_query_STN["per"]
pos_stn = axes.scatter(
stn_strip_xyz[:, 0],
stn_strip_xyz[:, 1],
stn_strip_xyz[:, 2],
c=np.clip(strip_color_stn, 0, 1),
s=100,
alpha=0.8,
cmap="viridis",
marker="o",
label="ecog electrode",
)

axes.set_aspect("equal")
axes.axis("off")
plt.show(block=True)
axes.set_aspect("equal")
axes.axis("off")
plt.show(block=True)
Loading

0 comments on commit 670a7dd

Please sign in to comment.