Skip to content

Commit

Permalink
Compare and Swap Holoviews Extension (#733)
Browse files Browse the repository at this point in the history
* initial fix for comparing and swaping hv enviroment

* add docstring and check for supported backends

* pre-commit

* refactor

* update docstring

* update comments and docstring

* add to internal API

---------

Co-authored-by: Orhan Eroglu <32553057+erogluorhan@users.noreply.github.com>
  • Loading branch information
philipc2 and erogluorhan authored Apr 2, 2024
1 parent 545a18d commit b417ac8
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/internal_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ Visualization
plot.dataarray_plot._plot_data_as_points
plot.dataarray_plot._polygon_raster
plot.dataarray_plot._point_raster
plot.utils.HoloviewsBackend

Slicing
-------
Expand Down
18 changes: 8 additions & 10 deletions uxarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import warnings

import uxarray.plot.utils


def plot(uxda, **kwargs):
"""Default Plotting Method for UxDataArray."""
Expand Down Expand Up @@ -229,8 +231,7 @@ def _point_raster(
# apply projection to coordinates
lon, lat, _ = projection.transform_points(ccrs.PlateCarree(), lon, lat).T

# this will be fixed in #733
hv.extension("bokeh")
uxarray.plot.utils.backend.assign(backend=backend)

# construct a dask dataframe from coordinates and data
point_dict = {"lon": lon, "lat": lat, "var": uxda.data}
Expand All @@ -241,7 +242,6 @@ def _point_raster(

if backend == "matplotlib":
# use holoviews matplotlib backend
hv.extension("matplotlib")
raster = hds_rasterize(
points,
pixel_ratio=pixel_ratio,
Expand All @@ -258,7 +258,6 @@ def _point_raster(
)
elif backend == "bokeh":
# use holoviews bokeh backend
hv.extension("bokeh")
raster = hds_rasterize(
points,
pixel_ratio=pixel_ratio,
Expand Down Expand Up @@ -315,9 +314,10 @@ def _polygon_raster(

hv_polygons = hv.Polygons(gdf, vdims=[uxda.name])

uxarray.plot.utils.backend.assign(backend=backend)

if backend == "matplotlib":
# use holoviews matplotlib backend
hv.extension("matplotlib")
raster = hds_rasterize(
hv_polygons,
pixel_ratio=pixel_ratio,
Expand All @@ -334,7 +334,6 @@ def _polygon_raster(
)
elif backend == "bokeh":
# use holoviews bokeh backend
hv.extension("bokeh")
raster = hds_rasterize(
hv_polygons,
pixel_ratio=pixel_ratio,
Expand Down Expand Up @@ -404,15 +403,14 @@ def polygons(

hv_polygons = hv.Polygons(gdf, vdims=[uxda.name])

uxarray.plot.utils.backend.assign(backend=backend)
if backend == "matplotlib":
# use holoviews matplotlib backend
hv.extension("matplotlib")

return hv_polygons.opts(colorbar=colorbar, cmap=cmap, **kwargs)

elif backend == "bokeh":
# use holoviews bokeh backend
hv.extension("bokeh")
return hv_polygons.opts(
width=width,
height=height,
Expand Down Expand Up @@ -521,9 +519,10 @@ def _plot_data_as_points(
verts = np.column_stack([lon, lat, uxda.values])
hv_points = Points(verts, vdims=["z"])

uxarray.plot.utils.backend.assign(backend=backend)

if backend == "matplotlib":
# use holoviews matplotlib backend
hv.extension("matplotlib")
return hv_points.opts(
color="z",
colorbar=colorbar,
Expand All @@ -535,7 +534,6 @@ def _plot_data_as_points(

elif backend == "bokeh":
# use holoviews bokeh backend
hv.extension("bokeh")
return hv_points.opts(
color="z",
width=width,
Expand Down
11 changes: 6 additions & 5 deletions uxarray/plot/grid_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
import holoviews as hv

import uxarray.plot.utils


def plot(grid: Grid, **kwargs):
"""Default Plotting Method for Grid."""
Expand Down Expand Up @@ -42,15 +44,14 @@ def mesh(

hv_paths = hv.Path(gdf)

uxarray.plot.utils.backend.assign(backend=backend)

if backend == "matplotlib":
# use holoviews matplotlib backend
hv.extension("matplotlib")

return hv_paths.opts(**kwargs)

elif backend == "bokeh":
# use holoviews bokeh backend
hv.extension("bokeh")
return hv_paths.opts(
width=width, height=height, xlabel=xlabel, ylabel=ylabel, **kwargs
)
Expand Down Expand Up @@ -174,14 +175,14 @@ def _plot_coords_as_points(
else:
raise ValueError("Invalid element selected.")

uxarray.plot.utils.backend.assign(backend=backend)

if backend == "matplotlib":
# use holoviews matplotlib backend
hv.extension("matplotlib")
return hv_points.opts(**kwargs)

elif backend == "bokeh":
# use holoviews bokeh backend
hv.extension("bokeh")
return hv_points.opts(
width=width, height=height, xlabel=xlabel, ylabel=ylabel, **kwargs
)
32 changes: 32 additions & 0 deletions uxarray/plot/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import holoviews as hv


class HoloviewsBackend:
"""Utility class to compare and set a HoloViews plotting backend for
visualization."""

def __init__(self):
self.backend = None

def assign(self, backend: str):
"""Assigns a backend for use with HoloViews visualization.
Parameters
----------
backend : str
Plotting backend to use, one of 'matplotlib', 'bokeh'
"""

if backend not in ["bokeh", "matplotlib"]:
raise ValueError(
f"Unsupported backend. Expected one of ['bokeh', 'matplotlib'], but received {backend}"
)

if backend != self.backend:
# only call hv.extension if it needs to be changed
hv.extension(backend)
self.backend = backend


# global reference to holoviews backend utility class
backend = HoloviewsBackend()

0 comments on commit b417ac8

Please sign in to comment.