Skip to content

Commit

Permalink
feat: new optimization methods + poetry to hatch + cleaner linear search
Browse files Browse the repository at this point in the history
- new optimization methods
- poetry to hatch
- cleaner linear search
  • Loading branch information
lgrcia authored May 28, 2024
2 parents 2ff4457 + 55341af commit 254892d
Show file tree
Hide file tree
Showing 17 changed files with 875 additions and 685 deletions.
62 changes: 62 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
name: Build, test and publish

on:
push:
branches:
- main
tags:
- "*"
pull_request:
release:
types: [published]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
name: Install Python
with:
python-version: "3.10"
- name: Build sdist and wheel
run: |
python -m pip install -U pip
python -m pip install -U build
python -m build .
- uses: actions/upload-artifact@v4
with:
path: dist/*

test:
runs-on: ubuntu-latest
needs: [build]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: python -m pip install -U pip
- name: Install package and test dependencies
run: python -m pip install ".[dev]"
- name: Run tests
run: python -m pytest

publish:
environment:
name: pypi
url: https://pypi.org/p/portrait
permissions:
id-token: write
needs: [test]
runs-on: ubuntu-latest
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/download-artifact@v4
with:
name: artifact
path: dist
- uses: pypa/gh-action-pypi-publish@v1.8.14
29 changes: 0 additions & 29 deletions .github/workflows/ci.yml

This file was deleted.

50 changes: 0 additions & 50 deletions .github/workflows/publish.yml

This file was deleted.

5 changes: 3 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ Efficient detection of planets transiting quiet or active stars
markdown/install
notebooks/motivation.ipynb
notebooks/star.ipynb
notebooks/templates.ipynb
examples.md
```

Expand All @@ -81,5 +79,8 @@ notebooks/tutorials/exocomet.ipynb
:caption: Reference
markdown/how.ipynb
notebooks/star.ipynb
notebooks/templates.ipynb
markdown/hardware.md
markdown/API
```
102 changes: 102 additions & 0 deletions docs/markdown/hardware.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Hardware acceleration

When running the linear search, nuance exploits the parallelization capabilities of JAX by using a default mapping strategy depending on the available devices.

## Solving for `(t0, D)`

To solve a particular model (like a transit) with a given epoch `t0` and duration `D`, we define the function

```python
import jax

@jax.jit
def solve(t0, D):
m = model(time, t0, D)
ll, w, v = nu._solve(m)
return w[-1], v[-1, -1], ll
```

where `model` is the [template model](../notebooks/templates.ipynb), and `nu._solve` is the `Nuance._solve` method returning:

- `w[-1]` the template model depth
- `v[-1, -1]` the variance of the template model depth
- `ll` the log-likelihood of the data to the model

## Batching over `(t0s, Ds)`
The goal of the linear search is then to call `solve` for a grid of of epochs `t0s` and durations `Ds`. As `t0s` is usually very large compared to `Ds` (~1000 vs. ~10), the default strategy is to batch the `t0s`:

```python
# we pad to have fixed size batches
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size])
t0s_batches = np.reshape(
t0s_padded, (len(t0s_padded) // batch_size, batch_size)
)
```

## JAX mapping

In order to solve a given batch in an optimal way, the `batch_size` can be set depending on the devices available:

- If multiple **CPUs** are available, the `batch_size` is chosen as the number of devices (`jax.device_count()`) and we can solve a given batch using

```python
solve_batch = jax.pmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
```

where each batch is `jax.pmap`ed over all available CPUs along the `t0s` axis.

- If a **GPU** is available, the `batch_size` can be larger and the batch is `jax.vmap`ed along `t0s`

```python
solve_batch = jax.vmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
```

Then, the linear search consists in iterating over `t0s_batches`:

```python
results = []

for t0_batch in t0s_batches:
results.append(solve_batch(t0_batch, Ds))
```

```{note}
Of course, one familiar with JAX can use their own mapping strategy to evaluate `solve` over a grid of epochs `t0s` and durations `Ds`.
```

## The full method

The method `nuance.Naunce.linear_search` is then

```python
def linear_search( self, t0s, Ds):

backend = jax.default_backend()
batch_size = {"cpu": DEVICES_COUNT, "gpu": 1000}[backend]

@jax.jit
def solve(t0, D):
m = self.model(self.time, t0, D)
ll, w, v = self._solve(m)
return jnp.array([w[-1], v[-1, -1], ll])

# Batches
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size])
t0s_batches = np.reshape(
t0s_padded, (len(t0s_padded) // batch_size, batch_size)
)

# Mapping
if backend == "cpu":
solve_batch = jax.pmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))
else:
solve_batch = jax.vmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None))

# Iterate
results = []

for t0_batch in t0s_batches:
results.append(solve_batch(t0_batch, Ds))

...
```
Binary file removed docs/notebooks/exocomet.pdf
Binary file not shown.
12 changes: 6 additions & 6 deletions docs/notebooks/single.ipynb

Large diffs are not rendered by default.

273 changes: 273 additions & 0 deletions docs/notebooks/tutorials/GP_optimization.ipynb

Large diffs are not rendered by default.

342 changes: 169 additions & 173 deletions docs/notebooks/tutorials/exocomet.ipynb

Large diffs are not rendered by default.

Binary file removed docs/notebooks/tutorials/exocomet.pdf
Binary file not shown.
135 changes: 30 additions & 105 deletions docs/notebooks/tutorials/ground_based.ipynb

Large diffs are not rendered by default.

235 changes: 47 additions & 188 deletions docs/notebooks/tutorials/tess_search.ipynb

Large diffs are not rendered by default.

33 changes: 32 additions & 1 deletion nuance/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import jax
import jax.numpy as jnp
import jaxopt


def eval_model(flux, X, gp):
def solve(flux, X, gp):
Liy = gp.solver.solve_triangular(flux)
LiX = gp.solver.solve_triangular(X.T)

Expand All @@ -22,6 +23,36 @@ def function(m):
return function


def gp_model(x, y, build_gp, X=None):

if X is None:
X = jnp.atleast_2d(jnp.ones_like(x))

@jax.jit
def nll_w(params):
gp = build_gp(params, x)
Liy = gp.solver.solve_triangular(y)
LiX = gp.solver.solve_triangular(X.T)
LiXT = LiX.T
LiX2 = LiXT @ LiX
w = jnp.linalg.lstsq(LiX2, LiXT @ Liy)[0]
nll = -gp.log_probability(y - w @ X)
return nll, w

@jax.jit
def nll(params):
return nll_w(params)[0]

@jax.jit
def mu(params):
gp = build_gp(params, x)
_, w = nll_w(params)
cond_gp = gp.condition(y - w @ X, x).gp
return cond_gp.loc + w @ X

return mu, nll


def transit_protopapas(t, t0, D, P=1e15, c=12):
_t = P * jnp.sin(jnp.pi * (t - t0) / P) / (jnp.pi * D)
return -0.5 * jnp.tanh(c * (_t + 1 / 2)) + 0.5 * jnp.tanh(c * (_t - 1 / 2))
Expand Down
Loading

0 comments on commit 254892d

Please sign in to comment.