From a71c51a82bdf9fd58124acfcb73111fc118f57fa Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 7 Feb 2022 13:52:37 -0500 Subject: [PATCH] Adding distance metric to stationary kernels (#35) * adding distance metric to stationary kernels * switching to new interface for stationary kernels * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updating multivariate tutorial * adding some docstrings * adding custom geometry tutorial * reduce memory usage in geometry tutorial * tweaking geometry tutorial labels * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/api.rst | 23 ++ docs/conf.py | 1 + docs/tutorials.md | 1 + docs/tutorials/geometry.ipynb | 263 +++++++++++++++++++ docs/tutorials/mixture.ipynb | 4 +- docs/tutorials/multivariate.ipynb | 16 +- src/tinygp/__init__.py | 6 +- src/tinygp/gp.py | 4 +- src/tinygp/kernels/__init__.py | 45 ++++ src/tinygp/{kernels.py => kernels/base.py} | 182 +------------- src/tinygp/kernels/stationary.py | 280 +++++++++++++++++++++ src/tinygp/transforms.py | 4 +- tests/test_kernels.py | 7 - tests/test_transforms.py | 14 +- 14 files changed, 633 insertions(+), 217 deletions(-) create mode 100644 docs/tutorials/geometry.ipynb create mode 100644 src/tinygp/kernels/__init__.py rename src/tinygp/{kernels.py => kernels/base.py} (59%) create mode 100644 src/tinygp/kernels/stationary.py diff --git a/docs/api.rst b/docs/api.rst index 972e95a4..948556e5 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -22,6 +22,17 @@ Kernels .. autoclass:: tinygp.kernels.Constant .. autoclass:: tinygp.kernels.DotProduct .. autoclass:: tinygp.kernels.Polynomial + + +.. _api-stationary: + +Stationary Kernels +------------------ + +Stationary kernels are defined with a distance metric implementing the +:class:`tinygp.kernels.stationary.Distance` interface. + +.. autoclass:: tinygp.kernels.Stationary .. autoclass:: tinygp.kernels.Exp .. autoclass:: tinygp.kernels.ExpSquared .. autoclass:: tinygp.kernels.Matern32 @@ -31,6 +42,18 @@ Kernels .. autoclass:: tinygp.kernels.RationalQuadratic +.. _api-distance: + +Distance Metrics +---------------- + +.. autoclass:: tinygp.kernels.stationary.Distance + :members: + +.. autoclass:: tinygp.kernels.stationary.L1Distance +.. autoclass:: tinygp.kernels.stationary.L2Distance + + .. _api-transforms: Transforms diff --git a/docs/conf.py b/docs/conf.py index aab13673..d1484efd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -51,5 +51,6 @@ autodoc_type_aliases = { "JAXArray": "tinygp.types.JAXArray", "Axis": "tinygp.kernels.Axis", + "Distance": "tinygp.kernels.Distance", "Metric": "tinygp.metrics.Metric", } diff --git a/docs/tutorials.md b/docs/tutorials.md index 7ccfa931..f7955b59 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -19,6 +19,7 @@ directly. tutorials/quickstart tutorials/modeling tutorials/multivariate +tutorials/geometry tutorials/transforms tutorials/kernels tutorials/derivative diff --git a/docs/tutorials/geometry.ipynb b/docs/tutorials/geometry.ipynb new file mode 100644 index 00000000..97f5be42 --- /dev/null +++ b/docs/tutorials/geometry.ipynb @@ -0,0 +1,263 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "26e39d4c-9caf-4435-a8ed-aba9544bdfac", + "metadata": { + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "try:\n", + " import tinygp\n", + "except ImportError:\n", + " %pip install -q tinygp\n", + "\n", + "try:\n", + " import jaxopt\n", + "except ImportError:\n", + " %pip install -q jaxopt" + ] + }, + { + "cell_type": "markdown", + "id": "d6d8253c-8c31-49be-b7d0-ead5b1dccfb5", + "metadata": {}, + "source": [ + "(geometry)=\n", + "\n", + "# Tutorial: Custom Geometry\n", + "\n", + "When working with multivariate inputs, you'll always need to choose a metric for computing the distance between coordinates in your input space.\n", + "As discussed in {ref}`multivariate`, `tinygp` includes built-in support for some common metrics which, when combined with {ref}`transforms`, can cover a wide range of use cases.\n", + "But this tutorial covers a more general use case: custom geometries.\n", + "\n", + "In this example, we will fit a GP model to data that lives on the surface of a sphere.\n", + "Here, we want to use our knowledge of this system to design a metric that takes this geometry into account.\n", + "Specifically, our data will have unit vector coordinates, and we will define a [great-circle distance](https://en.wikipedia.org/wiki/Great-circle_distance#Vector_version) to compute the angular distances between these vectors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e98bab1c-2a77-4f0d-8e87-1ff7a01bce06", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from tinygp import kernels, GaussianProcess\n", + "from jax.config import config\n", + "\n", + "config.update(\"jax_enable_x64\", True)\n", + "\n", + "\n", + "class GreatCircleDistance(kernels.stationary.Distance):\n", + " def distance(self, X1, X2):\n", + " if jnp.shape(X1) != (3,) or jnp.shape(X2) != (3,):\n", + " raise ValueError(\n", + " \"The great-circle distance is only defined for unit 3-vector\"\n", + " )\n", + " return jnp.arctan2(jnp.linalg.norm(jnp.cross(X1, X2)), (X1.T @ X2))\n", + "\n", + "\n", + "# Make a spherical grid\n", + "phi = np.linspace(-np.pi, np.pi, 50)\n", + "theta = np.linspace(-0.5 * np.pi, 0.5 * np.pi, 50)\n", + "phi_grid, theta_grid = np.meshgrid(phi, theta, indexing=\"ij\")\n", + "phi_grid = phi_grid.flatten()\n", + "theta_grid = theta_grid.flatten()\n", + "X_grid = np.vstack(\n", + " (\n", + " np.cos(phi_grid) * np.cos(theta_grid),\n", + " np.sin(phi_grid) * np.cos(theta_grid),\n", + " np.sin(theta_grid),\n", + " )\n", + ").T\n", + "\n", + "# Choose some uniformly distributed coordinates to be our \"data\"\n", + "random = np.random.default_rng(456)\n", + "X_obs = random.normal(size=(100, 3))\n", + "X_obs /= np.sqrt(np.sum(X_obs**2, axis=1))[:, None]\n", + "theta_obs = np.arctan2(\n", + " X_obs[:, 2], np.sqrt(X_obs[:, 0] ** 2 + X_obs[:, 1] ** 2)\n", + ")\n", + "phi_obs = np.arctan2(X_obs[:, 1], X_obs[:, 0])\n", + "\n", + "# Our kernel is parameterized by a length scale in **radians**\n", + "ell = 0.5\n", + "kernel = 1.5 * kernels.Matern52(ell, distance=GreatCircleDistance())\n", + "\n", + "# Sample a simulated dataset\n", + "gp = GaussianProcess(\n", + " kernel, np.concatenate((X_grid, X_obs), axis=0), diag=0.01\n", + ")\n", + "y_samp = gp.sample(jax.random.PRNGKey(10))\n", + "y_grid = y_samp[: len(X_grid)]\n", + "y_obs = y_samp[len(X_grid) :] + 0.5 * random.normal(size=len(X_obs))\n", + "\n", + "# Plot the map\n", + "plt.pcolor(\n", + " phi,\n", + " theta,\n", + " y_grid.reshape((len(phi), len(theta))).T,\n", + " vmin=y_grid.min(),\n", + " vmax=y_grid.max(),\n", + ")\n", + "plt.scatter(\n", + " phi_obs,\n", + " theta_obs,\n", + " c=y_obs,\n", + " edgecolor=\"k\",\n", + " vmin=y_grid.min(),\n", + " vmax=y_grid.max(),\n", + ")\n", + "plt.xlabel(r\"$\\phi$\")\n", + "plt.ylabel(r\"$\\theta$\")\n", + "_ = plt.title(\"simulated data\")" + ] + }, + { + "cell_type": "markdown", + "id": "6efc27c0", + "metadata": {}, + "source": [ + "Using these simulated data, we can now fit the model as usual:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7d102f4d", + "metadata": {}, + "outputs": [], + "source": [ + "import jaxopt\n", + "\n", + "\n", + "def build_gp(params):\n", + " kernel = jnp.exp(params[\"log_amp\"]) * kernels.Matern52(\n", + " jnp.exp(params[\"log_scale\"]), distance=GreatCircleDistance()\n", + " )\n", + " return GaussianProcess(\n", + " kernel,\n", + " X_obs,\n", + " diag=jnp.exp(2 * params[\"log_sigma\"]),\n", + " mean=params[\"mean\"],\n", + " )\n", + "\n", + "\n", + "@jax.jit\n", + "def loss(params):\n", + " return -build_gp(params).condition(y_obs)\n", + "\n", + "\n", + "params = {\n", + " \"log_amp\": np.zeros(()),\n", + " \"log_scale\": np.zeros(()),\n", + " \"log_sigma\": np.zeros(()),\n", + " \"mean\": np.zeros(()),\n", + "}\n", + "solver = jaxopt.ScipyMinimize(fun=loss)\n", + "soln = solver.run(params)\n", + "gp = build_gp(soln.params)" + ] + }, + { + "cell_type": "markdown", + "id": "ab716742", + "metadata": {}, + "source": [ + "At the maximum point, we can plot our model prediction compared to the ground truth, with the residuals plotted on the same scale:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c93fe6d7", + "metadata": {}, + "outputs": [], + "source": [ + "y_pred = gp.predict(y_obs, X_grid)\n", + "\n", + "fig, axes = plt.subplots(3, 1, sharex=True, figsize=(8, 8))\n", + "\n", + "axes[0].set_title(\"truth\")\n", + "axes[0].pcolor(\n", + " phi,\n", + " theta,\n", + " y_grid.reshape((len(phi), len(theta))).T,\n", + " vmin=y_grid.min(),\n", + " vmax=y_grid.max(),\n", + ")\n", + "\n", + "axes[1].set_title(\"predicted\")\n", + "axes[1].pcolor(\n", + " phi,\n", + " theta,\n", + " y_pred.reshape((len(phi), len(theta))).T,\n", + " vmin=y_grid.min(),\n", + " vmax=y_grid.max(),\n", + ")\n", + "\n", + "axes[2].set_title(\"residuals\")\n", + "axes[2].pcolor(\n", + " phi,\n", + " theta,\n", + " (y_pred - y_grid).reshape((len(phi), len(theta))).T,\n", + " vmin=y_grid.min(),\n", + " vmax=y_grid.max(),\n", + ")\n", + "\n", + "axes[2].set_xlabel(r\"$\\phi$\")\n", + "for ax in axes:\n", + " ax.set_ylabel(r\"$\\theta$\")" + ] + }, + { + "cell_type": "markdown", + "id": "c46b306f", + "metadata": {}, + "source": [ + "One thing that is worth commenting on here is that, unlike in {ref}`multivariate`, we're using only a single length scale.\n", + "This means that our kernel is _isotropic_.\n", + "For many use cases this is probably what you want because the whole point of this distance metric is that it is rotationally invariant.\n", + "If you want to model or discover anisotropies, you could use the methods discussed in {ref}`transforms`, but it would probably be worth considering designing a kernel that better captures what you're trying to model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfdaf219", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/tutorials/mixture.ipynb b/docs/tutorials/mixture.ipynb index 38989e0b..847509d4 100644 --- a/docs/tutorials/mixture.ipynb +++ b/docs/tutorials/mixture.ipynb @@ -84,8 +84,8 @@ "\n", "\n", "def build_gp(params):\n", - " kernel1 = jnp.exp(params[\"log_amp1\"]) * kernels.Matern32(\n", - " jnp.exp(params[\"log_scale1\"])\n", + " kernel1 = jnp.exp(params[\"log_amp1\"]) * transforms.Linear(\n", + " jnp.exp(params[\"log_scale1\"]), kernels.Matern32()\n", " )\n", " kernel2 = jnp.exp(params[\"log_amp2\"]) * transforms.Subspace(\n", " 0,\n", diff --git a/docs/tutorials/multivariate.ipynb b/docs/tutorials/multivariate.ipynb index 83cd496f..8c255637 100644 --- a/docs/tutorials/multivariate.ipynb +++ b/docs/tutorials/multivariate.ipynb @@ -48,7 +48,9 @@ "```\n", "\n", "as it was when using `george`.\n", - "This is indicated in the {ref}`api-kernels` section of the API docs, where the argument of each kernel is defined.\n", + "This is indicated in the {ref}`api-stationary` section of the API docs, where the argument of each kernel is defined.\n", + "\n", + "It is possible to change this behavior by specifying your preferred {class}`tinygp.kernels.stationary.Distance` metric using the `distance` argument to any {class}`tinygp.kernels.Stationary` kernel.\n", "\n", "Also, `tinygp` does not require that you specify dimension of the kernel using an `ndim` parameter when instantiating the kernel.\n", "The parameters of the kernel must, however, be broadcastable to the dimension of your inputs.\n", @@ -82,13 +84,11 @@ "ndim = 3\n", "X = np.random.default_rng(1).normal(size=(10, ndim))\n", "\n", - "# These to kernels are equivalent\n", + "# This kernel is equivalent...\n", "scale = 1.5\n", "kernel1 = kernels.Matern32(scale)\n", - "kernel2 = kernels.Matern32(jnp.full(ndim, scale))\n", - "np.testing.assert_allclose(kernel1(X, X), kernel2(X, X))\n", "\n", - "# And both are equivalent to manually scaling the input coordinates\n", + "# ... to manually scaling the input coordinates\n", "kernel0 = kernels.Matern32()\n", "np.testing.assert_allclose(kernel0(X / scale, X / scale), kernel1(X, X))" ] @@ -187,8 +187,8 @@ "\n", "\n", "def build_gp_uncorr(params):\n", - " kernel = jnp.exp(params[\"log_amp\"]) * kernels.ExpSquared(\n", - " jnp.exp(params[\"log_scale\"])\n", + " kernel = jnp.exp(params[\"log_amp\"]) * transforms.Linear(\n", + " jnp.exp(-params[\"log_scale\"]), kernels.ExpSquared()\n", " )\n", " return GaussianProcess(kernel, X, diag=yerr**2)\n", "\n", @@ -212,7 +212,7 @@ "outputs": [], "source": [ "y_pred = uncorr_gp.predict(y, X_pred).reshape(y_true.shape)\n", - "xy = uncorr_gp.kernel.kernel2.scale[:, None] * ellipse\n", + "xy = ellipse / uncorr_gp.kernel.kernel2.scale[:, None]\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)\n", "axes[0].plot(xy[0], xy[1], \"--k\", lw=0.5)\n", diff --git a/src/tinygp/__init__.py b/src/tinygp/__init__.py index e3f089b0..b3046606 100644 --- a/src/tinygp/__init__.py +++ b/src/tinygp/__init__.py @@ -2,9 +2,9 @@ __all__ = ["__version__", "kernels", "transforms", "GaussianProcess"] -from . import kernels, transforms -from .gp import GaussianProcess -from .tinygp_version import version as __version__ +from tinygp import kernels, transforms +from tinygp.gp import GaussianProcess +from tinygp.tinygp_version import version as __version__ __author__ = "Dan Foreman-Mackey" __email__ = "foreman.mackey@gmail.com" diff --git a/src/tinygp/gp.py b/src/tinygp/gp.py index 74979e84..4410b18b 100644 --- a/src/tinygp/gp.py +++ b/src/tinygp/gp.py @@ -11,8 +11,8 @@ import jax.numpy as jnp from jax.scipy import linalg -from .kernels import Kernel -from .types import JAXArray +from tinygp.kernels import Kernel +from tinygp.types import JAXArray class GaussianProcess: diff --git a/src/tinygp/kernels/__init__.py b/src/tinygp/kernels/__init__.py new file mode 100644 index 00000000..638bf5ed --- /dev/null +++ b/src/tinygp/kernels/__init__.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +__all__ = [ + "Distance", + "L1Distance", + "L2Distance", + "Kernel", + "Custom", + "Sum", + "Product", + "Constant", + "DotProduct", + "Polynomial", + "Stationary", + "Exp", + "ExpSquared", + "Matern32", + "Matern52", + "Cosine", + "ExpSineSquared", + "RationalQuadratic", +] + +from tinygp.kernels.base import ( + Constant, + Custom, + DotProduct, + Kernel, + Polynomial, + Product, + Sum, +) +from tinygp.kernels.stationary import ( + Cosine, + Distance, + Exp, + ExpSineSquared, + ExpSquared, + L1Distance, + L2Distance, + Matern32, + Matern52, + RationalQuadratic, + Stationary, +) diff --git a/src/tinygp/kernels.py b/src/tinygp/kernels/base.py similarity index 59% rename from src/tinygp/kernels.py rename to src/tinygp/kernels/base.py index d4b20ee1..5a1d8cd3 100644 --- a/src/tinygp/kernels.py +++ b/src/tinygp/kernels/base.py @@ -24,7 +24,7 @@ import jax import jax.numpy as jnp -from .types import JAXArray +from tinygp.types import JAXArray Axis = Union[int, Sequence[int]] @@ -203,183 +203,3 @@ def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: return ( (X1 / self.scale) @ (X2 / self.scale) + self.sigma2 ) ** self.order - - -class Exp(Kernel): - r"""The exponential kernel - - .. math:: - - k(\mathbf{x}_i,\,\mathbf{x}_j) = \exp(-r) - - where - - .. math:: - - r = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_1 - - Args: - scale: The parameter :math:`\ell`. - """ - - def __init__(self, scale: JAXArray = jnp.ones(())): - self.scale = scale - - def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - return jnp.exp(-jnp.sum(jnp.abs((X1 - X2) / self.scale))) - - -class ExpSquared(Kernel): - r"""The exponential squared or radial basis function kernel - - .. math:: - - k(\mathbf{x}_i,\,\mathbf{x}_j) = \exp(-r^2 / 2) - - where - - .. math:: - - r^2 = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_2^2 - - Args: - scale: The parameter :math:`\ell`. - """ - - def __init__(self, scale: JAXArray = jnp.ones(())): - self.scale = scale - - def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - return jnp.exp(-0.5 * jnp.sum(jnp.square((X1 - X2) / self.scale))) - - -class Matern32(Kernel): - r"""The Matern-3/2 kernel - - .. math:: - - k(\mathbf{x}_i,\,\mathbf{x}_j) = (1 + \sqrt{3}\,r)\,\exp(-\sqrt{3}\,r) - - where - - .. math:: - - r = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_1 - - Args: - scale: The parameter :math:`\ell`. - """ - - def __init__(self, scale: JAXArray = jnp.ones(())): - self.scale = scale - - def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - r = jnp.sum(jnp.abs((X1 - X2) / self.scale)) - arg = jnp.sqrt(3.0) * r - return (1.0 + arg) * jnp.exp(-arg) - - -class Matern52(Kernel): - r"""The Matern-5/2 kernel - - .. math:: - - k(\mathbf{x}_i,\,\mathbf{x}_j) = (1 + \sqrt{5}\,r + - 5\,r^2/\sqrt{3})\,\exp(-\sqrt{5}\,r) - - where - - .. math:: - - r = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_1 - - Args: - scale: The parameter :math:`\ell`. - """ - - def __init__(self, scale: JAXArray = jnp.ones(())): - self.scale = scale - - def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - r = jnp.sum(jnp.abs((X1 - X2) / self.scale)) - arg = jnp.sqrt(5.0) * r - return (1.0 + arg + jnp.square(arg) / 3.0) * jnp.exp(-arg) - - -class Cosine(Kernel): - r"""The cosine kernel - - .. math:: - - k(\mathbf{x}_i,\,\mathbf{x}_j) = \cos(2\,\pi\,r) - - where - - .. math:: - - r = ||(\mathbf{x}_i - \mathbf{x}_j) / P||_1 - - Args: - period: The parameter :math:`P`. - """ - - def __init__(self, period: JAXArray): - self.period = period - - def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - r = jnp.sum(jnp.abs((X1 - X2) / self.period)) - return jnp.cos(2 * jnp.pi * r) - - -class ExpSineSquared(Kernel): - r"""The exponential sine squared or quasiperiodic kernel - - .. math:: - - k(\mathbf{x}_i,\,\mathbf{x}_j) = \exp(-\Gamma\,\sin^2 \pi r) - - where - - .. math:: - - r = ||(\mathbf{x}_i - \mathbf{x}_j) / P||_1 - - Args: - period: The parameter :math:`P`. - gamma: The parameter :math:`\Gamma`. - """ - - def __init__(self, *, period: JAXArray, gamma: JAXArray): - self.period = period - self.gamma = gamma - - def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - r = jnp.sum(jnp.abs((X1 - X2) / self.period)) - return jnp.exp(-self.gamma * jnp.square(jnp.sin(jnp.pi * r))) - - -class RationalQuadratic(Kernel): - r"""The rational quadratic - - .. math:: - - k(\mathbf{x}_i,\,\mathbf{x}_j) = (1 + r^2 / 2\,\alpha)^{-\alpha} - - where - - .. math:: - - r^2 = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_2^2 - - Args: - scale: The parameter :math:`\ell`. - alpha: The parameter :math:`\alpha`. - """ - - def __init__(self, *, alpha: JAXArray, scale: Optional[JAXArray] = None): - self.scale = jnp.ones_like(alpha) if scale is None else scale - self.alpha = alpha - - def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - r2 = jnp.sum(jnp.square((X1 - X2) / self.scale)) - return (1.0 + 0.5 * r2 / self.alpha) ** -self.alpha diff --git a/src/tinygp/kernels/stationary.py b/src/tinygp/kernels/stationary.py new file mode 100644 index 00000000..0352d7a0 --- /dev/null +++ b/src/tinygp/kernels/stationary.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +__all__ = [ + "Distance", + "L1Distance", + "L2Distance", + "Stationary", + "Exp", + "ExpSquared", + "Matern32", + "Matern52", + "Cosine", + "ExpSineSquared", + "RationalQuadratic", +] + +from typing import Optional + +import jax.numpy as jnp + +from tinygp.kernels import Kernel +from tinygp.types import JAXArray + + +class Distance: + """An abstract base class defining a distance metric interface""" + + def distance(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + """Compute the distance between two coordinates under this metric""" + raise NotImplementedError() + + def squared_distance(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + """Compute the squared distance between two coordinates + + By default this returns the squared result of + :func:`tinygp.kernels.stationary.Distance.distance`, but some metrics + can take advantage of these separate implementations to avoid + unnecessary square roots. + """ + return jnp.square(self.distance(X1, X2)) + + +class L1Distance(Distance): + """The L1 or Manhattan distance between two coordinates""" + + def distance(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + return jnp.sum(jnp.abs(X1 - X2)) + + +class L2Distance(Distance): + """The L2 or Euclidean distance bettwen two coordaintes""" + + def distance(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + return jnp.sqrt(self.squared_distance(X1, X2)) + + def squared_distance(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + return jnp.sum(jnp.square(X1 - X2)) + + +class Stationary(Kernel): + """A stationary kernel is defined with respect to a distance metric + + Note that a stationary kernel is *always* isotropic. If you need more + non-isotropic length scales, wrap your kernel in a transform using + :class:`tinygp.transforms.Linear` or :class:`tinygp.transforms.Cholesky`. + + Args: + scale: The length scale, in the same units as ``distance`` for the + kernel. This must be a scalar. + distance: An object that implements ``distance`` and + ``squared_distance`` methods. Typically a subclass of + :class:`tinygp.kernels.stationary.Distance`. Each stationary kernel + also has a ``default_distance`` property that is used when + ``distance`` isn't provided. + """ + + default_distance: Distance = L1Distance() + + def __init__( + self, + scale: JAXArray = jnp.ones(()), + *, + distance: Optional[Distance] = None, + ): + if jnp.ndim(scale): + raise ValueError( + "Only scalar scales are permitted for stationary kernels; use" + "transforms.Linear or transforms.Cholesky for more flexiblity" + ) + self.scale = scale + self.distance = self.default_distance if distance is None else distance + + +class Exp(Stationary): + r"""The exponential kernel + + .. math:: + + k(\mathbf{x}_i,\,\mathbf{x}_j) = \exp(-r) + + where, by default, + + .. math:: + + r = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_1 + + Args: + scale: The parameter :math:`\ell`. + """ + + def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + return jnp.exp(-self.distance.distance(X1, X2) / self.scale) + + +class ExpSquared(Stationary): + r"""The exponential squared or radial basis function kernel + + .. math:: + + k(\mathbf{x}_i,\,\mathbf{x}_j) = \exp(-r^2 / 2) + + where, by default, + + .. math:: + + r^2 = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_2^2 + + Args: + scale: The parameter :math:`\ell`. + """ + default_distance: Distance = L2Distance() + + def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + r2 = self.distance.squared_distance(X1, X2) / jnp.square(self.scale) + return jnp.exp(-0.5 * r2) + + +class Matern32(Stationary): + r"""The Matern-3/2 kernel + + .. math:: + + k(\mathbf{x}_i,\,\mathbf{x}_j) = (1 + \sqrt{3}\,r)\,\exp(-\sqrt{3}\,r) + + where, by default, + + .. math:: + + r = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_1 + + Args: + scale: The parameter :math:`\ell`. + """ + + def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + r = self.distance.distance(X1, X2) / self.scale + arg = jnp.sqrt(3.0) * r + return (1.0 + arg) * jnp.exp(-arg) + + +class Matern52(Stationary): + r"""The Matern-5/2 kernel + + .. math:: + + k(\mathbf{x}_i,\,\mathbf{x}_j) = (1 + \sqrt{5}\,r + + 5\,r^2/\sqrt{3})\,\exp(-\sqrt{5}\,r) + + where, by default, + + .. math:: + + r = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_1 + + Args: + scale: The parameter :math:`\ell`. + """ + + def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + r = self.distance.distance(X1, X2) / self.scale + arg = jnp.sqrt(5.0) * r + return (1.0 + arg + jnp.square(arg) / 3.0) * jnp.exp(-arg) + + +class Cosine(Stationary): + r"""The cosine kernel + + .. math:: + + k(\mathbf{x}_i,\,\mathbf{x}_j) = \cos(2\,\pi\,r) + + where, by default, + + .. math:: + + r = ||(\mathbf{x}_i - \mathbf{x}_j) / P||_1 + + Args: + period: The parameter :math:`P`. + """ + + def __init__( + self, period: JAXArray, *, distance: Optional[Distance] = None + ): + super().__init__(scale=period, distance=distance) + self.period = self.scale + + def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + r = self.distance.distance(X1, X2) / self.period + return jnp.cos(2 * jnp.pi * r) + + +class ExpSineSquared(Stationary): + r"""The exponential sine squared or quasiperiodic kernel + + .. math:: + + k(\mathbf{x}_i,\,\mathbf{x}_j) = \exp(-\Gamma\,\sin^2 \pi r) + + where, by default, + + .. math:: + + r = ||(\mathbf{x}_i - \mathbf{x}_j) / P||_1 + + Args: + period: The parameter :math:`P`. + gamma: The parameter :math:`\Gamma`. + """ + + def __init__( + self, + *, + period: JAXArray, + gamma: JAXArray, + distance: Optional[Distance] = None, + ): + super().__init__(scale=period, distance=distance) + self.period = self.scale + self.gamma = gamma + + def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + r = self.distance.distance(X1, X2) / self.period + return jnp.exp(-self.gamma * jnp.square(jnp.sin(jnp.pi * r))) + + +class RationalQuadratic(Stationary): + r"""The rational quadratic + + .. math:: + + k(\mathbf{x}_i,\,\mathbf{x}_j) = (1 + r^2 / 2\,\alpha)^{-\alpha} + + where, by default, + + .. math:: + + r^2 = ||(\mathbf{x}_i - \mathbf{x}_j) / \ell||_2^2 + + Args: + scale: The parameter :math:`\ell`. + alpha: The parameter :math:`\alpha`. + """ + + def __init__( + self, + *, + alpha: JAXArray, + scale: JAXArray = jnp.ones(()), + distance: Optional[Distance] = None, + ): + super().__init__(scale=scale, distance=distance) + self.scale = scale + self.alpha = alpha + + def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: + r2 = self.distance.squared_distance(X1, X2) / jnp.square(self.scale) + return (1.0 + 0.5 * r2 / self.alpha) ** -self.alpha diff --git a/src/tinygp/transforms.py b/src/tinygp/transforms.py index 9aa2e734..6f4ea9eb 100644 --- a/src/tinygp/transforms.py +++ b/src/tinygp/transforms.py @@ -10,8 +10,8 @@ import jax.numpy as jnp from jax.scipy import linalg -from .kernels import Kernel -from .types import JAXArray +from tinygp.kernels import Kernel +from tinygp.types import JAXArray class Transform(Kernel): diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 0c859b67..48bd020a 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -48,13 +48,6 @@ def test_custom(data): k2 = kernels.ExpSquared(scale) np.testing.assert_allclose(k1(x1, x2), k2(x1, x2)) - scale = 1.5 * jnp.ones(x1.shape[1]) - k1 = kernels.Custom( - lambda X1, X2: jnp.exp(-0.5 * jnp.sum(jnp.square((X1 - X2) / scale))) - ) - k2 = kernels.ExpSquared(scale) - np.testing.assert_allclose(k1(x1, x2), k2(x1, x2)) - # Check that an invalid kernel raises as expected kernel = kernels.Custom( lambda X1, X2: jnp.exp(-0.5 * jnp.square((X1 - X2) / scale)) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index ed58d51f..b2209cce 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -17,16 +17,11 @@ def test_linear(): def test_multivariate_linear(): kernel0 = kernels.Matern32(4.5) - kernel1 = kernels.Matern32(jnp.full(3, 4.5)) - kernel2 = transforms.Linear(jnp.full(3, 1 / 4.5), kernels.Matern32()) + kernel1 = transforms.Linear(jnp.full(3, 1 / 4.5), kernels.Matern32()) np.testing.assert_allclose( kernel0.evaluate(jnp.full(3, 0.5), jnp.full(3, 0.1)), kernel1.evaluate(jnp.full(3, 0.5), jnp.full(3, 0.1)), ) - np.testing.assert_allclose( - kernel0.evaluate(jnp.full(3, 0.5), jnp.full(3, 0.1)), - kernel2.evaluate(jnp.full(3, 0.5), jnp.full(3, 0.1)), - ) def test_cholesky(): @@ -39,16 +34,11 @@ def test_cholesky(): def test_multivariate_cholesky(): kernel0 = kernels.Matern32(4.5) - kernel1 = kernels.Matern32(jnp.full(3, 4.5)) - kernel2 = transforms.Cholesky(jnp.full(3, 4.5), kernels.Matern32()) + kernel1 = transforms.Cholesky(jnp.full(3, 4.5), kernels.Matern32()) np.testing.assert_allclose( kernel0.evaluate(jnp.full(3, 0.5), jnp.full(3, 0.1)), kernel1.evaluate(jnp.full(3, 0.5), jnp.full(3, 0.1)), ) - np.testing.assert_allclose( - kernel0.evaluate(jnp.full(3, 0.5), jnp.full(3, 0.1)), - kernel2.evaluate(jnp.full(3, 0.5), jnp.full(3, 0.1)), - ) def test_subspace():