diff --git a/ecco_v4_py/calc_meridional_trsp.py b/ecco_v4_py/calc_meridional_trsp.py index 7ee581f..9cf6645 100644 --- a/ecco_v4_py/calc_meridional_trsp.py +++ b/ecco_v4_py/calc_meridional_trsp.py @@ -21,7 +21,8 @@ RHO_CONST = 1029 HEAT_CAPACITY = 4000 -def calc_meridional_vol_trsp(ds,lat_vals,basin_name=None,grid=None): +def calc_meridional_vol_trsp(ds,lat_vals, + basin_name=None,coords=None,grid=None): """Compute volumetric transport across latitude band in Sverdrups Parameters @@ -34,6 +35,9 @@ def calc_meridional_vol_trsp(ds,lat_vals,basin_name=None,grid=None): denote ocean basin over which to compute streamfunction If not specified, compute global quantity see get_basin.get_available_basin_names for options + coords : xarray Dataset + separate dataset containing the coordinate information + YC, Z, drF, dyG, dxG, optionally maskW, maskS grid : xgcm Grid object, optional denotes LLC90 operations for xgcm, see ecco_utils.get_llc_grid see also the [xgcm documentation](https://xgcm.readthedocs.io/en/latest/grid_topology.html) @@ -52,12 +56,14 @@ def calc_meridional_vol_trsp(ds,lat_vals,basin_name=None,grid=None): dimensions: 'time' (if provided), 'lat', and 'k' """ - x_vol = ds['UVELMASS'] * ds['drF'] * ds['dyG'] - y_vol = ds['VVELMASS'] * ds['drF'] * ds['dxG'] + coords = _parse_coords(ds,coords,['Z','YC','drF','dyG','dxG']) + + x_vol = ds['UVELMASS'] * coords['drF'] * coords['dyG'] + y_vol = ds['VVELMASS'] * coords['drF'] * coords['dxG'] # Computes salt transport in m^3/s at each depth level ds_out = meridional_trsp_at_depth(x_vol,y_vol, - cds=ds, + coords=coords, lat_vals=lat_vals, basin_name=basin_name, grid=grid) @@ -76,7 +82,8 @@ def calc_meridional_vol_trsp(ds,lat_vals,basin_name=None,grid=None): return ds_out -def calc_meridional_heat_trsp(ds,lat_vals,basin_name=None,grid=None): +def calc_meridional_heat_trsp(ds,lat_vals, + basin_name=None,coords=None,grid=None): """Compute heat transport across latitude band in Petwatts see calc_meridional_vol_trsp for argument documentation. The only differences are: @@ -85,6 +92,9 @@ def calc_meridional_heat_trsp(ds,lat_vals,basin_name=None,grid=None): ---------- ds : xarray Dataset must contain fields 'ADVx_TH','ADVy_TH','DFxE_TH','DFyE_TH' + coords : xarray Dataset, optional + in case coordinates are in a separate dataset + only needs field 'YC' and optionally maskW, maskS Returns ------- @@ -99,12 +109,14 @@ def calc_meridional_heat_trsp(ds,lat_vals,basin_name=None,grid=None): dimensions: 'time' (if provided), 'lat', and 'k' """ + coords = _parse_coords(ds,coords,['Z','YC']) + x_heat = ds['ADVx_TH'] + ds['DFxE_TH'] y_heat = ds['ADVy_TH'] + ds['DFyE_TH'] # Computes heat transport in degC * m^3/s at each depth level ds_out = meridional_trsp_at_depth(x_heat,y_heat, - cds=ds, + coords=coords, lat_vals=lat_vals, basin_name=basin_name, grid=grid) @@ -122,7 +134,8 @@ def calc_meridional_heat_trsp(ds,lat_vals,basin_name=None,grid=None): return ds_out -def calc_meridional_salt_trsp(ds,lat_vals,basin_name=None,grid=None): +def calc_meridional_salt_trsp(ds,lat_vals, + basin_name=None,coords=None,grid=None): """Compute salt transport across latitude band in psu * Sv see calc_meridional_vol_trsp for argument documentation. The only differences are: @@ -131,6 +144,9 @@ def calc_meridional_salt_trsp(ds,lat_vals,basin_name=None,grid=None): ---------- ds : xarray Dataset must contain fields 'ADVx_SLT','ADVy_SLT','DFxE_SLT','DFyE_SLT' + coords : xarray Dataset, optional + in case coordinates are in a separate dataset + only needs field 'YC' and optionally maskW, maskS Returns ------- @@ -145,12 +161,14 @@ def calc_meridional_salt_trsp(ds,lat_vals,basin_name=None,grid=None): dimensions: 'time' (if provided), 'lat', and 'k' """ + coords = _parse_coords(ds,coords,['Z','YC']) + x_salt = ds['ADVx_SLT'] + ds['DFxE_SLT'] y_salt = ds['ADVy_SLT'] + ds['DFyE_SLT'] # Computes salt transport in psu * m^3/s at each depth level ds_out = meridional_trsp_at_depth(x_salt,y_salt, - cds=ds, + coords=coords, lat_vals=lat_vals, basin_name=basin_name, grid=grid) @@ -170,7 +188,7 @@ def calc_meridional_salt_trsp(ds,lat_vals,basin_name=None,grid=None): # --------------------------------------------------------------------- -def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds, +def meridional_trsp_at_depth(xfld, yfld, lat_vals, coords, basin_name=None, grid=None, less_output=True): """ Compute transport of vector quantity at each depth level @@ -182,8 +200,8 @@ def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds, 3D spatial (+ time, optional) field at west and south grid cell edges lat_vals : float or list latitude value(s) specifying where to compute transport - cds : xarray Dataset - with all LLC90 coordinates, including: maskW, maskS, YC + coords : xarray Dataset + only needs YC, and optionally maskW, maskS (defining wet points) basin_name : string, optional denote ocean basin over which to compute streamfunction If not specified, compute global quantity @@ -203,14 +221,14 @@ def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds, """ if grid is None: - grid = get_llc_grid(cds) + grid = get_llc_grid(coords) # Initialize empty DataArray with coordinates and dims - ds_out = _initialize_trsp_data_array(cds, lat_vals) + ds_out = _initialize_trsp_data_array(coords, lat_vals) # Get basin mask - maskW = cds['maskW'] if 'maskW' in cds else xr.ones_like(xfld) - maskS = cds['maskS'] if 'maskS' in cds else xr.ones_like(yfld) + maskW = coords['maskW'] if 'maskW' in coords else xr.ones_like(xfld) + maskS = coords['maskS'] if 'maskS' in coords else xr.ones_like(yfld) if basin_name is not None: maskW = get_basin_mask(basin_name,maskW) maskS = get_basin_mask(basin_name,maskS) @@ -225,7 +243,7 @@ def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds, print ('calculating transport for latitutde ', lat) # Compute mask for particular latitude band - lat_maskW, lat_maskS = vector_calc.get_latitude_masks(lat, cds['YC'], grid) + lat_maskW, lat_maskS = vector_calc.get_latitude_masks(lat, coords['YC'], grid) # Sum horizontally lat_trsp_x = (tmp_x * lat_maskW).sum(dim=['i_g','j','tile']) @@ -236,12 +254,12 @@ def meridional_trsp_at_depth(xfld, yfld, lat_vals, cds, return ds_out -def _initialize_trsp_data_array(cds, lat_vals): +def _initialize_trsp_data_array(coords, lat_vals): """Create an xarray DataArray with time, depth, and latitude dims Parameters ---------- - cds : xarray Dataset + coords : xarray Dataset contains LLC coordinates 'k' and (optionally) 'time' lat_vals : int or array of ints latitude value(s) rounded to the nearest degree @@ -258,19 +276,31 @@ def _initialize_trsp_data_array(cds, lat_vals): the original depth coordinate """ - coords = OrderedDict() - dims = () lat_vals = np.array(lat_vals) if isinstance(lat_vals,list) else lat_vals lat_vals = np.array([lat_vals]) if np.isscalar(lat_vals) else lat_vals lat_vals = xr.DataArray(lat_vals,coords={'lat':lat_vals},dims=('lat',)) - xda = xr.zeros_like(lat_vals*cds['k']) - xda = xda if 'time' not in cds.dims else xda.broadcast_like(cds['time']) + xda = xr.zeros_like(coords['k']*lat_vals) + xda = xda if 'time' not in coords.dims else xda.broadcast_like(coords['time']).copy() # Convert to dataset to add Z coordinate xds = xda.to_dataset(name='trsp_z') - xds['Z'] = cds['Z'] + xds['Z'] = coords['Z'] xds = xds.set_coords('Z') return xds +def _parse_coords(ds,coords,coordlist): + if coords is not None: + return coords + else: + for f in set(['maskW','maskS']).intersection(ds.reset_coords().keys()): + coordlist.append(f) + + if 'time' in ds.dims: + coordlist.append('time') + + dsout = ds[coordlist] + if 'domain' in ds.attrs: + dsout.attrs['domain'] = ds.attrs['domain'] + return dsout diff --git a/ecco_v4_py/calc_section_trsp.py b/ecco_v4_py/calc_section_trsp.py index a61f534..c5b2f39 100644 --- a/ecco_v4_py/calc_section_trsp.py +++ b/ecco_v4_py/calc_section_trsp.py @@ -11,6 +11,7 @@ from collections import OrderedDict from .ecco_utils import get_llc_grid +from .calc_meridional_trsp import _parse_coords from .get_section_masks import get_available_sections, \ get_section_endpoints, get_section_line_masks @@ -28,7 +29,7 @@ def calc_section_vol_trsp(ds, pt1=None, pt2=None, section_name=None, maskW=None, maskS=None, - grid=None): + coords=None, grid=None): """Compute volumetric transport across section in Sverdrups There are 3 ways to call this function: @@ -59,7 +60,7 @@ def calc_section_vol_trsp(ds, Parameters ---------- ds : xarray Dataset - must contain UVELMASS,VVELMASS, drF, dyG, dxG + must contain UVELMASS,VVELMASS and fields listed under coords below pt1, pt2 : list or tuple with two floats, optional end points for section line as [lon lat] or (lon, lat) maskW, maskS : xarray DataArray, optional @@ -68,6 +69,9 @@ def calc_section_vol_trsp(ds, name for the section. If predefined value, section mask is defined via get_section_endpoints otherwise, adds name to returned DataArray + coords : xarray Dataset + separate dataset containing the coordinate information + XC, YC, drF, Z, dyG, dxG, optionally maskW, maskS grid : xgcm Grid object, optional denotes LLC90 operations for xgcm, see ecco_utils.get_llc_grid see also the [xgcm documentation](https://xgcm.readthedocs.io/en/latest/grid_topology.html) @@ -87,16 +91,18 @@ def calc_section_vol_trsp(ds, and the section_name as an attribute if it is provided """ - maskW, maskS = _parse_section_trsp_inputs(ds,pt1,pt2,maskW,maskS,section_name) + coords = _parse_coords(ds,coords,['Z','YC','XC','drF','dyG','dxG']) + + maskW, maskS = _parse_section_trsp_inputs(coords,pt1,pt2,maskW,maskS,section_name, + grid=grid) # Define volumetric transport - x_vol = ds['UVELMASS'] * ds['drF'] * ds['dyG'] - y_vol = ds['VVELMASS'] * ds['drF'] * ds['dxG'] + x_vol = ds['UVELMASS'] * coords['drF'] * coords['dyG'] + y_vol = ds['VVELMASS'] * coords['drF'] * coords['dxG'] # Computes salt transport in m^3/s at each depth level ds_out = section_trsp_at_depth(x_vol,y_vol,maskW,maskS, - cds=ds.coords.to_dataset(), - grid=grid) + coords=coords) # Rename to useful data array name ds_out = ds_out.rename({'trsp_z': 'vol_trsp_z'}) @@ -121,7 +127,7 @@ def calc_section_heat_trsp(ds, pt1=None, pt2=None, section_name=None, maskW=None, maskS=None, - grid=None): + coords=None,grid=None): """Compute heat transport across section in PW Inputs and usage are same as calc_section_vol_trsp. The only differences are: @@ -130,6 +136,8 @@ def calc_section_heat_trsp(ds, ---------- ds : xarray Dataset must contain ADVx_TH, ADVy_TH, DFxe_TH, DFyE_TH + coords : xarray Dataset, optional + must contain XC, YC, Z, optionally maskW, maskS Returns ------- @@ -146,7 +154,10 @@ def calc_section_heat_trsp(ds, and the section_name as an attribute if it is provided """ - maskW, maskS = _parse_section_trsp_inputs(ds,pt1,pt2,maskW,maskS,section_name) + coords = _parse_coords(ds,coords,['Z','YC','XC']) + + maskW, maskS = _parse_section_trsp_inputs(coords,pt1,pt2,maskW,maskS,section_name, + grid=grid) # Define heat transport x_heat = ds['ADVx_TH'] + ds['DFxE_TH'] @@ -154,8 +165,7 @@ def calc_section_heat_trsp(ds, # Computes salt transport in degC * m^3/s at each depth level ds_out = section_trsp_at_depth(x_heat,y_heat,maskW,maskS, - cds=ds.coords.to_dataset(), - grid=grid) + coords=coords) # Rename to useful data array name ds_out = ds_out.rename({'trsp_z': 'heat_trsp_z'}) @@ -180,7 +190,7 @@ def calc_section_salt_trsp(ds, pt1=None, pt2=None, section_name=None, maskW=None, maskS=None, - grid=None): + coords=None, grid=None): """Compute salt transport across section in psu*Sv Inputs and usage are same as calc_section_vol_trsp. The only differences are: @@ -189,6 +199,8 @@ def calc_section_salt_trsp(ds, ---------- ds : xarray Dataset must contain ADVx_SLT, ADVy_SLT, DFxe_SLT, DFyE_SLT + coords : xarray Dataset, optional + must contain XC, YC, Z, optionally maskW, maskS Returns ------- @@ -205,7 +217,10 @@ def calc_section_salt_trsp(ds, and the section_name as an attribute if it is provided """ - maskW, maskS = _parse_section_trsp_inputs(ds,pt1,pt2,maskW,maskS,section_name) + coords = _parse_coords(ds,coords,['Z','YC','XC']) + + maskW, maskS = _parse_section_trsp_inputs(coords,pt1,pt2,maskW,maskS,section_name, + grid=grid) # Define salt transport x_salt = ds['ADVx_SLT'] + ds['DFxE_SLT'] @@ -213,8 +228,7 @@ def calc_section_salt_trsp(ds, # Computes salt transport in psu * m^3/s at each depth level ds_out = section_trsp_at_depth(x_salt,y_salt,maskW,maskS, - cds=ds.coords.to_dataset(), - grid=grid) + coords=coords) # Rename to useful data array name ds_out = ds_out.rename({'trsp_z': 'salt_trsp_z'}) @@ -239,8 +253,8 @@ def calc_section_salt_trsp(ds, # Main function for computing standard transport quantities # ------------------------------------------------------------------------------- -def section_trsp_at_depth(xfld, yfld, maskW, maskS, cds, - grid=None): +def section_trsp_at_depth(xfld, yfld, maskW, maskS, + coords=None): """ Compute transport of vector quantity at each depth level across latitude(s) defined in lat_vals @@ -251,10 +265,8 @@ def section_trsp_at_depth(xfld, yfld, maskW, maskS, cds, 3D spatial (+ time, optional) field at west and south grid cell edge maskW, maskS : xarray DataArray defines the section to define transport across - cds : xarray Dataset - with all LLC90 coordinates, including: maskW/S, YC - grid : xgcm Grid object, optional - denotes LLC90 operations for xgcm, see utils.get_llc_grid + coords : xarray Dataset, optional + include if providing maskW/S (i.e. wet point masks in addition to line masks) Returns ------- @@ -266,15 +278,14 @@ def section_trsp_at_depth(xfld, yfld, maskW, maskS, cds, and 'k' (depth) """ - if grid is None: - grid = get_llc_grid(cds) - # Initialize empty DataArray with coordinates and dims - ds_out = _initialize_section_trsp_data_array(cds) + coords = coords if coords is not None else xfld.to_dataset(name='xfld') + ds_out = _initialize_section_trsp_data_array(coords) # Apply section mask and sum horizontally - maskW = maskW.where(cds['maskW']) if 'maskW' in cds else maskW - maskS = maskS.where(cds['maskS']) if 'maskS' in cds else maskS + # if wet point mask in coords, use it + maskW = maskW.where(coords['maskW']) if 'maskW' in coords else maskW + maskS = maskS.where(coords['maskS']) if 'maskS' in coords else maskS sec_trsp_x = (xfld * maskW).sum(dim=['i_g','j','tile']) sec_trsp_y = (yfld * maskS).sum(dim=['i','j_g','tile']) @@ -291,7 +302,7 @@ def section_trsp_at_depth(xfld, yfld, maskW, maskS, cds, # Helper functions for the computing volume, heat, and salt transport # ------------------------------------------------------------------------------- -def _parse_section_trsp_inputs(ds,pt1,pt2,maskW,maskS,section_name): +def _parse_section_trsp_inputs(ds,pt1,pt2,maskW,maskS,section_name,grid=None): """Handle inputs for computing volume, heat, or salt transport across a section @@ -331,22 +342,21 @@ def _parse_section_trsp_inputs(ds,pt1,pt2,maskW,maskS,section_name): if use_endpoints or use_masks: raise TypeError('Cannot provide more than one method for defining section') pt1, pt2 = get_section_endpoints(section_name) - _, maskW, maskS = get_section_line_masks(pt1, pt2, ds) else: # Secondly, try to use endpoints or mask if use_endpoints and use_masks: raise TypeError('Cannot provide more than one method for defining section') - elif use_endpoints: - _, maskW, maskS = get_section_line_masks(pt1, pt2, ds) + if not use_masks: + _, maskW, maskS = get_section_line_masks(pt1, pt2, ds, grid=grid) return maskW, maskS -def _initialize_section_trsp_data_array(cds): +def _initialize_section_trsp_data_array(coords): """Create an xarray DataArray with time, depth, and latitude dims Parameters ---------- - cds : xarray Dataset + coords : xarray Dataset contains LLC coordinates 'k' and (optionally) 'time' Returns @@ -360,15 +370,12 @@ def _initialize_section_trsp_data_array(cds): the original depth coordinate """ - coords = OrderedDict() - dims = () - - xda = xr.zeros_like(cds['k']) - xda = xda if 'time' not in cds.dims else xda.broadcast_like(cds['time']) + xda = xr.zeros_like(coords['k']) + xda = xda if 'time' not in coords.dims else xda.broadcast_like(coords['time']).copy() # Convert to dataset to add Z coordinate xds = xda.to_dataset(name='trsp_z') - xds['Z'] = cds['Z'] + xds['Z'] = coords['Z'] xds = xds.set_coords('Z') return xds diff --git a/ecco_v4_py/calc_stf.py b/ecco_v4_py/calc_stf.py index 181260e..65f88a8 100644 --- a/ecco_v4_py/calc_stf.py +++ b/ecco_v4_py/calc_stf.py @@ -6,14 +6,15 @@ import numpy as np from .ecco_utils import get_llc_grid -from .calc_meridional_trsp import meridional_trsp_at_depth +from .calc_meridional_trsp import _parse_coords, meridional_trsp_at_depth from .calc_section_trsp import _parse_section_trsp_inputs, section_trsp_at_depth # Define constants METERS_CUBED_TO_SVERDRUPS = 10**-6 -def calc_meridional_stf(ds,lat_vals,doFlip=True,basin_name=None,grid=None): - """Compute the meridional overturning streamfunction in Sverdrups +def calc_meridional_stf(ds,lat_vals,doFlip=True, + basin_name=None,coords=None,grid=None): + """Compute the meridional overturning streamfunction in Sverdrups at specified latitude(s) Parameters @@ -24,12 +25,15 @@ def calc_meridional_stf(ds,lat_vals,doFlip=True,basin_name=None,grid=None): latitude value(s) rounded to the nearest degree specifying where to compute overturning streamfunction doFlip : logical, optional - True: integrate from "bottom" by flipping Z dimension before cumsum(), + True: integrate from "bottom" by flipping Z dimension before cumsum(), then multiply by -1. False: flip neither dim nor sign. basin_name : string, optional denote ocean basin over which to compute streamfunction If not specified, compute global quantity see utils.get_available_basin_names for options + coords : xarray Dataset + separate dataset containing the coordinate information + YC, drF, dyG, dxG, optionally maskW, maskS grid : xgcm Grid object, optional denotes LLC90 operations for xgcm, see ecco_utils.get_llc_grid see also the [xgcm documentation](https://xgcm.readthedocs.io/en/latest/grid_topology.html) @@ -49,15 +53,18 @@ def calc_meridional_stf(ds,lat_vals,doFlip=True,basin_name=None,grid=None): with dimensions 'time' (if in given dataset), 'lat', and 'k' """ + # get coords + coords = _parse_coords(ds,coords,['Z','YC','drF','dyG','dxG']) + # Compute volume transport - trsp_x = ds['UVELMASS'] * ds['drF'] * ds['dyG'] - trsp_y = ds['VVELMASS'] * ds['drF'] * ds['dxG'] + trsp_x = ds['UVELMASS'] * coords['drF'] * coords['dyG'] + trsp_y = ds['VVELMASS'] * coords['drF'] * coords['dxG'] # Creates an empty streamfunction - ds_out = meridional_trsp_at_depth(trsp_x, trsp_y, - lat_vals=lat_vals, - cds=ds.coords.to_dataset(), - basin_name=basin_name, + ds_out = meridional_trsp_at_depth(trsp_x, trsp_y, + lat_vals=lat_vals, + coords=coords, + basin_name=basin_name, grid=grid) psi_moc = ds_out['trsp_z'].copy(deep=True) @@ -66,9 +73,9 @@ def calc_meridional_stf(ds,lat_vals,doFlip=True,basin_name=None,grid=None): if doFlip: psi_moc = psi_moc.isel(k=slice(None,None,-1)) - # Should this be done with a grid object??? + # Should this be done with a grid object??? psi_moc = psi_moc.cumsum(dim='k') - + if doFlip: psi_moc = -1 * psi_moc.isel(k=slice(None,None,-1)) @@ -90,23 +97,24 @@ def calc_meridional_stf(ds,lat_vals,doFlip=True,basin_name=None,grid=None): return ds_out -def calc_section_stf(ds, - pt1=None, pt2=None, +def calc_section_stf(ds, + pt1=None, pt2=None, section_name=None, maskW=None, maskS=None, - doFlip=True,grid=None): - """Compute the overturning streamfunction in plane normal to + doFlip=True,coords=None,grid=None): + """Compute the overturning streamfunction in plane normal to section defined by pt1 and pt2 in depth space - See calc_section_trsp.calc_section_vol_trsp for the various ways - to call this function + See calc_section_trsp.calc_section_vol_trsp for the various ways + to call this function All inputs are the same except: Parameters ---------- - ds : xarray DataSet - must contain UVELMASS,VVELMASS, drF, dyG, dxG + doFlip : logical, optional + True: integrate from "bottom" by flipping Z dimension before cumsum(), + then multiply by -1. False: flip neither dim nor sign. Returns ------- @@ -126,17 +134,19 @@ def calc_section_stf(ds, and the section_name as an attribute if it is provided """ + coords = _parse_coords(ds,coords,['Z','YC','XC','drF','dyG','dxG']) + # Compute volume transport - trsp_x = ds['UVELMASS'] * ds['drF'] * ds['dyG'] - trsp_y = ds['VVELMASS'] * ds['drF'] * ds['dxG'] + trsp_x = ds['UVELMASS'] * coords['drF'] * coords['dyG'] + trsp_y = ds['VVELMASS'] * coords['drF'] * coords['dxG'] - maskW, maskS = _parse_section_trsp_inputs(ds,pt1,pt2,maskW,maskS,section_name) + maskW, maskS = _parse_section_trsp_inputs(coords,pt1,pt2,maskW,maskS,section_name, + grid=grid) # Creates an empty streamfunction ds_out = section_trsp_at_depth(trsp_x, trsp_y, - maskW, maskS, - cds=ds.coords.to_dataset(), - grid=grid) + maskW, maskS, + coords=coords) psi_moc = ds_out['trsp_z'].copy(deep=True) @@ -144,9 +154,9 @@ def calc_section_stf(ds, if doFlip: psi_moc = psi_moc.isel(k=slice(None,None,-1)) - # Should this be done with a grid object??? + # Should this be done with a grid object??? psi_moc = psi_moc.cumsum(dim='k') - + if doFlip: psi_moc = -1 * psi_moc.isel(k=slice(None,None,-1)) diff --git a/ecco_v4_py/ecco_utils.py b/ecco_v4_py/ecco_utils.py index 2d3c18d..d5227dc 100644 --- a/ecco_v4_py/ecco_utils.py +++ b/ecco_v4_py/ecco_utils.py @@ -283,7 +283,7 @@ def months2days(nmon=288, baseyear=1992, basemon=1): #%% -def get_llc_grid(ds): +def get_llc_grid(ds,domain='global'): """ Define xgcm Grid object for the LLC grid See example usage in the xgcm documentation: @@ -302,40 +302,62 @@ def get_llc_grid(ds): """ + if 'domain' in ds.attrs: + domain = ds.attrs['domain'] + + if domain == 'global': + # Establish grid topology + tile_connections = {'tile': { + 0: {'X': ((12, 'Y', False), (3, 'X', False)), + 'Y': (None, (1, 'Y', False))}, + 1: {'X': ((11, 'Y', False), (4, 'X', False)), + 'Y': ((0, 'Y', False), (2, 'Y', False))}, + 2: {'X': ((10, 'Y', False), (5, 'X', False)), + 'Y': ((1, 'Y', False), (6, 'X', False))}, + 3: {'X': ((0, 'X', False), (9, 'Y', False)), + 'Y': (None, (4, 'Y', False))}, + 4: {'X': ((1, 'X', False), (8, 'Y', False)), + 'Y': ((3, 'Y', False), (5, 'Y', False))}, + 5: {'X': ((2, 'X', False), (7, 'Y', False)), + 'Y': ((4, 'Y', False), (6, 'Y', False))}, + 6: {'X': ((2, 'Y', False), (7, 'X', False)), + 'Y': ((5, 'Y', False), (10, 'X', False))}, + 7: {'X': ((6, 'X', False), (8, 'X', False)), + 'Y': ((5, 'X', False), (10, 'Y', False))}, + 8: {'X': ((7, 'X', False), (9, 'X', False)), + 'Y': ((4, 'X', False), (11, 'Y', False))}, + 9: {'X': ((8, 'X', False), None), + 'Y': ((3, 'X', False), (12, 'Y', False))}, + 10: {'X': ((6, 'Y', False), (11, 'X', False)), + 'Y': ((7, 'Y', False), (2, 'X', False))}, + 11: {'X': ((10, 'X', False), (12, 'X', False)), + 'Y': ((8, 'Y', False), (1, 'X', False))}, + 12: {'X': ((11, 'X', False), None), + 'Y': ((9, 'Y', False), (0, 'X', False))} + }} + + grid = xgcm.Grid(ds, + periodic=False, + face_connections=tile_connections + ) + elif domain == 'aste': + tile_connections = {'tile':{ + 0:{'X':((5,'Y',False),None), + 'Y':(None,(1,'Y',False))}, + 1:{'X':((4,'Y',False),None), + 'Y':((0,'Y',False),(2,'X',False))}, + 2:{'X':((1,'Y',False),(3,'X',False)), + 'Y':(None,(4,'X',False))}, + 3:{'X':((2,'X',False),None), + 'Y':(None,None)}, + 4:{'X':((2,'Y',False),(5,'X',False)), + 'Y':(None,(1,'X',False))}, + 5:{'X':((4,'X',False),None), + 'Y':(None,(0,'X',False))} + }} + grid = xgcm.Grid(ds,periodic=False,face_connections=tile_connections) + else: + raise TypeError(f'Domain {domain} not recognized') - # Establish grid topology - tile_connections = {'tile': { - 0: {'X': ((12, 'Y', False), (3, 'X', False)), - 'Y': (None, (1, 'Y', False))}, - 1: {'X': ((11, 'Y', False), (4, 'X', False)), - 'Y': ((0, 'Y', False), (2, 'Y', False))}, - 2: {'X': ((10, 'Y', False), (5, 'X', False)), - 'Y': ((1, 'Y', False), (6, 'X', False))}, - 3: {'X': ((0, 'X', False), (9, 'Y', False)), - 'Y': (None, (4, 'Y', False))}, - 4: {'X': ((1, 'X', False), (8, 'Y', False)), - 'Y': ((3, 'Y', False), (5, 'Y', False))}, - 5: {'X': ((2, 'X', False), (7, 'Y', False)), - 'Y': ((4, 'Y', False), (6, 'Y', False))}, - 6: {'X': ((2, 'Y', False), (7, 'X', False)), - 'Y': ((5, 'Y', False), (10, 'X', False))}, - 7: {'X': ((6, 'X', False), (8, 'X', False)), - 'Y': ((5, 'X', False), (10, 'Y', False))}, - 8: {'X': ((7, 'X', False), (9, 'X', False)), - 'Y': ((4, 'X', False), (11, 'Y', False))}, - 9: {'X': ((8, 'X', False), None), - 'Y': ((3, 'X', False), (12, 'Y', False))}, - 10: {'X': ((6, 'Y', False), (11, 'X', False)), - 'Y': ((7, 'Y', False), (2, 'X', False))}, - 11: {'X': ((10, 'X', False), (12, 'X', False)), - 'Y': ((8, 'Y', False), (1, 'X', False))}, - 12: {'X': ((11, 'X', False), None), - 'Y': ((9, 'Y', False), (0, 'X', False))} - }} - - grid = xgcm.Grid(ds, - periodic=False, - face_connections=tile_connections - ) return grid diff --git a/ecco_v4_py/get_basin.py b/ecco_v4_py/get_basin.py index e169a1a..2305325 100644 --- a/ecco_v4_py/get_basin.py +++ b/ecco_v4_py/get_basin.py @@ -21,6 +21,7 @@ def get_basin_mask(basin_name, mask, basin_path=os.path.join('..','binary_data')): """Return mask for ocean basin. Note: This mirrors gcmfaces/ecco_v4/v4_basin.m + And this only works for the global LLC90 domain Parameters ---------- @@ -51,6 +52,8 @@ def get_basin_mask(basin_name, mask, mask with values at cell centers, 1's for denoted ocean basin dimensions are the same as input field """ + if 'tile' not in mask.dims or len(mask.tile)!=13 or mask.shape[-1]!=90 or mask.shape[-2]!=90: + raise NotImplementedError("Basin masks only available for global LLC90 domain") if type(basin_name) is not list: basin_name = [basin_name] diff --git a/ecco_v4_py/get_section_masks.py b/ecco_v4_py/get_section_masks.py index 497738a..df9dcb1 100644 --- a/ecco_v4_py/get_section_masks.py +++ b/ecco_v4_py/get_section_masks.py @@ -22,7 +22,7 @@ def get_section_endpoints(section_name): pt1 = [-68, -54] pt2 = [-63, -66] - These sections mirror the gcmfaces definitions, see + These sections mirror the gcmfaces definitions, see gcmfaces/gcmfaces_calc/gcmfaces_lines_pairs.m Parameters @@ -35,9 +35,9 @@ def get_section_endpoints(section_name): pt1, pt2 : array_like array with two values, [lon, lat] of each endpoint - or + or - None + None if section_name is not in the pre-defined list of sections """ @@ -182,11 +182,11 @@ def get_available_sections(): return section_list # ------------------------------------------------------------------------------- -# Main function to compute section masks +# Main function to compute section masks # ------------------------------------------------------------------------------- -def get_section_line_masks(pt1, pt2, cds): - """Compute 2D mask with 1's along great circle line +def get_section_line_masks(pt1, pt2, cds, grid=None): + """Compute 2D mask with 1's along great circle line from lat/lon1 -> lat/lon2 Parameters @@ -195,6 +195,8 @@ def get_section_line_masks(pt1, pt2, cds): [longitude, latitude] or (longitude, latitude) of endpoints cds : xarray Dataset containing grid coordinate information, at least XC, YC + grid : xgcm grid object + see ecco_utils.get_llc_grid Returns ------- @@ -202,7 +204,7 @@ def get_section_line_masks(pt1, pt2, cds): 2D mask along section """ - # Get cartesian coordinates of end points + # Get cartesian coordinates of end points x1, y1, z1 = _convert_latlon_to_cartesian(pt1[0],pt1[1]) x2, y2, z2 = _convert_latlon_to_cartesian(pt2[0],pt2[1]) @@ -234,19 +236,19 @@ def get_section_line_masks(pt1, pt2, cds): x1, y1, z1 = _apply_rotation_matrix(rot_3, (x1,y1,z1)) x2, y2, z2 = _apply_rotation_matrix(rot_3, (x2,y2,z2)) - # Now apply rotations to the grid - # and get cartesian coordinates at cell centers + # Now apply rotations to the grid + # and get cartesian coordinates at cell centers xc, yc, zc = _rotate_the_grid(cds.XC, cds.YC, rot_1, rot_2, rot_3) # Interpolate for x,y to west and south edges - grid = get_llc_grid(cds) + grid = get_llc_grid(cds) if grid is None else grid xw = grid.interp(xc, 'X', boundary='fill') yw = grid.interp(yc, 'X', boundary='fill') xs = grid.interp(xc, 'Y', boundary='fill') ys = grid.interp(yc, 'Y', boundary='fill') # Compute the great circle mask, covering the entire globe - maskC = scalar_calc.get_edge_mask(zc>0,grid) + maskC = scalar_calc.get_edge_mask(zc>0,grid) maskW = grid.diff( 1*(zc>0), 'X', boundary='fill') maskS = grid.diff( 1*(zc>0), 'Y', boundary='fill') @@ -263,11 +265,11 @@ def get_section_line_masks(pt1, pt2, cds): # All functions below are non-user facing # # ------------------------------------------------------------------------------- -# Helper functions for computing section masks +# Helper functions for computing section masks # ------------------------------------------------------------------------------- def _calc_section_along_full_arc_mask( mask, x1, y1, x2, y2, xg, yg ): - """Given a mask which has a great circle passing through + """Given a mask which has a great circle passing through pt1 = (x1, y1) and pt2 = (x2,y2), grab the section just connecting pt1 and pt2 Parameters @@ -317,7 +319,15 @@ def _rotate_the_grid(lon, lat, rot_1, rot_2, rot_3): """ # Get cartesian of 1D view of lat/lon - xg, yg, zg = _convert_latlon_to_cartesian(lon.values.ravel(),lat.values.ravel()) + lon_v = lon.values.ravel() + lat_v = lat.values.ravel() + get_mesh = False + if not set(('i','j')).issubset(lon.coords): + lon_v,lat_v = np.meshgrid(lon_v,lat_v) + lon_v = lon_v.ravel() + lat_v = lat_v.ravel() + get_mesh = True + xg, yg, zg = _convert_latlon_to_cartesian(lon_v,lat_v) # These rotations result in: # xg = 0 at pt1 @@ -328,9 +338,13 @@ def _rotate_the_grid(lon, lat, rot_1, rot_2, rot_3): xg, yg, zg = _apply_rotation_matrix(rot_3, (xg,yg,zg)) # Remake into LLC xarray DataArray - xg = llc_tiles_to_xda(xg, grid_da=lon, less_output=True) - yg = llc_tiles_to_xda(yg, grid_da=lat, less_output=True) - zg = llc_tiles_to_xda(zg, grid_da=lon, less_output=True) + if get_mesh: + xdalike = lat*lon + else: + xdalike = lat + xg = llc_tiles_to_xda(xg, grid_da=xdalike, less_output=True) + yg = llc_tiles_to_xda(yg, grid_da=xdalike, less_output=True) + zg = llc_tiles_to_xda(zg, grid_da=xdalike, less_output=True) return xg, yg, zg @@ -365,7 +379,7 @@ def _apply_rotation_matrix(rot_mat,xyz): def _convert_latlon_to_cartesian(lon, lat): """Convert latitude, longitude (degrees) to cartesian coordinates - Note: conversion to cartesian differs from what is found at e.g. Wolfram + Note: conversion to cartesian differs from what is found at e.g. Wolfram because here lat \in [-pi/2, pi/2] with 0 at equator, not [0, pi], pi/2 at equator Parameters diff --git a/ecco_v4_py/read_bin_llc.py b/ecco_v4_py/read_bin_llc.py index d87c460..69b7343 100644 --- a/ecco_v4_py/read_bin_llc.py +++ b/ecco_v4_py/read_bin_llc.py @@ -39,7 +39,8 @@ def load_ecco_vars_from_mds(mds_var_dir, meta_common=dict(), mds_datatype = '>f4', llc_method = 'bigchunks', - less_output=True): + less_output=True, + **kwargs): """ @@ -123,6 +124,9 @@ def load_ecco_vars_from_mds(mds_var_dir, less_output : logical, optional if True (default), omit additional print statements + **kwargs: optional + extra inputs passed to xmitgcm.open_mdsdataset + Returns ======= @@ -170,7 +174,8 @@ def load_ecco_vars_from_mds(mds_var_dir, default_dtype = np.dtype(mds_datatype), grid_vars_to_coords=True, llc_method = llc_method, - ignore_unknown_vars=True) + ignore_unknown_vars=True, + **kwargs) else: if not less_output: @@ -191,7 +196,8 @@ def load_ecco_vars_from_mds(mds_var_dir, default_dtype = np.dtype(mds_datatype), grid_vars_to_coords=True, llc_method=llc_method, - ignore_unknown_vars=True) + ignore_unknown_vars=True, + **kwargs) else: raise TypeError('not a valid model_time_steps_to_load. must be "all", an "int", or a list of "int"') diff --git a/ecco_v4_py/test/test_common.py b/ecco_v4_py/test/test_common.py index cb91ec3..4be8d93 100644 --- a/ecco_v4_py/test/test_common.py +++ b/ecco_v4_py/test/test_common.py @@ -2,49 +2,81 @@ Helper functions for all tests """ import pytest -from xmitgcm.test.test_xmitgcm_common import llc_mds_datadirs +from xmitgcm.utils import get_extra_metadata +from xmitgcm.test.test_xmitgcm_common import ( + _experiments, llc_mds_datadirs, setup_mds_dir, dlroot ) import ecco_v4_py as ecco +# Following xmitgcm's lead to add an ASTE domain for testing +_experiments['aste270']= {'geometry':'llc', + 'dlink': dlroot + '25286756', + 'md5': 'f616fe46330f1125472f274af2c96e44', + 'shape': (50,6,270,270), + 'ref_date':'2002-01-01 00:00:00', + 'diagnostics':('state_2d_set1', + ['ETAN ', 'SIarea ', 'SIheff ', 'SIhsnow ', + 'DETADT2 ', 'PHIBOT ', 'sIceLoad', 'MXLDEPTH', + 'oceSPDep', 'SIatmQnt', 'SIatmFW ', 'oceQnet ', + 'oceFWflx', 'oceTAUX ', 'oceTAUY ', 'ADVxHEFF', + 'ADVyHEFF', 'DFxEHEFF', 'DFyEHEFF', 'ADVxSNOW', + 'ADVySNOW', 'DFxESNOW', 'DFyESNOW', 'SIuice ', + 'SIvice ' 'ETANSQ ']), + 'test_iternum':8} + +# Modify xmitgcm's function for both global ECCO and ASTE +@pytest.fixture(scope='module', params=['global_oce_llc90','aste270']) +def all_mds_datadirs(tmpdir_factory, request): + return setup_mds_dir(tmpdir_factory,request, _experiments) + @pytest.fixture(scope='module') -def get_test_ds(llc_mds_datadirs): +def get_test_ds(all_mds_datadirs): + """make 2 tests when called, one with global, one with ASTE, + using fixture above""" - dirname, expected = llc_mds_datadirs + dirname, expected = all_mds_datadirs + + kwargs = {} + if 'aste' in dirname: + kwargs['extra_metadata']=get_extra_metadata('aste',270) + kwargs['tiles_to_load']=[0,1,2,3,4,5] + kwargs['nx']=270 + domain = 'aste' + else: + domain = 'global' # read in array ds = ecco.load_ecco_vars_from_mds(dirname, model_time_steps_to_load=expected['test_iternum'], - mds_files=expected['diagnostics'][0]) + mds_files=['state_2d_set1','U','V','W','T','S'], + **kwargs) + ds.attrs['domain'] = domain return ds @pytest.fixture(scope='module') -def get_test_array_2d(llc_mds_datadirs): - """download, unzip and return 2D field""" +def get_global_ds(llc_mds_datadirs): + """just get the global dataset""" dirname, expected = llc_mds_datadirs # read in array ds = ecco.load_ecco_vars_from_mds(dirname, model_time_steps_to_load=expected['test_iternum'], - mds_files=expected['diagnostics'][0]) - - xda = ds['ETAN'] - - if 'time' in xda.dims: - xda = xda.isel(time=0) - return xda + mds_files=['state_2d_set1','U','V','W','T','S']) + return ds @pytest.fixture(scope='module') -def get_test_vectors(llc_mds_datadirs): - """download, unzip and return zonal/meridional velocity in dataset""" +def get_test_array_2d(llc_mds_datadirs): + """download, unzip and return 2D field""" dirname, expected = llc_mds_datadirs # read in array ds = ecco.load_ecco_vars_from_mds(dirname, model_time_steps_to_load=expected['test_iternum'], - mds_files=['U','V','W']) + mds_files='state_2d_set1') - if 'time' in ds.dims: - ds = ds.isel(time=-1) + xda = ds['ETAN'] - return ds + if 'time' in xda.dims: + xda = xda.isel(time=0) + return xda diff --git a/ecco_v4_py/test/test_ecco_utils.py b/ecco_v4_py/test/test_ecco_utils.py index 0e9f371..27f2f72 100644 --- a/ecco_v4_py/test/test_ecco_utils.py +++ b/ecco_v4_py/test/test_ecco_utils.py @@ -6,7 +6,7 @@ import pytest import ecco_v4_py -from .test_common import llc_mds_datadirs, get_test_ds +from .test_common import all_mds_datadirs, get_test_ds @pytest.mark.parametrize("mytype",['xda','nparr','list','single']) def test_extract_dates(mytype): @@ -36,4 +36,5 @@ def test_extract_dates(mytype): def test_get_grid(get_test_ds): """make sure we can make a grid ... that's it""" + grid = ecco_v4_py.get_llc_grid(get_test_ds) diff --git a/ecco_v4_py/test/test_get_basin.py b/ecco_v4_py/test/test_get_basin.py index 222ba46..bb2f4fb 100644 --- a/ecco_v4_py/test/test_get_basin.py +++ b/ecco_v4_py/test/test_get_basin.py @@ -4,26 +4,39 @@ import ecco_v4_py import pytest -from .test_common import llc_mds_datadirs, get_test_ds, get_test_vectors +from .test_common import ( + llc_mds_datadirs, get_global_ds, + all_mds_datadirs, get_test_ds) _test_dir = os.path.dirname(os.path.abspath(__file__)) -def test_each_basin_masks(get_test_ds): +# test error out for different domains +def test_notimplemented(get_test_ds): + ds = get_test_ds + with pytest.raises(NotImplementedError): + if len(ds.tile)<13: + ecco_v4_py.get_basin_mask('atl',ds.maskC) + else: + ecco_v4_py.get_basin_mask('atl',ds.sel(tile=0).maskC) + ecco_v4_py.get_basin_mask('atl',(1.*ds.maskC).diff(dim='i')) + ecco_v4_py.get_basin_mask('atl',(1.*ds.maskC).diff(dim='j')) + +def test_each_basin_masks(get_global_ds): """make sure we can make the basin masks """ - ds = get_test_ds + ds = get_global_ds all_basins = ecco_v4_py.read_llc_to_tiles(os.path.join(_test_dir,'..','..','binary_data'),'basins.data',less_output=True) ext_names = ['atlExt','pacExt','indExt'] for i,basin in enumerate(ecco_v4_py.get_available_basin_names(),start=1): mask = ecco_v4_py.get_basin_mask(basin,ds.maskC.isel(k=0)) assert np.all(mask.values == (all_basins==i)) -def test_ext_basin_masks(get_test_ds): +def test_ext_basin_masks(get_global_ds): """make sure we can make the extended masks """ - ds = get_test_ds + ds = get_global_ds ext_names = ['atlExt','pacExt','indExt'] individual_names = [['atl','mexico','hudson','med','north','baffin','gin'], @@ -34,10 +47,10 @@ def test_ext_basin_masks(get_test_ds): maskI = ecco_v4_py.get_basin_mask(ind,ds.maskC.isel(k=0)) assert np.all(maskE==maskI) -def test_3d(get_test_vectors): +def test_3d(get_global_ds): """check that vertical coordinate""" - ds = get_test_vectors + ds = get_global_ds grid = ecco_v4_py.get_llc_grid(ds) maskK = ds['maskC'] maskL = grid.interp(maskK,'Z',to='left',boundary='fill') @@ -46,10 +59,10 @@ def test_3d(get_test_vectors): for mask in [maskK,maskL,maskU,maskKp1]: ecco_v4_py.get_basin_mask('atl',mask) -def test_bin_dir_is_here(get_test_ds,hide_bin_dir): +def test_bin_dir_is_here(get_global_ds,hide_bin_dir): hide_bin_dir - ds = get_test_ds + ds = get_global_ds with pytest.raises(OSError): ecco_v4_py.get_basin_mask('atl',ds.maskC.isel(k=0)) diff --git a/ecco_v4_py/test/test_llc_array_conversion.py b/ecco_v4_py/test/test_llc_array_conversion.py index d8eac4c..a10b2e2 100644 --- a/ecco_v4_py/test/test_llc_array_conversion.py +++ b/ecco_v4_py/test/test_llc_array_conversion.py @@ -7,7 +7,7 @@ import pytest import ecco_v4_py as ecco -from .test_common import llc_mds_datadirs,get_test_ds +from .test_common import llc_mds_datadirs,get_global_ds # Define bin directory for test reading _PKG_DIR = Path(__file__).resolve().parent.parent.parent @@ -91,13 +91,13 @@ def test_convert_tiles_to_compact(llc_mds_datadirs,mydir,fname,nk,nl,skip, @pytest.mark.parametrize("grid_da",[None,True]) @pytest.mark.parametrize("var_type",['c','w','s','z']) @pytest.mark.parametrize("use_xmitgcm",[True,False]) -def test_convert_tiles_to_xda(llc_mds_datadirs,get_test_ds,mydir,fname,nk,nl, skip, +def test_convert_tiles_to_xda(llc_mds_datadirs,get_global_ds,mydir,fname,nk,nl, skip, grid_da, var_type, use_xmitgcm): if mydir == 'xmitgcm': mydir,_ = llc_mds_datadirs - ds = get_test_ds + ds = get_global_ds data_tiles = ecco.read_llc_to_tiles(fdir=mydir, fname=fname, diff --git a/ecco_v4_py/test/test_meridional_trsp.py b/ecco_v4_py/test/test_meridional_trsp.py index 1a11321..97d19ea 100644 --- a/ecco_v4_py/test/test_meridional_trsp.py +++ b/ecco_v4_py/test/test_meridional_trsp.py @@ -7,7 +7,7 @@ import pytest import ecco_v4_py -from .test_common import llc_mds_datadirs, get_test_ds, get_test_vectors +from .test_common import all_mds_datadirs, get_test_ds from .test_vector_calc import get_fake_vectors @pytest.mark.parametrize("lats",[-20,0,10,np.array([-30,-15,20,45])]) @@ -19,96 +19,91 @@ def test_trsp_ds(get_test_ds,lats): assert np.all(test.time==exp.time) assert np.all(test.k == exp.k) -@pytest.mark.parametrize("lats",[-20,0,10,np.array([-30,-15,20,45])]) +@pytest.mark.parametrize("myfunc, tfld, xflds, yflds, factor", + [ (ecco_v4_py.calc_meridional_vol_trsp,"vol_trsp_z", + ['UVELMASS'],['VVELMASS'], 1e-6), + (ecco_v4_py.calc_meridional_heat_trsp,"heat_trsp_z", + ['ADVx_TH','DFxE_TH'],['ADVy_TH','DFyE_TH'],1e-15*1029*4000), + (ecco_v4_py.calc_meridional_salt_trsp,"salt_trsp_z", + ['ADVx_SLT','DFxE_SLT'],['ADVy_SLT','DFyE_SLT'],1e-6)]) +@pytest.mark.parametrize("lats",[0,np.array([-20,30,45])]) @pytest.mark.parametrize("basin",[None,'atlExt','pacExt','indExt']) -def test_vol_trsp(get_test_vectors,lats,basin): - """compute a volume transport""" +def test_meridional_trsp(get_test_ds,myfunc,tfld,xflds,yflds,factor,lats,basin): + """compute a transport""" - ds = get_test_vectors + ds = get_test_ds grid = ecco_v4_py.get_llc_grid(ds) ds['U'],ds['V'] = get_fake_vectors(ds['U'],ds['V']) - ds = ds.rename({'U':'UVELMASS','V':'VVELMASS'}) + for fx,fy in zip(xflds,yflds): + ds[fx] = ds['U'].copy() + ds[fy] = ds['V'].copy() - trsp = ecco_v4_py.calc_meridional_vol_trsp(ds,lats,basin_name=basin,grid=grid) - if basin is not None: - basinW = ecco_v4_py.get_basin_mask(basin,ds['maskW']) - basinS = ecco_v4_py.get_basin_mask(basin,ds['maskS']) - else: + if basin is None or len(ds.tile)==13: + trsp = myfunc(ds,lats,basin_name=basin,grid=grid) basinW = ds['maskW'] basinS = ds['maskS'] - - - lats = [lats] if np.isscalar(lats) else lats - for lat in lats: - maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(lat,ds['YC'],grid) - - trspx = (ds['drF']*ds['dyG']*np.abs(maskW)).where(basinW).sum(dim=['i_g','j','tile']) - trspy = (ds['drF']*ds['dxG']*np.abs(maskS)).where(basinS).sum(dim=['i','j_g','tile']) - test = trsp.sel(lat=lat).vol_trsp_z.reset_coords(drop=True) - expected = (1e-6*(trspx+trspy)).reset_coords(drop=True) - xr.testing.assert_allclose(test,expected) - -@pytest.mark.parametrize("lats",[-20,0,10,np.array([-30,-15,20,45])]) -@pytest.mark.parametrize("basin",[None,'atlExt','pacExt','indExt']) -def test_heat_trsp(get_test_vectors,lats,basin): - """compute heat transport""" - - ds = get_test_vectors + if basin is not None: + basinW = ecco_v4_py.get_basin_mask(basin,basinW) + basinS = ecco_v4_py.get_basin_mask(basin,basinS) + + lats = [lats] if np.isscalar(lats) else lats + expx = (ds['drF']*ds['dyG']).copy() if tfld == 'vol_trsp_z' else 2.*xr.ones_like(ds['hFacW']) + expy = (ds['drF']*ds['dxG']).copy() if tfld == 'vol_trsp_z' else 2.*xr.ones_like(ds['hFacS']) + for lat in lats: + maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(lat,ds['YC'],grid) + + trspx = (expx*np.abs(maskW)).where(basinW).sum(dim=['i_g','j','tile']) + trspy = (expy*np.abs(maskS)).where(basinS).sum(dim=['i','j_g','tile']) + + test = trsp.sel(lat=lat)[tfld].squeeze().reset_coords(drop=True) + expected = (factor*(trspx+trspy)).reset_coords(drop=True) + xr.testing.assert_allclose(test,expected) + else: + with pytest.raises(NotImplementedError): + trsp = myfunc(ds,lats,basin_name=basin,grid=grid) + +@pytest.mark.parametrize("myfunc, fld, xflds, yflds", + [ (ecco_v4_py.calc_meridional_vol_trsp,"vol_trsp", + ['UVELMASS'],['VVELMASS']), + (ecco_v4_py.calc_meridional_heat_trsp,"heat_trsp", + ['ADVx_TH','DFxE_TH'],['ADVy_TH','DFyE_TH']), + (ecco_v4_py.calc_meridional_salt_trsp,"salt_trsp", + ['ADVx_SLT','DFxE_SLT'],['ADVy_SLT','DFyE_SLT'])]) +@pytest.mark.parametrize("lat",[10]) # more is unnecessary +def test_separate_coords(get_test_ds,myfunc,fld,xflds,yflds,lat): + ds = get_test_ds grid = ecco_v4_py.get_llc_grid(ds) ds['U'],ds['V'] = get_fake_vectors(ds['U'],ds['V']) - ds = ds.rename({'U':'ADVx_TH','V':'ADVy_TH'}) - ds['DFxE_TH'] = ds['ADVx_TH'].copy() - ds['DFyE_TH'] = ds['ADVy_TH'].copy() - - trsp = ecco_v4_py.calc_meridional_heat_trsp(ds,lats,basin_name=basin,grid=grid) - if basin is not None: - basinW = ecco_v4_py.get_basin_mask(basin,ds['maskW']) - basinS = ecco_v4_py.get_basin_mask(basin,ds['maskS']) - else: - basinW = ds['maskW'] - basinS = ds['maskS'] + for fx,fy in zip(xflds,yflds): + ds[fx] = ds['U'].copy() + ds[fy] = ds['V'].copy() + expected = myfunc(ds,lat,grid=grid) + coords = ds.coords.to_dataset().reset_coords() + ds = ds.reset_coords(drop=True) - lats = [lats] if np.isscalar(lats) else lats - for lat in lats: - maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(lat,ds['YC'],grid) + test = myfunc(ds,lat,coords=coords,grid=grid) + xr.testing.assert_equal(test[fld].reset_coords(drop=True), + expected[fld].reset_coords(drop=True)) - trspx = (2*np.abs(maskW)).where(basinW).sum(dim=['i_g','j','tile']) - trspy = (2*np.abs(maskS)).where(basinS).sum(dim=['i','j_g','tile']) - test = trsp.sel(lat=lat).heat_trsp_z.reset_coords(drop=True) - expected = (1e-15*1029*4000*(trspx+trspy)).reset_coords(drop=True) - xr.testing.assert_allclose(test,expected) - -@pytest.mark.parametrize("lats",[-20,0,10,np.array([-30,-15,20,45])]) -@pytest.mark.parametrize("basin",[None,'atlExt','pacExt','indExt']) -def test_salt_trsp(get_test_vectors,lats,basin): - """compute salt transport""" +@pytest.mark.parametrize("lat",[10]) +def test_trsp_masking(get_test_ds,lat): + """make sure internal masking is legit""" - ds = get_test_vectors + ds = get_test_ds grid = ecco_v4_py.get_llc_grid(ds) ds['U'],ds['V'] = get_fake_vectors(ds['U'],ds['V']) - ds = ds.rename({'U':'ADVx_SLT','V':'ADVy_SLT'}) - ds['DFxE_SLT'] = ds['ADVx_SLT'].copy() - ds['DFyE_SLT'] = ds['ADVy_SLT'].copy() - - trsp = ecco_v4_py.calc_meridional_salt_trsp(ds,lats,basin_name=basin,grid=grid) - if basin is not None: - basinW = ecco_v4_py.get_basin_mask(basin,ds['maskW']) - basinS = ecco_v4_py.get_basin_mask(basin,ds['maskS']) - else: - basinW = ds['maskW'] - basinS = ds['maskS'] - + ds['U'] = ds['U'].where(ds['maskW'],0.) + ds['V'] = ds['V'].where(ds['maskS'],0.) - lats = [lats] if np.isscalar(lats) else lats - for lat in lats: - maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(lat,ds['YC'],grid) + expected = ecco_v4_py.meridional_trsp_at_depth(ds['U'],ds['V'],lat,ds) + coords = ds[['Z','YC','XC','dyG','dxG','time']].copy() + coords.attrs=ds.attrs.copy() + ds = ds.reset_coords(drop=True) + test = ecco_v4_py.meridional_trsp_at_depth(ds['U'],ds['V'],lat,coords) - trspx = (2*np.abs(maskW)).where(basinW).sum(dim=['i_g','j','tile']) - trspy = (2*np.abs(maskS)).where(basinS).sum(dim=['i','j_g','tile']) - test = trsp.sel(lat=lat).salt_trsp_z.reset_coords(drop=True) - expected = (1e-6*(trspx+trspy)).reset_coords(drop=True) - xr.testing.assert_allclose(test,expected) + xr.testing.assert_equal(test['trsp_z'].reset_coords(drop=True), + expected['trsp_z'].reset_coords(drop=True)) diff --git a/ecco_v4_py/test/test_proj_plot.py b/ecco_v4_py/test/test_proj_plot.py index 0aa70d4..e1bb30b 100644 --- a/ecco_v4_py/test/test_proj_plot.py +++ b/ecco_v4_py/test/test_proj_plot.py @@ -9,7 +9,7 @@ import pytest from ecco_v4_py import plot_proj_to_latlon_grid -from .test_common import llc_mds_datadirs,get_test_array_2d +from .test_common import all_mds_datadirs, get_test_ds @pytest.mark.parametrize("kwargs", [ {'projection_type':'Mercator'}, @@ -36,17 +36,18 @@ {'show_land':False,'show_coastline':False}, {'show_grid_lines':False}]) @pytest.mark.parametrize("dx, dy",[(1,1)]) -def test_plot_proj(get_test_array_2d,kwargs,dx,dy): +def test_plot_proj(get_test_ds,kwargs,dx,dy): """Run through various options and make sure nothing is broken""" - test_arr = get_test_array_2d + ds = get_test_ds kwargs['dx']=dx kwargs['dy']=dy + print(ds) if 'blah' in kwargs.values(): with pytest.raises(NotImplementedError): - plot_proj_to_latlon_grid(test_arr.XC,test_arr.YC,test_arr,**kwargs) + plot_proj_to_latlon_grid(ds.XC,ds.YC,ds.ETAN,**kwargs) else: - plot_proj_to_latlon_grid(test_arr.XC,test_arr.YC,test_arr,**kwargs) + plot_proj_to_latlon_grid(ds.XC,ds.YC,ds.ETAN,**kwargs) plt.close() diff --git a/ecco_v4_py/test/test_scalar_calc.py b/ecco_v4_py/test/test_scalar_calc.py index f6d2fa0..7ae6890 100644 --- a/ecco_v4_py/test/test_scalar_calc.py +++ b/ecco_v4_py/test/test_scalar_calc.py @@ -9,7 +9,7 @@ import pytest from ecco_v4_py import scalar_calc, get_llc_grid -from .test_common import llc_mds_datadirs, get_test_ds +from .test_common import all_mds_datadirs, get_test_ds def test_latitude_mask(get_test_ds): """run through lats, and ensure we're grabbing the closest @@ -24,7 +24,7 @@ def test_latitude_mask(get_test_ds): dLat = 0.5 # is this robust? nx = 90 - for lat in np.arange(-89,89,5): + for lat in np.arange(-89,89,10): print('lat: ',lat) maskC = scalar_calc.get_latitude_mask(lat,ds['YC'],grid) diff --git a/ecco_v4_py/test/test_section_masks.py b/ecco_v4_py/test/test_section_masks.py index b71e975..2abf27d 100644 --- a/ecco_v4_py/test/test_section_masks.py +++ b/ecco_v4_py/test/test_section_masks.py @@ -1,7 +1,7 @@ import ecco_v4_py as ecco -from .test_common import llc_mds_datadirs, get_test_vectors +from .test_common import all_mds_datadirs, get_test_ds def test_section_endpoints(): """Ensure that the listed available sections are actually there @@ -10,12 +10,12 @@ def test_section_endpoints(): for section in ecco.get_available_sections(): assert ecco.get_section_endpoints(section) is not None -def test_calc_all_sections(get_test_vectors): +def test_calc_all_sections(get_test_ds): """Ensure that we can compute all section masks... not sure how to test these exactly... """ - ds = get_test_vectors + ds = get_test_ds for section in ecco.get_available_sections(): pt1,pt2 = ecco.get_section_endpoints(section) diff --git a/ecco_v4_py/test/test_section_trsp.py b/ecco_v4_py/test/test_section_trsp.py index 74f33bd..8747431 100644 --- a/ecco_v4_py/test/test_section_trsp.py +++ b/ecco_v4_py/test/test_section_trsp.py @@ -7,9 +7,13 @@ import pytest import ecco_v4_py -from .test_common import llc_mds_datadirs, get_test_ds, get_test_vectors +from .test_common import all_mds_datadirs, get_test_ds from .test_vector_calc import get_fake_vectors +_section='floridastrait' +_pt1=[-81,28] +_pt2=[-79,22] + def test_trsp_ds(get_test_ds): """stupid simple""" exp = get_test_ds @@ -17,165 +21,104 @@ def test_trsp_ds(get_test_ds): assert np.all(test.time==exp.time) assert np.all(test.k == exp.k) -@pytest.mark.parametrize("name, pt1, pt2, maskW, maskS, expArr", +@pytest.mark.parametrize("myfunc, tfld, xflds, yflds, factor", + [ (ecco_v4_py.calc_section_vol_trsp,"vol_trsp_z", + ['UVELMASS'],['VVELMASS'], 1e-6), + (ecco_v4_py.calc_section_heat_trsp,"heat_trsp_z", + ['ADVx_TH','DFxE_TH'],['ADVy_TH','DFyE_TH'],1029*4000*1e-15), + (ecco_v4_py.calc_section_salt_trsp,"salt_trsp_z", + ['ADVx_SLT','DFxE_SLT'],['ADVy_SLT','DFyE_SLT'],1e-6)]) +@pytest.mark.parametrize("args, mask, error", [ - ("drakepassage",None,None,None,None,None), - (None,[-173,65.5],[-164,65.5],None,None,None), - (None,None,None,True,True,None), - (None,None,None,None,None,TypeError), - ("drakepassage",[-173,65.5],[-164,65.5],None,None,TypeError), - ("drakepassage",None,None,True,True,TypeError), - (None,[-173,65.5],[-164,65.5],True,True,TypeError), - ("noname",None,None,None,None,TypeError) + ({'section_name':_section,'pt1':None,'pt2':None},False,None), + ({'section_name':None, 'pt1':_pt1,'pt2':_pt2},False,None), + ({'section_name':None, 'pt1':None,'pt2':None},True ,None), + ({'section_name':None, 'pt1':None,'pt2':None},False,TypeError), + ({'section_name':_section,'pt1':_pt1,'pt2':_pt2},False,TypeError), + ({'section_name':_section,'pt1':None,'pt2':None},True ,TypeError), + ({'section_name':"noname",'pt1':None,'pt2':None},False,TypeError), ]) -def test_vol_trsp(get_test_vectors,name,pt1,pt2,maskW,maskS,expArr): - """compute a volume transport""" +def test_section_trsp(get_test_ds,myfunc,tfld,xflds,yflds,factor,args,mask,error): + """compute a volume transport, + within the lat/lon portion of the domain""" - ds = get_test_vectors + ds = get_test_ds grid = ecco_v4_py.get_llc_grid(ds) ds['U'],ds['V'] = get_fake_vectors(ds['U'],ds['V']) - ds = ds.rename({'U':'UVELMASS','V':'VVELMASS'}) + for fx,fy in zip(xflds,yflds): + ds[fx] = ds['U'].copy() + ds[fy] = ds['V'].copy() - if maskW is not None and maskS is not None: - if maskW and maskS: - maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(30,ds['YC'],grid) + myargs = args.copy() + if mask: + myargs['maskW'],myargs['maskS'] = ecco_v4_py.vector_calc.get_latitude_masks(30,ds['YC'],grid) + else: + myargs['maskW']=None + myargs['maskS']=None - if expArr is None: - trsp = ecco_v4_py.calc_section_vol_trsp(ds, - pt1=pt1,pt2=pt2, - maskW=maskW,maskS=maskS, - section_name=name, - grid=grid) + if error is None: + trsp = myfunc(ds,grid=grid,**myargs) maskW,maskS = ecco_v4_py.calc_section_trsp._parse_section_trsp_inputs(ds, - pt1=pt1,pt2=pt2,maskW=maskW,maskS=maskS, - section_name=name) + grid=grid,**myargs) - trspx = (ds['drF']*ds['dyG']*np.abs(maskW)).where(ds['maskW']).sum(dim=['i_g','j','tile']) - trspy = (ds['drF']*ds['dxG']*np.abs(maskS)).where(ds['maskS']).sum(dim=['i','j_g','tile']) - test = trsp.vol_trsp_z.reset_coords(drop=True) - expected = (1e-6*(trspx+trspy)).reset_coords(drop=True) - xr.testing.assert_allclose(test,expected) + expx = (ds['drF']*ds['dyG']).copy() if tfld == 'vol_trsp_z' else 2.*xr.ones_like(ds['hFacW']) + expy = (ds['drF']*ds['dxG']).copy() if tfld == 'vol_trsp_z' else 2.*xr.ones_like(ds['hFacS']) + trspx = (expx*np.abs(maskW)).where(ds['maskW']).sum(dim=['i_g','j','tile']) + trspy = (expy*np.abs(maskS)).where(ds['maskS']).sum(dim=['i','j_g','tile']) - else: - with pytest.raises(expArr): - trsp = ecco_v4_py.calc_section_vol_trsp(ds, - pt1=pt1,pt2=pt2, - maskW=maskW,maskS=maskS, - section_name=name, - grid=grid) - - maskW,maskS = ecco_v4_py.calc_section_trsp._parse_section_trsp_inputs(ds, - pt1=pt1,pt2=pt2,maskW=maskW,maskS=maskS, - section_name=name) - -@pytest.mark.parametrize("name, pt1, pt2, maskW, maskS, expArr", - [ - ("drakepassage",None,None,None,None,None), - (None,[-173,65.5],[-164,65.5],None,None,None), - (None,None,None,True,True,None), - (None,None,None,None,None,TypeError), - ("drakepassage",[-173,65.5],[-164,65.5],None,None,TypeError), - ("drakepassage",None,None,True,True,TypeError), - (None,[-173,65.5],[-164,65.5],True,True,TypeError), - ("noname",None,None,None,None,TypeError) - ]) -def test_heat_trsp(get_test_vectors,name,pt1,pt2,maskW,maskS,expArr): - """compute heat transport""" + test = trsp[tfld].squeeze().reset_coords(drop=True) + expected = (factor*(trspx+trspy)).reset_coords(drop=True) + xr.testing.assert_equal(test,expected) - ds = get_test_vectors + else: + with pytest.raises(error): + trsp = myfunc(ds,**myargs) + +@pytest.mark.parametrize("myfunc, tfld, xflds, yflds", + [ (ecco_v4_py.calc_section_vol_trsp,"vol_trsp", + ['UVELMASS'],['VVELMASS']), + (ecco_v4_py.calc_section_heat_trsp,"heat_trsp", + ['ADVx_TH','DFxE_TH'],['ADVy_TH','DFyE_TH']), + (ecco_v4_py.calc_section_salt_trsp,"salt_trsp", + ['ADVx_SLT','DFxE_SLT'],['ADVy_SLT','DFyE_SLT'])]) +@pytest.mark.parametrize("section_name",["beringstrait"]) # more is unnecessary +def test_separate_coords(get_test_ds,myfunc,tfld,xflds,yflds,section_name): + ds = get_test_ds grid = ecco_v4_py.get_llc_grid(ds) ds['U'],ds['V'] = get_fake_vectors(ds['U'],ds['V']) - ds = ds.rename({'U':'ADVx_TH','V':'ADVy_TH'}) - ds['DFxE_TH'] = ds['ADVx_TH'].copy() - ds['DFyE_TH'] = ds['ADVy_TH'].copy() + for fx,fy in zip(xflds,yflds): + ds[fx] = ds['U'] + ds[fy] = ds['V'] - if maskW is not None and maskS is not None: - if maskW and maskS: - maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(30,ds['YC'],grid) + expected = myfunc(ds,section_name=section_name,grid=grid) + coords = ds.coords.to_dataset().reset_coords() + ds = ds.reset_coords(drop=True) - if expArr is None: - trsp = ecco_v4_py.calc_section_heat_trsp(ds, - pt1=pt1,pt2=pt2, - maskW=maskW,maskS=maskS, - section_name=name, - grid=grid) - - maskW,maskS = ecco_v4_py.calc_section_trsp._parse_section_trsp_inputs(ds, - pt1=pt1,pt2=pt2,maskW=maskW,maskS=maskS, - section_name=name) + test = myfunc(ds,section_name=section_name,coords=coords,grid=grid) + xr.testing.assert_equal(test[tfld].reset_coords(drop=True), + expected[tfld].reset_coords(drop=True)) - trspx = (2*np.abs(maskW)).where(ds['maskW']).sum(dim=['i_g','j','tile']) - trspy = (2*np.abs(maskS)).where(ds['maskS']).sum(dim=['i','j_g','tile']) - test = trsp.heat_trsp_z.reset_coords(drop=True) - expected = (1e-15*1029*4000*(trspx+trspy)).reset_coords(drop=True) - xr.testing.assert_allclose(test,expected) +@pytest.mark.parametrize("section_name",["beringstrait"]) +def test_trsp_masking(get_test_ds,section_name): + """make sure internal masking is legit""" - else: - with pytest.raises(expArr): - trsp = ecco_v4_py.calc_section_heat_trsp(ds, - pt1=pt1,pt2=pt2, - maskW=maskW,maskS=maskS, - section_name=name, - grid=grid) - - maskW,maskS = ecco_v4_py.calc_section_trsp._parse_section_trsp_inputs(ds, - pt1=pt1,pt2=pt2,maskW=maskW,maskS=maskS, - section_name=name) - -@pytest.mark.parametrize("name, pt1, pt2, maskW, maskS, expArr", - [ - ("drakepassage",None,None,None,None,None), - (None,[-173,65.5],[-164,65.5],None,None,None), - (None,None,None,True,True,None), - (None,None,None,None,None,TypeError), - ("drakepassage",[-173,65.5],[-164,65.5],None,None,TypeError), - ("drakepassage",None,None,True,True,TypeError), - (None,[-173,65.5],[-164,65.5],True,True,TypeError), - ("noname",None,None,None,None,TypeError) - ]) -def test_salt_trsp(get_test_vectors,name,pt1,pt2,maskW,maskS,expArr): - """compute salt transport""" - - ds = get_test_vectors + ds = get_test_ds grid = ecco_v4_py.get_llc_grid(ds) ds['U'],ds['V'] = get_fake_vectors(ds['U'],ds['V']) - ds = ds.rename({'U':'ADVx_SLT','V':'ADVy_SLT'}) - ds['DFxE_SLT'] = ds['ADVx_SLT'].copy() - ds['DFyE_SLT'] = ds['ADVy_SLT'].copy() + ds['U'] = ds['U'].where(ds['maskW'],0.) + ds['V'] = ds['V'].where(ds['maskS'],0.) - if maskW is not None and maskS is not None: - if maskW and maskS: - maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(30,ds['YC'],grid) + pt1,pt2 = ecco_v4_py.get_section_endpoints(section_name) + _, maskW,maskS = ecco_v4_py.get_section_line_masks(pt1,pt2,ds) - if expArr is None: - trsp = ecco_v4_py.calc_section_salt_trsp(ds, - pt1=pt1,pt2=pt2, - maskW=maskW,maskS=maskS, - section_name=name, - grid=grid) + expected = ecco_v4_py.section_trsp_at_depth(ds['U'],ds['V'],maskW,maskS,ds) - maskW,maskS = ecco_v4_py.calc_section_trsp._parse_section_trsp_inputs(ds, - pt1=pt1,pt2=pt2,maskW=maskW,maskS=maskS, - section_name=name) - - trspx = (2*np.abs(maskW)).where(ds['maskW']).sum(dim=['i_g','j','tile']) - trspy = (2*np.abs(maskS)).where(ds['maskS']).sum(dim=['i','j_g','tile']) - test = trsp.salt_trsp_z.reset_coords(drop=True) - expected = (1e-6*(trspx+trspy)).reset_coords(drop=True) - xr.testing.assert_allclose(test,expected) + ds = ds.drop_vars(['maskW','maskS']) + test = ecco_v4_py.section_trsp_at_depth(ds['U'],ds['V'],maskW,maskS) - else: - with pytest.raises(expArr): - trsp = ecco_v4_py.calc_section_salt_trsp(ds, - pt1=pt1,pt2=pt2, - maskW=maskW,maskS=maskS, - section_name=name, - grid=grid) - - maskW,maskS = ecco_v4_py.calc_section_trsp._parse_section_trsp_inputs(ds, - pt1=pt1,pt2=pt2,maskW=maskW,maskS=maskS, - section_name=name) + xr.testing.assert_equal(test['trsp_z'].reset_coords(drop=True), + expected['trsp_z'].reset_coords(drop=True)) diff --git a/ecco_v4_py/test/test_stf.py b/ecco_v4_py/test/test_stf.py index 513181b..8f4ed1b 100644 --- a/ecco_v4_py/test/test_stf.py +++ b/ecco_v4_py/test/test_stf.py @@ -7,85 +7,88 @@ import pytest import ecco_v4_py -from .test_common import llc_mds_datadirs, get_test_ds, get_test_vectors +from .test_common import all_mds_datadirs, get_test_ds from .test_vector_calc import get_fake_vectors -@pytest.mark.parametrize("lats",[-20,0,10,np.array([-30,-15,20,45])]) +_section='floridastrait' +_pt1=[-81,28] +_pt2=[-79,22] + +@pytest.mark.parametrize("lats",[0,np.array([-20,30,45])]) @pytest.mark.parametrize("basin",[None,'atlExt','pacExt','indExt']) @pytest.mark.parametrize("doFlip",[True,False]) -def test_meridional_stf(get_test_vectors,lats,basin,doFlip): +def test_meridional_stf(get_test_ds,lats,basin,doFlip): """compute a meridional streamfunction""" - ds = get_test_vectors + ds = get_test_ds grid = ecco_v4_py.get_llc_grid(ds) ds['U'],ds['V'] = get_fake_vectors(ds['U'],ds['V']) ds = ds.rename({'U':'UVELMASS','V':'VVELMASS'}) - trsp = ecco_v4_py.calc_meridional_stf(ds,lats,doFlip=doFlip,basin_name=basin,grid=grid) - if basin is not None: - basinW = ecco_v4_py.get_basin_mask(basin,ds['maskW']) - basinS = ecco_v4_py.get_basin_mask(basin,ds['maskS']) - else: + if basin is None or len(ds.tile)==13: + trsp = ecco_v4_py.calc_meridional_stf(ds,lats,doFlip=doFlip,basin_name=basin,grid=grid) + basinW = ds['maskW'] basinS = ds['maskS'] + if basin is not None: + basinW = ecco_v4_py.get_basin_mask(basin,basinW) + basinS = ecco_v4_py.get_basin_mask(basin,basinS) + + lats = [lats] if np.isscalar(lats) else lats + for lat in lats: + maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(lat,ds['YC'],grid) + + trspx = (ds['drF']*ds['dyG']*np.abs(maskW)).where(basinW).sum(dim=['i_g','j','tile']) + trspy = (ds['drF']*ds['dxG']*np.abs(maskS)).where(basinS).sum(dim=['i','j_g','tile']) + test = trsp.sel(lat=lat).psi_moc.squeeze().reset_coords(drop=True) + expected = (1e-6*(trspx+trspy)).reset_coords(drop=True) + if doFlip: + expected = expected.isel(k=slice(None,None,-1)) + expected=expected.cumsum(dim='k') + if doFlip: + expected = -1*expected.isel(k=slice(None,None,-1)) + xr.testing.assert_allclose(test,expected) + else: + with pytest.raises(NotImplementedError): + trsp = ecco_v4_py.calc_meridional_stf(ds,lats,doFlip=doFlip,basin_name=basin,grid=grid) - - lats = [lats] if np.isscalar(lats) else lats - for lat in lats: - maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(lat,ds['YC'],grid) - - trspx = (ds['drF']*ds['dyG']*np.abs(maskW)).where(basinW).sum(dim=['i_g','j','tile']) - trspy = (ds['drF']*ds['dxG']*np.abs(maskS)).where(basinS).sum(dim=['i','j_g','tile']) - test = trsp.sel(lat=lat).psi_moc.reset_coords(drop=True) - expected = (1e-6*(trspx+trspy)).reset_coords(drop=True) - if doFlip: - expected = expected.isel(k=slice(None,None,-1)) - expected=expected.cumsum(dim='k') - if doFlip: - expected = -1*expected.isel(k=slice(None,None,-1)) - xr.testing.assert_allclose(test,expected) - -@pytest.mark.parametrize("name, pt1, pt2, maskW, maskS, expArr", +@pytest.mark.parametrize("args, mask, error", [ - ("drakepassage",None,None,None,None,None), - (None,[-173,65.5],[-164,65.5],None,None,None), - (None,None,None,True,True,None), - (None,None,None,None,None,TypeError), - ("drakepassage",[-173,65.5],[-164,65.5],None,None,TypeError), - ("drakepassage",None,None,True,True,TypeError), - (None,[-173,65.5],[-164,65.5],True,True,TypeError), - ("noname",None,None,None,None,TypeError) + ({'section_name':_section,'pt1':None,'pt2':None},False,None), + ({'section_name':None, 'pt1':_pt1,'pt2':_pt2},False,None), + ({'section_name':None, 'pt1':None,'pt2':None},True ,None), + ({'section_name':None, 'pt1':None,'pt2':None},False,TypeError), + ({'section_name':_section,'pt1':_pt1,'pt2':_pt2},False,TypeError), + ({'section_name':_section,'pt1':None,'pt2':None},True ,TypeError), + ({'section_name':"noname",'pt1':None,'pt2':None},False,TypeError), ]) @pytest.mark.parametrize("doFlip",[True,False]) -def test_section_stf(get_test_vectors,name,pt1,pt2,maskW,maskS,expArr,doFlip): +def test_section_stf(get_test_ds,args,mask,error,doFlip): """compute streamfunction across section""" - ds = get_test_vectors + ds = get_test_ds grid = ecco_v4_py.get_llc_grid(ds) ds['U'],ds['V'] = get_fake_vectors(ds['U'],ds['V']) ds = ds.rename({'U':'UVELMASS','V':'VVELMASS'}) - if maskW is not None and maskS is not None: - if maskW and maskS: - maskW,maskS = ecco_v4_py.vector_calc.get_latitude_masks(30,ds['YC'],grid) + myargs = args.copy() + if mask: + myargs['maskW'],myargs['maskS'] = ecco_v4_py.vector_calc.get_latitude_masks(30,ds['YC'],grid) + else: + myargs['maskW']=None + myargs['maskS']=None - if expArr is None: - trsp = ecco_v4_py.calc_section_stf(ds, - pt1=pt1,pt2=pt2, - maskW=maskW,maskS=maskS, - section_name=name, - doFlip=doFlip,grid=grid) + if error is None: + trsp = ecco_v4_py.calc_section_stf(ds,doFlip=doFlip,grid=grid,**myargs) - maskW,maskS = ecco_v4_py.calc_section_trsp._parse_section_trsp_inputs(ds, - pt1=pt1,pt2=pt2,maskW=maskW,maskS=maskS, - section_name=name) + maskW,maskS = ecco_v4_py.calc_section_trsp._parse_section_trsp_inputs(ds,**myargs) trspx = (ds['drF']*ds['dyG']*np.abs(maskW)).where(ds['maskW']).sum(dim=['i_g','j','tile']) trspy = (ds['drF']*ds['dxG']*np.abs(maskS)).where(ds['maskS']).sum(dim=['i','j_g','tile']) - test = trsp.psi_moc.reset_coords(drop=True) + test = trsp.psi_moc.squeeze().reset_coords(drop=True) expected = (1e-6*(trspx+trspy)).reset_coords(drop=True) if doFlip: expected = expected.isel(k=slice(None,None,-1)) @@ -95,13 +98,24 @@ def test_section_stf(get_test_vectors,name,pt1,pt2,maskW,maskS,expArr,doFlip): xr.testing.assert_allclose(test,expected) else: - with pytest.raises(expArr): - trsp = ecco_v4_py.calc_section_stf(ds, - pt1=pt1,pt2=pt2, - maskW=maskW,maskS=maskS, - section_name=name, - doFlip=doFlip,grid=grid) - - maskW,maskS = ecco_v4_py.calc_section_trsp._parse_section_trsp_inputs(ds, - pt1=pt1,pt2=pt2,maskW=maskW,maskS=maskS, - section_name=name) + with pytest.raises(error): + trsp = ecco_v4_py.calc_section_stf(ds,**myargs) + +@pytest.mark.parametrize("myfunc, myarg", + [ (ecco_v4_py.calc_meridional_stf, {'lat_vals':10}), + (ecco_v4_py.calc_section_stf,{'section_name':_section})]) +def test_separate_coords(get_test_ds,myfunc,myarg): + ds = get_test_ds + grid = ecco_v4_py.get_llc_grid(ds) + + ds['U'],ds['V'] = get_fake_vectors(ds['U'],ds['V']) + ds = ds.rename({'U':'UVELMASS','V':'VVELMASS'}) + + myarg['grid']=grid + expected = myfunc(ds,**myarg) + coords = ds.coords.to_dataset().reset_coords() + ds = ds.reset_coords(drop=True) + + test = myfunc(ds,coords=coords,**myarg) + xr.testing.assert_allclose(test['psi_moc'].reset_coords(drop=True), + expected['psi_moc'].reset_coords(drop=True)) diff --git a/ecco_v4_py/test/test_vector_calc.py b/ecco_v4_py/test/test_vector_calc.py index e68d04f..8781e7b 100644 --- a/ecco_v4_py/test/test_vector_calc.py +++ b/ecco_v4_py/test/test_vector_calc.py @@ -9,20 +9,20 @@ import pytest from ecco_v4_py import vector_calc, get_llc_grid -from .test_common import llc_mds_datadirs, get_test_vectors +from .test_common import all_mds_datadirs, get_test_ds -def test_no_angles(get_test_vectors): +def test_no_angles(get_test_ds): """quick error handling test""" - ds = get_test_vectors + ds = get_test_ds ds = ds.drop_vars(['CS','SN']) with pytest.raises(KeyError): vector_calc.UEVNfromUXVY(ds['U'],ds['V'],ds) -def test_optional_grid(get_test_vectors): +def test_optional_grid(get_test_ds): """simple, make sure we can optionally provide grid...""" - ds = get_test_vectors + ds = get_test_ds grid = get_llc_grid(ds) uX = xr.ones_like(ds['U'].isel(k=0)).load(); @@ -33,31 +33,34 @@ def test_optional_grid(get_test_vectors): assert (u1==u2).all() assert (v1==v2).all() -def test_uevn_from_uxvy(get_test_vectors): +def test_uevn_from_uxvy(get_test_ds): """make sure grid loc is correct... etc... test by feeding right combo of 1, -1's so that all velocities are positive 1's ... interp should preserve this in the lat/lon components""" - ds = get_test_vectors + ds = get_test_ds uX,vY = get_fake_vectors(ds['U'],ds['V']) uE,vN = vector_calc.UEVNfromUXVY(uX,vY,ds) assert set(('i','j')).issubset(uE.dims) - assert set(('j','j')).issubset(vN.dims) + assert set(('i','j')).issubset(vN.dims) # check the lat/lon tiles - for t in [1,4,8,11]: - assert np.allclose(uE.sel(tile=t),1,atol=1e-15) - assert np.allclose(vN.sel(tile=t),1,atol=1e-15) - -def test_latitude_masks(get_test_vectors): + tilelist = [1,4,8,11] if len(uE.tile)==13 else [0,5] + for t in tilelist: + assert np.allclose(uE.where(ds.maskC,0.).sel(tile=t), + 1.*ds.maskC.sel(tile=t),atol=1e-12) + assert np.allclose(vN.where(ds.maskC,0.).sel(tile=t), + 1.*ds.maskC.sel(tile=t),atol=1e-12) + +def test_latitude_masks(get_test_ds): """run through lats, and ensure we're grabbing the closest "south-grid-cell-location" (whether that's in the x or y direction!) to each latitude value""" - ds = get_test_vectors + ds = get_test_ds grid = get_llc_grid(ds) yW = grid.interp(ds['YC'],'X',boundary='fill') @@ -68,26 +71,29 @@ def test_latitude_masks(get_test_vectors): dLat = 0.5 # is this robust? nx = 90 - for lat in np.arange(-89,89,5): + for lat in np.arange(-89,89,10): print('lat: ',lat) maskW,maskS = vector_calc.get_latitude_masks(lat,ds['YC'],grid) maskW = maskW.where((maskW!=0) & wetW,0.) maskS = maskS.where((maskS!=0) & wetS,0.) - assert not (maskW>0).sel(tile=slice(7,None)).any() - assert not (maskS<0).sel(tile=slice(5)).any() + arctic = int((len(maskW.tile)-1)/2) + assert not (maskW>0).sel(tile=slice(arctic+1,None)).any() + assert not (maskS<0).sel(tile=slice(arctic-1)).any() assert (yW-lat < dLat).where(ds['maskW'].isel(k=0) & (maskW!=0)).all().values assert (yS-lat < dLat).where(ds['maskS'].isel(k=0) & (maskS!=0)).all().values + def get_fake_vectors(fldx,fldy): fldx.load(); fldy.load(); + arctic = int((len(fldx.tile)-1)/2) for t in fldx.tile.values: - if t<6: + if t