Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: nequip model in batch #28

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 100 additions & 100 deletions deepmd_gnn/nequip.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
build_neighbor_list,
extend_input_and_build_neighbor_list,
)
from deepmd.pt.utils.region import (
phys2inter,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
)
Expand Down Expand Up @@ -487,111 +484,114 @@ def forward_lower_common(
extended_atype = extended_atype.to(torch.int64)
nall = extended_coord.shape[1]

# loop on nf
energies = []
forces = []
virials = []
atom_energies = []
atomic_virials = []
for ff in range(nf):
extended_coord_ff = extended_coord[ff]
extended_atype_ff = extended_atype[ff]
nlist_ff = nlist[ff]
edge_index = torch.ops.deepmd_gnn.edge_index(
nlist_ff,
extended_atype_ff,
torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"),
)
edge_index = edge_index.T
# Nequip and MACE have different defination for edge_index
edge_index = edge_index[[1, 0]]

# nequip can convert dtype by itself
default_dtype = torch.float64
extended_coord_ff = extended_coord_ff.to(default_dtype)
extended_coord_ff.requires_grad_(True) # noqa: FBT003

input_dict = {
"pos": extended_coord_ff,
"edge_index": edge_index,
"atom_types": extended_atype_ff,
}
if box is not None and mapping is not None:
# pass box, map edge index to real
box_ff = box[ff].to(extended_coord_ff.device)
input_dict["cell"] = box_ff
input_dict["pbc"] = torch.zeros(
3,
dtype=torch.bool,
device=box_ff.device,
)
shifts_atoms = extended_coord_ff - extended_coord_ff[mapping[ff]]
shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]]
edge_index = mapping[ff][edge_index]
input_dict["edge_index"] = edge_index
edge_cell_shift = phys2inter(shifts, box_ff.view(3, 3))
input_dict["edge_cell_shift"] = edge_cell_shift

ret = self.model.forward(
input_dict,
)

atom_energy = ret["atomic_energy"]
if atom_energy is None:
msg = "atom_energy is None"
raise ValueError(msg)
atom_energy = atom_energy.view(1, nall).to(extended_coord_.dtype)[:, :nloc]
# adds e0
atom_energy = atom_energy + self.e0[extended_atype_ff[:nloc]].view(
1,
nloc,
).to(
atom_energy.dtype,
# fake as one frame
extended_coord_ff = extended_coord.view(nf * nall, 3)
extended_atype_ff = extended_atype.view(nf * nall)
edge_index = torch.ops.deepmd_gnn.edge_index(
nlist,
extended_atype,
torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"),
)
edge_index = edge_index.T
# Nequip and MACE have different defination for edge_index
edge_index = edge_index[[1, 0]]

# nequip can convert dtype by itself
default_dtype = torch.float64
extended_coord_ff = extended_coord_ff.to(default_dtype)
extended_coord_ff.requires_grad_(True) # noqa: FBT003

input_dict = {
"pos": extended_coord_ff,
"edge_index": edge_index,
"atom_types": extended_atype_ff,
}
if box is not None and mapping is not None:
# pass box, map edge index to real
box_ff = box.to(extended_coord_ff.device)
input_dict["cell"] = box_ff
input_dict["pbc"] = torch.zeros(
3,
dtype=torch.bool,
device=box_ff.device,
)
energy = torch.sum(atom_energy, dim=1).view(1, 1).to(extended_coord_.dtype)
grad_outputs: list[Optional[torch.Tensor]] = [
torch.ones_like(energy),
]
force = torch.autograd.grad(
outputs=[energy],
inputs=[extended_coord_ff],
grad_outputs=grad_outputs,
retain_graph=True,
create_graph=self.training,
)[0]
if force is None:
msg = "force is None"
raise ValueError(msg)
force = -force
atomic_virial = force.unsqueeze(-1).to(
extended_coord_.dtype,
) @ extended_coord_ff.unsqueeze(-2).to(
extended_coord_.dtype,
batch = torch.arange(nf, device=box_ff.device).repeat(nall)
input_dict["batch"] = batch
ptr = torch.arange(
start=0,
end=nf * nall + 1,
step=nall,
dtype=torch.int64,
device=batch.device,
)
force = force.view(1, nall, 3).to(extended_coord_.dtype)
virial = (
torch.sum(atomic_virial, dim=0).view(1, 9).to(extended_coord_.dtype)
input_dict["ptr"] = ptr
mapping_ff = mapping.view(nf * nall) + torch.arange(
0,
nf * nall,
nall,
dtype=mapping.dtype,
device=mapping.device,
).unsqueeze(-1).expand(nf, nall).reshape(-1)
shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff]
shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]]
edge_index = mapping_ff[edge_index]
input_dict["edge_index"] = edge_index
rec_cell, _ = torch.linalg.inv_ex(box.view(nf, 3, 3))
edge_cell_shift = torch.einsum(
"ni,nij->nj",
shifts,
rec_cell[batch[edge_index[0]]],
)
input_dict["edge_cell_shift"] = edge_cell_shift

ret = self.model.forward(
input_dict,
)

energies.append(energy)
forces.append(force)
virials.append(virial)
atom_energies.append(atom_energy)
atomic_virials.append(atomic_virial)
energies_t = torch.cat(energies, dim=0)
forces_t = torch.cat(forces, dim=0)
virials_t = torch.cat(virials, dim=0)
atom_energies_t = torch.cat(atom_energies, dim=0)
atomic_virials_t = torch.cat(atomic_virials, dim=0)
atom_energy = ret["atomic_energy"]
if atom_energy is None:
msg = "atom_energy is None"
raise ValueError(msg)
atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc]
# adds e0
atom_energy = atom_energy + self.e0[extended_atype[:, :nloc]].view(
nf,
nloc,
).to(
atom_energy.dtype,
)
energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype)
grad_outputs: list[Optional[torch.Tensor]] = [
torch.ones_like(energy),
]
force = torch.autograd.grad(
outputs=[energy],
inputs=[extended_coord_ff],
grad_outputs=grad_outputs,
retain_graph=True,
create_graph=self.training,
)[0]
if force is None:
msg = "force is None"
raise ValueError(msg)
force = -force
atomic_virial = force.unsqueeze(-1).to(
extended_coord_.dtype,
) @ extended_coord_ff.unsqueeze(-2).to(
extended_coord_.dtype,
)
force = force.view(nf, nall, 3).to(extended_coord_.dtype)
atomic_virial = atomic_virial.view(nf, nall, 1, 9)
virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype)

return {
"energy_redu": energies_t.view(nf, 1),
"energy_derv_r": forces_t.view(nf, nall, 1, 3),
"energy_derv_c_redu": virials_t.view(nf, 1, 9),
"energy_redu": energy.view(nf, 1),
"energy_derv_r": force.view(nf, nall, 1, 3),
"energy_derv_c_redu": virial.view(nf, 1, 9),
# take the first nloc atoms to match other models
"energy": atom_energies_t.view(nf, nloc, 1),
"energy": atom_energy.view(nf, nloc, 1),
# fake atom_virial
"energy_derv_c": atomic_virials_t.view(nf, nall, 1, 9),
"energy_derv_c": atomic_virial.view(nf, nall, 1, 9),
}

def serialize(self) -> dict:
Expand Down