Skip to content

Commit

Permalink
Added transform and test (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
George G Vega Yon authored Mar 21, 2024
1 parent 7135843 commit 22ae4c7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
29 changes: 29 additions & 0 deletions model/src/pyrenew/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,32 @@ def transform(self, x):

def inverse(self, x):
return jax.scipy.special.expit(x)


class ScaledLogitTransform(AbstractTransform):
"""
Scaled logistic transformation from the
interval (0, X_max) to the interval
(-infinity, +infinity).
It's inverse is the inverse logit or
'expit' function multiplied by X_max
f(x) = log(x/X_max) - log(1 - x/X_max)
f^-1(x) = X_max / (1 + exp(-x))
"""

def __init__(self, x_max: float):
"""
Default constructor
Parameters
----------
x_max : float
Maximum value on the untransformed scale
(will be transformed to +infinity)
"""
self.x_max = x_max

def transform(self, x):
return jax.scipy.special.logit(x / self.x_max)

def inverse(self, x):
return self.x_max * jax.scipy.special.expit(x)
13 changes: 11 additions & 2 deletions model/src/test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from numpy.testing import assert_array_almost_equal


def generic_inversion_test(transform, test_vals, decimal=1e-8):
def generic_inversion_test(transform, test_vals, decimal=1e-8, **kwargs):
"""
Generic test for inverting a
pyrenew transform, confirming
Expand All @@ -29,8 +29,12 @@ def generic_inversion_test(transform, test_vals, decimal=1e-8):
decimal : float
Decimal tolerance, passed to
numpy.testing.assert_array_almost_equal()
**kwargs :
Additional keyword arguments passed
to the transform constructor
"""
instantiated = transform()
instantiated = transform(**kwargs)

assert_array_almost_equal(
test_vals,
Expand All @@ -51,6 +55,11 @@ def test_invert_dists():
generic_inversion_test(
t.LogitTransform, jnp.array([0.99235, 0.13242, 0.5, 0.235, 0.862])
)
generic_inversion_test(
t.ScaledLogitTransform,
50 * jnp.array([0.99235, 0.13242, 0.5, 0.235, 0.862]),
x_max=50,
)
generic_inversion_test(
t.IdentityTransform, jnp.array([0.99235, 0.13242, 0.5, 0.235, 0.862])
)

0 comments on commit 22ae4c7

Please sign in to comment.