Skip to content

Commit d905b68

Browse files
authoredDec 6, 2023
Merge pull request #58 from rmnldwg/release-1.0.0.a3
Release 1.0.0.a3
2 parents e464ea8 + 1358d52 commit d905b68

8 files changed

+157
-10
lines changed
 

‎CHANGELOG.md

+29-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,30 @@
33
All notable changes to this project will be documented in this file.
44

55

6+
<a name="1.0.0.a3"></a>
7+
## [1.0.0.a3] - 2023-12-06
8+
9+
Fourth alpha release. [@YoelPH](https://github.com/YoelPH) noticed some more bugs that have been fixed now. Most notably, the risk prediction raised exceptions, because of a missing transponed matrix `.T`.
10+
11+
### Bug Fixes
12+
13+
- Raise `ValueError` if diagnose time parameters are invalid (Fixes [#53])
14+
- Use names of LNLs in unilateral `comp_encoding()` (Fixes [#56])
15+
- Wrong shape in unilateral posterior computation (missing `.T`) (Fixes [#57])
16+
- Wrong shape in bilateral joint posterior computation (missing `.T`) (Fixes [#57])
17+
18+
### Documentation
19+
20+
- Add info on diagnose time distribution's `ValueError`
21+
22+
### Testing
23+
24+
- `ValueError` raised in diagnose time distribution's `set_params`
25+
- Check `comp_encoding_diagnoses()` for shape and dtype
26+
- Test unilateral posterior state distribution for shape and sum
27+
- Test bilateral posterior joint state distribution for shape and sum
28+
29+
630
<a name="1.0.0.a2"></a>
731
## [1.0.0.a2] - 2023-09-15
832

@@ -160,7 +184,8 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the
160184
- add pre-commit hook to check commit msg
161185

162186

163-
[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.0.0.a2...HEAD
187+
[Unreleased]: https://github.com/rmnldwg/lymph/compare/1.0.0.a3...HEAD
188+
[1.0.0.a3]: https://github.com/rmnldwg/lymph/compare/1.0.0.a2...1.0.0.a3
164189
[1.0.0.a2]: https://github.com/rmnldwg/lymph/compare/1.0.0.a1...1.0.0.a2
165190
[1.0.0.a1]: https://github.com/rmnldwg/lymph/compare/1.0.0.a0...1.0.0.a1
166191
[1.0.0.a0]: https://github.com/rmnldwg/lymph/compare/0.4.3...1.0.0.a0
@@ -169,6 +194,9 @@ Almost the entire API has changed. I'd therefore recommend to have a look at the
169194
[0.4.1]: https://github.com/rmnldwg/lymph/compare/0.4.0...0.4.1
170195
[0.4.0]: https://github.com/rmnldwg/lymph/compare/0.3.10...0.4.0
171196

197+
[#57]: https://github.com/rmnldwg/lymph/issues/57
198+
[#56]: https://github.com/rmnldwg/lymph/issues/56
199+
[#53]: https://github.com/rmnldwg/lymph/issues/53
172200
[#46]: https://github.com/rmnldwg/lymph/issues/46
173201
[#45]: https://github.com/rmnldwg/lymph/issues/45
174202
[#41]: https://github.com/rmnldwg/lymph/issues/41

‎docs/source/quickstart_unilateral.ipynb

+3
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@
322322
"\n",
323323
"Here, it's important that the first argument is the support of the probability mass function, i.e., the discrete time-steps from 0 to `max_time`. Also, all parameters must have default values. Otherwise, there would be cases when such a stored distribution cannot be accessed.\n",
324324
"\n",
325+
"Lastly, if some parameters have bounds, like e.g. the binomial distribution, they should raise a `ValueError`. This exception is propagated upwards but causes the `likelihood` method to simply return `-np.inf`. That way it will be seamlessly rejected during an MCMC sampling round.\n",
326+
"\n",
325327
"Let's look at a concrete, binomial example:"
326328
]
327329
},
@@ -336,6 +338,7 @@
336338
"def binom_pmf(k: np.ndarray, n: int, p: float):\n",
337339
" \"\"\"Binomial PMF\"\"\"\n",
338340
" if p > 1. or p < 0.:\n",
341+
" # This value error is important to enable seamless sampling!\n",
339342
" raise ValueError(\"Binomial prob must be btw. 0 and 1\")\n",
340343
" q = (1. - p)\n",
341344
" binom_coeff = factorial(n) / (factorial(k) * factorial(n - k))\n",

‎lymph/diagnose_times.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def __init__(
4444
function must return a list of probabilities for each diagnose time.
4545
4646
Note:
47-
All arguments except ``support`` must have default values.
47+
All arguments except ``support`` must have default values and if some
48+
parameters have bounds (like the binomial distribution's ``p``), the
49+
function must raise a ``ValueError`` if the parameter is invalid.
4850
4951
Since ``max_time`` specifies the support of the distribution (rangin from 0 to
5052
``max_time``), it must be provided if a parametrized function is passed. If a
@@ -180,12 +182,29 @@ def get_params(
180182

181183

182184
def set_params(self, **kwargs) -> None:
183-
"""Update distribution by setting its parameters and storing the frozen PMF."""
185+
"""Update distribution by setting its parameters and storing the frozen PMF.
186+
187+
To work during inference using e.g. MCMC sampling, it needs to throw a
188+
``ValueError`` if the parameters are invalid. To this end, it expects the
189+
underlying function to raise a ``ValueError`` if one of the parameters is
190+
invalid. If the parameters are valid, the frozen PMF is stored and can be
191+
retrieved via the :py:meth:`distribution` property.
192+
"""
184193
params_to_set = set(kwargs.keys()).intersection(self._kwargs.keys())
185194
if self.is_updateable:
186-
if hasattr(self, "_frozen"):
187-
del self._frozen
188-
self._kwargs.update({p: kwargs[p] for p in params_to_set})
195+
new_kwargs = self._kwargs.copy()
196+
new_kwargs.update({p: kwargs[p] for p in params_to_set})
197+
198+
try:
199+
self._frozen = self.normalize(
200+
self._func(self.support, **new_kwargs)
201+
)
202+
except ValueError as val_err:
203+
raise ValueError(
204+
"Invalid parameter(s) provided to distribution over diagnose times"
205+
) from val_err
206+
207+
self._kwargs = new_kwargs
189208
else:
190209
warnings.warn("Distribution is not updateable, skipping...")
191210

‎lymph/models/bilateral.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def comp_posterior_joint_state_dist(
519519
)
520520
observation_matrix = getattr(self, side).observation_matrix
521521
# vector with P(Z=z|X) for each state X. A data matrix for one "patient"
522-
diagnose_given_state[side] = diagnose_encoding @ observation_matrix
522+
diagnose_given_state[side] = diagnose_encoding @ observation_matrix.T
523523

524524
joint_state_dist = self.comp_joint_state_dist(t_stage=t_stage, mode=mode)
525525
# matrix with P(Zi=zi,Zc=zc|Xi,Xc) * P(Xi,Xc) for all states Xi,Xc.

‎lymph/models/unilateral.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def comp_diagnose_encoding(
819819
diagnose_encoding = np.kron(
820820
diagnose_encoding,
821821
matrix.compute_encoding(
822-
lnls=[lnl.name for lnl in self.graph.lnls],
822+
lnls=self.graph.lnls.keys(),
823823
pattern=given_diagnoses.get(modality, {}),
824824
),
825825
)
@@ -873,7 +873,7 @@ def comp_posterior_state_dist(
873873

874874
diagnose_encoding = self.comp_diagnose_encoding(given_diagnoses)
875875
# vector containing P(Z=z|X). Essentially a data matrix for one patient
876-
diagnose_given_state = diagnose_encoding @ self.observation_matrix
876+
diagnose_given_state = diagnose_encoding @ self.observation_matrix.T
877877

878878
# vector P(X=x) of probabilities of arriving in state x (marginalized over time)
879879
state_dist = self.comp_state_dist(t_stage, mode=mode)

‎tests/binary_bilateral_test.py

+33
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,36 @@ def test_compute_likelihood_twice(self):
122122
first_llh = self.model.likelihood(log=True)
123123
second_llh = self.model.likelihood(log=True)
124124
self.assertEqual(first_llh, second_llh)
125+
126+
127+
class RiskTestCase(fixtures.BilateralModelMixin, unittest.TestCase):
128+
"""Check that the risk is computed correctly."""
129+
130+
def setUp(self):
131+
super().setUp()
132+
self.model.modalities = fixtures.MODALITIES
133+
134+
def create_random_diagnoses(self):
135+
"""Create a random diagnosis for each modality and LNL."""
136+
diagnoses = {}
137+
138+
for modality in self.model.modalities:
139+
diagnoses[modality] = {}
140+
for lnl in self.model.ipsi.graph.lnls.keys():
141+
diagnoses[modality][lnl] = self.rng.choice([True, False, None])
142+
143+
return diagnoses
144+
145+
def test_posterior_state_dist(self):
146+
"""Test that the posterior state distribution is computed correctly."""
147+
num_states = len(self.model.ipsi.state_list)
148+
random_parameters = self.create_random_params()
149+
random_diagnoses = self.create_random_diagnoses()
150+
151+
posterior = self.model.comp_posterior_joint_state_dist(
152+
given_param_kwargs=random_parameters,
153+
given_diagnoses=random_diagnoses,
154+
)
155+
self.assertEqual(posterior.shape, (num_states, num_states))
156+
self.assertEqual(posterior.dtype, float)
157+
self.assertTrue(np.isclose(posterior.sum(), 1.))

‎tests/binary_unilateral_test.py

+41
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,44 @@ def test_likelihood_invalid_params_isinf(self):
317317
mode="HMM",
318318
)
319319
self.assertEqual(likelihood, -np.inf)
320+
321+
322+
class RiskTestCase(fixtures.BinaryUnilateralModelMixin, unittest.TestCase):
323+
"""Test anything related to the risk computation."""
324+
325+
def setUp(self):
326+
"""Load params."""
327+
super().setUp()
328+
self.model.modalities = fixtures.MODALITIES
329+
self.init_diag_time_dists(early="frozen", late="parametric")
330+
self.model.assign_params(**self.create_random_params())
331+
332+
def create_random_diagnoses(self):
333+
"""Create a random diagnosis for each modality and LNL."""
334+
self.diagnoses = {}
335+
336+
for modality in self.model.modalities:
337+
self.diagnoses[modality] = {}
338+
for lnl in self.model.graph.lnls.keys():
339+
self.diagnoses[modality][lnl] = self.rng.choice([True, False, None])
340+
341+
def test_comp_diagnose_encoding(self):
342+
"""Check computation of one-hot encoding of diagnoses."""
343+
self.create_random_diagnoses()
344+
num_lnls, num_mods = len(self.model.graph.lnls), len(self.model.modalities)
345+
num_posible_diagnoses = 2**(num_lnls * num_mods)
346+
347+
diagnose_encoding = self.model.comp_diagnose_encoding(self.diagnoses)
348+
self.assertEqual(diagnose_encoding.shape, (num_posible_diagnoses,))
349+
self.assertEqual(diagnose_encoding.dtype, bool)
350+
351+
def test_posterior_state_dist(self):
352+
"""Make sure the posterior state dist is correctly computed."""
353+
posterior_state_dist = self.model.comp_posterior_state_dist(
354+
given_param_kwargs=self.create_random_params(),
355+
given_diagnoses=self.create_random_diagnoses(),
356+
t_stage=self.rng.choice(["early", "late"]),
357+
)
358+
self.assertEqual(posterior_state_dist.shape, (2**len(self.model.graph.lnls),))
359+
self.assertEqual(posterior_state_dist.dtype, float)
360+
self.assertTrue(np.isclose(np.sum(posterior_state_dist), 1.))

‎tests/distribution_test.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,27 @@
1111
class FixtureMixin:
1212
"""Mixin that provides fixtures for the tests."""
1313

14+
@staticmethod
15+
def binom_pmf(
16+
support: np.ndarray,
17+
max_time: int = 10,
18+
p: float = 0.5,
19+
) -> np.ndarray:
20+
"""Binomial probability mass function."""
21+
if max_time <= 0:
22+
raise ValueError("max_time must be a positive integer.")
23+
if len(support) != max_time + 1:
24+
raise ValueError("support must have length max_time + 1.")
25+
if not 0. <= p <= 1.:
26+
raise ValueError("p must be between 0 and 1.")
27+
28+
return sp.stats.binom.pmf(support, max_time, p)
29+
30+
1431
def setUp(self):
1532
self.max_time = 10
1633
self.array_arg = np.random.uniform(size=self.max_time + 1, low=0., high=10.)
17-
self.func_arg = lambda support, p=0.5: sp.stats.binom.pmf(support, self.max_time, p)
34+
self.func_arg = lambda support, p=0.5: self.binom_pmf(support, self.max_time, p)
1835

1936

2037
class DistributionTestCase(FixtureMixin, unittest.TestCase):
@@ -58,6 +75,12 @@ def test_updateable_distribution_with_max_time(self):
5875
self.assertTrue(len(dist.distribution) == self.max_time + 1)
5976
self.assertTrue(np.allclose(sum(dist.distribution), 1.))
6077

78+
def test_updateable_distribution_raises_value_error(self):
79+
"""Check that an invalid parameter raises a ValueError."""
80+
dist = Distribution(self.func_arg, max_time=self.max_time)
81+
self.assertTrue(dist.is_updateable)
82+
self.assertRaises(ValueError, dist.set_params, p=1.5)
83+
6184

6285
class DistributionDictTestCase(FixtureMixin, unittest.TestCase):
6386
"""Test the distribution dictionary."""

0 commit comments

Comments
 (0)
Please sign in to comment.