-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsampler_algorithms.py
619 lines (564 loc) · 35.7 KB
/
sampler_algorithms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
import jax.numpy as jnp
from jax import vmap, jit, grad, random, lax
import scipy as sc
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
class ModelParallelizer:
def __init__(self, model: callable = None, has_input: bool = True, chains: int = None, n_obs: int = None,
activate_jit: bool = False):
"""
Parallelling the model function for the fast evaluation of the model as well as the derivatives of the model
with respect to the model parameters. The model can be either in the format of y=f(theta,x) or y=f(theta).
Constraint: The model should be a multi-input-single-output model.
:param: chains: an integer indicating the number of chains used for parallel evaluation of the model
:param model: Given an input of the data, the output of the model is returned. The model inputs are parameters
(ndim x 1) and model input variables (N x s). For parallel evaluation, the model input would be (ndim x C).
:param: n_obs: an integer indicating the number of observations (measurements) of the model
:param activate_jit: A boolean variable used to activate(deactivate) just-in-time evaluation of the model
"""
if isinstance(chains, int):
self.chains = chains
elif not chains:
self.chains = None
else:
raise Exception('The number of chains (optional) is not specified correctly!')
if isinstance(n_obs, int):
self.n_obs = n_obs
elif not n_obs:
self.n_obs = None
else:
raise Exception('The number of observations (optional) is not specified correctly!')
if isinstance(activate_jit, bool):
self.activate_jit = activate_jit
else:
self.activate_jit = False
print(
f'---------------------------------------------------------------------------------------------------\n'
f'The default value of {self.activate_jit} is selected for parallelized simulations\n'
f'----------------------------------------------------------------------------------------------------')
if isinstance(has_input, bool):
self.has_input = has_input
if self.has_input:
print(f'---------------------------------------------------------\n'
' you have specified that your model is like y = f(theta,x)\n'
'----------------------------------------------------------')
else:
print(f'---------------------------------------------------------\n'
' you have specified that your model is like y = f(theta)\n'
' ----------------------------------------------------------')
else:
raise Exception('Please specify whether the model has any input other than model parameters')
# parallelize the model evaluation as well as calculating the
if hasattr(model, "__call__"):
self.model_eval = model
if self.has_input:
model_val = vmap(vmap(self.model_eval,
in_axes=[None, 0], # this means that we loop over input observations(1 -> N)
axis_size=self.n_obs, # specifying the number of measurements
out_axes=0), # means that we stack the observations in rows
in_axes=[1, None], # means that we loop over chains (1 -> C)
axis_size=self.chains, # specifying the number of chains
out_axes=1) # means that we stack chains in columns
model_der = vmap(vmap(grad(self.model_eval,
argnums=0), # parameter 0 means model parameters (d/d theta)
in_axes=[1, None], # [1, None] means that we loop over chains (1 -> C)
out_axes=1), # means that chains are stacked in the second dimension
in_axes=[None, 0], # [None, 0] looping over model inputs (1 -> N)
axis_size=self.n_obs, # the size of observations
out_axes=2) # staking
else:
def reshaped_model(inputs):
return self.model_eval(inputs)[jnp.newaxis]
model_val = vmap(reshaped_model, in_axes=1, axis_size=self.chains, out_axes=1)
model_der = vmap(grad(self.model_eval, argnums=0), in_axes=1, axis_size=self.chains, out_axes=1)
else:
raise Exception('The function of the model is not defined properly!')
if self.activate_jit:
self.model_evaluate_ = jit(model_val)
self.diff_model_evaluate_ = jit(model_der)
else:
self.model_evaluate_ = model_val
self.diff_model_evaluate_ = model_der
def model_evaluation(self, parameter: jnp.ndarray = None, x: jnp.ndarray = None):
"""
Vectorized evaluation of the model
:param parameter: The matrix (vector) of model parameters
:param x: The matrix of model input
:return: The vectorized evaluation of the model
"""
return self.model_evaluate_(parameter, x)
def model_derivatives(self, parameter: jnp.ndarray = None, x: jnp.ndarray = None):
"""
Vectorized evaluation of the model derivatives
:param parameter: The matrix (vector) of model parameters
:param x: The matrix of model input
:return: The vectorized evaluation the first derivative of the model with respect to each parameter
"""
return self.diff_model_evaluate_(parameter, x)
def model_full_evaluation(self, parameter: jnp.ndarray = None, x: jnp.ndarray = None):
"""
Vectorized evaluation of the model and the first derivatives of the model
:param parameter: The matrix (vector) of model parameters
:param x: The matrix of model input
:return: The vectorized evaluation the first derivative of the model with respect to each parameter
"""
return self.model_evaluate_(parameter, x), self.diff_model_evaluate_(parameter, x)
@property
def info(self):
if self.has_input:
print('----------------------------------------------------------\n'
'the input of the model should be in the format of: \n'
'theta: (ndim x C). Where ndim indicate the dimension of the\n'
'problem. C also account for the number of chains (parallel \n'
'evaluation).\n'
'X(N x s): A matrix of the input (other than the model\n'
'parameters). N indicate the number of observations and\n'
's indicate the number of input variables.\n'
'Output:\n'
'y (N x C): parallelized valuation of the model output\n'
'dy/dt (ndim x C x N): the derivatives of the model\n'
' output with respect to each model parameters\n'
'----------------------------------------------------------')
else:
print('----------------------------------------------------------\n'
'the input of the model should be in the format of: \n'
'theta: (ndim x C). Where ndim indicate the dimension of the\n'
'problem. C also account for the number of chains (parallel \n'
'evaluation).\n'
'Output:\n'
'y (1 x C): parallelized valuation of the model output\n'
'dy/dt (ndim x C): the derivatives of the model\n'
' output with respect to each model parameters\n'
'----------------------------------------------------------')
return
class ParameterProposalInitialization:
def __init__(self, log_prop_fcn: callable = None,
iterations: int = None,
burnin: int = None,
x_init: jnp.ndarray = None,
activate_jit: bool = False,
chains: int = 1,
progress_bar: bool = True,
random_seed: int = 1,
move: str = 'single_stretch',
cov: jnp.ndarray = None,
n_split: int = 2,
a: float = None):
if isinstance(move, str): # checking the correctness of parameter proposal algorithm
if move in ['single_stretch', 'random_walk', 'parallel_stretch']:
self.move = move
elif not move:
self.move = 'random_walk'
else:
raise Exception('The algorithm of updating proposal parameters is not specified correctly')
if hasattr(log_prop_fcn, "__call__"): # checking the correctness of log probability function
self.log_prop_fcn = log_prop_fcn
else:
raise Exception('The log probability function is not defined properly!')
self.key = random.PRNGKey(random_seed)
if isinstance(iterations, int): # checking the correctness of the iteration
self.iterations = iterations
else:
self.iterations = 1000
print(f'-------------------------------------------------------------------------------------------------\n'
f'The iteration is not an integer value.\n'
f' The default value of {self.iterations} is selected as the number of iterations\n'
f'--------------------------------------------------------------------------------------------------')
if isinstance(burnin, int): # checking the correctness of the burnin period
self.burnin = burnin
elif burnin is None:
self.burnin = 0
print(f'-------------------------------------------------------------------------------------------------\n'
f'The number samples from dropping after simulation is not an integer value.\n'
f' The default value of {self.burnin} is selected as the number of burnin samples\n'
f'--------------------------------------------------------------------------------------------------')
else:
self.burnin = 0
print(f'-------------------------------------------------------------------------------------------------\n'
f'The number samples from dropping after simulation is not an integer value.\n'
f' The default value of {self.burnin} is selected as the number of burnin samples\n'
f'--------------------------------------------------------------------------------------------------')
if self.burnin >= self.iterations: # checking the correctness of iteration and burnin period
raise Exception('The number of samples selected for burnin cannot be greater than the simulation samples!')
if isinstance(chains, int): # checking the correctness of the number of chains
self.n_chains = chains
else:
self.n_chains = 1
print(
f'---------------------------------------------------------------------------------------------------\n'
f'The number of chains is not an integer value.\n'
f' The default value of {self.n_chains} is selected as the number of chains\n'
f'----------------------------------------------------------------------------------------------------')
if isinstance(n_split, int): # checking the correctness of the number of split
if self.move == 'parallel_stretch' and self.n_chains % n_split:
raise Exception(f'The number of chains should be a multiplication of the number of splits.\n'
f'As a suggestion, you may use {((self.n_chains // n_split) + 1) * n_split} as the number\n'
f'of chains.')
if self.move == 'parallel_stretch' and not self.n_chains % n_split:
self.n_split = n_split
self.split_len = self.n_chains // self.n_split
elif not n_split:
self.n_split = 2
print(
f'---------------------------------------------------------------------------------------------------\n'
f'The number of solit is not specified.\n'
f' The default value of {self.n_split} is selected as the number of splits\n'
f'----------------------------------------------------------------------------------------------------')
else:
raise Exception('The number of splits for ensemble sampling is not specified correctly')
if isinstance(x_init, jnp.ndarray): # checking the correctness of initial condition
dim1, dim2 = x_init.shape
if dim2 != self.n_chains:
raise Exception('The initial condition is not consistent with the number of chains!')
elif dim1 * 2 > self.n_chains:
raise Exception('The number of chains should be least two times of the dimension of the parameters')
else:
self.ndim = dim1
self.x_init = x_init
else:
raise Exception('The initial condition is not selected properly!')
if isinstance(activate_jit, bool): # checking the correctness of the just-in-time simulation
self.activate_jit = activate_jit
else:
self.activate_jit = False
print(
f'---------------------------------------------------------------------------------------------------\n'
f'The default value of {self.activate_jit} is selected for parallelized simulations\n'
f'----------------------------------------------------------------------------------------------------')
if isinstance(progress_bar, bool): # checking the correctness of the progressbar
self.progress_bar = not progress_bar
else:
self.progress_bar = False
print(
f'---------------------------------------------------------------------------------------------------\n'
f'The progress bar is activated by default since the it is not entered by the user\n'
f'----------------------------------------------------------------------------------------------------')
if isinstance(cov, jnp.ndarray): # checking the correctness of the covariance
if (cov.shape[0] != cov.shape[1]) or (cov.shape[0] != self.ndim):
raise Exception('The size of the covariance matrix is either incorrect or inconsistent with the'
' dimension of the parameters')
else:
if jnp.all(jnp.linalg.eigvals(cov) > 0):
self.cov_proposal = cov
else:
raise Exception('The covariance matrix of updating parameters should be positive definite')
elif not cov:
self.cov_proposal = None
else:
raise Exception('The covariance matrix for calculating proposal parameters are not entered correctly')
if isinstance(a,
(float, int)): # checking the correctness of the scaling factor a (for single/parallel stretch)
if a > 1:
self.a_proposal = a
else:
raise Exception('The value of a should be greater than 1')
elif not a:
self.a_proposal = None
else:
raise Exception('The value of a is not specified correctly')
if self.move == 'random_walk': # using random walk proposal algorithm
self.rndw_samples = jnp.transpose(random.multivariate_normal(key=self.key, mean=jnp.zeros((1, self.ndim)),
cov=self.cov_proposal[jnp.newaxis, :, :],
shape=(self.iterations, self.n_chains)),
axes=(2, 1, 0))
self.proposal_alg = self.random_walk_proposal
elif self.move == 'single_stretch': # using single stretch proposal algorithm
self.z = jnp.power(
(random.uniform(key=self.key, minval=0, maxval=1.0, shape=(self.iterations, self.n_chains)) *
(jnp.sqrt(self.a_proposal) - jnp.sqrt(1 / self.a_proposal)) + jnp.sqrt(1 / self.a_proposal)),
2)
self.index = jnp.zeros((self.iterations, self.n_chains), dtype=int)
ordered_index = jnp.arange(self.n_chains, dtype=int)
for i in range(self.n_chains):
self.index = self.index.at[:, i].set(random.choice(key=self.key, a=jnp.delete(arr=ordered_index, obj=i),
replace=True, shape=(self.iterations,)))
self.key += 1
self.proposal_alg = self.single_strech_move
elif self.move == 'parallel_stretch': # using parallel stretch move parameter proposal by dividing chains into
# n_split sub walkers
self.iterations *= 2
self.z = jnp.power(
(random.uniform(key=self.key, minval=0, maxval=1.0, shape=(self.iterations, self.n_chains)) *
(jnp.sqrt(self.a_proposal) - jnp.sqrt(1 / self.a_proposal)) + jnp.sqrt(1 / self.a_proposal)),
2)
self.index = jnp.zeros((self.iterations, self.n_chains), dtype=int)
ordered_index = jnp.arange(self.n_split).astype(int)
single_split = jnp.arange(start=0, step=1, stop=self.split_len)
for i in range(self.n_split):
selected_split = random.choice(key=self.key, a=jnp.delete(arr=ordered_index, obj=i), replace=True,
shape=(self.iterations, 1))
self.index = self.index.at[:, i * self.split_len:(i + 1) * self.split_len].set(random.permutation(
key=self.key,
x=selected_split * self.split_len + single_split,
axis=1,
independent=True))
self.key += 1
self.proposal_alg = self.single_strech_move
else:
raise Exception('A method for updating parameters should be entered')
def random_walk_proposal(self, whole_chains: jnp.ndarray = None, itr: int = None):
return whole_chains[:, :, itr - 1] + self.rndw_samples[:, :, itr - 1]
def single_strech_move(self, whole_chains: jnp.ndarray = None, itr: int = None):
return whole_chains[:, self.index[itr - 1, :], itr - 1] + self.z[itr, :] * (
whole_chains[:, :, itr - 1] - whole_chains[:, self.index[itr - 1, :], itr - 1])
class MetropolisHastings(ParameterProposalInitialization):
def __init__(self, log_prop_fcn: callable = None, iterations: int = None, burnin: int = None,
x_init: jnp.ndarray = None, activate_jit: bool = False, chains: int = 1, progress_bar: bool = True,
random_seed: int = 1, cov: jnp.ndarray = None):
"""
Metropolis Hastings sampling algorithm
:param log_prop_fcn: Takes the log posteriori function
:param iterations: The number of iteration
:param burnin: The number of initial samples to be droped sowing to non-stationary behaviour
:param x_init: The initialized value of parameters
:param parallelized: A boolean variable used to activate or deactivate the parallelized calculation
:param chains: the number of chains used for simulation
:param progress_bar: A boolean variable used to activate or deactivate the progress bar. Deactivation of the
progress bar results in activating XLA -accelerated iteration for the fast evaluation of the
chains(recommended!)
:param model: The model function (a function that input parameters and returns estimations)
"""
super(MetropolisHastings, self).__init__(log_prop_fcn=log_prop_fcn, iterations=iterations, burnin=burnin,
x_init=x_init, activate_jit=activate_jit, chains=chains, cov=cov,
progress_bar=progress_bar, random_seed=random_seed, move='random_walk')
# initializing chain values
self.chains = jnp.zeros((self.ndim, self.n_chains, self.iterations))
# initializing the log of the posteriori values
self.log_prop_values = jnp.zeros((self.iterations, self.n_chains))
# initializing the track of hasting ratio values
self.accept_rate = jnp.zeros((self.iterations, self.n_chains))
# in order to calculate the acceptance ration of all chains
self.n_of_accept = jnp.zeros((1, self.n_chains))
def sample(self):
"""
vectorized metropolis-hastings sampling algorithm used for sampling from the posteriori distribution
:returns: chains: The chains of samples drawn from the posteriori distribution
acceptance rate: The acceptance rate of the samples drawn form the posteriori distributions
"""
self.chains = self.chains.at[:, :, 0].set(self.x_init)
self.log_prop_values = self.log_prop_values.at[0:1, :].set(self.log_prop_fcn(self.x_init))
self.uniform_rand = random.uniform(key=self.key, minval=0, maxval=1.0, shape=(self.iterations, self.n_chains))
def alg_with_progress_bar(itr: int = None) -> None:
# The function suited for using progress bar
proposed = self.proposal_alg(whole_chains=self.chains, itr=itr)
ln_prop = self.log_prop_fcn(proposed)
hastings = jnp.minimum(jnp.exp(ln_prop - self.log_prop_values[itr - 1, :]), 1)
satis = (self.uniform_rand[itr, :] < hastings)[0, :]
non_satis = ~satis
self.chains = self.chains.at[:, satis, itr].set(proposed[:, satis])
self.chains = self.chains.at[:, non_satis, itr].set(self.chains[:, non_satis, itr - 1])
self.log_prop_values = self.log_prop_values.at[itr, satis].set(ln_prop[0, satis])
self.log_prop_values = self.log_prop_values.at[itr, non_satis].set(self.log_prop_values[i - 1, non_satis])
self.n_of_accept = self.n_of_accept.at[0, satis].set(self.n_of_accept[0, satis] + 1)
self.accept_rate = self.accept_rate.at[itr, :].set(self.n_of_accept[0, :] / itr)
return
def alg_with_lax_acclelrated(itr: int, recursive_variables: tuple) -> tuple:
# The function suited for fast and efficient evaluation of chains
lax_chains, lax_log_prop_values, lax_n_of_accept, lax_accept_rate = recursive_variables
proposed = self.proposal_alg(whole_chains=lax_chains, itr=itr)
ln_prop = self.log_prop_fcn(proposed)
hastings = jnp.minimum(jnp.exp(ln_prop - lax_log_prop_values[itr - 1, :]), 1)
lax_log_prop_values = lax_log_prop_values.at[itr, :].set(jnp.where(self.uniform_rand[itr, :] < hastings,
ln_prop,
lax_log_prop_values[itr - 1, :])[0, :])
lax_chains = lax_chains.at[:, :, itr].set(jnp.where(self.uniform_rand[itr, :] < hastings,
proposed,
lax_chains[:, :, itr - 1]))
lax_n_of_accept = lax_n_of_accept.at[0, :].set(jnp.where(self.uniform_rand[itr, :] < hastings,
lax_n_of_accept[0, :] + 1,
lax_n_of_accept[0, :])[0, :])
lax_accept_rate = lax_accept_rate.at[itr, :].set(lax_n_of_accept[0, :] / itr)
return lax_chains, lax_log_prop_values, lax_n_of_accept, lax_accept_rate
if not self.progress_bar:
for i in tqdm(range(1, self.iterations), disable=self.progress_bar):
alg_with_progress_bar(i)
else:
print('Simulating...')
self.chains, \
self.log_prop_values, \
self.n_of_accept, \
self.accept_rate = lax.fori_loop(lower=1,
upper=self.iterations,
body_fun=alg_with_lax_acclelrated,
init_val=(
self.chains.copy(),
self.log_prop_values.copy(),
self.n_of_accept.copy(),
self.accept_rate.copy()
))
return self.chains[:, :, self.burnin:], self.accept_rate
class MCMCHammer(ParameterProposalInitialization):
def __init__(self, log_prop_fcn: callable = None, iterations: int = None, burnin: int = None,
x_init: jnp.ndarray = None, activate_jit: bool = False, chains: int = 1, progress_bar: bool = True,
random_seed: int = 1, move: str = 'single_stretch', a: float = 2, n_split: int = 2):
"""
MCMC Hammer empowered with jax to large scale simulation
:param log_prop_fcn: A callable function returning the log-likelihood (or posteriori) of the distribution
:param iterations: An integer indicating the number of steps(or samples)
:param burnin: An integer used for truncating chains of samples to remove the transient variation of chains
:param x_init: An matrix (NxC) encompassing the initial condition for each chain
:param activate_jit: A boolean variable for activating/deactivating just-in-time evaluation of functions
:param chains: An integer determining the number of chains
:param progress_bar: A boolean variable used for activating or deactivating the progress bar
:param random_seed: An integer for fixing rng
:param move: A string variable used to determine the algorithm for calculating the proposal parameters. Options
are "single_stretch", "parallel_stretch"
:param a: An adjustable scale parameter (1<a) used for calculating the proposal parameters
"""
super(MCMCHammer, self).__init__(log_prop_fcn=log_prop_fcn, iterations=iterations, burnin=burnin,
x_init=x_init, activate_jit=activate_jit, chains=chains,
progress_bar=progress_bar, random_seed=random_seed, move=move,
a=a, n_split=n_split)
# initializing chain values
self.chains = jnp.zeros((self.ndim, self.n_chains, self.iterations))
# initializing the log of the posteriori values
self.log_prop_values = jnp.zeros((self.iterations, self.n_chains))
# initializing the track of hasting ratio values
self.accept_rate = jnp.zeros((self.iterations, self.n_chains))
# in order to calculate the acceptance ration of all chains
self.n_of_accept = jnp.zeros((1, self.n_chains))
def sample(self):
"""
vectorized MCMC Hammer sampling algorithm used for sampling from the posteriori distribution. Developed based on
the paper published in 2013:
<<Foreman-Mackey, Daniel, et al. "emcee: the MCMC hammer." Publications of the Astronomical
Society of the Pacific 125.925 (2013): 306.>>
:returns: chains: The chains of samples drawn from the posteriori distribution
acceptance rate: The acceptance rate of the samples drawn form the posteriori distributions
"""
self.chains = self.chains.at[:, :, 0].set(self.x_init)
self.log_prop_values = self.log_prop_values.at[0:1, :].set(self.log_prop_fcn(self.x_init))
self.uniform_rand = random.uniform(key=self.key, minval=0, maxval=1.0, shape=(self.iterations, self.n_chains))
# # for single streatch
# self.index = jnp.zeros((self.iterations, self.n_chains))
# ordered_index = jnp.arange(self.n_chains).astype(int)
# for i in range(self.n_chains):
# self.index = self.index.at[:, i].set(random.choice(key=self.key, a=jnp.delete(arr=ordered_index, obj=i),
# replace=True, shape=(self.iterations,)))
#
# n_split = 4 self.n_split = n_split self.split_len = self.n_chains // self.n_split ordered_index =
# jnp.arange(self.n_split).astype(int) single_split = jnp.arange(start=0, step=1, stop=self.split_len) for i
# in range(self.n_split): selected_split = random.choice(key=self.key, a=jnp.delete(arr=ordered_index,
# obj=i), replace=True, shape=(self.iterations, 1)) # XX = random.permutation(key=self.key, x=selected_split
# * self.split_len + single_split, axis=1, independent=True) self.index = self.index.at[:,
# i * self.split_len:(i + 1) * \ self.split_len].set(random.permutation(key=self.key, x=selected_split *
# self.split_len + single_split, axis=1, independent=True)) self.key += 1
def alg_with_progress_bar(itr: int = None) -> None:
# The function suited for using progress bar
proposed = self.proposal_alg(whole_chains=self.chains, itr=itr)
ln_prop = self.log_prop_fcn(proposed)
hastings = jnp.minimum(jnp.power(self.z[itr - 1, :], self.ndim - 1) *
jnp.exp(ln_prop - self.log_prop_values[itr - 1, :]), 1)
satis = (self.uniform_rand[itr, :] < hastings)[0, :]
non_satis = ~satis
self.chains = self.chains.at[:, satis, itr].set(proposed[:, satis])
self.chains = self.chains.at[:, non_satis, itr].set(self.chains[:, non_satis, itr - 1])
self.log_prop_values = self.log_prop_values.at[itr, satis].set(ln_prop[0, satis])
self.log_prop_values = self.log_prop_values.at[itr, non_satis].set(self.log_prop_values[i - 1, non_satis])
self.n_of_accept = self.n_of_accept.at[0, satis].set(self.n_of_accept[0, satis] + 1)
self.accept_rate = self.accept_rate.at[itr, :].set(self.n_of_accept[0, :] / itr)
return
def alg_with_lax_acclelrated(itr: int, recursive_variables: tuple) -> tuple:
# The function suited for fast and efficient evaluation of chains
lax_chains, lax_log_prop_values, lax_n_of_accept, lax_accept_rate = recursive_variables
proposed = self.proposal_alg(whole_chains=lax_chains, itr=itr)
ln_prop = self.log_prop_fcn(proposed)
hastings = jnp.minimum(jnp.power(self.z[itr - 1, :], self.ndim - 1) *
jnp.exp(ln_prop - lax_log_prop_values[itr - 1, :]), 1)
lax_log_prop_values = lax_log_prop_values.at[itr, :].set(jnp.where(self.uniform_rand[itr, :] < hastings,
ln_prop,
lax_log_prop_values[itr - 1, :])[0, :])
lax_chains = lax_chains.at[:, :, itr].set(jnp.where(self.uniform_rand[itr, :] < hastings,
proposed,
lax_chains[:, :, itr - 1]))
lax_n_of_accept = lax_n_of_accept.at[0, :].set(jnp.where(self.uniform_rand[itr, :] < hastings,
lax_n_of_accept[0, :] + 1,
lax_n_of_accept[0, :])[0, :])
lax_accept_rate = lax_accept_rate.at[itr, :].set(lax_n_of_accept[0, :] / itr)
return lax_chains, lax_log_prop_values, lax_n_of_accept, lax_accept_rate
if not self.progress_bar:
for i in tqdm(range(1, self.iterations), disable=self.progress_bar):
alg_with_progress_bar(itr=i)
else:
print('Simulating...')
self.chains, \
self.log_prop_values, \
self.n_of_accept, \
self.accept_rate = lax.fori_loop(lower=1,
upper=self.iterations,
body_fun=alg_with_lax_acclelrated,
init_val=(
self.chains.copy(),
self.log_prop_values.copy(),
self.n_of_accept.copy(),
self.accept_rate.copy()
))
if self.move == 'parallel_stretch':
return self.chains[:, :, self.burnin::2], self.accept_rate[::2, :]
elif self.move == 'single_stretch':
return self.chains[:, :, self.burnin:], self.accept_rate
class HMC(ParameterProposalInitialization):
def __init__(self, log_prop_fcn: callable = None, iterations: int = None, burnin: int = None,
x_init: jnp.ndarray = None, activate_jit: bool = False, chains: int = 1, progress_bar: bool = True,
random_seed: int = 1, move: str = 'single_stretch'):
"""
:param log_prop_fcn:
:param iterations:
:param burnin:
:param x_init:
:param activate_jit:
:param chains:
:param progress_bar:
:param random_seed:
:param move:
"""
super(HMC, self).__init__(log_prop_fcn=log_prop_fcn, iterations=iterations, burnin=burnin,
x_init=x_init, activate_jit=activate_jit, chains=chains,
progress_bar=progress_bar, random_seed=random_seed, move=move)
class NUTS(ParameterProposalInitialization):
def __init__(self, log_prop_fcn: callable = None,
grad_log_prob: callable = None,
iterations: int = None,
burnin: int = None,
x_init: jnp.ndarray = None,
activate_jit: bool = False,
chains: int = 1,
progress_bar: bool = True,
random_seed: int = 1,
step_size: float = None,
max_depth: int = None,
move: str = 'single_stretch'):
super(NUTS, self).__init__(log_prop_fcn=log_prop_fcn,
iterations=iterations,
burnin=burnin,
x_init=x_init,
activate_jit=activate_jit,
chains=chains,
progress_bar=progress_bar,
random_seed=random_seed,
move=move)
x = x_init.copy()
samples = [x_init]
log_prob_x = log_prop_fcn(x)
grad_log_prob_x = grad_log_prob(x)
eps = np.random.normal(size=x_init.shape)
p = step_size * eps
rho = 1
for i in range(max_depth):
x_prime = x + p
log_prob_x_prime = log_prop_fcn(x_prime)
grad_log_prob_x_prime = grad_log_prob(x_prime)
alpha = np.min([1, np.exp(log_prob_x_prime - log_prob_x - 0.5 * np.dot(p, grad_log_prob_x))])
u = np.random.rand()
if u < alpha:
samples.append(x_prime)
x = x_prime
log_prob_x = log_prob_x_prime
grad_log_prob_x = grad_log_prob_x_prime
else:
p = -p
rho = rho * (1 - alpha)
if rho < np.random.rand():
break
return samples