Skip to content

Commit

Permalink
add input validation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoalopez committed Jan 30, 2025
1 parent df68f7e commit 62c0bc1
Showing 1 changed file with 62 additions and 18 deletions.
80 changes: 62 additions & 18 deletions src/christoffel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# #
# Filename: christoffel.py #
# Description: This module calculates phase and group seismic velocities #
# in solids. #
# in solids based on the Christoffel equation. #
# #
# Copyright (c) 2023-Present #
# #
Expand Down Expand Up @@ -38,7 +38,10 @@

# Function definitions
def christoffel_wave_speeds(
Cij: np.ndarray, density: float, wavevectors: np.ndarray, type="phase"
Cij: np.ndarray,
density_gcm3: float,
wavevectors: np.ndarray,
type="phase"
):
"""_summary_
Expand Down Expand Up @@ -67,30 +70,20 @@ def christoffel_wave_speeds(
"""

# Sanity checks on inputs
# Check if Cij is a 6x6 symmetric matrix
if not isinstance(Cij, np.ndarray) or Cij.shape != (6, 6):
raise ValueError("Cij should be a 6x6 NumPy array.")
if not np.allclose(Cij, Cij.T):
raise ValueError("Cij should be symmetric.")
# validate wavevectors
if not isinstance(wavevectors, np.ndarray):
raise ValueError("wavevectors should be a NumPy array.")
validate_cijs(Cij)
validate_wavevectors(wavevectors)

# rearrange in case wavevectors have a shape (3,)
if wavevectors.ndim == 1 and wavevectors.shape[0] == 3:
wavevectors = wavevectors.reshape(1, 3)
elif wavevectors.ndim == 2 and wavevectors.shape[1] == 3:
pass
else:
raise ValueError(
"wavevectors should be a NumPy array of shape (3,) if 1D or (n, 3) if 2D."
)

# rearrange tensor Cij → Cijkl
Cijkl = _rearrange_tensor(Cij)

# estimate the normalized Christoffel matrix (M) for
# every wavevector
Mij = _christoffel_matrix(wavevectors, Cijkl)
scaling_factor = 1 / density
scaling_factor = 1 / density_gcm3
norm_Mij = Mij * scaling_factor

# estimate the eigenvalues and eigenvectors
Expand Down Expand Up @@ -450,7 +443,7 @@ def calc_spherical_angles(group_directions: np.ndarray) -> np.ndarray:
x, z = group_directions[ori, : 0], group_directions[ori, : 2]

# handle edge cases for z near ±1
near_pole
# near_pole


# TODO (UNTESTED!)
Expand Down Expand Up @@ -600,4 +593,55 @@ def _calc_enhancement_factor(Hλ):
pass


############################################################################################
def validate_cijs(Cij) -> bool:
"""Input validation
Parameters
----------
Cij : np.ndarray
The input array to validate.
Returns
-------
bool : True if the array is of the correct shape and type,
otherwise raises a ValueError.
"""
if not isinstance(Cij, np.ndarray) or Cij.shape != (6, 6):
raise ValueError("Cij should be a 6x6 NumPy array.")

if Cij is not np.allclose(Cij, Cij.T):
raise ValueError("Cij should be symmetric.")

return True


def validate_wavevectors(wavevectors) -> bool:
"""Input validation
Parameters
----------
wavevectors : np.ndarray of shape (3,) or (n, 3)
The input array to validate.
Returns
-------
bool : True if the array is of the correct shape and type,
otherwise raises a ValueError.
"""
if not isinstance(wavevectors, np.ndarray):
raise ValueError("Input must be a NumPy array.")

if wavevectors.ndim not in [1, 2]:
raise ValueError("Input array must be 1-dimensional or 2-dimensional.")

if wavevectors.ndim == 1 and wavevectors.shape != (3,):
raise ValueError("1-dimensional array must have shape (3,).")

if wavevectors.ndim == 2 and wavevectors.shape[1] != 3:
raise ValueError("2-dimensional array must have shape (n, 3).")

return True


# End of file

0 comments on commit 62c0bc1

Please sign in to comment.