Skip to content

Commit

Permalink
fix skipped tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Feb 6, 2024
1 parent 6c12380 commit cefc181
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 86 deletions.
15 changes: 3 additions & 12 deletions source/tests/pt/model/test_env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,9 @@
import numpy as np
import torch

try:
from deepmd.dpmodel import (
EnvMat,
)

support_env_mat = True
except ModuleNotFoundError:
support_env_mat = False
except ImportError:
support_env_mat = False

from deepmd.dpmodel.utils import (
EnvMat,
)
from deepmd.pt.model.descriptor.env_mat import (
prod_env_mat_se_a,
)
Expand Down Expand Up @@ -77,7 +69,6 @@ def setUp(self):


# to be merged with the tf test case
@unittest.skipIf(not support_env_mat, "EnvMat not supported")
class TestEnvMat(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
TestCaseSingleFrameWithNlist.setUp(self)
Expand Down
75 changes: 12 additions & 63 deletions source/tests/pt/model/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,72 +5,25 @@
import numpy as np
import torch

from deepmd.dpmodel.utils import EmbeddingNet as DPEmbeddingNet
from deepmd.dpmodel.utils import FittingNet as DPFittingNet
from deepmd.dpmodel.utils import (
NativeLayer,
NativeNet,
)
from deepmd.pt.model.network.mlp import (
MLP,
EmbeddingNet,
FittingNet,
MLPLayer,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
PRECISION_DICT,
)

try:
from deepmd.pt.model.network.mlp import (
MLP,
MLPLayer,
)

support_native_net = True
except ModuleNotFoundError:
support_native_net = False

try:
from deepmd.pt.model.network.mlp import (
EmbeddingNet,
)

support_embedding_net = True
except ModuleNotFoundError:
support_embedding_net = False

try:
from deepmd.pt.model.network.mlp import (
FittingNet,
)

support_fitting_net = True
except ModuleNotFoundError:
support_fitting_net = False


try:
from deepmd.dpmodel import (
NativeLayer,
NativeNet,
)

support_native_net = True
except ModuleNotFoundError:
support_native_net = False
except ImportError:
support_native_net = False

try:
from deepmd.dpmodel import EmbeddingNet as DPEmbeddingNet

support_embedding_net = True
except ModuleNotFoundError:
support_embedding_net = False
except ImportError:
support_embedding_net = False

try:
from deepmd.dpmodel import FittingNet as DPFittingNet

support_fitting_net = True
except ModuleNotFoundError:
support_fitting_net = False
except ImportError:
support_fitting_net = False


def get_tols(prec):
if prec in ["single", "float32"]:
Expand All @@ -84,7 +37,6 @@ def get_tols(prec):
return rtol, atol


@unittest.skipIf(not support_native_net, "NativeLayer not supported")
class TestMLPLayer(unittest.TestCase):
def setUp(self):
self.test_cases = itertools.product(
Expand Down Expand Up @@ -141,7 +93,6 @@ def test_jit(self):
model = torch.jit.script(ml1)


@unittest.skipIf(not support_native_net, "NativeLayer not supported")
class TestMLP(unittest.TestCase):
def setUp(self):
self.test_cases = itertools.product(
Expand Down Expand Up @@ -210,7 +161,6 @@ def test_jit(self):
model = torch.jit.script(ml1)


@unittest.skipIf(not support_embedding_net, "EmbeddingNet not supported")
class TestEmbeddingNet(unittest.TestCase):
def setUp(self):
self.test_cases = itertools.product(
Expand Down Expand Up @@ -261,7 +211,6 @@ def test_jit(
model = torch.jit.script(ml1)


@unittest.skipIf(not support_fitting_net, "FittingNet not supported")
class TestFittingNet(unittest.TestCase):
def setUp(self):
self.test_cases = itertools.product(
Expand Down
12 changes: 1 addition & 11 deletions source/tests/pt/model/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,7 @@
import numpy as np
import torch

try:
# from deepmd.dpmodel import PRECISION_DICT as DP_PRECISION_DICT
from deepmd.dpmodel import DescrptSeA as DPDescrptSeA

support_se_e2_a = True
except ModuleNotFoundError:
support_se_e2_a = False
except ImportError:
support_se_e2_a = False

from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA
from deepmd.pt.model.descriptor.se_a import (
DescrptSeA,
)
Expand All @@ -36,7 +27,6 @@


# to be merged with the tf test case
@unittest.skipIf(not support_se_e2_a, "EnvMat not supported")
class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
TestCaseSingleFrameWithNlist.setUp(self)
Expand Down

0 comments on commit cefc181

Please sign in to comment.