Skip to content

Commit

Permalink
change optimizer type
Browse files Browse the repository at this point in the history
  • Loading branch information
jla-gardner committed Mar 5, 2024
1 parent 3009a07 commit 9961b16
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
26 changes: 23 additions & 3 deletions src/graph_pes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def train_model(
model: T,
train_data: list[AtomicGraph],
val_data: list[AtomicGraph] | None = None,
optimizer: Callable[[], torch.optim.Optimizer | OptimizerLRSchedulerConfig]
optimizer: Callable[[T], torch.optim.Optimizer | OptimizerLRSchedulerConfig]
| None = None,
loss: WeightedLoss | Loss | None = None,
*,
Expand Down Expand Up @@ -97,7 +97,7 @@ def train_model(
if optimizer is None:
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
else:
opt = optimizer()
opt = optimizer(model)

# create the task (a pytorch lightning module)
task = LearnThePES(model, opt, total_loss)
Expand Down Expand Up @@ -237,7 +237,7 @@ def process_loss(
return WeightedLoss([loss], [1.0])

default_transforms = {
keys.ENERGY: PerAtomStandardScaler(), # TODO is this right?
keys.ENERGY: PerAtomStandardScaler(),
keys.FORCES: PerAtomScale(),
keys.STRESS: Scale(),
}
Expand Down Expand Up @@ -292,3 +292,23 @@ def device_info_filter(record):
logging.getLogger("pytorch_lightning.utilities.rank_zero").addFilter(
device_info_filter
)


def Adam(
lr: float = 3e-4, weight_decay: float = 0.0
) -> Callable[[GraphPESModel], torch.optim.Optimizer]:
return lambda model: torch.optim.Adam(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
)


def SGD(
lr: float = 3e-4, weight_decay: float = 0.0
) -> Callable[[GraphPESModel], torch.optim.Optimizer]:
return lambda model: torch.optim.SGD(
model.parameters(),
lr=lr,
weight_decay=weight_decay,
)
2 changes: 1 addition & 1 deletion tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def test_calc():
calc = GraphPESCalculator(LennardJones(), cutoff=5)
ethanol = molecule("CH3CH2OH")
ethanol.set_calculator(calc)
ethanol.calc = calc

assert ethanol.get_potential_energy().shape == ()
assert ethanol.get_forces().shape == (9, 3)

0 comments on commit 9961b16

Please sign in to comment.