Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling complex numbers for PairwisePotential #655

Merged
merged 31 commits into from
Jun 28, 2022
Merged
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8a5bd40
Handling complex numbers for PairwisePotential
epolack Apr 29, 2022
281e161
comments from Antoine
epolack Apr 29, 2022
42b82c0
assert
epolack Apr 29, 2022
4198c1c
bug
epolack May 2, 2022
1b9a3c7
Update pairwise.jl
epolack May 3, 2022
20c8a17
workarounds for complex exponentiation
epolack May 3, 2022
ca7f803
Update pairwise.jl
epolack May 3, 2022
98b65a0
Update pairwise.jl
epolack May 5, 2022
4185299
Update pairwise.jl
epolack May 5, 2022
45da86c
Update pairwise.jl
epolack May 5, 2022
d4d48e9
testing ph_disp
epolack Jun 7, 2022
ef14519
comment
epolack Jun 7, 2022
9135a16
renaming
epolack Jun 7, 2022
3f343ea
factorisation
epolack Jun 7, 2022
0157e80
Move estimate_integer_bounds to structure.jl and support 1D and 2D sy…
niklasschmitz Jun 7, 2022
b4beeee
Rewrite pairwise without shelll_indices
niklasschmitz Jun 7, 2022
cbe2980
trim whitespace
niklasschmitz Jun 7, 2022
de7405c
Fix pairwise bound comments
niklasschmitz Jun 7, 2022
86141e0
Fix comment
niklasschmitz Jun 7, 2022
aa2bf98
first batch of modifications
epolack Jun 8, 2022
afc7c14
some more
epolack Jun 8, 2022
a558373
Merge branch 'nfs/pairwise-bounds' into complex_pairwise
epolack Jun 8, 2022
f88dd86
workaround back with a vengeance
epolack Jun 8, 2022
1d1d790
bugfix
epolack Jun 8, 2022
610d404
Update forwarddiff_rules.jl
epolack Jun 9, 2022
284a915
Antoine's comment
epolack Jun 13, 2022
82c3bda
Merge remote-tracking branch 'origin/master' into complex_pairwise
epolack Jun 13, 2022
74908cf
bugfix
epolack Jun 13, 2022
0927546
Update phonons.jl
epolack Jun 28, 2022
1c3d395
Update pairwise.jl
epolack Jun 28, 2022
9200787
Merge remote-tracking branch 'origin/master' into complex_pairwise
epolack Jun 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
workarounds for complex exponentiation
  • Loading branch information
epolack committed Jun 7, 2022
commit 20c8a1764fd4c805ca0ddff6e6e3bc9159b65eeb
23 changes: 9 additions & 14 deletions src/terms/pairwise.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# We cannot use `LinearAlgebra.norm` with complex numbers due to the need to use its
# analytic continuation
function norm_cplx(x)
# TODO: ForwardDiff bug (https://github.com/JuliaDiff/ForwardDiff.jl/issues/324)
sqrt(sum(x.*x))
sqrt(sum(x.^2))
end

struct PairwisePotential
@@ -17,10 +16,10 @@ Lennard—Jones terms.
The potential is dependent on the distance between to atomic positions and the pairwise
atomic types:
For a distance `d` between to atoms `A` and `B`, the potential is `V(d, params[(A, B)])`.
The parameters `max_radius` is of `1000` by default, and gives the maximum (Cartesian)
distance between nuclei for which we consider interactions.
The parameters `max_radius` is of `100` by default, and gives the maximum distance (in
Cartesian coordinates) between nuclei for which we consider interactions.
"""
function PairwisePotential(V, params; max_radius=1000)
function PairwisePotential(V, params; max_radius=100)
params = Dict(minmax(key[1], key[2]) => value for (key, value) in params)
PairwisePotential(V, params, max_radius)
end
@@ -43,8 +42,7 @@ end
@timing "forces: Pairwise" function compute_forces(term::TermPairwisePotential,
basis::PlaneWaveBasis{T}, ψ, occ;
kwargs...) where {T}
TT = promote_type(T, eltype(basis.model.positions[1]))
forces = zero(TT, basis.model.positions)
forces = zero(basis.model.positions)
energy_pairwise(basis.model, term.V, term.params; max_radius=term.max_radius,
forces=forces, kwargs...)
forces
@@ -65,15 +63,13 @@ end

# This could be factorised with Ewald, but the use of `symbols` would slow down the
# computationally intensive Ewald sums. So we leave it as it for now.
# TODO: *Beware* of using ForwardDiff to derive this function with complex numbers, use
# multiplications and not powers (https://github.com/JuliaDiff/ForwardDiff.jl/issues/324).
# `q` is the phonon `q`-point (`Vec3`), and `ph_disp` a list of `Vec3` displacements to
# compute the Fourier transform of the force constant matrix.
function energy_pairwise(lattice, symbols, positions, V, params;
max_radius=1000, forces=nothing, ph_disp=nothing, q=nothing)
max_radius=100, forces=nothing, ph_disp=nothing, q=nothing)
@assert length(symbols) == length(positions)

T = eltype(lattice)
T = eltype(positions[1])
if ph_disp !== nothing
@assert q !== nothing
T = promote_type(complex(T), eltype(ph_disp[1]))
@@ -135,9 +131,8 @@ function energy_pairwise(lattice, symbols, positions, V, params;
sum_pairwise += energy_contribution
if forces !== nothing
dE_ddist = ForwardDiff.derivative(real(zero(eltype(dist)))) do ε
res = V(dist + ε, param_ij)
[real(res), imag(res)]
end |> x -> complex(x...)
V(dist + ε, param_ij)
end
dE_dti = lattice' * ((dE_ddist / dist) * Δr)
# We need to "break" the symmetry for phonons; at equilibrium, expect
# the forces to be zero at machine precision.
29 changes: 29 additions & 0 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
@@ -272,3 +272,32 @@ function Smearing.occupation(S::Smearing.FermiDirac, d::ForwardDiff.Dual{T}) whe
end
ForwardDiff.Dual{T}(Smearing.occupation(S, x), ∂occ * ForwardDiff.partials(d))
end

# Workarounds for issue https://github.com/JuliaDiff/ForwardDiff.jl/issues/324
ForwardDiff.derivative(f, x::Complex) = throw(DimensionMismatch("derivative(f, x) expects that x is a real number (does not support Wirtinger derivatives). Separate real and imaginary parts of the input."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fixes are not for the issue above. Is the fix not already upstream in a released version? If yes, only define the functions if the version is below the one that has the fix (and then at some point we remove the code and depend on a new version)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This JuliaDiff/ForwardDiff.jl#577 seems to fix the exponentiation problem. So maybe remove it altogether?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it fix JuliaDiff/ForwardDiff.jl#514 (comment)? Since my message was after that PR was merge I don't think so?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I thought is was after. This is strange, something seems to have fixed the issue, because I am not running into it anymore, even for this test:

using ForwardDiff, FiniteDifferences, Random, Test

v = randn()
p, m = randn(ComplexF64), randn(ComplexF64)

for f in (x -> (x*m)^p,
          x -> m^(p*x),
          x -> (x*m)^(p*x))
  @test ≈(ForwardDiff.derivative(f, v), central_fdm(5,1)(f, v), atol=1e-8)
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the code in my comment in the forwarddiff tracker above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the issue was very specific: when differentiating at a real number in a complex direction.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, didn't test the right thing 🙄.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the bug was due to something we talked about earlier if the types of x and y differ. Let's keep it that way.

@inline ForwardDiff.extract_derivative(::Type{T}, y::Complex) where {T} = zero(y)
@inline function ForwardDiff.extract_derivative(::Type{T}, y::Complex{TD}) where {T, TD <: ForwardDiff.Dual}
complex(ForwardDiff.partials(T, real(y), 1), ForwardDiff.partials(T, imag(y), 1))
end
function Base.:^(x::Complex{ForwardDiff.Dual{T,V,N}}, y::Complex{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
xx = complex(ForwardDiff.value(real(x)), ForwardDiff.value(imag(x)))
yy = complex(ForwardDiff.value(real(y)), ForwardDiff.value(imag(y)))
dx = complex.(ForwardDiff.partials(real(x)), ForwardDiff.partials(imag(x)))
dy = complex.(ForwardDiff.partials(real(y)), ForwardDiff.partials(imag(y)))

expv = xx^yy
∂expv∂x = yy * xx^(yy-1)
∂expv∂y = log(xx) * expv
dxexpv = ∂expv∂x * dx
# TODO: Fishy and should be checked, but seems to catch most cases
if iszero(xx) && ForwardDiff.isconstant(real(y)) && ForwardDiff.isconstant(imag(y)) && imag(y) === zero(imag(y)) && real(y) > 0
dexpv = zero(expv)
elseif iszero(xx)
throw(DomainError(x, "mantissa cannot be zero for complex exponentiation"))
else
dyexpv = ∂expv∂y * dy
dexpv = dxexpv + dyexpv
end
complex(ForwardDiff.Dual{T,V,N}(real(expv), ForwardDiff.Partials{N,V}(tuple(real(dexpv)...))),
ForwardDiff.Dual{T,V,N}(imag(expv), ForwardDiff.Partials{N,V}(tuple(imag(dexpv)...))))
end