Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve bug in bar_model.delta_f_ creation #397

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/alchemlyb/estimators/bar_.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,25 +100,24 @@ def fit(self, u_nk):
(len(groups.get_group(i)) if i in groups.groups else 0)
for i in u_nk.columns
]

# Pull lambda states from indices
states = list(set( x[1:] for x in u_nk.index))
states = list(set(x[1:] if len(x[1:]) > 1 else x[1] for x in u_nk.index))
for state in states:
if len(state) == 1:
state = state[0]
if state not in self._states_:
raise ValueError(
f"Indexed lambda state, {state}, is not represented in u_nk columns:"
f" {self._states_}"
)

states.sort(key=lambda x: self._states_.index(x))

# Now get free energy differences and their uncertainties for each step
deltas = np.array([])
d_deltas = np.array([])
for k in range(len(N_k) - 1):
if N_k[k] == 0 or N_k[k + 1] == 0:
continue

# get us from lambda step k
uk = groups.get_group(self._states_[k])
# get w_F
Expand Down Expand Up @@ -149,7 +148,7 @@ def fit(self, u_nk):
"To compute the free energy with BAR, ensure that values in u_nk exist"
f" for the columns:\n{states}."
)

# build matrix of deltas between each state
adelta = np.zeros((len(deltas) + 1, len(deltas) + 1))
ad_delta = np.zeros_like(adelta)
Expand Down
44 changes: 44 additions & 0 deletions src/alchemlyb/tests/test_fep_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,50 @@ def test_states_(self, estimator):
_estimator.states_ = 1


def test_delta_f_columns(
gmx_benzene_Coulomb_u_nk,
gmx_expanded_ensemble_case_1,
):
"""Ensure columns are tuples when appropriate."""

bar_1lambda = BAR().fit(alchemlyb.concat(gmx_benzene_Coulomb_u_nk))
assert set(bar_1lambda.delta_f_.columns) == set([0.0, 0.25, 0.5, 0.75, 1.0])

bar_4lambda = BAR().fit(alchemlyb.concat(gmx_expanded_ensemble_case_1))
assert set(bar_4lambda.delta_f_.columns) == set(
[
(0.0, 0.1, 0.0, 0.0),
(0.0, 0.4, 0.0, 0.0),
(0.0, 1.0, 0.4, 0.002),
(0.0, 1.0, 0.0, 0.0001),
(0.0, 1.0, 0.1, 0.0002),
(0.0, 0.84, 0.0, 0.0),
(0.0, 0.68, 0.0, 0.0),
(0.0, 1.0, 0.84, 0.2),
(0.0, 1.0, 0.3, 0.001),
(0.0, 1.0, 0.2, 0.0004),
(0.0, 0.16, 0.0, 0.0),
(0.0, 1.0, 0.52, 0.01),
(0.0, 1.0, 0.92, 0.4),
(0.0, 0.76, 0.0, 0.0),
(0.0, 0.46, 0.0, 0.0),
(0.0, 1.0, 0.6, 0.02),
(0.0, 0.92, 0.0, 0.0),
(0.0, 0.6, 0.0, 0.0),
(0.0, 0.34, 0.0, 0.0),
(0.0, 1.0, 0.76, 0.1),
(0.0, 1.0, 1.0, 1.0),
(0.0, 0.05, 0.0, 0.0),
(0.0, 1.0, 0.48, 0.004),
(0.0, 0.0, 0.0, 0.0),
(0.0, 0.22, 0.0, 0.0),
(0.0, 0.52, 0.0, 0.0),
(0.0, 1.0, 0.68, 0.04),
(0.0, 0.28, 0.0, 0.0),
]
)


def test_bootstrap(gmx_benzene_Coulomb_u_nk):
u_nk = alchemlyb.concat(gmx_benzene_Coulomb_u_nk)
mbar = MBAR(n_bootstraps=2)
Expand Down
Loading