Skip to content

Commit

Permalink
black, isort again
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandraVolokhova committed Nov 13, 2023
1 parent 0d785d4 commit 8206e02
Show file tree
Hide file tree
Showing 26 changed files with 134 additions and 68 deletions.
3 changes: 1 addition & 2 deletions gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from torch.distributions import Categorical
from torchtyping import TensorType

from gflownet.utils.common import (copy, set_device, set_float_precision,
tbool, tfloat)
from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat

CMAP = mpl.colormaps["cividis"]

Expand Down
7 changes: 5 additions & 2 deletions gflownet/envs/crystals/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from gflownet.utils.common import tlong
from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES
from gflownet.utils.crystals.pyxtal_cache import (
get_space_group, space_group_check_compatible,
space_group_lowest_free_wp_multiplicity, space_group_wyckoff_gcd)
get_space_group,
space_group_check_compatible,
space_group_lowest_free_wp_multiplicity,
space_group_wyckoff_gcd,
)


class Composition(GFlowNetEnv):
Expand Down
14 changes: 10 additions & 4 deletions gflownet/envs/crystals/lattice_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
from torchtyping import TensorType

from gflownet.envs.grid import Grid
from gflownet.utils.crystals.constants import (CUBIC, HEXAGONAL,
LATTICE_SYSTEMS, MONOCLINIC,
ORTHORHOMBIC, RHOMBOHEDRAL,
TETRAGONAL, TRICLINIC)
from gflownet.utils.crystals.constants import (
CUBIC,
HEXAGONAL,
LATTICE_SYSTEMS,
MONOCLINIC,
ORTHORHOMBIC,
RHOMBOHEDRAL,
TETRAGONAL,
TRICLINIC,
)


class LatticeParameters(Grid):
Expand Down
3 changes: 1 addition & 2 deletions gflownet/envs/ctorus.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import numpy.typing as npt
import pandas as pd
import torch
from torch.distributions import (Categorical, MixtureSameFamily, Uniform,
VonMises)
from torch.distributions import Categorical, MixtureSameFamily, Uniform, VonMises
from torchtyping import TensorType

from gflownet.envs.htorus import HybridTorus
Expand Down
12 changes: 9 additions & 3 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,15 @@
from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.batch import Batch
from gflownet.utils.buffer import Buffer
from gflownet.utils.common import (batch_with_rest, set_device,
set_float_precision, tbool, tfloat, tlong,
torch2np)
from gflownet.utils.common import (
batch_with_rest,
set_device,
set_float_precision,
tbool,
tfloat,
tlong,
torch2np,
)


class GFlowNetAgent:
Expand Down
12 changes: 10 additions & 2 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
from torchtyping import TensorType

from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import (concat_items, copy, extend, set_device,
set_float_precision, tbool, tfloat, tlong)
from gflownet.utils.common import (
concat_items,
copy,
extend,
set_device,
set_float_precision,
tbool,
tfloat,
tlong,
)


class Batch:
Expand Down
19 changes: 13 additions & 6 deletions gflownet/utils/crystals/build_lattice_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,19 @@

import numpy as np
import yaml
from lattice_constants import (CRYSTAL_CLASSES_WIKIPEDIA,
CRYSTAL_LATTICE_SYSTEMS, CRYSTAL_SYSTEMS,
POINT_SYMMETRIES,
RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA)
from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup,
sg_symbol_from_int_number)
from lattice_constants import (
CRYSTAL_CLASSES_WIKIPEDIA,
CRYSTAL_LATTICE_SYSTEMS,
CRYSTAL_SYSTEMS,
POINT_SYMMETRIES,
RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA,
)
from pymatgen.symmetry.groups import (
PointGroup,
SpaceGroup,
SymmetryGroup,
sg_symbol_from_int_number,
)

N_SPACE_GROUPS = 230

Expand Down
6 changes: 4 additions & 2 deletions gflownet/utils/molecule/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from rdkit import Chem
from tqdm import tqdm

from gflownet.utils.molecule.rotatable_bonds import (get_rotatable_ta_list,
is_hydrogen_ta)
from gflownet.utils.molecule.rotatable_bonds import (
get_rotatable_ta_list,
is_hydrogen_ta,
)


def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None):
Expand Down
3 changes: 1 addition & 2 deletions gflownet/utils/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
)
pass
try:
from bbdob import (DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin,
WModel)
from bbdob import DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, WModel
from bbdob.utils import idx2one_hot
except:
print(
Expand Down
1 change: 1 addition & 0 deletions playground/botorch/mes_exact_deepKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from math import floor

import gpytorch

# import tqdm
import torch
from botorch.test_functions import Hartmann
Expand Down
1 change: 1 addition & 0 deletions playground/botorch/mes_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Branin, Hartmann
Expand Down
7 changes: 3 additions & 4 deletions playground/botorch/mes_gp_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import gpytorch
import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Hartmann
Expand Down Expand Up @@ -49,10 +50,8 @@ def forward(self, x):

from botorch.models.utils import add_output_dim
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)
from gpytorch.likelihoods.gaussian_likelihood import \
FixedNoiseGaussianLikelihood
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood


class myGPModel(SingleTaskGP):
Expand Down
5 changes: 3 additions & 2 deletions playground/botorch/mes_nn_bao_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Hartmann
Expand Down Expand Up @@ -55,8 +56,8 @@
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal

# from botorch.posteriors.
from torch.distributions import Normal

Expand Down
4 changes: 2 additions & 2 deletions playground/botorch/mes_nn_hardcode_gpVal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Hartmann
Expand Down Expand Up @@ -56,8 +57,7 @@
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal


class NN_Model(Model):
Expand Down
5 changes: 3 additions & 2 deletions playground/botorch/mes_nn_like_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import numpy as np
import torch
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.test_functions import Hartmann
from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.mlls import ExactMarginalLogLikelihood

# from botorch.posteriors.
from torch import distributions, tensor
from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential
Expand Down
5 changes: 3 additions & 2 deletions playground/botorch/mes_nn_like_gp_nondiagonalcovar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import numpy as np
import torch
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.test_functions import Hartmann
from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.mlls import ExactMarginalLogLikelihood

# from botorch.posteriors.
from torch import distributions, tensor
from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential
Expand Down
5 changes: 4 additions & 1 deletion playground/botorch/mes_var_deepKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from math import floor

import gpytorch

# import tqdm
import torch
from botorch.test_functions import Hartmann
Expand Down Expand Up @@ -214,7 +215,9 @@ def posterior(


from botorch.acquisition.max_value_entropy_search import (
qLowerBoundMaxValueEntropy, qMaxValueEntropy)
qLowerBoundMaxValueEntropy,
qMaxValueEntropy,
)

proxy = myGPModel(model, train_x, train_y.unsqueeze(-1))
qMES = qLowerBoundMaxValueEntropy(proxy, candidate_set=train_x, use_gumbel=True)
Expand Down
14 changes: 10 additions & 4 deletions scripts/conformer/geom_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
from rdkit import Chem
from tqdm import tqdm

from gflownet.utils.molecule.geom import (all_same_graphs, get_all_confs_geom,
get_conf_geom, get_rd_mol)
from gflownet.utils.molecule.rotatable_bonds import (get_rotatable_ta_list,
has_hydrogen_tas)
from gflownet.utils.molecule.geom import (
all_same_graphs,
get_all_confs_geom,
get_conf_geom,
get_rd_mol,
)
from gflownet.utils.molecule.rotatable_bonds import (
get_rotatable_ta_list,
has_hydrogen_tas,
)

"""
Here we use rdkit_folder format of the GEOM dataset
Expand Down
3 changes: 1 addition & 2 deletions scripts/dav_mp20_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

from collections import Counter

from external.repos.ActiveLearningMaterials.dave.utils.loaders import \
make_loaders
from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders

from gflownet.proxy.crystals.dave import DAVE
from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path
Expand Down
8 changes: 6 additions & 2 deletions scripts/pyxtal/pyxtal_vs_pymatgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
"""
from argparse import ArgumentParser

from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup,
sg_symbol_from_int_number)
from pymatgen.symmetry.groups import (
PointGroup,
SpaceGroup,
SymmetryGroup,
sg_symbol_from_int_number,
)
from pyxtal.symmetry import Group

N_SYMMETRY_GROUPS = 230
Expand Down
18 changes: 11 additions & 7 deletions tests/gflownet/envs/test_lattice_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
import pytest
import torch

from gflownet.envs.crystals.lattice_parameters import (CUBIC, HEXAGONAL,
LATTICE_SYSTEMS,
MONOCLINIC,
ORTHORHOMBIC,
RHOMBOHEDRAL,
TETRAGONAL, TRICLINIC,
LatticeParameters)
from gflownet.envs.crystals.lattice_parameters import (
CUBIC,
HEXAGONAL,
LATTICE_SYSTEMS,
MONOCLINIC,
ORTHORHOMBIC,
RHOMBOHEDRAL,
TETRAGONAL,
TRICLINIC,
LatticeParameters,
)


@pytest.fixture()
Expand Down
11 changes: 9 additions & 2 deletions tests/gflownet/envs/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@
import pytest
import torch

from gflownet.envs.tree import (ActionType, Attribute, NodeType, Operator,
Stage, Status, Tree)
from gflownet.envs.tree import (
ActionType,
Attribute,
NodeType,
Operator,
Stage,
Status,
Tree,
)
from gflownet.utils.common import tfloat

NAN = float("NaN")
Expand Down
11 changes: 7 additions & 4 deletions tests/gflownet/policy/test_multihead_tree_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from torch_geometric.data import Batch

from gflownet.envs.tree import Attribute, Operator, Tree
from gflownet.policy.multihead_tree import (Backbone, FeatureSelectionHead,
LeafSelectionHead,
OperatorSelectionHead,
ThresholdSelectionHead)
from gflownet.policy.multihead_tree import (
Backbone,
FeatureSelectionHead,
LeafSelectionHead,
OperatorSelectionHead,
ThresholdSelectionHead,
)

N_OBSERVATIONS = 17
N_FEATURES = 5
Expand Down
6 changes: 4 additions & 2 deletions tests/gflownet/utils/molecule/test_rotatable_bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from rdkit import Chem

from gflownet.utils.molecule import constants
from gflownet.utils.molecule.rotatable_bonds import (find_rotor_from_smiles,
is_hydrogen_ta)
from gflownet.utils.molecule.rotatable_bonds import (
find_rotor_from_smiles,
is_hydrogen_ta,
)


def test_simple_ad():
Expand Down
Loading

0 comments on commit 8206e02

Please sign in to comment.