Skip to content

Commit

Permalink
Fix bug when all classes are hidden (#55)
Browse files Browse the repository at this point in the history
* Docstring and typing fixes

* Deprecations and private methods

* Fix bug when all classes are interactively hidden

Hiding all classes would throw an exception related to mismatched
data types during some internal update of the widget state. I traced
that back to the `SankeyPlot.plot.data[0].link` update and found it
broke if you used the link data from a `go.Sankey` that was
initialized with `link_kwargs`. Instead, intitializing with an empty
`go.Sankey` seems to solve the issue. That only arises if all classes
are hidden, so I just threw in an early return of an empty `go.Sankey`
in that case.
  • Loading branch information
aazuspan authored Mar 13, 2024
1 parent 75d61a1 commit 61ef35e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 52 deletions.
28 changes: 17 additions & 11 deletions sankee/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from warnings import warn

import ee
import pandas as pd

Expand Down Expand Up @@ -75,14 +77,14 @@ def get_year(self, year: int) -> ee.Image:
)

img = self.collection.filterDate(str(year), str(year + 1)).first()
img = self.set_visualization_properties(img)
img = self._set_visualization_properties(img)

if self.nodata is not None:
img = img.updateMask(img.neq(self.nodata))

return img.select(self.band)

def set_visualization_properties(self, image: ee.Image) -> ee.Image:
def _set_visualization_properties(self, image: ee.Image) -> ee.Image:
"""Set the properties used by Earth Engine to automatically assign a palette to an image
from this dataset."""
return image.set(
Expand All @@ -92,7 +94,7 @@ def set_visualization_properties(self, image: ee.Image) -> ee.Image:
[c.replace("#", "") for c in self.palette.values()],
)

def list_years(self) -> ee.List:
def _list_years(self) -> ee.List:
"""Get an ee.List of all years in the collection."""
return (
self.collection.aggregate_array("system:time_start")
Expand Down Expand Up @@ -137,8 +139,6 @@ def sankify(
projection.
seed : int, default 0
The seed value used to generate repeatable results during random sampling.
exclude : None
Unused parameter that will be removed in a future release.
label_type : str, default "class"
The type of label to display for each link, one of "class", "percent", or "count".
Selecting "class" will use the class label, "percent" will use the proportion of
Expand All @@ -153,6 +153,12 @@ def sankify(
SankeyPlot
An interactive Sankey plot widget.
"""
if exclude is not None:
warn(
"The `exclude` parameter is unused and will be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)
if len(years) < 2:
raise ValueError("Select at least two years.")
if len(set(years)) != len(years):
Expand Down Expand Up @@ -184,7 +190,7 @@ def sankify(
)


class LCMS_Dataset(Dataset):
class _LCMS_Dataset(Dataset):
def get_year(self, year: int) -> ee.Image:
"""Get one year's image from the dataset. LCMS splits up each year into two images: CONUS
and SEAK. This merges those into a single image."""
Expand All @@ -200,7 +206,7 @@ def get_year(self, year: int) -> ee.Image:
return merged


class CCAP_Dataset(Dataset):
class _CCAP_Dataset(Dataset):
def get_year(self, year: int) -> ee.Image:
"""Get one year's image from the dataset. C-CAP splits up each year into multiple images,
so merge those and set the class value and palette metadata to allow automatic
Expand All @@ -221,12 +227,12 @@ def get_year(self, year: int) -> ee.Image:
.setDefaultProjection("EPSG:5070")
)

img = self.set_visualization_properties(img)
img = self._set_visualization_properties(img)

return img


LCMS_LU = LCMS_Dataset(
LCMS_LU = _LCMS_Dataset(
name="LCMS LU - Land Change Monitoring System Land Use",
id="USFS/GTAC/LCMS/v2022-8",
band="Land_Use",
Expand All @@ -253,7 +259,7 @@ def get_year(self, year: int) -> ee.Image:
)

# https://developers.google.com/earth-engine/datasets/catalog/USFS_GTAC_LCMS_v2020-5
LCMS_LC = LCMS_Dataset(
LCMS_LC = _LCMS_Dataset(
name="LCMS LC - Land Change Monitoring System Land Cover",
id="USFS/GTAC/LCMS/v2022-8",
band="Land_Cover",
Expand Down Expand Up @@ -535,7 +541,7 @@ def get_year(self, year: int) -> ee.Image:
)

# https://samapriya.github.io/awesome-gee-community-datasets/projects/ccap_mlc/
CCAP_LC30 = CCAP_Dataset(
CCAP_LC30 = _CCAP_Dataset(
name="C-CAP - NOAA Coastal Change Analysis Program 30m",
id="projects/sat-io/open-datasets/NOAA/ccap_30m",
band="b1",
Expand Down
77 changes: 41 additions & 36 deletions sankee/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import namedtuple
from typing import Literal

import ee
import ipywidgets as widgets
Expand Down Expand Up @@ -36,7 +37,7 @@ def sankify(
title: None | str = None,
scale: None | int = None,
seed: int = 0,
label_type: None | str = "class",
label_type: None | Literal["class", "percent", "count"] = "class",
theme: str | themes.Theme = "default",
) -> SankeyPlot:
"""
Expand Down Expand Up @@ -77,10 +78,10 @@ def sankify(
seed : int, default 0
The seed value used to generate repeatable results during random sampling.
label_type : str, default "class"
The type of label to display for each link, one of "class", "percent", "count", or False.
The type of label to display for each link, one of "class", "percent", "count", or None.
Selecting "class" will use the class label, "percent" will use the proportion of sampled
pixels in each class, and "count" will use the number of sampled pixels in each class.
False will disable link labels.
None will disable link labels.
theme : str or Theme
The theme to apply to the Sankey diagram. Can be the name of a built-in theme (e.g. "d3") or
a custom `sankee.Theme` object.
Expand Down Expand Up @@ -126,12 +127,13 @@ def sankify(
class SankeyPlot(widgets.DOMWidget):
def __init__(
self,
*,
data: pd.DataFrame,
labels: dict[int, str],
palette: dict[int, str],
title: str,
samples: ee.FeatureCollection,
label_type: str,
label_type: None | Literal["class", "percent", "count"],
theme: str | themes.Theme,
):
self.data = data
Expand All @@ -140,15 +142,15 @@ def __init__(
self.title = title
self.samples = samples
self.label_type = label_type
self.theme = theme
self.theme = theme if isinstance(theme, themes.Theme) else themes.load_theme(theme)

self.hide = []
# Initialized by `self.generate_plot`
self.df = None
self.plot = self.generate_plot()
self.gui = self.generate_gui()
self.plot = self._generate_figurewidget()
self.gui = self._generate_gui()

def get_sorted_classes(self) -> pd.Series:
def _get_sorted_classes(self) -> pd.Series:
"""Return all unique class values, sorted by the total number of observations."""
start_count = (
self.df.loc[:, ["source", "total"]]
Expand All @@ -168,11 +170,11 @@ def get_sorted_classes(self) -> pd.Series:

return total_count.sort_values(by="count", ascending=False)["class"].reset_index(drop=True)

def get_active_classes(self) -> pd.Series:
def _get_active_classes(self) -> pd.Series:
"""Return all unique active, visibile class values after filtering."""
return self.df[["source", "target"]].melt().value.unique()

def generate_plot_parameters(self) -> SankeyParameters:
def _generate_plot_parameters(self) -> SankeyParameters:
"""Generate Sankey plot parameters from a formatted, cleaned dataframe"""
df = self.df.copy()

Expand Down Expand Up @@ -226,7 +228,7 @@ def generate_plot_parameters(self) -> SankeyParameters:
all_classes["label"] = ""
else:
raise ValueError(
"Invalid label_type. Choose from 'class', 'percent', 'count', or False."
"Invalid label_type. Choose from 'class', 'percent', 'count', or None."
)

return SankeyParameters(
Expand All @@ -240,7 +242,7 @@ def generate_plot_parameters(self) -> SankeyParameters:
value=df.changed,
)

def generate_dataframe(self) -> pd.DataFrame:
def _generate_dataframe(self) -> pd.DataFrame:
"""Convert raw sampling data to a formatted dataframe"""
data = self.data.copy()

Expand Down Expand Up @@ -299,14 +301,15 @@ def _model_id(self):
return self.gui._model_id

def update_layout(self, *args, **kwargs):
"""Pass layout changes to the plot. This is primarily kept for compatibility with geemap."""
"""Pass layout changes to the plot."""
# This is primarily kept for compatibility with geemap
self.plot.update_layout(*args, **kwargs)

def generate_gui(self):
def _generate_gui(self):
BUTTON_HEIGHT = "24px"
BUTTON_WIDTH = "24px"

unique_classes = self.get_sorted_classes()
unique_classes = self._get_sorted_classes()

def toggle_button(button):
button.toggle()
Expand All @@ -322,13 +325,13 @@ def toggle_button(button):
update_plot()

def update_plot():
"""Swap new data into the plot"""
new_plot = self.generate_plot()
self.plot.data[0].link = new_plot.data[0].link
self.plot.data[0].node = new_plot.data[0].node
"""Swap new data into the plot."""
new_sankey = self._generate_sankey()
self.plot.data[0].link = new_sankey.link
self.plot.data[0].node = new_sankey.node

buttons = []
active_classes = self.get_active_classes()
active_classes = self._get_active_classes()
for i in unique_classes:
label = self.labels[i]
on_color = self.palette[i]
Expand Down Expand Up @@ -373,18 +376,20 @@ def reset_plot(_):

return gui

def generate_plot(self) -> go.Figure:
self.df = self.generate_dataframe()
params = self.generate_plot_parameters()
def _generate_sankey(self) -> go.Figure:
"""Generate the Sankey plot based on the currently visible classes."""
self.df = self._generate_dataframe()
# Explicitly return an empty Sankey plot if all classes are hidden to avoid widget update
# errors.
if len(self.df) == 0:
return go.Sankey()

theme = (
self.theme if isinstance(self.theme, themes.Theme) else themes.load_theme(self.theme)
)
params = self._generate_plot_parameters()

node_kwargs = dict(
customdata=params.node_labels,
hovertemplate="<b>%{customdata}</b><extra></extra>",
label=[f"<span style='{theme.label_style}'>{s}</span>" for s in params.label],
label=[f"<span style='{self.theme.label_style}'>{s}</span>" for s in params.label],
color=params.node_palette,
)
link_kwargs = dict(
Expand All @@ -396,18 +401,18 @@ def generate_plot(self) -> go.Figure:
hovertemplate="%{customdata} <extra></extra>",
)

fig = go.FigureWidget(
data=[
go.Sankey(
arrangement="snap",
node={**node_kwargs, **theme.node_kwargs},
link={**link_kwargs, **theme.link_kwargs},
)
]
return go.Sankey(
arrangement="snap",
node={**node_kwargs, **self.theme.node_kwargs},
link={**link_kwargs, **self.theme.link_kwargs},
)

def _generate_figurewidget(self) -> go.FigureWidget:
"""Generate the FigureWidget that wraps the Sankey plot."""
fig = go.FigureWidget(data=[self._generate_sankey()])

fig.update_layout(
title_text=f"<span style='{theme.title_style}'>{self.title}</span>"
title_text=f"<span style='{self.theme.title_style}'>{self.title}</span>"
if self.title
else None,
font_size=16,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def test_get_year_CORINE():

@pytest.mark.parametrize("dataset", sankee.datasets.datasets, ids=lambda d: d.name)
def test_years(dataset):
assert dataset.years == tuple(dataset.list_years().getInfo())
"""Check that the hard-coded dataset years match the Earth Engine catalog years."""
assert dataset.years == tuple(dataset._list_years().getInfo())


def test_get_unsupported_year():
Expand Down Expand Up @@ -117,8 +118,8 @@ def test_sankify():
title="My plot!",
)

params1 = sankey1.generate_plot_parameters()
params2 = sankey2.generate_plot_parameters()
params1 = sankey1._generate_plot_parameters()
params2 = sankey2._generate_plot_parameters()

for p1, p2 in zip(params1, params2):
assert_series_equal(p1, p2)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def sankey():

def test_get_sorted_classes(sankey):
"""Test that classes are correctly sorted."""
assert_series_equal(sankey.get_sorted_classes(), pd.Series([1, 2, 4, 3]), check_names=False)
assert_series_equal(sankey._get_sorted_classes(), pd.Series([1, 2, 4, 3]), check_names=False)


def test_plot_parameters(sankey):
"""Test that plot parameters are generated correctly."""
params = sankey.generate_plot_parameters()
params = sankey._generate_plot_parameters()
node_labels = ["start", "start", "start", "end", "end", "end", "end"]
label = [
"Agriculture",
Expand Down

0 comments on commit 61ef35e

Please sign in to comment.