Skip to content

Commit

Permalink
Remove unused dtype and device parameters from potential classes
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Feb 3, 2025
1 parent ddff22b commit a680e9b
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 56 deletions.
15 changes: 0 additions & 15 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,6 @@
import torch


def _get_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
return torch.get_default_dtype() if dtype is None else dtype


def _get_device(device: Union[None, str, torch.device]) -> torch.device:
new_device = torch.get_default_device() if device is None else torch.device(device)

# Add default index of 0 to a cuda device to avoid errors when comparing with
# devices from tensors
if new_device.type == "cuda" and new_device.index is None:
new_device = torch.device("cuda:0")

return new_device


def _validate_parameters(
charges: torch.Tensor,
cell: torch.Tensor,
Expand Down
8 changes: 1 addition & 7 deletions src/torchpme/potentials/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,10 @@ def __init__(
learnable_weights: Optional[bool] = True,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
exclusion_radius=exclusion_radius,
dtype=dtype,
device=device,
)

smearings = [pot.smearing for pot in potentials]
Expand Down Expand Up @@ -73,9 +69,7 @@ def __init__(
"The number of initial weights must match the number of potentials being combined"
)
else:
initial_weights = torch.ones(
len(potentials), dtype=self.dtype, device=self.device
)
initial_weights = torch.ones(len(potentials))
# for torchscript
self.potentials = torch.nn.ModuleList(potentials)
if learnable_weights:
Expand Down
8 changes: 3 additions & 5 deletions src/torchpme/potentials/coulomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,18 @@ def __init__(
self,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
super().__init__(smearing, exclusion_radius)

# constants used in the forwward
self.register_buffer(
"_rsqrt2",
torch.rsqrt(torch.tensor(2.0, dtype=self.dtype, device=self.device)),
torch.rsqrt(torch.tensor(2.0)),
)
self.register_buffer(
"_sqrt_2_on_pi",
torch.sqrt(
torch.tensor(2.0 / torch.pi, dtype=self.dtype, device=self.device)
torch.tensor(2.0 / torch.pi)
),
)

Expand Down
10 changes: 3 additions & 7 deletions src/torchpme/potentials/inversepowerlaw.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,20 @@ class InversePowerLawPotential(Potential):
:param: exclusion_radius: float or torch.Tensor containing the length scale
corresponding to a local environment. See also
:class:`Potential`.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
self,
exponent: int,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(smearing, exclusion_radius, dtype, device)
super().__init__(smearing, exclusion_radius)

# function call to check the validity of the exponent
gammaincc_over_powerlaw(exponent, torch.tensor(1.0, dtype=dtype, device=device))
gammaincc_over_powerlaw(exponent, torch.tensor(1.0))
self.register_buffer(
"exponent", torch.tensor(exponent, dtype=self.dtype, device=self.device)
"exponent", torch.tensor(exponent)
)

@torch.jit.export
Expand Down
14 changes: 2 additions & 12 deletions src/torchpme/potentials/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

import torch

from .._utils import _get_device, _get_dtype


class Potential(torch.nn.Module):
r"""
Base class defining the interface for a pair potential energy function
Expand Down Expand Up @@ -32,32 +29,25 @@ class Potential(torch.nn.Module):
:param exclusion_radius: A length scale that defines a *local environment* within
which the potential should be smoothly zeroed out, as it will be described by a
separate model.
:param dtype: type used for the internal buffers and parameters
:param device: device used for the internal buffers and parameters
"""

def __init__(
self,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__()

self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

if smearing is not None:
self.register_buffer(
"smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
"smearing", torch.tensor(smearing)
)
else:
self.smearing = None
if exclusion_radius is not None:
self.register_buffer(
"exclusion_radius",
torch.tensor(exclusion_radius, device=self.device, dtype=self.dtype),
torch.tensor(exclusion_radius),
)
else:
self.exclusion_radius = None
Expand Down
16 changes: 6 additions & 10 deletions src/torchpme/potentials/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,17 @@ def __init__(
yhat_at_zero: Optional[float] = None,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Union[None, str, torch.device] = None,
):
super().__init__(
smearing=smearing,
exclusion_radius=exclusion_radius,
dtype=dtype,
device=device,
)

if len(y_grid) != len(r_grid):
raise ValueError("Length of radial grid and value array mismatch.")

r_grid = r_grid.to(dtype=self.dtype, device=self.device)
y_grid = y_grid.to(dtype=self.dtype, device=self.device)
self.register_buffer("r_grid", r_grid)
self.register_buffer("y_grid", y_grid)

if reciprocal:
if torch.min(r_grid) <= 0.0:
Expand All @@ -87,9 +83,9 @@ def __init__(
if reciprocal:
k_grid = torch.pi * 2 * torch.reciprocal(r_grid).flip(dims=[0])
else:
k_grid = r_grid.clone()
k_grid = r_grid.clone().detach()
else:
k_grid = k_grid.to(dtype=self.dtype, device=self.device)
self.register_buffer("k_grid", k_grid)

if yhat_grid is None:
# computes automatically!
Expand All @@ -100,7 +96,7 @@ def __init__(
compute_second_derivatives(r_grid, y_grid),
)
else:
yhat_grid = yhat_grid.to(dtype=self.dtype, device=self.device)
self.register_buffer("yhat_grid", yhat_grid)

# the function is defined for k**2, so we define the grid accordingly
if reciprocal:
Expand Down

0 comments on commit a680e9b

Please sign in to comment.