Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move counting routines to estimator class #25

Merged
merged 8 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/corrgi/correlation/angular_correlation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable

import gundam.cflibfor as cff
import numpy as np
import pandas as pd
from gundam import gundam
from hipscat.catalog.catalog_info import CatalogInfo
Expand Down Expand Up @@ -64,3 +65,7 @@ def _construct_cross_args(
*args[5:],
]
return args

def get_bdd_counts(self) -> np.ndarray:
"""Returns the boostrap counts for the angular correlation"""
return np.zeros([self.params.nsept, 0])
9 changes: 9 additions & 0 deletions src/corrgi/correlation/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@
"""Generate the arguments required for the cross pairing method"""
raise NotImplementedError()

@abstractmethod
def get_bdd_counts(self) -> np.ndarray:
"""Returns the boostrap counts for the correlation"""
raise NotImplementedError()

Check warning on line 89 in src/corrgi/correlation/correlation.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/correlation/correlation.py#L89

Added line #L89 was not covered by tests

def transform_counts(self, counts: list[np.ndarray]) -> list[np.ndarray]:
"""Applies final transformations to the correlation counts"""
return counts

@staticmethod
def get_coords(df: pd.DataFrame, catalog_info: CatalogInfo) -> tuple[float, float, float]:
"""Calculate the cartesian coordinates for the points in the partition"""
Expand Down
8 changes: 8 additions & 0 deletions src/corrgi/correlation/projected_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,11 @@ def _construct_cross_args(
*args[7:],
]
return args

def transform_counts(self, counts: list[np.ndarray]) -> list[np.ndarray]:
"""The projected counts need to be transposed before being sent to Fortran"""
return [c.transpose([1, 0]) for c in counts]

def get_bdd_counts(self) -> np.ndarray:
"""Returns the boostrap counts for the projected correlation"""
return np.zeros([self.params.nsepp, self.params.nsepv, 0])
9 changes: 9 additions & 0 deletions src/corrgi/correlation/redshift_correlation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable

import numpy as np

Check warning on line 3 in src/corrgi/correlation/redshift_correlation.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/correlation/redshift_correlation.py#L3

Added line #L3 was not covered by tests
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo
from lsdb import Catalog
Expand Down Expand Up @@ -35,3 +36,11 @@
right_catalog_info: CatalogInfo,
) -> list:
raise NotImplementedError()

def get_bdd_counts(self) -> np.ndarray:

Check warning on line 40 in src/corrgi/correlation/redshift_correlation.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/correlation/redshift_correlation.py#L40

Added line #L40 was not covered by tests
"""Returns the boostrap counts for the correlation"""
raise NotImplementedError()

Check warning on line 42 in src/corrgi/correlation/redshift_correlation.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/correlation/redshift_correlation.py#L42

Added line #L42 was not covered by tests

def transform_counts(self, counts: list[np.ndarray]) -> list[np.ndarray]:

Check warning on line 44 in src/corrgi/correlation/redshift_correlation.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/correlation/redshift_correlation.py#L44

Added line #L44 was not covered by tests
"""Applies final transformations to the correlation counts"""
raise NotImplementedError()

Check warning on line 46 in src/corrgi/correlation/redshift_correlation.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/correlation/redshift_correlation.py#L46

Added line #L46 was not covered by tests
10 changes: 3 additions & 7 deletions src/corrgi/corrgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from munch import Munch

from corrgi.correlation.correlation import Correlation
from corrgi.dask import compute_autocorrelation_counts
from corrgi.estimators import calculate_natural_estimate
from corrgi.utils import compute_catalog_size
from corrgi.estimators.estimator_factory import get_estimator_for_correlation


def compute_autocorrelation(
Expand All @@ -25,10 +23,8 @@ def compute_autocorrelation(
"""
correlation = corr_type(**kwargs)
correlation.validate([catalog, random])
num_galaxies = compute_catalog_size(catalog)
num_random = compute_catalog_size(random)
counts_dd, counts_rr = compute_autocorrelation_counts(catalog, random, correlation)
return calculate_natural_estimate(counts_dd, counts_rr, num_galaxies, num_random)
estimator = get_estimator_for_correlation(correlation)
return estimator.compute_auto_estimate(catalog, random)


def compute_crosscorrelation(left: Catalog, right: Catalog, random: Catalog, params: Munch) -> np.ndarray:
Expand Down
16 changes: 0 additions & 16 deletions src/corrgi/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,6 @@
from corrgi.utils import join_count_histograms


def compute_autocorrelation_counts(catalog: Catalog, random: Catalog, correlation: Correlation) -> np.ndarray:
"""Computes the auto-correlation counts for a catalog.

Args:
catalog (Catalog): The catalog with galaxy samples.
random (Catalog): The catalog with random samples.
correlation (Correlation): The correlation instance.

Returns:
The histogram counts to calculate the auto-correlation.
"""
counts_dd = perform_auto_counts(catalog, correlation)
counts_rr = perform_auto_counts(random, correlation)
return dask.compute(*[counts_dd, counts_rr])


def perform_auto_counts(catalog: Catalog, *args) -> np.ndarray:
"""Aligns the pixel of a single catalog and performs the pairs counting.

Expand Down
25 changes: 0 additions & 25 deletions src/corrgi/estimators.py

This file was deleted.

Empty file.
66 changes: 66 additions & 0 deletions src/corrgi/estimators/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable

import numpy as np
from gundam.gundam import tpcf, tpcf_wrp
from lsdb import Catalog

from corrgi.correlation.correlation import Correlation
from corrgi.correlation.projected_correlation import ProjectedCorrelation
from corrgi.utils import compute_catalog_size


class Estimator(ABC):
"""Estimator base class"""

def __init__(self, correlation: Correlation):
self.correlation = correlation

def compute_auto_estimate(self, catalog: Catalog, random: Catalog) -> np.ndarray:
"""Computes the auto-correlation for this estimator.

Args:
catalog (Catalog): The catalog of galaxy samples.
random (Catalog): The catalog of random samples.

Returns:
The statistical estimate of the auto-correlation function, as a numpy array.
"""
num_galaxies = compute_catalog_size(catalog)
num_random = compute_catalog_size(random)
dd, rr, dr = self.compute_autocorrelation_counts(catalog, random)
args = self._get_auto_args(num_galaxies, num_random, dd, rr, dr)
estimate, _ = self._get_auto_subroutine()(*args)
return estimate

@abstractmethod
def compute_autocorrelation_counts(
self, catalog: Catalog, random: Catalog
) -> list[np.ndarray, np.ndarray, np.ndarray | int]:
"""Computes the auto-correlation counts (DD, RR, DR). These counts are
represented as numpy arrays but DR may be 0 if it isn't used (e.g. with
the natural estimator)."""
raise NotImplementedError()

Check warning on line 45 in src/corrgi/estimators/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/estimators/estimator.py#L45

Added line #L45 was not covered by tests

def _get_auto_subroutine(self) -> Callable:
"""Returns the Fortran routine to calculate the correlation estimate"""
return tpcf_wrp if isinstance(self.correlation, ProjectedCorrelation) else tpcf

def _get_auto_args(
self,
num_galaxies: int,
num_random: int,
counts_dd: np.ndarray,
counts_rr: np.ndarray,
counts_dr: np.ndarray,
) -> list:
"""Returns the args for the auto-correlation estimator routine"""
counts_bdd = self.correlation.get_bdd_counts()
args = [num_galaxies, num_random, counts_dd, counts_bdd, counts_rr, counts_dr]
if isinstance(self.correlation, ProjectedCorrelation):
# The projected routines require an additional parameter
args.append(self.correlation.params.dsepv)
args.append(self.correlation.params.estimator)
return args
22 changes: 22 additions & 0 deletions src/corrgi/estimators/estimator_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from corrgi.correlation.correlation import Correlation
from corrgi.estimators.estimator import Estimator
from corrgi.estimators.natural_estimator import NaturalEstimator

estimator_class_for_type: dict[str, type[Estimator]] = {"NAT": NaturalEstimator}


def get_estimator_for_correlation(correlation: Correlation) -> Estimator:
"""Constructs an Estimator instance for the specified correlation.

Args:
correlation (Correlation): The correlation instance. The type of
"estimator" to use is specified in its parameters.

Returns:
An initialized Estimator object wrapping the correlation to compute.
"""
type_to_use = correlation.params.estimator
if type_to_use not in estimator_class_for_type:
raise ValueError(f"Cannot load estimator type: {str(type_to_use)}")

Check warning on line 20 in src/corrgi/estimators/estimator_factory.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/estimators/estimator_factory.py#L20

Added line #L20 was not covered by tests
estimator_class = estimator_class_for_type[type_to_use]
return estimator_class(correlation)
31 changes: 31 additions & 0 deletions src/corrgi/estimators/natural_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import dask
import numpy as np
from lsdb import Catalog

from corrgi.dask import perform_auto_counts
from corrgi.estimators.estimator import Estimator


class NaturalEstimator(Estimator):
"""Natural Estimator (`DD/RR - 1`)"""

def compute_autocorrelation_counts(
self, catalog: Catalog, random: Catalog
) -> list[np.ndarray, np.ndarray, np.ndarray | int]:
"""Computes the auto-correlation counts for the provided catalog.

Args:
catalog (Catalog): A galaxy samples catalog.
random (Catalog): A random samples catalog.

Returns:
The DD, RR and DR counts for the natural estimator.
"""
counts_dd = perform_auto_counts(catalog, self.correlation)
counts_rr = perform_auto_counts(random, self.correlation)
counts_dr = 0 # The natural estimator does not use DR counts
counts_dd_rr = dask.compute(*[counts_dd, counts_rr])
counts_dd_rr = self.correlation.transform_counts(counts_dd_rr)
return [*counts_dd_rr, counts_dr]
5 changes: 5 additions & 0 deletions tests/corrgi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def acf_nat_estimate(acf_expected_results):
return np.load(acf_expected_results / "w_acf_nat.npy")


@pytest.fixture
def pcf_nat_estimate(pcf_expected_results):
return np.load(pcf_expected_results / "w_pcf_nat.npy")


@pytest.fixture
def single_data_partition(data_catalog_dir):
return pd.read_parquet(data_catalog_dir / "Norder=0" / "Dir=0" / "Npix=1.parquet")
Expand Down
44 changes: 14 additions & 30 deletions tests/corrgi/test_acf.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,52 @@
import numpy as np
import numpy.testing as npt
import pytest
from gundam import gundam
import hipscat

from corrgi.correlation.angular_correlation import AngularCorrelation
from corrgi.corrgi import compute_autocorrelation
from corrgi.dask import compute_autocorrelation_counts
from corrgi.estimators import calculate_natural_estimate
from corrgi.estimators.natural_estimator import NaturalEstimator


def test_acf_bins_are_correct(acf_bins_left_edges, acf_bins_right_edges, acf_params):
bins, _ = gundam.makebins(
acf_params.nsept,
acf_params.septmin,
acf_params.dsept,
acf_params.logsept,
)
bins = AngularCorrelation(params=acf_params).make_bins()
all_bins = np.append(acf_bins_left_edges, acf_bins_right_edges[-1])
assert np.array_equal(bins, all_bins)


def test_acf_counts_are_correct(
def test_acf_natural_counts_are_correct(
dask_client, data_catalog, rand_catalog, acf_dd_counts, acf_rr_counts, acf_params
):
ang_corr = AngularCorrelation(params=acf_params)
counts_dd, counts_rr = compute_autocorrelation_counts(
data_catalog, rand_catalog, ang_corr
estimator = NaturalEstimator(AngularCorrelation(params=acf_params))
counts_dd, counts_rr, _ = estimator.compute_autocorrelation_counts(
data_catalog, rand_catalog
)
npt.assert_allclose(counts_dd, acf_dd_counts, rtol=1e-3)
npt.assert_allclose(counts_rr, acf_rr_counts, rtol=2e-3)


def test_acf_natural_estimate_is_correct(
data_catalog_dir, rand_catalog_dir, acf_dd_counts, acf_rr_counts, acf_nat_estimate
dask_client, data_catalog, rand_catalog, acf_nat_estimate, acf_params
):
galaxy_hc_catalog = hipscat.read_from_hipscat(data_catalog_dir)
random_hc_catalog = hipscat.read_from_hipscat(rand_catalog_dir)
num_galaxies = galaxy_hc_catalog.catalog_info.total_rows
num_random = random_hc_catalog.catalog_info.total_rows
estimate = calculate_natural_estimate(
acf_dd_counts, acf_rr_counts, num_galaxies, num_random
)
npt.assert_allclose(acf_nat_estimate, estimate, rtol=2e-3)


def test_acf_e2e(dask_client, data_catalog, rand_catalog, acf_nat_estimate, acf_params):
acf_params.estimator = "NAT"
estimate = compute_autocorrelation(
data_catalog, rand_catalog, AngularCorrelation, params=acf_params
)
npt.assert_allclose(estimate, acf_nat_estimate, rtol=1e-7)


def test_acf_counts_with_weights_are_correct(
def test_acf_natural_counts_with_weights_are_correct(
dask_client,
acf_gals_weight_catalog,
acf_rans_weight_catalog,
acf_dd_counts_with_weights,
acf_rr_counts_with_weights,
acf_params,
):
ang_corr = AngularCorrelation(params=acf_params, use_weights=True)
counts_dd, counts_rr = compute_autocorrelation_counts(
acf_gals_weight_catalog, acf_rans_weight_catalog, ang_corr
estimator = NaturalEstimator(
AngularCorrelation(params=acf_params, use_weights=True)
)
counts_dd, counts_rr, _ = estimator.compute_autocorrelation_counts(
acf_gals_weight_catalog, acf_rans_weight_catalog
)
npt.assert_allclose(counts_dd, acf_dd_counts_with_weights, rtol=1e-3)
npt.assert_allclose(counts_rr, acf_rr_counts_with_weights, rtol=2e-3)
Expand Down
Loading