diff --git a/src/christoffel.py b/src/christoffel.py index 5e11be3..fb74b67 100644 --- a/src/christoffel.py +++ b/src/christoffel.py @@ -5,7 +5,7 @@ # # # Filename: christoffel.py # # Description: This module calculates phase and group seismic velocities # -# in solids. # +# in solids based on the Christoffel equation. # # # # Copyright (c) 2023-Present # # # @@ -38,7 +38,10 @@ # Function definitions def christoffel_wave_speeds( - Cij: np.ndarray, density: float, wavevectors: np.ndarray, type="phase" + Cij: np.ndarray, + density_gcm3: float, + wavevectors: np.ndarray, + type="phase" ): """_summary_ @@ -67,22 +70,12 @@ def christoffel_wave_speeds( """ # Sanity checks on inputs - # Check if Cij is a 6x6 symmetric matrix - if not isinstance(Cij, np.ndarray) or Cij.shape != (6, 6): - raise ValueError("Cij should be a 6x6 NumPy array.") - if not np.allclose(Cij, Cij.T): - raise ValueError("Cij should be symmetric.") - # validate wavevectors - if not isinstance(wavevectors, np.ndarray): - raise ValueError("wavevectors should be a NumPy array.") + validate_cijs(Cij) + validate_wavevectors(wavevectors) + + # rearrange in case wavevectors have a shape (3,) if wavevectors.ndim == 1 and wavevectors.shape[0] == 3: wavevectors = wavevectors.reshape(1, 3) - elif wavevectors.ndim == 2 and wavevectors.shape[1] == 3: - pass - else: - raise ValueError( - "wavevectors should be a NumPy array of shape (3,) if 1D or (n, 3) if 2D." - ) # rearrange tensor Cij → Cijkl Cijkl = _rearrange_tensor(Cij) @@ -90,7 +83,7 @@ def christoffel_wave_speeds( # estimate the normalized Christoffel matrix (M) for # every wavevector Mij = _christoffel_matrix(wavevectors, Cijkl) - scaling_factor = 1 / density + scaling_factor = 1 / density_gcm3 norm_Mij = Mij * scaling_factor # estimate the eigenvalues and eigenvectors @@ -450,7 +443,7 @@ def calc_spherical_angles(group_directions: np.ndarray) -> np.ndarray: x, z = group_directions[ori, : 0], group_directions[ori, : 2] # handle edge cases for z near ±1 - near_pole + # near_pole # TODO (UNTESTED!) @@ -600,4 +593,55 @@ def _calc_enhancement_factor(Hλ): pass +############################################################################################ +def validate_cijs(Cij) -> bool: + """Input validation + + Parameters + ---------- + Cij : np.ndarray + The input array to validate. + + Returns + ------- + bool : True if the array is of the correct shape and type, + otherwise raises a ValueError. + """ + if not isinstance(Cij, np.ndarray) or Cij.shape != (6, 6): + raise ValueError("Cij should be a 6x6 NumPy array.") + + if Cij is not np.allclose(Cij, Cij.T): + raise ValueError("Cij should be symmetric.") + + return True + + +def validate_wavevectors(wavevectors) -> bool: + """Input validation + + Parameters + ---------- + wavevectors : np.ndarray of shape (3,) or (n, 3) + The input array to validate. + + Returns + ------- + bool : True if the array is of the correct shape and type, + otherwise raises a ValueError. + """ + if not isinstance(wavevectors, np.ndarray): + raise ValueError("Input must be a NumPy array.") + + if wavevectors.ndim not in [1, 2]: + raise ValueError("Input array must be 1-dimensional or 2-dimensional.") + + if wavevectors.ndim == 1 and wavevectors.shape != (3,): + raise ValueError("1-dimensional array must have shape (3,).") + + if wavevectors.ndim == 2 and wavevectors.shape[1] != 3: + raise ValueError("2-dimensional array must have shape (n, 3).") + + return True + + # End of file