From 8206e0245cc45e6ff1dc803b9e3c4ec00b9145c4 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 13 Nov 2023 18:14:05 -0500 Subject: [PATCH] black, isort again --- gflownet/envs/base.py | 3 +-- gflownet/envs/crystals/composition.py | 7 +++++-- gflownet/envs/crystals/lattice_parameters.py | 14 ++++++++++---- gflownet/envs/ctorus.py | 3 +-- gflownet/gflownet.py | 12 +++++++++--- gflownet/utils/batch.py | 12 ++++++++++-- .../utils/crystals/build_lattice_dicts.py | 19 +++++++++++++------ gflownet/utils/molecule/geom.py | 6 ++++-- gflownet/utils/oracle.py | 3 +-- playground/botorch/mes_exact_deepKernel.py | 1 + playground/botorch/mes_gp.py | 1 + playground/botorch/mes_gp_debug.py | 7 +++---- playground/botorch/mes_nn_bao_fix.py | 5 +++-- playground/botorch/mes_nn_hardcode_gpVal.py | 4 ++-- playground/botorch/mes_nn_like_gp.py | 5 +++-- .../mes_nn_like_gp_nondiagonalcovar.py | 5 +++-- playground/botorch/mes_var_deepKernel.py | 5 ++++- scripts/conformer/geom_stats.py | 14 ++++++++++---- scripts/dav_mp20_stats.py | 3 +-- scripts/pyxtal/pyxtal_vs_pymatgen.py | 8 ++++++-- .../gflownet/envs/test_lattice_parameters.py | 18 +++++++++++------- tests/gflownet/envs/test_tree.py | 11 +++++++++-- .../policy/test_multihead_tree_policy.py | 11 +++++++---- .../utils/molecule/test_rotatable_bonds.py | 6 ++++-- .../gflownet/utils/molecule/test_torsions.py | 6 ++---- tests/gflownet/utils/test_batch.py | 13 ++++++++++--- 26 files changed, 134 insertions(+), 68 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index b7f59128f..e0381ed37 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -15,8 +15,7 @@ from torch.distributions import Categorical from torchtyping import TensorType -from gflownet.utils.common import (copy, set_device, set_float_precision, - tbool, tfloat) +from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat CMAP = mpl.colormaps["cividis"] diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 721f702b7..2e1e75240 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -14,8 +14,11 @@ from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES from gflownet.utils.crystals.pyxtal_cache import ( - get_space_group, space_group_check_compatible, - space_group_lowest_free_wp_multiplicity, space_group_wyckoff_gcd) + get_space_group, + space_group_check_compatible, + space_group_lowest_free_wp_multiplicity, + space_group_wyckoff_gcd, +) class Composition(GFlowNetEnv): diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index e15bfed54..957a7e229 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -9,10 +9,16 @@ from torchtyping import TensorType from gflownet.envs.grid import Grid -from gflownet.utils.crystals.constants import (CUBIC, HEXAGONAL, - LATTICE_SYSTEMS, MONOCLINIC, - ORTHORHOMBIC, RHOMBOHEDRAL, - TETRAGONAL, TRICLINIC) +from gflownet.utils.crystals.constants import ( + CUBIC, + HEXAGONAL, + LATTICE_SYSTEMS, + MONOCLINIC, + ORTHORHOMBIC, + RHOMBOHEDRAL, + TETRAGONAL, + TRICLINIC, +) class LatticeParameters(Grid): diff --git a/gflownet/envs/ctorus.py b/gflownet/envs/ctorus.py index 8b6b79d0e..c006428f6 100644 --- a/gflownet/envs/ctorus.py +++ b/gflownet/envs/ctorus.py @@ -9,8 +9,7 @@ import numpy.typing as npt import pandas as pd import torch -from torch.distributions import (Categorical, MixtureSameFamily, Uniform, - VonMises) +from torch.distributions import Categorical, MixtureSameFamily, Uniform, VonMises from torchtyping import TensorType from gflownet.envs.htorus import HybridTorus diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 64b228acb..4eebc4d6b 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -21,9 +21,15 @@ from gflownet.envs.base import GFlowNetEnv from gflownet.utils.batch import Batch from gflownet.utils.buffer import Buffer -from gflownet.utils.common import (batch_with_rest, set_device, - set_float_precision, tbool, tfloat, tlong, - torch2np) +from gflownet.utils.common import ( + batch_with_rest, + set_device, + set_float_precision, + tbool, + tfloat, + tlong, + torch2np, +) class GFlowNetAgent: diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index c76ecfa9e..a35f01ddf 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -7,8 +7,16 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv -from gflownet.utils.common import (concat_items, copy, extend, set_device, - set_float_precision, tbool, tfloat, tlong) +from gflownet.utils.common import ( + concat_items, + copy, + extend, + set_device, + set_float_precision, + tbool, + tfloat, + tlong, +) class Batch: diff --git a/gflownet/utils/crystals/build_lattice_dicts.py b/gflownet/utils/crystals/build_lattice_dicts.py index f62c4bd19..65d4a7958 100644 --- a/gflownet/utils/crystals/build_lattice_dicts.py +++ b/gflownet/utils/crystals/build_lattice_dicts.py @@ -8,12 +8,19 @@ import numpy as np import yaml -from lattice_constants import (CRYSTAL_CLASSES_WIKIPEDIA, - CRYSTAL_LATTICE_SYSTEMS, CRYSTAL_SYSTEMS, - POINT_SYMMETRIES, - RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA) -from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, - sg_symbol_from_int_number) +from lattice_constants import ( + CRYSTAL_CLASSES_WIKIPEDIA, + CRYSTAL_LATTICE_SYSTEMS, + CRYSTAL_SYSTEMS, + POINT_SYMMETRIES, + RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA, +) +from pymatgen.symmetry.groups import ( + PointGroup, + SpaceGroup, + SymmetryGroup, + sg_symbol_from_int_number, +) N_SPACE_GROUPS = 230 diff --git a/gflownet/utils/molecule/geom.py b/gflownet/utils/molecule/geom.py index 273f2cdae..e7ac71308 100644 --- a/gflownet/utils/molecule/geom.py +++ b/gflownet/utils/molecule/geom.py @@ -8,8 +8,10 @@ from rdkit import Chem from tqdm import tqdm -from gflownet.utils.molecule.rotatable_bonds import (get_rotatable_ta_list, - is_hydrogen_ta) +from gflownet.utils.molecule.rotatable_bonds import ( + get_rotatable_ta_list, + is_hydrogen_ta, +) def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None): diff --git a/gflownet/utils/oracle.py b/gflownet/utils/oracle.py index 1fc9c92c6..c4713745a 100644 --- a/gflownet/utils/oracle.py +++ b/gflownet/utils/oracle.py @@ -14,8 +14,7 @@ ) pass try: - from bbdob import (DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, - WModel) + from bbdob import DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, WModel from bbdob.utils import idx2one_hot except: print( diff --git a/playground/botorch/mes_exact_deepKernel.py b/playground/botorch/mes_exact_deepKernel.py index 743b68256..b77bb2e89 100644 --- a/playground/botorch/mes_exact_deepKernel.py +++ b/playground/botorch/mes_exact_deepKernel.py @@ -8,6 +8,7 @@ from math import floor import gpytorch + # import tqdm import torch from botorch.test_functions import Hartmann diff --git a/playground/botorch/mes_gp.py b/playground/botorch/mes_gp.py index 8afde5dc8..b51df0ce6 100644 --- a/playground/botorch/mes_gp.py +++ b/playground/botorch/mes_gp.py @@ -6,6 +6,7 @@ import numpy as np import torch + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Branin, Hartmann diff --git a/playground/botorch/mes_gp_debug.py b/playground/botorch/mes_gp_debug.py index 76af6ff00..06c5a3ed6 100644 --- a/playground/botorch/mes_gp_debug.py +++ b/playground/botorch/mes_gp_debug.py @@ -3,6 +3,7 @@ import gpytorch import numpy as np import torch + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -49,10 +50,8 @@ def forward(self, x): from botorch.models.utils import add_output_dim from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) -from gpytorch.likelihoods.gaussian_likelihood import \ - FixedNoiseGaussianLikelihood +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal +from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood class myGPModel(SingleTaskGP): diff --git a/playground/botorch/mes_nn_bao_fix.py b/playground/botorch/mes_nn_bao_fix.py index 861268aeb..c4f7de6d0 100644 --- a/playground/botorch/mes_nn_bao_fix.py +++ b/playground/botorch/mes_nn_bao_fix.py @@ -2,6 +2,7 @@ import numpy as np import torch + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -55,8 +56,8 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal + # from botorch.posteriors. from torch.distributions import Normal diff --git a/playground/botorch/mes_nn_hardcode_gpVal.py b/playground/botorch/mes_nn_hardcode_gpVal.py index 42dcbb9b4..6320d4f05 100644 --- a/playground/botorch/mes_nn_hardcode_gpVal.py +++ b/playground/botorch/mes_nn_hardcode_gpVal.py @@ -2,6 +2,7 @@ import numpy as np import torch + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.test_functions import Hartmann @@ -56,8 +57,7 @@ from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal class NN_Model(Model): diff --git a/playground/botorch/mes_nn_like_gp.py b/playground/botorch/mes_nn_like_gp.py index 0b15c98be..d0664a342 100644 --- a/playground/botorch/mes_nn_like_gp.py +++ b/playground/botorch/mes_nn_like_gp.py @@ -3,14 +3,15 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.mlls import ExactMarginalLogLikelihood + # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py index 1d6626b33..2c75fd6a4 100644 --- a/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py +++ b/playground/botorch/mes_nn_like_gp_nondiagonalcovar.py @@ -3,14 +3,15 @@ import numpy as np import torch from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy + # from botorch.fit import fit_gpytorch_mll from botorch.models import SingleTaskGP from botorch.models.model import Model from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.test_functions import Hartmann -from gpytorch.distributions import (MultitaskMultivariateNormal, - MultivariateNormal) +from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal from gpytorch.mlls import ExactMarginalLogLikelihood + # from botorch.posteriors. from torch import distributions, tensor from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential diff --git a/playground/botorch/mes_var_deepKernel.py b/playground/botorch/mes_var_deepKernel.py index 989af46c4..f712eaaf0 100644 --- a/playground/botorch/mes_var_deepKernel.py +++ b/playground/botorch/mes_var_deepKernel.py @@ -10,6 +10,7 @@ from math import floor import gpytorch + # import tqdm import torch from botorch.test_functions import Hartmann @@ -214,7 +215,9 @@ def posterior( from botorch.acquisition.max_value_entropy_search import ( - qLowerBoundMaxValueEntropy, qMaxValueEntropy) + qLowerBoundMaxValueEntropy, + qMaxValueEntropy, +) proxy = myGPModel(model, train_x, train_y.unsqueeze(-1)) qMES = qLowerBoundMaxValueEntropy(proxy, candidate_set=train_x, use_gumbel=True) diff --git a/scripts/conformer/geom_stats.py b/scripts/conformer/geom_stats.py index 5c6fa1699..5194b1068 100644 --- a/scripts/conformer/geom_stats.py +++ b/scripts/conformer/geom_stats.py @@ -9,10 +9,16 @@ from rdkit import Chem from tqdm import tqdm -from gflownet.utils.molecule.geom import (all_same_graphs, get_all_confs_geom, - get_conf_geom, get_rd_mol) -from gflownet.utils.molecule.rotatable_bonds import (get_rotatable_ta_list, - has_hydrogen_tas) +from gflownet.utils.molecule.geom import ( + all_same_graphs, + get_all_confs_geom, + get_conf_geom, + get_rd_mol, +) +from gflownet.utils.molecule.rotatable_bonds import ( + get_rotatable_ta_list, + has_hydrogen_tas, +) """ Here we use rdkit_folder format of the GEOM dataset diff --git a/scripts/dav_mp20_stats.py b/scripts/dav_mp20_stats.py index 868c3745e..3df1c78c9 100644 --- a/scripts/dav_mp20_stats.py +++ b/scripts/dav_mp20_stats.py @@ -19,8 +19,7 @@ from collections import Counter -from external.repos.ActiveLearningMaterials.dave.utils.loaders import \ - make_loaders +from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders from gflownet.proxy.crystals.dave import DAVE from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path diff --git a/scripts/pyxtal/pyxtal_vs_pymatgen.py b/scripts/pyxtal/pyxtal_vs_pymatgen.py index 6ffcbaa21..62a226ae7 100644 --- a/scripts/pyxtal/pyxtal_vs_pymatgen.py +++ b/scripts/pyxtal/pyxtal_vs_pymatgen.py @@ -4,8 +4,12 @@ """ from argparse import ArgumentParser -from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup, - sg_symbol_from_int_number) +from pymatgen.symmetry.groups import ( + PointGroup, + SpaceGroup, + SymmetryGroup, + sg_symbol_from_int_number, +) from pyxtal.symmetry import Group N_SYMMETRY_GROUPS = 230 diff --git a/tests/gflownet/envs/test_lattice_parameters.py b/tests/gflownet/envs/test_lattice_parameters.py index d74ea29bd..16aea2814 100644 --- a/tests/gflownet/envs/test_lattice_parameters.py +++ b/tests/gflownet/envs/test_lattice_parameters.py @@ -2,13 +2,17 @@ import pytest import torch -from gflownet.envs.crystals.lattice_parameters import (CUBIC, HEXAGONAL, - LATTICE_SYSTEMS, - MONOCLINIC, - ORTHORHOMBIC, - RHOMBOHEDRAL, - TETRAGONAL, TRICLINIC, - LatticeParameters) +from gflownet.envs.crystals.lattice_parameters import ( + CUBIC, + HEXAGONAL, + LATTICE_SYSTEMS, + MONOCLINIC, + ORTHORHOMBIC, + RHOMBOHEDRAL, + TETRAGONAL, + TRICLINIC, + LatticeParameters, +) @pytest.fixture() diff --git a/tests/gflownet/envs/test_tree.py b/tests/gflownet/envs/test_tree.py index 3d7af9c84..d21009288 100644 --- a/tests/gflownet/envs/test_tree.py +++ b/tests/gflownet/envs/test_tree.py @@ -5,8 +5,15 @@ import pytest import torch -from gflownet.envs.tree import (ActionType, Attribute, NodeType, Operator, - Stage, Status, Tree) +from gflownet.envs.tree import ( + ActionType, + Attribute, + NodeType, + Operator, + Stage, + Status, + Tree, +) from gflownet.utils.common import tfloat NAN = float("NaN") diff --git a/tests/gflownet/policy/test_multihead_tree_policy.py b/tests/gflownet/policy/test_multihead_tree_policy.py index 5d5448099..a28570a53 100644 --- a/tests/gflownet/policy/test_multihead_tree_policy.py +++ b/tests/gflownet/policy/test_multihead_tree_policy.py @@ -4,10 +4,13 @@ from torch_geometric.data import Batch from gflownet.envs.tree import Attribute, Operator, Tree -from gflownet.policy.multihead_tree import (Backbone, FeatureSelectionHead, - LeafSelectionHead, - OperatorSelectionHead, - ThresholdSelectionHead) +from gflownet.policy.multihead_tree import ( + Backbone, + FeatureSelectionHead, + LeafSelectionHead, + OperatorSelectionHead, + ThresholdSelectionHead, +) N_OBSERVATIONS = 17 N_FEATURES = 5 diff --git a/tests/gflownet/utils/molecule/test_rotatable_bonds.py b/tests/gflownet/utils/molecule/test_rotatable_bonds.py index 931bb518b..b316dc1c0 100644 --- a/tests/gflownet/utils/molecule/test_rotatable_bonds.py +++ b/tests/gflownet/utils/molecule/test_rotatable_bonds.py @@ -2,8 +2,10 @@ from rdkit import Chem from gflownet.utils.molecule import constants -from gflownet.utils.molecule.rotatable_bonds import (find_rotor_from_smiles, - is_hydrogen_ta) +from gflownet.utils.molecule.rotatable_bonds import ( + find_rotor_from_smiles, + is_hydrogen_ta, +) def test_simple_ad(): diff --git a/tests/gflownet/utils/molecule/test_torsions.py b/tests/gflownet/utils/molecule/test_torsions.py index ef11ca543..acdeef1db 100644 --- a/tests/gflownet/utils/molecule/test_torsions.py +++ b/tests/gflownet/utils/molecule/test_torsions.py @@ -8,8 +8,7 @@ from gflownet.utils.molecule import constants from gflownet.utils.molecule.featurizer import MolDGLFeaturizer from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values -from gflownet.utils.molecule.torsions import (apply_rotations, - get_rotation_masks) +from gflownet.utils.molecule.torsions import apply_rotations, get_rotation_masks def test_four_nodes_chain(): @@ -148,8 +147,7 @@ def stress_test_apply_rotation_alanine_dipeptide(): from rdkit.Geometry.rdGeometry import Point3D from gflownet.utils.molecule.featurizer import MolDGLFeaturizer - from gflownet.utils.molecule.rdkit_conformer import \ - get_torsion_angles_values + from gflownet.utils.molecule.rdkit_conformer import get_torsion_angles_values mol = Chem.MolFromSmiles(constants.ad_smiles) mol = Chem.AddHs(mol) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index e26dd2455..338dfd061 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -8,9 +8,16 @@ from gflownet.proxy.corners import Corners from gflownet.proxy.tetris import Tetris as TetrisScore from gflownet.utils.batch import Batch -from gflownet.utils.common import (concat_items, copy, set_device, - set_float_precision, tbool, tfloat, tint, - tlong) +from gflownet.utils.common import ( + concat_items, + copy, + set_device, + set_float_precision, + tbool, + tfloat, + tint, + tlong, +) # Sets the number of repetitions for the tests. Please increase to ~10 after # introducing changes to the Batch class and decrease again to 1 when passed.