Skip to content

Commit

Permalink
Hhh 118 update interpolation method in waverespons as todays method i…
Browse files Browse the repository at this point in the history
…s deprecated (#65)
  • Loading branch information
helene-pisani-4ss authored Jul 4, 2024
1 parent d0bd0d8 commit f45955b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ classifiers = [
"Operating System :: Microsoft :: Windows",
]
dependencies = [
"numpy<2.0.0",
"numpy",
"pandas",
"scipy<1.14.0",
"scipy",
"pyarrow"
]

Expand Down
23 changes: 14 additions & 9 deletions src/waveresponse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from scipy.integrate import trapezoid
from scipy.interpolate import interp2d
from scipy.interpolate import RegularGridInterpolator as RGI
from scipy.special import gamma


Expand Down Expand Up @@ -601,11 +601,12 @@ def rotate(self, angle, degrees=False):

def _interpolate_function(self, complex_convert="rectangular", **kw):
"""
Interpolation function based on ``scipy.interpolate.interp2d``.
Interpolation function based on ``scipy.interpolate.RegularGridInterpolator``.
"""
xp = np.concatenate(
(self._dirs[-1:] - 2 * np.pi, self._dirs, self._dirs[:1] + 2.0 * np.pi)
)

yp = self._freq
zp = np.concatenate(
(
Expand All @@ -617,11 +618,11 @@ def _interpolate_function(self, complex_convert="rectangular", **kw):
)

if np.all(np.isreal(zp)):
return interp2d(xp, yp, zp, **kw)
return RGI((xp, yp), zp.T, **kw)
elif complex_convert.lower() == "polar":
amp, phase = complex_to_polar(zp, phase_degrees=False)
interp_amp = interp2d(xp, yp, amp, **kw)
interp_phase = interp2d(xp, yp, phase, **kw)
interp_amp = RGI((xp, yp), amp.T, **kw)
interp_phase = RGI((xp, yp), phase.T, **kw)
return lambda *args, **kwargs: (
polar_to_complex(
interp_amp(*args, **kwargs),
Expand All @@ -630,8 +631,8 @@ def _interpolate_function(self, complex_convert="rectangular", **kw):
)
)
elif complex_convert.lower() == "rectangular":
interp_real = interp2d(xp, yp, np.real(zp), **kw)
interp_imag = interp2d(xp, yp, np.imag(zp), **kw)
interp_real = RGI((xp, yp), np.real(zp.T), **kw)
interp_imag = RGI((xp, yp), np.imag(zp.T), **kw)
return lambda *args, **kwargs: (
interp_real(*args, **kwargs) + 1j * interp_imag(*args, **kwargs)
)
Expand Down Expand Up @@ -702,10 +703,14 @@ def interpolate(
self._check_dirs(dirs)

interp_fun = self._interpolate_function(
complex_convert=complex_convert, kind="linear", fill_value=fill_value
complex_convert=complex_convert,
method="linear",
bounds_error=False,
fill_value=fill_value,
)

return interp_fun(dirs, freq, assume_sorted=True)
dirsnew, freqnew = np.meshgrid(dirs, freq, indexing="ij", sparse=True)
return interp_fun((dirsnew, freqnew)).T

def reshape(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ def test_interpolate_single_coordinate(self):

vals_out = grid.interpolate(1.8, 12.1, freq_hz=True, degrees=True)

vals_expect = np.array([a * 12.1 + b * 1.8])
vals_expect = np.array(a * 12.1 + b * 1.8)

np.testing.assert_array_almost_equal(vals_out, vals_expect)

Expand Down

0 comments on commit f45955b

Please sign in to comment.