-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: new optimization methods + poetry to hatch + cleaner linear search
- new optimization methods - poetry to hatch - cleaner linear search
- Loading branch information
Showing
17 changed files
with
875 additions
and
685 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
name: Build, test and publish | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
tags: | ||
- "*" | ||
pull_request: | ||
release: | ||
types: [published] | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
with: | ||
fetch-depth: 0 | ||
- uses: actions/setup-python@v5 | ||
name: Install Python | ||
with: | ||
python-version: "3.10" | ||
- name: Build sdist and wheel | ||
run: | | ||
python -m pip install -U pip | ||
python -m pip install -U build | ||
python -m build . | ||
- uses: actions/upload-artifact@v4 | ||
with: | ||
path: dist/* | ||
|
||
test: | ||
runs-on: ubuntu-latest | ||
needs: [build] | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: "3.10" | ||
- name: Install dependencies | ||
run: python -m pip install -U pip | ||
- name: Install package and test dependencies | ||
run: python -m pip install ".[dev]" | ||
- name: Run tests | ||
run: python -m pytest | ||
|
||
publish: | ||
environment: | ||
name: pypi | ||
url: https://pypi.org/p/portrait | ||
permissions: | ||
id-token: write | ||
needs: [test] | ||
runs-on: ubuntu-latest | ||
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') | ||
steps: | ||
- uses: actions/download-artifact@v4 | ||
with: | ||
name: artifact | ||
path: dist | ||
- uses: pypa/gh-action-pypi-publish@v1.8.14 |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Hardware acceleration | ||
|
||
When running the linear search, nuance exploits the parallelization capabilities of JAX by using a default mapping strategy depending on the available devices. | ||
|
||
## Solving for `(t0, D)` | ||
|
||
To solve a particular model (like a transit) with a given epoch `t0` and duration `D`, we define the function | ||
|
||
```python | ||
import jax | ||
|
||
@jax.jit | ||
def solve(t0, D): | ||
m = model(time, t0, D) | ||
ll, w, v = nu._solve(m) | ||
return w[-1], v[-1, -1], ll | ||
``` | ||
|
||
where `model` is the [template model](../notebooks/templates.ipynb), and `nu._solve` is the `Nuance._solve` method returning: | ||
|
||
- `w[-1]` the template model depth | ||
- `v[-1, -1]` the variance of the template model depth | ||
- `ll` the log-likelihood of the data to the model | ||
|
||
## Batching over `(t0s, Ds)` | ||
The goal of the linear search is then to call `solve` for a grid of of epochs `t0s` and durations `Ds`. As `t0s` is usually very large compared to `Ds` (~1000 vs. ~10), the default strategy is to batch the `t0s`: | ||
|
||
```python | ||
# we pad to have fixed size batches | ||
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size]) | ||
t0s_batches = np.reshape( | ||
t0s_padded, (len(t0s_padded) // batch_size, batch_size) | ||
) | ||
``` | ||
|
||
## JAX mapping | ||
|
||
In order to solve a given batch in an optimal way, the `batch_size` can be set depending on the devices available: | ||
|
||
- If multiple **CPUs** are available, the `batch_size` is chosen as the number of devices (`jax.device_count()`) and we can solve a given batch using | ||
|
||
```python | ||
solve_batch = jax.pmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None)) | ||
``` | ||
|
||
where each batch is `jax.pmap`ed over all available CPUs along the `t0s` axis. | ||
|
||
- If a **GPU** is available, the `batch_size` can be larger and the batch is `jax.vmap`ed along `t0s` | ||
|
||
```python | ||
solve_batch = jax.vmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None)) | ||
``` | ||
|
||
Then, the linear search consists in iterating over `t0s_batches`: | ||
|
||
```python | ||
results = [] | ||
|
||
for t0_batch in t0s_batches: | ||
results.append(solve_batch(t0_batch, Ds)) | ||
``` | ||
|
||
```{note} | ||
Of course, one familiar with JAX can use their own mapping strategy to evaluate `solve` over a grid of epochs `t0s` and durations `Ds`. | ||
``` | ||
|
||
## The full method | ||
|
||
The method `nuance.Naunce.linear_search` is then | ||
|
||
```python | ||
def linear_search( self, t0s, Ds): | ||
|
||
backend = jax.default_backend() | ||
batch_size = {"cpu": DEVICES_COUNT, "gpu": 1000}[backend] | ||
|
||
@jax.jit | ||
def solve(t0, D): | ||
m = self.model(self.time, t0, D) | ||
ll, w, v = self._solve(m) | ||
return jnp.array([w[-1], v[-1, -1], ll]) | ||
|
||
# Batches | ||
t0s_padded = np.pad(t0s, [0, batch_size - (len(t0s) % batch_size) % batch_size]) | ||
t0s_batches = np.reshape( | ||
t0s_padded, (len(t0s_padded) // batch_size, batch_size) | ||
) | ||
|
||
# Mapping | ||
if backend == "cpu": | ||
solve_batch = jax.pmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None)) | ||
else: | ||
solve_batch = jax.vmap(jax.vmap(solve, in_axes=(None, 0)), in_axes=(0, None)) | ||
|
||
# Iterate | ||
results = [] | ||
|
||
for t0_batch in t0s_batches: | ||
results.append(solve_batch(t0_batch, Ds)) | ||
|
||
... | ||
``` |
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.