-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathlagrange.jl
185 lines (167 loc) · 5.26 KB
/
lagrange.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Inspired by `Hypatia.jl/src/PolyUtils/realinterp.jl`
export BoxSampling
abstract type AbstractNodes{T,V} end
mutable struct BoxSampling{T,V} <: AbstractNodes{T,V}
lower::V
upper::V
sample_factor::Int
function BoxSampling(lower, upper; sample_factor = 0)
@assert length(lower) == length(upper)
l = float.(lower)
u = float.(upper)
V = promote_type(typeof(l), typeof(u))
return new{eltype(V),V}(l, u, sample_factor)
end
end
function sample(s::BoxSampling{T}, n::Integer) where {T}
samples = rand(T, length(s.lower), n) .- inv(T(2))
shift = (s.upper .+ s.lower) .* inv(T(2))
for i in 1:n
for j in eachindex(s.lower)
samples[j, i] = samples[j, i] * (s.upper[j] - s.lower[j]) + shift[j]
end
end
return samples
end
struct LagrangePolynomial{T,P,V}
variables::V
point::P
function LagrangePolynomial(variables, point)
return new{eltype(point),typeof(point),typeof(variables)}(
variables,
point,
)
end
end
struct ImplicitLagrangeBasis{T,P,N<:AbstractNodes{T,P},V} <:
SA.ImplicitBasis{LagrangePolynomial{T,P,V},Pair{V,P}}
variables::V
nodes::AbstractNodes{T,P}
function ImplicitLagrangeBasis(
variables,
nodes::AbstractNodes{T,P},
) where {T,P}
return new{T,P,typeof(nodes),typeof(variables)}(variables, nodes)
end
end
function Base.getindex(
implicit::ImplicitLagrangeBasis{T,P,N,V},
subs::Pair{V,P},
) where {T,P,N,V}
if subs.first != implicit.variables
error(
"Variables `$(subs.first)` do not match Lagrange basis variables `$(implicit.variables)`",
)
end
return LagrangePolynomial(implicit.variables, subs.second)
end
struct LagrangeBasis{T,P,U<:AbstractVector{P},V} <:
SA.ExplicitBasis{LagrangePolynomial{T,P,V},Int}
variables::V
points::U
function LagrangeBasis(variables, points::AbstractVector)
P = eltype(points)
return new{eltype(P),P,typeof(points),typeof(variables)}(
variables,
points,
)
end
end
Base.length(basis::LagrangeBasis) = length(basis.points)
MP.nvariables(basis::LagrangeBasis) = length(basis.variables)
MP.variables(basis::LagrangeBasis) = basis.variables
function explicit_basis_type(
::Type{<:ImplicitLagrangeBasis{T,_P,N,V}},
) where {T,_P,N,V}
points = _eachcol(ones(T, 1, 1))
P = eltype(points)
return LagrangeBasis{eltype(P),P,typeof(points),V}
end
function eval_basis!(
univariate_buffer,
result,
basis::SubBasis{B},
variables,
values,
) where {B}
for v in MP.variables(basis)
if !(v in variables)
error(
"Cannot evaluate `$basis` as its variable `$v` is not part of the variables `$variables` of the `LagrangeBasis`",
)
end
end
for i in eachindex(values)
l = MP.maxdegree(basis.monomials, variables[i]) + 1
univariate_eval!(B, view(univariate_buffer, 1:l, i), values[i])
end
for i in eachindex(basis)
result[i] = one(eltype(result))
for j in eachindex(values)
d = MP.degree(basis.monomials[i], variables[j])
result[i] = MA.operate!!(*, result[i], univariate_buffer[d+1, j])
end
end
return result
end
function transformation_to(basis::SubBasis, lag::LagrangeBasis{T}) where {T}
# To avoid allocating this too often, we allocate it once here
# and reuse it for each sample
univariate_buffer = Matrix{T}(undef, length(basis), MP.nvariables(lag))
V = Matrix{T}(undef, length(lag), length(basis))
for i in eachindex(lag)
eval_basis!(
univariate_buffer,
view(V, i, :),
basis,
MP.variables(lag),
lag.points[i],
)
end
return V
end
# Heuristic taken from Hypatia
function num_samples(sample_factor, dim)
if iszero(sample_factor)
if dim <= 12_000
sample_factor = 10
elseif dim <= 15_000
sample_factor = 5
elseif dim <= 22_000
sample_factor = 2
else
sample_factor = 1
end
end
return sample_factor * dim
end
if VERSION >= v"1.10"
_eachcol(x) = eachcol(x)
_column_norm() = LinearAlgebra.ColumnNorm()
else
# It is a `Base.Generator` so not an `AbstractVector`
_eachcol(x) = collect(eachcol(x))
_column_norm() = Val(true)
end
function sample(variables, s::AbstractNodes, basis::SubBasis)
samples = sample(s, num_samples(s.sample_factor, length(basis)))
full = LagrangeBasis(variables, _eachcol(samples))
V = transformation_to(basis, full)
F = LinearAlgebra.qr!(Matrix(V'), _column_norm())
kept_indices = F.p[1:length(basis)]
return LagrangeBasis(variables, _eachcol(samples[:, kept_indices]))
end
function explicit_basis_covering(
implicit::ImplicitLagrangeBasis,
basis::SubBasis,
)
return sample(implicit.variables, implicit.nodes, basis)
end
function SA.coeffs(coeffs, source::SubBasis, target::LagrangeBasis)
return transformation_to(source, target) * coeffs
end
function SA.coeffs(coeffs, implicit::FullBasis, target::LagrangeBasis)
a = algebra_element(coeffs, implicit)
explicit = explicit_basis(a)
return SA.coeffs(SA.coeffs(a, explicit), explicit, target)
end