Skip to content

Commit

Permalink
Revert pywavelets code changes and constrain scipy <1.15
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Jan 15, 2025
1 parent dd19a17 commit 01a91e4
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 18 deletions.
3 changes: 1 addition & 2 deletions conda-env/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ dependencies:
- netcdf4
- output_viewer >=1.3.0
- numpy >=2.0.0,<3.0.0
- pywavelets
- scipy
- scipy <1.15
- shapely >=2.0.0,<3.0.0
- xarray >=2024.3.0
- xcdat >=0.7.3,<1.0.0
Expand Down
3 changes: 1 addition & 2 deletions conda-env/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ dependencies:
- netcdf4
- output_viewer >=1.3.0
- numpy >=2.0.0,<3.0.0
- pywavelets
- scipy
- scipy <1.15
- shapely >=2.0.0,<3.0.0
- xarray >=2024.3.0
- xcdat >=0.7.3,<1.0.0
Expand Down
13 changes: 1 addition & 12 deletions e3sm_diags/driver/qbo_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import TYPE_CHECKING, Dict, Literal, Tuple, TypedDict

import numpy as np
import pywt
import scipy.fftpack
import xarray as xr
import xcdat as xc
Expand Down Expand Up @@ -410,23 +409,13 @@ def _get_psd_from_wavelet(data: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
-------
Tuple(np.ndarray, np.ndarray)
The period and PSD arrays.
Notes
-----
- https://pywavelets.readthedocs.io/en/latest/ref/cwt.html#complex-morlet-wavelets
"""
deg = 6
period = np.arange(1, 55 + 1)
freq = 1 / period

# FIXME: Where do we pass in the `deg` argument for Omega0?
# scipy.signal.mortlet2() has a `deg` argument, but not
# pywt.ContinuousWavelet().
# Complex Morlet Wavelets ("cmorB-C")
wavelet_obj = pywt.ContinuousWavelet("cmorB-C", w=deg)
widths = deg / (2 * np.pi * freq)

cwtmatr, _ = pywt.cwt(data, widths, wavelet_obj)
cwtmatr = scipy.signal.cwt(data, scipy.signal.morlet2, widths=widths, w=deg)
psd = np.mean(np.square(np.abs(cwtmatr)), axis=1)

return (period, psd)
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ dependencies = [
"netcdf4",
"output_viewer >=1.3.0",
"numpy >=2.0.0,<3.0.0",
"pywavelets",
"scipy",
"scipy <1.15",
"shapely >=2.0.0,<3.0.0",
"xarray >=2024.03.0",
"xcdat >=0.7.3,<1.0.0",
Expand Down

0 comments on commit 01a91e4

Please sign in to comment.