Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new optimization methods + poetry to hatch + cleaner linear search #24

Merged
merged 13 commits into from
May 28, 2024
Merged
62 changes: 62 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: Build, test and publish

on:
push:
branches:
- main
tags:
- "*"
pull_request:
release:
types: [published]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
name: Install Python
with:
python-version: "3.10"
- name: Build sdist and wheel
run: |
python -m pip install -U pip
python -m pip install -U build
python -m build .
- uses: actions/upload-artifact@v4
with:
path: dist/*

test:
runs-on: ubuntu-latest
needs: [build]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: python -m pip install -U pip
- name: Install package and test dependencies
run: python -m pip install ".[dev]"
- name: Run tests
run: python -m pytest

publish:
environment:
name: pypi
url: https://pypi.org/p/portrait
permissions:
id-token: write
needs: [test]
runs-on: ubuntu-latest
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/download-artifact@v4
with:
name: artifact
path: dist
- uses: pypa/gh-action-pypi-publish@v1.8.14
29 changes: 0 additions & 29 deletions .github/workflows/ci.yml

This file was deleted.

50 changes: 0 additions & 50 deletions .github/workflows/publish.yml

This file was deleted.

5 changes: 3 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ Efficient detection of planets transiting quiet or active stars

markdown/install
notebooks/motivation.ipynb
notebooks/star.ipynb
notebooks/templates.ipynb
examples.md
```

Expand All @@ -81,5 +79,8 @@ notebooks/tutorials/exocomet.ipynb
:caption: Reference

markdown/how.ipynb
notebooks/star.ipynb
notebooks/templates.ipynb
markdown/hardware.md
markdown/API
```
102 changes: 102 additions & 0 deletions docs/markdown/hardware.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Hardware acceleration

When running the linear search, nuance exploits the parallelization capabilities of JAX by using a default mapping strategy depending on the available devices.

## Solving for `(t0, D)`

To solve a particular model (like a transit) with a given epoch `t0` and duration `D`, we define the function

```python
import jax

@jax.jit
def solve(t0, D):
m = model(time, t0, D)
ll, w, v = nu._solve(m)
return w[-1], v[-1, -1], ll
```

where `model` is the [template model](../notebooks/templates.ipynb), and `nu._solve` is the `Nuance._solve` method returning:

- `w[-1]` the template model depth
- `v[-1, -1]` the variance of the template model depth
- `ll` the log-likelihood of the data to the model

## Batching over `(t0s, Ds)`
The goal of the linear search is then to call `solve` for a grid of of epochs `t0s` and durations `Ds`. As `t0s` is usually very large compared to `Ds` (~1000 vs. ~10), the default strategy is to batch the `t0s`:

```python
# we pad to have fixed size batches
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size])
t0s_batches = np.reshape(
t0s_padded, (len(t0s_padded) // batch_size, batch_size)
)
```

## JAX mapping

In order to solve a given batch in an optimal way, the `batch_size` can be set depending on the devices available:

- If multiple **CPUs** are available, the `batch_size` is chosen as the number of devices (`jax.device_count()`) and we can solve a given batch using

```python
solve_batch = jax.pmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
```

where each batch is `jax.pmap`ed over all available CPUs along the `t0s` axis.

- If a **GPU** is available, the `batch_size` can be larger and the batch is `jax.vmap`ed along `t0s`

```python
solve_batch = jax.vmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
```

Then, the linear search consists in iterating over `t0s_batches`:

```python
results = []

for t0_batch in t0s_batches:
results.append(solve_batch(t0_batch, Ds))
```

```{note}
Of course, one familiar with JAX can use their own mapping strategy to evaluate `solve` over a grid of epochs `t0s` and durations `Ds`.
```

## The full method

The method `nuance.Naunce.linear_search` is then

```python
def linear_search( self, t0s, Ds):

backend = jax.default_backend()
batch_size = {"cpu": DEVICES_COUNT, "gpu": 1000}[backend]

@jax.jit
def solve(t0, D):
m = self.model(self.time, t0, D)
ll, w, v = self._solve(m)
return jnp.array([w[-1], v[-1, -1], ll])

# Batches
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size])
t0s_batches = np.reshape(
t0s_padded, (len(t0s_padded) // batch_size, batch_size)
)

# Mapping
if backend == "cpu":
solve_batch = jax.pmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
else:
solve_batch = jax.vmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))

# Iterate
results = []

for t0_batch in t0s_batches:
results.append(solve_batch(t0_batch, Ds))

...
```
Binary file removed docs/notebooks/exocomet.pdf
Binary file not shown.
12 changes: 6 additions & 6 deletions docs/notebooks/single.ipynb

Large diffs are not rendered by default.

273 changes: 273 additions & 0 deletions docs/notebooks/tutorials/GP_optimization.ipynb

Large diffs are not rendered by default.

342 changes: 169 additions & 173 deletions docs/notebooks/tutorials/exocomet.ipynb

Large diffs are not rendered by default.

Binary file removed docs/notebooks/tutorials/exocomet.pdf
Binary file not shown.
135 changes: 30 additions & 105 deletions docs/notebooks/tutorials/ground_based.ipynb

Large diffs are not rendered by default.

235 changes: 47 additions & 188 deletions docs/notebooks/tutorials/tess_search.ipynb

Large diffs are not rendered by default.

33 changes: 32 additions & 1 deletion nuance/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import jax
import jax.numpy as jnp
import jaxopt


def eval_model(flux, X, gp):
def solve(flux, X, gp):
Liy = gp.solver.solve_triangular(flux)
LiX = gp.solver.solve_triangular(X.T)

Expand All @@ -22,6 +23,36 @@ def function(m):
return function


def gp_model(x, y, build_gp, X=None):

if X is None:
X = jnp.atleast_2d(jnp.ones_like(x))

@jax.jit
def nll_w(params):
gp = build_gp(params, x)
Liy = gp.solver.solve_triangular(y)
LiX = gp.solver.solve_triangular(X.T)
LiXT = LiX.T
LiX2 = LiXT @ LiX
w = jnp.linalg.lstsq(LiX2, LiXT @ Liy)[0]
nll = -gp.log_probability(y - w @ X)
return nll, w

@jax.jit
def nll(params):
return nll_w(params)[0]

@jax.jit
def mu(params):
gp = build_gp(params, x)
_, w = nll_w(params)
cond_gp = gp.condition(y - w @ X, x).gp
return cond_gp.loc + w @ X

return mu, nll


def transit_protopapas(t, t0, D, P=1e15, c=12):
_t = P * jnp.sin(jnp.pi * (t - t0) / P) / (jnp.pi * D)
return -0.5 * jnp.tanh(c * (_t + 1 / 2)) + 0.5 * jnp.tanh(c * (_t - 1 / 2))
Expand Down
Loading