Skip to content

Commit

Permalink
fix: t0s grid + CombinedNuance features + new tutorial
Browse files Browse the repository at this point in the history
fix: t0s grid + CombinedNuance features + new tutorial
  • Loading branch information
lgrcia authored Feb 6, 2024
2 parents 12ad2de + f1736ba commit c09d943
Show file tree
Hide file tree
Showing 15 changed files with 677 additions and 648 deletions.
32 changes: 17 additions & 15 deletions docs/notebooks/combined.ipynb

Large diffs are not rendered by default.

30 changes: 14 additions & 16 deletions docs/notebooks/multi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
" CpuDevice(id=6),\n",
" CpuDevice(id=7),\n",
" CpuDevice(id=8),\n",
" CpuDevice(id=9),\n",
" CpuDevice(id=10),\n",
" CpuDevice(id=11)]"
" CpuDevice(id=9)]"
]
},
"execution_count": 1,
Expand Down Expand Up @@ -81,7 +79,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/lgarcia/code/dev/nuance/nuance/combined.py:7: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
"/Users/lgrcia/code/dev/nuance/nuance/combined.py:7: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
" from tqdm.autonotebook import tqdm\n"
]
},
Expand Down Expand Up @@ -152,12 +150,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "52a1143ec15c42d0ac62c8c4da2e0d77",
"model_id": "b48b609eac8c478cb5630c685587f851",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/84 [00:00<?, ?it/s]"
" 0%| | 0/100 [00:00<?, ?it/s]"
]
},
"metadata": {},
Expand Down Expand Up @@ -211,13 +209,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "99bdae816415436eb32eecaf88d2bf80",
"model_id": "1126a6088496485f9d1a3f090a166c66",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -240,10 +238,10 @@
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x2d46a2290>"
"<matplotlib.legend.Legend at 0x17768e580>"
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
Expand Down Expand Up @@ -317,7 +315,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -344,13 +342,13 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "73de98f383484957acc59fa4a045bd1e",
"model_id": "e1ce39b180d447d6be6e1c0b4cbde82d",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -373,10 +371,10 @@
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x2d46447d0>"
"<matplotlib.legend.Legend at 0x2b51d74f0>"
]
},
"execution_count": 11,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
Expand Down Expand Up @@ -455,7 +453,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
18 changes: 8 additions & 10 deletions docs/notebooks/periodic.ipynb

Large diffs are not rendered by default.

12 changes: 5 additions & 7 deletions docs/notebooks/single.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
" CpuDevice(id=6),\n",
" CpuDevice(id=7),\n",
" CpuDevice(id=8),\n",
" CpuDevice(id=9),\n",
" CpuDevice(id=10),\n",
" CpuDevice(id=11)]"
" CpuDevice(id=9)]"
]
},
"execution_count": 1,
Expand All @@ -61,7 +59,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/lgarcia/code/dev/nuance/nuance/combined.py:7: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
"/Users/lgrcia/code/dev/nuance/nuance/combined.py:7: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
" from tqdm.autonotebook import tqdm\n"
]
},
Expand Down Expand Up @@ -109,12 +107,12 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0d08808b1e0a4879ba363307b474c9d8",
"model_id": "61fb5d6d76c048dfb2d842a437de8f10",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/167 [00:00<?, ?it/s]"
" 0%| | 0/200 [00:00<?, ?it/s]"
]
},
"metadata": {},
Expand Down Expand Up @@ -247,7 +245,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.8.13"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
458 changes: 458 additions & 0 deletions docs/notebooks/tutorials/ground_based.ipynb

Large diffs are not rendered by default.

27 changes: 17 additions & 10 deletions docs/notebooks/tutorials/tess_search.ipynb

Large diffs are not rendered by default.

45 changes: 0 additions & 45 deletions docs/source/multi.ipynb

This file was deleted.

279 changes: 0 additions & 279 deletions docs/source/periodic.ipynb

This file was deleted.

215 changes: 0 additions & 215 deletions docs/source/single.ipynb

This file was deleted.

54 changes: 27 additions & 27 deletions nuance/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __post_init__(self):
self._fill_search_data()
self._compute_L()

def __getitem__(self, i):
return self.datasets[i]

@property
def model(self):
"""The model"""
Expand All @@ -44,7 +47,14 @@ def model(self):
def _fill_search_data(self):
if all([d.search_data is not None for d in self.datasets]):
t0s = np.hstack([d.search_data.t0s for d in self.datasets])
Ds = np.hstack([d.search_data.Ds for d in self.datasets])
all_Ds = [n.search_data.Ds for n in self.datasets]
all_equal = (
np.diff(np.vstack(all_Ds).reshape(len(all_Ds), -1), axis=0) == 0
).all()
assert (
all_equal
), "All datasets linear searches must have same duration grids"
Ds = self.datasets[0].search_data.Ds
ll = np.vstack([d.search_data.ll for d in self.datasets])
z = np.vstack([d.search_data.z for d in self.datasets])
vz = np.vstack([d.search_data.vz for d in self.datasets])
Expand Down Expand Up @@ -73,7 +83,7 @@ def _compute_L(self):
Liy = solve_triangular(*[(d.gp, d.flux) for d in self.datasets])
LiX = solve_triangular(*[(d.gp, d.X.T) for d in self.datasets])

def eval_m(ms):
def eval_model(ms):
Lim = solve_triangular(*[(d.gp, m) for d, m in zip(self.datasets, ms)])
LiXm = jnp.hstack([LiX, Lim[:, None]])
LiXmT = LiXm.T
Expand All @@ -82,18 +92,13 @@ def eval_m(ms):
v = jnp.linalg.inv(LimX2)
return w, v

self.eval_m = eval_m
self.eval_model = eval_model

def linear_search(self, t0s, Ds, progress=True):
for d in self.datasets:
d.linear_search(t0s, Ds, progress=progress)

def periodic_transits(self, t0, D, P, c=None):
if c is None:
c = self.c
return [self.model(d.search_data.t0s, t0, D, P) for d in self.datasets]

def solve(self, t0, D, P, c=None):
def solve(self, t0, D, P):
"""Solve the combined model for a given set of parameters.
Parameters
Expand All @@ -102,7 +107,7 @@ def solve(self, t0, D, P, c=None):
epoch, same unit as time
D : float
duration, same unit as time
P : float, optional
P : float, optionale
period, same unit as time, by default None
c : float, optional
c parameter of the transit model, by default None
Expand All @@ -112,12 +117,11 @@ def solve(self, t0, D, P, c=None):
list
(w, v): linear coefficients and their covariance matrix
"""
if c is None:
c = self.c
w, v = self.eval_m(self.periodic_transits(t0, D, P, c))
models = [self.model(d.search_data.t0s, t0, D, P) for d in self.datasets]
w, v = self.eval_model(models)
return w, v

def snr(self, t0, D, P, c=None):
def snr(self, t0, D, P):
"""SNR of transit linearly solved for epoch `t0` and duration `D` (and period `P` for a periodic transit)
Parameters
Expand All @@ -136,10 +140,8 @@ def snr(self, t0, D, P, c=None):
float
transit snr
"""
if c is None:
c = self.c
w, v = self.solve(t0, D, P, c)
return np.max([0, w[-1] / jnp.sqrt(v[-1, -1])])
w, v = self.solve(t0, D, P)
return jnp.max(jnp.array([0, w[-1] / jnp.sqrt(v[-1, -1])]))

def periodic_search(self, periods, dphi=0.01):
"""Performs the periodic search
Expand Down Expand Up @@ -193,7 +195,7 @@ def _search(p):

return new_search_data

def models(self, t0, D, P):
def models(self, t0, D, P, split=False):
"""Solve the combined model for a given set of parameters.
Parameters
Expand All @@ -204,18 +206,14 @@ def models(self, t0, D, P):
duration, same unit as time
P : float, optional
period, same unit as time, by default None
c : float, optional
c parameter of the transit model, by default None
Returns
-------
list
(w, v): linear coefficients and their covariance matrix
"""
if c is None:
c = self.c
m = self.model(t0, D, P)
w, _ = self.eval_m(m)
ms = [d.model(d.time, t0, D, P) for d in self.datasets]
w, _ = self.eval_model(ms)

# means
w_idxs = [0, *np.cumsum([d.X.shape[0] for d in self.datasets])]
Expand All @@ -231,8 +229,10 @@ def models(self, t0, D, P):
for i, d in enumerate(self.datasets):
_, cond = d.gp.condition(d.flux - means[i] - signals[i])
noises.append(cond.mean)

return np.hstack(means), np.hstack(signals), np.hstack(noises)
if split:
return means, signals, noises
else:
return np.hstack(means), np.hstack(signals), np.hstack(noises)

def mask_model(self, t0: float, D: float, P: float):
new_self = self.__class__(
Expand Down
25 changes: 23 additions & 2 deletions nuance/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import jax
import jax.numpy as jnp

Expand All @@ -21,11 +23,30 @@ def function(m):


@jax.jit
def transit_protopapas(t, t0, D, P=1e15, c=12, d=1.0):
def transit_protopapas(t, t0, D, P=1e15, c=12):
_t = P * jnp.sin(jnp.pi * (t - t0) / P) / (jnp.pi * D)
return -d * 0.5 * jnp.tanh(c * (_t + 1 / 2)) + 0.5 * jnp.tanh(c * (_t - 1 / 2))
return -0.5 * jnp.tanh(c * (_t + 1 / 2)) + 0.5 * jnp.tanh(c * (_t - 1 / 2))


@jax.jit
def transit_box(time, t0, D, P=1e15):
return -((jnp.abs(time - t0) % P) < D / 2).astype(float)


def map_function(eval_function, model, time, backend, map_t0, map_D):
jitted_eval = jax.jit(eval_function, backend=backend)

@jax.jit
def single_eval(t0, D):
m = model(time, t0, D)
ll, w, v = jitted_eval(m)
return w[-1], v[-1, -1], ll

t0s_eval = map_t0(single_eval, in_axes=(0, None))
ds_t0s_eval = map_D(t0s_eval, in_axes=(None, 0))

return ds_t0s_eval


pmap_cpus = partial(map_function, backend="cpu", map_t0=jax.pmap, map_D=jax.vmap)
vmap_gpu = partial(map_function, backend="gpu", map_t0=jax.vmap, map_D=jax.vmap)
Loading

0 comments on commit c09d943

Please sign in to comment.