Skip to content

Commit

Permalink
wip: cleaner reader
Browse files Browse the repository at this point in the history
  • Loading branch information
jourdain committed Jan 24, 2025
1 parent 85eae17 commit 891cef4
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 70 deletions.
11 changes: 9 additions & 2 deletions pan3d/ui/preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from pan3d.ui.css import base, preview
from pan3d.ui.collapsible import CollapsableSection
from pan3d.xarray.cf.constants import Projection


class SummaryToolbar(v3.VCard):
Expand Down Expand Up @@ -732,7 +733,9 @@ def update_from_source(self, source=None):
self.state.axis_names = [source.x, source.y, source.z]
self.state.slice_extents = source.slice_extents
self.state.projection_mode = (
"spherical" if source.spherical else "euclidean"
"spherical"
if source.projection == Projection.SPHERICAL
else "euclidean"
)
self.state.spherical_bias = source.vertical_bias
self.state.spherical_scale = source.vertical_scale
Expand Down Expand Up @@ -865,7 +868,11 @@ def _on_array_selection(self, data_arrays, **_):
def _on_projection_change(
self, spherical_bias, spherical_scale, projection_mode, **_
):
self.source.spherical = projection_mode == "spherical"
self.source.projection = (
Projection.SPHERICAL
if projection_mode == "spherical"
else Projection.EUCLIDEAN
)
self.source.vertical_bias = spherical_bias
self.source.vertical_scale = spherical_scale

Expand Down
5 changes: 5 additions & 0 deletions pan3d/xarray/cf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def __repr__(self):
"""


class Projection(Enum):
SPHERICAL = "Spherical"
EUCLIDEAN = "Euclidean"


class Scale(Enum):
da = (1e1, {"deca", "deka"})
h = (1e2, {"hecto"})
Expand Down
122 changes: 107 additions & 15 deletions pan3d/xarray/cf/coords/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
import numpy as np
from pan3d.xarray.cf import mesh
from pan3d.xarray.cf.coords.convert import is_uniform
from pan3d.xarray.cf.constants import Projection
from vtkmodules.vtkCommonDataModel import (
vtkImageData,
vtkRectilinearGrid,
vtkStructuredGrid,
vtkUnstructuredGrid,
)

PRESSURE_UNITS = {
"bar",
Expand All @@ -21,7 +28,6 @@
"kilometer",
"km",
}

COORDINATES_DETECTION = {
"longitude": {
"units": {
Expand Down Expand Up @@ -393,18 +399,104 @@ def use_coords(self, dims):
return False
return True

def get_mesh(self, time_index=0, spherical=True, fields=None):
if self.xr_dataset is None or not fields:
return None
def compatible_fields(self, fields=None):
if not fields:
return []
data_dims = self.xr_dataset[fields[0]].dims
return [n for n in fields if self.xr_dataset[n].dims == data_dims]

def dimensions(self, field):
return self.xr_dataset[field].dims

def timeless_dimensions(self, field):
dims = self.xr_dataset[field].dims
return dims[1:] if dims[0] == self.time else dims

def field_extent(self, field):
extent = [0, 0, 0, 0, 0, 0]
dimensions = self.timeless_dimensions(field)
for idx in range(len(dimensions)):
array = self.xr_dataset[dimensions[-(1 + idx)]]
# Fill in reverse order (t, z, y, x) => [0, x.size, 0, y.size, 0, z.size]
# And extent include both index so (len-1)
extent[idx * 2 + 1] = array.size - 1

return extent

def get_vtk_mesh_type(self, projection, fields=None):
fields = self.compatible_fields(fields)

if self.longitude is None or self.latitude is None or not fields:
# default empty mesh
return vtkImageData()

# unstructured
timeless_dims = self.timeless_dimensions(fields[0])
if len(timeless_dims) == 1:
return vtkUnstructuredGrid()

# structured
if (
self.coords_has_bounds
or projection == Projection.SPHERICAL
or not self.coords_1d
):
return vtkStructuredGrid()

# rectilinear
if not self.uniform_spacing:
return vtkRectilinearGrid()

# imagedata
return vtkImageData()

def get_vtk_whole_extent(self, projection, fields=None):
if self.longitude is None or self.latitude is None or not fields:
return [
0,
0,
0,
0,
0,
0,
]

mesh_type = self.get_vtk_mesh_type(projection, fields)
fields = self.compatible_fields(fields)
extent = self.field_extent(fields[0])
dimensions = self.timeless_dimensions(fields[0])

print(f"before {extent=}")
print(f"class {mesh_type.GetClassName()}")

if mesh_type.IsA("vtkStructuredGrid") and not (
self.uniform_lat_lon and self.use_coords(dimensions)
):
# point data
return extent

# cell data, need to +1 on the extent
for i in range(3):
if extent[i * 2 + 1] > 0:
extent[i * 2 + 1] += 1

print(f"after {extent=}")

return extent

def get_vtk_mesh(self, time_index=0, projection=None, fields=None):
vtk_mesh, data_location = None, None
if self.xr_dataset is None or not fields:
return vtk_mesh

# resolve projection
if projection is None:
projection = Projection.SPHERICAL
spherical_proj = projection == Projection.SPHERICAL

# ensure similar dimension across array names
data_dims = self.xr_dataset[fields[0]].dims
data_dims_no_time = data_dims[1:] if data_dims[0] == self.time else data_dims
valid_data_array_names = [
n for n in fields if self.xr_dataset[n].dims == data_dims
]
fields = self.compatible_fields(fields)
data_dims_no_time = self.timeless_dimensions(fields[0])

# No mesh if no lon/lat
if self.longitude is None or self.latitude is None:
Expand All @@ -413,20 +505,20 @@ def get_mesh(self, time_index=0, spherical=True, fields=None):
# Unstructured
if len(data_dims_no_time) == 1:
vtk_mesh, data_location = mesh.unstructured.generate_mesh(
self, data_dims_no_time, time_index, spherical
self, data_dims_no_time, time_index, spherical_proj
)

# Structured
if vtk_mesh is None and (
self.coords_has_bounds or spherical or not self.coords_1d
self.coords_has_bounds or spherical_proj or not self.coords_1d
):
vtk_mesh, data_location = mesh.structured.generate_mesh(
self, data_dims_no_time, time_index, spherical
self, data_dims_no_time, time_index, spherical_proj
)

# This should only happen if we don't want spherical
# This should only happen if we don't want spherical_proj
if vtk_mesh is None:
assert not spherical
assert not spherical_proj

# Rectilinear
if vtk_mesh is None and not self.uniform_spacing:
Expand All @@ -443,7 +535,7 @@ def get_mesh(self, time_index=0, spherical=True, fields=None):
# Add fields
if vtk_mesh:
container = getattr(vtk_mesh, data_location)
for field_name in valid_data_array_names:
for field_name in fields:
field = (
self.xr_dataset[field_name][time_index].values
if self.time
Expand Down
Loading

0 comments on commit 891cef4

Please sign in to comment.