Skip to content

Commit

Permalink
black, isort
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandraVolokhova committed Nov 13, 2023
1 parent 2157983 commit 436890f
Show file tree
Hide file tree
Showing 33 changed files with 329 additions and 308 deletions.
3 changes: 2 additions & 1 deletion gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from torch.distributions import Categorical
from torchtyping import TensorType

from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat
from gflownet.utils.common import (copy, set_device, set_float_precision,
tbool, tfloat)

CMAP = mpl.colormaps["cividis"]

Expand Down
1 change: 0 additions & 1 deletion gflownet/envs/conformers/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from gflownet.utils.molecule.rdkit_conformer import RDKitConformer
from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smiles


PREDEFINED_SMILES = [
"O=C(c1ccccc1)c1ccc2c(c1)OCCOCCOCCOCCO2",
"O=S(=O)(NN=C1CCCCCC1)c1ccc(Cl)cc1",
Expand Down
7 changes: 2 additions & 5 deletions gflownet/envs/crystals/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@
from gflownet.utils.common import tlong
from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES
from gflownet.utils.crystals.pyxtal_cache import (
get_space_group,
space_group_check_compatible,
space_group_lowest_free_wp_multiplicity,
space_group_wyckoff_gcd,
)
get_space_group, space_group_check_compatible,
space_group_lowest_free_wp_multiplicity, space_group_wyckoff_gcd)


class Composition(GFlowNetEnv):
Expand Down
14 changes: 4 additions & 10 deletions gflownet/envs/crystals/lattice_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,10 @@
from torchtyping import TensorType

from gflownet.envs.grid import Grid
from gflownet.utils.crystals.constants import (
CUBIC,
HEXAGONAL,
LATTICE_SYSTEMS,
MONOCLINIC,
ORTHORHOMBIC,
RHOMBOHEDRAL,
TETRAGONAL,
TRICLINIC,
)
from gflownet.utils.crystals.constants import (CUBIC, HEXAGONAL,
LATTICE_SYSTEMS, MONOCLINIC,
ORTHORHOMBIC, RHOMBOHEDRAL,
TETRAGONAL, TRICLINIC)


class LatticeParameters(Grid):
Expand Down
3 changes: 2 additions & 1 deletion gflownet/envs/ctorus.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numpy.typing as npt
import pandas as pd
import torch
from torch.distributions import Categorical, MixtureSameFamily, Uniform, VonMises
from torch.distributions import (Categorical, MixtureSameFamily, Uniform,
VonMises)
from torchtyping import TensorType

from gflownet.envs.htorus import HybridTorus
Expand Down
12 changes: 3 additions & 9 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,9 @@
from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.batch import Batch
from gflownet.utils.buffer import Buffer
from gflownet.utils.common import (
batch_with_rest,
set_device,
set_float_precision,
tbool,
tfloat,
tlong,
torch2np,
)
from gflownet.utils.common import (batch_with_rest, set_device,
set_float_precision, tbool, tfloat, tlong,
torch2np)


class GFlowNetAgent:
Expand Down
12 changes: 2 additions & 10 deletions gflownet/utils/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,8 @@
from torchtyping import TensorType

from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import (
concat_items,
copy,
extend,
set_device,
set_float_precision,
tbool,
tfloat,
tlong,
)
from gflownet.utils.common import (concat_items, copy, extend, set_device,
set_float_precision, tbool, tfloat, tlong)


class Batch:
Expand Down
19 changes: 6 additions & 13 deletions gflownet/utils/crystals/build_lattice_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,12 @@

import numpy as np
import yaml
from lattice_constants import (
CRYSTAL_CLASSES_WIKIPEDIA,
CRYSTAL_LATTICE_SYSTEMS,
CRYSTAL_SYSTEMS,
POINT_SYMMETRIES,
RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA,
)
from pymatgen.symmetry.groups import (
PointGroup,
SpaceGroup,
SymmetryGroup,
sg_symbol_from_int_number,
)
from lattice_constants import (CRYSTAL_CLASSES_WIKIPEDIA,
CRYSTAL_LATTICE_SYSTEMS, CRYSTAL_SYSTEMS,
POINT_SYMMETRIES,
RHOMBOHEDRAL_SPACE_GROUPS_WIKIPEDIA)
from pymatgen.symmetry.groups import (PointGroup, SpaceGroup, SymmetryGroup,
sg_symbol_from_int_number)

N_SPACE_GROUPS = 230

Expand Down
30 changes: 18 additions & 12 deletions gflownet/utils/molecule/geom.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,63 @@
import os
import json
import os
import pickle
from pathlib import Path

import numpy as np
import pandas as pd

from rdkit import Chem
from pathlib import Path
from tqdm import tqdm

from gflownet.utils.molecule.rotatable_bonds import get_rotatable_ta_list, is_hydrogen_ta
from gflownet.utils.molecule.rotatable_bonds import (get_rotatable_ta_list,
is_hydrogen_ta)


def get_conf_geom(base_path, smiles, conf_idx=0, summary_file=None):
if summary_file is None:
drugs_file = base_path / 'rdkit_folder/summary_drugs.json'
drugs_file = base_path / "rdkit_folder/summary_drugs.json"
with open(drugs_file, "r") as f:
summary_file = json.load(f)

pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path']
pickle_path = base_path / "rdkit_folder" / summary_file[smiles]["pickle_path"]
if os.path.isfile(pickle_path):
with open(pickle_path, "rb") as f:
dic = pickle.load(f)
mol = dic['conformers'][conf_idx]['rd_mol']
mol = dic["conformers"][conf_idx]["rd_mol"]
return mol


def get_all_confs_geom(base_path, smiles, summary_file=None):
if summary_file is None:
drugs_file = base_path / 'rdkit_folder/summary_drugs.json'
drugs_file = base_path / "rdkit_folder/summary_drugs.json"
with open(drugs_file, "r") as f:
summary_file = json.load(f)
try:
pickle_path = base_path / "rdkit_folder" / summary_file[smiles]['pickle_path']
pickle_path = base_path / "rdkit_folder" / summary_file[smiles]["pickle_path"]
if os.path.isfile(pickle_path):
with open(pickle_path, "rb") as f:
dic = pickle.load(f)
conformers = [x['rd_mol'] for x in dic['conformers']]
conformers = [x["rd_mol"] for x in dic["conformers"]]
return conformers
except KeyError:
print('No pickle_path file for {}'.format(smiles))
print("No pickle_path file for {}".format(smiles))
return None


def get_rd_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
return mol


def has_same_can_smiles(mol1, mol2):
sm1 = Chem.CanonSmiles(Chem.MolToSmiles(mol1))
sm2 = Chem.CanonSmiles(Chem.MolToSmiles(mol2))
return sm1 == sm2


def all_same_graphs(mols):
ref = mols[0]
same = []
for mol in mols:
same.append(has_same_can_smiles(ref, mol))
return np.all(same)
return np.all(same)
13 changes: 7 additions & 6 deletions gflownet/utils/molecule/metrics.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# some functions inspired by: https://gist.github.com/ZhouGengmo/5b565f51adafcd911c0bc115b2ef027c

import numpy as np
import pandas as pd
import copy

import numpy as np
import pandas as pd
from rdkit import Chem

from rdkit.Chem import rdMolAlign as MA
from rdkit import Chem
from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule
from rdkit.Geometry.rdGeometry import Point3D


def get_best_rmsd(gen_mol, ref_mol):
gen_mol = Chem.RemoveHs(gen_mol)
ref_mol = Chem.RemoveHs(ref_mol)
rmsd = MA.GetBestRMS(gen_mol, ref_mol)
return rmsd


def get_cov_mat(ref_mols, gen_mols, threshold=1.25):
rmsd_mat = np.zeros([len(ref_mols), len(gen_mols)], dtype=np.float32)
for i, gen_mol in enumerate(gen_mols):
Expand All @@ -27,10 +27,11 @@ def get_cov_mat(ref_mols, gen_mols, threshold=1.25):
rmsd_mat_min = rmsd_mat.min(-1)
return (rmsd_mat_min <= threshold).mean(), rmsd_mat_min.mean()


def normalise_positions(mol):
conf = mol.GetConformer()
pos = conf.GetPositions()
pos = pos - pos.mean(axis=0)
pos = pos - pos.mean(axis=0)
for idx, p in enumerate(pos):
conf.SetAtomPosition(idx, Point3D(*p))
return mol
return mol
3 changes: 2 additions & 1 deletion gflownet/utils/molecule/rotatable_bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ def is_connected_to_three_hydrogens(mol, atom_id, except_id):
second = is_connected_to_three_hydrogens(mol, ta[2], ta[1])
return first or second


def has_hydrogen_tas(mol):
tas = get_rotatable_ta_list(mol)
hydrogen_flags = []
for t in tas:
hydrogen_flags.append(is_hydrogen_ta(mol, t))
return np.any(hydrogen_flags)
return np.any(hydrogen_flags)
3 changes: 2 additions & 1 deletion gflownet/utils/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
)
pass
try:
from bbdob import DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin, WModel
from bbdob import (DeceptiveTrap, FourPeaks, NKLandscape, OneMax, TwoMin,
WModel)
from bbdob.utils import idx2one_hot
except:
print(
Expand Down
1 change: 0 additions & 1 deletion playground/botorch/mes_exact_deepKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from math import floor

import gpytorch

# import tqdm
import torch
from botorch.test_functions import Hartmann
Expand Down
1 change: 0 additions & 1 deletion playground/botorch/mes_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Branin, Hartmann
Expand Down
7 changes: 4 additions & 3 deletions playground/botorch/mes_gp_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import gpytorch
import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Hartmann
Expand Down Expand Up @@ -50,8 +49,10 @@ def forward(self, x):

from botorch.models.utils import add_output_dim
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)
from gpytorch.likelihoods.gaussian_likelihood import \
FixedNoiseGaussianLikelihood


class myGPModel(SingleTaskGP):
Expand Down
5 changes: 2 additions & 3 deletions playground/botorch/mes_nn_bao_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Hartmann
Expand Down Expand Up @@ -56,8 +55,8 @@
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal

from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)
# from botorch.posteriors.
from torch.distributions import Normal

Expand Down
4 changes: 2 additions & 2 deletions playground/botorch/mes_nn_hardcode_gpVal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import torch

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.test_functions import Hartmann
Expand Down Expand Up @@ -57,7 +56,8 @@
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)


class NN_Model(Model):
Expand Down
5 changes: 2 additions & 3 deletions playground/botorch/mes_nn_like_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import numpy as np
import torch
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.test_functions import Hartmann
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)
from gpytorch.mlls import ExactMarginalLogLikelihood

# from botorch.posteriors.
from torch import distributions, tensor
from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential
Expand Down
5 changes: 2 additions & 3 deletions playground/botorch/mes_nn_like_gp_nondiagonalcovar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import numpy as np
import torch
from botorch.acquisition.max_value_entropy_search import qMaxValueEntropy

# from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.test_functions import Hartmann
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.distributions import (MultitaskMultivariateNormal,
MultivariateNormal)
from gpytorch.mlls import ExactMarginalLogLikelihood

# from botorch.posteriors.
from torch import distributions, tensor
from torch.nn import Dropout, Linear, MSELoss, ReLU, Sequential
Expand Down
5 changes: 1 addition & 4 deletions playground/botorch/mes_var_deepKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from math import floor

import gpytorch

# import tqdm
import torch
from botorch.test_functions import Hartmann
Expand Down Expand Up @@ -215,9 +214,7 @@ def posterior(


from botorch.acquisition.max_value_entropy_search import (
qLowerBoundMaxValueEntropy,
qMaxValueEntropy,
)
qLowerBoundMaxValueEntropy, qMaxValueEntropy)

proxy = myGPModel(model, train_x, train_y.unsqueeze(-1))
qMES = qLowerBoundMaxValueEntropy(proxy, candidate_set=train_x, use_gumbel=True)
Expand Down
Loading

0 comments on commit 436890f

Please sign in to comment.