Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Crystal-GFN] Restricted sampling #331

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions config/experiments/crystals/starling_fe_restricted_a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# @package _global_
#
# Restricted sampling A: only elements O and Fe, with maximum 10 atoms per element
#
# Forward trajectories (10) + Replay buffer (5) + Train set (5)
# Learning rate decay

defaults:
- override /env: crystals/crystal
- override /gflownet: trajectorybalance
- override /proxy: crystals/dave
- override /logger: wandb

device: cpu

# Environment
env:
do_composition_to_sg_constraints: False
do_sg_to_composition_constraints: True
do_sg_to_lp_constraints: True
do_sg_before_composition: True
composition_kwargs:
elements: [8, 26]
max_diff_elem: 5
min_diff_elem: 1
min_atoms: 1
max_atoms: 80
min_atom_i: 1
max_atom_i: 10
do_charge_check: True
space_group_kwargs:
space_groups_subset: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 25, 26, 29, 30, 31, 33, 36, 38, 40, 41, 43, 44, 46, 47, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 69, 70, 71, 72, 74, 82, 84, 85, 86, 87, 88, 92, 99, 102, 107, 113, 114, 121, 122, 123, 126, 129, 131, 136, 137, 138, 139, 140, 141, 146, 147, 148, 150, 155, 156, 160, 161, 162, 163, 164, 166, 167, 176, 181, 185, 186, 187, 189, 192, 194, 198, 199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230]
lattice_parameters_kwargs:
min_length: 0.9
max_length: 100.0
min_angle: 50.0
max_angle: 150.0
n_comp: 5
beta_params_min: 0.1
beta_params_max: 100.0
min_incr: 0.1
fixed_distr_params:
beta_weights: 1.0
beta_alpha: 10.0
beta_beta: 10.0
bernoulli_eos_prob: 0.1
bernoulli_bts_prob: 0.1
random_distr_params:
beta_weights: 1.0
beta_alpha: 10.0
beta_beta: 10.0
bernoulli_eos_prob: 0.1
bernoulli_bts_prob: 0.1
buffer:
replay_capacity: 1000
train:
type: csv
path: /network/projects/crystalgfn/data/eform/train.csv
test:
type: csv
path: /network/projects/crystalgfn/data/eform/val.csv

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
optimizer:
batch_size:
forward: 10
backward_replay: 5
backward_dataset: 5
lr: 0.0001
z_dim: 16
lr_z_mult: 100
n_train_steps: 100000
lr_decay_period: 11000
lr_decay_gamma: 0.5
replay_sampling: weighted
train_sampling: permutation

# Policy
policy:
forward:
type: mlp
n_hid: 256
n_layers: 3
checkpoint: forward
backward:
type: mlp
n_hid: 256
n_layers: 3
shared_weights: False
checkpoint: backward

# Proxy (eform)
proxy:
reward_min: 1e-08
do_clip_rewards: True
release: 0.3.4 # Formation energy release
# Boltzmann (exponential), with negative beta because the formation energy is negative and the lower the better
reward_function: exponential
# Parameters of the reward function
reward_function_kwargs:
beta: -8.0
alpha: 1.0

# Evaluator
evaluator:
first_it: False
period: -1
checkpoints_period: 500
n_trajs_logprobs: 100
logprobs_batch_size: 10
n: 10
n_top_k: 5000
top_k: 100
top_k_period: -1

# WandB
logger:
lightweight: True
project_name: "crystal-gfn"
tags:
- gflownet
- crystals
- stack
- matbench
- formationenergy
do:
online: true

# Hydra
hydra:
run:
dir: ${user.logdir.root}/crystalgfn/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S_%f}
135 changes: 135 additions & 0 deletions config/experiments/crystals/starling_fe_restricted_b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# @package _global_
#
# Restricted sampling B: Composition Li-Mn-O
# - max_diff_elem and min_diff_elem are set to 3
#
# Forward trajectories (10) + Replay buffer (5) + Train set (5)
# Learning rate decay

defaults:
- override /env: crystals/crystal
- override /gflownet: trajectorybalance
- override /proxy: crystals/dave
- override /logger: wandb

device: cpu

# Environment
env:
do_composition_to_sg_constraints: False
do_sg_to_composition_constraints: True
do_sg_to_lp_constraints: True
do_sg_before_composition: True
composition_kwargs:
elements: [3, 8, 25]
max_diff_elem: 3
min_diff_elem: 3
min_atoms: 1
max_atoms: 80
min_atom_i: 1
max_atom_i: 16
do_charge_check: True
space_group_kwargs:
space_groups_subset: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 25, 26, 29, 30, 31, 33, 36, 38, 40, 41, 43, 44, 46, 47, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 69, 70, 71, 72, 74, 82, 84, 85, 86, 87, 88, 92, 99, 102, 107, 113, 114, 121, 122, 123, 126, 129, 131, 136, 137, 138, 139, 140, 141, 146, 147, 148, 150, 155, 156, 160, 161, 162, 163, 164, 166, 167, 176, 181, 185, 186, 187, 189, 192, 194, 198, 199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230]
lattice_parameters_kwargs:
min_length: 0.9
max_length: 100.0
min_angle: 50.0
max_angle: 150.0
n_comp: 5
beta_params_min: 0.1
beta_params_max: 100.0
min_incr: 0.1
fixed_distr_params:
beta_weights: 1.0
beta_alpha: 10.0
beta_beta: 10.0
bernoulli_eos_prob: 0.1
bernoulli_bts_prob: 0.1
random_distr_params:
beta_weights: 1.0
beta_alpha: 10.0
beta_beta: 10.0
bernoulli_eos_prob: 0.1
bernoulli_bts_prob: 0.1
buffer:
replay_capacity: 1000
train:
type: csv
path: /network/projects/crystalgfn/data/eform/train.csv
test:
type: csv
path: /network/projects/crystalgfn/data/eform/val.csv

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
optimizer:
batch_size:
forward: 10
backward_replay: 5
backward_dataset: 5
lr: 0.0001
z_dim: 16
lr_z_mult: 100
n_train_steps: 100000
lr_decay_period: 11000
lr_decay_gamma: 0.5
replay_sampling: weighted
train_sampling: permutation

# Policy
policy:
forward:
type: mlp
n_hid: 256
n_layers: 3
checkpoint: forward
backward:
type: mlp
n_hid: 256
n_layers: 3
shared_weights: False
checkpoint: backward

# Proxy (eform)
proxy:
reward_min: 1e-08
do_clip_rewards: True
release: 0.3.4 # Formation energy release
# Boltzmann (exponential), with negative beta because the formation energy is negative and the lower the better
reward_function: exponential
# Parameters of the reward function
reward_function_kwargs:
beta: -8.0
alpha: 1.0

# Evaluator
evaluator:
first_it: False
period: -1
checkpoints_period: 500
n_trajs_logprobs: 100
logprobs_batch_size: 10
n: 10
n_top_k: 5000
top_k: 100
top_k_period: -1

# WandB
logger:
lightweight: True
project_name: "crystal-gfn"
tags:
- gflownet
- crystals
- stack
- matbench
- formationenergy
do:
online: true

# Hydra
hydra:
run:
dir: ${user.logdir.root}/crystalgfn/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S_%f}
Loading
Loading