Skip to content

Commit

Permalink
Adding distance metric to stationary kernels (#35)
Browse files Browse the repository at this point in the history
* adding distance metric to stationary kernels

* switching to new interface for stationary kernels

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* updating multivariate tutorial

* adding some docstrings

* adding custom geometry tutorial

* reduce memory usage in geometry tutorial

* tweaking geometry tutorial labels

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfm and pre-commit-ci[bot] authored Feb 7, 2022
1 parent 8e71bd7 commit a71c51a
Show file tree
Hide file tree
Showing 14 changed files with 633 additions and 217 deletions.
23 changes: 23 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ Kernels
.. autoclass:: tinygp.kernels.Constant
.. autoclass:: tinygp.kernels.DotProduct
.. autoclass:: tinygp.kernels.Polynomial


.. _api-stationary:

Stationary Kernels
------------------

Stationary kernels are defined with a distance metric implementing the
:class:`tinygp.kernels.stationary.Distance` interface.

.. autoclass:: tinygp.kernels.Stationary
.. autoclass:: tinygp.kernels.Exp
.. autoclass:: tinygp.kernels.ExpSquared
.. autoclass:: tinygp.kernels.Matern32
Expand All @@ -31,6 +42,18 @@ Kernels
.. autoclass:: tinygp.kernels.RationalQuadratic


.. _api-distance:

Distance Metrics
----------------

.. autoclass:: tinygp.kernels.stationary.Distance
:members:

.. autoclass:: tinygp.kernels.stationary.L1Distance
.. autoclass:: tinygp.kernels.stationary.L2Distance


.. _api-transforms:

Transforms
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@
autodoc_type_aliases = {
"JAXArray": "tinygp.types.JAXArray",
"Axis": "tinygp.kernels.Axis",
"Distance": "tinygp.kernels.Distance",
"Metric": "tinygp.metrics.Metric",
}
1 change: 1 addition & 0 deletions docs/tutorials.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ directly.
tutorials/quickstart
tutorials/modeling
tutorials/multivariate
tutorials/geometry
tutorials/transforms
tutorials/kernels
tutorials/derivative
Expand Down
263 changes: 263 additions & 0 deletions docs/tutorials/geometry.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "26e39d4c-9caf-4435-a8ed-aba9544bdfac",
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"try:\n",
" import tinygp\n",
"except ImportError:\n",
" %pip install -q tinygp\n",
"\n",
"try:\n",
" import jaxopt\n",
"except ImportError:\n",
" %pip install -q jaxopt"
]
},
{
"cell_type": "markdown",
"id": "d6d8253c-8c31-49be-b7d0-ead5b1dccfb5",
"metadata": {},
"source": [
"(geometry)=\n",
"\n",
"# Tutorial: Custom Geometry\n",
"\n",
"When working with multivariate inputs, you'll always need to choose a metric for computing the distance between coordinates in your input space.\n",
"As discussed in {ref}`multivariate`, `tinygp` includes built-in support for some common metrics which, when combined with {ref}`transforms`, can cover a wide range of use cases.\n",
"But this tutorial covers a more general use case: custom geometries.\n",
"\n",
"In this example, we will fit a GP model to data that lives on the surface of a sphere.\n",
"Here, we want to use our knowledge of this system to design a metric that takes this geometry into account.\n",
"Specifically, our data will have unit vector coordinates, and we will define a [great-circle distance](https://en.wikipedia.org/wiki/Great-circle_distance#Vector_version) to compute the angular distances between these vectors."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e98bab1c-2a77-4f0d-8e87-1ff7a01bce06",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from tinygp import kernels, GaussianProcess\n",
"from jax.config import config\n",
"\n",
"config.update(\"jax_enable_x64\", True)\n",
"\n",
"\n",
"class GreatCircleDistance(kernels.stationary.Distance):\n",
" def distance(self, X1, X2):\n",
" if jnp.shape(X1) != (3,) or jnp.shape(X2) != (3,):\n",
" raise ValueError(\n",
" \"The great-circle distance is only defined for unit 3-vector\"\n",
" )\n",
" return jnp.arctan2(jnp.linalg.norm(jnp.cross(X1, X2)), (X1.T @ X2))\n",
"\n",
"\n",
"# Make a spherical grid\n",
"phi = np.linspace(-np.pi, np.pi, 50)\n",
"theta = np.linspace(-0.5 * np.pi, 0.5 * np.pi, 50)\n",
"phi_grid, theta_grid = np.meshgrid(phi, theta, indexing=\"ij\")\n",
"phi_grid = phi_grid.flatten()\n",
"theta_grid = theta_grid.flatten()\n",
"X_grid = np.vstack(\n",
" (\n",
" np.cos(phi_grid) * np.cos(theta_grid),\n",
" np.sin(phi_grid) * np.cos(theta_grid),\n",
" np.sin(theta_grid),\n",
" )\n",
").T\n",
"\n",
"# Choose some uniformly distributed coordinates to be our \"data\"\n",
"random = np.random.default_rng(456)\n",
"X_obs = random.normal(size=(100, 3))\n",
"X_obs /= np.sqrt(np.sum(X_obs**2, axis=1))[:, None]\n",
"theta_obs = np.arctan2(\n",
" X_obs[:, 2], np.sqrt(X_obs[:, 0] ** 2 + X_obs[:, 1] ** 2)\n",
")\n",
"phi_obs = np.arctan2(X_obs[:, 1], X_obs[:, 0])\n",
"\n",
"# Our kernel is parameterized by a length scale in **radians**\n",
"ell = 0.5\n",
"kernel = 1.5 * kernels.Matern52(ell, distance=GreatCircleDistance())\n",
"\n",
"# Sample a simulated dataset\n",
"gp = GaussianProcess(\n",
" kernel, np.concatenate((X_grid, X_obs), axis=0), diag=0.01\n",
")\n",
"y_samp = gp.sample(jax.random.PRNGKey(10))\n",
"y_grid = y_samp[: len(X_grid)]\n",
"y_obs = y_samp[len(X_grid) :] + 0.5 * random.normal(size=len(X_obs))\n",
"\n",
"# Plot the map\n",
"plt.pcolor(\n",
" phi,\n",
" theta,\n",
" y_grid.reshape((len(phi), len(theta))).T,\n",
" vmin=y_grid.min(),\n",
" vmax=y_grid.max(),\n",
")\n",
"plt.scatter(\n",
" phi_obs,\n",
" theta_obs,\n",
" c=y_obs,\n",
" edgecolor=\"k\",\n",
" vmin=y_grid.min(),\n",
" vmax=y_grid.max(),\n",
")\n",
"plt.xlabel(r\"$\\phi$\")\n",
"plt.ylabel(r\"$\\theta$\")\n",
"_ = plt.title(\"simulated data\")"
]
},
{
"cell_type": "markdown",
"id": "6efc27c0",
"metadata": {},
"source": [
"Using these simulated data, we can now fit the model as usual:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7d102f4d",
"metadata": {},
"outputs": [],
"source": [
"import jaxopt\n",
"\n",
"\n",
"def build_gp(params):\n",
" kernel = jnp.exp(params[\"log_amp\"]) * kernels.Matern52(\n",
" jnp.exp(params[\"log_scale\"]), distance=GreatCircleDistance()\n",
" )\n",
" return GaussianProcess(\n",
" kernel,\n",
" X_obs,\n",
" diag=jnp.exp(2 * params[\"log_sigma\"]),\n",
" mean=params[\"mean\"],\n",
" )\n",
"\n",
"\n",
"@jax.jit\n",
"def loss(params):\n",
" return -build_gp(params).condition(y_obs)\n",
"\n",
"\n",
"params = {\n",
" \"log_amp\": np.zeros(()),\n",
" \"log_scale\": np.zeros(()),\n",
" \"log_sigma\": np.zeros(()),\n",
" \"mean\": np.zeros(()),\n",
"}\n",
"solver = jaxopt.ScipyMinimize(fun=loss)\n",
"soln = solver.run(params)\n",
"gp = build_gp(soln.params)"
]
},
{
"cell_type": "markdown",
"id": "ab716742",
"metadata": {},
"source": [
"At the maximum point, we can plot our model prediction compared to the ground truth, with the residuals plotted on the same scale:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c93fe6d7",
"metadata": {},
"outputs": [],
"source": [
"y_pred = gp.predict(y_obs, X_grid)\n",
"\n",
"fig, axes = plt.subplots(3, 1, sharex=True, figsize=(8, 8))\n",
"\n",
"axes[0].set_title(\"truth\")\n",
"axes[0].pcolor(\n",
" phi,\n",
" theta,\n",
" y_grid.reshape((len(phi), len(theta))).T,\n",
" vmin=y_grid.min(),\n",
" vmax=y_grid.max(),\n",
")\n",
"\n",
"axes[1].set_title(\"predicted\")\n",
"axes[1].pcolor(\n",
" phi,\n",
" theta,\n",
" y_pred.reshape((len(phi), len(theta))).T,\n",
" vmin=y_grid.min(),\n",
" vmax=y_grid.max(),\n",
")\n",
"\n",
"axes[2].set_title(\"residuals\")\n",
"axes[2].pcolor(\n",
" phi,\n",
" theta,\n",
" (y_pred - y_grid).reshape((len(phi), len(theta))).T,\n",
" vmin=y_grid.min(),\n",
" vmax=y_grid.max(),\n",
")\n",
"\n",
"axes[2].set_xlabel(r\"$\\phi$\")\n",
"for ax in axes:\n",
" ax.set_ylabel(r\"$\\theta$\")"
]
},
{
"cell_type": "markdown",
"id": "c46b306f",
"metadata": {},
"source": [
"One thing that is worth commenting on here is that, unlike in {ref}`multivariate`, we're using only a single length scale.\n",
"This means that our kernel is _isotropic_.\n",
"For many use cases this is probably what you want because the whole point of this distance metric is that it is rotationally invariant.\n",
"If you want to model or discover anisotropies, you could use the methods discussed in {ref}`transforms`, but it would probably be worth considering designing a kernel that better captures what you're trying to model."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cfdaf219",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
4 changes: 2 additions & 2 deletions docs/tutorials/mixture.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@
"\n",
"\n",
"def build_gp(params):\n",
" kernel1 = jnp.exp(params[\"log_amp1\"]) * kernels.Matern32(\n",
" jnp.exp(params[\"log_scale1\"])\n",
" kernel1 = jnp.exp(params[\"log_amp1\"]) * transforms.Linear(\n",
" jnp.exp(params[\"log_scale1\"]), kernels.Matern32()\n",
" )\n",
" kernel2 = jnp.exp(params[\"log_amp2\"]) * transforms.Subspace(\n",
" 0,\n",
Expand Down
16 changes: 8 additions & 8 deletions docs/tutorials/multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
"```\n",
"\n",
"as it was when using `george`.\n",
"This is indicated in the {ref}`api-kernels` section of the API docs, where the argument of each kernel is defined.\n",
"This is indicated in the {ref}`api-stationary` section of the API docs, where the argument of each kernel is defined.\n",
"\n",
"It is possible to change this behavior by specifying your preferred {class}`tinygp.kernels.stationary.Distance` metric using the `distance` argument to any {class}`tinygp.kernels.Stationary` kernel.\n",
"\n",
"Also, `tinygp` does not require that you specify dimension of the kernel using an `ndim` parameter when instantiating the kernel.\n",
"The parameters of the kernel must, however, be broadcastable to the dimension of your inputs.\n",
Expand Down Expand Up @@ -82,13 +84,11 @@
"ndim = 3\n",
"X = np.random.default_rng(1).normal(size=(10, ndim))\n",
"\n",
"# These to kernels are equivalent\n",
"# This kernel is equivalent...\n",
"scale = 1.5\n",
"kernel1 = kernels.Matern32(scale)\n",
"kernel2 = kernels.Matern32(jnp.full(ndim, scale))\n",
"np.testing.assert_allclose(kernel1(X, X), kernel2(X, X))\n",
"\n",
"# And both are equivalent to manually scaling the input coordinates\n",
"# ... to manually scaling the input coordinates\n",
"kernel0 = kernels.Matern32()\n",
"np.testing.assert_allclose(kernel0(X / scale, X / scale), kernel1(X, X))"
]
Expand Down Expand Up @@ -187,8 +187,8 @@
"\n",
"\n",
"def build_gp_uncorr(params):\n",
" kernel = jnp.exp(params[\"log_amp\"]) * kernels.ExpSquared(\n",
" jnp.exp(params[\"log_scale\"])\n",
" kernel = jnp.exp(params[\"log_amp\"]) * transforms.Linear(\n",
" jnp.exp(-params[\"log_scale\"]), kernels.ExpSquared()\n",
" )\n",
" return GaussianProcess(kernel, X, diag=yerr**2)\n",
"\n",
Expand All @@ -212,7 +212,7 @@
"outputs": [],
"source": [
"y_pred = uncorr_gp.predict(y, X_pred).reshape(y_true.shape)\n",
"xy = uncorr_gp.kernel.kernel2.scale[:, None] * ellipse\n",
"xy = ellipse / uncorr_gp.kernel.kernel2.scale[:, None]\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)\n",
"axes[0].plot(xy[0], xy[1], \"--k\", lw=0.5)\n",
Expand Down
6 changes: 3 additions & 3 deletions src/tinygp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

__all__ = ["__version__", "kernels", "transforms", "GaussianProcess"]

from . import kernels, transforms
from .gp import GaussianProcess
from .tinygp_version import version as __version__
from tinygp import kernels, transforms
from tinygp.gp import GaussianProcess
from tinygp.tinygp_version import version as __version__

__author__ = "Dan Foreman-Mackey"
__email__ = "foreman.mackey@gmail.com"
Expand Down
4 changes: 2 additions & 2 deletions src/tinygp/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import jax.numpy as jnp
from jax.scipy import linalg

from .kernels import Kernel
from .types import JAXArray
from tinygp.kernels import Kernel
from tinygp.types import JAXArray


class GaussianProcess:
Expand Down
Loading

0 comments on commit a71c51a

Please sign in to comment.