Skip to content

Commit

Permalink
Merge pull request #64 from necludov:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 556771559
Change-Id: I7b7ee318642e70f9b49d8c1934d2f0aaaed80489
  • Loading branch information
jsspencer committed Aug 14, 2023
2 parents 6d17339 + dda112b commit c2f345d
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 12 deletions.
1 change: 1 addition & 0 deletions ferminet/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def default() -> ml_collections.ConfigDict:
# importlib.import_module.
'config_module': __name__,
'optim': {
'objective': 'vmc', # objective type. Either 'vmc' or 'wqmc'
'iterations': 1000000, # number of iterations
'optimizer': 'kfac', # one of adam, kfac, lamb, none
'lr': {
Expand Down
85 changes: 85 additions & 0 deletions ferminet/configs/li_wqmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Generic single-atom configuration for FermiNet."""

from ferminet import base_config
from ferminet.utils import elements
from ferminet.utils import system
import ml_collections


def _adjust_nuclear_charge(cfg):
"""Sets the molecule, nuclear charge electrons for the atom.
Note: function name predates this logic but is kept for compatibility with
xm_expt.py.
Args:
cfg: ml_collections.ConfigDict after all argument parsing.
Returns:
ml_collections.ConfictDict with the nuclear charge for the atom in
cfg.system.molecule and cfg.system.charge appropriately set.
"""
if cfg.system.molecule:
atom = cfg.system.molecule[0]
else:
atom = system.Atom(symbol=cfg.system.atom, coords=(0, 0, 0))

if abs(cfg.system.delta_charge) > 1.0e-8:
nuclear_charge = atom.charge + cfg.system.delta_charge
cfg.system.molecule = [
system.Atom(atom.symbol, atom.coords, nuclear_charge)
]
else:
cfg.system.molecule = [atom]

if not cfg.system.electrons:
atomic_number = elements.SYMBOLS[atom.symbol].atomic_number
if 'charge' in cfg.system:
atomic_number -= cfg.system.charge
if (
'spin_polarisation' in cfg.system
and cfg.system.spin_polarisation is not None
):
spin_polarisation = cfg.system.spin_polarisation
else:
spin_polarisation = elements.ATOMIC_NUMS[atomic_number].spin_config
nalpha = (atomic_number + spin_polarisation) // 2
cfg.system.electrons = (nalpha, atomic_number - nalpha)

return cfg


def get_config():
"""Returns config for running generic atoms with qmc."""
cfg = base_config.default()
cfg.system.atom = 'Li'
cfg.system.charge = 0
cfg.system.delta_charge = 0.0
cfg.system.spin_polarisation = ml_collections.FieldReference(
None, field_type=int
)
with cfg.ignore_type():
cfg.system.set_molecule = _adjust_nuclear_charge
cfg.config_module = '.atom'
cfg.network.network_type = 'psiformer'
cfg.optim.iterations = 10_000
cfg.optim.lr.delay = 5_000
cfg.optim.clip_median = True
cfg.debug.deterministic = True
cfg.optim.kfac.norm_constraint = 1e-3
cfg.optim.objective = 'wqmc'
return cfg
6 changes: 4 additions & 2 deletions ferminet/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def potential_electron_electron(r_ee: Array) -> jnp.ndarray:
between electrons i and j. Other elements in the final axes are not
required.
"""
return jnp.sum(jnp.triu(1 / r_ee[..., 0], k=1))
r_ee = r_ee[jnp.triu_indices_from(r_ee[..., 0], 1)]
return (1.0 / r_ee).sum()


def potential_electron_nuclear(charges: Array, r_ae: Array) -> jnp.ndarray:
Expand Down Expand Up @@ -233,7 +234,8 @@ def _e_l(
"""
del key # unused
_, _, r_ae, r_ee = networks.construct_input_features(
data.positions, data.atoms)
data.positions, data.atoms
)
potential = potential_energy(r_ae, r_ee, data.atoms, charges)
kinetic = ke(params, data)
return potential + kinetic
Expand Down
151 changes: 150 additions & 1 deletion ferminet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ class AuxiliaryLossData:
variance: mean variance over batch, and over all devices if inside a pmap.
local_energy: local energy for each MCMC configuration.
clipped_energy: local energy after clipping has been applied
grad_local_energy: gradient of the local energy.
"""
variance: jax.Array
local_energy: jax.Array
clipped_energy: jax.Array
grad_local_energy: jax.Array | None


class LossFn(Protocol):
Expand Down Expand Up @@ -203,7 +205,11 @@ def total_energy(
loss_diff = e_l - loss
variance = constants.pmean(jnp.mean(loss_diff * jnp.conj(loss_diff)))
return loss, AuxiliaryLossData(
variance=variance.real, local_energy=e_l, clipped_energy=e_l)
variance=variance.real,
local_energy=e_l,
clipped_energy=e_l,
grad_local_energy=None,
)

@total_energy.defjvp
def total_energy_jvp(primals, tangents): # pylint: disable=unused-variable
Expand Down Expand Up @@ -255,3 +261,146 @@ def total_energy_jvp(primals, tangents): # pylint: disable=unused-variable
return primals_out, tangents_out

return total_energy


def make_wqmc_loss(
network: networks.LogFermiNetLike,
local_energy: hamiltonian.LocalEnergy,
clip_local_energy: float = 0.0,
clip_from_median: bool = True,
center_at_clipped_energy: bool = True,
complex_output: bool = False,
) -> LossFn:
"""Creates the WQMC loss function, including custom gradients.
Args:
network: callable which evaluates the log of the magnitude of the
wavefunction (square root of the log probability distribution) at a single
MCMC configuration given the network parameters.
local_energy: callable which evaluates the local energy.
clip_local_energy: If greater than zero, clip local energies that are
outside [E_L - n D, E_L + n D], where E_L is the mean local energy, n is
this value and D the mean absolute deviation of the local energies from
the mean, to the boundaries. The clipped local energies are only used to
evaluate gradients.
clip_from_median: If true, center the clipping window at the median rather
than the mean. Potentially expensive in multi-host training, but more
accurate.
center_at_clipped_energy: If true, center the local energy differences
passed back to the gradient around the clipped local energy, so the mean
difference across the batch is guaranteed to be zero.
complex_output: If true, the local energies will be complex valued.
Returns:
Callable with signature (params, data) and returns (loss, aux_data), where
loss is the mean energy, and aux_data is an AuxiliaryLossDataobject. The
loss is averaged over the batch and over all devices inside a pmap.
"""
batch_local_energy = jax.vmap(
local_energy,
in_axes=(
None,
0,
networks.FermiNetData(positions=0, spins=0, atoms=0, charges=0),
),
out_axes=0,
)
batch_network = jax.vmap(network, in_axes=(None, 0, 0, 0, 0), out_axes=0)

@jax.custom_jvp
def total_energy(
params: networks.ParamTree,
key: chex.PRNGKey,
data: networks.FermiNetData,
) -> Tuple[jnp.ndarray, AuxiliaryLossData]:
"""Evaluates the total energy of the network for a batch of configurations.
Note: the signature of this function is fixed to match that expected by
kfac_jax.optimizer.Optimizer with value_func_has_rng=True and
value_func_has_aux=True.
Args:
params: parameters to pass to the network.
key: PRNG state.
data: Batched MCMC configurations to pass to the local energy function.
Returns:
(loss, aux_data), where loss is the mean energy, and aux_data is an
AuxiliaryLossData object containing the variance of the energy and the
local energy per MCMC configuration. The loss and variance are averaged
over the batch and over all devices inside a pmap.
"""
keys = jax.random.split(key, num=data.positions.shape[0])
e_l = batch_local_energy(params, keys, data)
loss = constants.pmean(jnp.mean(e_l))
loss_diff = e_l - loss
variance = constants.pmean(jnp.mean(loss_diff * jnp.conj(loss_diff)))

def batch_local_energy_pos(pos):
network_data = networks.FermiNetData(
positions=pos,
spins=data.spins,
atoms=data.atoms,
charges=data.charges,
)
return batch_local_energy(params, keys, network_data).sum()

grad_e_l = jax.grad(batch_local_energy_pos)(data.positions)
grad_e_l = jnp.tanh(jax.lax.stop_gradient(grad_e_l))
return loss, AuxiliaryLossData(
variance=variance.real,
local_energy=e_l,
clipped_energy=e_l,
grad_local_energy=grad_e_l,
)

@total_energy.defjvp
def total_energy_jvp(primals, tangents): # pylint: disable=unused-variable
"""Custom Jacobian-vector product for unbiased local energy gradients."""
params, key, data = primals
loss, aux_data = total_energy(params, key, data)

if clip_local_energy > 0.0:
aux_data.clipped_energy, diff = clip_local_values(
aux_data.local_energy,
loss,
clip_local_energy,
clip_from_median,
center_at_clipped_energy,
complex_output,
)
else:
diff = aux_data.local_energy - loss

def log_q(params_, pos_, spins_, atoms_, charges_):
out = batch_network(params_, pos_, spins_, atoms_, charges_)
kfac_jax.register_normal_predictive_distribution(out[:, None])
return out.sum()

score = jax.grad(log_q, argnums=1)
primals = (params, data.positions, data.spins, data.atoms, data.charges)
tangents = (
tangents[0],
tangents[2].positions,
tangents[2].spins,
tangents[2].atoms,
tangents[2].charges,
)
score_primal, score_tangent = jax.jvp(score, primals, tangents)

score_norm = jnp.linalg.norm(score_primal, axis=-1, keepdims=True)
median = jnp.median(constants.all_gather(score_norm))
deviation = jnp.mean(jnp.abs(score_norm - median))
mask = score_norm < (median + 5 * deviation)
log_q_tangent_out = (aux_data.grad_local_energy * score_tangent * mask).sum(
axis=1
)
log_q_tangent_out *= len(mask) / mask.sum()

_, psi_tangent = jax.jvp(batch_network, primals, tangents)
log_q_tangent_out += diff * psi_tangent
primals_out = loss, aux_data
tangents_out = (log_q_tangent_out.mean(), aux_data)
return primals_out, tangents_out

return total_energy
1 change: 0 additions & 1 deletion ferminet/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def construct_input_features(
n = ee.shape[0]
r_ee = (
jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n)))

return ae, ee, r_ae, r_ee[..., None]


Expand Down
29 changes: 21 additions & 8 deletions ferminet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,14 +571,27 @@ def log_network(*args, **kwargs):
nspins=nspins,
use_scan=False,
complex_output=cfg.network.get('complex', False))
evaluate_loss = qmc_loss_functions.make_loss(
log_network if cfg.network.get('complex', False) else logabs_network,
local_energy,
clip_local_energy=cfg.optim.clip_local_energy,
clip_from_median=cfg.optim.clip_median,
center_at_clipped_energy=cfg.optim.center_at_clip,
complex_output=cfg.network.get('complex', False)
)
if cfg.optim.objective == 'vmc':
evaluate_loss = qmc_loss_functions.make_loss(
log_network if cfg.network.get('complex', False) else logabs_network,
local_energy,
clip_local_energy=cfg.optim.clip_local_energy,
clip_from_median=cfg.optim.clip_median,
center_at_clipped_energy=cfg.optim.center_at_clip,
complex_output=cfg.network.get('complex', False),
)
elif cfg.optim.objective == 'wqmc':
evaluate_loss = qmc_loss_functions.make_wqmc_loss(
log_network if cfg.network.get('complex', False) else logabs_network,
local_energy,
clip_local_energy=cfg.optim.clip_local_energy,
clip_from_median=cfg.optim.clip_median,
center_at_clipped_energy=cfg.optim.center_at_clip,
complex_output=cfg.network.get('complex', False),
)
else:
raise ValueError(f'Not a recognized objective: {cfg.optim.objective}')

# Compute the learning rate
def learning_rate_schedule(t_: jnp.ndarray) -> jnp.ndarray:
return cfg.optim.lr.rate * jnp.power(
Expand Down

0 comments on commit c2f345d

Please sign in to comment.