Skip to content

Commit

Permalink
Merge branch 'main' into fix-typo
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot authored Dec 17, 2024
2 parents 12c3c4a + 9f10381 commit 06a0d47
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 8 deletions.
2 changes: 1 addition & 1 deletion discretize/mixins/mesh_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def read_UBC(TreeMesh, file_name, directory=None):
else:
max_level = min(ls) + 1

mesh = TreeMesh(hs, origin=origin)
mesh = TreeMesh(hs, origin=origin, diagonal_balance=False)
levels = indArr[:, -1]
indArr = indArr[:, :-1]

Expand Down
2 changes: 1 addition & 1 deletion discretize/mixins/mpl_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2149,7 +2149,7 @@ def __plot_slice_tree(
normal[normalInd] = 1

# create a temporary TreeMesh with the slice through
temp_mesh = discretize.TreeMesh(h2d, x2d)
temp_mesh = discretize.TreeMesh(h2d, x2d, diagonal_balance=False)
level_diff = self.max_level - temp_mesh.max_level

# get list of cells which intersect the slicing plane
Expand Down
2 changes: 1 addition & 1 deletion discretize/tree_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,7 +1071,7 @@ def equals(self, other): # NOQA D102

def __reduce__(self):
"""Return the necessary items to reconstruct this object's state."""
return TreeMesh, (self.h, self.origin), self.__getstate__()
return TreeMesh, (self.h, self.origin, False), self.__getstate__()

cellGrad = deprecate_property(
"cell_gradient", "cellGrad", removal_version="1.0.0", error=True
Expand Down
6 changes: 5 additions & 1 deletion discretize/utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def is_scalar(f):
"""
if isinstance(f, SCALARTYPES):
return True
elif isinstance(f, np.ndarray) and f.size == 1 and isinstance(f[0], SCALARTYPES):
elif (
isinstance(f, np.ndarray)
and f.size == 1
and isinstance(f.reshape(-1)[0], SCALARTYPES)
):
return True
return False

Expand Down
6 changes: 5 additions & 1 deletion discretize/utils/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def mesh_builder_xyz(
depth_core=None,
expansion_factor=1.3,
mesh_type="tensor",
tree_diagonal_balance=None,
):
"""Generate a tensor or tree mesh using a cloud of points.
Expand Down Expand Up @@ -428,6 +429,9 @@ def mesh_builder_xyz(
Expansion factor for padding cells. Ignored if *mesh_type* = *tree*
mesh_type : {'tensor', 'tree'}
Specify output mesh type
tree_diagonal_balance : bool, optional
Whether to diagonally balance the tree mesh, `None` will use the `TreeMesh`
default behavoir.
Returns
-------
Expand Down Expand Up @@ -521,7 +525,7 @@ def expand(dx, pad):
h_dim += [np.ones(2**maxLevel) * h[ii]]

# Define the mesh and origin
mesh = discretize.TreeMesh(h_dim)
mesh = discretize.TreeMesh(h_dim, diagonal_balance=tree_diagonal_balance)

for ii, _cc in enumerate(nC):
core = limits[ii][0] - limits[ii][1]
Expand Down
12 changes: 9 additions & 3 deletions tests/base/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,15 @@ def test_is_scalar(self):
self.assertTrue(is_scalar(1.0))
self.assertTrue(is_scalar(1))
self.assertTrue(is_scalar(1j))
self.assertTrue(is_scalar(np.r_[1.0]))
self.assertTrue(is_scalar(np.r_[1]))
self.assertTrue(is_scalar(np.r_[1j]))
self.assertTrue(is_scalar(np.array(1.0)))
self.assertTrue(is_scalar(np.array(1)))
self.assertTrue(is_scalar(np.array(1j)))
self.assertTrue(is_scalar(np.array([1.0])))
self.assertTrue(is_scalar(np.array([1])))
self.assertTrue(is_scalar(np.array([1j])))
self.assertTrue(is_scalar(np.array([[1.0]])))
self.assertTrue(is_scalar(np.array([[1]])))
self.assertTrue(is_scalar(np.array([[1j]])))

def test_as_array_n_by_dim(self):
true = np.array([[1, 2, 3]])
Expand Down
14 changes: 14 additions & 0 deletions tests/tree/test_tree_io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import numpy as np
import discretize
import pickle
Expand Down Expand Up @@ -49,6 +50,19 @@ def test_UBCfiles(mesh, tmp_path):
np.testing.assert_array_equal(vec, vecUBC2)


def test_ubc_files_no_warning_diagonal_balance(mesh, tmp_path):
"""
Test that reading UBC files don't trigger the diagonal balance warning.
"""
# Save the sample mesh into a UBC file
fname = tmp_path / "temp.msh"
mesh.write_UBC(fname)
# Make sure that no warning is raised when reading the mesh
with warnings.catch_warnings():
warnings.simplefilter("error")
discretize.TreeMesh.read_UBC(fname)


if has_vtk:

def test_write_VTU_files(mesh, tmp_path):
Expand Down

0 comments on commit 06a0d47

Please sign in to comment.