Skip to content

Commit

Permalink
Add in TOX dev tools and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
whalenpt committed Feb 25, 2023
1 parent 6b59eb1 commit 0fab433
Show file tree
Hide file tree
Showing 17 changed files with 2,036 additions and 1,435 deletions.
17 changes: 15 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ classifiers=[
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
Expand All @@ -34,8 +36,19 @@ demo = [
"matplotlib",
"jupyter"
]
dev = ["twine"]
test = ["pytest"]
dev = [
"flake8",
"mccabe",
"mypy",
"pylint",
"twine"
]

test = [
"coverage",
"pytest",
"tox"
]

[project.urls]
homepage = "https://github.com/whalenpt/rkstiff"
Expand Down
2 changes: 1 addition & 1 deletion rkstiff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
try:
import importlib_metadata as metadata
except:
except ModuleNotFoundError:
import importlib.metadata as metadata

__version__ = metadata.version("rkstiff")
34 changes: 16 additions & 18 deletions rkstiff/derivatives.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@

import numpy as np

def dx_rfft(kx,u,n=1):
""" Takes derivative(s) of a real valued array in spectral space

def dx_rfft(kx, u, n=1):
"""Takes derivative(s) of a real valued array in spectral space
INPUTS
kx - wavenumbers of the spectral grid
u - real valued array
n - order of derivative (default is 1)
OUTPUTS
uxn - derivative of u to the nth power
"""
if not isinstance(n,int):
raise TypeError('derivative order n must be an integer, it is {}'.format(n))
if not isinstance(n, int):
raise TypeError("derivative order n must be an integer, it is {}".format(n))
if n < 0:
raise ValueError('derivative order n must non-negative, it is {}'.format(n))
raise ValueError("derivative order n must non-negative, it is {}".format(n))

if n == 0:
return u

uFFT = np.fft.rfft(u)
if n == 1:
uxn = np.fft.irfft(1j*kx*uFFT)
uxn = np.fft.irfft(1j * kx * uFFT)
else:
uxn = np.fft.irfft(np.power(1j*kx,n)*uFFT)
uxn = np.fft.irfft(np.power(1j * kx, n) * uFFT)
return uxn

def dx_fft(kx,u,n=1):
""" Takes derivative(s) of a complex valued array in spectral space

def dx_fft(kx, u, n=1):
"""Takes derivative(s) of a complex valued array in spectral space
INPUTS
kx - wavenumbers of the spectral grid
u - complex valued array
Expand All @@ -35,19 +36,16 @@ def dx_fft(kx,u,n=1):
uxn - derivative of u to the nth power
"""

if not isinstance(n,int):
raise TypeError('derivative order n must be an integer, it is {}'.format(n))
if not isinstance(n, int):
raise TypeError("derivative order n must be an integer, it is {}".format(n))
if n < 0:
raise ValueError('derivative order n must non-negative, it is {}'.format(n))
raise ValueError("derivative order n must non-negative, it is {}".format(n))
if n == 0:
return u

uFFT = np.fft.fft(u)
if n == 1:
uxn = np.fft.ifft(1j*kx*uFFT)
uxn = np.fft.ifft(1j * kx * uFFT)
else:
uxn = np.fft.ifft(np.power(1j*kx,n)*uFFT)
uxn = np.fft.ifft(np.power(1j * kx, n) * uFFT)
return uxn



118 changes: 85 additions & 33 deletions rkstiff/etd.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@

import numpy as np
from rkstiff.solver import StiffSolverAS,StiffSolverCS
from rkstiff.solver import StiffSolverAS, StiffSolverCS


def phi1(z):
""" Computes RKETD psi-function of the first order.
"""Computes RKETD psi-function of the first order.
INPUTS
z - real or complex-valued input array
z - real or complex-valued input array
OUTPUT
return (exp(z)-1)/z -> real or complex-valued RKETD function of first order
"""
return (np.exp(z)-1)/z
return (np.exp(z) - 1) / z


def phi2(z):
""" Computes RKETD psi-function of the second order.
"""Computes RKETD psi-function of the second order.
INPUTS
z - real or complex-valued input array
z - real or complex-valued input array
OUTPUT
return 2!*(exp(z)-1-z/2)/z^2 -> real or complex-valued RKETD function of second order
"""
return 2*(np.exp(z)-1-z)/z**2
return 2 * (np.exp(z) - 1 - z) / z**2


def phi3(z):
""" Computes RKETD psi-function of the third order.
"""Computes RKETD psi-function of the third order.
INPUTS
z - real or complex-valued input array
z - real or complex-valued input array
OUTPUT
return 3!*(exp(z)-1-z/2-z^3/6)/z^3 -> real or complex-valued RKETD function of third order
"""
return 6*(np.exp(z)-1-z-z**2/2)/z**3
return 6 * (np.exp(z) - 1 - z - z**2 / 2) / z**3


class ETDAS(StiffSolverAS):
"""
Expand All @@ -47,25 +50,60 @@ class ETDAS(StiffSolverAS):
"""

def __init__(self,linop,NLfunc,epsilon=1e-4,incrF = 1.25, decrF = 0.85, safetyF = 0.8,\
adapt_cutoff = 0.01, minh = 1e-16, modecutoff = 0.01, contour_points = 32,\
contour_radius = 1.0):
super().__init__(linop,NLfunc,epsilon=epsilon,incrF=incrF,decrF=decrF,\
safetyF=safetyF,adapt_cutoff=adapt_cutoff,minh=minh)
def __init__(
self,
linop,
NLfunc,
epsilon=1e-4,
incrF=1.25,
decrF=0.85,
safetyF=0.8,
adapt_cutoff=0.01,
minh=1e-16,
modecutoff=0.01,
contour_points=32,
contour_radius=1.0,
):
super().__init__(
linop,
NLfunc,
epsilon=epsilon,
incrF=incrF,
decrF=decrF,
safetyF=safetyF,
adapt_cutoff=adapt_cutoff,
minh=minh,
)
self.modecutoff = modecutoff
if (self.modecutoff > 1.0) or (self.modecutoff <= 0):
raise ValueError('modecutoff must be between 0.0 and 1.0 but is {}'.format(self.modecutoff))
raise ValueError(
"modecutoff must be between 0.0 and 1.0 but is {}".format(
self.modecutoff
)
)
self.contour_points = contour_points
if not isinstance(self.contour_points,int):
raise TypeError('contour_points must be an integer but is {}'.format(self.contour_points))
if not isinstance(self.contour_points, int):
raise TypeError(
"contour_points must be an integer but is {}".format(
self.contour_points
)
)
if self.contour_points <= 1:
raise ValueError('contour_points must be an integer greater than 1 but is {}'.format(self.contour_points))
raise ValueError(
"contour_points must be an integer greater than 1 but is {}".format(
self.contour_points
)
)

self.contour_radius = contour_radius
if self.contour_radius <= 0:
raise ValueError('contour_radius must greater than 0 but is {}'.format(self.contour_radius))
raise ValueError(
"contour_radius must greater than 0 but is {}".format(
self.contour_radius
)
)
self._h_coeff = None


class ETDCS(StiffSolverCS):
"""
Expand All @@ -85,21 +123,35 @@ class ETDCS(StiffSolverCS):
"""

def __init__(self,linop,NLfunc,modecutoff = 0.01,contour_points = 32,contour_radius = 1.0):
super().__init__(linop,NLfunc)
def __init__(
self, linop, NLfunc, modecutoff=0.01, contour_points=32, contour_radius=1.0
):
super().__init__(linop, NLfunc)
self.modecutoff = modecutoff
if (self.modecutoff > 1.0) or (self.modecutoff <= 0):
raise ValueError('modecutoff must be between 0.0 and 1.0 but is {}'.format(self.modecutoff))
raise ValueError(
"modecutoff must be between 0.0 and 1.0 but is {}".format(
self.modecutoff
)
)
self.contour_points = contour_points
if not isinstance(self.contour_points,int):
raise TypeError('contour_points must be an integer but is {}'.format(self.contour_points))
if not isinstance(self.contour_points, int):
raise TypeError(
"contour_points must be an integer but is {}".format(
self.contour_points
)
)
if self.contour_points <= 1:
raise ValueError('contour_points must be an integer greater than 1 but is {}'.format(self.contour_points))
raise ValueError(
"contour_points must be an integer greater than 1 but is {}".format(
self.contour_points
)
)
self.contour_radius = contour_radius
if self.contour_radius <= 0:
raise ValueError('contour_radius must greater than 0 but is {}'.format(self.contour_radius))
raise ValueError(
"contour_radius must greater than 0 but is {}".format(
self.contour_radius
)
)
self._h_coeff = None




Loading

0 comments on commit 0fab433

Please sign in to comment.