Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 9, 2023
1 parent d567995 commit 7667a4b
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 82 deletions.
23 changes: 14 additions & 9 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def build(
self.filter_precision,
)
self.negative_mask = -(2 << 32) * (1.0 - self.nmask)
# hard coding the magnitude of attention weight shift
# hard coding the magnitude of attention weight shift
self.smth_attn_w_shift = 20.0
# only used when tensorboard was set as true
tf.summary.histogram("descrpt", self.descrpt)
Expand Down Expand Up @@ -601,7 +601,9 @@ def build(
)
self.recovered_r = (
tf.reshape(
tf.slice(tf.reshape(self.descrpt_reshape, [-1, 4]), [0, 0], [-1, 1]),
tf.slice(
tf.reshape(self.descrpt_reshape, [-1, 4]), [0, 0], [-1, 1]
),
[-1, natoms[0], self.sel_all_a[0]],
)
* self.std_looked_up
Expand Down Expand Up @@ -870,18 +872,21 @@ def _scaled_dot_attn(
if self.smooth:
# (nb x nloc) x nsel
nsel = self.sel_all_a[0]
attn = ((attn + self.smth_attn_w_shift) *
tf.reshape(self.recovered_switch, [-1,1,nsel]) *
tf.reshape(self.recovered_switch, [-1,nsel,1]) -
self.smth_attn_w_shift)
attn = (attn + self.smth_attn_w_shift) * tf.reshape(
self.recovered_switch, [-1, 1, nsel]
) * tf.reshape(
self.recovered_switch, [-1, nsel, 1]
) - self.smth_attn_w_shift
else:
attn *= self.nmask
attn += self.negative_mask
attn = tf.nn.softmax(attn, axis=-1)
if self.smooth:
attn = (attn *
tf.reshape(self.recovered_switch, [-1,1,nsel]) *
tf.reshape(self.recovered_switch, [-1,nsel,1]))
attn = (
attn
* tf.reshape(self.recovered_switch, [-1, 1, nsel])
* tf.reshape(self.recovered_switch, [-1, nsel, 1])
)
else:
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])
if save_weights:
Expand Down
93 changes: 50 additions & 43 deletions source/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,77 +530,84 @@ def strerch_box(old_coord, old_box, new_box):
return ncoord.reshape(old_coord.shape)


def finite_difference_fv(sess, energy, feed_dict, t_coord, t_box, delta=1e-6):
"""for energy models, compute f, v by finite difference
"""
def finite_difference_fv(sess, energy, feed_dict, t_coord, t_box, delta=1e-6):
"""For energy models, compute f, v by finite difference."""
base_dict = feed_dict.copy()
coord0 = base_dict.pop(t_coord)
box0 = base_dict.pop(t_box)
fdf = -finite_difference(
lambda coord: sess.run(
energy, feed_dict={**base_dict, t_coord: coord, t_box: box0}
).reshape(-1),
coord0,
delta=delta,
).reshape(-1)
fdv = -(
finite_difference(
lambda box: sess.run(
energy,
feed_dict={
**base_dict,
t_coord: strerch_box(coord0, box0, box),
t_box: box,
},
lambda coord: sess.run(
energy, feed_dict={**base_dict, t_coord: coord, t_box: box0}
).reshape(-1),
box0,
coord0,
delta=delta,
)
.reshape([-1, 3, 3])
.transpose(0, 2, 1)
@ box0.reshape(3, 3)
).reshape(-1)
fdv = -(
finite_difference(
lambda box: sess.run(
energy,
feed_dict={
**base_dict,
t_coord: strerch_box(coord0, box0, box),
t_box: box,
},
).reshape(-1),
box0,
delta=delta,
)
.reshape([-1, 3, 3])
.transpose(0, 2, 1)
@ box0.reshape(3, 3)
).reshape(-1)
return fdf, fdv


def check_continuity(f, cc, rcut, delta):
"""coord[0:2] to [[0, 0, 0], [rcut+-.5*delta, 0, 0]]
"""
cc = cc.reshape([-1,3])
"""coord[0:2] to [[0, 0, 0], [rcut+-.5*delta, 0, 0]]."""
cc = cc.reshape([-1, 3])
cc0 = np.copy(cc)
cc1 = np.copy(cc)
cc0[:2,:] = np.array([
0.0, 0.0, 0.0,
rcut-0.5*delta, 0.0, 0.0,
]).reshape([-1,3])
cc1[:2,:] = np.array([
0.0, 0.0, 0.0,
rcut+0.5*delta, 0.0, 0.0,
]).reshape([-1,3])
cc0[:2, :] = np.array(
[
0.0,
0.0,
0.0,
rcut - 0.5 * delta,
0.0,
0.0,
]
).reshape([-1, 3])
cc1[:2, :] = np.array(
[
0.0,
0.0,
0.0,
rcut + 0.5 * delta,
0.0,
0.0,
]
).reshape([-1, 3])
return f(cc0.reshape(-1)), f(cc1.reshape(-1))


def check_smooth_efv(sess, energy, force, virial, feed_dict, t_coord, rcut, delta=1e-5):
"""check the smoothness of e, f and v
"""Check the smoothness of e, f and v
the returned values are de, df, dv
de[0] are supposed to be closed to de[1]
df[0] are supposed to be closed to df[1]
dv[0] are supposed to be closed to dv[1]
dv[0] are supposed to be closed to dv[1].
"""
base_dict = feed_dict.copy()
coord0 = base_dict.pop(t_coord)
[fe, ff, fv] = [
lambda coord: sess.run(
ii, feed_dict={**base_dict, t_coord: coord}
).reshape(-1)
for ii in [energy, force, virial]
lambda coord: sess.run(ii, feed_dict={**base_dict, t_coord: coord}).reshape(-1)
for ii in [energy, force, virial]
]
[de, df, dv] = [
check_continuity(ii, coord0, rcut, delta=delta)
for ii in [fe, ff, fv]
check_continuity(ii, coord0, rcut, delta=delta) for ii in [fe, ff, fv]
]
return de, df, dv


def run_dp(cmd: str) -> int:
"""Run DP directly from the entry point instead of the subprocess.
Expand Down
67 changes: 37 additions & 30 deletions source/tests/test_model_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
from common import (
DataSystem,
check_smooth_efv,
finite_difference_fv,
gen_data,
j_loader,
)
Expand All @@ -28,7 +30,6 @@
from deepmd.utils.type_embed import (
TypeEmbedNet,
)
from common import finite_difference_fv, check_smooth_efv

GLOBAL_ENER_FLOAT_PRECISION = tf.float64
GLOBAL_TF_FLOAT_PRECISION = tf.float64
Expand Down Expand Up @@ -728,10 +729,8 @@ def test_stripped_type_embedding_exclude_types(self):
with self.assertRaises(AssertionError):
np.testing.assert_almost_equal(des[:, 2:6], 0.0, 10)


def test_smoothness_of_stripped_type_embedding_smooth_model(self):
"""test: auto-diff, continuity of e,f,v
"""
"""test: auto-diff, continuity of e,f,v."""
jfile = "water_se_atten.json"
jdata = j_loader(jfile)

Expand Down Expand Up @@ -834,34 +833,42 @@ def test_smoothness_of_stripped_type_embedding_smooth_model(self):
eps = 1e-4
delta = 1e-5
fdf, fdv = finite_difference_fv(
sess, energy, feed_dict_test, t_coord, t_box, delta=eps)
sess, energy, feed_dict_test, t_coord, t_box, delta=eps
)
np.testing.assert_allclose(pf, fdf, delta)
np.testing.assert_allclose(pv, fdv, delta)

tested_eps = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
for eps in tested_eps:
deltae = eps
deltad = eps
de, df, dv = check_smooth_efv(
sess, energy, force, virial,
feed_dict_test, t_coord,
jdata["model"]["descriptor"]["rcut"],
delta=eps,
)
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

deltae = eps
deltad = eps
de, df, dv = check_smooth_efv(
sess,
energy,
force,
virial,
feed_dict_test,
t_coord,
jdata["model"]["descriptor"]["rcut"],
delta=eps,
)
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

for eps in tested_eps:
deltae = 5.*eps
deltad = 5.*eps
de, df, dv = check_smooth_efv(
sess, energy, force, virial,
feed_dict_test, t_coord,
jdata["model"]["descriptor"]["rcut_smth"],
delta=eps,
)
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

deltae = 5.0 * eps
deltad = 5.0 * eps
de, df, dv = check_smooth_efv(
sess,
energy,
force,
virial,
feed_dict_test,
t_coord,
jdata["model"]["descriptor"]["rcut_smth"],
delta=eps,
)
np.testing.assert_allclose(de[0], de[1], rtol=0, atol=deltae)
np.testing.assert_allclose(df[0], df[1], rtol=0, atol=deltad)
np.testing.assert_allclose(dv[0], dv[1], rtol=0, atol=deltad)

0 comments on commit 7667a4b

Please sign in to comment.