Skip to content

Commit

Permalink
Support Xarray Plotting Routines (#760)
Browse files Browse the repository at this point in the history
* add support for calling xarray's plotting routines

* add tests

* reset mpl backend

* add matplotlib-inline to enviroments

* fix package name

* generalize implementation
  • Loading branch information
philipc2 authored Apr 17, 2024
1 parent 00a4f89 commit a85db05
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 4 deletions.
1 change: 1 addition & 0 deletions ci/asv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- gmpy2
- holoviews
- matplotlib-base
- matplotlib-inline
- netcdf4
- numba
- numpy
Expand Down
1 change: 1 addition & 0 deletions ci/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies:
- sphinx-design
- nbsphinx
- matplotlib-base
- matplotlib-inline
- shapely
- hvplot
- holoviews
Expand Down
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies:
- gmpy2
- holoviews
- matplotlib-base
- matplotlib-inline
- netcdf4
- numba
- numpy
Expand Down
21 changes: 21 additions & 0 deletions test/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,24 @@ def test_clabel(self):
raster_no_clabel = uxds['v1'][0][0].plot.rasterize(method='point')

raster_with_clabel = uxds['v1'][0][0].plot.rasterize(method='point', clabel='Foo')



class TestXarrayMethods(TestCase):

def test_dataset(self):
"""Tests whether a Xarray DataArray method can be called through the
UxDataArray plotting accessor."""
uxds = ux.open_dataset(gridfile_geoflow, datafile_geoflow)

# plot.hist() is an xarray method
assert hasattr(uxds['v1'].plot, 'hist')


def test_dataarray(self):
"""Tests whether a Xarray Dataset method can be called through the
UxDataset plotting accessor."""
uxds = ux.open_dataset(gridfile_geoflow, datafile_geoflow)

# plot.scatter() is an xarray method
assert hasattr(uxds.plot, 'scatter')
41 changes: 37 additions & 4 deletions uxarray/plot/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import functools

import warnings

if TYPE_CHECKING:
from uxarray.core.dataset import UxDataset
Expand All @@ -12,6 +11,7 @@

import uxarray.plot.grid_plot as grid_plot
import uxarray.plot.dataarray_plot as dataarray_plot
import uxarray.plot.utils

import cartopy.crs as ccrs

Expand Down Expand Up @@ -269,6 +269,22 @@ def __init__(self, uxda: UxDataArray) -> None:
def __call__(self, **kwargs) -> Any:
return dataarray_plot.plot(self._uxda, **kwargs)

def __getattr__(self, name: str) -> Any:
"""When a function that isn't part of the class is invoked (i.e.
uxda.plot.hist), an attempt is made to try and call Xarray's
implementation of that function if it exsists."""

# reference to xr.DataArray.plot accessor
xarray_plot_accessor = super(type(self._uxda), self._uxda).plot

if hasattr(xarray_plot_accessor, name):
# call xarray plot method if it exists
# use inline backend to reset configuration if holoviz methods were called before
uxarray.plot.utils.backend.reset_mpl_backend()
return getattr(xarray_plot_accessor, name)
else:
raise AttributeError(f"Unsupported Plotting Method: '{name}'")

@functools.wraps(dataarray_plot.datashade)
def datashade(
self,
Expand Down Expand Up @@ -468,7 +484,24 @@ def __init__(self, uxds: UxDataset) -> None:
self._uxds = uxds

def __call__(self, **kwargs) -> Any:
warnings.warn(
"Plotting for UxDataset instances not yet supported. Did you mean to plot a data variable, i.e. uxds['data_variable'].plot()"
raise ValueError(
"UxDataset.plot cannot be called directly. Use an explicit plot method, "
"e.g uxds.plot.scatter(...)"
)
pass

def __getattr__(self, name: str) -> Any:
"""When a function that isn't part of the class is invoked (i.e.
uxds.plot.scatter), an attempt is made to try and call Xarray's
implementation of that function if it exists."""

# reference to xr.Dataset.plot accessor
xarray_plot_accessor = super(type(self._uxds), self._uxds).plot

if hasattr(xarray_plot_accessor, name):
# call xarray plot method if it exists
# # use inline backend to reset configuration if holoviz methods were called before
uxarray.plot.utils.backend.reset_mpl_backend()

return getattr(xarray_plot_accessor, name)
else:
raise AttributeError(f"Unsupported Plotting Method: '{name}'")
6 changes: 6 additions & 0 deletions uxarray/plot/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import holoviews as hv
import matplotlib as mpl


class HoloviewsBackend:
Expand All @@ -7,6 +8,7 @@ class HoloviewsBackend:

def __init__(self):
self.backend = None
self.matplotlib_backend = mpl.get_backend()

def assign(self, backend: str):
"""Assigns a backend for use with HoloViews visualization.
Expand All @@ -27,6 +29,10 @@ def assign(self, backend: str):
hv.extension(backend)
self.backend = backend

def reset_mpl_backend(self):
"""Resets the default backend for the ``matplotlib`` module."""
mpl.use(self.matplotlib_backend)


# global reference to holoviews backend utility class
backend = HoloviewsBackend()

0 comments on commit a85db05

Please sign in to comment.