diff --git a/nlmod/dims/resample.py b/nlmod/dims/resample.py index 50a17cfe..6efa7780 100644 --- a/nlmod/dims/resample.py +++ b/nlmod/dims/resample.py @@ -407,7 +407,7 @@ def vertex_da_to_ds(da, ds, method="nearest"): # when there are more dimensions than icell2d z = [] if method == "nearest": - # geneterate the tree only once, to increase speed + # generate the tree only once, to increase speed tree = cKDTree(points) _, i = tree.query(xi) dims = np.array(da.dims) diff --git a/nlmod/plot/plot.py b/nlmod/plot/plot.py index f581ff78..bfec9cf2 100644 --- a/nlmod/plot/plot.py +++ b/nlmod/plot/plot.py @@ -282,20 +282,21 @@ def geotop_lithok_in_cross_section( return cs -def _get_figure(ax=None, da=None, ds=None, figsize=None, rotated=True): +def _get_figure(ax=None, da=None, ds=None, figsize=None, rotated=True, extent=None): # figure if ax is not None: f = ax.figure else: - if ds is None: - extent = [ - da.x.values.min(), - da.x.values.max(), - da.y.values.min(), - da.y.values.max(), - ] - else: - extent = get_extent(ds, rotated=rotated) + if extent is None: + if ds is None: + extent = [ + da.x.values.min(), + da.x.values.max(), + da.y.values.min(), + da.y.values.max(), + ] + else: + extent = get_extent(ds, rotated=rotated) if figsize is None: figsize = get_figsize(extent) @@ -337,6 +338,7 @@ def map_array( background=False, figsize=None, animate=False, + **kwargs, ): # get data if isinstance(da, str): @@ -377,7 +379,9 @@ def map_array( else: t = None - f, ax = _get_figure(ax=ax, da=da, ds=ds, figsize=figsize, rotated=rotated) + f, ax = _get_figure( + ax=ax, da=da, ds=ds, figsize=figsize, rotated=rotated, extent=extent + ) # get normalization if vmin/vmax are passed if vmin is not None or vmax is not None: @@ -388,10 +392,6 @@ def map_array( da, ds=ds, cmap=cmap, alpha=alpha, norm=norm, ax=ax, rotated=rotated ) - # set extent - if extent is not None: - ax.axis(extent) - # bgmap if background: add_background_map(ax, map_provider="nlmaps.water", alpha=0.5) @@ -407,6 +407,10 @@ def map_array( raise ValueError("Plotting modelgrid requires model Dataset!") modelgrid(ds, ax=ax, lw=0.25, alpha=0.5, color="k") + # set extent + if extent is not None: + ax.axis(extent) + # axes properties if ilay is not None: title += f" (layer={layer})" @@ -421,7 +425,7 @@ def map_array( divider = make_axes_locatable(ax) if colorbar: cax = divider.append_axes("right", size="5%", pad=0.1) - cbar = f.colorbar(pc, cax=cax) + cbar = f.colorbar(pc, cax=cax, extend=kwargs.pop("extend", "neither")) if levels is not None: cbar.set_ticks(levels) cbar.set_label(colorbar_label)