-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiscrete_gaussian_distribution.hpp
171 lines (134 loc) · 5.37 KB
/
discrete_gaussian_distribution.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#ifndef _DISCRETE_GAUSSIAN_H_
#define _DISCRETE_GAUSSIAN_H_
#include <cassert>
#include <cmath>
#include <istream>
#include <numbers>
#include <ostream>
#include <random>
#include <type_traits>
#include "include/discrete_distributions/discrete_laplacian_distribution.hpp"
template <typename IntType = int> class DiscreteGaussian {
static_assert(std::is_integral<IntType>::value &&
std::is_signed<IntType>::value,
"template type must be a signed integral type");
public:
using result_type = IntType;
class param_type {
double _M_sigma_square;
double _M_sigma;
public:
explicit param_type(double sigma_square)
: _M_sigma_square(sigma_square), _M_sigma(std::sqrt(sigma_square)) {
assert(_M_sigma_square > 0 && "the variance must be positive");
}
double sigma_square() const { return _M_sigma_square; }
double sigma() const { return _M_sigma; }
friend bool operator==(const param_type &l, const param_type &r) {
return l._M_sigma_square == r._M_sigma_square;
}
friend bool operator!=(const param_type &l, const param_type &r) {
return !(l == r);
}
};
explicit DiscreteGaussian(double sigma) : _M_param(sigma) {}
DiscreteGaussian(const param_type ¶m) : _M_param(param) {}
void reset() {}
template <std::uniform_random_bit_generator URNG>
result_type operator()(URNG &urng) {
return (*this)(urng, _M_param);
}
template <std::uniform_random_bit_generator URNG>
result_type operator()(URNG &urng, const param_type ¶m) {
auto t = std::floor(param.sigma()) + 1;
auto p = std::exp(-1.0 / t);
while (1) {
// sample discrete laplacian
auto y = DiscreteLaplacian<>(p)(urng);
// sample bernoulli
auto q = std::exp(-std::pow(std::abs(y) - param.sigma_square() / t, 2) /
(2 * param.sigma_square()));
auto c = std::bernoulli_distribution(q)(urng);
if (c == 1) {
return y;
}
}
}
double sigma_square() const { return _M_param.sigma_square(); }
param_type param() const { return _M_param; }
// these are defined by the integer type we use
result_type min() const { return std::numeric_limits<IntType>::min(); }
result_type max() const { return std::numeric_limits<IntType>::max(); }
// the discrete gaussian is symmetric around 0
result_type mean() const { return 0; }
// the variance is not exactly sigma square for this distribution,
// but sigma square provides an upper bound that is very close
double var() const {
if (_M_param.sigma_square() > 1.0 / 3.0) {
return _M_param.sigma_square();
} else {
return 3 * std::exp(-1.0 / (2 * _M_param.sigma_square()));
}
}
// this function gives an **approximate** probability mass function for the
// discrete gaussian. it is approximate in the sense that we can not exactly
// compute the normalization term, so we use a lower bound that is fairly
// tight. therefore the probability mass values produced by this function
// slightly **overestimate** the probability of the number occuring. as the
// deviations are very small, this should not be an issue in real world
// applications, but if the behavior of an algorithm that depends on this
// function is not as expected, it might warrant investigation if the
// approxmation is sufficiently accurate.
double pmf(IntType k) const {
auto p = std::exp(-(k * k) / (2 * _M_param.sigma_square()));
// this is an approximate normalization constant, which gives a rather tight
// **lower bound** on the constant
auto normalization =
std::max(std::sqrt(2 * std::numbers::pi_v<double>) * _M_param.sigma() *
(1 + 2 * std::exp(-2.0 * std::numbers::pi_v<double> *
std::numbers::pi_v<double> *
_M_param.sigma_square())),
1 + 2 * std::exp(-1.0 / (2 * _M_param.sigma_square())));
return p / normalization;
}
// NOTE: currently we don't implement a cumulative density function for the
// discrete gaussian, as there is no good closed form expression
friend bool operator==(const DiscreteGaussian &l, const DiscreteGaussian &r) {
return l._M_param == r._M_param;
}
friend bool operator!=(const DiscreteGaussian &l, const DiscreteGaussian &r) {
return !(l == r);
}
private:
param_type _M_param;
};
template <class CharT, class Traits, class IntegerType>
std::basic_ostream<CharT, Traits> &
operator<<(std::basic_ostream<CharT, Traits> &os,
const DiscreteGaussian<IntegerType> &dnd) {
std::basic_ostream<CharT, Traits> savestate(nullptr);
savestate.copyfmt(os);
using OS = std::basic_ostream<CharT, Traits>;
os.flags(OS::dec | OS::left | OS::fixed | OS::scientific);
os << dnd.sigma_square();
os.copyfmt(savestate);
return os;
}
template <class CharT, class Traits, class IntegerType>
std::basic_istream<CharT, Traits> &
operator>>(std::basic_istream<CharT, Traits> &is,
const DiscreteGaussian<IntegerType> &dnd) {
std::basic_istream<CharT, Traits> savestate(nullptr);
savestate.copyfmt(is);
using IS = std::basic_istream<CharT, Traits>;
using param_type = DiscreteGaussian<IntegerType>::param_type;
is.flags(IS::dec | IS::skipws);
double sigma_square;
is >> sigma_square;
if (!is.fail()) {
dnd.param(param_type(sigma_square));
}
is.copyfmt(savestate);
return is;
}
#endif