-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding distance metric to stationary kernels (#35)
* adding distance metric to stationary kernels * switching to new interface for stationary kernels * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updating multivariate tutorial * adding some docstrings * adding custom geometry tutorial * reduce memory usage in geometry tutorial * tweaking geometry tutorial labels * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
8e71bd7
commit a71c51a
Showing
14 changed files
with
633 additions
and
217 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "26e39d4c-9caf-4435-a8ed-aba9544bdfac", | ||
"metadata": { | ||
"tags": [ | ||
"hide-cell" | ||
] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"try:\n", | ||
" import tinygp\n", | ||
"except ImportError:\n", | ||
" %pip install -q tinygp\n", | ||
"\n", | ||
"try:\n", | ||
" import jaxopt\n", | ||
"except ImportError:\n", | ||
" %pip install -q jaxopt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d6d8253c-8c31-49be-b7d0-ead5b1dccfb5", | ||
"metadata": {}, | ||
"source": [ | ||
"(geometry)=\n", | ||
"\n", | ||
"# Tutorial: Custom Geometry\n", | ||
"\n", | ||
"When working with multivariate inputs, you'll always need to choose a metric for computing the distance between coordinates in your input space.\n", | ||
"As discussed in {ref}`multivariate`, `tinygp` includes built-in support for some common metrics which, when combined with {ref}`transforms`, can cover a wide range of use cases.\n", | ||
"But this tutorial covers a more general use case: custom geometries.\n", | ||
"\n", | ||
"In this example, we will fit a GP model to data that lives on the surface of a sphere.\n", | ||
"Here, we want to use our knowledge of this system to design a metric that takes this geometry into account.\n", | ||
"Specifically, our data will have unit vector coordinates, and we will define a [great-circle distance](https://en.wikipedia.org/wiki/Great-circle_distance#Vector_version) to compute the angular distances between these vectors." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e98bab1c-2a77-4f0d-8e87-1ff7a01bce06", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"from tinygp import kernels, GaussianProcess\n", | ||
"from jax.config import config\n", | ||
"\n", | ||
"config.update(\"jax_enable_x64\", True)\n", | ||
"\n", | ||
"\n", | ||
"class GreatCircleDistance(kernels.stationary.Distance):\n", | ||
" def distance(self, X1, X2):\n", | ||
" if jnp.shape(X1) != (3,) or jnp.shape(X2) != (3,):\n", | ||
" raise ValueError(\n", | ||
" \"The great-circle distance is only defined for unit 3-vector\"\n", | ||
" )\n", | ||
" return jnp.arctan2(jnp.linalg.norm(jnp.cross(X1, X2)), (X1.T @ X2))\n", | ||
"\n", | ||
"\n", | ||
"# Make a spherical grid\n", | ||
"phi = np.linspace(-np.pi, np.pi, 50)\n", | ||
"theta = np.linspace(-0.5 * np.pi, 0.5 * np.pi, 50)\n", | ||
"phi_grid, theta_grid = np.meshgrid(phi, theta, indexing=\"ij\")\n", | ||
"phi_grid = phi_grid.flatten()\n", | ||
"theta_grid = theta_grid.flatten()\n", | ||
"X_grid = np.vstack(\n", | ||
" (\n", | ||
" np.cos(phi_grid) * np.cos(theta_grid),\n", | ||
" np.sin(phi_grid) * np.cos(theta_grid),\n", | ||
" np.sin(theta_grid),\n", | ||
" )\n", | ||
").T\n", | ||
"\n", | ||
"# Choose some uniformly distributed coordinates to be our \"data\"\n", | ||
"random = np.random.default_rng(456)\n", | ||
"X_obs = random.normal(size=(100, 3))\n", | ||
"X_obs /= np.sqrt(np.sum(X_obs**2, axis=1))[:, None]\n", | ||
"theta_obs = np.arctan2(\n", | ||
" X_obs[:, 2], np.sqrt(X_obs[:, 0] ** 2 + X_obs[:, 1] ** 2)\n", | ||
")\n", | ||
"phi_obs = np.arctan2(X_obs[:, 1], X_obs[:, 0])\n", | ||
"\n", | ||
"# Our kernel is parameterized by a length scale in **radians**\n", | ||
"ell = 0.5\n", | ||
"kernel = 1.5 * kernels.Matern52(ell, distance=GreatCircleDistance())\n", | ||
"\n", | ||
"# Sample a simulated dataset\n", | ||
"gp = GaussianProcess(\n", | ||
" kernel, np.concatenate((X_grid, X_obs), axis=0), diag=0.01\n", | ||
")\n", | ||
"y_samp = gp.sample(jax.random.PRNGKey(10))\n", | ||
"y_grid = y_samp[: len(X_grid)]\n", | ||
"y_obs = y_samp[len(X_grid) :] + 0.5 * random.normal(size=len(X_obs))\n", | ||
"\n", | ||
"# Plot the map\n", | ||
"plt.pcolor(\n", | ||
" phi,\n", | ||
" theta,\n", | ||
" y_grid.reshape((len(phi), len(theta))).T,\n", | ||
" vmin=y_grid.min(),\n", | ||
" vmax=y_grid.max(),\n", | ||
")\n", | ||
"plt.scatter(\n", | ||
" phi_obs,\n", | ||
" theta_obs,\n", | ||
" c=y_obs,\n", | ||
" edgecolor=\"k\",\n", | ||
" vmin=y_grid.min(),\n", | ||
" vmax=y_grid.max(),\n", | ||
")\n", | ||
"plt.xlabel(r\"$\\phi$\")\n", | ||
"plt.ylabel(r\"$\\theta$\")\n", | ||
"_ = plt.title(\"simulated data\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6efc27c0", | ||
"metadata": {}, | ||
"source": [ | ||
"Using these simulated data, we can now fit the model as usual:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7d102f4d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jaxopt\n", | ||
"\n", | ||
"\n", | ||
"def build_gp(params):\n", | ||
" kernel = jnp.exp(params[\"log_amp\"]) * kernels.Matern52(\n", | ||
" jnp.exp(params[\"log_scale\"]), distance=GreatCircleDistance()\n", | ||
" )\n", | ||
" return GaussianProcess(\n", | ||
" kernel,\n", | ||
" X_obs,\n", | ||
" diag=jnp.exp(2 * params[\"log_sigma\"]),\n", | ||
" mean=params[\"mean\"],\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
"@jax.jit\n", | ||
"def loss(params):\n", | ||
" return -build_gp(params).condition(y_obs)\n", | ||
"\n", | ||
"\n", | ||
"params = {\n", | ||
" \"log_amp\": np.zeros(()),\n", | ||
" \"log_scale\": np.zeros(()),\n", | ||
" \"log_sigma\": np.zeros(()),\n", | ||
" \"mean\": np.zeros(()),\n", | ||
"}\n", | ||
"solver = jaxopt.ScipyMinimize(fun=loss)\n", | ||
"soln = solver.run(params)\n", | ||
"gp = build_gp(soln.params)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ab716742", | ||
"metadata": {}, | ||
"source": [ | ||
"At the maximum point, we can plot our model prediction compared to the ground truth, with the residuals plotted on the same scale:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "c93fe6d7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"y_pred = gp.predict(y_obs, X_grid)\n", | ||
"\n", | ||
"fig, axes = plt.subplots(3, 1, sharex=True, figsize=(8, 8))\n", | ||
"\n", | ||
"axes[0].set_title(\"truth\")\n", | ||
"axes[0].pcolor(\n", | ||
" phi,\n", | ||
" theta,\n", | ||
" y_grid.reshape((len(phi), len(theta))).T,\n", | ||
" vmin=y_grid.min(),\n", | ||
" vmax=y_grid.max(),\n", | ||
")\n", | ||
"\n", | ||
"axes[1].set_title(\"predicted\")\n", | ||
"axes[1].pcolor(\n", | ||
" phi,\n", | ||
" theta,\n", | ||
" y_pred.reshape((len(phi), len(theta))).T,\n", | ||
" vmin=y_grid.min(),\n", | ||
" vmax=y_grid.max(),\n", | ||
")\n", | ||
"\n", | ||
"axes[2].set_title(\"residuals\")\n", | ||
"axes[2].pcolor(\n", | ||
" phi,\n", | ||
" theta,\n", | ||
" (y_pred - y_grid).reshape((len(phi), len(theta))).T,\n", | ||
" vmin=y_grid.min(),\n", | ||
" vmax=y_grid.max(),\n", | ||
")\n", | ||
"\n", | ||
"axes[2].set_xlabel(r\"$\\phi$\")\n", | ||
"for ax in axes:\n", | ||
" ax.set_ylabel(r\"$\\theta$\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c46b306f", | ||
"metadata": {}, | ||
"source": [ | ||
"One thing that is worth commenting on here is that, unlike in {ref}`multivariate`, we're using only a single length scale.\n", | ||
"This means that our kernel is _isotropic_.\n", | ||
"For many use cases this is probably what you want because the whole point of this distance metric is that it is rotationally invariant.\n", | ||
"If you want to model or discover anisotropies, you could use the methods discussed in {ref}`transforms`, but it would probably be worth considering designing a kernel that better captures what you're trying to model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "cfdaf219", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.