Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Jan 13, 2025
1 parent 14be2ae commit 40c0ec8
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 42 deletions.
7 changes: 6 additions & 1 deletion mace/cli/select_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def main():
else:

if args.output_file is None:
args.output_file = args.model_file + "." + args.head_name + ("." + args.target_device if (args.target_device is not None) else "")
args.output_file = (
args.model_file
+ "."
+ args.head_name
+ ("." + args.target_device if (args.target_device is not None) else "")
)

model_single = remove_pt_head(model, args.head_name)
if args.target_device is not None:
Expand Down
2 changes: 1 addition & 1 deletion mace/modules/irreps_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
out.append(field)

if hasattr(self, "cueq_config"):
if self.cueq_config is not None:
if self.cueq_config is not None: # pylint: disable=no-else-return
if self.cueq_config.layout_str == "mul_ir":
return torch.cat(out, dim=-1)
return torch.cat(out, dim=-2)
Expand Down
22 changes: 17 additions & 5 deletions tests/modules/test_radial.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import pytest
import torch
from mace.modules.radial import ZBLBasis, AgnesiTransform

from mace.modules.radial import AgnesiTransform, ZBLBasis


@pytest.fixture
def zbl_basis():
return ZBLBasis(p=6, trainable=False)


def test_zbl_basis_initialization(zbl_basis):
assert zbl_basis.p == torch.tensor(6.0)
assert torch.allclose(zbl_basis.c, torch.tensor([0.1818, 0.5099, 0.2802, 0.02817]))
Expand All @@ -15,6 +18,7 @@ def test_zbl_basis_initialization(zbl_basis):
assert not zbl_basis.a_exp.requires_grad
assert not zbl_basis.a_prefactor.requires_grad


def test_trainable_zbl_basis_initialization(zbl_basis):
zbl_basis = ZBLBasis(p=6, trainable=True)
assert zbl_basis.p == torch.tensor(6.0)
Expand All @@ -25,9 +29,12 @@ def test_trainable_zbl_basis_initialization(zbl_basis):
assert zbl_basis.a_exp.requires_grad
assert zbl_basis.a_prefactor.requires_grad


def test_forward(zbl_basis):
x = torch.tensor([1.0, 1.0, 2.0]).unsqueeze(-1) # [n_edges]
node_attrs = torch.tensor([[1, 0], [0, 1]]) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers
node_attrs = torch.tensor(
[[1, 0], [0, 1]]
) # [n_nodes, n_node_features] - one_hot encoding of atomic numbers
edge_index = torch.tensor([[0, 1, 1], [1, 0, 1]]) # [2, n_edges]
atomic_numbers = torch.tensor([1, 6]) # [n_nodes]
output = zbl_basis(x, node_attrs, edge_index, atomic_numbers)
Expand All @@ -37,13 +44,15 @@ def test_forward(zbl_basis):
assert torch.allclose(
output,
torch.tensor([0.0031, 0.0031], dtype=torch.get_default_dtype()),
rtol=1e-2
rtol=1e-2,
)


@pytest.fixture
def agnesi():
return AgnesiTransform(trainable=False)


def test_agnesi_transform_initialization(agnesi: AgnesiTransform):
assert agnesi.q.item() == pytest.approx(0.9183, rel=1e-4)
assert agnesi.p.item() == pytest.approx(4.5791, rel=1e-4)
Expand All @@ -52,6 +61,7 @@ def test_agnesi_transform_initialization(agnesi: AgnesiTransform):
assert not agnesi.q.requires_grad
assert not agnesi.p.requires_grad


def test_trainable_agnesi_transform_initialization():
agnesi = AgnesiTransform(trainable=True)

Expand All @@ -62,6 +72,7 @@ def test_trainable_agnesi_transform_initialization():
assert agnesi.q.requires_grad
assert agnesi.p.requires_grad


def test_agnesi_transform_forward():
agnesi = AgnesiTransform()
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.get_default_dtype()).unsqueeze(-1)
Expand All @@ -76,8 +87,9 @@ def test_agnesi_transform_forward():
torch.tensor(
[0.3646, 0.2175, 0.2089], dtype=torch.get_default_dtype()
).unsqueeze(-1),
rtol=1e-2
rtol=1e-2,
)


if __name__ == "__main__":
pytest.main([__file__])
pytest.main([__file__])
69 changes: 34 additions & 35 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import os
from typing import Optional
from pathlib import Path
from typing import List, Optional

import pandas as pd
import json
import pytest
import torch
from ase import build

from mace import data
from mace import data as mace_data
from mace.calculators.foundations_models import mace_mp
from mace.tools import AtomicNumberTable, torch_geometric, torch_tools

Expand Down Expand Up @@ -57,8 +58,8 @@ def create_batch(size: int, model: torch.nn.Module, device: str) -> dict:
z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers])
atoms = build.bulk("C", "diamond", a=3.567, cubic=True)
atoms = atoms.repeat((size, size, size))
config = data.config_from_atoms(atoms)
dataset = [data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)]
config = mace_data.config_from_atoms(atoms)
dataset = [mace_data.AtomicData.from_config(config, z_table=z_table, cutoff=cutoff)]
data_loader = torch_geometric.dataloader.DataLoader(
dataset=dataset,
batch_size=1,
Expand All @@ -78,45 +79,43 @@ def log_bench_info(benchmark, dtype, compile_mode, batch):
benchmark.extra_info["device_name"] = torch.cuda.get_device_name()


def read_bench_results(files: list[str]) -> pd.DataFrame:
def read(file):
with open(file, "r") as f:
data = json.load(f)
def process_benchmark_file(bench_file: Path) -> pd.DataFrame:
with open(bench_file, "r", encoding="utf-8") as f:
bench_data = json.load(f)

records = []
for bench in data["benchmarks"]:
record = {**bench["extra_info"], **bench["stats"]}
records.append(record)
records = []
for bench in bench_data["benchmarks"]:
record = {**bench["extra_info"], **bench["stats"]}
records.append(record)

df = pd.DataFrame(records)
df["ns/day (1 fs/step)"] = 0.086400 / df["median"]
df["Steps per day"] = df["ops"] * 86400
columns = [
"num_atoms",
"num_edges",
"dtype",
"is_compiled",
"device_name",
"median",
"Steps per day",
"ns/day (1 fs/step)",
]
return df[columns]
result_df = pd.DataFrame(records)
result_df["ns/day (1 fs/step)"] = 0.086400 / result_df["median"]
result_df["Steps per day"] = result_df["ops"] * 86400
columns = [
"num_atoms",
"num_edges",
"dtype",
"is_compiled",
"device_name",
"median",
"Steps per day",
"ns/day (1 fs/step)",
]
return result_df[columns]

return pd.concat([read(f) for f in files])

def read_bench_results(result_files: List[str]) -> pd.DataFrame:
return pd.concat([process_benchmark_file(Path(f)) for f in result_files])


if __name__ == "__main__":
# Print to stdout a csv of the benchmark metrics
import subprocess

result = subprocess.run(
["pytest-benchmark", "list"], capture_output=True, text=True
["pytest-benchmark", "list"], capture_output=True, text=True, check=True
)

if result.returncode != 0:
raise RuntimeError(f"Command failed with return code {result.returncode}")

files = result.stdout.strip().split("\n")
df = read_bench_results(files)
print(df.to_csv(index=False))
bench_files = result.stdout.strip().split("\n")
bench_results = read_bench_results(bench_files)
print(bench_results.to_csv(index=False))

0 comments on commit 40c0ec8

Please sign in to comment.