diff --git a/config/env/ccube.yaml b/config/env/ccube.yaml index 57efa44ef..714638524 100644 --- a/config/env/ccube.yaml +++ b/config/env/ccube.yaml @@ -12,7 +12,7 @@ n_dim: 2 kappa: 1e-3 # Policy min_incr: 0.1 -n_comp: 1 +n_comp: 2 epsilon: 1e-6 beta_params_min: 0.1 beta_params_max: 100.0 diff --git a/config/env/crystals/ccrystal.yaml b/config/env/crystals/ccrystal.yaml new file mode 100644 index 000000000..47fd9455d --- /dev/null +++ b/config/env/crystals/ccrystal.yaml @@ -0,0 +1,27 @@ +defaults: + - base + +_target_: gflownet.envs.crystals.ccrystal.CCrystal + +# Composition config +id: ccrystal +composition_kwargs: + elements: 89 +# Lattice parameters config +lattice_parameters_kwargs: + min_length: 1.0 + max_length: 350.0 + min_angle: 50.0 + max_angle: 150.0 +# Space group config +space_group_kwargs: + space_groups_subset: null +# Stoichiometry <-> space group check +do_composition_to_sg_constraints: True +self.do_sg_to_lp_constraints: True + +# Buffer +buffer: + data_path: null + train: null + test: null diff --git a/config/env/crystals/clattice_parameters.yaml b/config/env/crystals/clattice_parameters.yaml new file mode 100644 index 000000000..da190ff97 --- /dev/null +++ b/config/env/crystals/clattice_parameters.yaml @@ -0,0 +1,43 @@ +defaults: + - base + +_target_: gflownet.envs.crystals.clattice_parameters.CLatticeParameters + +id: clattice_parameters + +# Lattice system +lattice_system: triclinic +# Allowed ranges of size and angles +min_length: 1.0 +max_length: 350.0 +min_angle: 50.0 +max_angle: 150.0 + +# Policy +min_incr: 0.1 +n_comp: 2 +epsilon: 1e-6 +beta_params_min: 0.1 +beta_params_max: 100.0 +fixed_distribution: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_bts_prob: 0.1 + bernoulli_eos_prob: 0.1 +random_distribution: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_bts_prob: 0.1 + bernoulli_eos_prob: 0.1 + +# Buffer +buffer: + data_path: null + train: null + test: + type: grid + n: 900 + output_csv: clp_test.csv + output_pkl: clp_test.pkl diff --git a/config/env/crystals/crystal.yaml b/config/env/crystals/crystal.yaml index 8a844ea17..3bb3640ef 100644 --- a/config/env/crystals/crystal.yaml +++ b/config/env/crystals/crystal.yaml @@ -7,6 +7,8 @@ _target_: gflownet.envs.crystals.crystal.Crystal id: crystal composition_kwargs: elements: 89 + max_atoms: 20 + max_atom_i: 16 lattice_parameters_kwargs: min_length: 1.0 max_length: 5.0 diff --git a/config/experiments/ccube/corners.yaml b/config/experiments/ccube/corners.yaml index e3594ac76..ccc207c6f 100644 --- a/config/experiments/ccube/corners.yaml +++ b/config/experiments/ccube/corners.yaml @@ -40,18 +40,17 @@ gflownet: z_dim: 16 lr_z_mult: 100 n_train_steps: 10000 - policy: - forward: - type: mlp - n_hid: 512 - n_layers: 5 - checkpoint: forward - backward: - type: mlp - n_hid: 512 - n_layers: 5 - shared_weights: False - checkpoint: backward + +# Policy +policy: + forward: + type: mlp + n_hid: 128 + n_layers: 2 + checkpoint: forward + backward: + shared_weights: True + checkpoint: backward # WandB logger: diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml new file mode 100644 index 000000000..87e44bfb5 --- /dev/null +++ b/config/experiments/ccube/hyperparams_search_20230920_batch1.yaml @@ -0,0 +1,145 @@ +# Shared config +shared: + slurm: {} + script: + user: $USER + device: cpu + logger: + project_name: cube + do: + online: True + test: + period: 500 + n: 900 + checkpoints: + period: 10000 + # Contiunuous Cube environment + env: + __value__: ccube + n_dim: 2 + # Buffer + buffer: + data_path: null + train: null + test: + type: grid + n: 1000 + output_csv: ccube_test.csv + output_pkl: ccube_test.pkl + # Proxy + proxy: corners + # GFlowNet config + gflownet: + __value__: trajectorybalance + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 100 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + # Policy + +gflownet: + policy: + forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward + # Use + to add new variables + +gflownet: + policy: + backward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: backward + shared_weights: False + +# Jobs +jobs: + - slurm: + job_name: pigeonish + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.01 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.7311 + bernoulli_bts_prob: 0.7311 + - slurm: + job_name: finch + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.7311 + bernoulli_bts_prob: 0.7311 + - slurm: + job_name: dove + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.7311 + bernoulli_bts_prob: 0.7311 + - slurm: + job_name: pine + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.01 + beta_params_max: 1000.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.7311 + bernoulli_bts_prob: 0.7311 + - slurm: + job_name: spruce + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 1000.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.7311 + bernoulli_bts_prob: 0.7311 + - slurm: + job_name: fir + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 1 + beta_params_max: 1000.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.7311 + bernoulli_bts_prob: 0.7311 diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml new file mode 100644 index 000000000..93491e3e9 --- /dev/null +++ b/config/experiments/ccube/hyperparams_search_20230920_batch2.yaml @@ -0,0 +1,145 @@ +# Shared config +shared: + slurm: {} + script: + user: $USER + device: cpu + logger: + project_name: cube + do: + online: True + test: + period: 500 + n: 900 + checkpoints: + period: 10000 + # Contiunuous Cube environment + env: + __value__: ccube + n_dim: 2 + # Buffer + buffer: + data_path: null + train: null + test: + type: grid + n: 1000 + output_csv: ccube_test.csv + output_pkl: ccube_test.pkl + # Proxy + proxy: corners + # GFlowNet config + gflownet: + __value__: trajectorybalance + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 100 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + # Policy + +gflownet: + policy: + forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward + # Use + to add new variables + +gflownet: + policy: + backward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: backward + shared_weights: False + +# Jobs +jobs: + - slurm: + job_name: large + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.01 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: cedar + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: hemlock + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: yew + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.01 + beta_params_max: 1000.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: cycad + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 1000.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: palm + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 1 + beta_params_max: 1000.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml new file mode 100644 index 000000000..7912af9b3 --- /dev/null +++ b/config/experiments/ccube/hyperparams_search_20230920_batch3.yaml @@ -0,0 +1,145 @@ +# Shared config +shared: + slurm: {} + script: + user: $USER + device: cpu + logger: + project_name: cube + do: + online: True + test: + period: 500 + n: 900 + checkpoints: + period: 10000 + # Contiunuous Cube environment + env: + __value__: ccube + n_dim: 2 + # Buffer + buffer: + data_path: null + train: null + test: + type: grid + n: 1000 + output_csv: ccube_test.csv + output_pkl: ccube_test.pkl + # Proxy + proxy: corners + # GFlowNet config + gflownet: + __value__: trajectorybalance + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 100 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + # Policy + +gflownet: + policy: + forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward + # Use + to add new variables + +gflownet: + policy: + backward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: backward + shared_weights: False + +# Jobs +jobs: + - slurm: + job_name: papaya + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 0.01 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: mango + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: pineapple + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: apple + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 0.01 + beta_params_max: 1000.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: pear + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 0.1 + beta_params_max: 1000.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: avocado + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 1 + beta_params_max: 1000.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 100.0 + beta_beta: 100.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 diff --git a/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml b/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml new file mode 100644 index 000000000..cc82e322c --- /dev/null +++ b/config/experiments/ccube/hyperparams_search_20230920_batch4.yaml @@ -0,0 +1,229 @@ +# Shared config +shared: + slurm: {} + script: + user: $USER + device: cpu + logger: + project_name: cube + do: + online: True + test: + period: 500 + n: 900 + checkpoints: + period: 10000 + # Contiunuous Cube environment + env: + __value__: ccube + n_dim: 2 + # Buffer + buffer: + data_path: null + train: null + test: + type: grid + n: 1000 + output_csv: ccube_test.csv + output_pkl: ccube_test.pkl + # Proxy + proxy: corners + # GFlowNet config + gflownet: + __value__: trajectorybalance + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 100 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + # Policy + +gflownet: + policy: + forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward + # Use + to add new variables + +gflownet: + policy: + backward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: backward + shared_weights: False + +# Jobs +jobs: + - slurm: + job_name: papaya + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: mango + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: pineapple + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.5 + bernoulli_bts_prob: 0.5 + - slurm: + job_name: apple + script: + env: + __value__: ccube + n_comp: 2 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.5 + bernoulli_bts_prob: 0.5 + - slurm: + job_name: papaya + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: mango + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: pineapple + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.5 + bernoulli_bts_prob: 0.5 + - slurm: + job_name: apple + script: + env: + __value__: ccube + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.5 + bernoulli_bts_prob: 0.5 + - slurm: + job_name: papaya + script: + env: + __value__: ccube + n_comp: 1 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: mango + script: + env: + __value__: ccube + n_comp: 1 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + - slurm: + job_name: pineapple + script: + env: + __value__: ccube + n_comp: 1 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.5 + bernoulli_bts_prob: 0.5 + - slurm: + job_name: apple + script: + env: + __value__: ccube + n_comp: 1 + beta_params_min: 0.1 + beta_params_max: 100.0 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.5 + bernoulli_bts_prob: 0.5 diff --git a/config/experiments/ccube/uniform.yaml b/config/experiments/ccube/uniform.yaml index 6970a3e95..a81d58d05 100644 --- a/config/experiments/ccube/uniform.yaml +++ b/config/experiments/ccube/uniform.yaml @@ -40,18 +40,17 @@ gflownet: z_dim: 16 lr_z_mult: 100 n_train_steps: 10000 - policy: - forward: - type: mlp - n_hid: 256 - n_layers: 3 - checkpoint: forward - backward: - type: mlp - n_hid: 256 - n_layers: 3 - shared_weights: False - checkpoint: backward + +# Policy +policy: + forward: + type: mlp + n_hid: 128 + n_layers: 2 + checkpoint: forward + backward: + shared_weights: True + checkpoint: backward # WandB logger: diff --git a/config/experiments/clatticeparams/clatticeparams_owl.yaml b/config/experiments/clatticeparams/clatticeparams_owl.yaml new file mode 100644 index 000000000..30f1c1347 --- /dev/null +++ b/config/experiments/clatticeparams/clatticeparams_owl.yaml @@ -0,0 +1,80 @@ +# @package _global_ + +defaults: + - override /env: crystals/clattice_parameters + - override /gflownet: trajectorybalance + - override /proxy: corners + - override /logger: wandb + - override /user: alex + +# Environment +env: + # Lattice system + lattice_system: cubic + # Allowed ranges of size and angles + min_length: 1.0 + max_length: 5.0 + min_angle: 30.0 + max_angle: 150.0 + # Cube + n_comp: 5 + beta_params_min: 0.01 + beta_params_max: 100.0 + min_incr: 0.1 + fixed_distribution: + beta_weights: 1.0 + beta_alpha: 0.01 + beta_beta: 0.01 + bernoulli_source_logit: 1.0 + bernoulli_eos_logit: 1.0 + random_distribution: + beta_weights: 1.0 + beta_alpha: 0.01 + beta_beta: 0.01 + bernoulli_source_logit: 1.0 + bernoulli_eos_logit: 1.0 + reward_func: identity + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 100 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + +# Policy +policy: + forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward + backward: + type: mlp + n_hid: 512 + n_layers: 5 + shared_weights: False + checkpoint: backward + +# WandB +logger: + lightweight: True + project_name: "GFlowNet Cube" + tags: + - gflownet + - continuous + - ccube + test: + period: 500 + n: 1000 + checkpoints: + period: 500 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/debug/ccube/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/experiments/crystals/albatross.yaml b/config/experiments/crystals/albatross.yaml new file mode 100644 index 000000000..db33a21c1 --- /dev/null +++ b/config/experiments/crystals/albatross.yaml @@ -0,0 +1,108 @@ +# @package _global_ + +defaults: + - override /env: crystals/ccrystal + - override /gflownet: trajectorybalance + - override /proxy: crystals/dave + - override /logger: wandb + +device: cpu + +# Environment +env: + do_composition_to_sg_constraints: True + do_sg_to_lp_constraints: True + composition_kwargs: + elements: [1, 3, 6, 7, 8, 9, 12, 14, 15, 16, 17, 26] + min_atoms: 2 + max_atoms: 50 + min_atom_i: 1 + max_atom_i: 16 + do_charge_check: True + space_group_kwargs: + space_groups_subset: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 25, 26, 29, 30, 31, 33, 36, 38, 40, 41, 43, 44, 46, 47, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 69, 70, 71, 72, 74, 82, 84, 85, 86, 87, 88, 92, 99, 102, 107, 113, 114, 121, 122, 123, 126, 129, 131, 136, 137, 138, 139, 140, 141, 146, 147, 148, 150, 155, 156, 160, 161, 162, 163, 164, 166, 167, 176, 181, 185, 186, 187, 189, 192, 194, 198, 199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230] + lattice_parameters_kwargs: + min_length: 0.9 + max_length: 100.0 + min_angle: 50.0 + max_angle: 150.0 + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + min_incr: 0.1 + fixed_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + reward_func: boltzmann + reward_beta: 8 + buffer: + replay_capacity: 0 + test: + type: pkl + path: /home/mila/h/hernanga/gflownet/data/crystals/matbench_normed_l0.9-100_a50-150_val_12_SGinter_states_energy.pkl + output_csv: ccrystal_val.csv + output_pkl: ccrystal_val.pkl + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 10 + backward_replay: -1 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 50000 + lr_decay_period: 1000000 + replay_sampling: weighted + +# Policy +policy: + forward: + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: forward + backward: + type: mlp + n_hid: 256 + n_layers: 3 + shared_weights: False + checkpoint: backward + +# WandB +logger: + lightweight: True + project_name: "crystal-gfn" + tags: + - gflownet + - crystals + - matbench + - workshop23 + checkpoints: + period: 500 + do: + online: true + test: + n_trajs_logprobs: 10 + period: 500 + n: 10 + n_top_k: 5000 + top_k: 100 + top_k_period: -1 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/workshop23/discrete-matbench/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S} + diff --git a/config/experiments/crystals/pigeon.yaml b/config/experiments/crystals/pigeon.yaml new file mode 100644 index 000000000..880647ce1 --- /dev/null +++ b/config/experiments/crystals/pigeon.yaml @@ -0,0 +1,112 @@ +# @package _global_ +# No constraints: +# - no charge check +# - no composition to space group constraint +# - no space group to lattice parameters constraint + +defaults: + - override /env: crystals/ccrystal + - override /gflownet: trajectorybalance + - override /proxy: crystals/dave + - override /logger: wandb + +device: cpu + +# Environment +env: + do_composition_to_sg_constraints: False + do_sg_to_lp_constraints: False + composition_kwargs: + elements: [1, 3, 6, 7, 8, 9, 12, 14, 15, 16, 17, 26] + min_atoms: 2 + max_atoms: 50 + min_atom_i: 1 + max_atom_i: 16 + do_charge_check: False + space_group_kwargs: + space_groups_subset: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 25, 26, 29, 30, 31, 33, 36, 38, 40, 41, 43, 44, 46, 47, 51, 52, 53, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 69, 70, 71, 72, 74, 82, 84, 85, 86, 87, 88, 92, 99, 102, 107, 113, 114, 121, 122, 123, 126, 129, 131, 136, 137, 138, 139, 140, 141, 146, 147, 148, 150, 155, 156, 160, 161, 162, 163, 164, 166, 167, 176, 181, 185, 186, 187, 189, 192, 194, 198, 199, 205, 206, 216, 217, 220, 221, 224, 225, 227, 229, 230] + lattice_parameters_kwargs: + min_length: 0.9 + max_length: 100.0 + min_angle: 50.0 + max_angle: 150.0 + n_comp: 5 + beta_params_min: 0.1 + beta_params_max: 100.0 + min_incr: 0.1 + fixed_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + random_distr_params: + beta_weights: 1.0 + beta_alpha: 10.0 + beta_beta: 10.0 + bernoulli_eos_prob: 0.1 + bernoulli_bts_prob: 0.1 + reward_func: boltzmann + reward_beta: 8 + buffer: + replay_capacity: 0 + test: + type: pkl + path: /home/mila/h/hernanga/gflownet/data/crystals/matbench_normed_l0.9-100_a50-150_val_12_SGinter_states_energy.pkl + output_csv: ccrystal_val.csv + output_pkl: ccrystal_val.pkl + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 10 + backward_replay: -1 + lr: 0.0001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 50000 + lr_decay_period: 1000000 + replay_sampling: weighted + +# Policy +policy: + forward: + type: mlp + n_hid: 256 + n_layers: 3 + checkpoint: forward + backward: + type: mlp + n_hid: 256 + n_layers: 3 + shared_weights: False + checkpoint: backward + +# WandB +logger: + lightweight: True + project_name: "crystal-gfn" + tags: + - gflownet + - crystals + - matbench + - workshop23 + checkpoints: + period: 500 + do: + online: true + test: + n_trajs_logprobs: 10 + period: 500 + n: 10 + n_top_k: 5000 + top_k: 100 + top_k_period: -1 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/workshop23/discrete-matbench/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S} + diff --git a/config/experiments/workshop23/discrete-matbench.yaml b/config/experiments/workshop23/discrete-matbench.yaml new file mode 100644 index 000000000..a68b0c367 --- /dev/null +++ b/config/experiments/workshop23/discrete-matbench.yaml @@ -0,0 +1,77 @@ +# @package _global_ + +defaults: + - override /env: crystals/crystal + - override /gflownet: trajectorybalance + - override /proxy: crystals/dave + - override /logger: wandb + +device: cpu + +# Environment +env: + lattice_parameters_kwargs: + min_length: 1.0 + max_length: 350.0 + min_angle: 50.0 + max_angle: 150.0 + grid_size: 10 + composition_kwargs: + elements: [1,3,4,5,6,7,8,9,11,12,13,14,15,16,17,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,89,90,91,92,93,94] + reward_func: boltzmann + reward_beta: 1 + buffer: + replay_capacity: 0 + +# GFlowNet hyperparameters +gflownet: + random_action_prob: 0.1 + optimizer: + batch_size: + forward: 10 + backward_replay: -1 + lr: 0.001 + z_dim: 16 + lr_z_mult: 100 + n_train_steps: 10000 + lr_decay_period: 1000000 + replay_sampling: weighted + +# Policy +policy: + forward: + type: mlp + n_hid: 512 + n_layers: 5 + checkpoint: forward + backward: + type: mlp + n_hid: 512 + n_layers: 5 + shared_weights: False + checkpoint: backward + +# WandB +logger: + lightweight: True + project_name: "crystal-gfn" + tags: + - gflownet + - crystals + - matbench + - workshop23 + checkpoints: + period: 500 + do: + online: true + test: + period: -1 + n: 500 + n_top_k: 5000 + top_k: 100 + top_k_period: -1 + +# Hydra +hydra: + run: + dir: ${user.logdir.root}/workshop23/discrete-matbench/${oc.env:SLURM_JOB_ID,local}/${now:%Y-%m-%d_%H-%M-%S} diff --git a/config/logger/base.yaml b/config/logger/base.yaml index 64dab5d05..7707b75c5 100644 --- a/config/logger/base.yaml +++ b/config/logger/base.yaml @@ -45,3 +45,4 @@ debug: False lightweight: False progress: True context: "0" +notes: null # wandb run notes (e.g. "baseline") diff --git a/config/proxy/crystals/dave.yaml b/config/proxy/crystals/dave.yaml index 20865c20c..d61c938fa 100644 --- a/config/proxy/crystals/dave.yaml +++ b/config/proxy/crystals/dave.yaml @@ -1,7 +1,13 @@ _target_: gflownet.proxy.crystals.dave.DAVE -release: 0.3.2 +release: 0.3.4 ckpt_path: mila: /network/scratch/s/schmidtv/crystals-proxys/proxy-ckpts/ victor: ~/Documents/Github/ActiveLearningMaterials/checkpoints/980065c0/checkpoints-3ff648a2 rescale_outputs: true +clip: + do: False + min_stds: null + max_stds: null + min: null + max: null diff --git a/data/crystals/matbench_val_12_SGinter_states_energy.pkl b/data/crystals/matbench_val_12_SGinter_states_energy.pkl new file mode 100644 index 000000000..456914f24 Binary files /dev/null and b/data/crystals/matbench_val_12_SGinter_states_energy.pkl differ diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 9b5e81f3a..278490adc 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -90,6 +90,7 @@ def __init__( self.action_space, device=self.device, dtype=self.float ) self.action_space_dim = len(self.action_space) + self.mask_dim = self.action_space_dim # Max trajectory length self.max_traj_length = self.get_max_traj_length() # Policy outputs @@ -472,8 +473,14 @@ def sample_actions_batch( if sampling_method == "uniform": logits = torch.ones(policy_outputs.shape, dtype=self.float, device=device) elif sampling_method == "policy": - logits = policy_outputs + logits = policy_outputs.clone().detach() logits /= temperature_logits + else: + raise NotImplementedError( + f"Sampling method {sampling_method} is invalid. " + "Options are: policy, uniform." + ) + if mask is not None: assert not torch.all(mask), dedent( """ @@ -539,7 +546,7 @@ def get_logprobs( """ device = policy_outputs.device ns_range = torch.arange(policy_outputs.shape[0]).to(device) - logits = policy_outputs + logits = policy_outputs.clone() if mask is not None: logits[mask] = -torch.inf action_indices = ( @@ -607,7 +614,7 @@ def trajectory_random(self): The list of actions (tuples) in the trajectory. """ actions = [] - while self.done is not True: + while not self.done: _, action, valid = self.step_random() if valid: actions.append(action) @@ -663,7 +670,9 @@ def get_random_terminating_states( count += 1 return states - def get_policy_output(self, params: Optional[dict] = None): + def get_policy_output( + self, params: Optional[dict] = None + ) -> TensorType["policy_output_dim"]: """ Defines the structure of the output of the policy model, from which an action is to be determined or sampled, by returning a vector with a fixed @@ -672,7 +681,7 @@ def get_policy_output(self, params: Optional[dict] = None): Continuous environments will generally have to overwrite this method. """ - return np.ones(self.action_space_dim) + return torch.ones(self.action_space_dim, dtype=self.float, device=self.device) def state2proxy(self, state: List = None): """ diff --git a/gflownet/envs/crystals/ccrystal.py b/gflownet/envs/crystals/ccrystal.py new file mode 100644 index 000000000..19830339f --- /dev/null +++ b/gflownet/envs/crystals/ccrystal.py @@ -0,0 +1,923 @@ +import json +from collections import OrderedDict +from enum import Enum +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor +from torchtyping import TensorType + +from gflownet.envs.base import GFlowNetEnv +from gflownet.envs.crystals.clattice_parameters import CLatticeParameters +from gflownet.envs.crystals.composition import Composition +from gflownet.envs.crystals.spacegroup import SpaceGroup +from gflownet.utils.common import copy, tbool, tfloat, tlong +from gflownet.utils.crystals.constants import TRICLINIC + + +class Stage(Enum): + """ + In addition to encoding current stage, contains methods used for padding individual + component environment's actions (to ensure they have the same length for + tensorization). + """ + + COMPOSITION = 0 + SPACE_GROUP = 1 + LATTICE_PARAMETERS = 2 + DONE = 3 + + def next(self) -> "Stage": + """ + Returns the next Stage in the enumeration or None if at the last stage. + """ + if self.value + 1 == len(Stage): + return None + return Stage(self.value + 1) + + def prev(self) -> "Stage": + """ + Returns the previous Stage in the enumeration or DONE if from the first stage. + """ + if self.value - 1 < 0: + return Stage.DONE + return Stage(self.value - 1) + + def to_pad(self) -> int: + """ + Maps stage value to a padding. The following mapping is used: + + COMPOSITION = -2 + SPACE_GROUP = -3 + LATTICE_PARAMETERS = -4 + + We use negative numbers starting from -2 because they are not used by any of + the underlying environments, which should lead to every padded action being + unique. + """ + return -(self.value + 2) + + @classmethod + def from_pad(cls, pad_value: int) -> "Stage": + return Stage(-pad_value - 2) + + +class CCrystal(GFlowNetEnv): + """ + A combination of Composition, SpaceGroup and CLatticeParameters into a single + environment. Works sequentially, by first filling in the Composition, then + SpaceGroup, and finally CLatticeParameters. + """ + + def __init__( + self, + composition_kwargs: Optional[Dict] = None, + space_group_kwargs: Optional[Dict] = None, + lattice_parameters_kwargs: Optional[Dict] = None, + do_composition_to_sg_constraints: bool = True, + do_sg_to_lp_constraints: bool = True, + **kwargs, + ): + self.composition_kwargs = composition_kwargs or {} + self.space_group_kwargs = space_group_kwargs or {} + self.lattice_parameters_kwargs = lattice_parameters_kwargs or {} + self.do_composition_to_sg_constraints = do_composition_to_sg_constraints + self.do_sg_to_lp_constraints = do_sg_to_lp_constraints + + composition = Composition(**self.composition_kwargs) + space_group = SpaceGroup(**self.space_group_kwargs) + # We initialize lattice parameters with triclinic lattice system as it is the + # most general one, but it will have to be reinitialized using proper lattice + # system from space group once that is determined. + lattice_parameters = CLatticeParameters( + lattice_system=TRICLINIC, **self.lattice_parameters_kwargs + ) + self.subenvs = OrderedDict( + { + Stage.COMPOSITION: composition, + Stage.SPACE_GROUP: space_group, + Stage.LATTICE_PARAMETERS: lattice_parameters, + } + ) + + # 0-th element of state encodes current stage: 0 for composition, + # 1 for space group, 2 for lattice parameters + self.source = [Stage.COMPOSITION.value] + for subenv in self.subenvs.values(): + self.source.extend(subenv.source) + + # Get action dimensionality by computing the maximum action length among all + # sub-environments. + self.max_action_length = max( + [len(subenv.eos) for subenv in self.subenvs.values()] + ) + + # EOS is EOS of the last stage (lattice parameters) + self.eos = self._pad_action( + self.subenvs[Stage.LATTICE_PARAMETERS].eos, Stage.LATTICE_PARAMETERS + ) + + # Mask dimensionality + self.mask_dim = sum([subenv.mask_dim for subenv in self.subenvs.values()]) + + # Base class init + # Since only the lattice parameters subenv has distribution parameters, only + # these are pased to the base init. + super().__init__( + fixed_distr_params=self.subenvs[ + Stage.LATTICE_PARAMETERS + ].fixed_distr_params, + random_distr_params=self.subenvs[ + Stage.LATTICE_PARAMETERS + ].random_distr_params, + **kwargs, + ) + self.continuous = True + + # TODO: remove or redo + def _set_lattice_parameters(self): + """ + Sets CLatticeParameters conditioned on the lattice system derived from the + SpaceGroup. + """ + if self.subenvs[Stage.SPACE_GROUP].lattice_system == "None": + raise ValueError( + "Cannot set lattice parameters without lattice system determined in " + "the space group." + ) + self.subenvs[Stage.LATTICE_PARAMETERS] = CLatticeParameters( + lattice_system=self.subenvs[Stage.SPACE_GROUP].lattice_system, + **self.lattice_parameters_kwargs, + ) + + def _pad_action(self, action: Tuple[int], stage: Stage) -> Tuple[int]: + """ + Pads action such that all actions, regardless of the underlying environment, + have the same length. Required due to the fact that action space has to be + convertable to a tensor. + """ + return action + (Stage.to_pad(stage),) * (self.max_action_length - len(action)) + + def _pad_action_space( + self, action_space: List[Tuple[int]], stage: Stage + ) -> List[Tuple[int]]: + return [self._pad_action(a, stage) for a in action_space] + + def _depad_action(self, action: Tuple[int], stage: Stage) -> Tuple[int]: + """ + Reverses padding operation, such that the resulting action can be passed to the + underlying environment. + """ + return action[: len(self.subenvs[stage].eos)] + + # TODO: consider removing if unused because too simple + def _get_actions_of_subenv( + self, actions: TensorType["n_states", "action_dim"], stage: Stage + ): + """ + Returns the columns of a tensor of actions that correspond to the + sub-environment indicated by stage. + + Args + actions + mask : tensor + A tensor containing a batch of actions. It is assumed that all the rows in + the this tensor correspond to the same stage. + + stage : Stage + Identifier of the sub-environment of which the corresponding columns of the + actions are to be extracted. + """ + return actions[:, len(self.subenvs[stage].eos)] + + def get_action_space(self) -> List[Tuple[int]]: + action_space = [] + for stage, subenv in self.subenvs.items(): + action_space.extend(self._pad_action_space(subenv.action_space, stage)) + + if len(action_space) != len(set(action_space)): + raise ValueError( + "Detected duplicate actions between different components of Crystal " + "environment." + ) + + return action_space + + def action2representative(self, action: Tuple) -> Tuple: + """ + Replaces the continuous values of lattice parameters actions by the + representative action of the environment so that it can be compared against the + action space. + """ + if self._get_stage() == Stage.LATTICE_PARAMETERS: + return self.subenvs[Stage.LATTICE_PARAMETERS].action2representative( + self._depad_action(action, Stage.LATTICE_PARAMETERS) + ) + return action + + def get_max_traj_length(self) -> int: + return sum([subenv.get_max_traj_length() for subenv in self.subenvs.values()]) + + def get_policy_output(self, params: dict) -> TensorType["policy_output_dim"]: + """ + Defines the structure of the output of the policy model. + + The policy output is in this case the concatenation of the policy outputs of + the three sub-environments. + """ + return torch.cat( + [subenv.get_policy_output(params) for subenv in self.subenvs.values()] + ) + + def _get_policy_outputs_of_subenv( + self, policy_outputs: TensorType["n_states", "policy_output_dim"], stage: Stage + ): + """ + Returns the columns of the policy outputs that correspond to the + sub-environment indicated by stage. + + Args + ---- + policy_outputs : tensor + A tensor containing a batch of policy outputs. It is assumed that all the + rows in the this tensor correspond to the same stage. + + stage : Stage + Identifier of the sub-environment of which the corresponding columns of the + policy outputs are to be extracted. + """ + init_col = 0 + for stg, subenv in self.subenvs.items(): + end_col = init_col + subenv.policy_output_dim + if stg == stage: + return policy_outputs[:, init_col:end_col] + init_col = end_col + + def _get_mask_of_subenv( + self, mask: Union[List, TensorType["n_states", "mask_dim"]], stage: Stage + ): + """ + Returns the columns of a tensor of masks that correspond to the sub-environment + indicated by stage. + + Args + ---- + mask : list or tensor + A mask of a single state as a list or a tensor containing a batch of masks. + It is assumed that all the rows in the this tensor correspond to the same + stage. + + stage : Stage + Identifier of the sub-environment of which the corresponding columns of the + masks are to be extracted. + """ + init_col = 0 + for stg, subenv in self.subenvs.items(): + end_col = init_col + subenv.mask_dim + if stg == stage: + if isinstance(mask, list): + return mask[init_col:end_col] + else: + return mask[:, init_col:end_col] + init_col = end_col + + def reset(self, env_id: Union[int, str] = None): + self.subenvs[Stage.COMPOSITION].reset() + self.subenvs[Stage.SPACE_GROUP].reset() + self.subenvs[Stage.LATTICE_PARAMETERS] = CLatticeParameters( + lattice_system=TRICLINIC, **self.lattice_parameters_kwargs + ) + + super().reset(env_id=env_id) + self._set_stage(Stage.COMPOSITION) + + return self + + def _get_stage(self, state: Optional[List] = None) -> Stage: + """ + Returns the stage of the current environment from self.state[0] or from the + state passed as an argument. + """ + if state is None: + state = self.state + return Stage(state[0]) + + def _set_stage(self, stage: Stage, state: Optional[List] = None): + """ + Sets the stage of the current environment (self.state) or of the state passed + as an argument by updating state[0]. + """ + if state is None: + state = self.state + state[0] = stage.value + + def _get_policy_states_of_subenv( + self, state: TensorType["n_states", "state_dim"], stage: Stage + ): + """ + Returns the part of the states corresponding to the subenv indicated by stage. + + Args + ---- + states : tensor + A tensor containing a batch of states in policy format. + + stage : Stage + Identifier of the sub-environment of which the corresponding columns of the + batch of states are to be extracted. + """ + init_col = 0 + for stg, subenv in self.subenvs.items(): + end_col = init_col + subenv.policy_input_dim + if stg == stage: + return states[:, init_col:end_col] + init_col = end_col + + def _get_state_of_subenv(self, state: List, stage: Optional[Stage] = None): + """ + Returns the part of the state corresponding to the subenv indicated by stage. + + Args + ---- + state : list + A state of the parent Crystal environment. + + stage : Stage + Identifier of the sub-environment of which the corresponding part of the + state is to be extracted. If None, it is inferred from the state. + """ + if stage is None: + stage = self._get_stage(state) + init_col = 1 + for stg, subenv in self.subenvs.items(): + end_col = init_col + len(subenv.source) + if stg == stage: + return state[init_col:end_col] + init_col = end_col + + def _get_states_of_subenv( + self, states: TensorType["n_states", "state_dim"], stage: Stage + ): + """ + Returns the part of the batch of states corresponding to the subenv indicated + by stage. + + Args + ---- + states : tensor + A batch of states of the parent Crystal environment. + + stage : Stage + Identifier of the sub-environment of which the corresponding part of the + states is to be extracted. If None, it is inferred from the states. + """ + init_col = 1 + for stg, subenv in self.subenvs.items(): + end_col = init_col + len(subenv.source) + if stg == stage: + return states[:, init_col:end_col] + init_col = end_col + + # TODO: set mask of done state if stage is not the current one for correctness. + def get_mask_invalid_actions_forward( + self, state: Optional[List[int]] = None, done: Optional[bool] = None + ) -> List[bool]: + """ + Computes the forward actions mask of the state. + + The mask of the parent crystal is simply the concatenation of the masks of the + three sub-environments. This assumes that the methods that will use the mask + will extract the part corresponding to the relevant stage and ignore the rest. + """ + state = self._get_state(state) + done = self._get_done(done) + + mask = [] + for stage, subenv in self.subenvs.items(): + mask.extend( + subenv.get_mask_invalid_actions_forward( + self._get_state_of_subenv(state, stage), done + ) + ) + return mask + + # TODO: this piece of code looks awful + def get_mask_invalid_actions_backward( + self, state: Optional[List[int]] = None, done: Optional[bool] = None + ) -> List[bool]: + """ + Computes the backward actions mask of the state. + + The mask of the parent crystal is, in general, simply the concatenation of the + masks of the three sub-environments. Only the mask of the state of the current + sub-environment is computed; for the other sub-environments, the mask of the + source is used. Note that this assumes that the methods that will use the mask + will extract the part corresponding to the relevant stage and ignore the rest. + + Nonetheless, in order to enable backward transitions between stages, the EOS + action of the preceding stage has to be the only valid action when the state of + a sub-environment is the source. Additionally, sample_batch_actions will have + to also detect the source states and change the stage. + + Note that the sub-environments are iterated in reversed order so as to save + unnecessary computations and simplify the code. + """ + state = self._get_state(state) + done = self._get_done(done) + stage = self._get_stage(state) + + mask = [] + do_eos_only = False + # Iterate stages in reverse order + for stg, subenv in reversed(self.subenvs.items()): + state_subenv = self._get_state_of_subenv(state, stg) + # Set mask of done state because state of next subenv is source + if do_eos_only: + mask_subenv = subenv.get_mask_invalid_actions_backward( + state_subenv, done=True + ) + do_eos_only = False + # General case + else: + # stg is the current stage + if stg == stage: + # state of subenv is the source state + if stg != Stage(0) and state_subenv == subenv.source: + do_eos_only = True + mask_subenv = subenv.get_mask_invalid_actions_backward( + subenv.source + ) + # General case + else: + mask_subenv = subenv.get_mask_invalid_actions_backward( + state_subenv, done + ) + # stg is not current stage, so set mask of source + else: + mask_subenv = subenv.get_mask_invalid_actions_backward( + subenv.source + ) + mask.extend(mask_subenv[::-1]) + return mask[::-1] + + def _update_state(self, stage: Stage): + """ + Updates the global state based on the states of the sub-environments and the + stage passed as an argument. + """ + state = [stage.value] + for subenv in self.subenvs.values(): + state.extend(subenv.state) + return state + + def step( + self, action: Tuple[int], skip_mask_check: bool = False + ) -> Tuple[List[int], Tuple[int], bool]: + """ + Executes forward step given an action. + + The action is performed by the corresponding sub-environment and then the + global state is updated accordingly. If the action is the EOS of the + sub-environment, the stage is advanced and constraints are set on the + subsequent sub-environment. + + Args + ---- + action : tuple + Action to be executed. The input action is global, that is padded. + + Returns + ------- + self.state : list + The state after executing the action. + + action : int + Action executed. + + valid : bool + False, if the action is not allowed for the current state. True otherwise. + """ + stage = self._get_stage(self.state) + # Skip mask check if stage is lattice parameters (continuous actions) + if stage == Stage.LATTICE_PARAMETERS: + skip_mask_check = True + # Replace action by its representative to check against the mask. + action_to_check = self.action2representative(action) + do_step, self.state, action_to_check = self._pre_step( + action_to_check, + skip_mask_check=(skip_mask_check or self.skip_mask_check), + ) + if not do_step: + return self.state, action, False + + # Call step of current subenvironment + action_subenv = self._depad_action(action, stage) + _, action_subenv, valid = self.subenvs[stage].step(action_subenv) + + # If action is invalid, exit immediately. Otherwise increment actions and go on + if not valid: + return self.state, action, False + self.n_actions += 1 + + # If action is EOS of subenv, advance stage and set constraints or exit + if action_subenv == self.subenvs[stage].eos: + stage = Stage.next(stage) + if stage == Stage.SPACE_GROUP: + if self.do_composition_to_sg_constraints: + self.subenvs[Stage.SPACE_GROUP].set_n_atoms_compatibility_dict( + self.subenvs[Stage.COMPOSITION].state + ) + elif stage == Stage.LATTICE_PARAMETERS: + if self.do_sg_to_lp_constraints: + lattice_system = self.subenvs[Stage.SPACE_GROUP].lattice_system + self.subenvs[Stage.LATTICE_PARAMETERS].set_lattice_system( + lattice_system + ) + elif stage == Stage.DONE: + self.n_actions += 1 + self.done = True + return self.state, self.eos, True + else: + raise ValueError(f"Unrecognized stage {stage}.") + + self.state = self._update_state(stage) + return self.state, action, valid + + def step_backwards( + self, action: Tuple[int], skip_mask_check: bool = False + ) -> Tuple[List[int], Tuple[int], bool]: + """ + Executes backward step given an action. + + The action is performed by the corresponding sub-environment and then the + global state is updated accordingly. If the updated state of the + sub-environment becomes its source, the stage is decreased. + + Args + ---- + action : tuple + Action to be executed. The input action is global, that is padded. + + Returns + ------- + self.state : list + The state after executing the action. + + action : int + Action executed. + + valid : bool + False, if the action is not allowed for the current state. True otherwise. + """ + stage = self._get_stage(self.state) + # Skip mask check if stage is lattice parameters (continuous actions) + if stage == Stage.LATTICE_PARAMETERS: + skip_mask_check = True + # Replace action by its representative to check against the mask. + action_to_check = self.action2representative(action) + do_step, self.state, action_to_check = self._pre_step( + action_to_check, + backward=True, + skip_mask_check=(skip_mask_check or self.skip_mask_check), + ) + if not do_step: + return self.state, action, False + + # If state of subenv is source of subenv, decrease stage + if self._get_state_of_subenv(self.state, stage) == self.subenvs[stage].source: + stage = Stage.prev(stage) + # If stage is DONE, set global source and return + if stage == Stage.DONE: + self.state = self.source + return self.state, action, True + + # Call step of current subenvironment + action_subenv = self._depad_action(action, stage) + state_next, _, valid = self.subenvs[stage].step_backwards(action_subenv) + + # If action is invalid, exit immediately. Otherwise continue, + if not valid: + return self.state, action, False + self.n_actions += 1 + + # If action from done, set done False + if self.done: + assert action == self.eos + self.done = False + + self.state = self._update_state(stage) + return self.state, action, valid + + def sample_actions_batch( + self, + policy_outputs: TensorType["n_states", "policy_output_dim"], + mask: Optional[TensorType["n_states", "policy_output_dim"]] = None, + states_from: List = None, + is_backward: Optional[bool] = False, + sampling_method: Optional[str] = "policy", + temperature_logits: Optional[float] = 1.0, + max_sampling_attempts: Optional[int] = 10, + ) -> Tuple[List[Tuple], TensorType["n_states"]]: + """ + Samples a batch of actions from a batch of policy outputs. + + This method calls the sample_actions_batch() method of the sub-environment + corresponding to each state in the batch. For composition and space_group it + will be the method from the base discrete environment; for the lattice + parameters, it will be the method from the cube environment. + + Note that in order to call sample_actions_batch() of the sub-environments, we + need to first extract the part of the policy outputs, the masks and the states + that correspond to the sub-environment. + """ + states_dict = {stage: [] for stage in Stage} + """ + A dictionary with keys equal to Stage and the values are the list of states in + the stage of the key. The states are only the part corresponding to the + sub-environment. + """ + stages = [] + for s in states_from: + stage = self._get_stage(s) + state_subenv = self._get_state_of_subenv(s, stage) + # If the actions are backwards and state is source of subenv, decrease + # stage so that EOS of preceding stage is sampled. + if ( + is_backward + and stage != Stage(0) + and state_subenv == self.subenvs[stage].source + ): + stage = Stage.prev(stage) + states_dict[stage].append(state_subenv) + stages.append(stage) + stages_tensor = tlong([stage.value for stage in stages], device=self.device) + is_subenv_dict = {stage: stages_tensor == stage.value for stage in Stage} + + # Sample actions from each sub-environment + actions_logprobs_dict = { + stage: subenv.sample_actions_batch( + self._get_policy_outputs_of_subenv( + policy_outputs[is_subenv_dict[stage]], stage + ), + self._get_mask_of_subenv(mask[is_subenv_dict[stage]], stage), + states_dict[stage], + is_backward, + sampling_method, + temperature_logits, + max_sampling_attempts, + ) + for stage, subenv in self.subenvs.items() + if torch.any(is_subenv_dict[stage]) + } + + # Stitch all actions in the right order, with the right padding + actions = [] + for stage in stages: + actions.append( + self._pad_action(actions_logprobs_dict[stage][0].pop(0), stage) + ) + return actions, None + + def get_logprobs( + self, + policy_outputs: TensorType["n_states", "policy_output_dim"], + actions: TensorType["n_states", "actions_dim"], + mask: TensorType["n_states", "mask_dim"], + states_from: List, + is_backward: bool, + ) -> TensorType["batch_size"]: + """ + Computes log probabilities of actions given policy outputs and actions. + + Args + ---- + policy_outputs : tensor + The output of the GFlowNet policy model. + + mask : tensor + The mask containing information about invalid actions and special cases. + + actions : tensor + The actions (global) from each state in the batch for which to compute the + log probability. + + states_from : tensor + The states originating the actions, in GFlowNet format. + + is_backward : bool + True if the actions are backward, False if the actions are forward + (default). + """ + n_states = policy_outputs.shape[0] + states_dict = {stage: [] for stage in Stage} + """ + A dictionary with keys equal to Stage and the values are the list of states in + the stage of the key. The states are only the part corresponding to the + sub-environment. + """ + stages = [] + for s in states_from: + stage = self._get_stage(s) + state_subenv = self._get_state_of_subenv(s, stage) + # If the actions are backwards and state is source of subenv, decrease + # stage so that EOS of preceding stage is sampled. + if ( + is_backward + and stage != Stage(0) + and state_subenv == self.subenvs[stage].source + ): + stage = Stage.prev(stage) + states_dict[stage].append(state_subenv) + stages.append(stage) + stages_tensor = tlong([stage.value for stage in stages], device=self.device) + is_subenv_dict = {stage: stages_tensor == stage.value for stage in Stage} + + # Compute logprobs from each sub-environment + logprobs = torch.empty(n_states, dtype=self.float, device=self.device) + for stage, subenv in self.subenvs.items(): + if not torch.any(is_subenv_dict[stage]): + continue + logprobs[is_subenv_dict[stage]] = subenv.get_logprobs( + self._get_policy_outputs_of_subenv( + policy_outputs[is_subenv_dict[stage]], stage + ), + actions[is_subenv_dict[stage], : len(subenv.eos)], + self._get_mask_of_subenv(mask[is_subenv_dict[stage]], stage), + states_dict[stage], + is_backward, + ) + return logprobs + + def state2policy(self, state: Optional[List[int]] = None) -> Tensor: + """ + Prepares one state in "GFlowNet format" for the policy. Simply + a concatenation of all crystal components. + """ + state = self._get_state(state) + return self.statetorch2policy( + torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) + )[0] + + def statebatch2policy( + self, states: List[List] + ) -> TensorType["batch", "state_policy_dim"]: + """ + Prepares a batch of states in "GFlowNet format" for the policy. Simply + a concatenation of all crystal components. + """ + return self.statetorch2policy( + tfloat(states, device=self.device, float_type=self.float) + ) + + def statetorch2policy( + self, states: TensorType["batch", "state_dim"] + ) -> TensorType["batch", "state_policy_dim"]: + """ + Prepares a tensor batch of states in "GFlowNet format" for the policy. Simply + a concatenation of all crystal components. + """ + return torch.cat( + [ + subenv.statetorch2policy(self._get_states_of_subenv(states, stage)) + for stage, subenv in self.subenvs.items() + ], + dim=1, + ) + + def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: + """ + Prepares one state in "GFlowNet format" for the oracle. Simply + a concatenation of all crystal components. + """ + state = self._get_state(state) + return self.statetorch2oracle( + torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) + ) + + def statebatch2oracle( + self, states: List[List] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Prepares a batch of states in "GFlowNet format" for the oracle. Simply + a concatenation of all crystal components. + """ + return self.statetorch2oracle( + tfloat(states, device=self.device, float_type=self.float) + ) + + def statetorch2oracle( + self, states: TensorType["batch", "state_dim"] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Prepares one state in "GFlowNet format" for the oracle. Simply + a concatenation of all crystal components. + """ + return torch.cat( + [ + subenv.statetorch2oracle(self._get_states_of_subenv(states, stage)) + for stage, subenv in self.subenvs.items() + ], + dim=1, + ) + + def state2proxy(self, state: Optional[List[int]] = None) -> Tensor: + """ + Returns state2oracle(state). + """ + return self.state2oracle(state) + + def statebatch2proxy( + self, states: List[List] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Returns statebatch2oracle(states). + """ + return self.statebatch2oracle(states) + + def statetorch2proxy( + self, states: TensorType["batch", "state_dim"] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Returns statetorch2oracle(states). + """ + return self.statetorch2oracle(states) + + def set_state(self, state: List, done: Optional[bool] = False): + super().set_state(state, done) + + stage_idx = self._get_stage(state).value + + # Determine which subenvs are done based on stage and done + done_subenvs = [True] * stage_idx + [False] * (len(self.subenvs) - stage_idx) + done_subenvs[-1] = done + # Set state and done of each sub-environment + for (stage, subenv), subenv_done in zip(self.subenvs.items(), done_subenvs): + subenv.set_state(self._get_state_of_subenv(state, stage), subenv_done) + + """ + We synchronize LatticeParameter's lattice system with the one of SpaceGroup + (if it was set) or reset it to the default triclinic otherwise. Why this is + needed: for backward sampling, where we start from an arbitrary terminal state, + and need to synchronize the LatticeParameter's lattice system to what that + state indicates, + """ + if self.subenvs[Stage.SPACE_GROUP].done: + lattice_system = self.subenvs[Stage.SPACE_GROUP].lattice_system + if lattice_system != "None" and self.do_sg_to_lp_constraints: + self.subenvs[Stage.LATTICE_PARAMETERS].set_lattice_system( + lattice_system + ) + else: + self.subenvs[Stage.LATTICE_PARAMETERS].set_lattice_system(TRICLINIC) + # Set stoichiometry constraints in space group sub-environment + if ( + self.do_composition_to_sg_constraints + and self.subenvs[Stage.COMPOSITION].done + ): + self.subenvs[Stage.SPACE_GROUP].set_n_atoms_compatibility_dict( + self.subenvs[Stage.COMPOSITION].state + ) + + def state2readable(self, state: Optional[List[int]] = None) -> str: + if state is None: + state = self.state + + readables = [ + subenv.state2readable(self._get_state_of_subenv(state, stage)) + for stage, subenv in self.subenvs.items() + ] + return ( + f"{self._get_stage(state)}; " + f"Composition = {readables[0]}; " + f"SpaceGroup = {readables[1]}; " + f"LatticeParameters = {readables[2]}" + ) + + def process_data_set(self, data: List[List]) -> List[List]: + is_valid_list = [] + for x in data: + is_valid_list.append( + all( + [ + subenv.is_valid(self._get_state_of_subenv(x, stage)) + for stage, subenv in self.subenvs.items() + ] + ) + ) + return [x for x, is_valid in zip(data, is_valid_list) if is_valid] + + # TODO: redo + + +# def readable2state(self, readable: str) -> List[int]: +# splits = readable.split("; ") +# readables = [x.split(" = ")[1] for x in splits] +# +# return ( +# [int(readables[0])] +# + self.composition.readable2state( +# json.loads(readables[1].replace("'", '"')) +# ) +# + self.space_group.readable2state(readables[2]) +# + self.lattice_parameters.readable2state(readables[3]) +# ) diff --git a/gflownet/envs/crystals/clattice_parameters.py b/gflownet/envs/crystals/clattice_parameters.py new file mode 100644 index 000000000..4891bf5ef --- /dev/null +++ b/gflownet/envs/crystals/clattice_parameters.py @@ -0,0 +1,374 @@ +""" +Classes to represent continuous lattice parameters environments. +""" +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +from torchtyping import TensorType + +from gflownet.envs.cube import ContinuousCube +from gflownet.utils.common import copy, tfloat +from gflownet.utils.crystals.constants import ( + CUBIC, + HEXAGONAL, + MONOCLINIC, + ORTHORHOMBIC, + RHOMBOHEDRAL, + TETRAGONAL, + TRICLINIC, +) + +LENGTH_PARAMETER_NAMES = ("a", "b", "c") +ANGLE_PARAMETER_NAMES = ("alpha", "beta", "gamma") +PARAMETER_NAMES = LENGTH_PARAMETER_NAMES + ANGLE_PARAMETER_NAMES + + +# TODO: figure out a way to inherit the (discrete) LatticeParameters env or create a +# common class for both discrete and continous with the common methods. +class CLatticeParameters(ContinuousCube): + """ + Continuous lattice parameters environment for crystal structures generation. + + Models lattice parameters (three edge lengths and three angles describing unit + cell) with the constraints given by the provided lattice system (see + https://en.wikipedia.org/wiki/Bravais_lattice). This is implemented by inheriting + from the (continuous) cube environment, creating a mapping between cell position + and edge length or angle, and imposing lattice system constraints on their values. + + The environment is a hyper cube of dimensionality 6 (the number of lattice + parameters), but it takes advantage of the mask of ignored dimensions implemented + in the Cube environment. + + The values of the state will remain in the default [0, 1] range of the Cube, but + they are mapped to [min_length, max_length] in the case of the lengths and + [min_angle, max_angle] in the case of the angles. + """ + + def __init__( + self, + lattice_system: str, + min_length: Optional[float] = 1.0, + max_length: Optional[float] = 350.0, + min_angle: Optional[float] = 50.0, + max_angle: Optional[float] = 150.0, + **kwargs, + ): + """ + Args + ---- + lattice_system : str + One of the seven lattice systems. By default, the triclinic lattice system + is used, which has no constraints. + + min_length : float + Minimum value of the lengths. + + max_length : float + Maximum value of the lengths. + + min_angle : float + Minimum value of the angles. + + max_angle : float + Maximum value of the angles. + """ + self.continuous = True + self.lattice_system = lattice_system + self.min_length = min_length + self.max_length = max_length + self.length_range = self.max_length - self.min_length + self.min_angle = min_angle + self.max_angle = max_angle + self.angle_range = self.max_angle - self.min_angle + self._setup_constraints() + super().__init__(n_dim=6, **kwargs) + + # TODO: if source, keep as is + def _statevalue2length(self, value): + return self.min_length + value * self.length_range + + def _length2statevalue(self, length): + return (length - self.min_length) / self.length_range + + # TODO: if source, keep as is + def _statevalue2angle(self, value): + return self.min_angle + value * self.angle_range + + def _angle2statevalue(self, angle): + return (angle - self.min_angle) / self.angle_range + + def _get_param(self, state, param): + if hasattr(self, param): + return getattr(self, param) + else: + if param in LENGTH_PARAMETER_NAMES: + return self._statevalue2length(state[self._get_index_of_param(param)]) + elif param in ANGLE_PARAMETER_NAMES: + return self._statevalue2angle(state[self._get_index_of_param(param)]) + else: + raise ValueError(f"{param} is not a valid lattice parameter") + + def _set_param(self, state, param, value): + param_idx = self._get_index_of_param(param) + if param_idx is not None: + if param in LENGTH_PARAMETER_NAMES: + state[param_idx] = self._length2statevalue(value) + elif param in ANGLE_PARAMETER_NAMES: + state[param_idx] = self._angle2statevalue(value) + else: + raise ValueError(f"{param} is not a valid lattice parameter") + return state + + def _get_index_of_param(self, param): + param_idx = f"{param}_idx" + if hasattr(self, param_idx): + return getattr(self, param_idx) + else: + return None + + def set_lattice_system(self, lattice_system: str): + """ + Sets the lattice system of the unit cell and updates the constraints. + """ + self.lattice_system = lattice_system + self._setup_constraints() + + def _setup_constraints(self): + """ + Computes the mask of ignored dimensions, given the constraints imposed by the + lattice system. Sets self.ignored_dims. + """ + # Lengths: a, b, c + # a == b == c + if self.lattice_system in [CUBIC, RHOMBOHEDRAL]: + lengths_ignored_dims = [False, True, True] + self.a_idx = 0 + self.b_idx = 0 + self.c_idx = 0 + # a == b != c + elif self.lattice_system in [HEXAGONAL, TETRAGONAL]: + lengths_ignored_dims = [False, True, False] + self.a_idx = 0 + self.b_idx = 0 + self.c_idx = 1 + # a != b and a != c and b != c + elif self.lattice_system in [MONOCLINIC, ORTHORHOMBIC, TRICLINIC]: + lengths_ignored_dims = [False, False, False] + self.a_idx = 0 + self.b_idx = 1 + self.c_idx = 2 + else: + raise NotImplementedError + # Angles: alpha, beta, gamma + # alpha == beta == gamma == 90.0 + if self.lattice_system in [CUBIC, ORTHORHOMBIC, TETRAGONAL]: + angles_ignored_dims = [True, True, True] + self.alpha_idx = None + self.alpha = 90.0 + self.alpha_state = self._angle2statevalue(self.alpha) + self.beta_idx = None + self.beta = 90.0 + self.beta_state = self._angle2statevalue(self.beta) + self.gamma_idx = None + self.gamma = 90.0 + self.gamma_state = self._angle2statevalue(self.gamma) + # alpha == beta == 90.0 and gamma == 120.0 + elif self.lattice_system == HEXAGONAL: + angles_ignored_dims = [True, True, True] + self.alpha_idx = None + self.alpha = 90.0 + self.alpha_state = self._angle2statevalue(self.alpha) + self.beta_idx = None + self.beta = 90.0 + self.beta_state = self._angle2statevalue(self.beta) + self.gamma_idx = None + self.gamma = 120.0 + self.gamma_state = self._angle2statevalue(self.gamma) + # alpha == gamma == 90.0 and beta != 90.0 + elif self.lattice_system == MONOCLINIC: + angles_ignored_dims = [True, False, True] + self.alpha_idx = None + self.alpha = 90.0 + self.alpha_state = self._angle2statevalue(self.alpha) + self.beta_idx = 4 + self.gamma_idx = None + self.gamma = 90.0 + self.gamma_state = self._angle2statevalue(self.gamma) + # alpha == beta == gamma != 90.0 + elif self.lattice_system == RHOMBOHEDRAL: + angles_ignored_dims = [False, True, True] + self.alpha_idx = 3 + self.beta_idx = 3 + self.gamma_idx = 3 + # alpha != beta, alpha != gamma, beta != gamma + elif self.lattice_system == TRICLINIC: + angles_ignored_dims = [False, False, False] + self.alpha_idx = 3 + self.beta_idx = 4 + self.gamma_idx = 5 + else: + raise NotImplementedError + self.ignored_dims = lengths_ignored_dims + angles_ignored_dims + + def _step( + self, + action: Tuple[float], + backward: bool, + ) -> Tuple[List[float], Tuple[float], bool]: + """ + Updates the dimensions of the state corresponding to the ignored dimensions + after a call to the Cube's _step(). + """ + state, action, valid = super()._step(action, backward) + for idx, (param, is_ignored) in enumerate( + zip(PARAMETER_NAMES, self.ignored_dims) + ): + if not is_ignored: + continue + param_idx = self._get_index_of_param(param) + if param_idx is not None: + state[idx] = state[param_idx] + else: + state[idx] = getattr(self, f"{param}_state") + self.state = copy(state) + return self.state, action, valid + + def _unpack_lengths_angles( + self, state: Optional[List[int]] = None + ) -> Tuple[Tuple, Tuple]: + """ + Helper that 1) unpacks values coding lengths and angles from the state or from + the attributes of the instance and 2) converts them to actual edge lengths and + angles in the target units (angstroms or degrees). + """ + state = self._get_state(state) + + a, b, c, alpha, beta, gamma = [ + self._get_param(state, p) for p in PARAMETER_NAMES + ] + return (a, b, c), (alpha, beta, gamma) + + def state2readable(self, state: Optional[List[int]] = None) -> str: + """ + Converts the state into a human-readable string in the format "(a, b, c), + (alpha, beta, gamma)". + """ + state = self._get_state(state) + + lengths, angles = self._unpack_lengths_angles(state) + return f"{lengths}, {angles}" + + def readable2state(self, readable: str) -> List[int]: + """ + Converts a human-readable representation of a state into the standard format. + """ + state = copy(self.source) + + for c in ["(", ")", " "]: + readable = readable.replace(c, "") + values = readable.split(",") + values = [float(value) for value in values] + + for param, value in zip(PARAMETER_NAMES, values): + state = self._set_param(state, param, value) + return state + + def state2policy(self, state: Optional[List[float]] = None) -> Tensor: + """ + Simply returns a torch tensor of the state as is, in the range [0, 1]. + """ + state = self._get_state(state) + return tfloat(state, float_type=self.float, device=self.device) + + def statebatch2policy( + self, states: List[List] + ) -> TensorType["batch", "state_proxy_dim"]: + """ + Simply returns a torch tensor of the states as are, in the range [0, 1], by + calling statetorch2policy. + """ + return self.statetorch2policy( + tfloat(states, device=self.device, float_type=self.float) + ) + + def statetorch2policy( + self, states: TensorType["batch", "state_dim"] = None + ) -> TensorType["batch", "policy_input_dim"]: + """ + Simply returns the states as are, in the range [0, 1]. + """ + return states + + def state2oracle(self, state: Optional[List[float]] = None) -> Tensor: + """ + Maps [0; 1] state values to edge lengths and angles. + """ + state = self._get_state(state) + + return tfloat( + [self._get_param(state, p) for p in PARAMETER_NAMES], + float_type=self.float, + device=self.device, + ) + + def statebatch2oracle( + self, states: List[List] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Maps [0; 1] state values to edge lengths and angles. + """ + return self.statetorch2oracle( + tfloat(states, device=self.device, float_type=self.float) + ) + + def statetorch2oracle( + self, states: TensorType["batch", "state_dim"] = None + ) -> TensorType["batch", "oracle_input_dim"]: + """ + Maps [0; 1] state values to edge lengths and angles. + """ + return torch.cat( + [ + self._statevalue2length(states[:, :3]), + self._statevalue2angle(states[:, 3:]), + ], + dim=1, + ) + + def state2proxy(self, state: Optional[List[int]] = None) -> Tensor: + """ + Returns state2oracle(state). + """ + return self.state2oracle(state) + + def statebatch2proxy( + self, states: List[List] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Returns statebatch2oracle(states). + """ + return self.statebatch2oracle(states) + + def statetorch2proxy( + self, states: TensorType["batch", "state_dim"] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Returns statetorch2oracle(states). + """ + return self.statetorch2oracle(states) + + def is_valid(self, x: List) -> bool: + """ + Determines whether a state is valid, according to the attributes of the + environment. + """ + lengths, angles = self._unpack_lengths_angles(x) + # Check lengths + if any([l < self.min_length or l > self.max_length for l in lengths]): + return False + if any([l < self.min_angle or l > self.max_angle for l in angles]): + return False + + # If all checks are passed, return True + return True diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 6fc62e991..55b8bb186 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -11,7 +11,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv -from gflownet.utils.common import tlong +from gflownet.utils.common import tfloat, tlong from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES from gflownet.utils.crystals.pyxtal_cache import ( get_space_group, @@ -20,6 +20,8 @@ space_group_wyckoff_gcd, ) +N_ELEMENTS_ORACLE = 94 + class Composition(GFlowNetEnv): """ @@ -28,7 +30,7 @@ class Composition(GFlowNetEnv): def __init__( self, - elements: Union[List, int] = 84, + elements: Union[List, int] = 94, max_diff_elem: int = 5, min_diff_elem: int = 2, min_atoms: int = 2, @@ -406,8 +408,9 @@ def get_element_mask(min_atoms, max_atoms): def state2oracle(self, state: List = None) -> Tensor: """ - Prepares a state in "GFlowNet format" for the oracle. In this case, it simply - converts the state into a torch tensor, with dtype torch.long. + Prepares a state in "GFlowNet format" for the oracle. The output is a tensor of + length N_ELEMENTS_ORACLE + 1, where the positions of self.elements are filled with + the number of atoms of each element in the state. Args ---- @@ -421,15 +424,17 @@ def state2oracle(self, state: List = None) -> Tensor: """ if state is None: state = self.state - - return tlong(state, device=self.device) + return self.statetorch2oracle( + torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) + )[0] def statetorch2oracle( self, states: TensorType["batch", "state_dim"] ) -> TensorType["batch", "state_oracle_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is the atom counts for individual elements. + Prepares a batch of states in "GFlowNet format" for the oracle. The output is + a tensor with N_ELEMENTS_ORACLE + 1 columns, where the positions of + self.elements are filled with the number of atoms of each element in the state. Args ---- @@ -440,7 +445,15 @@ def statetorch2oracle( ---- oracle_states : Tensor """ - return states + states_float = states.to(self.float) + + states_oracle = torch.zeros( + (states.shape[0], N_ELEMENTS_ORACLE + 1), + device=self.device, + dtype=self.float, + ) + states_oracle[:, tlong(self.elements, device=self.device)] = states_float + return states_oracle def statebatch2oracle( self, states: List[List] @@ -453,7 +466,7 @@ def statebatch2oracle( ---- state : list """ - return tlong(states, device=self.device) + return self.statetorch2oracle(tlong(states, device=self.device)) def state2readable(self, state=None): """ @@ -642,3 +655,34 @@ def _can_produce_neutral_charge(self, state: Optional[List[int]] = None) -> bool nums_charges[0] = (num - 1, charges) return 0 in poss_charge_sum + + def is_valid(self, x: List) -> bool: + """ + Determines whether a state is valid, according to the attributes of the + environment. + """ + # Check length is equal to number of elements + if len(x) != len(self.elements): + return False + # Check total number of atoms + n_atoms = sum(x) + if n_atoms < self.min_atoms: + return False + if n_atoms > self.max_atoms: + return False + # Check number element + if any([n < self.min_atom_i for n in x if n > 0]): + return False + if any([n > self.max_atom_i for n in x if n > 0]): + return False + # Check required elements + used_elements = [self.idx2elem[idx] for idx, n in enumerate(x) if n > 0] + if len(used_elements) < self.min_diff_elem: + return False + if len(used_elements) > self.max_diff_elem: + return False + if any(r not in used_elements for r in self.required_elements): + return False + + # If all checks are passed, return True + return True diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 38ea7fb4e..9506032ba 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -614,8 +614,6 @@ def set_n_atoms_compatibility_dict(self, n_atoms: List): removed from the list since they do not count towards the compatibility with a space group. """ - if n_atoms is not None: - n_atoms = [n for n in n_atoms if n > 0] # Get compatibility with stoichiometry self.n_atoms_compatibility_dict = SpaceGroup.build_n_atoms_compatibility_dict( n_atoms, self.space_groups.keys() @@ -650,7 +648,9 @@ def _is_compatible( return len(space_groups) > 0 @staticmethod - def build_n_atoms_compatibility_dict(n_atoms: List[int], space_groups: List[int]): + def build_n_atoms_compatibility_dict( + n_atoms: List[int], space_groups: Iterable[int] + ): """ Obtains which space groups are compatible with the stoichiometry given as argument (n_atoms). @@ -662,8 +662,9 @@ def build_n_atoms_compatibility_dict(n_atoms: List[int], space_groups: List[int] Args ---- n_atoms : list of int - A list of positive number of atoms for each element in a stoichiometry. If - None, all space groups will be marked as compatible. + A list of number of atoms for each element in a stoichiometry. 0s will be + removed from the list since they do not count towards the compatibility + with a space group. If None, all space groups will be marked as compatible. space_groups : list of int A list of space group international numbers, in [1, 230] @@ -676,6 +677,7 @@ def build_n_atoms_compatibility_dict(n_atoms: List[int], space_groups: List[int] """ if n_atoms is None: return {sg: True for sg in space_groups} + n_atoms = [n for n in n_atoms if n > 0] assert all([n > 0 for n in n_atoms]) assert all([sg > 0 and sg <= 230 for sg in space_groups]) return {sg: space_group_check_compatible(sg, n_atoms) for sg in space_groups} @@ -760,3 +762,12 @@ def get_all_terminating_states( continue all_x.append(self._set_constrained_properties([0, 0, sg])) return all_x + + def is_valid(self, x: List) -> bool: + """ + Determines whether a state is valid, according to the attributes of the + environment. + """ + if x[self.sg_idx] in self.space_groups: + return True + return False diff --git a/gflownet/envs/ctorus.py b/gflownet/envs/ctorus.py index 3cf8543bd..b2a6ac948 100644 --- a/gflownet/envs/ctorus.py +++ b/gflownet/envs/ctorus.py @@ -35,6 +35,8 @@ class ContinuousTorus(HybridTorus): def __init__(self, **kwargs): super().__init__(**kwargs) + # Mask dimensionality: + self.mask_dim = 2 def get_action_space(self): """ @@ -71,7 +73,9 @@ def get_policy_output(self, params: dict) -> TensorType["policy_output_dim"]: (self.n_comp). The first 3 x C entries in the policy output correspond to the first dimension, and so on. """ - policy_output = np.ones(self.n_dim * self.n_comp * 3) + policy_output = torch.ones( + self.n_dim * self.n_comp * 3, dtype=self.float, device=self.device + ) policy_output[1::3] = params["vonmises_mean"] policy_output[2::3] = params["vonmises_concentration"] return policy_output diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index a66363286..ca929ea6d 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -48,6 +48,11 @@ class CubeBase(GFlowNetEnv, ABC): Small constant to control the intervals of the generated sets of states (in a grid or uniformly). States will be in the interval [kappa, 1 - kappa]. Default: 1e-3. + + ignored_dims : list + Boolean mask of ignored dimensions. This can be used for trajectories that may + have multiple dimensions coupled or fixed. For each dimension, True if ignored, + False, otherwise. If None, no dimension is ignored. """ def __init__( @@ -59,6 +64,7 @@ def __init__( beta_params_max: float = 100.0, epsilon: float = 1e-6, kappa: float = 1e-3, + ignored_dims: Optional[List[bool]] = None, fixed_distr_params: dict = { "beta_weights": 1.0, "beta_alpha": 10.0, @@ -82,6 +88,10 @@ def __init__( # Main properties self.n_dim = n_dim self.min_incr = min_incr + if ignored_dims: + self.ignored_dims = ignored_dims + else: + self.ignored_dims = [False] * self.n_dim # Parameters of the policy distribution self.n_comp = n_comp self.beta_params_min = beta_params_min @@ -92,14 +102,6 @@ def __init__( self.epsilon = epsilon # Small constant to restrict the interval of (test) sets self.kappa = kappa - # Conversions: only conversions to policy are implemented and the rest are the - # same - self.state2proxy = self.state2policy - self.statebatch2proxy = self.statebatch2policy - self.statetorch2proxy = self.statetorch2policy - self.state2oracle = self.state2proxy - self.statebatch2oracle = self.statebatch2proxy - self.statetorch2oracle = self.statetorch2proxy # Base class init super().__init__( fixed_distr_params=fixed_distr_params, @@ -128,9 +130,9 @@ def get_mask_invalid_actions_forward( def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None): pass - def statetorch2policy( + def statetorch2oracle( self, states: TensorType["batch", "state_dim"] = None - ) -> TensorType["batch", "policy_input_dim"]: + ) -> TensorType["batch", "oracle_input_dim"]: """ Clips the states into [0, 1] and maps them to [-1.0, 1.0] @@ -141,9 +143,9 @@ def statetorch2policy( """ return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 - def statebatch2policy( + def statebatch2oracle( self, states: List[List] - ) -> TensorType["batch", "state_proxy_dim"]: + ) -> TensorType["batch", "state_oracle_dim"]: """ Clips the states into [0, 1] and maps them to [-1.0, 1.0] @@ -152,11 +154,11 @@ def statebatch2policy( state : list State """ - return self.statetorch2policy( + return self.statetorch2oracle( tfloat(states, device=self.device, float_type=self.float) ) - def state2policy(self, state: List = None) -> List: + def state2oracle(self, state: List = None) -> List: """ Clips the state into [0, 1] and maps it to [-1.0, 1.0] """ @@ -164,6 +166,70 @@ def state2policy(self, state: List = None) -> List: state = self.state.copy() return [2.0 * min(max(0.0, s), 1.0) - 1.0 for s in state] + def statetorch2proxy( + self, states: TensorType["batch", "state_dim"] = None + ) -> TensorType["batch", "oracle_input_dim"]: + """ + Returns statetorch2oracle(states), that is states mapped to [-1.0, 1.0]. + + Args + ---- + state : list + State + """ + return self.statetorch2oracle(states) + + def statebatch2proxy( + self, states: List[List] + ) -> TensorType["batch", "state_oracle_dim"]: + """ + Returns statebatch2oracle(states), that is states mapped to [-1.0, 1.0]. + + Args + ---- + state : list + State + """ + return self.statebatch2oracle(states) + + def state2proxy(self, state: List = None) -> List: + """ + Returns state2oracle(state), that is the state mapped to [-1.0, 1.0]. + """ + return self.state2oracle(state) + + def statetorch2policy( + self, states: TensorType["batch", "state_dim"] = None + ) -> TensorType["batch", "policy_input_dim"]: + """ + Returns statetorch2proxy(states), that is states mapped to [-1.0, 1.0]. + + Args + ---- + state : list + State + """ + return self.statetorch2proxy(states) + + def statebatch2policy( + self, states: List[List] + ) -> TensorType["batch", "state_proxy_dim"]: + """ + Returns statebatch2proxy(states), that is states mapped to [-1.0, 1.0]. + + Args + ---- + state : list + State + """ + return self.statebatch2proxy(states) + + def state2policy(self, state: List = None) -> List: + """ + Returns state2proxy(state), that is the state mapped to [-1.0, 1.0]. + """ + return self.state2proxy(state) + def state2readable(self, state: List) -> str: """ Converts a state (a list of positions) into a human-readable string @@ -282,6 +348,10 @@ def _beta_params_to_policy_outputs(self, param_name: str, params_dict: dict): ) return torch.logit((param_value - self.beta_params_min) / self.beta_params_max) + def _get_effective_dims(self, state: Optional[List] = None) -> List: + state = self._get_state(state) + return [s for s, ign_dim in zip(state, self.ignored_dims) if not ign_dim] + class ContinuousCube(CubeBase): """ @@ -316,6 +386,9 @@ class ContinuousCube(CubeBase): def __init__(self, **kwargs): super().__init__(**kwargs) + # Mask dimensionality: 3 + number of dimensions + self.mask_dim_base = 3 + self.mask_dim = self.mask_dim_base + self.n_dim def get_action_space(self): """ @@ -341,10 +414,13 @@ def get_max_traj_length(self): def get_policy_output(self, params: dict) -> TensorType["policy_output_dim"]: """ - Defines the structure of the output of the policy model, from which an - action is to be determined or sampled, by returning a vector with a fixed - random policy. The environment consists of both continuous and discrete - actions. + Defines the structure of the output of the policy model. + + The policy output will be used to initialize a distribution, from which an + action is to be determined or sampled. This method returns a vector with a + fixed policy defined by params. + + The environment consists of both continuous and discrete actions. Continuous actions @@ -490,24 +566,27 @@ def get_mask_invalid_actions_forward( the source state, True otherwise. - 2 : whether EOS action is invalid. EOS is valid from any state, except the source state or if done is True. + - -n_dim: : dimensions that should be ignored when sampling actions or + computing logprobs. This can be used for trajectories that may have + multiple dimensions coupled or fixed. For each dimension, True if ignored, + False, otherwise. """ state = self._get_state(state) done = self._get_done(done) - mask_dim = 3 # If done, the entire mask is True (all actions are "invalid" and no special # cases) if done: - return [True] * mask_dim - mask = [False] * mask_dim + return [True] * self.mask_dim + mask = [False] * self.mask_dim_base + self.ignored_dims # If the state is the source state, EOS is invalid - if state == self.source: + if self._get_effective_dims(state) == self._get_effective_dims(self.source): mask[2] = True # If the state is not the source, indicate not special case (True) else: mask[1] = True # If the value of any dimension is greater than 1 - min_incr, then continuous # actions are invalid (True). - if any([s > 1 - self.min_incr for s in state]): + if any([s > 1 - self.min_incr for s in self._get_effective_dims(state)]): mask[0] = True return mask @@ -531,13 +610,16 @@ def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=Non False if any dimension is smaller than min_incr, True otherwise. - 2 : whether EOS action is invalid. False only if done is True, True (invalid) otherwise. + - -n_dim: : dimensions that should be ignored when sampling actions or + computing logprobs. this can be used for trajectories that may have + multiple dimensions coupled or fixed. for each dimension, true if ignored, + false, otherwise. By default, no dimension is ignored. """ state = self._get_state(state) done = self._get_done(done) - mask_dim = 3 - mask = [True] * mask_dim + mask = [True] * self.mask_dim_base + self.ignored_dims # If the state is the source state, entire mask is True - if state == self.source: + if self._get_effective_dims(state) == self._get_effective_dims(self.source): return mask # If done, only valid action is EOS. if done: @@ -545,7 +627,7 @@ def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=Non return mask # If any dimension is smaller than m, then back-to-source action is the only # possible actiona. - if any([s < self.min_incr for s in state]): + if any([s < self.min_incr for s in self._get_effective_dims(state)]): mask[1] = False return mask # Otherwise, continuous actions are valid @@ -675,10 +757,38 @@ def _make_increments_distribution( beta_distr = Beta(alphas, betas) return MixtureSameFamily(mix, beta_distr) + def _mask_ignored_dimensions( + self, + mask: TensorType["n_states", "policy_outputs_dim"], + tensor_to_mask: TensorType["n_states", "n_dim"], + ) -> MixtureSameFamily: + """ + Makes the actions, logprobs or log jacobian entries of ignored dimensions zero. + + Since the shape of all the tensor of actions, the logprobs of increments and + the log of the diagonal of the Jacobian must be the same, this method makes no + distiction between for simplicity. + + Args + ---- + mask : tensor + Boolean mask indicating (True) which dimensions should be set to zero. + + tensor_to_mask : tensor + Tensor to be modified. It may be a tensor of actions, of logprobs of + increments or the log of the diagonal of the Jacobian. + """ + is_ignored_dim = mask[:, -self.n_dim :] + if torch.any(is_ignored_dim): + shape_orig = tensor_to_mask.shape + tensor_to_mask[is_ignored_dim] = 0.0 + tensor_to_mask = tensor_to_mask.reshape(shape_orig) + return tensor_to_mask + def sample_actions_batch( self, policy_outputs: TensorType["n_states", "policy_output_dim"], - mask: Optional[TensorType["n_states", "policy_output_dim"]] = None, + mask: Optional[TensorType["n_states", "mask_dim"]] = None, states_from: List = None, is_backward: Optional[bool] = False, sampling_method: Optional[str] = "policy", @@ -700,7 +810,7 @@ def sample_actions_batch( def _sample_actions_batch_forward( self, policy_outputs: TensorType["n_states", "policy_output_dim"], - mask: Optional[TensorType["n_states", "policy_output_dim"]] = None, + mask: Optional[TensorType["n_states", "mask_dim"]] = None, states_from: List = None, sampling_method: Optional[str] = "policy", temperature_logits: Optional[float] = 1.0, @@ -792,10 +902,12 @@ def _sample_actions_batch_forward( (n_states, self.n_dim + 1), torch.inf, dtype=self.float, device=self.device ) if torch.any(do_increments): - increments = torch.cat( + # Make increments of ignored dimensions zero + increments = self._mask_ignored_dimensions(mask[do_increments], increments) + # Add dimension is_source and add to actions tensor + actions_tensor[do_increments] = torch.cat( (increments, torch.zeros((increments.shape[0], 1))), dim=1 ) - actions_tensor[do_increments] = increments actions_tensor[is_source, -1] = 1 actions = [tuple(a.tolist()) for a in actions_tensor] return actions, None @@ -803,7 +915,7 @@ def _sample_actions_batch_forward( def _sample_actions_batch_backward( self, policy_outputs: TensorType["n_states", "policy_output_dim"], - mask: Optional[TensorType["n_states", "policy_output_dim"]] = None, + mask: Optional[TensorType["n_states", "mask_dim"]] = None, states_from: List = None, sampling_method: Optional[str] = "policy", temperature_logits: Optional[float] = 1.0, @@ -880,10 +992,12 @@ def _sample_actions_batch_backward( self.eos, float_type=self.float, device=self.device ) if torch.any(do_increments): - increments = torch.cat( + # Make increments of ignored dimensions zero + increments = self._mask_ignored_dimensions(mask[do_increments], increments) + # Add dimension is_source and add to actions tensor + actions_tensor[do_increments] = torch.cat( (increments, torch.zeros((increments.shape[0], 1))), dim=1 ) - actions_tensor[do_increments] = increments if torch.any(is_bts): # BTS actions are equal to the originating states actions_bts = tfloat( @@ -893,6 +1007,10 @@ def _sample_actions_batch_backward( (actions_bts, torch.ones((actions_bts.shape[0], 1))), dim=1 ) actions_tensor[is_bts] = actions_bts + # Make ignored dimensions zero + actions_tensor[is_bts, :-1] = self._mask_ignored_dimensions( + mask[is_bts], actions_tensor[is_bts, :-1] + ) actions = [tuple(a.tolist()) for a in actions_tensor] return actions, None @@ -900,7 +1018,7 @@ def get_logprobs( self, policy_outputs: TensorType["n_states", "policy_output_dim"], actions: TensorType["n_states", "actions_dim"], - mask: TensorType["n_states", "3"], + mask: TensorType["n_states", "mask_dim"], states_from: List, is_backward: bool, ) -> TensorType["batch_size"]: @@ -913,7 +1031,7 @@ def get_logprobs( The output of the GFlowNet policy model. mask : tensor - The mask containing information invalid actions and special cases. + The mask containing information about invalid actions and special cases. actions : tensor The actions (absolute increments) from each state in the batch for which to @@ -957,7 +1075,7 @@ def _get_logprobs_forward( logprobs_increments_rel = torch.zeros( (n_states, self.n_dim), dtype=self.float, device=self.device ) - jacobian_diag = torch.ones( + log_jacobian_diag = torch.zeros( (n_states, self.n_dim), device=self.device, dtype=self.float ) eos_tensor = tfloat(self.eos, float_type=self.float, device=self.device) @@ -1006,10 +1124,14 @@ def _get_logprobs_forward( # not source is_relative = torch.logical_and(do_increments, ~is_source) if torch.any(is_relative): - jacobian_diag[is_relative] = self._get_jacobian_diag( - states_from_rel, - is_backward=False, + log_jacobian_diag[is_relative] = torch.log( + self._get_jacobian_diag( + states_from_rel, + is_backward=False, + ) ) + # Make ignored dimensions zero + log_jacobian_diag = self._mask_ignored_dimensions(mask, log_jacobian_diag) # Get logprobs distr_increments = self._make_increments_distribution( policy_outputs[do_increments] @@ -1018,8 +1140,12 @@ def _get_logprobs_forward( logprobs_increments_rel[do_increments] = distr_increments.log_prob( torch.clamp(increments, min=self.epsilon, max=(1 - self.epsilon)) ) - # Get log determinant of the Jacobian - log_det_jacobian = torch.sum(torch.log(jacobian_diag), dim=1) + # Make ignored dimensions zero + logprobs_increments_rel = self._mask_ignored_dimensions( + mask, logprobs_increments_rel + ) + # Sum log Jacobian across dimensions + log_det_jacobian = torch.sum(log_jacobian_diag, dim=1) # Compute combined probabilities sumlogprobs_increments = logprobs_increments_rel.sum(axis=1) logprobs = logprobs_eos + sumlogprobs_increments + log_det_jacobian @@ -1045,7 +1171,7 @@ def _get_logprobs_backward( logprobs_increments_rel = torch.zeros( (n_states, self.n_dim), dtype=self.float, device=self.device ) - jacobian_diag = torch.ones( + log_jacobian_diag = torch.zeros( (n_states, self.n_dim), device=self.device, dtype=self.float ) # EOS is the only possible action only if done is True (mask[2] is False) @@ -1081,10 +1207,14 @@ def _get_logprobs_backward( is_backward=True, ) # Compute diagonal of the Jacobian (see _get_jacobian_diag()) - jacobian_diag[do_increments] = self._get_jacobian_diag( - states_from_tensor[do_increments], - is_backward=True, + log_jacobian_diag[do_increments] = torch.log( + self._get_jacobian_diag( + states_from_tensor[do_increments], + is_backward=True, + ) ) + # Make ignored dimensions zero + log_jacobian_diag = self._mask_ignored_dimensions(mask, log_jacobian_diag) # Get logprobs distr_increments = self._make_increments_distribution( policy_outputs[do_increments] @@ -1093,8 +1223,12 @@ def _get_logprobs_backward( logprobs_increments_rel[do_increments] = distr_increments.log_prob( torch.clamp(increments, min=self.epsilon, max=(1 - self.epsilon)) ) - # Get log determinant of the Jacobian - log_det_jacobian = torch.sum(torch.log(jacobian_diag), dim=1) + # Make ignored dimensions zero + logprobs_increments_rel = self._mask_ignored_dimensions( + mask, logprobs_increments_rel + ) + # Sum log Jacobian across dimensions + log_det_jacobian = torch.sum(log_jacobian_diag, dim=1) # Compute combined probabilities sumlogprobs_increments = logprobs_increments_rel.sum(axis=1) logprobs = logprobs_bts + sumlogprobs_increments + log_det_jacobian @@ -1179,6 +1313,7 @@ def _step( if not backward and action[-1] == 1 and self.state == self.source: state = [0.0 for _ in range(self.n_dim)] else: + assert action[-1] == 0 state = copy(self.state) # Increment dimensions for dim, incr in enumerate(action[:-1]): @@ -1188,7 +1323,10 @@ def _step( state[dim] += incr # If state is out of bounds, return invalid - if any([s > 1.0 for s in state]) or any([s < 0.0 for s in state]): + effective_dims = self._get_effective_dims(state) + if any([s > 1.0 for s in effective_dims]) or any( + [s < 0.0 for s in effective_dims] + ): warnings.warn( f""" State is out of cube bounds. @@ -1229,7 +1367,9 @@ def step(self, action: Tuple[float]) -> Tuple[List[float], Tuple[int, float], bo if self.done: return self.state, action, False if action == self.eos: - assert self.state != self.source + assert self._get_effective_dims(self.state) != self._get_effective_dims( + self.source + ) self.done = True self.n_actions += 1 return self.state, self.eos, True @@ -1277,6 +1417,14 @@ def step_backwards( # Otherwise perform action return self._step(action, backward=True) + def action2representative(self, action: Tuple) -> Tuple: + """ + Replaces the continuous values of an action by 0s (the "generic" or + "representative" action in the first position of the action space), so that + they can be compared against the action space or a mask. + """ + return self.action_space[0] + def get_grid_terminating_states( self, n_states: int, kappa: Optional[float] = None ) -> List[List]: diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 6a82dee1a..011a74c51 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -127,7 +127,9 @@ def get_policy_output(self, params: dict): - d * n_params_per_dim + 3: logit of Bernoulli distribution with d in [0, ..., D] """ - policy_output = np.ones(self.n_dim * self.n_params_per_dim + 1) + policy_output = torch.ones( + self.n_dim * self.n_params_per_dim + 1, dtype=self.float, device=self.device + ) policy_output[1 :: self.n_params_per_dim] = params["vonmises_mean"] policy_output[2 :: self.n_params_per_dim] = params["vonmises_concentration"] return policy_output diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index c9cac3146..06777a851 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -507,7 +507,6 @@ def sample_batch( envs, actions, valids = self.step(envs, actions, backward=True) # Add to batch batch_replay.add_to_batch(envs, actions, valids, backward=True, train=train) - assert all(valids) # Filter out finished trajectories envs = [env for env in envs if not env.equal(env.state, env.source)] times["replay_actions"] = time.time() - t0_replay @@ -718,6 +717,7 @@ def estimate_logprobs_data( The logarithm of the average ratio PF/PB over n trajectories sampled for each data point. """ + print("Compute logprobs...", flush=True) times = {} # Determine terminating states if isinstance(data, list): @@ -751,6 +751,10 @@ def estimate_logprobs_data( mult_indices = max(n_states, n_trajectories) init_batch = 0 end_batch = min(batch_size, n_states) + print( + "Sampling backward actions from test data to estimate logprobs...", + flush=True, + ) pbar = tqdm(total=n_states) while init_batch < n_states: batch = Batch(env=self.env, device=self.device, float_type=self.float) @@ -802,6 +806,7 @@ def estimate_logprobs_data( logprobs_estimates = torch.logsumexp( logprobs_f - logprobs_b, dim=1 ) - torch.log(torch.tensor(n_trajectories, device=self.device)) + print("Done computing logprobs", flush=True) return logprobs_estimates def train(self): @@ -1004,11 +1009,12 @@ def test(self, **plot_kwargs): ).item() nll_tt = -logprobs_x_tt.mean().item() - batch, _ = self.sample_batch(n_forward=self.logger.test.n, train=False) - assert batch.is_valid() - x_sampled = batch.get_terminating_states() - + x_sampled = [] if self.buffer.test_type is not None and self.buffer.test_type == "all": + batch, _ = self.sample_batch(n_forward=self.logger.test.n, train=False) + assert batch.is_valid() + x_sampled = batch.get_terminating_states() + if "density_true" in dict_tt: density_true = dict_tt["density_true"] else: @@ -1025,20 +1031,10 @@ def test(self, **plot_kwargs): density_pred = np.array([hist[tuple(x)] / z_pred for x in x_tt]) log_density_true = np.log(density_true + 1e-8) log_density_pred = np.log(density_pred + 1e-8) - elif self.buffer.test_type == "random": - # TODO: refactor - env_metrics = self.env.test(x_sampled) - return ( - self.l1, - self.kl, - self.jsd, - corr_prob_traj_rewards, - var_logrewards_logp, - nll_tt, - (None,), - env_metrics, - ) - elif self.continuous: + elif self.continuous and hasattr(self.env, "fit_kde"): + batch, _ = self.sample_batch(n_forward=self.logger.test.n, train=False) + assert batch.is_valid() + x_sampled = batch.get_terminating_states() # TODO make it work with conditional env x_sampled = torch2np(self.env.statebatch2proxy(x_sampled)) x_tt = torch2np(self.env.statebatch2proxy(x_tt)) @@ -1078,7 +1074,18 @@ def test(self, **plot_kwargs): density_true = np.exp(log_density_true) density_pred = np.exp(log_density_pred) else: - raise NotImplementedError + # TODO: refactor + env_metrics = self.env.test(x_sampled) + return ( + self.l1, + self.kl, + self.jsd, + corr_prob_traj_rewards, + var_logrewards_logp, + nll_tt, + (None,), + env_metrics, + ) # L1 error l1 = np.abs(density_pred - density_true).mean() # KL divergence diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 766231481..50b625e09 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -12,12 +12,8 @@ def __init__(self, config, env, device, float_precision, base=None): self.float = set_float_precision(float_precision) # Input and output dimensions self.state_dim = env.policy_input_dim - self.fixed_output = torch.tensor(env.fixed_policy_output).to( - dtype=self.float, device=self.device - ) - self.random_output = torch.tensor(env.random_policy_output).to( - dtype=self.float, device=self.device - ) + self.fixed_output = env.fixed_policy_output + self.random_output = env.random_policy_output self.output_dim = len(self.fixed_output) # Optional base model self.base = base diff --git a/gflownet/proxy/crystals/dave.py b/gflownet/proxy/crystals/dave.py index d261e6b90..7197a9513 100644 --- a/gflownet/proxy/crystals/dave.py +++ b/gflownet/proxy/crystals/dave.py @@ -45,6 +45,10 @@ def __init__(self, ckpt_path=None, release=None, rescale_outputs=True, **kwargs) super().__init__(**kwargs) self.rescale_outputs = rescale_outputs self.scaled = False + if "clip" in kwargs: + self.clip = kwargs["clip"] + else: + self.clip = False print("Initializing DAVE proxy:") print(" Checking out release:", release) @@ -98,10 +102,28 @@ def _set_scales(self): self.scaled = True @torch.no_grad() - def __call__(self, states: TensorType["batch", "96"]) -> TensorType["batch"]: + def __call__(self, states: TensorType["batch", "102"]) -> TensorType["batch"]: """ Forward pass of the proxy. + The proxy will decompose the state as: + * composition: ``states[:, :-7]`` -> length 95 (dummy 0 then 94 elements) + * space group: ``states[:, -7] - 1`` + * lattice parameters: ``states[:, -6:]`` + + >>> composition MUST be a list of ATOMIC NUMBERS, prepended with a 0. + >>> dummy padding value at comp[0] MUST be 0. + ie -> comp[i] -> element Z=i + ie -> LiO2 -> [0, 0, 0, 1, 0, 0, 2, 0, ...] up until Z=94 for the MatBench proxy + ie -> len(comp) = 95 (0 then 94 elements) + + >>> sg MUST be a list of ACTUAL space group numbers (1-230) + + >>> lat_params MUST be a list of lattice parameters in the following order: + [a, b, c, alpha, beta, gamma] as floats. + + >>> the states tensor MUST already be on the device. + Args: states (torch.Tensor): States to infer on. Shape: ``(batch, [6 + 1 + n_elements])``. @@ -112,16 +134,9 @@ def __call__(self, states: TensorType["batch", "96"]) -> TensorType["batch"]: self._set_scales() comp = states[:, :-7] - sg = states[:, -7] - 1 + sg = states[:, -7] lat_params = states[:, -6:] - n_env = comp.shape[-1] - if n_env != self.model.n_elements: - missing = torch.zeros( - (len(comp), self.model.n_elements - n_env), device=comp.device - ) - comp = torch.cat([comp, missing], dim=-1) - if self.rescale_outputs: lat_params = (lat_params - self.scales["x"]["mean"]) / self.scales["x"][ "std" @@ -134,6 +149,21 @@ def __call__(self, states: TensorType["batch", "96"]) -> TensorType["batch"]: if self.rescale_outputs: y = y * self.scales["y"]["std"] + self.scales["y"]["mean"] + if self.clip and self.clip.do: + if self.rescale_outputs: + if self.clip.min_stds: + y_min = -1.0 * self.clip.min_stds * self.scales["y"]["std"] + else: + y_min = None + if self.clip.max_stds: + y_max = self.clip.max_stds * self.scales["y"]["std"] + else: + y_max = None + else: + y_min = self.clip.min + y_max = self.clip.max + y = torch.clamp(min=y_min, max=y_max) + return y @torch.no_grad() diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index c66f4d8d3..eb105613a 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -200,13 +200,33 @@ def make_data_set(self, config): """ if config is None: return None, None - elif "path" in config and config.path is not None: - path = self.logger.logdir / Path("data") / config.path - df = pd.read_csv(path, index_col=0) - # TODO: check if state2readable transformation is required. - return df - elif "type" not in config: + print("\nConstructing data set ", end="") + if "type" not in config: return None, None + elif config.type == "pkl" and "path" in config: + print(f"from pickled file: {config.path}\n") + with open(config.path, "rb") as f: + data_dict = pickle.load(f) + samples = data_dict["x"] + n_samples_orig = len(samples) + print(f"The data set containts {n_samples_orig} samples", end="") + samples = self.env.process_data_set(samples) + n_samples_new = len(samples) + if n_samples_new != n_samples_orig: + print( + f", but only {n_samples_new} are valid according to the " + "environment settings. Invalid samples have been discarded." + ) + n_max = 100 + samples = samples[:n_max] + print(f"Only the first {n_max} samples will be kept in the data.") + print("Remember to write a function to normalise the data in code") + print("Max number of elements in data set has to match config") + print("Actually, write a function that contrasts the stats") + elif config.type == "csv" and "path" in config: + print(f"from CSV: {config.path}\n") + df = pd.read_csv(config.path, index_col=0) + samples = df.iloc[:, :-1].values elif config.type == "all" and hasattr(self.env, "get_all_terminating_states"): samples = self.env.get_all_terminating_states() elif ( @@ -214,6 +234,7 @@ def make_data_set(self, config): and "n" in config and hasattr(self.env, "get_grid_terminating_states") ): + print(f"by sampling a grid of {config.n} points\n") samples = self.env.get_grid_terminating_states(config.n) elif ( config.type == "uniform" @@ -221,12 +242,14 @@ def make_data_set(self, config): and "seed" in config and hasattr(self.env, "get_uniform_terminating_states") ): + print(f"by sampling {config.n} points uniformly\n") samples = self.env.get_uniform_terminating_states(config.n, config.seed) elif ( config.type == "random" and "n" in config and hasattr(self.env, "get_random_terminating_states") ): + print(f"by sampling {config.n} points randomly\n") samples = self.env.get_random_terminating_states(config.n) else: return None, None diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index afa751816..2e4a6b2d2 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -11,6 +11,8 @@ from omegaconf import OmegaConf from torchtyping import TensorType +from gflownet.utils.policy import parse_policy_config + def set_device(device: Union[str, torch.device]): if isinstance(device, torch.device): @@ -102,59 +104,89 @@ def find_latest_checkpoint(ckpt_dir, pattern): return sorted(ckpts, key=lambda f: float(f.stem.split("iter")[1]))[-1] -def load_gflow_net_from_run_path(run_path, device="cuda"): - device = str(device) +def load_gflow_net_from_run_path( + run_path, + no_wandb=True, + print_config=False, + device="cuda", + load_final_ckpt=True, +): run_path = resolve_path(run_path) hydra_dir = run_path / ".hydra" + with initialize_config_dir( version_base=None, config_dir=str(hydra_dir), job_name="xxx" ): config = compose(config_name="config") + + if print_config: print(OmegaConf.to_yaml(config)) - # Disable wandb - config.logger.do.online = False + + if no_wandb: + # Disable wandb + config.logger.do.online = False + # Logger logger = instantiate(config.logger, config, _recursive_=False) # The proxy is required in the env for scoring: might be an oracle or a model proxy = instantiate( config.proxy, - device=device, + device=config.device, float_precision=config.float_precision, ) # The proxy is passed to env and used for computing rewards env = instantiate( config.env, proxy=proxy, - device=device, + device=config.device, float_precision=config.float_precision, ) + forward_config = parse_policy_config(config, kind="forward") + backward_config = parse_policy_config(config, kind="backward") + forward_policy = instantiate( + forward_config, + env=env, + device=config.device, + float_precision=config.float_precision, + ) + backward_policy = instantiate( + backward_config, + env=env, + device=config.device, + float_precision=config.float_precision, + base=forward_policy, + ) gflownet = instantiate( config.gflownet, - device=device, + device=config.device, float_precision=config.float_precision, env=env, buffer=config.env.buffer, + forward_policy=forward_policy, + backward_policy=backward_policy, logger=logger, ) - # Load final models - ckpt_dir = Path(run_path) / config.logger.logdir.ckpts - forward_latest = find_latest_checkpoint( - ckpt_dir, config.gflownet.policy.forward.checkpoint - ) + + if not load_final_ckpt: + return gflownet, config + + # ------------------------------- + # ----- Load final models ----- + # ------------------------------- + + ckpt = [f for f in run_path.rglob(config.logger.logdir.ckpts) if f.is_dir()][0] + forward_final = find_latest_checkpoint(ckpt, "pf") gflownet.forward_policy.model.load_state_dict( - torch.load(forward_latest, map_location=device) + torch.load(forward_final, map_location=set_device(device)) ) try: - backward_latest = find_latest_checkpoint( - ckpt_dir, config.gflownet.policy.backward.checkpoint - ) + backward_final = find_latest_checkpoint(ckpt, "pb") gflownet.backward_policy.model.load_state_dict( - torch.load(backward_latest, map_location=device) + torch.load(backward_final, map_location=set_device(device)) ) - except AttributeError: + except ValueError: print("No backward policy found") - - return gflownet + return gflownet, config def batch_with_rest(start, stop, step, tensor=False): diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index e9556f9c5..50356b377 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -32,6 +32,7 @@ def __init__( run_name=None, tags: list = None, context: str = "0", + notes: str = None, ): self.config = config self.do = do @@ -60,7 +61,7 @@ def __init__( if slurm_job_id: wandb_config["slurm_job_id"] = slurm_job_id self.run = self.wandb.init( - config=wandb_config, project=project_name, name=run_name + config=wandb_config, project=project_name, name=run_name, notes=notes ) else: self.wandb = None diff --git a/mila/launch.py b/mila/launch.py index 9b122ec7f..de00b79a7 100644 --- a/mila/launch.py +++ b/mila/launch.py @@ -7,8 +7,8 @@ from os.path import expandvars from pathlib import Path from textwrap import dedent -from git import Repo +from git import Repo from yaml import safe_load ROOT = Path(__file__).resolve().parent.parent @@ -265,6 +265,19 @@ def find_jobs_conf(args): return jobs_conf_path, local_out_dir +def quote(value): + v = str(value) + v = v.replace("(", r"\(").replace(")", r"\)") + if " " in v or "=" in v: + if '"' not in v: + v = f'"{v}"' + elif "'" not in v: + v = f"'{v}'" + else: + raise ValueError(f"Cannot quote {value}") + return v + + def script_dict_to_main_args_str(script_dict, is_first=True, nested_key=""): """ Recursively turns a dict of script args into a string of main.py args @@ -275,11 +288,24 @@ def script_dict_to_main_args_str(script_dict, is_first=True, nested_key=""): previous_str (str, optional): base string to append to. Defaults to "". """ if not isinstance(script_dict, dict): - return nested_key + "=" + str(script_dict) + " " + candidate = f"{nested_key}={quote(script_dict)}" + if candidate.count("=") > 1: + assert "'" not in candidate, """Keys cannot contain ` ` and `'` and `=` """ + candidate = f"'{candidate}'" + return candidate + " " new_str = "" for k, v in script_dict.items(): if k == "__value__": - new_str += nested_key + "=" + str(v) + " " + value = str(v) + if " " in value: + value = f"'{value}'" + candidate = f"{nested_key}={quote(v)} " + if candidate.count("=") > 1: + assert ( + "'" not in candidate + ), """Keys cannot contain ` ` and `'` and `=` """ + candidate = f"'{candidate}'" + new_str += candidate continue new_key = k if not nested_key else nested_key + "." + str(k) new_str += script_dict_to_main_args_str(v, nested_key=new_key, is_first=False) @@ -375,17 +401,21 @@ def code_dir_for_slurm_tmp_dir_checkout(git_checkout): sys.exit(0) GIT_WARNING = False + repo_url = ssh_to_https(repo.remotes.origin.url) + repo_name = repo_url.split("/")[-1].split(".git")[0] + return dedent( """\ $SLURM_TMPDIR - git clone {git_url} tpm-gflownet - cd tpm-gflownet + git clone {git_url} tmp-{repo_name} + cd tmp-{repo_name} {git_checkout} echo "Current commit: $(git rev-parse HEAD)" """ ).format( - git_url=ssh_to_https(repo.remotes.origin.url), + git_url=repo_url, git_checkout=f"git checkout {git_checkout}" if git_checkout else "", + repo_name=repo_name, ) @@ -615,13 +645,20 @@ def code_dir_for_slurm_tmp_dir_checkout(git_checkout): sbatch_path.parent.mkdir(parents=True, exist_ok=True) # write template sbatch_path.write_text(templated) - print(f"\n 🏷 Created ./{sbatch_path.relative_to(Path.cwd())}") + print() # Submit job to SLURM - out = popen(f"sbatch {sbatch_path}").read() + out = popen(f"sbatch {sbatch_path}").read().strip() # Identify printed-out job id job_id = re.findall(r"Submitted batch job (\d+)", out)[0] job_ids.append(job_id) print(" ✅ " + out) + # Rename sbatch file with job id + parts = sbatch_path.stem.split(f"_{now}") + new_name = f"{parts[0]}_{job_id}_{now}" + if len(parts) > 1: + new_name += f"_{parts[1]}" + sbatch_path = sbatch_path.rename(sbatch_path.parent / new_name) + print(f" 🏷 Created ./{sbatch_path.relative_to(Path.cwd())}") # Write job ID & output file path in the sbatch file job_output_file = str(outdir / f"{job_args['job_name']}-{job_id}.out") job_out_files.append(job_output_file) diff --git a/scripts/eval_gflownet.py b/scripts/eval_gflownet.py index c0cb359db..760155811 100644 --- a/scripts/eval_gflownet.py +++ b/scripts/eval_gflownet.py @@ -1,17 +1,23 @@ """ Computes evaluation metrics and plots from a pre-trained GFlowNet model. """ +import pickle +import shutil import sys from argparse import ArgumentParser from pathlib import Path -import hydra +import pandas as pd import torch -from hydra import compose, initialize, initialize_config_dir -from omegaconf import OmegaConf -from torch.distributions.categorical import Categorical +from tqdm import tqdm -from gflownet.gflownet import GFlowNetAgent, Policy +sys.path.append(str(Path(__file__).resolve().parent.parent)) + +from crystalrandom import generate_random_crystals + +from gflownet.gflownet import GFlowNetAgent +from gflownet.utils.common import load_gflow_net_from_run_path +from gflownet.utils.policy import parse_policy_config def add_args(parser): @@ -33,10 +39,77 @@ def add_args(parser): type=int, help="Number of sequences to sample", ) + parser.add_argument( + "--sampling_batch_size", + default=100, + type=int, + help="Number of samples to generate at a time to " + + "avoid memory issues. Will sum to n_samples.", + ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + help="Path to output directory. If not provided, will use run_path.", + ) + parser.add_argument( + "--print_config", + default=False, + action="store_true", + help="Print the config file", + ) + parser.add_argument( + "--samples_only", + default=False, + action="store_true", + help="Only sample from the model, do not compute metrics", + ) + parser.add_argument( + "--randominit", + action="store_true", + help="Sample from an untrained GFlowNet", + ) + parser.add_argument( + "--random_crystals", + action="store_true", + help="Sample crystals uniformly, without constraints", + ) parser.add_argument("--device", default="cpu", type=str) return parser +def get_batch_sizes(total, b=1): + """ + Batches an iterable into chunks of size n and returns their expected lengths + + Args: + total (int): total samples to produce + b (int): the batch size + + Returns: + list: list of batch sizes + """ + n = total // b + chunks = [b] * n + if total % b != 0: + chunks += [total % b] + return chunks + + +def print_args(args): + """ + Prints the arguments + + Args: + args (argparse.Namespace): the parsed arguments + """ + print("Arguments:") + darg = vars(args) + max_k = max([len(k) for k in darg]) + for k in darg: + print(f"\t{k:{max_k}}: {darg[k]}") + + def set_device(device: str): if device.lower() == "cuda" and torch.cuda.is_available(): return torch.device("cuda") @@ -45,62 +118,136 @@ def set_device(device: str): def main(args): - # Load config - with initialize_config_dir( - version_base=None, config_dir=args.run_path + "/.hydra", job_name="xxx" - ): - config = compose(config_name="config") - print(OmegaConf.to_yaml(config)) - # Disable wandb - config.logger.do.online = False - # Logger - logger = hydra.utils.instantiate(config.logger, config, _recursive_=False) - # The proxy is required in the env for scoring: might be an oracle or a model - proxy = hydra.utils.instantiate( - config.proxy, - device=config.device, - float_precision=config.float_precision, - ) - # The proxy is passed to env and used for computing rewards - env = hydra.utils.instantiate( - config.env, - proxy=proxy, - device=config.device, - float_precision=config.float_precision, - ) - gflownet = hydra.utils.instantiate( - config.gflownet, - device=config.device, - float_precision=config.float_precision, - env=env, - buffer=config.env.buffer, - logger=logger, - ) - # Load final models - ckpt = Path(args.run_path) / config.logger.logdir.ckpts - forward_final = [ - f for f in ckpt.glob(f"{config.gflownet.policy.forward.checkpoint}*final*") - ][0] - gflownet.forward_policy.model.load_state_dict( - torch.load(forward_final, map_location=set_device(args.device)) - ) - backward_final = [ - f for f in ckpt.glob(f"{config.gflownet.policy.backward.checkpoint}*final*") - ][0] - gflownet.backward_policy.model.load_state_dict( - torch.load(backward_final, map_location=set_device(args.device)) + if args.randominit: + prefix = "randominit" + load_final_ckpt = False + else: + prefix = "gfn" + load_final_ckpt = True + + gflownet, config = load_gflow_net_from_run_path( + run_path=args.run_path, + device=args.device, + no_wandb=True, + print_config=args.print_config, + load_final_ckpt=load_final_ckpt, ) - # Test GFlowNet model - gflownet.logger.test.n = args.n_samples - l1, kl, jsd, figs = gflownet.test() - # Save figures - keys = ["True reward and GFlowNet samples", "GFlowNet KDE Policy", "Reward KDE"] - fignames = ["samples", "kde_gfn", "kde_reward"] - output_dir = Path(args.run_path) / "figures" + env = gflownet.env + + base_dir = Path(args.output_dir or args.run_path) + + # --------------------------------- + # ----- Test GFlowNet model ----- + # --------------------------------- + + if not args.samples_only: + gflownet.logger.test.n = args.n_samples + ( + l1, + kl, + jsd, + corr_prob_traj_rew, + var_logrew_logp, + nll, + figs, + env_metrics, + ) = gflownet.test() + # Save figures + keys = ["True reward and GFlowNet samples", "GFlowNet KDE Policy", "Reward KDE"] + fignames = ["samples", "kde_gfn", "kde_reward"] + + output_dir = base_dir / "figures" + print("output_dir: ", str(output_dir)) + output_dir.mkdir(parents=True, exist_ok=True) + + for fig, figname in zip(figs, fignames): + output_fig = output_dir / figname + if fig is not None: + fig.savefig(output_fig, bbox_inches="tight") + print(f"Saved figures to {output_dir}") + + # Print metrics + print(f"L1: {l1}") + print(f"KL: {kl}") + print(f"JSD: {jsd}") + print(f"Corr (exp(logp), rewards): {corr_prob_traj_rew}") + print(f"Var (log(R) - logp): {var_logrew_logp}") + print(f"NLL: {nll}") + + # ------------------------------------------ + # ----- Sample GFlowNet ----- + # ------------------------------------------ + + output_dir = base_dir / "eval" / "samples" output_dir.mkdir(parents=True, exist_ok=True) - for fig, figname in zip(figs, fignames): - output_fig = output_dir / figname - fig.savefig(output_fig, bbox_inches="tight") + tmp_dir = output_dir / "tmp" + tmp_dir.mkdir(parents=True, exist_ok=True) + + if args.n_samples > 0 and args.n_samples <= 1e5: + print( + f"Sampling {args.n_samples} forward trajectories", + f"from GFlowNet in batches of {args.sampling_batch_size}", + ) + for i, bs in enumerate( + tqdm(get_batch_sizes(args.n_samples, args.sampling_batch_size)) + ): + batch, times = gflownet.sample_batch(n_forward=bs, train=False) + x_sampled = batch.get_terminating_states(proxy=True) + energies = env.oracle(x_sampled) + x_sampled = batch.get_terminating_states() + df = pd.DataFrame( + { + "readable": [env.state2readable(x) for x in x_sampled], + "energies": energies.tolist(), + } + ) + df.to_csv(tmp_dir / f"gfn_samples_{i}.csv") + dct = {"x": x_sampled, "energy": energies.tolist()} + pickle.dump(dct, open(tmp_dir / f"gfn_samples_{i}.pkl", "wb")) + + # Concatenate all samples + print("Concatenating sample CSVs") + df = pd.concat([pd.read_csv(f) for f in tqdm(list(tmp_dir.glob("*.csv")))]) + df.to_csv(output_dir / f"{prefix}_samples.csv") + dct = {"x": [], "energy": []} + for f in tqdm(list(tmp_dir.glob("*.pkl"))): + tmp_dict = pickle.load(open(f, "rb")) + dct = {k: v + tmp_dict[k] for k, v in dct.items()} + pickle.dump(dct, open(output_dir / f"{prefix}_samples.pkl", "wb")) + + if "y" in input("Delete temporary files? (y/n)"): + shutil.rmtree(tmp_dir) + + # ------------------------------------ + # ----- Sample random crystals ----- + # ------------------------------------ + + # Sample random crystals uniformly without constraints + if args.random_crystals and args.n_samples > 0 and args.n_samples <= 1e5: + print(f"Sampling {args.n_samples} random crystals without constraints...") + x_sampled = generate_random_crystals( + n_samples=args.n_samples, + elements=config.env.composition_kwargs.elements, + min_elements=2, + max_elements=5, + max_atoms=config.env.composition_kwargs.max_atoms, + max_atom_i=config.env.composition_kwargs.max_atom_i, + space_groups=config.env.space_group_kwargs.space_groups_subset, + min_length=0.0, + max_length=1.0, + min_angle=0.0, + max_angle=1.0, + ) + energies = env.oracle(env.statebatch2proxy(x_sampled)) + df = pd.DataFrame( + { + "readable": [env.state2readable(x) for x in x_sampled], + "energies": energies.tolist(), + } + ) + df.to_csv(output_dir / "randomcrystals_samples.csv") + dct = {"x": x_sampled, "energy": energies.tolist()} + pickle.dump(dct, open(output_dir / "randomcrystals_samples.pkl", "wb")) if __name__ == "__main__": @@ -108,6 +255,8 @@ def main(args): _, override_args = parser.parse_known_args() parser = add_args(parser) args = parser.parse_args() + torch.set_grad_enabled(False) torch.set_num_threads(1) + print_args(args) main(args) sys.exit() diff --git a/scripts/fit_lattice_proxy.py b/scripts/fit_lattice_proxy.py index a416a9bc8..83550d8b2 100644 --- a/scripts/fit_lattice_proxy.py +++ b/scripts/fit_lattice_proxy.py @@ -17,7 +17,6 @@ from gflownet.envs.crystals.lattice_parameters import LatticeParameters from gflownet.proxy.crystals.lattice_parameters import PICKLE_PATH - DATASET_PATH = ( Path(__file__).parents[1] / "data" / "crystals" / "matbench_mp_e_form_lp_stats.csv" ) diff --git a/scripts/mp20_matbench_lp_range.py b/scripts/mp20_matbench_lp_range.py index 4d3ec5180..8ae7fdedd 100644 --- a/scripts/mp20_matbench_lp_range.py +++ b/scripts/mp20_matbench_lp_range.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd - if __name__ == "__main__": mp = pd.read_csv(Path(__file__).parents[1] / "data/crystals/mp20_lp_stats.csv") mb = pd.read_csv( diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 9f5d1b9a3..862c8bdcd 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -158,10 +158,7 @@ def test__sample_actions__backward__returns_eos_if_done(env, n=5): env.set_state(state, done=True) masks.append(env.get_mask_invalid_actions_backward()) # Build random policy outputs and tensor masks - policy_outputs = torch.tile( - tfloat(env.random_policy_output, float_type=env.float, device=env.device), - (len(states), 1), - ) + policy_outputs = torch.tile(env.random_policy_output, (len(states), 1)) # Add noise to policy outputs policy_outputs += torch.randn(policy_outputs.shape) masks = tbool(masks, device=env.device) @@ -188,10 +185,7 @@ def test__get_logprobs__backward__returns_zero_if_done(env, n=5): (len(states), 1), ) # Build random policy outputs and tensor masks - policy_outputs = torch.tile( - tfloat(env.random_policy_output, float_type=env.float, device=env.device), - (len(states), 1), - ) + policy_outputs = torch.tile(env.random_policy_output, (len(states), 1)) # Add noise to policy outputs policy_outputs += torch.randn(policy_outputs.shape) masks = tbool(masks, device=env.device) @@ -306,7 +300,7 @@ def test__gflownet_minimal_runs(env): def test__sample_actions__get_logprobs__return_valid_actions_and_logprobs(env): env = env.reset() while not env.done: - policy_outputs = torch.unsqueeze(torch.tensor(env.random_policy_output), 0) + policy_outputs = torch.unsqueeze(env.random_policy_output, 0) mask_invalid = env.get_mask_invalid_actions_forward() valid_actions = [a for a, m in zip(env.action_space, mask_invalid) if not m] masks_invalid_torch = torch.unsqueeze(torch.BoolTensor(mask_invalid), 0) @@ -330,11 +324,9 @@ def test__sample_actions__get_logprobs__return_valid_actions_and_logprobs(env): @pytest.mark.repeat(1000) def test__forward_actions_have_nonzero_backward_prob(env): env = env.reset() - policy_random = torch.unsqueeze( - tfloat(env.random_policy_output, float_type=env.float, device=env.device), 0 - ) + policy_random = torch.unsqueeze(env.random_policy_output, 0) while not env.done: - state_new, action, valid = env.step_random(backward=False) + state_next, action, valid = env.step_random(backward=False) if not valid: continue # Get backward logprobs @@ -343,19 +335,56 @@ def test__forward_actions_have_nonzero_backward_prob(env): actions_torch = torch.unsqueeze( tfloat(action, float_type=env.float, device=env.device), 0 ) - states_torch = torch.unsqueeze( - tfloat(env.state, float_type=env.float, device=env.device), 0 - ) policy_outputs = policy_random.clone().detach() logprobs_bw = env.get_logprobs( policy_outputs=policy_outputs, actions=actions_torch, mask=masks, - states_from=states_torch, + states_from=[env.state], is_backward=True, ) assert torch.isfinite(logprobs_bw) assert logprobs_bw > -1e6 + state_prev = copy(state_next) + + +@pytest.mark.repeat(1000) +def test__trajectories_are_reversible(env): + # Skip for certain environments until fixed: + skip_envs = ["Crystal", "LatticeParameters", "Tree"] + if env.__class__.__name__ in skip_envs: + warnings.warn("Skipping test for this specific environment.") + return + env = env.reset() + + # Sample random forward trajectory + states_trajectory_fw = [] + actions_trajectory_fw = [] + while not env.done: + state, action, valid = env.step_random(backward=False) + if valid: + states_trajectory_fw.append(state) + actions_trajectory_fw.append(action) + + # Sample backward trajectory with actions in forward trajectory + states_trajectory_bw = [] + actions_trajectory_bw = [] + actions_trajectory_fw_copy = actions_trajectory_fw.copy() + while not env.equal(env.state, env.source) or env.done: + state, action, valid = env.step_backwards(actions_trajectory_fw_copy.pop()) + if valid: + states_trajectory_bw.append(state) + actions_trajectory_bw.append(action) + + assert all( + [ + env.equal(s_fw, s_bw) + for s_fw, s_bw in zip( + states_trajectory_fw[:-1], states_trajectory_bw[-2::-1] + ) + ] + ) + assert actions_trajectory_fw == actions_trajectory_bw[::-1] @pytest.mark.repeat(1000) @@ -407,15 +436,13 @@ def test__backward_actions_have_nonzero_forward_prob(env, n=1000): if states is None: warnings.warn("Skipping test because states are None.") return - policy_random = torch.unsqueeze( - tfloat(env.random_policy_output, float_type=env.float, device=env.device), 0 - ) + policy_random = torch.unsqueeze(env.random_policy_output, 0) for state in states: env.set_state(state, done=True) while True: if env.equal(env.state, env.source): break - state_new, action, valid = env.step_random(backward=True) + state_next, action, valid = env.step_random(backward=True) assert valid # Get forward logprobs mask_fw = env.get_mask_invalid_actions_forward() @@ -423,19 +450,17 @@ def test__backward_actions_have_nonzero_forward_prob(env, n=1000): actions_torch = torch.unsqueeze( tfloat(action, float_type=env.float, device=env.device), 0 ) - states_torch = torch.unsqueeze( - tfloat(env.state, float_type=env.float, device=env.device), 0 - ) policy_outputs = policy_random.clone().detach() logprobs_fw = env.get_logprobs( policy_outputs=policy_outputs, actions=actions_torch, mask=masks, - states_from=states_torch, + states_from=[env.state], is_backward=False, ) assert torch.isfinite(logprobs_fw) assert logprobs_fw > -1e6 + state_prev = copy(state_next) @pytest.mark.repeat(10) diff --git a/tests/gflownet/envs/test_ccrystal.py b/tests/gflownet/envs/test_ccrystal.py new file mode 100644 index 000000000..6bd49c576 --- /dev/null +++ b/tests/gflownet/envs/test_ccrystal.py @@ -0,0 +1,1296 @@ +import warnings + +import common +import numpy as np +import pytest +import torch +from torch import Tensor + +from gflownet.envs.crystals.ccrystal import CCrystal, Stage +from gflownet.envs.crystals.clattice_parameters import TRICLINIC +from gflownet.utils.common import tbool, tfloat + +SG_SUBSET_ALL_CLS_PS = [ + 1, + 2, + 3, + 6, + 16, + 17, + 67, + 81, + 89, + 127, + 143, + 144, + 146, + 148, + 168, + 169, + 189, + 195, + 200, + 230, +] + + +@pytest.fixture +def env(): + return CCrystal( + composition_kwargs={"elements": 4}, + do_composition_to_sg_constraints=False, + space_group_kwargs={"space_groups_subset": list(range(1, 15 + 1)) + [105]}, + ) + + +@pytest.fixture +def env_with_stoichiometry_sg_check(): + return CCrystal( + composition_kwargs={"elements": 4}, + do_composition_to_sg_constraints=True, + space_group_kwargs={"space_groups_subset": SG_SUBSET_ALL_CLS_PS}, + ) + + +def test__stage_next__returns_expected(): + assert Stage.next(Stage.COMPOSITION) == Stage.SPACE_GROUP + assert Stage.next(Stage.SPACE_GROUP) == Stage.LATTICE_PARAMETERS + assert Stage.next(Stage.LATTICE_PARAMETERS) == Stage.DONE + assert Stage.next(Stage.DONE) == None + + +def test__stage_prev__returns_expected(): + assert Stage.prev(Stage.COMPOSITION) == Stage.DONE + assert Stage.prev(Stage.SPACE_GROUP) == Stage.COMPOSITION + assert Stage.prev(Stage.LATTICE_PARAMETERS) == Stage.SPACE_GROUP + assert Stage.prev(Stage.DONE) == Stage.LATTICE_PARAMETERS + + +def test__environment__initializes_properly(env): + pass + + +def test__environment__has_expected_initial_state(env): + """ + The source of the composition and space group environments is all 0s. The source of + the continuous lattice parameters environment is all -1s. + """ + assert ( + env.state == env.source == [0] * (1 + 4 + 3) + [-1] * 6 + ) # stage + n elements + space groups + lattice parameters + + +def test__environment__has_expected_action_space(env): + assert len(env.action_space) == len( + env.subenvs[Stage.COMPOSITION].action_space + ) + len(env.subenvs[Stage.SPACE_GROUP].action_space) + len( + env.subenvs[Stage.LATTICE_PARAMETERS].action_space + ) + + underlying_action_space = ( + env.subenvs[Stage.COMPOSITION].action_space + + env.subenvs[Stage.SPACE_GROUP].action_space + + env.subenvs[Stage.LATTICE_PARAMETERS].action_space + ) + + for action, underlying_action in zip(env.action_space, underlying_action_space): + assert action[: len(underlying_action)] == underlying_action + + +def test__pad_depad_action(env): + for stage, subenv in env.subenvs.items(): + for action in subenv.action_space: + padded = env._pad_action(action, stage) + assert len(padded) == env.max_action_length + depadded = env._depad_action(padded, stage) + assert depadded == action + + +@pytest.mark.parametrize( + "states", + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.12, 0.23, 0.34, 0.45, 0.56, 0.67], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.75, 0.74, 0.73, 0.72, 0.71], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + ], +) +def test__statetorch2policy__is_concatenation_of_subenv_states(env, states): + # Get policy states from the batch of states converted into each subenv + states_dict = {stage: [] for stage in env.subenvs} + for state in states: + for stage in env.subenvs: + states_dict[stage].append(env._get_state_of_subenv(state, stage)) + states_policy_dict = { + stage: subenv.statebatch2policy(states_dict[stage]) + for stage, subenv in env.subenvs.items() + } + states_policy_expected = torch.cat( + [el for el in states_policy_dict.values()], dim=1 + ) + # Get policy states from env.statetorch2policy + states_torch = tfloat(states, float_type=env.float, device=env.device) + states_policy = env.statetorch2policy(states_torch) + assert torch.all(torch.eq(states_policy, states_policy_expected)) + + +@pytest.mark.parametrize( + "states", + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.12, 0.23, 0.34, 0.45, 0.56, 0.67], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.75, 0.74, 0.73, 0.72, 0.71], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + ], +) +def test__statetorch2proxy__is_concatenation_of_subenv_states(env, states): + # Get proxy states from the batch of states converted into each subenv + states_dict = {stage: [] for stage in env.subenvs} + for state in states: + for stage in env.subenvs: + states_dict[stage].append(env._get_state_of_subenv(state, stage)) + states_proxy_dict = { + stage: subenv.statebatch2proxy(states_dict[stage]) + for stage, subenv in env.subenvs.items() + } + states_proxy_expected = torch.cat([el for el in states_proxy_dict.values()], dim=1) + # Get proxy states from env.statetorch2proxy + states_torch = tfloat(states, float_type=env.float, device=env.device) + states_proxy = env.statetorch2proxy(states_torch) + assert torch.all(torch.eq(states_proxy, states_proxy_expected)) + + +@pytest.mark.parametrize( + "states", + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.12, 0.23, 0.34, 0.45, 0.56, 0.67], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.75, 0.74, 0.73, 0.72, 0.71], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + ], +) +def test__state2readable__is_concatenation_of_subenv_states(env, states): + # Get policy states from the batch of states converted into each subenv + states_readable_expected = [] + for state in states: + readables = [] + for stage, subenv in env.subenvs.items(): + readables.append( + subenv.state2readable(env._get_state_of_subenv(state, stage)) + ) + states_readable_expected.append( + f"{env._get_stage(state)}; " + f"Composition = {readables[0]}; " + f"SpaceGroup = {readables[1]}; " + f"LatticeParameters = {readables[2]}" + ) + # Get policy states from env.statetorch2policy + states_readable = [env.state2readable(state) for state in states] + for readable, readable_expected in zip(states_readable, states_readable_expected): + assert readable == readable_expected + + +@pytest.mark.parametrize( + "state, state_composition, state_space_group, state_lattice_parameters", + [ + [ + [0, 1, 0, 4, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 0, 4, 0], + [0, 0, 0], + [-1, -1, -1, -1, -1, -1], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0], + [0, 0, 0], + [-1, -1, -1, -1, -1, -1], + ], + [ + [1, 1, 0, 4, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 0, 4, 0], + [0, 0, 0], + [-1, -1, -1, -1, -1, -1], + ], + [ + [1, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + [1, 0, 4, 0], + [4, 3, 105], + [-1, -1, -1, -1, -1, -1], + ], + [ + [1, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + [1, 0, 4, 0], + [4, 3, 105], + [-1, -1, -1, -1, -1, -1], + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + [1, 0, 4, 0], + [4, 3, 105], + [-1, -1, -1, -1, -1, -1], + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + [1, 0, 4, 0], + [4, 3, 105], + [-1, -1, -1, -1, -1, -1], + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [1, 0, 4, 0], + [4, 3, 105], + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [1, 0, 4, 0], + [4, 3, 105], + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.75, 0.74, 0.73, 0.72, 0.71], + [1, 0, 4, 0], + [4, 3, 105], + [0.76, 0.75, 0.74, 0.73, 0.72, 0.71], + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.75, 0.74, 0.73, 0.72, 0.71], + [1, 0, 4, 0], + [4, 3, 105], + [0.76, 0.75, 0.74, 0.73, 0.72, 0.71], + ], + ], +) +def test__state_of_subenv__returns_expected( + env, state, state_composition, state_space_group, state_lattice_parameters +): + for stage in env.subenvs: + state_subenv = env._get_state_of_subenv(state, stage) + if stage == Stage.COMPOSITION: + assert state_subenv == state_composition + elif stage == Stage.SPACE_GROUP: + assert state_subenv == state_space_group + elif stage == Stage.LATTICE_PARAMETERS: + assert state_subenv == state_lattice_parameters + else: + raise ValueError(f"Unrecognized stage {stage}.") + + +@pytest.mark.parametrize( + "env_input, state, dones, has_lattice_parameters, has_composition_constraints", + [ + ( + "env", + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [False, False, False], + False, + False, + ), + ( + "env", + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [False, False, False], + False, + False, + ), + ( + "env", + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [True, False, False], + True, + False, + ), + ( + "env", + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [True, True, False], + True, + False, + ), + ( + "env_with_stoichiometry_sg_check", + [2, 4, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [True, True, False], + True, + True, + ), + ], +) +def test__set_state__sets_state_subenvs_dones_and_constraints( + env_input, + state, + dones, + has_lattice_parameters, + has_composition_constraints, + request, +): + env = request.getfixturevalue(env_input) + env.set_state(state) + # Check global state + assert env.state == state + + # Check states of subenvs + for stage, subenv in env.subenvs.items(): + assert subenv.state == env._get_state_of_subenv(state, stage) + + # Check dones + for subenv, done in zip(env.subenvs.values(), dones): + assert subenv.done == done + + # Check lattice parameters + if env.subenvs[Stage.SPACE_GROUP].lattice_system != "None": + assert has_lattice_parameters + assert ( + env.subenvs[Stage.SPACE_GROUP].lattice_system + == env.subenvs[Stage.LATTICE_PARAMETERS].lattice_system + ) + else: + assert not has_lattice_parameters + + # Check composition constraints + if has_composition_constraints: + n_atoms = [n for n in env.subenvs[Stage.COMPOSITION].state if n > 0] + n_atoms_compatibility_dict = env.subenvs[ + Stage.SPACE_GROUP + ].build_n_atoms_compatibility_dict( + n_atoms, + env.subenvs[Stage.SPACE_GROUP].space_groups.keys(), + ) + assert ( + n_atoms_compatibility_dict + == env.subenvs[Stage.SPACE_GROUP].n_atoms_compatibility_dict + ) + + +@pytest.mark.parametrize( + "state", + [ + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 2, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [2, 1, 0, 4, 0, 4, 3, 105, 0.12, 0.23, 0.34, 0.45, 0.56, 0.67], + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.75, 0.74, 0.73, 0.72, 0.71], + ], +) +def test__get_mask_invalid_actions_backward__returns_expected_general_case(env, state): + stage = env._get_stage(state) + mask = env.get_mask_invalid_actions_backward(state, done=False) + for stg, subenv in env.subenvs.items(): + if stg == stage: + # Mask of state if stage is current stage in state + mask_subenv_expected = subenv.get_mask_invalid_actions_backward( + env._get_state_of_subenv(state, stg) + ) + else: + # Mask of source if stage is other than current stage in state + mask_subenv_expected = subenv.get_mask_invalid_actions_backward( + subenv.source + ) + mask_subenv = env._get_mask_of_subenv(mask, stg) + assert mask_subenv == mask_subenv_expected + + +@pytest.mark.parametrize( + "state", + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 3, 1, 0, 6, 1, 2, 2, -1, -1, -1, -1, -1, -1], + [2, 3, 1, 0, 6, 2, 1, 3, -1, -1, -1, -1, -1, -1], + ], +) +def test__get_mask_invald_actions_backward__returns_expected_stage_transition( + env, state +): + stage = env._get_stage(state) + mask = env.get_mask_invalid_actions_backward(state, done=False) + for stg, subenv in env.subenvs.items(): + if stg == Stage.prev(stage) and stage != Stage(0): + # Mask of done (EOS only) if stage is previous stage in state + mask_subenv_expected = subenv.get_mask_invalid_actions_backward( + env._get_state_of_subenv(state, stg), done=True + ) + else: + mask_subenv_expected = subenv.get_mask_invalid_actions_backward( + subenv.source + ) + if stg == stage: + assert env._get_state_of_subenv(state, stg) == subenv.source + mask_subenv = env._get_mask_of_subenv(mask, stg) + assert mask_subenv == mask_subenv_expected + + +@pytest.mark.parametrize( + "action", [(1, 1, -2, -2, -2, -2, -2), (3, 4, -2, -2, -2, -2, -2)] +) +def test__step__single_action_works(env, action): + env.step(action) + + assert env.state != env.source + + +@pytest.mark.parametrize( + "actions, exp_result, exp_stage, last_action_valid", + [ + [ + [(1, 1, -2, -2, -2, -2, -2), (3, 4, -2, -2, -2, -2, -2)], + [0, 1, 0, 4, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.COMPOSITION, + True, + ], + [ + [(2, 105, 3, -3, -3, -3, -3)], + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.COMPOSITION, + False, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + ], + [1, 1, 0, 4, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.SPACE_GROUP, + True, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + ], + [1, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + Stage.SPACE_GROUP, + True, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + (2, 105, 0, -3, -3, -3, -3), + ], + [1, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + Stage.SPACE_GROUP, + False, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -1, -3, -3, -3, -3), + ], + [2, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + Stage.LATTICE_PARAMETERS, + True, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -1, -3, -3, -3, -3), + (1.5, 0, 0, 0, 0, 0, 0), + ], + [2, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + Stage.LATTICE_PARAMETERS, + False, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -1, -3, -3, -3, -3), + (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1), + ], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.1, 0.3, 0.4, 0.4, 0.4], + Stage.LATTICE_PARAMETERS, + True, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -1, -3, -3, -3, -3), + (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1), + (0.6, 0.5, 0.8, 0.3, 0.2, 0.6, 0), + ], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.1, 0.3, 0.4, 0.4, 0.4], + Stage.LATTICE_PARAMETERS, + False, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -1, -3, -3, -3, -3), + (0.1, 0.1, 0.3, 0.0, 0.0, 0.0, 1), + (0.66, 0.0, 0.44, 0.0, 0.0, 0.0, 0), + ], + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.76, 0.74, 0.4, 0.4, 0.4], + Stage.LATTICE_PARAMETERS, + True, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -1, -3, -3, -3, -3), + (0.1, 0.1, 0.3, 0.0, 0.0, 0.0, 1), + (0.66, 0.66, 0.44, 0.0, 0.0, 0.0, 0), + (np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf), + ], + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.76, 0.74, 0.4, 0.4, 0.4], + Stage.LATTICE_PARAMETERS, + True, + ], + ], +) +def test__step__action_sequence_has_expected_result( + env, actions, exp_result, exp_stage, last_action_valid +): + for action in actions: + warnings.filterwarnings("ignore") + _, _, valid = env.step(action) + + assert env.state == exp_result + assert env._get_stage() == exp_stage + assert valid == last_action_valid + + +@pytest.mark.parametrize( + "state_init, state_end, stage_init, stage_end, actions, last_action_valid", + [ + [ + [0, 1, 0, 4, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.COMPOSITION, + Stage.COMPOSITION, + [(3, 4, -2, -2, -2, -2, -2), (1, 1, -2, -2, -2, -2, -2)], + True, + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.COMPOSITION, + Stage.COMPOSITION, + [(2, 105, 3, -3, -3, -3, -3)], + False, + ], + [ + [1, 1, 0, 4, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.SPACE_GROUP, + Stage.COMPOSITION, + [ + (-1, -1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (1, 1, -2, -2, -2, -2, -2), + ], + True, + ], + [ + [1, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.SPACE_GROUP, + Stage.COMPOSITION, + [ + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (1, 1, -2, -2, -2, -2, -2), + ], + True, + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, -1, -1, -1, -1, -1, -1], + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.LATTICE_PARAMETERS, + Stage.COMPOSITION, + [ + (-1, -1, -1, -3, -3, -3, -3), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (1, 1, -2, -2, -2, -2, -2), + ], + True, + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.1, 0.3, 0.4, 0.4, 0.4], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.1, 0.3, 0.4, 0.4, 0.4], + Stage.LATTICE_PARAMETERS, + Stage.LATTICE_PARAMETERS, + [ + (1.5, 0, 0, 0, 0, 0, 0), + ], + False, + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.1, 0.3, 0.4, 0.4, 0.4], + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.LATTICE_PARAMETERS, + Stage.COMPOSITION, + [ + (0.1, 0.1, 0.3, 0.0, 0.0, 0.0, 1), + (-1, -1, -1, -3, -3, -3, -3), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (1, 1, -2, -2, -2, -2, -2), + ], + True, + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.76, 0.74, 0.4, 0.4, 0.4], + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.LATTICE_PARAMETERS, + Stage.COMPOSITION, + [ + (0.66, 0.0, 0.44, 0.0, 0.0, 0.0, 0), + (0.1, 0.1, 0.3, 0.0, 0.0, 0.0, 1), + (-1, -1, -1, -3, -3, -3, -3), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (1, 1, -2, -2, -2, -2, -2), + ], + True, + ], + [ + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.76, 0.74, 0.4, 0.4, 0.4], + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + Stage.LATTICE_PARAMETERS, + Stage.COMPOSITION, + [ + (np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf), + (0.66, 0.0, 0.44, 0.0, 0.0, 0.0, 0), + (0.1, 0.1, 0.3, 0.0, 0.0, 0.0, 1), + (-1, -1, -1, -3, -3, -3, -3), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (1, 1, -2, -2, -2, -2, -2), + ], + True, + ], + ], +) +def test__step_backwards__action_sequence_has_expected_result( + env, state_init, state_end, stage_init, stage_end, actions, last_action_valid +): + # Hacky way to also test if first action global EOS + if actions[0] == env.eos: + env.set_state(state_init, done=True) + else: + env.set_state(state_init, done=False) + assert env.state == state_init + assert env._get_stage() == stage_init + for action in actions: + warnings.filterwarnings("ignore") + _, _, valid = env.step_backwards(action) + + assert env.state == state_end + assert env._get_stage() == stage_end + assert valid == last_action_valid + + +@pytest.mark.parametrize( + "actions", + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -1, -3, -3, -3, -3), + (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1), + (0.66, 0.55, 0.44, 0.33, 0.22, 0.11, 0), + (np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.inf), + ] + ], +) +def test__reset(env, actions): + for action in actions: + env.step(action) + + assert env.state != env.source + for subenv in [ + env.subenvs[Stage.COMPOSITION], + env.subenvs[Stage.SPACE_GROUP], + env.subenvs[Stage.LATTICE_PARAMETERS], + ]: + assert subenv.state != subenv.source + assert env.subenvs[Stage.LATTICE_PARAMETERS].lattice_system != TRICLINIC + + env.reset() + + assert env.state == env.source + for subenv in [ + env.subenvs[Stage.COMPOSITION], + env.subenvs[Stage.SPACE_GROUP], + env.subenvs[Stage.LATTICE_PARAMETERS], + ]: + assert subenv.state == subenv.source + assert env.subenvs[Stage.LATTICE_PARAMETERS].lattice_system == TRICLINIC + + +# TODO: write new test of masks, both fw and bw +@pytest.mark.skip(reason="skip while developping other tests") +@pytest.mark.parametrize( + "actions, exp_stage", + [ + [ + [], + Stage.COMPOSITION, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + ], + Stage.SPACE_GROUP, + ], + [ + [ + (1, 1, -2, -2, -2, -2, -2), + (3, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + (2, 105, 0, -3, -3, -3, -3), + (-1, -1, -1, -3, -3, -3, -3), + ], + Stage.LATTICE_PARAMETERS, + ], + ], +) +def test__get_mask_invalid_actions_forward__masks_all_actions_from_different_stages( + env, actions, exp_stage +): + for action in actions: + env.step(action) + + assert env._get_stage() == exp_stage + + mask = env.get_mask_invalid_actions_forward() + + if env._get_stage() == Stage.COMPOSITION: + assert not all(mask[: len(env.subenvs[Stage.COMPOSITION].action_space)]) + assert all(mask[len(env.subenvs[Stage.COMPOSITION].action_space) :]) + if env._get_stage() == Stage.SPACE_GROUP: + assert not all( + mask[ + len(env.subenvs[Stage.COMPOSITION].action_space) : len( + env.subenvs[Stage.COMPOSITION].action_space + ) + + len(env.subenvs[Stage.SPACE_GROUP].action_space) + ] + ) + assert all(mask[: len(env.subenvs[Stage.COMPOSITION].action_space)]) + assert all( + mask[ + len(env.subenvs[Stage.COMPOSITION].action_space) + + len(env.subenvs[Stage.SPACE_GROUP].action_space) : + ] + ) + if env._get_stage() == Stage.LATTICE_PARAMETERS: + assert not all( + mask[ + len(env.subenvs[Stage.COMPOSITION].action_space) + + len(env.subenvs[Stage.SPACE_GROUP].action_space) : + ] + ) + assert all( + mask[ + : len(env.subenvs[Stage.COMPOSITION].action_space) + + len(env.subenvs[Stage.SPACE_GROUP].action_space) + ] + ) + + +def test__get_policy_outputs__is_the_concatenation_of_subenvs(env): + policy_output_composition = env.subenvs[Stage.COMPOSITION].get_policy_output( + env.subenvs[Stage.COMPOSITION].fixed_distr_params + ) + policy_output_space_group = env.subenvs[Stage.SPACE_GROUP].get_policy_output( + env.subenvs[Stage.SPACE_GROUP].fixed_distr_params + ) + policy_output_lattice_parameters = env.subenvs[ + Stage.LATTICE_PARAMETERS + ].get_policy_output(env.subenvs[Stage.LATTICE_PARAMETERS].fixed_distr_params) + policy_output_cat = torch.cat( + ( + policy_output_composition, + policy_output_space_group, + policy_output_lattice_parameters, + ) + ) + policy_output = env.get_policy_output(env.fixed_distr_params) + assert torch.all(torch.eq(policy_output_cat, policy_output)) + + +def test___get_policy_outputs_of_subenv__returns_correct_output(env): + n_states = 5 + policy_output_composition = torch.tile( + env.subenvs[Stage.COMPOSITION].get_policy_output( + env.subenvs[Stage.COMPOSITION].fixed_distr_params + ), + dims=(n_states, 1), + ) + policy_output_space_group = torch.tile( + env.subenvs[Stage.SPACE_GROUP].get_policy_output( + env.subenvs[Stage.SPACE_GROUP].fixed_distr_params + ), + dims=(n_states, 1), + ) + policy_output_lattice_parameters = torch.tile( + env.subenvs[Stage.LATTICE_PARAMETERS].get_policy_output( + env.subenvs[Stage.LATTICE_PARAMETERS].fixed_distr_params + ), + dims=(n_states, 1), + ) + policy_outputs = torch.tile( + env.get_policy_output(env.fixed_distr_params), dims=(n_states, 1) + ) + assert torch.all( + torch.eq( + env._get_policy_outputs_of_subenv(policy_outputs, Stage.COMPOSITION), + policy_output_composition, + ) + ) + assert torch.all( + torch.eq( + env._get_policy_outputs_of_subenv(policy_outputs, Stage.SPACE_GROUP), + policy_output_space_group, + ) + ) + assert torch.all( + torch.eq( + env._get_policy_outputs_of_subenv(policy_outputs, Stage.LATTICE_PARAMETERS), + policy_output_lattice_parameters, + ) + ) + + +@pytest.mark.parametrize( + "states", + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.12, 0.23, 0.34, 0.45, 0.56, 0.67], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [2, 1, 0, 4, 0, 4, 3, 105, 0.76, 0.75, 0.74, 0.73, 0.72, 0.71], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + ], +) +def test__get_mask_of_subenv__returns_correct_submasks(env, states): + # Get states from each stage and masks computed with the Crystal env. + states_dict = {stage: [] for stage in Stage} + masks_dict = {stage: [] for stage in Stage} + stages = [] + for s in states: + stage = env._get_stage(s) + states_dict[stage].append(s) + masks_dict[stage].append(env.get_mask_invalid_actions_forward(s)) + stages.append(stage) + + for stage, subenv in env.subenvs.items(): + # Get masks computed with subenv + masks_subenv = tbool( + [ + subenv.get_mask_invalid_actions_forward( + env._get_state_of_subenv(s, stage) + ) + for s in states_dict[stage] + ], + device=env.device, + ) + assert torch.all( + torch.eq( + env._get_mask_of_subenv( + tbool(masks_dict[stage], device=env.device), stage + ), + masks_subenv, + ) + ) + + +@pytest.mark.repeat(10) +def test__step_random__does_not_crash_from_source(env): + """ + Very low bar test... + """ + env.reset() + env.step_random() + pass + + +@pytest.mark.parametrize( + "states", + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.1, 0.3, 0.4, 0.4, 0.4], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.5, 0.5, 0.3, 0.4, 0.4, 0.4], + [2, 1, 0, 4, 0, 4, 3, 105, 0.45, 0.45, 0.33, 0.4, 0.4, 0.4], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + ], +) +def test__sample_actions_forward__returns_valid_actions(env, states): + """ + Still low bar, but getting better... + """ + n_states = len(states) + # Get masks + masks = tbool( + [env.get_mask_invalid_actions_forward(s) for s in states], device=env.device + ) + # Build policy outputs + params = env.random_distr_params + policy_outputs = torch.tile(env.get_policy_output(params), dims=(n_states, 1)) + # Sample actions + actions, _ = env.sample_actions_batch( + policy_outputs, masks, states, is_backward=False + ) + # Sample actions are valid + for state, action in zip(states, actions): + if env._get_stage(state) == Stage.LATTICE_PARAMETERS: + continue + assert action in env.get_valid_actions(state, done=False, backward=False) + + +@pytest.mark.parametrize( + "states", + [ + [ + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + [ + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + ], + [ + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.1, 0.3, 0.4, 0.4, 0.4], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.5, 0.5, 0.3, 0.4, 0.4, 0.4], + [2, 1, 0, 4, 0, 4, 3, 105, 0.45, 0.45, 0.33, 0.4, 0.4, 0.4], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + ], +) +def test__sample_actions_backward__returns_valid_actions(env, states): + """ + Just a little higher... + """ + n_states = len(states) + # Get masks + masks = tbool( + [env.get_mask_invalid_actions_backward(s) for s in states], device=env.device + ) + # Build policy outputs + params = env.random_distr_params + policy_outputs = torch.tile(env.get_policy_output(params), dims=(n_states, 1)) + # Sample actions + actions, _ = env.sample_actions_batch( + policy_outputs, masks, states, is_backward=True + ) + # Sample actions are valid + for state, action in zip(states, actions): + if env._get_stage(state) == Stage.LATTICE_PARAMETERS: + continue + assert action in env.get_valid_actions(state, done=False, backward=True) + + +@pytest.mark.repeat(100) +def test__trajectory_random__does_not_crash_from_source(env): + """ + Raising the bar... + """ + env.reset() + env.trajectory_random() + pass + + +@pytest.mark.parametrize( + "states, actions", + [ + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + [ + (1, 7, -2, -2, -2, -2, -2), + (3, 16, -2, -2, -2, -2, -2), + (1, 6, -2, -2, -2, -2, -2), + (3, 8, -2, -2, -2, -2, -2), + (2, 11, -2, -2, -2, -2, -2), + (3, 9, -2, -2, -2, -2, -2), + ], + ], + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + ], + [ + (1, 6, -2, -2, -2, -2, -2), + (2, 14, 0, -3, -3, -3, -3), + (2, 2, 1, -3, -3, -3, -3), + (2, 1, 3, -3, -3, -3, -3), + ], + ], + [ + [ + [0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.1, 0.3, 0.4, 0.4, 0.4], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 4, 0, 4, 3, 105, 0.5, 0.5, 0.3, 0.4, 0.4, 0.4], + [2, 1, 0, 4, 0, 4, 3, 105, 0.45, 0.45, 0.33, 0.4, 0.4, 0.4], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + [ + (1, 15, -2, -2, -2, -2, -2), + (1, 2, -2, -2, -2, -2, -2), + (2, 7, 0, -3, -3, -3, -3), + (0.49, 0.40, 0.40, 0.37, 0.35, 0.36, 0.0), + (2, 1, 1, -3, -3, -3, -3), + (2, 1, 3, -3, -3, -3, -3), + (2, 11, -2, -2, -2, -2, -2), + (3, 9, -2, -2, -2, -2, -2), + (2, 2, 3, -3, -3, -3, -3), + (3, 2, -2, -2, -2, -2, -2), + (0.27, 0.28, 0.30, 0.39, 0.37, 0.29, 0.0), + (0.32, 0.30, 0.45, 0.33, 0.42, 0.39, 0.0), + (4, 4, -2, -2, -2, -2, -2), + ], + ], + ], +) +def test__get_logprobs_forward__returns_valid_actions(env, states, actions): + """ + This would already be not too bad! + """ + n_states = len(states) + actions = tfloat(actions, float_type=env.float, device=env.device) + # Get masks + masks = tbool( + [env.get_mask_invalid_actions_forward(s) for s in states], device=env.device + ) + # Build policy outputs + params = env.random_distr_params + policy_outputs = torch.tile(env.get_policy_output(params), dims=(n_states, 1)) + # Get log probs + logprobs = env.get_logprobs( + policy_outputs, actions, masks, states, is_backward=False + ) + assert torch.all(torch.isfinite(logprobs)) + + +# TODO: Set lattice system +@pytest.mark.parametrize( + "states, actions", + [ + [ + [ + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + [ + (2, 4, -2, -2, -2, -2, -2), + (2, 4, -2, -2, -2, -2, -2), + (1, 3, -2, -2, -2, -2, -2), + (1, 3, -2, -2, -2, -2, -2), + (4, 6, -2, -2, -2, -2, -2), + ], + ], + [ + [ + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + ], + [ + (-1, -1, -2, -2, -2, -2, -2), + (0, 1, 0, -3, -3, -3, -3), + (1, 1, 1, -3, -3, -3, -3), + ], + ], + [ + [ + [0, 0, 4, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + # [2, 1, 0, 4, 0, 4, 3, 105, 0.1, 0.1, 0.3, 0.4, 0.4, 0.4], + [1, 3, 1, 0, 6, 1, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 1, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 0, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + [1, 3, 1, 0, 6, 1, 2, 0, -1, -1, -1, -1, -1, -1], + [0, 3, 1, 0, 6, 0, 0, 0, -1, -1, -1, -1, -1, -1], + # [2, 1, 0, 4, 0, 4, 3, 105, 0.5, 0.5, 0.3, 0.4, 0.4, 0.4], + # [2, 1, 0, 4, 0, 4, 3, 105, 0.45, 0.45, 0.33, 0.4, 0.4, 0.4], + [0, 0, 4, 3, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1], + ], + [ + (2, 4, -2, -2, -2, -2, -2), + (-1, -1, -2, -2, -2, -2, -2), + # (0.10, 0.10, 0.17, 0.0, 0.0, 0.0, 0.0), + (0, 1, 0, -3, -3, -3, -3), + (1, 1, 1, -3, -3, -3, -3), + (1, 3, -2, -2, -2, -2, -2), + (1, 3, -2, -2, -2, -2, -2), + (1, 2, 1, -3, -3, -3, -3), + (2, 1, -2, -2, -2, -2, -2), + # (0.37, 0.37, 0.23, 0.0, 0.0, 0.0, 0.0), + # (0.23, 0.23, 0.11, 0.0, 0.0, 0.0, 0.0), + (3, 3, -2, -2, -2, -2, -2), + ], + ], + ], +) +def test__get_logprobs_backward__returns_valid_actions(env, states, actions): + """ + And backwards? + """ + n_states = len(states) + actions = tfloat(actions, float_type=env.float, device=env.device) + # Get masks + masks = tbool( + [env.get_mask_invalid_actions_backward(s) for s in states], device=env.device + ) + # Build policy outputs + params = env.random_distr_params + policy_outputs = torch.tile(env.get_policy_output(params), dims=(n_states, 1)) + # Get log probs + logprobs = env.get_logprobs( + policy_outputs, actions, masks, states, is_backward=True + ) + assert torch.all(torch.isfinite(logprobs)) + + +def test__continuous_env_common(env): + print( + "\n\nCommon tests for crystal without composition <-> space group constraints\n" + ) + return common.test__continuous_env_common(env) + + +def test__continuous_env_with_stoichiometry_sg_check_common( + env_with_stoichiometry_sg_check, +): + print("\n\nCommon tests for crystal with composition <-> space group constraints\n") + return common.test__continuous_env_common(env_with_stoichiometry_sg_check) diff --git a/tests/gflownet/envs/test_ccube.py b/tests/gflownet/envs/test_ccube.py index feda16bf3..25181ed4e 100644 --- a/tests/gflownet/envs/test_ccube.py +++ b/tests/gflownet/envs/test_ccube.py @@ -91,8 +91,8 @@ def test__mask_backward__returns_all_true_except_eos_if_done(env, request): for state in states: env.set_state(state, done=True) mask = env.get_mask_invalid_actions_backward() - assert all(mask[:-1]) - assert mask[-1] is False + assert all(mask[:2]) + assert mask[2] is False @pytest.mark.parametrize( @@ -100,23 +100,23 @@ def test__mask_backward__returns_all_true_except_eos_if_done(env, request): [ ( [-1.0], - [False, False, True], + [False, False, True, False], ), ( [0.0], - [False, True, False], + [False, True, False, False], ), ( [0.5], - [False, True, False], + [False, True, False, False], ), ( [0.90], - [False, True, False], + [False, True, False, False], ), ( [0.95], - [True, True, False], + [True, True, False, False], ), ], ) @@ -131,43 +131,43 @@ def test__mask_forward__1d__returns_expected(cube1d, state, mask_expected): [ ( [-1.0, -1.0], - [False, False, True], + [False, False, True, False, False], ), ( [0.0, 0.0], - [False, True, False], + [False, True, False, False, False], ), ( [0.5, 0.0], - [False, True, False], + [False, True, False, False, False], ), ( [0.0, 0.01], - [False, True, False], + [False, True, False, False, False], ), ( [0.5, 0.5], - [False, True, False], + [False, True, False, False, False], ), ( [0.90, 0.5], - [False, True, False], + [False, True, False, False, False], ), ( [0.95, 0.5], - [True, True, False], + [True, True, False, False, False], ), ( [0.5, 0.90], - [False, True, False], + [False, True, False, False, False], ), ( [0.5, 0.95], - [True, True, False], + [True, True, False, False, False], ), ( [0.95, 0.95], - [True, True, False], + [True, True, False, False, False], ), ], ) @@ -182,31 +182,31 @@ def test__mask_forward__2d__returns_expected(cube2d, state, mask_expected): [ ( [-1.0], - [True, True, True], + [True, True, True, False], ), ( [0.0], - [True, False, True], + [True, False, True, False], ), ( [0.05], - [True, False, True], + [True, False, True, False], ), ( [0.1], - [False, True, True], + [False, True, True, False], ), ( [0.5], - [False, True, True], + [False, True, True, False], ), ( [0.90], - [False, True, True], + [False, True, True, False], ), ( [0.95], - [False, True, True], + [False, True, True, False], ), ], ) @@ -221,47 +221,47 @@ def test__mask_backward__1d__returns_expected(cube1d, state, mask_expected): [ ( [-1.0, -1.0], - [True, True, True], + [True, True, True, False, False], ), ( [0.0, 0.0], - [True, False, True], + [True, False, True, False, False], ), ( [0.5, 0.5], - [False, True, True], + [False, True, True, False, False], ), ( [0.05, 0.5], - [True, False, True], + [True, False, True, False, False], ), ( [0.5, 0.05], - [True, False, True], + [True, False, True, False, False], ), ( [0.05, 0.05], - [True, False, True], + [True, False, True, False, False], ), ( [0.90, 0.5], - [False, True, True], + [False, True, True, False, False], ), ( [0.5, 0.90], - [False, True, True], + [False, True, True, False, False], ), ( [0.95, 0.5], - [False, True, True], + [False, True, True, False, False], ), ( [0.5, 0.95], - [False, True, True], + [False, True, True, False, False], ), ( [0.95, 0.95], - [False, True, True], + [False, True, True, False, False], ), ], ) diff --git a/tests/gflownet/envs/test_clattice_parameters.py b/tests/gflownet/envs/test_clattice_parameters.py new file mode 100644 index 000000000..6ecddb050 --- /dev/null +++ b/tests/gflownet/envs/test_clattice_parameters.py @@ -0,0 +1,304 @@ +import common +import pytest +import torch + +from gflownet.envs.crystals.clattice_parameters import ( + CUBIC, + HEXAGONAL, + MONOCLINIC, + ORTHORHOMBIC, + PARAMETER_NAMES, + RHOMBOHEDRAL, + TETRAGONAL, + TRICLINIC, + CLatticeParameters, +) +from gflownet.envs.crystals.lattice_parameters import LATTICE_SYSTEMS +from gflownet.utils.common import tfloat + +N_REPETITIONS = 100 + + +@pytest.fixture() +def env(lattice_system): + return CLatticeParameters( + lattice_system=lattice_system, + min_length=1.0, + max_length=5.0, + min_angle=30.0, + max_angle=150.0, + ) + + +@pytest.mark.parametrize("lattice_system", LATTICE_SYSTEMS) +def test__environment__initializes_properly(env, lattice_system): + pass + + +@pytest.mark.parametrize( + "lattice_system, expected_params", + [ + (CUBIC, [None, None, None, 90, 90, 90]), + (HEXAGONAL, [None, None, None, 90, 90, 120]), + (MONOCLINIC, [None, None, None, 90, None, 90]), + (ORTHORHOMBIC, [None, None, None, 90, 90, 90]), + (RHOMBOHEDRAL, [None, None, None, None, None, None]), + (TETRAGONAL, [None, None, None, 90, 90, 90]), + (TRICLINIC, [None, None, None, None, None, None]), + ], +) +def test__environment__has_expected_fixed_parameters( + env, lattice_system, expected_params +): + for expected_value, param_name in zip(expected_params, PARAMETER_NAMES): + if expected_value is not None: + assert getattr(env, param_name) == expected_value + + +@pytest.mark.parametrize( + "lattice_system", + [CUBIC], +) +@pytest.mark.repeat(N_REPETITIONS) +def test__cubic__constraints_remain_after_random_actions(env, lattice_system): + env = env.reset() + while not env.done: + (a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles() + assert len({a, b, c}) == 1 + assert len({alpha, beta, gamma, 90.0}) == 1 + env.step_random() + + +@pytest.mark.parametrize( + "lattice_system", + [HEXAGONAL], +) +@pytest.mark.repeat(N_REPETITIONS) +def test__hexagonal__constraints_remain_after_random_actions(env, lattice_system): + env = env.reset() + while not env.done: + env.step_random() + (a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles() + assert a == b + assert len({a, b, c}) == 2 + assert len({alpha, beta, 90.0}) == 1 + assert gamma == 120.0 + + +@pytest.mark.parametrize( + "lattice_system", + [MONOCLINIC], +) +@pytest.mark.repeat(N_REPETITIONS) +def test__monoclinic__constraints_remain_after_random_actions(env, lattice_system): + env = env.reset() + while not env.done: + env.step_random() + (a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles() + assert len({a, b, c}) == 3 + assert len({alpha, gamma, 90.0}) == 1 + assert beta != 90.0 + + +@pytest.mark.parametrize( + "lattice_system", + [ORTHORHOMBIC], +) +@pytest.mark.repeat(N_REPETITIONS) +def test__orthorhombic__constraints_remain_after_random_actions(env, lattice_system): + env = env.reset() + while not env.done: + env.step_random() + (a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles() + assert len({a, b, c}) == 3 + assert len({alpha, beta, gamma, 90.0}) == 1 + + +@pytest.mark.parametrize( + "lattice_system", + [RHOMBOHEDRAL], +) +@pytest.mark.repeat(N_REPETITIONS) +def test__rhombohedral__constraints_remain_after_random_actions(env, lattice_system): + env = env.reset() + while not env.done: + env.step_random() + (a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles() + assert len({a, b, c}) == 1 + assert len({alpha, beta, gamma}) == 1 + assert len({alpha, beta, gamma, 90.0}) == 2 + + +@pytest.mark.parametrize( + "lattice_system", + [TETRAGONAL], +) +@pytest.mark.repeat(N_REPETITIONS) +def test__tetragonal__constraints_remain_after_random_actions(env, lattice_system): + env = env.reset() + while not env.done: + env.step_random() + (a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles() + assert a == b + assert len({a, b, c}) == 2 + assert len({alpha, beta, gamma, 90.0}) == 1 + + +@pytest.mark.parametrize( + "lattice_system", + [TRICLINIC], +) +@pytest.mark.repeat(N_REPETITIONS) +def test__triclinic__constraints_remain_after_random_actions(env, lattice_system): + env = env.reset() + while not env.done: + env.step_random() + (a, b, c), (alpha, beta, gamma) = env._unpack_lengths_angles() + assert len({a, b, c}) == 3 + assert len({alpha, beta, gamma, 90.0}) == 4 + + +@pytest.mark.parametrize( + "lattice_system, states, states_proxy_expected", + [ + ( + TRICLINIC, + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.2, 0.5, 0.0, 0.5, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [1.0, 1.0, 1.0, 30.0, 30.0, 30.0], + [1.0, 1.8, 3.0, 30.0, 90.0, 150.0], + [5.0, 5.0, 5.0, 150.0, 150.0, 150.0], + ], + ), + ( + CUBIC, + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.25, 0.5, 0.75, 0.25, 0.5, 0.75], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [1.0, 1.0, 1.0, 30.0, 30.0, 30.0], + [2.0, 3.0, 4.0, 60.0, 90.0, 120.0], + [5.0, 5.0, 5.0, 150.0, 150.0, 150.0], + ], + ), + ], +) +def test__statetorch2proxy__returns_expected( + env, lattice_system, states, states_proxy_expected +): + """ + Various lattice systems are tried because the conversion should be independent of + the lattice system, since the states are expected to satisfy the constraints. + """ + # Get policy states from the batch of states converted into each subenv + # Get policy states from env.statetorch2policy + states_torch = tfloat(states, float_type=env.float, device=env.device) + states_proxy_expected_torch = tfloat( + states_proxy_expected, float_type=env.float, device=env.device + ) + states_proxy = env.statetorch2proxy(states_torch) + assert torch.all(torch.eq(states_proxy, states_proxy_expected_torch)) + states_proxy = env.statebatch2proxy(states_torch) + assert torch.all(torch.eq(states_proxy, states_proxy_expected_torch)) + + +@pytest.mark.parametrize( + "lattice_system, states, states_policy_expected", + [ + ( + TRICLINIC, + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.2, 0.5, 0.0, 0.5, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.2, 0.5, 0.0, 0.5, 1.0], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + ), + ( + CUBIC, + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.25, 0.5, 0.75, 0.25, 0.5, 0.75], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.25, 0.5, 0.75, 0.25, 0.5, 0.75], + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + ), + ], +) +def test__statetorch2policy__returns_expected( + env, lattice_system, states, states_policy_expected +): + """ + Various lattice systems are tried because the conversion should be independent of + the lattice system, since the states are expected to satisfy the constraints. + """ + # Get policy states from the batch of states converted into each subenv + # Get policy states from env.statetorch2policy + states_torch = tfloat(states, float_type=env.float, device=env.device) + states_policy_expected_torch = tfloat( + states_policy_expected, float_type=env.float, device=env.device + ) + states_policy = env.statetorch2policy(states_torch) + assert torch.all(torch.eq(states_policy, states_policy_expected_torch)) + states_policy = env.statebatch2policy(states_torch) + assert torch.all(torch.eq(states_policy, states_policy_expected_torch)) + + +@pytest.mark.parametrize( + "lattice_system, expected_output", + [ + (CUBIC, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"), + (HEXAGONAL, "(1.0, 1.0, 1.0), (90.0, 90.0, 120.0)"), + (MONOCLINIC, "(1.0, 1.0, 1.0), (90.0, 30.0, 90.0)"), + (ORTHORHOMBIC, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"), + (RHOMBOHEDRAL, "(1.0, 1.0, 1.0), (30.0, 30.0, 30.0)"), + (TETRAGONAL, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"), + (TRICLINIC, "(1.0, 1.0, 1.0), (30.0, 30.0, 30.0)"), + ], +) +@pytest.mark.skip(reason="skip until it gets updated") +def test__state2readable__gives_expected_results_for_initial_states( + env, lattice_system, expected_output +): + assert env.state2readable() == expected_output + + +@pytest.mark.parametrize( + "lattice_system, readable", + [ + (CUBIC, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"), + (HEXAGONAL, "(1.0, 1.0, 1.0), (90.0, 90.0, 120.0)"), + (MONOCLINIC, "(1.0, 1.0, 1.0), (90.0, 30.0, 90.0)"), + (ORTHORHOMBIC, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"), + (RHOMBOHEDRAL, "(1.0, 1.0, 1.0), (30.0, 30.0, 30.0)"), + (TETRAGONAL, "(1.0, 1.0, 1.0), (90.0, 90.0, 90.0)"), + (TRICLINIC, "(1.0, 1.0, 1.0), (30.0, 30.0, 30.0)"), + ], +) +@pytest.mark.skip(reason="skip until it gets updated") +def test__readable2state__gives_expected_results_for_initial_states( + env, lattice_system, readable +): + assert env.readable2state(readable) == env.state + + +@pytest.mark.parametrize( + "lattice_system", + [CUBIC, HEXAGONAL, MONOCLINIC, ORTHORHOMBIC, RHOMBOHEDRAL, TETRAGONAL, TRICLINIC], +) +def test__continuous_env_common(env, lattice_system): + return common.test__continuous_env_common(env) diff --git a/tests/gflownet/envs/test_composition.py b/tests/gflownet/envs/test_composition.py index 888e7221a..31fcd470e 100644 --- a/tests/gflownet/envs/test_composition.py +++ b/tests/gflownet/envs/test_composition.py @@ -38,15 +38,39 @@ def test__environment__initializes_properly(elements): [ ( [0, 0, 2, 0], - [0, 0, 2, 0], + [ + # fmt: off + 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # fmt: on + ], ), ( [3, 0, 0, 0], - [3, 0, 0, 0], + [ + # fmt: off + 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # fmt: on + ], ), ( [0, 1, 0, 1], - [0, 1, 0, 1], + [ + # fmt: off + 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + # fmt: on + ], ), ], ) diff --git a/tests/gflownet/envs/test_crystal.py b/tests/gflownet/envs/test_crystal.py index c11cac1ec..5d4c9cbed 100644 --- a/tests/gflownet/envs/test_crystal.py +++ b/tests/gflownet/envs/test_crystal.py @@ -66,11 +66,47 @@ def test__pad_depad_action(env): [ [ (2, 1, 1, 1, 1, 1, 2, 3, 1, 2, 3, 4, 5, 6), - Tensor([1.0, 1.0, 1.0, 1.0, 3.0, 1.4, 1.8, 2.2, 78.0, 90.0, 102.0]), + Tensor( + [ + # fmt: off + # Composition state + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, + # Spacegroup state + 3.0, + # Lattice parameter state + 1.4, 1.8, 2.2, 78.0, 90.0, 102.0, + # fmt: on + ] + ), ], [ (2, 4, 9, 0, 3, 0, 0, 105, 5, 3, 1, 0, 0, 9), - Tensor([4.0, 9.0, 0.0, 3.0, 105.0, 3.0, 2.2, 1.4, 30.0, 30.0, 138.0]), + Tensor( + [ + # fmt: off + # Composition state + 0.0, 4.0, 9.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, + # Spacegroup state + 105.0, + # Lattice parameter state + 3.0, 2.2, 1.4, 30.0, 30.0, 138.0, + # fmt: on + ] + ), ], ], ) @@ -83,11 +119,47 @@ def test__state2oracle__returns_expected_value(env, state, expected): [ [ (2, 1, 1, 1, 1, 1, 2, 3, 1, 2, 3, 4, 5, 6), - Tensor([1.0, 1.0, 1.0, 1.0, 3.0, 1.4, 1.8, 2.2, 78.0, 90.0, 102.0]), + Tensor( + [ + # fmt: off + # Composition state + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, + # Spacegroup state + 3.0, + # Lattice parameter state + 1.4, 1.8, 2.2, 78.0, 90.0, 102.0, + # fmt: on + ] + ), ], [ (2, 4, 9, 0, 3, 0, 0, 105, 5, 3, 1, 0, 0, 9), - Tensor([4.0, 9.0, 0.0, 3.0, 105.0, 3.0, 2.2, 1.4, 30.0, 30.0, 138.0]), + Tensor( + [ + # fmt: off + # Composition state + 0.0, 4.0, 9.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, + # Spacegroup state + 105.0, + # Lattice parameter state + 3.0, 2.2, 1.4, 30.0, 30.0, 138.0, + # fmt: on + ] + ), ], ], ) @@ -105,8 +177,40 @@ def test__state2proxy__returns_expected_value(env, state, expected): ], Tensor( [ - [1.0, 1.0, 1.0, 1.0, 3.0, 1.4, 1.8, 2.2, 78.0, 90.0, 102.0], - [4.0, 9.0, 0.0, 3.0, 105.0, 3.0, 2.2, 1.4, 30.0, 30.0, 138.0], + [ + # fmt: off + # Composition state + 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, + # Spacegroup state + 3.0, + # Lattice parameter state + 1.4, 1.8, 2.2, 78.0, 90.0, 102.0, + # fmt: on + ], + [ + # fmt: off + # Composition state + 0.0, 4.0, 9.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, + # Spacegroup state + 105.0, + # Lattice parameter state + 3.0, 2.2, 1.4, 30.0, 30.0, 138.0, + # fmt: on + ], ] ), ], diff --git a/tests/gflownet/envs/test_ctorus.py b/tests/gflownet/envs/test_ctorus.py index 8c995e153..0f418a5f9 100644 --- a/tests/gflownet/envs/test_ctorus.py +++ b/tests/gflownet/envs/test_ctorus.py @@ -53,10 +53,7 @@ def test__sample_actions_batch__special_cases( mask = torch.unsqueeze( tbool(env.get_mask_invalid_actions_forward(), device=env.device), 0 ) - random_policy = torch.unsqueeze( - tfloat(env.random_policy_output, float_type=env.float, device=env.device), - 0, - ) + random_policy = torch.unsqueeze(env.random_policy_output, 0) action_sampled = env.sample_actions_batch( random_policy, mask, @@ -96,10 +93,7 @@ def test__sample_actions_batch__not_special_cases( mask = torch.unsqueeze( tbool(env.get_mask_invalid_actions_forward(), device=env.device), 0 ) - random_policy = torch.unsqueeze( - tfloat(env.random_policy_output, float_type=env.float, device=env.device), - 0, - ) + random_policy = torch.unsqueeze(env.random_policy_output, 0) action_sampled = env.sample_actions_batch( random_policy, mask, diff --git a/tests/gflownet/envs/test_spacegroup.py b/tests/gflownet/envs/test_spacegroup.py index 049d081c5..fd19c24b0 100644 --- a/tests/gflownet/envs/test_spacegroup.py +++ b/tests/gflownet/envs/test_spacegroup.py @@ -8,6 +8,7 @@ from gflownet.envs.crystals.spacegroup import SpaceGroup N_ATOMS = [3, 7, 9] +N_ATOMS_B = [5, 0, 14, 1] SG_SUBSET = [1, 17, 39, 123, 230] @@ -21,6 +22,11 @@ def env_with_composition(): return SpaceGroup(n_atoms=N_ATOMS) +@pytest.fixture +def env_with_composition_b(): + return SpaceGroup(n_atoms=N_ATOMS_B) + + @pytest.fixture def env_with_restricted_spacegroups(): return SpaceGroup(space_groups_subset=SG_SUBSET) @@ -89,6 +95,11 @@ def test__environment__action_space_has_eos(): assert env.eos in env.action_space +def test__env_with_composition_b__debug(env_with_composition_b): + env = env_with_composition_b + pass + + @pytest.mark.parametrize( "action, expected", [