Skip to content

Commit

Permalink
feat: density heatmap (#598)
Browse files Browse the repository at this point in the history
Fixes #22 
This doesn't add every argument, but does add the core ones and trivial
ones

Here is a good example of the basic ways this can be used, either with
`x`, `y`, and `histfunc="count"`, or `x`, `y`, `z`, and any `histfunc`,
as well as other arguments to control the binning as, like `histogram`,
how someone chooses to bin has a large impact on the visualization

```
import deephaven.plot.express as dx
from deephaven import new_table, agg

from deephaven import time_table

result = time_table(period="PT1S").update(["X = randomGaussian(10, 3)", "Y = randomGaussian(10, 3)", "Z = randomGaussian(10, 3)"])

density_heatmap = dx.density_heatmap(result, "X", "Y", title="Test", range_bins_x=[0, 20], range_bins_y=[0, 20], nbinsx=20, nbinsy=20)
density_heatmap_z = dx.density_heatmap(result, "X", "Y", "Z", title="Test", histfunc="std", range_bins_x=[0, 20], range_bins_y=[0, 20], nbinsx=5, nbinsy=5)
```
  • Loading branch information
jnumainville authored Jul 16, 2024
1 parent e56424c commit 8fb924d
Show file tree
Hide file tree
Showing 15 changed files with 1,570 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
density_mapbox,
line_geo,
line_mapbox,
density_heatmap,
)

from .data import data_generators
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
DeephavenFigureNode,
)
from .generate import generate_figure, update_traces
from .custom_draw import draw_ohlc, draw_candlestick
from .custom_draw import draw_ohlc, draw_candlestick, draw_density_heatmap
from .RevisionManager import RevisionManager
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pandas import DataFrame
import plotly.graph_objects as go
from plotly.graph_objects import Figure
from plotly.validators.heatmap import ColorscaleValidator


def draw_finance(
Expand Down Expand Up @@ -109,3 +110,79 @@ def draw_candlestick(
"""

return draw_finance(data_frame, x_finance, open, high, low, close, go.Candlestick)


def draw_density_heatmap(
data_frame: DataFrame,
x: str,
y: str,
z: str,
labels: dict[str, str] | None = None,
range_color: list[float] | None = None,
color_continuous_scale: str | list[str] | None = "plasma",
color_continuous_midpoint: float | None = None,
opacity: float = 1.0,
title: str | None = None,
template: str | None = None,
) -> Figure:
"""Create a density heatmap
Args:
data_frame: The data frame to draw with
x: The name of the column containing x-axis values
y: The name of the column containing y-axis values
z: The name of the column containing bin values
labels: A dictionary of labels mapping columns to new labels
color_continuous_scale: A color scale or list of colors for a continuous scale
range_color: A list of two numbers that form the endpoints of the color axis
color_continuous_midpoint: A number that is the midpoint of the color axis
opacity: Opacity to apply to all markers. 0 is completely transparent
and 1 is completely opaque.
title: The title of the chart
template: The template for the chart.
Returns:
The plotly density heatmap
"""

# currently, most plots rely on px setting several attributes such as coloraxis, opacity, etc.
# so we need to set some things manually
# this could be done with handle_custom_args in generate.py in the future if
# we need to provide more options, but it's much easier to just set it here
# and doesn't risk breaking any other plots

heatmap = go.Figure(
go.Heatmap(
x=data_frame[x],
y=data_frame[y],
z=data_frame[z],
coloraxis="coloraxis1",
opacity=opacity,
)
)

range_color_list = range_color or [None, None]

colorscale_validator = ColorscaleValidator("colorscale", "draw_density_heatmap")

coloraxis_layout = dict(
colorscale=colorscale_validator.validate_coerce(color_continuous_scale),
cmid=color_continuous_midpoint,
cmin=range_color_list[0],
cmax=range_color_list[1],
)

if labels:
x = labels.get(x, x)
y = labels.get(y, y)

heatmap.update_layout(
coloraxis1=coloraxis_layout,
title=title,
template=template,
xaxis_title=x,
yaxis_title=y,
)

return heatmap
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"current_partition",
"colors",
"unsafe_update_figure",
"heatmap_agg_label",
}

# these are columns that are "attached" sequentially to the traces
Expand Down Expand Up @@ -669,8 +670,9 @@ def handle_custom_args(

elif arg == "bargap" or arg == "rangemode":
fig.update_layout({arg: val})
# x_axis_generators.append(key_val_generator("bargap", [val]))
# y_axis_generators.append(key_val_generator("bargap", [val]))

elif arg == "heatmap_agg_label":
fig.update_coloraxes(colorbar_title_text=val)

trace_generator = combined_generator(trace_generators)

Expand Down Expand Up @@ -781,7 +783,6 @@ def get_hover_body(

def hover_text_generator(
hover_mapping: list[dict[str, str]],
# hover_data - todo, dependent on arrays supported in data mappings
types: set[str] | None = None,
current_partition: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
Expand Down Expand Up @@ -824,6 +825,7 @@ def hover_text_generator(
def compute_labels(
hover_mapping: list[dict[str, str]],
hist_val_name: str | None,
heatmap_agg_label: str | None,
# hover_data - todo, dependent on arrays supported in data mappings
types: set[str],
labels: dict[str, str] | None,
Expand All @@ -836,6 +838,7 @@ def compute_labels(
Args:
hover_mapping: The mapping of variables to columns
hist_val_name: The histogram name for the value axis, generally histfunc
heatmap_agg_label: The aggregate density heatmap column title
types: Any types of this chart that require special processing
labels: A dictionary of old column name to new column name mappings
current_partition: The columns that this figure is partitioned by
Expand All @@ -846,9 +849,36 @@ def compute_labels(

calculate_hist_labels(hist_val_name, hover_mapping[0])

calculate_density_heatmap_labels(heatmap_agg_label, hover_mapping[0], labels)

relabel_columns(labels, hover_mapping, types, current_partition)


def calculate_density_heatmap_labels(
heatmap_agg_label: str | None,
hover_mapping: dict[str, str],
labels: dict[str, str] | None,
) -> None:
"""Calculate the labels for a density heatmap
The z column is renamed to the heatmap_agg_label
Args:
heatmap_agg_label: The name of the heatmap aggregate label
hover_mapping: The mapping of variables to columns
labels: A dictionary of labels mapping columns to new labels.
"""
labels = labels or {}
if heatmap_agg_label:
# the last part of the label is the z column, and could be replaced by labels
split_label = heatmap_agg_label.split(" ")
split_label[-1] = labels.get(split_label[-1], split_label[-1])
# it's also possible that someone wants to override the whole label
# plotly doesn't seem to do that, but it seems reasonable to allow
new_label = " ".join(split_label)
hover_mapping["z"] = labels.get(new_label, new_label)


def calculate_hist_labels(
hist_val_name: str | None, current_mapping: dict[str, str]
) -> None:
Expand All @@ -871,6 +901,7 @@ def add_axis_titles(
custom_call_args: dict[str, Any],
hover_mapping: list[dict[str, str]],
hist_val_name: str | None,
heatmap_agg_label: str | None,
) -> None:
"""Add axis titles. Generally, this only applies when there is a list variable
Expand All @@ -879,6 +910,7 @@ def add_axis_titles(
create hover and axis titles
hover_mapping: The mapping of variables to columns
hist_val_name: The histogram name for the value axis, generally histfunc
heatmap_agg_label: The aggregate density heatmap column title
"""
# Although hovertext is handled above for all plot types, plotly still
Expand All @@ -892,6 +924,9 @@ def add_axis_titles(
new_xaxis_titles = [hover_mapping[0].get("x", None)]
new_yaxis_titles = [hover_mapping[0].get("y", None)]

if heatmap_agg_label:
custom_call_args["heatmap_agg_label"] = heatmap_agg_label

# a specified axis title update should override this
if new_xaxis_titles:
custom_call_args["xaxis_titles"] = custom_call_args.get(
Expand Down Expand Up @@ -928,6 +963,9 @@ def create_hover_and_axis_titles(
(such as the y-axis if the x-axis is specified). Otherwise, there is a
legend or not depending on if there is a list of columns or not.
Density heatmaps are also an exception. If "heatmap_agg_label" is specified,
the z column is renamed to this label.
Args:
custom_call_args: The custom_call_args that are used to
create hover and axis titles
Expand All @@ -941,14 +979,26 @@ def create_hover_and_axis_titles(

labels = custom_call_args.get("labels", None)
hist_val_name = custom_call_args.get("hist_val_name", None)
heatmap_agg_label = custom_call_args.get("heatmap_agg_label", None)

current_partition = custom_call_args.get("current_partition", {})

compute_labels(hover_mapping, hist_val_name, types, labels, current_partition)
compute_labels(
hover_mapping,
hist_val_name,
heatmap_agg_label,
types,
labels,
current_partition,
)

hover_text = hover_text_generator(hover_mapping, types, current_partition)

add_axis_titles(custom_call_args, hover_mapping, hist_val_name)
if heatmap_agg_label:
# it's possible that heatmap_agg_label was relabeled, so grab the new label
heatmap_agg_label = hover_mapping[0]["z"]

add_axis_titles(custom_call_args, hover_mapping, hist_val_name, heatmap_agg_label)

return hover_text

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def partition_generator(self) -> Generator[dict[str, Any], None, None]:
"preprocess_hist" in self.groups
or "preprocess_freq" in self.groups
or "preprocess_time" in self.groups
or "preprocess_heatmap" in self.groups
) and self.preprocessor:
# still need to preprocess the base table
table, arg_update = cast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from ._layer import layer
from .subplots import make_subplots
from .maps import scatter_geo, scatter_mapbox, density_mapbox, line_geo, line_mapbox
from .heatmap import density_heatmap
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,8 @@ def histogram(
range_y: A list of two numbers that specify the range of the y-axis.
range_bins: A list of two numbers that specify the range of data that is used.
histfunc: The function to use when aggregating within bins. One of
'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std', 'sum',
or 'var'
'abs_sum', 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std',
'sum', or 'var'
cumulative: If True, values are cumulative.
nbins: The number of bins to use.
text_auto: If True, display the value at each bar.
Expand Down
90 changes: 90 additions & 0 deletions plugins/plotly-express/src/deephaven/plot/express/plots/heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

from typing import Callable, Literal

from deephaven.plot.express.shared import default_callback

from ._private_utils import process_args
from ..deephaven_figure import DeephavenFigure, draw_density_heatmap
from deephaven.table import Table


def density_heatmap(
table: Table,
x: str | None = None,
y: str | None = None,
z: str | None = None,
labels: dict[str, str] | None = None,
color_continuous_scale: str | list[str] | None = None,
range_color: list[float] | None = None,
color_continuous_midpoint: float | None = None,
opacity: float = 1.0,
log_x: bool = False,
log_y: bool = False,
range_x: list[float] | None = None,
range_y: list[float] | None = None,
range_bins_x: list[float | None] | None = None,
range_bins_y: list[float | None] | None = None,
histfunc: str = "count",
nbinsx: int = 10,
nbinsy: int = 10,
empty_bin_default: float | Literal["NaN"] | None = None,
title: str | None = None,
template: str | None = None,
unsafe_update_figure: Callable = default_callback,
) -> DeephavenFigure:
"""
A density heatmap creates a grid of colored bins. Each bin represents an aggregation of data points in that region.
Args:
table: A table to pull data from.
x: A column that contains x-axis values.
y: A column that contains y-axis values.
z: A column that contains z-axis values. If not provided, the count of joint occurrences of x and y will be used.
labels: A dictionary of labels mapping columns to new labels.
color_continuous_scale: A color scale or list of colors for a continuous scale
range_color: A list of two numbers that form the endpoints of the color axis
color_continuous_midpoint: A number that is the midpoint of the color axis
opacity: Opacity to apply to all markers. 0 is completely transparent
and 1 is completely opaque.
log_x: A boolean that specifies if the corresponding axis is a log axis or not.
log_y: A boolean that specifies if the corresponding axis is a log axis or not.
range_x: A list of two numbers that specify the range of the x axes.
None can be specified for no range
range_y: A list of two numbers that specify the range of the y axes.
None can be specified for no range
range_bins_x: A list of two numbers that specify the range of data that is used for x.
None can be specified to use the min and max of the data.
None can also be specified for either of the list values to use the min or max of the data, respectively.
range_bins_y: A list of two numbers that specify the range of data that is used for y.
None can be specified to use the min and max of the data.
None can also be specified for either of the list values to use the min or max of the data, respectively.
histfunc: The function to use when aggregating within bins. One of
'abs_sum', 'avg', 'count', 'count_distinct', 'max', 'median', 'min', 'std',
'sum', or 'var'
nbinsx: The number of bins to use for the x-axis
nbinsy: The number of bins to use for the y-axis
empty_bin_default: The value to use for bins that have no data.
If None and histfunc is 'count' or 'count_distinct', 0 is used.
Otherwise, if None or 'NaN', NaN is used.
'NaN' forces the bin to be NaN if no data is present, even if histfunc is 'count' or 'count_distinct'.
Note that if multiple points are required to color a bin, such as the case for a histfunc of 'std' or var,
the bin will still be NaN if less than the required number of points are present.
title: The title of the chart
template: The template for the chart.
unsafe_update_figure: An update function that takes a plotly figure
as an argument and optionally returns a plotly figure. If a figure is
not returned, the plotly figure passed will be assumed to be the return
value. Used to add any custom changes to the underlying plotly figure.
Note that the existing data traces should not be removed. This may lead
to unexpected behavior if traces are modified in a way that break data
mappings.
Returns:
DeephavenFigure: A DeephavenFigure that contains the density heatmap
"""
args = locals()

return process_args(args, {"preprocess_heatmap"}, px_func=draw_density_heatmap)
Loading

0 comments on commit 8fb924d

Please sign in to comment.