Skip to content

Commit

Permalink
Update expression/effects plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
dzhu8 committed Feb 20, 2024
1 parent 5659780 commit e99379f
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 19 deletions.
58 changes: 47 additions & 11 deletions spateo/plotting/static/three_d_plot/three_dims_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,18 +941,20 @@ def plot_expression_3D(
group_key: Optional[str] = None,
ct_subset: Optional[list] = None,
pcutoff: Optional[float] = 99.7,
zero_opacity: float = 1.0,
):
"""Visualize gene expression in a 3D space.
Args:
target: Target gene to visualize
interaction: Interaction to visualize (e.g. "Igf1:Igf1r" for L:R model, "Igf1" for ligand model)
save_path: Path to save the figure to (will save as HTML file)
coords_key: Key for spatial coordinates in adata.obsm
adata: AnnData object containing spatial coordinates and cell type labels
save_path: Path to save the plot
gene: Will plot expression pattern of this gene
coords_key: Key in adata.obsm where spatial coordinates are stored
group_key: Optional key for grouping in adata.obs, but needed if "ct_subset" is provided
ct_subset: Optional list of cell types to include in the plot. If None, all cell types will be included.
pcutoff: Percentile cutoff for gene expression. Default is 99.7, which will set the max value plotted to the
99.7th percentile of gene expression values.
zero_opacity: Opacity of points with zero expression. Between 0.0 and 1.0. Default is 1.0.
"""
if group_key is not None:
if group_key not in adata.obs.keys():
Expand All @@ -963,25 +965,59 @@ def plot_expression_3D(
x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]

gene_expr = adata[:, gene].X.toarray().flatten()

# Lenient w/ the max value cutoff so that the colored dots are more distinct from black background
cutoff = np.percentile(gene_expr, pcutoff)
gene_expr[gene_expr > cutoff] = cutoff
scatter_expr = go.Scatter3d(
x=x,
y=y,
z=z,

# Separately plot zeros and nonzeros:
zero_indices = gene_expr == 0
non_zero_indices = gene_expr > 0

x_zeros, y_zeros, z_zeros = x[zero_indices], y[zero_indices], z[zero_indices]
x_non_zeros, y_non_zeros, z_non_zeros = x[non_zero_indices], y[non_zero_indices], z[non_zero_indices]
gene_expr_non_zeros = gene_expr[non_zero_indices]

# Plot non-zero expression values including one zero for color consistency
gene_expr_nz = np.append(gene_expr_non_zeros, 0) # Include one zero value
x_nz = np.append(x_non_zeros, x[zero_indices][0]) if len(x_zeros) > 0 else x_non_zeros
y_nz = np.append(y_non_zeros, y[zero_indices][0]) if len(y_zeros) > 0 else y_non_zeros
z_nz = np.append(z_non_zeros, z[zero_indices][0]) if len(z_zeros) > 0 else z_non_zeros

scatter_expr_nz = go.Scatter3d(
x=x_nz,
y=y_nz,
z=z_nz,
mode="markers",
marker=dict(
color=gene_expr,
color=gene_expr_nz,
colorscale="Hot",
size=2,
colorbar=dict(title=f"{gene}", x=0.75, titlefont=dict(size=24), tickfont=dict(size=24)),
),
showlegend=False,
)

fig = go.Figure(data=[scatter_expr])
# Add separate trace for zero expression values, if any, with specified opacity
if len(x_zeros) > 0:
scatter_expr_zeros = go.Scatter3d(
x=x_zeros,
y=y_zeros,
z=z_zeros,
mode="markers",
marker=dict(
color="#000000", # Use zero for color to match color scale
size=2,
opacity=zero_opacity, # Apply custom opacity for zeros
),
showlegend=False,
)
else:
scatter_expr_zeros = None

fig = go.Figure(data=[scatter_expr_nz])
if scatter_expr_zeros is not None:
fig.add_trace(scatter_expr_zeros)

title_dict = dict(
text=f"{gene}",
y=0.9,
Expand Down
93 changes: 85 additions & 8 deletions spateo/tools/CCI_effects_modeling/MuSIC_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ def plot_interaction_effect_3D(
save_path: str,
pcutoff: Optional[float] = 99.7,
min_value: Optional[float] = 0,
zero_opacity: float = 1.0,
):
"""Quick-visualize the magnitude of the predicted effect on target for a given interaction.
Expand All @@ -780,6 +781,7 @@ def plot_interaction_effect_3D(
pcutoff: Percentile cutoff for the colorbar. Will set all values above this percentile to this value.
min_value: Minimum value to set the colorbar to. Will set all values below this value to this value.
Defaults to 0.
zero_opacity: Opacity of points with zero expression. Between 0.0 and 1.0. Default is 1.0.
"""
targets = pd.read_csv(
os.path.join(os.path.splitext(self.output_path)[0], "design_matrix", "targets.csv"), index_col=0
Expand All @@ -806,13 +808,29 @@ def plot_interaction_effect_3D(
target_interaction_coef[target_interaction_coef > cutoff] = cutoff
target_interaction_coef[target_interaction_coef < min_value] = min_value
plot_vals = target_interaction_coef.values

# Separate data into zero and non-zero (keeping one zero with non-zeros)
is_zero = plot_vals == 0
if np.any(is_zero):
non_zeros = np.where(is_zero, 0, plot_vals)
# Select the first zero to keep
first_zero_idx = np.where(is_zero)[0][0]
# Temp- to get the correct indices of nonzeros
non_zeros[first_zero_idx] = 1
is_nonzero = non_zeros != 0
non_zeros[first_zero_idx] = 0
else:
non_zeros = plot_vals
is_nonzero = np.ones(len(plot_vals), dtype=bool)

# Two plots, one for the zeros and one for the nonzeros
scatter_effect = go.Scatter3d(
x=x,
y=y,
z=z,
x=x[is_nonzero],
y=y[is_nonzero],
z=z[is_nonzero],
mode="markers",
marker=dict(
color=plot_vals,
color=non_zeros,
colorscale="Hot",
size=2,
colorbar=dict(
Expand All @@ -825,7 +843,26 @@ def plot_interaction_effect_3D(
showlegend=False,
)

# Plot zeros separately (if there are any):
scatter_zeros = None
if np.any(is_zero):
scatter_zeros = go.Scatter3d(
x=x[is_zero],
y=y[is_zero],
z=z[is_zero],
mode="markers",
marker=dict(
color="#000000", # Use zero values for color to match the scale
size=2,
opacity=zero_opacity,
),
showlegend=False,
)

fig = go.Figure(data=[scatter_effect])
if scatter_zeros is not None:
fig.add_trace(scatter_zeros)

title_dict = dict(
text=f"{interaction.title()} Effect on {target.title()}",
y=0.9,
Expand Down Expand Up @@ -1031,6 +1068,8 @@ def plot_tf_effect_3D(
receptor_targets: bool = False,
target_gene_targets: bool = False,
pcutoff: float = 99.7,
min_value: float = 0,
zero_opacity: float = 1.0,
):
"""Quick-visualize the magnitude of the predicted effect on target for a given TF. Can only find the files
necessary for this if :func `CCI_deg_detection()` has been run.
Expand All @@ -1046,6 +1085,8 @@ def plot_tf_effect_3D(
target_gene_targets: Set True if target genes were used as the target genes for the :func
`CCI_deg_detection()` model.
pcutoff: Percentile cutoff for the colorbar. Will set all values above this percentile to this value.
min_value: Minimum value to set the colorbar to. Will set all values below this value to this value.
zero_opacity: Opacity of points with zero expression. Between 0.0 and 1.0. Default is 1.0.
"""
downstream_parent_dir = os.path.dirname(os.path.splitext(self.output_path)[0])
id = os.path.splitext(os.path.basename(self.output_path))[0]
Expand Down Expand Up @@ -1105,17 +1146,34 @@ def plot_tf_effect_3D(
)

target_tf_coef = downstream_coeffs[target].loc[adata.obs_names, f"b_{tf}"]

# Lenient w/ the max value cutoff so that the colored dots are more distinct from black background
cutoff = np.percentile(target_tf_coef.values, pcutoff)
target_tf_coef[target_tf_coef > cutoff] = cutoff
target_tf_coef[target_tf_coef < min_value] = min_value
plot_vals = target_tf_coef.values
# Separate data into zero and non-zero (keeping one zero with non-zeros)
is_zero = plot_vals == 0
if np.any(is_zero):
non_zeros = np.where(is_zero, 0, plot_vals)
# Select the first zero to keep
first_zero_idx = np.where(is_zero)[0][0]
# Temp- to get the correct indices of nonzeros
non_zeros[first_zero_idx] = 1
is_nonzero = non_zeros != 0
non_zeros[first_zero_idx] = 0
else:
non_zeros = plot_vals
is_nonzero = np.ones(len(plot_vals), dtype=bool)

# Two plots, one for the zeros and one for the nonzeros
scatter_effect = go.Scatter3d(
x=x,
y=y,
z=z,
x=x[is_nonzero],
y=y[is_nonzero],
z=z[is_nonzero],
mode="markers",
marker=dict(
color=plot_vals,
color=non_zeros,
colorscale="Hot",
size=2,
colorbar=dict(
Expand All @@ -1128,7 +1186,26 @@ def plot_tf_effect_3D(
showlegend=False,
)

# Plot zeros separately (if there are any):
scatter_zeros = None
if np.any(is_zero):
scatter_zeros = go.Scatter3d(
x=x[is_zero],
y=y[is_zero],
z=z[is_zero],
mode="markers",
marker=dict(
color="#000000", # Use zero values for color to match the scale
size=2,
opacity=zero_opacity,
),
showlegend=False,
)

fig = go.Figure(data=[scatter_effect])
if scatter_zeros is not None:
fig.add_trace(scatter_zeros)

title_dict = dict(
text=f"{tf.title()} Effect on {target.title()}",
y=0.9,
Expand Down

0 comments on commit e99379f

Please sign in to comment.