From 737cda81fa46501903982783334c36f2bbf0007d Mon Sep 17 00:00:00 2001 From: nikita-0209 Date: Tue, 11 Apr 2023 11:58:07 -0400 Subject: [PATCH 01/54] checkout files from test_merge_cont_mf --- config/env/amp.yaml | 34 +++ config/env/aptamers.yaml | 6 +- config/proxy/amp.yaml | 9 + config/proxy/aptamers.yaml | 6 + gflownet/envs/amp.py | 33 +++ gflownet/envs/aptamers.py | 424 +--------------------------- gflownet/envs/sequence.py | 473 ++++++++++++++++++++++++++++++++ gflownet/proxy/aptamers.py | 74 ++++- tests/gflownet/envs/test_amp.py | 168 ++++++++++++ 9 files changed, 801 insertions(+), 426 deletions(-) create mode 100644 config/env/amp.yaml create mode 100644 config/proxy/amp.yaml create mode 100644 config/proxy/aptamers.yaml create mode 100644 gflownet/envs/amp.py create mode 100644 gflownet/envs/sequence.py create mode 100644 tests/gflownet/envs/test_amp.py diff --git a/config/env/amp.yaml b/config/env/amp.yaml new file mode 100644 index 000000000..c3d4165d2 --- /dev/null +++ b/config/env/amp.yaml @@ -0,0 +1,34 @@ +defaults: + - base + +_target_: gflownet.envs.amp.AMP + +id: amp +# Minimum and maximum length for the sequences +min_seq_length: 1 +max_seq_length: 50 +# Number of letters in alphabet +n_alphabet: 20 +# Minimum and maximum number of steps in the action space +min_word_len: 1 +max_word_len: 1 +reward_func: power +reward_norm_std_mult: -1.0 +reward_norm: 0.1 +reward_beta: 8.0 +do_state_padding: True +# Buffer +buffer: + replay_capacity: 10 + train: + path: data_train.csv + n: 20 + seed: 168 + output: None +# TODO: Might need to delete irrelevant params after updating buffer + test: + base: None + path: null + n: 10 + seed: 168 + output: None \ No newline at end of file diff --git a/config/env/aptamers.yaml b/config/env/aptamers.yaml index 5de209fa3..8b09ee2dc 100644 --- a/config/env/aptamers.yaml +++ b/config/env/aptamers.yaml @@ -1,10 +1,10 @@ defaults: - base -_target_: gflownet.envs.aptamers.AptamerSeq +_target_: gflownet.envs.aptamers.Aptamers id: aptamers -func: nupack energy +# func: nupack energy # Minimum and maximum length for the sequences min_seq_length: 30 max_seq_length: 30 @@ -13,4 +13,4 @@ n_alphabet: 4 # Minimum and maximum number of steps in the action space min_word_len: 1 max_word_len: 1 - +corr_type: None diff --git a/config/proxy/amp.yaml b/config/proxy/amp.yaml new file mode 100644 index 000000000..ec4063f47 --- /dev/null +++ b/config/proxy/amp.yaml @@ -0,0 +1,9 @@ +_target_: gflownet.proxy.amp.AMPOracleWrapper + +oracle_split: "D2_target" +oracle_type: "MLP" +oracle_features: "AlBert" +dist_fn: "edit" +medoid_oracle_norm: 1 +maximize: True +cost: 1 diff --git a/config/proxy/aptamers.yaml b/config/proxy/aptamers.yaml new file mode 100644 index 000000000..5ebce2d19 --- /dev/null +++ b/config/proxy/aptamers.yaml @@ -0,0 +1,6 @@ +_target_: gflownet.proxy.aptamers.Aptamers + +oracle_id: "energy" +norm: False +cost: 4 +maximize: False \ No newline at end of file diff --git a/gflownet/envs/amp.py b/gflownet/envs/amp.py new file mode 100644 index 000000000..69f5e5ab3 --- /dev/null +++ b/gflownet/envs/amp.py @@ -0,0 +1,33 @@ +""" +Classes to represent aptamers environments +""" +from typing import List, Tuple +import itertools +import numpy as np +from gflownet.envs.base import GFlowNetEnv +import itertools +from polyleven import levenshtein +import numpy.typing as npt +from torchtyping import TensorType +import torch +import matplotlib.pyplot as plt +import torch.nn.functional as F +from gflownet.utils.sequence.amp import AMINO_ACIDS +from gflownet.envs.sequence import Sequence + + +class AMP(Sequence): + """ + Anti-microbial peptide sequence environment + """ + + def __init__( + self, + **kwargs, + ): + special_tokens = ["[PAD]", "[EOS]"] + self.vocab = AMINO_ACIDS + special_tokens + super().__init__( + **kwargs, + special_tokens=special_tokens, + ) diff --git a/gflownet/envs/aptamers.py b/gflownet/envs/aptamers.py index 425a2eb91..1840ddeb0 100644 --- a/gflownet/envs/aptamers.py +++ b/gflownet/envs/aptamers.py @@ -8,424 +8,30 @@ import numpy as np import numpy.typing as npt import pandas as pd - -from gflownet.envs.base import GFlowNetEnv +import time +from gflownet.utils.sequence.aptamers import NUCLEOTIDES +from gflownet.envs.sequence import Sequence -class AptamerSeq(GFlowNetEnv): +class Aptamers(Sequence): """ Aptamer sequence environment - - Attributes - ---------- - max_seq_length : int - Maximum length of the sequences - - min_seq_length : int - Minimum length of the sequences - - n_alphabet : int - Number of letters in the alphabet - - state : list - Representation of a sequence (state), as a list of length max_seq_length where - each element is the index of a letter in the alphabet, from 0 to (n_alphabet - - 1). - - done : bool - True if the sequence has reached a terminal state (maximum length, or stop - action executed. - - func : str - Name of the reward function - - n_actions : int - Number of actions applied to the sequence - - proxy : lambda - Proxy model """ def __init__( self, - max_seq_length=42, - min_seq_length=1, - n_alphabet=4, - min_word_len=1, - max_word_len=1, **kwargs, ): - super().__init__() - self.source = [] - self.min_seq_length = min_seq_length - self.max_seq_length = max_seq_length - self.n_alphabet = n_alphabet - self.min_word_len = min_word_len - self.max_word_len = max_word_len - self.action_space = self.get_action_space() - self.eos = self.action_space_dim - self.reset() - self.fixed_policy_output = self.get_fixed_policy_output() - self.random_policy_output = self.get_fixed_policy_output() - self.policy_output_dim = len(self.fixed_policy_output) - self.policy_input_dim = len(self.state2policy()) - self.max_traj_len = self.get_max_traj_length() - # Set up proxy - self.setup_proxy() - - def get_action_space(self): - """ - Constructs list with all possible actions - """ - assert self.max_word_len >= self.min_word_len - valid_wordlens = np.arange(self.min_word_len, self.max_word_len + 1) - alphabet = [a for a in range(self.n_alphabet)] - actions = [] - for r in valid_wordlens: - actions_r = [el for el in itertools.product(alphabet, repeat=r)] - actions += actions_r - return actions - - def get_max_traj_length( - self, - ): - return self.max_seq_length / self.min_word_len + 1 - - def reward_arbitrary_i(self, state): - if len(state) > 0: - return (state[-1] + 1) * len(state) - else: - return 0 - - def statebatch2oracle(self, states: List[List]): - """ - Prepares a batch of sequence states for the oracles. - - Args - ---- - states : list of lists - List of sequences. - """ - queries = [s + [-1] * (self.max_seq_length - len(s)) for s in states] - queries = np.array(queries, dtype=int) - if queries.ndim == 1: - queries = queries[np.newaxis, ...] - queries += 1 - if queries.shape[1] == 1: - import ipdb - - ipdb.set_trace() - queries = np.column_stack((queries, np.zeros(queries.shape[0]))) - return queries - - def state2policy(self, state: List = None) -> List: - """ - Transforms the sequence (state) given as argument (or self.state if None) into a - one-hot encoding. The output is a list of length n_alphabet * max_seq_length, - where each n-th successive block of n_alphabet elements is a one-hot encoding of - the letter in the n-th position. - - Example: - - Sequence: AATGC - - state: [0, 1, 3, 2] - A, T, G, C - - state2policy(state): [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] - | A | T | G | C | - - If max_seq_length > len(s), the last (max_seq_length - len(s)) blocks are all - 0s. - """ - if state is None: - state = self.state.copy() - state_policy = np.zeros(self.n_alphabet * self.max_seq_length, dtype=np.float32) - if len(state) > 0: - state_policy[(np.arange(len(state)) * self.n_alphabet + state)] = 1 - return state_policy - - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Transforms a batch of states into the policy model format. The output is a numpy - array of shape [n_states, n_angles * n_dim + 1]. - - See state2policy(). - """ - cols, lengths = zip( - *[ - (state + np.arange(len(state)) * self.n_alphabet, len(state)) - for state in states - ] - ) - rows = np.repeat(np.arange(len(states)), lengths) - state_policy = np.zeros( - (len(states), self.n_alphabet * self.max_seq_length), dtype=np.float32 - ) - state_policy[rows, np.concatenate(cols)] = 1.0 - return state_policy - - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a sequence (state) given as argument - into a a sequence of letter indices. - - Example: - - Sequence: AATGC - - state_policy: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] - | A | A | T | G | C | - - state: [0, 0, 1, 3, 2] - A, A, T, G, C - """ - return np.where( - np.reshape(state_policy, (self.max_seq_length, self.n_alphabet)) - )[1].tolist() - - def state2readable(self, state, alphabet={0: "A", 1: "T", 2: "C", 3: "G"}): - """ - Transforms a sequence given as a list of indices into a sequence of letters - according to an alphabet. - """ - return [alphabet[el] for el in state] - - def readable2state(self, letters, alphabet={0: "A", 1: "T", 2: "C", 3: "G"}): - """ - Transforms a sequence given as a list of indices into a sequence of letters - according to an alphabet. - """ - alphabet = {v: k for k, v in alphabet.items()} - return [alphabet[el] for el in letters] - - def reset(self, env_id=None): - """ - Resets the environment. - """ - self.state = [] - self.n_actions = 0 - self.done = False - self.id = env_id - return self - - def get_parents(self, state=None, done=None, action=None): - """ - Determines all parents and actions that lead to sequence state - - Args - ---- - state : list - Representation of a sequence (state), as a list of length max_seq_length - where each element is the index of a letter in the alphabet, from 0 to - (n_alphabet - 1). - - done : bool - Whether the trajectory is done. If None, done is taken from instance. - - action : None - Ignored - - Returns - ------- - parents : list - List of parents in state format - - actions : list - List of actions that lead to state for each parent in parents - """ - if state is None: - state = self.state.copy() - if done is None: - done = self.done - if done: - return [state], [self.eos] - else: - parents = [] - actions = [] - for idx, a in enumerate(self.action_space): - is_parent = state[-len(a) :] == list(a) - if not isinstance(is_parent, bool): - is_parent = all(is_parent) - if is_parent: - parents.append(state[: -len(a)]) - actions.append(idx) - return parents, actions - - def step(self, action_idx): - """ - Executes step given an action index - - If action_idx is smaller than eos (no stop), add action to next - position. - - See: step_daug() - See: step_chain() - - Args - ---- - action_idx : int - Index of action in the action space. a == eos indicates "stop action" - - Returns - ------- - self.state : list - The sequence after executing the action - - valid : bool - False, if the action is not allowed for the current state, e.g. stop at the - root state - """ - # If only possible action is eos, then force eos - if len(self.state) == self.max_seq_length: - self.done = True - self.n_actions += 1 - return self.state, [self.eos], True - # If action is not eos, then perform action - if action_idx != self.eos: - action = self.action_space[action_idx] - state_next = self.state + list(action) - if len(state_next) > self.max_seq_length: - valid = False - else: - self.state = state_next - valid = True - self.n_actions += 1 - return self.state, action_idx, valid - # If action is eos, then perform eos - else: - if len(self.state) < self.min_seq_length: - valid = False - else: - self.done = True - valid = True - self.n_actions += 1 - return self.state, self.eos, valid - - def get_mask_invalid_actions_forward(self, state=None, done=None): - """ - Returns a vector of length the action space + 1: True if action is invalid - given the current state, False otherwise. - """ - if state is None: - state = self.state.copy() - if done is None: - done = self.done - if done: - return [True for _ in range(self.action_space_dim + 1)] - mask = [False for _ in range(self.action_space_dim + 1)] - seq_length = len(state) - if seq_length < self.min_seq_length: - mask[self.eos] = True - for idx, a in enumerate(self.action_space): - if seq_length + len(a) > self.max_seq_length: - mask[idx] = True - return mask - - def make_train_set( - self, - ntrain, - oracle=None, - seed=168, - output_csv=None, - ): - """ - Constructs a randomly sampled train set. - - Args - ---- - ntest : int - Number of test samples. - - seed : int - Random seed. - - output_csv: str - Optional path to store the test set as CSV. - """ - samples_dict = oracle.initializeDataset( - save=False, returnData=True, customSize=ntrain, custom_seed=seed + special_tokens = ["[PAD]", "[EOS]"] + self.vocab = NUCLEOTIDES + special_tokens + super().__init__( + **kwargs, + special_tokens=special_tokens, ) - energies = samples_dict["energies"] - samples_mat = samples_dict["samples"] - state_letters = oracle.numbers2letters(samples_mat) - state_ints = [ - "".join([str(el) for el in state if el > 0]) for state in samples_mat - ] - if isinstance(energies, dict): - energies.update({"samples": state_letters, "indices": state_ints}) - df_train = pd.DataFrame(energies) - else: - df_train = pd.DataFrame( - {"samples": state_letters, "indices": state_ints, "energies": energies} - ) - if output_csv: - df_train.to_csv(output_csv) - return df_train - - # TODO: improve approximation of uniform - def make_test_set( - self, - path_base_dataset, - ntest, - min_length=0, - max_length=np.inf, - seed=167, - output_csv=None, - ): - """ - Constructs an approximately uniformly distributed (on the score) set, by - selecting samples from a larger base set. - - Args - ---- - path_base_dataset : str - Path to a CSV file containing the base data set. - - ntest : int - Number of test samples. - - seed : int - Random seed. - - dask : bool - If True, use dask to efficiently read a large base file. - output_csv: str - Optional path to store the test set as CSV. - """ - if path_base_dataset is None: - return None, None - times = { - "all": 0.0, - "indices": 0.0, - } - t0_all = time.time() - if seed: - np.random.seed(seed) - df_base = pd.read_csv(path_base_dataset, index_col=0) - df_base = df_base.loc[ - (df_base["samples"].map(len) >= min_length) - & (df_base["samples"].map(len) <= max_length) - ] - energies_base = df_base["energies"].values - min_base = energies_base.min() - max_base = energies_base.max() - distr_unif = np.random.uniform(low=min_base, high=max_base, size=ntest) - # Get minimum distance samples without duplicates - t0_indices = time.time() - idx_samples = [] - for idx in tqdm(range(ntest)): - dist = np.abs(energies_base - distr_unif[idx]) - idx_min = np.argmin(dist) - if idx_min in idx_samples: - idx_sort = np.argsort(dist) - for idx_next in idx_sort: - if idx_next not in idx_samples: - idx_samples.append(idx_next) - break - else: - idx_samples.append(idx_min) - t1_indices = time.time() - times["indices"] += t1_indices - t0_indices - # Make test set - df_test = df_base.iloc[idx_samples] - if output_csv: - df_test.to_csv(output_csv) - t1_all = time.time() - times["all"] += t1_all - t0_all - return df_test, times + # if ( + # hasattr(self, "proxy") + # and self.proxy is not None + # and hasattr(self.proxy, "setup") + # ): + # self.proxy.setup(self.max_seq_length) diff --git a/gflownet/envs/sequence.py b/gflownet/envs/sequence.py new file mode 100644 index 000000000..f562b1cb7 --- /dev/null +++ b/gflownet/envs/sequence.py @@ -0,0 +1,473 @@ +""" +Classes to represent sequence-like environments +Particularly AMP and DNA +""" +from typing import List, Tuple +import itertools +import numpy as np +from gflownet.envs.base import GFlowNetEnv +import itertools +from polyleven import levenshtein +import numpy.typing as npt +from torchtyping import TensorType +import torch +import matplotlib.pyplot as plt +import torch.nn.functional as F + + +class Sequence(GFlowNetEnv): + """ + Anti-microbial peptide sequence environment + + Attributes + ---------- + max_seq_length : int + Maximum length of the sequences + + min_seq_length : int + Minimum length of the sequences + + nalphabet : int + Number of letters in the alphabet + + state : list + Representation of a sequence (state), as a list of length max_seq_length where + each element is the index of a letter in the alphabet, from 0 to (nalphabet - + 1). + + done : bool + True if the sequence has reached a terminal state (maximum length, or stop + action executed. + + func : str + Name of the reward function + + n_actions : int + Number of actions applied to the sequence + + proxy : lambda + Proxy model + """ + + def __init__( + self, + corr_type, + max_seq_length=50, + min_seq_length=1, + # Not required in env. But used in config_env in MLP. TODO: Find a way out + n_alphabet=20, + min_word_len=1, + max_word_len=1, + special_tokens=None, + **kwargs, + ): + self.min_seq_length = min_seq_length + self.max_seq_length = max_seq_length + self.min_word_len = min_word_len + self.max_word_len = max_word_len + self.corr_type = corr_type + self.lookup = {a: i for (i, a) in enumerate(self.vocab)} + self.inverse_lookup = {i: a for (i, a) in enumerate(self.vocab)} + self.n_alphabet = len(self.vocab) - len(special_tokens) + self.padding_idx = self.lookup["[PAD]"] + # TODO: eos re-initalised in get_actions_space so why was this initialisation required in the first place (maybe mfenv) + self.eos = self.lookup["[EOS]"] + self.source = ( + torch.ones(self.max_seq_length, dtype=torch.int64) * self.padding_idx + ) + # reset this to a lower value + self.min_reward = 1e-20 + # if proxy is not None: + # self.proxy = proxy + super().__init__( + **kwargs, + ) + self.policy_input_dim = self.state2policy().shape[-1] + self.tokenizer = None + + def get_action_space(self): + """ + Constructs list with all possible actions + If min_word_len = n_alphabet = 2, actions: [(0, 0,), (1, 1)] and so on + """ + assert self.max_word_len >= self.min_word_len + valid_wordlens = np.arange(self.min_word_len, self.max_word_len + 1) + alphabet = [a for a in range(self.n_alphabet)] + actions = [] + for r in valid_wordlens: + actions_r = [el for el in itertools.product(alphabet, repeat=r)] + actions += actions_r + # Add "eos" action + # eos != n_alphabet in the init because it would break if max_word_len >1 + actions = actions + [(len(actions),)] + self.eos = len(actions) - 1 + return actions + + def copy(self): + return self.__class__(**self.__dict__) + # return deepcopy(self) + + def get_mask_invalid_actions_forward(self, state=None, done=None): + """ + Returns a vector of length the action space (where action space includes eos): True if action is invalid + given the current state, False otherwise. + """ + if state is None: + state = self.state.clone().detach() + if done is None: + done = self.done + if done: + return [True for _ in range(len(self.action_space))] + mask = [False for _ in range(len(self.action_space))] + seq_length = ( + torch.where(state == self.padding_idx)[0][0] + if state[-1] == self.padding_idx + else len(state) + ) + if seq_length < self.min_seq_length: + mask[self.eos] = True + for idx, a in enumerate(self.action_space[:-1]): + if seq_length + len(list(a)) > self.max_seq_length: + mask[idx] = True + return mask + + def true_density(self, max_states=1e6): + """ + Computes the reward density (reward / sum(rewards)) of the whole space, if the + dimensionality is smaller than specified in the arguments. + + Returns + ------- + Tuple: + - normalized reward for each state + - states + - (un-normalized) reward) + """ + if self._true_density is not None: + return self._true_density + if self.n_alphabet**self.max_seq_length > max_states: + return (None, None, None) + state_all = np.int32( + list( + itertools.product(*[list(range(self.n_alphabet))] * self.max_seq_length) + ) + ) + traj_rewards, state_end = zip( + *[ + (self.proxy(state), state) + for state in state_all + if len(self.get_parents(state, False)[0]) > 0 or sum(state) == 0 + ] + ) + traj_rewards = np.array(traj_rewards) + self._true_density = ( + traj_rewards / traj_rewards.sum(), + list(map(tuple, state_end)), + traj_rewards, + ) + return self._true_density + + # def state2oracle(self, state: List = None): + # return "".join(self.state2readable(state)) + + def get_max_traj_length( + self, + ): + return self.max_seq_length / self.min_word_len + 1 + + def statebatch2oracle( + self, states: List[TensorType["max_seq_length"]] + ) -> List[str]: + state_oracle = [] + for state in states: + if state[-1] == self.padding_idx: + state = state[: torch.where(state == self.padding_idx)[0][0]] + if self.tokenizer is not None and state[0] == self.tokenizer.bos_idx: + state = state[1:-1] + state_numpy = state.detach().cpu().numpy() + state_oracle.append(self.state2oracle(state_numpy)) + return state_oracle + + def statetorch2oracle( + self, states: TensorType["batch_dim", "max_seq_length"] + ) -> List[str]: + return self.statebatch2oracle(states) + + # TODO: Deprecate as never used. + def state2policy(self, state=None): + """ + Transforms the sequence (state) given as argument (or self.state if None) into a + one-hot encoding. The output is a list of length nalphabet * max_seq_length, + where each n-th successive block of nalphabet elements is a one-hot encoding of + the letter in the n-th position. + + Example: + - Sequence: AATGC + - state: [0, 1, 3, 2] + A, T, G, C + - state2obs(state): [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] + | A | T | G | C | + + If max_seq_length > len(state), the last (max_seq_length - len(state)) blocks are all + 0s. + """ + if state is None: + state = self.state.clone().detach() + state = ( + state[: torch.where(state == self.padding_idx)[0][0]] + if state[-1] == self.padding_idx + else state + ) + state_policy = torch.zeros(1, self.max_seq_length, self.n_alphabet) + if len(state) == 0: + return state_policy.reshape(1, -1) + state_onehot = F.one_hot(state, num_classes=self.n_alphabet + 1)[:, :, 1:].to( + self.float + ) + state_policy[:, : state_onehot.shape[1], :] = state_onehot + return state_policy.reshape(state.shape[0], -1) + + def statebatch2policy( + self, states: List[TensorType["1", "max_seq_length"]] + ) -> TensorType["batch", "policy_input_dim"]: + """ + Transforms a batch of states into the policy model format. The output is a numpy + array of shape [n_states, n_alphabet * max_seq_len]. + + See state2policy() + """ + state_tensor = torch.vstack(states) + state_policy = self.statetorch2policy(state_tensor) + return state_policy + + def statetorch2policy( + self, states: TensorType["batch", "max_seq_length"] + ) -> TensorType["batch", "policy_input_dim"]: + if states.dtype != torch.long: + states = states.to(torch.long) + state_onehot = ( + F.one_hot(states, self.n_alphabet + 2)[:, :, :-2] + .to(self.float) + .to(self.device) + ) + state_padding_mask = (states != self.padding_idx).to(self.float).to(self.device) + state_onehot_pad = state_onehot * state_padding_mask.unsqueeze(-1) + # Assertion works as long as [PAD] is last key in lookup table. + assert torch.eq(state_onehot_pad, state_onehot).all() + state_policy = torch.zeros( + states.shape[0], + self.max_seq_length, + self.n_alphabet, + device=self.device, + dtype=self.float, + ) + state_policy[:, : state_onehot.shape[1], :] = state_onehot + return state_policy.reshape(states.shape[0], -1) + + def policytorch2state(self, state_policy: List) -> List: + """ + Transforms the one-hot encoding version of a sequence (state) given as argument + into a a sequence of letter indices. + + Example: + - Sequence: AATGC + - state_policy: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] + | A | A | T | G | C | + - policy2state(state_policy): [0, 0, 1, 3, 2] + A, A, T, G, C + """ + mat_state_policy = torch.reshape( + state_policy, (self.max_seq_length, self.n_alphabet) + ) + state = torch.where(mat_state_policy)[1].tolist() + return state + + # TODO: Deprecate as never used. + def policy2state(self, state_policy: List) -> List: + """ + Transforms the one-hot encoding version of a sequence (state) given as argument + into a a sequence of letter indices. + + Example: + - Sequence: AATGC + - state_policy: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] + | A | A | T | G | C | + - policy2state(state_policy): [0, 0, 1, 3, 2] + A, A, T, G, C + """ + mat_state_policy = np.reshape( + state_policy, (self.max_seq_length, self.n_alphabet) + ) + state = np.where(mat_state_policy)[1].tolist() + return state + + def state2oracle(self, state: List = None): + return "".join(self.state2readable(state)) + + def statebatch2oracle( + self, states: List[TensorType["max_seq_length"]] + ) -> List[str]: + state_oracle = [] + for state in states: + if state[-1] == self.padding_idx: + state = state[: torch.where(state == self.padding_idx)[0][0]] + if self.tokenizer is not None and state[0] == self.tokenizer.bos_idx: + state = state[1:-1] + state_numpy = state.detach().cpu().numpy() + state_oracle.append(self.state2oracle(state_numpy)) + return state_oracle + + def statetorch2oracle( + self, states: TensorType["batch_dim", "max_seq_length"] + ) -> List[str]: + return self.statebatch2oracle(states) + + def state2readable(self, state: List) -> str: + """ + Transforms a sequence given as a list of indices into a sequence of letters + according to an alphabet. + Used only in Buffer + """ + if isinstance(state, torch.Tensor) == False: + state = torch.tensor(state).long() + if state[-1] == self.padding_idx: + state = state[: torch.where(state == self.padding_idx)[0][0]] + state = state.tolist() + return "".join([self.inverse_lookup[el] for el in state]) + + def statetorch2readable(self, state: TensorType["1", "max_seq_length"]) -> str: + if state[-1] == self.padding_idx: + state = state[: torch.where(state == self.padding_idx)[0][0]] + # TODO: neater way without having lookup as input arg + if ( + self.lookup is not None + and "[CLS]" in self.lookup.keys() + and state[0] == self.lookup["[CLS]"] + ): + state = state[1:-1] + state = state.tolist() + readable = [self.inverse_lookup[el] for el in state] + return "".join(readable) + + def readable2state(self, readable) -> TensorType["batch_dim", "max_seq_length"]: + """ + Transforms a list or string of letters into a list of indices according to an alphabet. + """ + if isinstance(readable, str): + encoded_readable = [self.lookup[el] for el in readable] + state = ( + torch.ones(self.max_seq_length, dtype=torch.int64) * self.padding_idx + ) + state[: len(encoded_readable)] = torch.tensor(encoded_readable) + else: + encoded_readable = [[self.lookup[el] for el in seq] for seq in readable] + state = ( + torch.ones((len(readable), self.max_seq_length), dtype=torch.int64) + * self.padding_idx + ) + for i, seq in enumerate(encoded_readable): + state[i, : len(seq)] = torch.tensor(seq) + return state + + def get_parents(self, state=None, done=None, action=None): + """ + Determines all parents and actions that lead to sequence state + + Args + ---- + state : list + Representation of a sequence (state), as a list of length max_seq_length + where each element is the index of a letter in the alphabet, from 0 to + (nalphabet - 1). + + action : int + Last action performed, only to determine if it was eos. + + Returns + ------- + parents : list + List of parents as state2obs(state) + + actions : list + List of actions that lead to state for each parent in parents + """ + # TODO: Adapt to tuple form actions + if state is None: + state = self.state.clone().detach() + if done is None: + done = self.done + if done: + return [state], [(self.eos,)] + elif torch.eq(state, self.source).all(): + return [], [] + else: + parents = [] + actions = [] + if state[-1] == self.padding_idx: + state_last_element = int(torch.where(state == self.padding_idx)[0][0]) + else: + state_last_element = len(state) + for idx, a in enumerate(self.action_space): + is_parent = state[ + state_last_element - len(a) : state_last_element + ] == torch.LongTensor(a) + if not isinstance(is_parent, bool): + is_parent = all(is_parent) + if is_parent: + parent = state.clone().detach() + parent[ + state_last_element - len(a) : state_last_element + ] = self.padding_idx + parents.append(parent) + actions.append(a) + return parents, actions + + def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int, int], bool]: + """ + Executes step given an action index + + If action_idx is smaller than eos (no stop), add action to next + position. + + Args + ---- + action_idx : int + Index of action in the action space. a == eos indicates "stop action" + + Returns + ------- + self.state : list + The sequence after executing the action + + valid : bool + False, if the action is not allowed for the current state, e.g. stop at the + root state + """ + assert action in self.action_space + # If only possible action is eos, then force eos + if self.state[-1] != self.padding_idx: + self.done = True + self.n_actions += 1 + return self.state, (self.eos,), True + # If action is not eos, then perform action + state_last_element = int(torch.where(self.state == self.padding_idx)[0][0]) + if action[0] != self.eos: + state_next = self.state.clone().detach() + if state_last_element + len(action) > self.max_seq_length: + valid = False + else: + state_next[ + state_last_element : state_last_element + len(action) + ] = torch.LongTensor(action) + self.state = state_next + valid = True + self.n_actions += 1 + return self.state, action, valid + else: + if state_last_element < self.min_seq_length: + valid = False + else: + self.done = True + valid = True + self.n_actions += 1 + return self.state, (self.eos,), valid diff --git a/gflownet/proxy/aptamers.py b/gflownet/proxy/aptamers.py index 338c72347..34e172510 100644 --- a/gflownet/proxy/aptamers.py +++ b/gflownet/proxy/aptamers.py @@ -1,5 +1,7 @@ import numpy as np import numpy.typing as npt +from nupack import * +import torch from gflownet.proxy.base import Proxy @@ -9,27 +11,71 @@ class Aptamers(Proxy): DNA Aptamer oracles """ - def __init__(self, oracle_id, norm): - super().__init__() + def __init__(self, oracle_id, norm, cost, **kwargs): + super().__init__(**kwargs) self.type = oracle_id self.norm = norm + self.cost = cost - def setup(self, env=None): + def setup(self, env, norm=True): self.max_seq_length = env.max_seq_length - def __call__(self, states: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + def _length(self, x): + if self.norm: + return -1.0 * np.sum(x, axis=1) / self.max_seq_length + else: + return -1.0 * np.sum(x, axis=1) + + def __call__(self, sequences): + if self.type == "length": + return self._length(sequences) + elif self.type == "pairs": + self.function = self._func_pairs + return self._nupack(sequences) + elif self.type == "energy": + self.function = self._func_energy + return self._nupack(sequences) + else: + raise NotImplementedError + + def _nupack(self, sequences): """ args: - states : ndarray + inputs: list of arrays in desired format interpretable by oracle + returns: + array of scores + function: + creates the complex set and calls the desired nupack function """ + temperature = 310.0 # Kelvin + ionicStrength = 1.0 # molar + strandList = [] + comps = [] + i = -1 + for sequence in sequences: + i += 1 + strandList.append(Strand(sequence, name="strand{}".format(i))) + comps.append(Complex([strandList[-1]], name="comp{}".format(i))) - def _length(x): - if self.norm: - return -1.0 * np.sum(x, axis=1) / self.max_seq_length - else: - return -1.0 * np.sum(x, axis=1) + set = ComplexSet( + strands=strandList, complexes=SetSpec(max_size=1, include=comps) + ) + model1 = Model(material="dna", celsius=temperature - 273, sodium=ionicStrength) + results = complex_analysis(set, model=model1, compute=["mfe"]) - if self.type == "length": - return _length(states) - else: - raise NotImplementedError("self.type must be length") + energy = self.function(sequences, results, comps) + + return torch.tensor(energy, device=self.device, dtype=self.float) + + def _func_energy(self, sequences, results, comps): + energies = np.zeros(len(sequences)) + for i in range(len(energies)): + energies[i] = results[comps[i]].mfe[0].energy + return energies + + def _func_pairs(self, sequences, results, comps): + ssStrings = np.zeros(len(sequences), dtype=object) + for i in range(len(ssStrings)): + ssStrings[i] = str(results[comps[i]].mfe[0].structure) + nPairs = np.asarray([ssString.count("(") for ssString in ssStrings]).astype(int) + return -nPairs diff --git a/tests/gflownet/envs/test_amp.py b/tests/gflownet/envs/test_amp.py new file mode 100644 index 000000000..ed05f2201 --- /dev/null +++ b/tests/gflownet/envs/test_amp.py @@ -0,0 +1,168 @@ +import pytest +import torch +import numpy as np + +from gflownet.envs.amp import AMP + + +@pytest.fixture +def env(): + return AMP(proxy_state_format="state") + + +def test__environment__initializes_properly(): + env = AMP(proxy_state_format="state") + assert torch.eq( + env.source, torch.ones(env.max_seq_length, dtype=torch.int64) * env.padding_idx + ).all() + assert torch.eq( + env.state, torch.ones(env.max_seq_length, dtype=torch.int64) * env.padding_idx + ).all() + + +def test__environment__action_space_has_eos(): + env = AMP(proxy_state_format="state") + assert (env.eos,) in env.action_space + + +@pytest.mark.parametrize( + "state, expected_state_policy", + [ + ( + torch.tensor([[3, 2, 21, 21, 21]]), + [ + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ), + ( + torch.tensor([[3, 2, 4, 2, 0]]), + [ + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ), + ( + torch.tensor([[21, 21, 21, 21, 21]]), + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + ), + ], +) +def test_environment_policy_transformation(state, expected_state_policy): + env = AMP(proxy_state_format="state", max_seq_length=5) + expected_state_policy_tensor = torch.tensor( + expected_state_policy, dtype=env.float, device=env.device + ).reshape(state.shape[0], -1) + state_policy = env.statetorch2policy(state) + assert torch.eq(state_policy, expected_state_policy_tensor).all() + + +@pytest.mark.parametrize( + "state, done, expected_parent, expected_parent_action", + [ + ( + torch.tensor([3, 21, 21, 21, 21]), + False, + [torch.tensor([21, 21, 21, 21, 21])], + [(3,)], + ), + ( + torch.tensor([3, 2, 4, 2, 0]), + False, + [torch.tensor([3, 2, 4, 2, 21])], + [(0,)], + ), + ( + torch.tensor([3, 21, 21, 21, 21]), + False, + [torch.tensor([21, 21, 21, 21, 21])], + [(3,)], + ), + ( + torch.tensor([3, 21, 21, 21, 21]), + True, + [torch.tensor([3, 21, 21, 21, 21])], + [(20,)], + ), + ( + torch.tensor([21, 21, 21, 21, 21]), + False, + [], + [], + ), + ], +) +def test_environment_get_parents(state, done, expected_parent, expected_parent_action): + env = AMP(proxy_state_format="state", max_seq_length=5) + parent, parent_action = env.get_parents(state, done) + print(parent, parent_action) + if parent != []: + parent_tensor = torch.vstack(parent).to(env.device).to(env.float) + expected_parent_tensor = ( + torch.vstack(expected_parent).to(env.device).to(env.float) + ) + assert torch.eq(parent_tensor, expected_parent_tensor).all() + else: + assert parent == expected_parent + assert parent_action == expected_parent_action + + +@pytest.mark.parametrize( + "state, action, expected_next_state, expected_executed_action, expected_valid", + [ + ( + torch.tensor([3, 21, 21, 21, 21]), + (2,), + torch.tensor([3, 2, 21, 21, 21]), + (2,), + True, + ), + ( + torch.tensor([3, 2, 4, 2, 0]), + (2,), + torch.tensor([3, 2, 4, 2, 0]), + (20,), + True, + ), + ( + torch.tensor([21, 21, 21, 21, 21]), + (20,), + torch.tensor([21, 21, 21, 21, 21]), + (20,), + False, + ), + ( + torch.tensor([3, 21, 21, 21, 21]), + (20,), + torch.tensor([3, 21, 21, 21, 21]), + (20,), + True, + ), + ], +) +def test_environment_step( + state, action, expected_next_state, expected_executed_action, expected_valid +): + env = AMP(proxy_state_format="state", max_seq_length=5) + env.state = state + n_actions = env.n_actions + next_state, action_executed, valid = env.step(action) + if expected_executed_action == (20,) and expected_valid == True: + assert env.done == True + if expected_valid == True: + assert env.n_actions == n_actions + 1 + assert torch.eq(next_state, expected_next_state).all() + assert action_executed == expected_executed_action + assert valid == expected_valid From afdb76242287ebb3f3170ac84c97da9a287c9cc8 Mon Sep 17 00:00:00 2001 From: nikita-0209 Date: Tue, 11 Apr 2023 12:00:36 -0400 Subject: [PATCH 02/54] remove comments --- config/env/amp.yaml | 18 +----------------- config/env/aptamers.yaml | 1 - gflownet/envs/aptamers.py | 7 ------- 3 files changed, 1 insertion(+), 25 deletions(-) diff --git a/config/env/amp.yaml b/config/env/amp.yaml index c3d4165d2..b4b70bbb8 100644 --- a/config/env/amp.yaml +++ b/config/env/amp.yaml @@ -15,20 +15,4 @@ max_word_len: 1 reward_func: power reward_norm_std_mult: -1.0 reward_norm: 0.1 -reward_beta: 8.0 -do_state_padding: True -# Buffer -buffer: - replay_capacity: 10 - train: - path: data_train.csv - n: 20 - seed: 168 - output: None -# TODO: Might need to delete irrelevant params after updating buffer - test: - base: None - path: null - n: 10 - seed: 168 - output: None \ No newline at end of file +reward_beta: 8.0 \ No newline at end of file diff --git a/config/env/aptamers.yaml b/config/env/aptamers.yaml index 8b09ee2dc..210df4f83 100644 --- a/config/env/aptamers.yaml +++ b/config/env/aptamers.yaml @@ -4,7 +4,6 @@ defaults: _target_: gflownet.envs.aptamers.Aptamers id: aptamers -# func: nupack energy # Minimum and maximum length for the sequences min_seq_length: 30 max_seq_length: 30 diff --git a/gflownet/envs/aptamers.py b/gflownet/envs/aptamers.py index 1840ddeb0..5b50eedf1 100644 --- a/gflownet/envs/aptamers.py +++ b/gflownet/envs/aptamers.py @@ -28,10 +28,3 @@ def __init__( **kwargs, special_tokens=special_tokens, ) - - # if ( - # hasattr(self, "proxy") - # and self.proxy is not None - # and hasattr(self.proxy, "setup") - # ): - # self.proxy.setup(self.max_seq_length) From e7cd7a59e3f871d0d5911f4ce28ca3ee488c5cb1 Mon Sep 17 00:00:00 2001 From: Nikita Saxena <59296031+nikita-0209@users.noreply.github.com> Date: Tue, 11 Apr 2023 14:40:38 -0400 Subject: [PATCH 03/54] Update to a slightly faster version of get_parents() --- gflownet/envs/sequence.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/gflownet/envs/sequence.py b/gflownet/envs/sequence.py index f562b1cb7..3d1e41572 100644 --- a/gflownet/envs/sequence.py +++ b/gflownet/envs/sequence.py @@ -407,19 +407,16 @@ def get_parents(self, state=None, done=None, action=None): state_last_element = int(torch.where(state == self.padding_idx)[0][0]) else: state_last_element = len(state) - for idx, a in enumerate(self.action_space): - is_parent = state[ - state_last_element - len(a) : state_last_element - ] == torch.LongTensor(a) - if not isinstance(is_parent, bool): - is_parent = all(is_parent) - if is_parent: + max_parent_action_length = self.max_word_len + 1 - self.min_word_len + for parent_action_length in range(1, max_parent_action_length + 1): + parent_action = tuple(state[state_last_element - parent_action_length : state_last_element].numpy()) + if parent_action in self.action_space: parent = state.clone().detach() parent[ - state_last_element - len(a) : state_last_element + state_last_element - parent_action_length : state_last_element ] = self.padding_idx parents.append(parent) - actions.append(a) + actions.append(parent_action) return parents, actions def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int, int], bool]: From 7567fe3ecd7dd11188941fd4400c521e67c9d3ba Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 1 Aug 2023 09:56:27 -0400 Subject: [PATCH 04/54] Move sequence config files into new dir seqs/ --- config/env/{ => seqs}/amp.yaml | 0 config/env/{ => seqs}/aptamers.yaml | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename config/env/{ => seqs}/amp.yaml (100%) rename config/env/{ => seqs}/aptamers.yaml (100%) diff --git a/config/env/amp.yaml b/config/env/seqs/amp.yaml similarity index 100% rename from config/env/amp.yaml rename to config/env/seqs/amp.yaml diff --git a/config/env/aptamers.yaml b/config/env/seqs/aptamers.yaml similarity index 100% rename from config/env/aptamers.yaml rename to config/env/seqs/aptamers.yaml From 2ba9ae40da83e2cbbc1f2d473dd5a20176686d10 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 1 Aug 2023 09:57:34 -0400 Subject: [PATCH 05/54] Move sequence proxy config files into new dir seqs/ --- config/proxy/{ => seqs}/amp.yaml | 0 config/proxy/{ => seqs}/aptamers.yaml | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename config/proxy/{ => seqs}/amp.yaml (100%) rename config/proxy/{ => seqs}/aptamers.yaml (100%) diff --git a/config/proxy/amp.yaml b/config/proxy/seqs/amp.yaml similarity index 100% rename from config/proxy/amp.yaml rename to config/proxy/seqs/amp.yaml diff --git a/config/proxy/aptamers.yaml b/config/proxy/seqs/aptamers.yaml similarity index 100% rename from config/proxy/aptamers.yaml rename to config/proxy/seqs/aptamers.yaml From e607efc4a10f925f73c9f87a2e66d17b9977f6e2 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 1 Aug 2023 09:59:39 -0400 Subject: [PATCH 06/54] Move sequence env *.py files into new dir seqs/ --- gflownet/envs/{ => seqs}/amp.py | 0 gflownet/envs/{ => seqs}/aptamers.py | 0 gflownet/envs/{ => seqs}/sequence.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename gflownet/envs/{ => seqs}/amp.py (100%) rename gflownet/envs/{ => seqs}/aptamers.py (100%) rename gflownet/envs/{ => seqs}/sequence.py (100%) diff --git a/gflownet/envs/amp.py b/gflownet/envs/seqs/amp.py similarity index 100% rename from gflownet/envs/amp.py rename to gflownet/envs/seqs/amp.py diff --git a/gflownet/envs/aptamers.py b/gflownet/envs/seqs/aptamers.py similarity index 100% rename from gflownet/envs/aptamers.py rename to gflownet/envs/seqs/aptamers.py diff --git a/gflownet/envs/sequence.py b/gflownet/envs/seqs/sequence.py similarity index 100% rename from gflownet/envs/sequence.py rename to gflownet/envs/seqs/sequence.py From 690fde0b3cd71975a67085632bfa39c76345ea69 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 1 Aug 2023 11:24:51 -0400 Subject: [PATCH 07/54] wip: refactor of parent sequence class --- gflownet/envs/seqs/sequence.py | 133 ++++++++++++++++----------------- 1 file changed, 65 insertions(+), 68 deletions(-) diff --git a/gflownet/envs/seqs/sequence.py b/gflownet/envs/seqs/sequence.py index 3d1e41572..4bd012053 100644 --- a/gflownet/envs/seqs/sequence.py +++ b/gflownet/envs/seqs/sequence.py @@ -1,6 +1,8 @@ """ -Classes to represent sequence-like environments -Particularly AMP and DNA +Parent class to represent sequence-like environments, such as AMP and DNA. Sequences +are constructed by adding tokens from a dictionary. An alternative to this kind of +sequence environment (not-implemented as of July 2023) would be a "mutation-based" +modification of the sequences, or a combination of mutations and additions. """ from typing import List, Tuple import itertools @@ -17,55 +19,50 @@ class Sequence(GFlowNetEnv): """ - Anti-microbial peptide sequence environment + Parent of sequence environments. By default, for illustration purposes, this parent + class is functional and represents binary sequences of 0s and 1s. Attributes ---------- - max_seq_length : int - Maximum length of the sequences - - min_seq_length : int + dictionary : dict + A dictionary containing (key: value) pairs where the value is a string + representing the token and the key is an arbitrary integer uniquely identifying + the token. The dictionary should also include (key: value) pairs for the + special tokens: + - End of sequence token (EOS) + - Padding token (PAD) + + max_length : int + Maximum length of the sequences. + + min_length : int Minimum length of the sequences - nalphabet : int - Number of letters in the alphabet - - state : list - Representation of a sequence (state), as a list of length max_seq_length where - each element is the index of a letter in the alphabet, from 0 to (nalphabet - - 1). - - done : bool - True if the sequence has reached a terminal state (maximum length, or stop - action executed. - - func : str - Name of the reward function + max_word_length : int + Maximum number of tokens allowed per action. - n_actions : int - Number of actions applied to the sequence - - proxy : lambda - Proxy model + min_word_length : int + Minimum number of tokens allowed per action. """ def __init__( self, - corr_type, - max_seq_length=50, - min_seq_length=1, - # Not required in env. But used in config_env in MLP. TODO: Find a way out - n_alphabet=20, - min_word_len=1, - max_word_len=1, + dictionary={0: "0", 1: "1", -1: "[PAD]", -2: "[EOS]"}, + max_length=10, + min_length=1, + max_word_length=1, + min_word_length=1, special_tokens=None, **kwargs, ): - self.min_seq_length = min_seq_length - self.max_seq_length = max_seq_length - self.min_word_len = min_word_len - self.max_word_len = max_word_len - self.corr_type = corr_type + # Main attributes + self.dictionary= dictionary + self.min_length = min_length + self.max_length = max_length + self.min_word_length = min_word_length + self.max_word_length = max_word_length + + self.lookup = {a: i for (i, a) in enumerate(self.vocab)} self.inverse_lookup = {i: a for (i, a) in enumerate(self.vocab)} self.n_alphabet = len(self.vocab) - len(special_tokens) @@ -73,7 +70,7 @@ def __init__( # TODO: eos re-initalised in get_actions_space so why was this initialisation required in the first place (maybe mfenv) self.eos = self.lookup["[EOS]"] self.source = ( - torch.ones(self.max_seq_length, dtype=torch.int64) * self.padding_idx + torch.ones(self.max_length, dtype=torch.int64) * self.padding_idx ) # reset this to a lower value self.min_reward = 1e-20 @@ -88,17 +85,17 @@ def __init__( def get_action_space(self): """ Constructs list with all possible actions - If min_word_len = n_alphabet = 2, actions: [(0, 0,), (1, 1)] and so on + If min_word_length = n_alphabet = 2, actions: [(0, 0,), (1, 1)] and so on """ - assert self.max_word_len >= self.min_word_len - valid_wordlens = np.arange(self.min_word_len, self.max_word_len + 1) + assert self.max_word_length >= self.min_word_length + valid_wordlens = np.arange(self.min_word_length, self.max_word_length + 1) alphabet = [a for a in range(self.n_alphabet)] actions = [] for r in valid_wordlens: actions_r = [el for el in itertools.product(alphabet, repeat=r)] actions += actions_r # Add "eos" action - # eos != n_alphabet in the init because it would break if max_word_len >1 + # eos != n_alphabet in the init because it would break if max_word_length >1 actions = actions + [(len(actions),)] self.eos = len(actions) - 1 return actions @@ -124,10 +121,10 @@ def get_mask_invalid_actions_forward(self, state=None, done=None): if state[-1] == self.padding_idx else len(state) ) - if seq_length < self.min_seq_length: + if seq_length < self.min_length: mask[self.eos] = True for idx, a in enumerate(self.action_space[:-1]): - if seq_length + len(list(a)) > self.max_seq_length: + if seq_length + len(list(a)) > self.max_length: mask[idx] = True return mask @@ -145,11 +142,11 @@ def true_density(self, max_states=1e6): """ if self._true_density is not None: return self._true_density - if self.n_alphabet**self.max_seq_length > max_states: + if self.n_alphabet**self.max_length > max_states: return (None, None, None) state_all = np.int32( list( - itertools.product(*[list(range(self.n_alphabet))] * self.max_seq_length) + itertools.product(*[list(range(self.n_alphabet))] * self.max_length) ) ) traj_rewards, state_end = zip( @@ -173,10 +170,10 @@ def true_density(self, max_states=1e6): def get_max_traj_length( self, ): - return self.max_seq_length / self.min_word_len + 1 + return self.max_length / self.min_word_length + 1 def statebatch2oracle( - self, states: List[TensorType["max_seq_length"]] + self, states: List[TensorType["max_length"]] ) -> List[str]: state_oracle = [] for state in states: @@ -189,7 +186,7 @@ def statebatch2oracle( return state_oracle def statetorch2oracle( - self, states: TensorType["batch_dim", "max_seq_length"] + self, states: TensorType["batch_dim", "max_length"] ) -> List[str]: return self.statebatch2oracle(states) @@ -197,7 +194,7 @@ def statetorch2oracle( def state2policy(self, state=None): """ Transforms the sequence (state) given as argument (or self.state if None) into a - one-hot encoding. The output is a list of length nalphabet * max_seq_length, + one-hot encoding. The output is a list of length nalphabet * max_length, where each n-th successive block of nalphabet elements is a one-hot encoding of the letter in the n-th position. @@ -208,7 +205,7 @@ def state2policy(self, state=None): - state2obs(state): [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] | A | T | G | C | - If max_seq_length > len(state), the last (max_seq_length - len(state)) blocks are all + If max_length > len(state), the last (max_length - len(state)) blocks are all 0s. """ if state is None: @@ -218,7 +215,7 @@ def state2policy(self, state=None): if state[-1] == self.padding_idx else state ) - state_policy = torch.zeros(1, self.max_seq_length, self.n_alphabet) + state_policy = torch.zeros(1, self.max_length, self.n_alphabet) if len(state) == 0: return state_policy.reshape(1, -1) state_onehot = F.one_hot(state, num_classes=self.n_alphabet + 1)[:, :, 1:].to( @@ -228,7 +225,7 @@ def state2policy(self, state=None): return state_policy.reshape(state.shape[0], -1) def statebatch2policy( - self, states: List[TensorType["1", "max_seq_length"]] + self, states: List[TensorType["1", "max_length"]] ) -> TensorType["batch", "policy_input_dim"]: """ Transforms a batch of states into the policy model format. The output is a numpy @@ -241,7 +238,7 @@ def statebatch2policy( return state_policy def statetorch2policy( - self, states: TensorType["batch", "max_seq_length"] + self, states: TensorType["batch", "max_length"] ) -> TensorType["batch", "policy_input_dim"]: if states.dtype != torch.long: states = states.to(torch.long) @@ -256,7 +253,7 @@ def statetorch2policy( assert torch.eq(state_onehot_pad, state_onehot).all() state_policy = torch.zeros( states.shape[0], - self.max_seq_length, + self.max_length, self.n_alphabet, device=self.device, dtype=self.float, @@ -277,7 +274,7 @@ def policytorch2state(self, state_policy: List) -> List: A, A, T, G, C """ mat_state_policy = torch.reshape( - state_policy, (self.max_seq_length, self.n_alphabet) + state_policy, (self.max_length, self.n_alphabet) ) state = torch.where(mat_state_policy)[1].tolist() return state @@ -296,7 +293,7 @@ def policy2state(self, state_policy: List) -> List: A, A, T, G, C """ mat_state_policy = np.reshape( - state_policy, (self.max_seq_length, self.n_alphabet) + state_policy, (self.max_length, self.n_alphabet) ) state = np.where(mat_state_policy)[1].tolist() return state @@ -305,7 +302,7 @@ def state2oracle(self, state: List = None): return "".join(self.state2readable(state)) def statebatch2oracle( - self, states: List[TensorType["max_seq_length"]] + self, states: List[TensorType["max_length"]] ) -> List[str]: state_oracle = [] for state in states: @@ -318,7 +315,7 @@ def statebatch2oracle( return state_oracle def statetorch2oracle( - self, states: TensorType["batch_dim", "max_seq_length"] + self, states: TensorType["batch_dim", "max_length"] ) -> List[str]: return self.statebatch2oracle(states) @@ -335,7 +332,7 @@ def state2readable(self, state: List) -> str: state = state.tolist() return "".join([self.inverse_lookup[el] for el in state]) - def statetorch2readable(self, state: TensorType["1", "max_seq_length"]) -> str: + def statetorch2readable(self, state: TensorType["1", "max_length"]) -> str: if state[-1] == self.padding_idx: state = state[: torch.where(state == self.padding_idx)[0][0]] # TODO: neater way without having lookup as input arg @@ -349,20 +346,20 @@ def statetorch2readable(self, state: TensorType["1", "max_seq_length"]) -> str: readable = [self.inverse_lookup[el] for el in state] return "".join(readable) - def readable2state(self, readable) -> TensorType["batch_dim", "max_seq_length"]: + def readable2state(self, readable) -> TensorType["batch_dim", "max_length"]: """ Transforms a list or string of letters into a list of indices according to an alphabet. """ if isinstance(readable, str): encoded_readable = [self.lookup[el] for el in readable] state = ( - torch.ones(self.max_seq_length, dtype=torch.int64) * self.padding_idx + torch.ones(self.max_length, dtype=torch.int64) * self.padding_idx ) state[: len(encoded_readable)] = torch.tensor(encoded_readable) else: encoded_readable = [[self.lookup[el] for el in seq] for seq in readable] state = ( - torch.ones((len(readable), self.max_seq_length), dtype=torch.int64) + torch.ones((len(readable), self.max_length), dtype=torch.int64) * self.padding_idx ) for i, seq in enumerate(encoded_readable): @@ -376,7 +373,7 @@ def get_parents(self, state=None, done=None, action=None): Args ---- state : list - Representation of a sequence (state), as a list of length max_seq_length + Representation of a sequence (state), as a list of length max_length where each element is the index of a letter in the alphabet, from 0 to (nalphabet - 1). @@ -407,7 +404,7 @@ def get_parents(self, state=None, done=None, action=None): state_last_element = int(torch.where(state == self.padding_idx)[0][0]) else: state_last_element = len(state) - max_parent_action_length = self.max_word_len + 1 - self.min_word_len + max_parent_action_length = self.max_word_length + 1 - self.min_word_length for parent_action_length in range(1, max_parent_action_length + 1): parent_action = tuple(state[state_last_element - parent_action_length : state_last_element].numpy()) if parent_action in self.action_space: @@ -450,7 +447,7 @@ def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int, int], bool]: state_last_element = int(torch.where(self.state == self.padding_idx)[0][0]) if action[0] != self.eos: state_next = self.state.clone().detach() - if state_last_element + len(action) > self.max_seq_length: + if state_last_element + len(action) > self.max_length: valid = False else: state_next[ @@ -461,7 +458,7 @@ def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int, int], bool]: self.n_actions += 1 return self.state, action, valid else: - if state_last_element < self.min_seq_length: + if state_last_element < self.min_length: valid = False else: self.done = True From beed5df6b33f007f50481172412a12e2b7a7f028 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 17 Oct 2023 22:19:47 -0400 Subject: [PATCH 08/54] Replace original README with info about the new private repository. --- README.md | 42 +++++++----------------------------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 796f8dc3b..61d306145 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,11 @@ -# GFlowNet +# Private sister repository of gflownet -This repository implements GFlowNets, generative flow networks for probabilistic modelling, on PyTorch. A design guideline behind this implementation is the separation of the logic of the GFlowNet agent and the environments on which the agent can be trained on. In other words, this implementation should allow its extension with new environments without major or any changes to to the agent. Another design guideline is flexibility and modularity. The configuration is handled via the use of [Hydra](https://hydra.cc/docs/intro/). +This repository (`gflownet-dev`) is private. It is meant to be used to develop research ideas and projects before making them public in the original [alexhernandezgarcia/gflownet](https://github.com/alexhernandezgarcia/gflownet) repository (`gflownet`). -## Installation +As of October 2023, it is uncertain whether we will stick to this plan in the long term, but the idea is the following: -### pip +- Develop ideas and projects in `gflownet-dev`. +- Upon publication or whenever the authors feel comfortable, transfer the relevant code to `gflownet`. +- Relevant code improvements and development that does not compromise research projects should be transferred to `gflownet` as early as possible. -```bash -python -m pip install --upgrade https://github.com/alexhernandezgarcia/gflownet/archive/main.zip -``` - -## How to train a GFlowNet model - -To train a GFlowNet model with the default configuration, simply run - -```bash -python main.py user.logdir.root= -``` - -Alternatively, you can create a user configuration file in `config/user/.yaml` specifying a `logdir.root` and run - -```bash -python main.py user= -``` - -Using Hydra, you can easily specify any variable of the configuration in the command line. For example, to train GFlowNet with the trajectory balance loss, on the continuous torus (`ctorus`) environment and the corresponding proxy: - -```bash -python main.py gflownet=trajectorybalance env=ctorus proxy=torus -``` - -The above command will overwrite the `env` and `proxy` default configuration with the configuration files in `config/env/ctorus.yaml` and `config/proxy/torus.yaml` respectively. - -Hydra configuration is hierarchical. For instance, a handy variable to change while debugging our code is to avoid logging to wandb. You can do this by setting `logger.do.online=False`. - -## Logging to wandb - -The repository supports logging of train and evaluation metrics to [wandb.ai](https://wandb.ai), but it is disabled by default. In order to enable it, set the configuration variable `logger.do.online` to `True`. +This involves extra complexity, so we will re-evaluate or refine this plan after a test period. From 9a4829df4c29389f3b19b43978ef3126ace08a9f Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Thu, 26 Oct 2023 21:29:43 -0400 Subject: [PATCH 09/54] WIP: few changes to sequence base --- gflownet/envs/seqs/sequence.py | 73 +++++++++++++++------------------- 1 file changed, 32 insertions(+), 41 deletions(-) diff --git a/gflownet/envs/seqs/sequence.py b/gflownet/envs/seqs/sequence.py index 4bd012053..8d93f7435 100644 --- a/gflownet/envs/seqs/sequence.py +++ b/gflownet/envs/seqs/sequence.py @@ -4,33 +4,30 @@ sequence environment (not-implemented as of July 2023) would be a "mutation-based" modification of the sequences, or a combination of mutations and additions. """ -from typing import List, Tuple import itertools +from typing import Iterable, List, Tuple + +import matplotlib.pyplot as plt import numpy as np -from gflownet.envs.base import GFlowNetEnv -import itertools -from polyleven import levenshtein import numpy.typing as npt -from torchtyping import TensorType import torch -import matplotlib.pyplot as plt import torch.nn.functional as F +from polyleven import levenshtein +from torchtyping import TensorType + +from gflownet.envs.base import GFlowNetEnv class Sequence(GFlowNetEnv): """ Parent of sequence environments. By default, for illustration purposes, this parent - class is functional and represents binary sequences of 0s and 1s. + class is functional and represents binary sequences of 0s and 1s, that can be + padded with the special token [PAD] and are terminated by the special token [EOS]. Attributes ---------- - dictionary : dict - A dictionary containing (key: value) pairs where the value is a string - representing the token and the key is an arbitrary integer uniquely identifying - the token. The dictionary should also include (key: value) pairs for the - special tokens: - - End of sequence token (EOS) - - Padding token (PAD) + tokens : iterable + An iterable containing the vocabulary of tokens that make the sequences. max_length : int Maximum length of the sequences. @@ -47,31 +44,31 @@ class is functional and represents binary sequences of 0s and 1s. def __init__( self, - dictionary={0: "0", 1: "1", -1: "[PAD]", -2: "[EOS]"}, - max_length=10, - min_length=1, - max_word_length=1, - min_word_length=1, - special_tokens=None, + tokens: Iterable = [0, 1], + max_length: int = 10, + min_length: int = 1, + max_word_length: int = 1, + min_word_length: int = 1, **kwargs, ): + assert min_length > 0 + assert max_length > 0 + assert min_word_length > 0 + assert max_word_length > 0 # Main attributes - self.dictionary= dictionary + self.tokens = tokens self.min_length = min_length self.max_length = max_length self.min_word_length = min_word_length self.max_word_length = max_word_length - self.lookup = {a: i for (i, a) in enumerate(self.vocab)} self.inverse_lookup = {i: a for (i, a) in enumerate(self.vocab)} self.n_alphabet = len(self.vocab) - len(special_tokens) self.padding_idx = self.lookup["[PAD]"] # TODO: eos re-initalised in get_actions_space so why was this initialisation required in the first place (maybe mfenv) self.eos = self.lookup["[EOS]"] - self.source = ( - torch.ones(self.max_length, dtype=torch.int64) * self.padding_idx - ) + self.source = torch.ones(self.max_length, dtype=torch.int64) * self.padding_idx # reset this to a lower value self.min_reward = 1e-20 # if proxy is not None: @@ -145,9 +142,7 @@ def true_density(self, max_states=1e6): if self.n_alphabet**self.max_length > max_states: return (None, None, None) state_all = np.int32( - list( - itertools.product(*[list(range(self.n_alphabet))] * self.max_length) - ) + list(itertools.product(*[list(range(self.n_alphabet))] * self.max_length)) ) traj_rewards, state_end = zip( *[ @@ -172,9 +167,7 @@ def get_max_traj_length( ): return self.max_length / self.min_word_length + 1 - def statebatch2oracle( - self, states: List[TensorType["max_length"]] - ) -> List[str]: + def statebatch2oracle(self, states: List[TensorType["max_length"]]) -> List[str]: state_oracle = [] for state in states: if state[-1] == self.padding_idx: @@ -292,18 +285,14 @@ def policy2state(self, state_policy: List) -> List: - policy2state(state_policy): [0, 0, 1, 3, 2] A, A, T, G, C """ - mat_state_policy = np.reshape( - state_policy, (self.max_length, self.n_alphabet) - ) + mat_state_policy = np.reshape(state_policy, (self.max_length, self.n_alphabet)) state = np.where(mat_state_policy)[1].tolist() return state def state2oracle(self, state: List = None): return "".join(self.state2readable(state)) - def statebatch2oracle( - self, states: List[TensorType["max_length"]] - ) -> List[str]: + def statebatch2oracle(self, states: List[TensorType["max_length"]]) -> List[str]: state_oracle = [] for state in states: if state[-1] == self.padding_idx: @@ -352,9 +341,7 @@ def readable2state(self, readable) -> TensorType["batch_dim", "max_length"]: """ if isinstance(readable, str): encoded_readable = [self.lookup[el] for el in readable] - state = ( - torch.ones(self.max_length, dtype=torch.int64) * self.padding_idx - ) + state = torch.ones(self.max_length, dtype=torch.int64) * self.padding_idx state[: len(encoded_readable)] = torch.tensor(encoded_readable) else: encoded_readable = [[self.lookup[el] for el in seq] for seq in readable] @@ -406,7 +393,11 @@ def get_parents(self, state=None, done=None, action=None): state_last_element = len(state) max_parent_action_length = self.max_word_length + 1 - self.min_word_length for parent_action_length in range(1, max_parent_action_length + 1): - parent_action = tuple(state[state_last_element - parent_action_length : state_last_element].numpy()) + parent_action = tuple( + state[ + state_last_element - parent_action_length : state_last_element + ].numpy() + ) if parent_action in self.action_space: parent = state.clone().detach() parent[ From b4919d6abc7cab2fda62b83b7bc8dc3464f9c753 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Thu, 26 Oct 2023 23:02:42 -0400 Subject: [PATCH 10/54] Finished __init__ for now. --- gflownet/envs/seqs/sequence.py | 43 ++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/gflownet/envs/seqs/sequence.py b/gflownet/envs/seqs/sequence.py index 8d93f7435..8ca72bc92 100644 --- a/gflownet/envs/seqs/sequence.py +++ b/gflownet/envs/seqs/sequence.py @@ -40,6 +40,12 @@ class is functional and represents binary sequences of 0s and 1s, that can be min_word_length : int Minimum number of tokens allowed per action. + + eos_token : int, str + EOS token. Default: -1. + + pad_token : int, str + PAD token. Default: -2. """ def __init__( @@ -49,42 +55,43 @@ def __init__( min_length: int = 1, max_word_length: int = 1, min_word_length: int = 1, + eos_token: Union[int, str] = -1, + pad_token: Union[int, str] = -2, **kwargs, ): assert min_length > 0 assert max_length > 0 + assert max_length >= min_length assert min_word_length > 0 assert max_word_length > 0 + assert max_word_length >= min_word_length # Main attributes - self.tokens = tokens + self.tokens = set(tokens) self.min_length = min_length self.max_length = max_length self.min_word_length = min_word_length self.max_word_length = max_word_length - - self.lookup = {a: i for (i, a) in enumerate(self.vocab)} - self.inverse_lookup = {i: a for (i, a) in enumerate(self.vocab)} - self.n_alphabet = len(self.vocab) - len(special_tokens) - self.padding_idx = self.lookup["[PAD]"] - # TODO: eos re-initalised in get_actions_space so why was this initialisation required in the first place (maybe mfenv) - self.eos = self.lookup["[EOS]"] - self.source = torch.ones(self.max_length, dtype=torch.int64) * self.padding_idx - # reset this to a lower value - self.min_reward = 1e-20 - # if proxy is not None: - # self.proxy = proxy - super().__init__( - **kwargs, + self.eos_idx = -1 + self.pad_idx = -2 + # Dictionaries + self.idx2token = {idx: token for idx, token in enumerate(self.tokens)} + self.idx2token[self.eos_idx] = eos_token + self.idx2token[self.pad_idx] = pad_token + self.token2idx = {token: idx for idx, token in self.idx2token.items()} + # Source state: vector of length max_length filled with pad token + self.source = torch.full( + self.max_length, self.pad_idx, dtype=torch.long, device=self.device ) - self.policy_input_dim = self.state2policy().shape[-1] - self.tokenizer = None + # End-of-sequence action + self.eos = (self.eos_idx,) + (self.pad_idx,) * (self.max_word_length - 1) + # Base class init + super().__init__(**kwargs) def get_action_space(self): """ Constructs list with all possible actions If min_word_length = n_alphabet = 2, actions: [(0, 0,), (1, 1)] and so on """ - assert self.max_word_length >= self.min_word_length valid_wordlens = np.arange(self.min_word_length, self.max_word_length + 1) alphabet = [a for a in range(self.n_alphabet)] actions = [] From 8ce60ab8271f60d6190ecf1e16d7c46afddc83af Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Thu, 26 Oct 2023 23:11:41 -0400 Subject: [PATCH 11/54] Docstring of get_action_space. --- gflownet/envs/seqs/sequence.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/gflownet/envs/seqs/sequence.py b/gflownet/envs/seqs/sequence.py index 8ca72bc92..23c23e57c 100644 --- a/gflownet/envs/seqs/sequence.py +++ b/gflownet/envs/seqs/sequence.py @@ -89,8 +89,19 @@ def __init__( def get_action_space(self): """ - Constructs list with all possible actions - If min_word_length = n_alphabet = 2, actions: [(0, 0,), (1, 1)] and so on + Constructs list with all possible actions, including eos. + + An action is represented by a vector of length max_word_length where each + element indicates the idex of the token to add to the sequence. Actions with a + number of tokens smaller than max_word_length are padded with pad_idx. + + Examples: + If min_word_length = 1 and max_word_length = 1: + actions: [(0,), (1,), (-1,)] + If min_word_length = 2 and max_word_length = 2: + actions: [(0, 0,), (0, 1), (1, 0), (1, 1), (-1, -2)] + If min_word_length = 1 and max_word_length = 2: + actions: [(0, -2), (1, -2), (0, 0,), (0, 1), (1, 0), (1, 1), (-1, -2)] """ valid_wordlens = np.arange(self.min_word_length, self.max_word_length + 1) alphabet = [a for a in range(self.n_alphabet)] From 8b0b6218ce1a30148a2225e34c5342e7b05bae8b Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sun, 29 Oct 2023 18:33:36 -0400 Subject: [PATCH 12/54] Remove env.oracle and leave proxy only --- gflownet/envs/base.py | 52 +++++++++------------------------------- gflownet/utils/buffer.py | 2 +- main.py | 2 +- 3 files changed, 13 insertions(+), 43 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 9b5e81f3a..572935064 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -38,7 +38,6 @@ def __init__( energies_stats: List[int] = None, denorm_proxy: bool = False, proxy=None, - oracle=None, proxy_state_format: str = "oracle", fixed_distr_params: Optional[dict] = None, random_distr_params: Optional[dict] = None, @@ -68,17 +67,10 @@ def __init__( self.reward_func = reward_func self.energies_stats = energies_stats self.denorm_proxy = denorm_proxy - # Proxy and oracle + # Proxy self.proxy = proxy self.setup_proxy() - if oracle is None: - self.oracle = self.proxy - else: - self.oracle = oracle - if self.oracle is None or self.oracle.higher_is_better: - self.proxy_factor = 1.0 - else: - self.proxy_factor = -1.0 + self.proxy_factor = -1.0 self.proxy_state_format = proxy_state_format # Flag to skip checking if action is valid (computing mask) before step self.skip_mask_check = skip_mask_check @@ -99,9 +91,6 @@ def __init__( self.random_policy_output = self.get_policy_output(self.random_distr_params) self.policy_output_dim = len(self.fixed_policy_output) self.policy_input_dim = len(self.state2policy()) - if proxy is not None and self.proxy == self.oracle: - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle @abstractmethod def get_action_space(self): @@ -706,25 +695,6 @@ def statetorch2proxy( """ return states - def state2oracle(self, state: List = None): - """ - Prepares a state in "GFlowNet format" for the oracle. - - Args - ---- - state : list - A state - """ - if state is None: - state = self.state.copy() - return state - - def statebatch2oracle(self, states: List[List]): - """ - Prepares a batch of states in "GFlowNet format" for the oracles - """ - return states - def statetorch2policy( self, states: TensorType["batch_size", "state_dim"] ) -> TensorType["batch_size", "policy_input_dim"]: @@ -824,9 +794,9 @@ def reward_torchbatch( def proxy2reward(self, proxy_vals): """ - Prepares the output of an oracle for GFlowNet: the inputs proxy_vals is - expected to be a negative value (energy), unless self.denorm_proxy is True. If - the latter, the proxy values are first de-normalized according to the mean and + Prepares the output of a proxy for GFlowNet: the inputs proxy_vals is expected + to be a negative value (energy), unless self.denorm_proxy is True. If the + latter, the proxy values are first de-normalized according to the mean and standard deviation in self.energies_stats. The output of the function is a strictly positive reward - provided self.reward_norm and self.reward_beta are positive - and larger than self.min_reward. @@ -869,7 +839,7 @@ def proxy2reward(self, proxy_vals): def reward2proxy(self, reward): """ Converts a "GFlowNet reward" into a (negative) energy or values as returned by - an oracle. + a proxy. """ if self.reward_func == "power": return self.proxy_factor * torch.exp( @@ -1350,7 +1320,7 @@ def top_k_metrics_and_plots( return metrics, figs, fig_names def plot_reward_distribution( - self, states=None, scores=None, ax=None, title=None, oracle=None, **kwargs + self, states=None, scores=None, ax=None, title=None, proxy=None, **kwargs ): if ax is None: fig, ax = plt.subplots() @@ -1359,15 +1329,15 @@ def plot_reward_distribution( standalone = False if title == None: title = "Scores of Sampled States" - if oracle is None: - oracle = self.oracle + if proxy is None: + proxy = self.proxy if scores is None: if isinstance(states[0], torch.Tensor): states = torch.vstack(states).to(self.device, self.float) if isinstance(states, torch.Tensor) == False: states = torch.tensor(states, device=self.device, dtype=self.float) - oracle_states = self.statetorch2oracle(states) - scores = oracle(oracle_states) + states_proxy = self.statetorch2proxy(states) + scores = self.proxy(states_proxy) if isinstance(scores, TensorType): scores = scores.cpu().detach().numpy() ax.hist(scores) diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index c66f4d8d3..c7fd61ade 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -230,7 +230,7 @@ def make_data_set(self, config): samples = self.env.get_random_terminating_states(config.n) else: return None, None - energies = self.env.oracle(self.env.statebatch2oracle(samples)).tolist() + energies = self.env.proxy(self.env.statebatch2proxy(samples)).tolist() df = pd.DataFrame( { "samples": [self.env.state2readable(s) for s in samples], diff --git a/main.py b/main.py index 127e59001..62a5d064a 100644 --- a/main.py +++ b/main.py @@ -77,7 +77,7 @@ def main(config): if config.n_samples > 0 and config.n_samples <= 1e5: batch, times = gflownet.sample_batch(n_forward=config.n_samples, train=False) x_sampled = batch.get_terminating_states(proxy=True) - energies = env.oracle(x_sampled) + energies = env.proxy(x_sampled) x_sampled = batch.get_terminating_states() df = pd.DataFrame( { From 0088ee4ca383316e21ed9e615c56602f45b4ebb5 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sun, 29 Oct 2023 23:03:29 -0400 Subject: [PATCH 13/54] Grid: states2proxy, states2policy; temporary state because old code is still there. --- gflownet/envs/grid.py | 69 ++++++++++++++++++++++++++------ tests/gflownet/envs/test_grid.py | 12 +++--- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 46e3639ef..6a7727c16 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -2,7 +2,7 @@ Classes to represent a hyper-grid environments """ import itertools -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -12,6 +12,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import tfloat, tlong class Grid(GFlowNetEnv): @@ -81,9 +82,6 @@ def __init__( # TODO: assess if really needed if self.proxy_state_format == "ohe": self.statebatch2proxy = self.statebatch2policy - elif self.proxy_state_format == "oracle": - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle def get_action_space(self): """ @@ -127,7 +125,7 @@ def get_mask_invalid_actions_forward( mask[idx] = True return mask - def state2oracle(self, state: List = None) -> List: + def state2proxy(self, state: List = None) -> List: """ Prepares a state in "GFlowNet format" for the oracles: a list of length n_dim with values in the range [cell_min, cell_max] for each state. @@ -150,33 +148,52 @@ def state2oracle(self, state: List = None) -> List: .tolist() ) - def statebatch2oracle( + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: + """ + Prepares a batch of states in "GFlowNet format" for the proxy: each state is + a vector of length n_dim with values in the range [cell_min, cell_max]. + + See: statetorch2policy() + """ + states = tfloat(states, device=self.device, float_type=self.float) + return ( + self.statetorch2policy(states).reshape( + (states.shape[0], self.n_dim, self.length) + ) + * torch.tensor(self.cells[None, :]).to(states.device, self.float) + ).sum(axis=2) + + def statebatch2proxy( self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: """ Prepares a batch of states in "GFlowNet format" for the oracles: each state is a vector of length n_dim with values in the range [cell_min, cell_max]. - See: statetorch2oracle() + See: statetorch2proxy() Args ---- state : list State """ - return self.statetorch2oracle( - torch.tensor(states, device=self.device, dtype=self.float) + return self.states2proxy(states) + return self.statetorch2proxy( + tfloat(states, device=self.device, float_type=self.float) ) - def statetorch2oracle( + def statetorch2proxy( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: """ Prepares a batch of states in "GFlowNet format" for the oracles: each state is a vector of length n_dim with values in the range [cell_min, cell_max]. See: statetorch2policy() """ + return self.states2proxy(states) return ( self.statetorch2policy(states).reshape( (len(states), self.n_dim, self.length) @@ -202,6 +219,32 @@ def state2policy(self, state: List = None) -> List: state_policy[(np.arange(len(state)) * self.length + state)] = 1 return state_policy.tolist() + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_output_dim"]: + """ + Prepares a batch of states in "GFlowNet format" for the policy model: states + are one-hot encoded. + + The output is a 2D tensor, with the second dimension of size length * n_dim, + where each n-th successive block of length elements is a one-hot encoding of + the position in the n-th dimension. + + Example (n_dim = 3, length = 4): + - state: [0, 3, 1] + - policy format: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] + | 0 | 3 | 1 | + """ + states = tlong(states, device=self.device) + n_states = states.shape[0] + cols = states + torch.arange(self.n_dim) * self.length + rows = torch.repeat_interleave(torch.arange(n_states), self.n_dim) + states_policy = torch.zeros( + (n_states, self.length * self.n_dim), dtype=self.float, device=self.device + ) + states_policy[rows, cols.flatten()] = 1.0 + return states_policy + def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: """ Transforms a batch of states into a one-hot encoding. The output is a numpy @@ -209,6 +252,7 @@ def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: See state2policy(). """ + return self.states2policy(states) cols = np.array(states) + np.arange(self.n_dim) * self.length rows = np.repeat(np.arange(len(states)), self.n_dim) state_policy = np.zeros( @@ -226,6 +270,7 @@ def statetorch2policy( See state2policy(). """ + return self.states2policy(states) device = states.device cols = (states + torch.arange(self.n_dim).to(device) * self.length).to(int) rows = torch.repeat_interleave( diff --git a/tests/gflownet/envs/test_grid.py b/tests/gflownet/envs/test_grid.py index 2bf3ea4bf..747685554 100644 --- a/tests/gflownet/envs/test_grid.py +++ b/tests/gflownet/envs/test_grid.py @@ -45,7 +45,7 @@ def config_path(): @pytest.mark.parametrize( - "state, state2oracle", + "state, state2proxy", [ ( [0, 0, 0], @@ -65,12 +65,12 @@ def config_path(): ), ], ) -def test__state2oracle__returns_expected(env, state, state2oracle): - assert state2oracle == env.state2oracle(state) +def test__state2proxy__returns_expected(env, state, state2proxy): + assert state2proxy == env.state2proxy(state) @pytest.mark.parametrize( - "states, statebatch2oracle", + "states, statebatch2proxy", [ ( [[0, 0, 0], [4, 4, 4], [1, 2, 3], [4, 0, 1]], @@ -78,8 +78,8 @@ def test__state2oracle__returns_expected(env, state, state2oracle): ), ], ) -def test__statebatch2oracle__returns_expected(env, states, statebatch2oracle): - assert torch.equal(torch.Tensor(statebatch2oracle), env.statebatch2oracle(states)) +def test__statebatch2proxy__returns_expected(env, states, statebatch2proxy): + assert torch.equal(torch.Tensor(statebatch2proxy), env.statebatch2proxy(states)) @pytest.mark.parametrize( From c83d8b38ec78146070a296dcec407916dbdc2a23 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sun, 29 Oct 2023 23:50:28 -0400 Subject: [PATCH 14/54] docstring --- gflownet/envs/grid.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 6a7727c16..a820a5647 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -156,6 +156,16 @@ def states2proxy( a vector of length n_dim with values in the range [cell_min, cell_max]. See: statetorch2policy() + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ states = tfloat(states, device=self.device, float_type=self.float) return ( @@ -234,6 +244,16 @@ def states2policy( - state: [0, 3, 1] - policy format: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] | 0 | 3 | 1 | + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ states = tlong(states, device=self.device) n_states = states.shape[0] From 04bb59fb9e9496cf7d8ee7a802c1357f322fb7ff Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sun, 29 Oct 2023 23:50:48 -0400 Subject: [PATCH 15/54] Tetris: states2proxy, states2policy; temporary state because old code is still there. --- gflownet/envs/tetris.py | 91 +++++++++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 21 deletions(-) diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index a80d447ff..e9ff4d542 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -4,7 +4,7 @@ import itertools import re import warnings -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -99,10 +99,6 @@ def __init__( ) # End-of-sequence action: all -1 self.eos = (-1, -1, -1) - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle # Precompute all possible rotations of each piece and the corresponding binary # mask @@ -251,11 +247,12 @@ def get_mask_invalid_actions_forward( mask[-1] = True return mask - def state2oracle( + def state2proxy( self, state: Optional[TensorType["height", "width"]] = None ) -> TensorType["height", "width"]: """ - Prepares a state in "GFlowNet format" for the oracles: simply converts non-zero + Prepares a state in "environment format" for the oracles: simply converts + non-zero (non-empty) cells into 1s. Args @@ -264,13 +261,37 @@ def state2oracle( """ if state is None: state = self.state.clone().detach() - state_oracle = state.clone().detach() - state_oracle[state_oracle != 0] = 1 - return state_oracle + state_proxy = state.clone().detach() + state_proxy[state_proxy != 0] = 1 + return state_proxy - def statebatch2oracle( + def states2proxy( + self, + states: Union[ + List[TensorType["height", "width"]], TensorType["height", "width", "batch"] + ], + ) -> TensorType["height", "width", "batch"]: + """ + Prepares a batch of states in "environment format" for a proxy: : simply + converts non-zero (non-empty) cells into 1s. + + Args + ---- + states : list of 2D tensors or 3D tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. + """ + states = tint(states, device=self.device, int_type=self.int) + states[states != 0] = 1 + return states + + def statebatch2proxy( self, states: List[TensorType["height", "width"]] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: """ Prepares a batch of states in "GFlowNet format" for the oracles: simply converts non-zero (non-empty) cells into 1s. @@ -279,17 +300,19 @@ def statebatch2oracle( ---- state : list """ + return self.states2proxy(states) states = torch.stack(states) states[states != 0] = 1 return states - def statetorch2oracle( + def statetorch2proxy( self, states: TensorType["height", "width", "batch"] ) -> TensorType["height", "width", "batch"]: """ Prepares a batch of states in "GFlowNet format" for the oracles: : simply converts non-zero (non-empty) cells into 1s. """ + return self.states2proxy(states) states[states != 0] = 1 return states @@ -299,19 +322,44 @@ def state2policy( """ Prepares a state in "GFlowNet format" for the policy model. - See: state2oracle() + See: state2proxy() + """ + return self.state2proxy(state).flatten() + + def states2policy( + self, + states: Union[ + List[TensorType["height", "width"]], TensorType["height", "width", "batch"] + ], + ) -> TensorType["height", "width", "batch"]: + """ + Prepares a batch of states in "environment format" for the policy model. + + See statetorch2proxy(). + + Args + ---- + states : list of 2D tensors or 3D tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ - return self.state2oracle(state).flatten() + states = tint(states, device=self.device, int_type=self.int) + return self.states2proxy(states).flatten(start_dim=1) def statebatch2policy( self, states: List[TensorType["height", "width"]] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: """ Prepares a batch of states in "GFlowNet format" for the policy model. - See statebatch2oracle(). + See statebatch2proxy(). """ - return self.statebatch2oracle(states).flatten(start_dim=1) + return self.states2policy(states) + return self.statebatch2proxy(states).flatten(start_dim=1) def statetorch2policy( self, states: TensorType["height", "width", "batch"] @@ -319,9 +367,10 @@ def statetorch2policy( """ Prepares a batch of states in "GFlowNet format" for the policy model. - See statetorch2oracle(). + See statetorch2proxy(). """ - return self.statetorch2oracle(states).flatten(start_dim=1) + return self.states2policy(states) + return self.statetorch2proxy(states).flatten(start_dim=1) def policy2state( self, policy: Optional[TensorType["height", "width"]] = None @@ -329,7 +378,7 @@ def policy2state( """ Returns None to signal that the conversion is not reversible. - See: state2oracle() + See: state2proxy() """ return None From 41962f955a3ef7c7d0dda83cd5ffa42544c439cc Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sun, 29 Oct 2023 23:51:48 -0400 Subject: [PATCH 16/54] Refactoring of tfloat, tlogn, tint and tbool --- gflownet/utils/common.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index afa751816..ab1c85d92 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -167,36 +167,36 @@ def batch_with_rest(start, stop, step, tensor=False): def tfloat(x, device, float_type): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(float_type).to(device) + return torch.stack(x).to(device=device, dtype=float_type) if torch.is_tensor(x): - return x.type(float_type).to(device) + return x.to(device=device, dtype=float_type) else: return torch.tensor(x, dtype=float_type, device=device) def tlong(x, device): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(torch.long).to(device) + return torch.stack(x).to(device=device, dtype=torch.long) if torch.is_tensor(x): - return x.type(torch.long).to(device) + return x.to(device=device, dtype=torch.long) else: return torch.tensor(x, dtype=torch.long, device=device) def tint(x, device, int_type): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(int_type).to(device) + return torch.stack(x).to(device=device, dtype=int_type) if torch.is_tensor(x): - return x.type(int_type).to(device) + return x.to(device=device, dtype=int_type) else: return torch.tensor(x, dtype=int_type, device=device) def tbool(x, device): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(torch.bool).to(device) + return torch.stack(x).to(device=device, dtype=torch.bool) if torch.is_tensor(x): - return x.type(torch.bool).to(device) + return x.to(device=device, dtype=torch.bool) else: return torch.tensor(x, dtype=torch.bool, device=device) From e446262dbf9c4f170be90178836ce5d0c40f2c62 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sun, 29 Oct 2023 23:58:24 -0400 Subject: [PATCH 17/54] Typo in typing and correct docstrings --- gflownet/envs/grid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index a820a5647..5bfd60948 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -152,7 +152,7 @@ def states2proxy( self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the proxy: each state is + Prepares a batch of states in "environment format" for the proxy: each state is a vector of length n_dim with values in the range [cell_min, cell_max]. See: statetorch2policy() @@ -231,9 +231,9 @@ def state2policy(self, state: List = None) -> List: def states2policy( self, states: Union[List, TensorType["batch", "state_dim"]] - ) -> TensorType["batch", "policy_output_dim"]: + ) -> TensorType["batch", "policy_input_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the policy model: states + Prepares a batch of states in "environment format" for the policy model: states are one-hot encoded. The output is a 2D tensor, with the second dimension of size length * n_dim, From be67b5c442a515a24bdd7237e5245613e331d1de Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Sun, 29 Oct 2023 23:58:44 -0400 Subject: [PATCH 18/54] Cube: states2policy; temporary state because old code is still there. --- gflownet/envs/cube.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index a66363286..7eb08f513 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -4,7 +4,7 @@ import itertools import warnings from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -92,14 +92,11 @@ def __init__( self.epsilon = epsilon # Small constant to restrict the interval of (test) sets self.kappa = kappa - # Conversions: only conversions to policy are implemented and the rest are the - # same + # Conversions: only conversions to policy are implemented and the conversion to + # proxy format is the same self.state2proxy = self.state2policy self.statebatch2proxy = self.statebatch2policy self.statetorch2proxy = self.statetorch2policy - self.state2oracle = self.state2proxy - self.statebatch2oracle = self.statebatch2proxy - self.statetorch2oracle = self.statetorch2proxy # Base class init super().__init__( fixed_distr_params=fixed_distr_params, @@ -128,6 +125,26 @@ def get_mask_invalid_actions_forward( def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None): pass + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: + """ + Prepares a batch of states in "environment format" for the policy model: clips + the states into [0, 1] and maps them to [-1.0, 1.0] + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. + """ + states = float(states, device=self.device, float_type=self.float) + return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 + def statetorch2policy( self, states: TensorType["batch", "state_dim"] = None ) -> TensorType["batch", "policy_input_dim"]: @@ -139,6 +156,7 @@ def statetorch2policy( state : list State """ + return self.states2policy(states) return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 def statebatch2policy( @@ -152,6 +170,7 @@ def statebatch2policy( state : list State """ + return self.states2policy(states) return self.statetorch2policy( tfloat(states, device=self.device, float_type=self.float) ) From 89ebdaf8bf653402438377d57b8474a92caed53b Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 30 Oct 2023 20:58:39 -0400 Subject: [PATCH 19/54] Fix typos --- gflownet/envs/grid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 5bfd60948..34ce110bc 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -155,7 +155,7 @@ def states2proxy( Prepares a batch of states in "environment format" for the proxy: each state is a vector of length n_dim with values in the range [cell_min, cell_max]. - See: statetorch2policy() + See: states2policy() Args ---- @@ -169,7 +169,7 @@ def states2proxy( """ states = tfloat(states, device=self.device, float_type=self.float) return ( - self.statetorch2policy(states).reshape( + self.states2policy(states).reshape( (states.shape[0], self.n_dim, self.length) ) * torch.tensor(self.cells[None, :]).to(states.device, self.float) From 43bf2ea91184a7ddff51f4c23e55d954e9423c24 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 30 Oct 2023 20:59:25 -0400 Subject: [PATCH 20/54] Continuous Tori: states2proxy, states2policy; temporary state because old code is still there. --- gflownet/envs/htorus.py | 68 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 6a82dee1a..f14d53adb 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -4,7 +4,7 @@ import itertools import re from copy import deepcopy -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -73,10 +73,6 @@ def __init__( self.source = self.source_angles + [0] # End-of-sequence action: (n_dim, 0) self.eos = (self.n_dim, 0) - # TODO: assess if really needed - self.state2oracle = self.state2proxy - self.statebatch2oracle = self.statebatch2proxy - self.statetorch2oracle = self.statetorch2proxy # Base class init super().__init__( fixed_distr_params=fixed_distr_params, @@ -183,6 +179,26 @@ def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=Non ] + [mask[-1]] return mask + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: + """ + Prepares a batch of states in "environment format" for the proxy: each state is + a vector of length n_dim where each value is an angle in radians. The n_actions + item is removed. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. + """ + return tfloat(states, device=self.device, float_type=self.float)[:, :-1] + def statebatch2proxy( self, states: List[List] ) -> TensorType["batch", "state_proxy_dim"]: @@ -191,6 +207,7 @@ def statebatch2proxy( each state is a row of length n_dim with an angle in radians. The n_actions item is removed. """ + return self.states2proxy(states) return torch.tensor(states, device=self.device)[:, :-1] def statetorch2proxy( @@ -199,6 +216,7 @@ def statetorch2proxy( """ Prepares a batch of states in torch "GFlowNet format" for the proxy. """ + return self.states2proxy(states) return states[:, :-1] def state2policy(self, state: List = None) -> List: @@ -211,6 +229,44 @@ def state2policy(self, state: List = None) -> List: state = self.state.copy() return self.statebatch2policy([state]).tolist()[0] + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: + """ + Prepares a batch of states in "environment format" for the policy model: if + policy_encoding_dim_per_angle >= 2, then the state (angles) is encoded using + trigonometric components. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. + """ + states = tfloat(states, float_type=self.float, device=self.device) + if ( + self.policy_encoding_dim_per_angle is None + or self.policy_encoding_dim_per_angle < 2 + ): + return states + step = states[:, -1] + code_half_size = self.policy_encoding_dim_per_angle // 2 + int_coeff = ( + torch.arange(1, code_half_size + 1).repeat(states.shape[-1] - 1).to(states) + ) + encoding = ( + torch.repeat_interleave(states[:, :-1], repeats=code_half_size, dim=1) + * int_coeff + ) + return torch.cat( + [torch.cos(encoding), torch.sin(encoding), torch.unsqueeze(step, 1)], + dim=1, + ) + def statetorch2policy( self, states: TensorType["batch", "state_dim"] ) -> TensorType["batch", "policy_input_dim"]: @@ -220,6 +276,7 @@ def statetorch2policy( If policy_encoding_dim_per_angle >= 2, then the state (angles) is encoded using trigonometric components. """ + return self.states2policy(states) if ( self.policy_encoding_dim_per_angle is not None and self.policy_encoding_dim_per_angle >= 2 @@ -249,6 +306,7 @@ def statebatch2policy( See: statetorch2policy() """ + return self.states2policy(states) states = tfloat(states, float_type=self.float, device=self.device) return self.statetorch2policy(states) From 75935783ea4de5e61590e10fe88160d14855bc36 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 30 Oct 2023 21:06:06 -0400 Subject: [PATCH 21/54] Alanine Dipeptide: states2proxy; temporary state because old code is still there. --- gflownet/envs/alaninedipeptide.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/alaninedipeptide.py index 76b725e3b..d352c6bdf 100644 --- a/gflownet/envs/alaninedipeptide.py +++ b/gflownet/envs/alaninedipeptide.py @@ -40,10 +40,38 @@ def sync_conformer_with_state(self, state: List = None): self.conformer.set_torsion_angle(ta, state[idx]) return self.conformer + # TODO: are the conversions to oracle relevant? + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> npt.NDArray: + """ + Prepares a batch of states in "environment format" for the proxy: each state is + a vector of length n_dim where each value is an angle in radians. The n_actions + item is removed. + + Important: this method returns a numpy array, unlike in most other + environments. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A numpy array containing all the states in the batch. + """ + if torch.is_tensor(states[0]): + return states.cpu().numpy()[:, :-1] + else: + return np.array(states)[:, :-1] + def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray: """ Prepares a batch of states in torch "GFlowNet format" for the oracle. """ + return self.states2proxy(states) device = states.device if device == torch.device("cpu"): np_states = states.numpy() @@ -57,6 +85,7 @@ def statebatch2proxy(self, states: List[List]) -> npt.NDArray: each state is a row of length n_dim with an angle in radians. The n_actions item is removed. """ + return self.states2proxy(states) return np.array(states)[:, :-1] def statetorch2oracle( From 72a31eab909de4df2ccbbebdbbcadfec5daa07d5 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 30 Oct 2023 21:22:16 -0400 Subject: [PATCH 22/54] Missing imports --- gflownet/envs/alaninedipeptide.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/alaninedipeptide.py index d352c6bdf..00d9d4a42 100644 --- a/gflownet/envs/alaninedipeptide.py +++ b/gflownet/envs/alaninedipeptide.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import numpy.typing as npt From f7a8d2b426487e8bcde395a8dc01d64eb59bc833 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 30 Oct 2023 21:22:39 -0400 Subject: [PATCH 23/54] Discrete Torus: states2proxy, states2policy; temporary state because old code is still there. --- gflownet/envs/torus.py | 72 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/gflownet/envs/torus.py b/gflownet/envs/torus.py index 8c0ce712d..bb6fca4c0 100644 --- a/gflownet/envs/torus.py +++ b/gflownet/envs/torus.py @@ -2,7 +2,7 @@ Classes to represent hyper-torus environments """ import itertools -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -11,6 +11,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import tfloat, tlong class Torus(GFlowNetEnv): @@ -64,9 +65,6 @@ def __init__( self.eos = tuple([self.max_increment + 1 for _ in range(self.n_dim)]) # Angle increments in radians self.angle_rad = 2 * np.pi / self.n_angles - # TODO: assess if really needed - self.state2oracle = self.state2proxy - self.statebatch2oracle = self.statebatch2proxy # Base class init super().__init__(**kwargs) @@ -112,12 +110,36 @@ def get_mask_invalid_actions_forward( mask[-1] = True return mask + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: + """ + Prepares a batch of states in "environment format" for the proxy: each state is + a vector of length n_dim where each value is an angle in radians. The n_actions + item is removed. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. + """ + return ( + tfloat(states, device=self.device, float_type=self.float)[:, :-1] + * self.angle_rad + ) + def statebatch2proxy(self, states: List[List]) -> npt.NDArray[np.float32]: """ Prepares a batch of states in "GFlowNet format" for the proxy: an array where each state is a row of length n_dim with an angle in radians. The n_actions item is removed. """ + return self.states2proxy(states) return torch.tensor(states, device=self.device)[:, :-1] * self.angle_rad def statetorch2proxy( @@ -126,6 +148,7 @@ def statetorch2proxy( """ Prepares a batch of states in torch "GFlowNet format" for the proxy. """ + return self.states2proxy(states) return states[:, :-1] * self.angle_rad # TODO: circular encoding as in htorus @@ -155,6 +178,45 @@ def state2policy(self, state=None) -> List: state_policy[-1] = state[-1] return state_policy + # TODO: circular encoding as in htorus + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: + """ + Prepares a batch of states in "environment format" for the policy model: the + policy format is a one-hot encoding of the states. + + Each row is a vector of length n_angles * n_dim + 1, where each n-th successive + block of length elements is a one-hot encoding of the position in the n-th + dimension. + + Example, n_dim = 2, n_angles = 4: + - state: [1, 3, 4] + | a | n | (a = angles, n = n_actions) + - policy format: [0, 1, 0, 0, 0, 0, 0, 1, 4] + | 1 | 3 | 4 | + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. + """ + states = tlong(states, device=self.device) + cols = states[:, :-1] + torch.arange(self.n_dim).to(self.device) * self.n_angles + rows = torch.repeat_interleave( + torch.arange(states.shape[0]).to(self.device), self.n_dim + ) + states_policy = torch.zeros( + (states.shape[0], self.n_angles * self.n_dim + 1) + ).to(states) + states_policy[rows, cols.flatten()] = 1.0 + states_policy[:, -1] = states[:, -1] + return states_policy + def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: """ Transforms a batch of states into the policy model format. The output is a numpy @@ -162,6 +224,7 @@ def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: See state2policy(). """ + return self.states2policy(states) states = np.array(states) cols = states[:, :-1] + np.arange(self.n_dim) * self.n_angles rows = np.repeat(np.arange(states.shape[0]), self.n_dim) @@ -181,6 +244,7 @@ def statetorch2policy( See state2policy(). """ + return self.states2policy(states) device = states.device cols = ( states[:, :-1] + torch.arange(self.n_dim).to(device) * self.n_angles From 7201991a4776d4462f2784f358a11e87bdcb475c Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 30 Oct 2023 21:26:56 -0400 Subject: [PATCH 24/54] Tree: remove mention to oracle --- gflownet/envs/tree.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gflownet/envs/tree.py b/gflownet/envs/tree.py index bc92cda13..51a630fd3 100644 --- a/gflownet/envs/tree.py +++ b/gflownet/envs/tree.py @@ -291,7 +291,6 @@ def __init__( raise ValueError( f"Unrecognized policy_format = {policy_format}, expected either 'mlp' or 'gnn'." ) - self.statetorch2oracle = self.statetorch2policy super().__init__( fixed_distr_params=fixed_distr_params, From 127a017b1d85b478438dd09f0e3de4928b656f85 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:06:00 -0400 Subject: [PATCH 25/54] Composition: states2proxy; temporary state because old code is still there. --- gflownet/envs/crystals/composition.py | 33 ++++++++++++--------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 6fc62e991..955a512f1 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -6,10 +6,6 @@ import numpy as np import torch -from pyxtal.symmetry import Group -from torch import Tensor -from torchtyping import TensorType - from gflownet.envs.base import GFlowNetEnv from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES @@ -19,6 +15,9 @@ space_group_lowest_free_wp_multiplicity, space_group_wyckoff_gcd, ) +from pyxtal.symmetry import Group +from torch import Tensor +from torchtyping import TensorType class Composition(GFlowNetEnv): @@ -132,10 +131,6 @@ def __init__( self.source = [0 for _ in self.elements] # End-of-sequence action self.eos = (-1, -1) - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle super().__init__(**kwargs) def get_action_space(self): @@ -404,9 +399,9 @@ def get_element_mask(min_atoms, max_atoms): return mask - def state2oracle(self, state: List = None) -> Tensor: + def state2proxy(self, state: List = None) -> Tensor: """ - Prepares a state in "GFlowNet format" for the oracle. In this case, it simply + Prepares a state in "GFlowNet format" for the proxy. In this case, it simply converts the state into a torch tensor, with dtype torch.long. Args @@ -416,7 +411,7 @@ def state2oracle(self, state: List = None) -> Tensor: Returns ---- - oracle_state : Tensor + proxy_state : Tensor Tensor containing counts of individual elements """ if state is None: @@ -424,12 +419,12 @@ def state2oracle(self, state: List = None) -> Tensor: return tlong(state, device=self.device) - def statetorch2oracle( + def statetorch2proxy( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is the atom counts for individual elements. + Prepares a batch of states in "GFlowNet format" for the proxy. The input to the + proxy is the atom counts for individual elements. Args ---- @@ -438,15 +433,15 @@ def statetorch2oracle( Returns ---- - oracle_states : Tensor + proxy_states : Tensor """ return states - def statebatch2oracle( + def statebatch2proxy( self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracles. In this case, + Prepares a batch of states in "GFlowNet format" for the proxy. In this case, it simply converts the states into a torch tensor, with dtype torch.long. Args From 650dfba0c3a9e2b2a485bc20c574aa4022b5efc4 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:06:21 -0400 Subject: [PATCH 26/54] Lattice parameters: states2proxy; temporary state because old code is still there. --- gflownet/envs/crystals/lattice_parameters.py | 51 ++++++++++++++----- .../gflownet/envs/test_lattice_parameters.py | 4 +- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index 957a7e229..ffdd78e72 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -1,14 +1,12 @@ """ Classes to represent crystal environments """ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import torch -from torch import Tensor -from torchtyping import TensorType - from gflownet.envs.grid import Grid +from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import ( CUBIC, HEXAGONAL, @@ -19,6 +17,8 @@ TETRAGONAL, TRICLINIC, ) +from torch import Tensor +from torchtyping import TensorType class LatticeParameters(Grid): @@ -336,9 +336,9 @@ def get_mask_invalid_actions_forward( return mask - def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: + def state2proxy(self, state: Optional[List[int]] = None) -> Tensor: """ - Prepares a list of states in "GFlowNet format" for the oracle. + Prepares a list of states in "GFlowNet format" for the proxy. Args ---- @@ -347,7 +347,7 @@ def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: Returns ---- - oracle_state : Tensor + proxy_state : Tensor Tensor containing lengths and angles converted from the Grid format. """ if state is None: @@ -358,12 +358,38 @@ def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: + [self.cell2angle[s] for s in state[3:]] ) - def statetorch2oracle( + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: + """ + Prepares a batch of states in "environment format" for the proxy: the + concatenation of the lengths and angles. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. + """ + states = tlong(states, device=self.device) + return torch.cat( + [ + self.lengths_tensor[states[:, :3]], + self.angles_tensor[states[:, 3:]], + ], + dim=1, + ) + + def statetorch2proxy( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is the lengths and angles. + Prepares a batch of states in "GFlowNet format" for the proxy. The input to the + proxy is the lengths and angles. Args ---- @@ -372,8 +398,9 @@ def statetorch2oracle( Returns ---- - oracle_states : Tensor + proxy_states : Tensor """ + return self.states2proxy(states) return torch.cat( [ self.lengths_tensor[states[:, :3].long()], diff --git a/tests/gflownet/envs/test_lattice_parameters.py b/tests/gflownet/envs/test_lattice_parameters.py index 16aea2814..281d30783 100644 --- a/tests/gflownet/envs/test_lattice_parameters.py +++ b/tests/gflownet/envs/test_lattice_parameters.py @@ -282,8 +282,8 @@ def test__step__changes_state_as_expected(env, lattice_system, actions, exp_stat ), ], ) -def test__state2oracle__returns_expected_tensor(env, lattice_system, state, exp_tensor): - assert torch.equal(env.state2oracle(state), torch.Tensor(exp_tensor)) +def test__state2proxy__returns_expected_tensor(env, lattice_system, state, exp_tensor): + assert torch.equal(env.state2proxy(state), torch.Tensor(exp_tensor)) @pytest.mark.parametrize("lattice_system", [TRICLINIC]) From cdf4961e0f07220d3b289b232698606ddc9836b6 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:06:36 -0400 Subject: [PATCH 27/54] Space group: states2proxy; temporary state because old code is still there. --- gflownet/envs/crystals/spacegroup.py | 64 ++++++++++++++++++---------- 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 38ea7fb4e..ab1425c7e 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -10,11 +10,11 @@ import numpy as np import torch import yaml -from torch import Tensor -from torchtyping import TensorType - from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import tlong from gflownet.utils.crystals.pyxtal_cache import space_group_check_compatible +from torch import Tensor +from torchtyping import TensorType CRYSTAL_LATTICE_SYSTEMS = None POINT_SYMMETRIES = None @@ -130,10 +130,6 @@ def __init__( # Source state: index 0 (empty) for all three properties (crystal-lattice # system index, point symmetry index, space group) self.source = [0 for _ in range(3)] - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle # Base class init super().__init__(**kwargs) @@ -247,10 +243,10 @@ def get_mask_invalid_actions_forward( ] return mask - def state2oracle(self, state: List = None) -> Tensor: + def state2proxy(self, state: List = None) -> Tensor: """ - Prepares a list of states in "GFlowNet format" for the oracle. The input to the - oracle is simply the space group. + Prepares a list of states in "GFlowNet format" for the proxy. The input to the + proxy is simply the space group. Args ---- @@ -259,22 +255,42 @@ def state2oracle(self, state: List = None) -> Tensor: Returns ---- - oracle_state : Tensor + proxy_state : Tensor """ if state is None: state = self.state if state[self.sg_idx] == 0: raise ValueError( - "The space group must have been set in order to call the oracle" + "The space group must have been set in order to call the proxy" ) return torch.tensor(state[self.sg_idx], device=self.device, dtype=torch.long) - def statebatch2oracle( + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: + """ + Prepares a batch of states in "environment format" for the proxy: the proxy + format is simply the space group. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. + """ + states = tlong(states, device=self.device) + return torch.unsqueeze(states[:, self.sg_idx], dim=1) + + def statebatch2proxy( self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is simply the space group. + Prepares a batch of states in "GFlowNet format" for the proxy. The input to the + proxy is simply the space group. Args ---- @@ -283,18 +299,19 @@ def statebatch2oracle( Returns ---- - oracle_state : Tensor + proxy_state : Tensor """ - return self.statetorch2oracle( + return self.states2proxy(states) + return self.statetorch2proxy( torch.tensor(states, device=self.device, dtype=torch.long) ) - def statetorch2oracle( + def statetorch2proxy( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is simply the space group. + Prepares a batch of states in "GFlowNet format" for the proxy. The input to the + proxy is simply the space group. Args ---- @@ -303,8 +320,9 @@ def statetorch2oracle( Returns ---- - oracle_state : Tensor + proxy_state : Tensor """ + return self.states2proxy(states) return torch.unsqueeze(states[:, self.sg_idx], dim=1).to(torch.long) def state2readable(self, state=None): From e7ee4afaba7a4ea9aade3adb627e82fe16b6461c Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:06:56 -0400 Subject: [PATCH 28/54] Crystal: states2proxy; temporary state because old code is still there. --- gflownet/envs/crystals/crystal.py | 96 ++++++++++++++++++++----------- 1 file changed, 64 insertions(+), 32 deletions(-) diff --git a/gflownet/envs/crystals/crystal.py b/gflownet/envs/crystals/crystal.py index 0e914ce9a..e765f16a7 100644 --- a/gflownet/envs/crystals/crystal.py +++ b/gflownet/envs/crystals/crystal.py @@ -3,14 +3,13 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from torch import Tensor -from torchtyping import TensorType - from gflownet.envs.base import GFlowNetEnv from gflownet.envs.crystals.composition import Composition from gflownet.envs.crystals.lattice_parameters import LatticeParameters from gflownet.envs.crystals.spacegroup import SpaceGroup from gflownet.utils.crystals.constants import TRICLINIC +from torch import Tensor +from torchtyping import TensorType class Stage(Enum): @@ -128,11 +127,6 @@ def __init__( self.lattice_parameters.eos, Stage.LATTICE_PARAMETERS ) - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle - super().__init__(**kwargs) def _set_lattice_parameters(self): @@ -247,7 +241,7 @@ def _get_composition_state(self, state: Optional[List[int]] = None) -> List[int] def _get_composition_tensor_states( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: return states[:, self.composition_state_start : self.composition_state_end] def _get_space_group_state(self, state: Optional[List[int]] = None) -> List[int]: @@ -258,7 +252,7 @@ def _get_space_group_state(self, state: Optional[List[int]] = None) -> List[int] def _get_space_group_tensor_states( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: return states[:, self.space_group_state_start : self.space_group_state_end] def _get_lattice_parameters_state( @@ -273,7 +267,7 @@ def _get_lattice_parameters_state( def _get_lattice_parameters_tensor_states( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: return states[ :, self.lattice_parameters_state_start : self.lattice_parameters_state_end ] @@ -466,58 +460,96 @@ def get_parents( return parents, actions - def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: + def state2proxy(self, state: Optional[List[int]] = None) -> Tensor: """ - Prepares a list of states in "GFlowNet format" for the oracle. Simply + Prepares a list of states in "GFlowNet format" for the proxy. Simply a concatenation of all crystal components. """ if state is None: state = self.state.copy() - composition_oracle_state = self.composition.state2oracle( + composition_proxy_state = self.composition.state2proxy( state=self._get_composition_state(state) ).to(self.device) - space_group_oracle_state = ( - self.space_group.state2oracle(state=self._get_space_group_state(state)) - .unsqueeze(-1) # StateGroup oracle state is a single number + space_group_proxy_state = ( + self.space_group.state2proxy(state=self._get_space_group_state(state)) + .unsqueeze(-1) # StateGroup proxy state is a single number .to(self.device) ) - lattice_parameters_oracle_state = self.lattice_parameters.state2oracle( + lattice_parameters_proxy_state = self.lattice_parameters.state2proxy( state=self._get_lattice_parameters_state(state) ).to(self.device) return torch.cat( [ - composition_oracle_state, - space_group_oracle_state, - lattice_parameters_oracle_state, + composition_proxy_state, + space_group_proxy_state, + lattice_parameters_proxy_state, ] ) - def statebatch2oracle( + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: + """ + Prepares a batch of states in "environment format" for the proxy: simply a + concatenation of all crystal components. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. + """ + states = tlong(states, device=self.device) + composition_proxy_states = self.composition.statetorch2proxy( + self._get_composition_tensor_states(states) + ).to(self.device) + space_group_proxy_states = self.space_group.statetorch2proxy( + self._get_space_group_tensor_states(states) + ).to(self.device) + lattice_parameters_proxy_states = self.lattice_parameters.statetorch2proxy( + self._get_lattice_parameters_tensor_states(states) + ).to(self.device) + return torch.cat( + [ + composition_proxy_states, + space_group_proxy_states, + lattice_parameters_proxy_states, + ], + dim=1, + ) + + def statebatch2proxy( self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - return self.statetorch2oracle( + ) -> TensorType["batch", "state_proxy_dim"]: + return self.states2proxy(states) + return self.statetorch2proxy( torch.tensor(states, device=self.device, dtype=torch.long) ) - def statetorch2oracle( + def statetorch2proxy( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - composition_oracle_states = self.composition.statetorch2oracle( + ) -> TensorType["batch", "state_proxy_dim"]: + return self.states2proxy(states) + composition_proxy_states = self.composition.statetorch2proxy( self._get_composition_tensor_states(states) ).to(self.device) - space_group_oracle_states = self.space_group.statetorch2oracle( + space_group_proxy_states = self.space_group.statetorch2proxy( self._get_space_group_tensor_states(states) ).to(self.device) - lattice_parameters_oracle_states = self.lattice_parameters.statetorch2oracle( + lattice_parameters_proxy_states = self.lattice_parameters.statetorch2proxy( self._get_lattice_parameters_tensor_states(states) ).to(self.device) return torch.cat( [ - composition_oracle_states, - space_group_oracle_states, - lattice_parameters_oracle_states, + composition_proxy_states, + space_group_proxy_states, + lattice_parameters_proxy_states, ], dim=1, ) From 6792ebde01b3c9820ac012720ee5700a4f5d80eb Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:22:38 -0400 Subject: [PATCH 29/54] Base: statebatch2proxy and statetorch2proxy unified into states2proxy; statebatch2policy and statetorch2policy unified into states2policy; policy2state removed --- gflownet/envs/base.py | 93 +++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 47 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 572935064..9740018b8 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -663,72 +663,71 @@ def get_policy_output(self, params: Optional[dict] = None): """ return np.ones(self.action_space_dim) - def state2proxy(self, state: List = None): + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a state in "GFlowNet format" for the proxy. + Prepares a batch of states in "environment format" for the proxy. By default, + the batch of states is converted into a tensor with float dtype and returned as + is. Args ---- - state : list - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state.copy() - return self.statebatch2proxy([state]) + return tfloat(states, device=self.device, float_type=self.float) - def statebatch2proxy(self, states: List[List]) -> npt.NDArray[np.float32]: + def state2proxy(self, state: Union[List, TensorType["state_dim"]] = None): """ - Prepares a batch of states in "GFlowNet format" for the proxy. + Prepares a state in "GFlowNet format" for the proxy. By default, states2proxy + is called, which by default will return the state as is. Args ---- state : list A state """ - return np.array(states) + state = self._get_state(state) + return self.states2proxy([state]) - def statetorch2proxy( - self, states: TensorType["batch_size", "state_dim"] - ) -> TensorType["batch_size", "state_proxy_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the proxy. + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: """ - return states + Prepares a batch of states in "environment format" for the policy model: By + default, the batch of states is converted into a tensor with float dtype and + returned as is. - def statetorch2policy( - self, states: TensorType["batch_size", "state_dim"] - ) -> TensorType["batch_size", "policy_input_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the policy - """ - return states + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - def state2policy(self, state=None): - """ - Converts a state into a format suitable for a machine learning model, such as a - one-hot encoding. + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state - return tfloat(state, float_type=self.float, device=self.device) + return tfloat(states, device=self.device, float_type=self.float) - def statebatch2policy( - self, states: List[List] - ) -> TensorType["batch_size", "policy_input_dim"]: - """ - Converts a batch of states into a format suitable for a machine learning model, - such as a one-hot encoding. Returns a numpy array. + def state2policy(self, state: Union[List, TensorType["state_dim"]] = None): """ - return self.statetorch2policy( - tfloat(states, float_type=self.float, device=self.device) - ) + Prepares a state in "GFlowNet format" for the policy model. By default, + states2policy is called, which by default will return the state as is. - def policy2state(self, state_policy: List) -> List: - """ - Converts the model (e.g. one-hot encoding) version of a state given as - argument into a state. + Args + ---- + state : list + A state """ - return state_policy + state = self._get_state(state) + return self.states2policy([state]) def state2readable(self, state=None): """ @@ -766,7 +765,7 @@ def reward_batch(self, states: List[List], done=None): """ if done is None: done = np.ones(len(states), dtype=bool) - states_proxy = self.statebatch2proxy(states) + states_proxy = self.states2proxy(states) if isinstance(states_proxy, torch.Tensor): states_proxy = states_proxy[list(done), :] elif isinstance(states_proxy, list): @@ -786,7 +785,7 @@ def reward_torchbatch( """ if done is None: done = torch.ones(states.shape[0], dtype=torch.bool, device=self.device) - states_proxy = self.statetorch2proxy(states[done, :]) + states_proxy = self.states2proxy(states[done, :]) reward = torch.zeros(done.shape[0], dtype=self.float, device=self.device) if states[done, :].shape[0] > 0: reward[done] = self.proxy2reward(self.proxy(states_proxy)) @@ -1336,7 +1335,7 @@ def plot_reward_distribution( states = torch.vstack(states).to(self.device, self.float) if isinstance(states, torch.Tensor) == False: states = torch.tensor(states, device=self.device, dtype=self.float) - states_proxy = self.statetorch2proxy(states) + states_proxy = self.states2proxy(states) scores = self.proxy(states_proxy) if isinstance(scores, TensorType): scores = scores.cpu().detach().numpy() From c5c2fd918b0f4dcbe4fb09c2b5b9dd26b35690d1 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:25:10 -0400 Subject: [PATCH 30/54] Remove policy2state from all environments and tests because it is not only unused but ill-defined in general. --- gflownet/envs/grid.py | 12 ------------ gflownet/envs/htorus.py | 10 ---------- gflownet/envs/tetris.py | 10 ---------- gflownet/envs/torus.py | 17 ----------------- gflownet/envs/tree.py | 8 -------- tests/gflownet/envs/common.py | 10 ---------- 6 files changed, 67 deletions(-) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 34ce110bc..72d54cbd3 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -302,18 +302,6 @@ def statetorch2policy( state_policy[rows, cols.flatten()] = 1.0 return state_policy - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a state given as argument - into a state (list of the position at each dimension). - - Example: - - state_policy: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] (length = 4, n_dim = 3) - | 0 | 3 | 1 | - - policy2state(state_policy): [0, 3, 1] - """ - return np.where(np.reshape(state_policy, (self.n_dim, self.length)))[1].tolist() - def readable2state(self, readable, alphabet={}): """ Converts a human-readable string representing a state into a state as a list of diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index f14d53adb..07ee53594 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -310,16 +310,6 @@ def statebatch2policy( states = tfloat(states, float_type=self.float, device=self.device) return self.statetorch2policy(states) - def policy2state(self, state_policy: List) -> List: - """ - Returns the input as is. - """ - if self.policy_encoding_dim_per_angle is not None: - raise NotImplementedError( - "Convertion from encoded policy_state to state is not impemented" - ) - return state_policy - def state2readable(self, state: List) -> str: """ Converts a state (a list of positions) into a human-readable string diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index e9ff4d542..cbef1c945 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -372,16 +372,6 @@ def statetorch2policy( return self.states2policy(states) return self.statetorch2proxy(states).flatten(start_dim=1) - def policy2state( - self, policy: Optional[TensorType["height", "width"]] = None - ) -> TensorType["height", "width"]: - """ - Returns None to signal that the conversion is not reversible. - - See: state2proxy() - """ - return None - def state2readable(self, state: Optional[TensorType["height", "width"]] = None): """ Converts a state (board) into a human-friendly string. diff --git a/gflownet/envs/torus.py b/gflownet/envs/torus.py index bb6fca4c0..e7ce84742 100644 --- a/gflownet/envs/torus.py +++ b/gflownet/envs/torus.py @@ -259,23 +259,6 @@ def statetorch2policy( state_policy[:, -1] = states[:, -1] return state_policy - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a state given as argument - into a state (list of the position at each dimension). - - Example, n_dim = 2, n_angles = 4: - - state_policy: [0, 1, 0, 0, 0, 0, 0, 1, 4] - | 0 | 3 | 4 | - - policy2state(state_policy): [1, 3, 4] - | a | n | (a = angles, n = n_actions) - """ - mat_angles_policy = np.reshape( - state_policy[: self.n_dim * self.n_angles], (self.n_dim, self.n_angles) - ) - angles = np.where(mat_angles_policy)[1].tolist() - return angles + [int(state_policy[-1])] - def state2readable(self, state: Optional[List] = None) -> str: """ Converts a state (a list of positions) into a human-readable string diff --git a/gflownet/envs/tree.py b/gflownet/envs/tree.py index 51a630fd3..bdacdeda8 100644 --- a/gflownet/envs/tree.py +++ b/gflownet/envs/tree.py @@ -854,14 +854,6 @@ def statetorch2policy_mlp( states = torch.cat([states[:, :, : Attribute.ACTIVE], active_features], dim=1) return states.flatten(start_dim=1) - def policy2state( - self, policy: Optional[TensorType["policy_input_dim"]] = None - ) -> None: - """ - Returns None to signal that the conversion is not reversible. - """ - return None - def statebatch2proxy( self, states: List[TensorType["state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 9f5d1b9a3..59b1aacdd 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -219,16 +219,6 @@ def test__sample_backwards_reaches_source(env, n=100): assert n_actions <= env.max_traj_length -@pytest.mark.repeat(100) -def test__state2policy__is_reversible(env): - env = env.reset() - while not env.done: - state_recovered = env.policy2state(env.state2policy()) - if state_recovered is not None: - assert env.equal(env.state, state_recovered) - env.step_random() - - @pytest.mark.repeat(100) def test__state2readable__is_reversible(env): env = env.reset() From 7c2ba156ab81804e715d5e823b71f29b44156aa7 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:43:59 -0400 Subject: [PATCH 31/54] Tree: states2proxy --- gflownet/envs/tree.py | 36 ++++++++---------------------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/gflownet/envs/tree.py b/gflownet/envs/tree.py index bdacdeda8..13684f4a3 100644 --- a/gflownet/envs/tree.py +++ b/gflownet/envs/tree.py @@ -285,8 +285,7 @@ def __init__( # Conversions policy_format = policy_format.lower() if policy_format == "mlp": - self.state2policy = self.state2policy_mlp - self.statetorch2policy = self.statetorch2policy_mlp + self.states2policy = self.states2policy_mlp elif policy_format != "gnn": raise ValueError( f"Unrecognized policy_format = {policy_format}, expected either 'mlp' or 'gnn'." @@ -829,24 +828,19 @@ def get_logprobs( is_backward, ) - def state2policy_mlp( - self, state: Optional[TensorType["state_dim"]] = None - ) -> TensorType["policy_input_dim"]: - """ - Prepares a state in "GFlowNet format" for the policy model. - """ - if state is None: - state = self.state.clone().detach() - return self.statetorch2policy_mlp(state.unsqueeze(0))[0] - - def statetorch2policy_mlp( - self, states: TensorType["batch_size", "state_dim"] + def states2policy_mlp( + self, + states: Union[ + List[TensorType["state_dim"]], TensorType["batch_size", "state_dim"] + ], ) -> TensorType["batch_size", "policy_input_dim"]: """ Prepares a batch of states in torch "GFlowNet format" for an MLP policy model. It replaces the NaNs by -2s, removes the activity attribute, and explicitly appends the attribute vector of the active node (if present). """ + if isinstance(states, list): + states = torch.stack(states) rows, cols = torch.where(states[:, :-1, Attribute.ACTIVE] == Status.ACTIVE) active_features = torch.full((states.shape[0], 1, 4), -2.0) active_features[rows] = states[rows, cols, : Attribute.ACTIVE].unsqueeze(1) @@ -854,20 +848,6 @@ def statetorch2policy_mlp( states = torch.cat([states[:, :, : Attribute.ACTIVE], active_features], dim=1) return states.flatten(start_dim=1) - def statebatch2proxy( - self, states: List[TensorType["state_dim"]] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the proxy: simply - stacks the list of tensors and calls self.statetorch2proxy. - - Args - ---- - state : list - """ - states = torch.stack(states) - return self.statetorch2proxy(states) - def _attributes_to_readable(self, attributes: List) -> str: # Node type if attributes[Attribute.TYPE] == NodeType.CONDITION: From df5af321b0245fed0305e47f1ddb626d6e07be3c Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:44:40 -0400 Subject: [PATCH 32/54] All environments: remove statebatch2proxy, statetorch2proxy, statebatch2policy, statetorch2policy, state2policy and state2proxy --- gflownet/envs/alaninedipeptide.py | 23 +---- gflownet/envs/crystals/composition.py | 57 +++--------- gflownet/envs/crystals/crystal.py | 41 ++------- gflownet/envs/crystals/lattice_parameters.py | 30 +------ gflownet/envs/crystals/spacegroup.py | 45 +--------- gflownet/envs/cube.py | 43 +-------- gflownet/envs/grid.py | 93 +------------------- gflownet/envs/htorus.py | 75 +--------------- gflownet/envs/tetris.py | 61 +------------ gflownet/envs/torus.py | 60 ------------- 10 files changed, 32 insertions(+), 496 deletions(-) diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/alaninedipeptide.py index 00d9d4a42..04b8e39b8 100644 --- a/gflownet/envs/alaninedipeptide.py +++ b/gflownet/envs/alaninedipeptide.py @@ -67,27 +67,7 @@ def states2proxy( else: return np.array(states)[:, :-1] - def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray: - """ - Prepares a batch of states in torch "GFlowNet format" for the oracle. - """ - return self.states2proxy(states) - device = states.device - if device == torch.device("cpu"): - np_states = states.numpy() - else: - np_states = states.cpu().numpy() - return np_states[:, :-1] - - def statebatch2proxy(self, states: List[List]) -> npt.NDArray: - """ - Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where - each state is a row of length n_dim with an angle in radians. The n_actions - item is removed. - """ - return self.states2proxy(states) - return np.array(states)[:, :-1] - + # TODO: need to keep? def statetorch2oracle( self, states: TensorType["batch", "state_dim"] ) -> List[Tuple[npt.NDArray, npt.NDArray]]: @@ -102,6 +82,7 @@ def statetorch2oracle( result = self.statebatch2oracle(np_states) return result + # TODO: need to keep? def statebatch2oracle( self, states: List[List] ) -> List[Tuple[npt.NDArray, npt.NDArray]]: diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 955a512f1..15212ae0d 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -6,6 +6,10 @@ import numpy as np import torch +from pyxtal.symmetry import Group +from torch import Tensor +from torchtyping import TensorType + from gflownet.envs.base import GFlowNetEnv from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import ELEMENT_NAMES, OXIDATION_STATES @@ -15,9 +19,6 @@ space_group_lowest_free_wp_multiplicity, space_group_wyckoff_gcd, ) -from pyxtal.symmetry import Group -from torch import Tensor -from torchtyping import TensorType class Composition(GFlowNetEnv): @@ -399,54 +400,22 @@ def get_element_mask(min_atoms, max_atoms): return mask - def state2proxy(self, state: List = None) -> Tensor: - """ - Prepares a state in "GFlowNet format" for the proxy. In this case, it simply - converts the state into a torch tensor, with dtype torch.long. - - Args - ---- - state : list - A state - - Returns - ---- - proxy_state : Tensor - Tensor containing counts of individual elements - """ - if state is None: - state = self.state - - return tlong(state, device=self.device) - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the proxy. The input to the - proxy is the atom counts for individual elements. + Prepares a batch of states in "environment format" for the proxy: simply + returns the states as are with dtype long. Args ---- - states : Tensor - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. Returns - ---- - proxy_states : Tensor - """ - return states - - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the proxy. In this case, - it simply converts the states into a torch tensor, with dtype torch.long. - - Args - ---- - state : list + ------- + A tensor containing all the states in the batch. """ return tlong(states, device=self.device) diff --git a/gflownet/envs/crystals/crystal.py b/gflownet/envs/crystals/crystal.py index e765f16a7..11186176d 100644 --- a/gflownet/envs/crystals/crystal.py +++ b/gflownet/envs/crystals/crystal.py @@ -3,13 +3,14 @@ from typing import Dict, List, Optional, Tuple, Union import torch +from torch import Tensor +from torchtyping import TensorType + from gflownet.envs.base import GFlowNetEnv from gflownet.envs.crystals.composition import Composition from gflownet.envs.crystals.lattice_parameters import LatticeParameters from gflownet.envs.crystals.spacegroup import SpaceGroup from gflownet.utils.crystals.constants import TRICLINIC -from torch import Tensor -from torchtyping import TensorType class Stage(Enum): @@ -506,43 +507,13 @@ def states2proxy( A tensor containing all the states in the batch. """ states = tlong(states, device=self.device) - composition_proxy_states = self.composition.statetorch2proxy( - self._get_composition_tensor_states(states) - ).to(self.device) - space_group_proxy_states = self.space_group.statetorch2proxy( - self._get_space_group_tensor_states(states) - ).to(self.device) - lattice_parameters_proxy_states = self.lattice_parameters.statetorch2proxy( - self._get_lattice_parameters_tensor_states(states) - ).to(self.device) - return torch.cat( - [ - composition_proxy_states, - space_group_proxy_states, - lattice_parameters_proxy_states, - ], - dim=1, - ) - - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_proxy_dim"]: - return self.states2proxy(states) - return self.statetorch2proxy( - torch.tensor(states, device=self.device, dtype=torch.long) - ) - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - return self.states2proxy(states) - composition_proxy_states = self.composition.statetorch2proxy( + composition_proxy_states = self.composition.states2proxy( self._get_composition_tensor_states(states) ).to(self.device) - space_group_proxy_states = self.space_group.statetorch2proxy( + space_group_proxy_states = self.space_group.states2proxy( self._get_space_group_tensor_states(states) ).to(self.device) - lattice_parameters_proxy_states = self.lattice_parameters.statetorch2proxy( + lattice_parameters_proxy_states = self.lattice_parameters.states2proxy( self._get_lattice_parameters_tensor_states(states) ).to(self.device) return torch.cat( diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index ffdd78e72..11090bf02 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -5,6 +5,9 @@ import numpy as np import torch +from torch import Tensor +from torchtyping import TensorType + from gflownet.envs.grid import Grid from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import ( @@ -17,8 +20,6 @@ TETRAGONAL, TRICLINIC, ) -from torch import Tensor -from torchtyping import TensorType class LatticeParameters(Grid): @@ -384,31 +385,6 @@ def states2proxy( dim=1, ) - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the proxy. The input to the - proxy is the lengths and angles. - - Args - ---- - states : Tensor - A state - - Returns - ---- - proxy_states : Tensor - """ - return self.states2proxy(states) - return torch.cat( - [ - self.lengths_tensor[states[:, :3].long()], - self.angles_tensor[states[:, 3:].long()], - ], - dim=1, - ) - def state2readable(self, state: Optional[List[int]] = None) -> str: """ Converts the state into a human-readable string in the format "(a, b, c), (alpha, beta, gamma)". diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index ab1425c7e..2f4052bcc 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -10,11 +10,12 @@ import numpy as np import torch import yaml +from torch import Tensor +from torchtyping import TensorType + from gflownet.envs.base import GFlowNetEnv from gflownet.utils.common import tlong from gflownet.utils.crystals.pyxtal_cache import space_group_check_compatible -from torch import Tensor -from torchtyping import TensorType CRYSTAL_LATTICE_SYSTEMS = None POINT_SYMMETRIES = None @@ -285,46 +286,6 @@ def states2proxy( states = tlong(states, device=self.device) return torch.unsqueeze(states[:, self.sg_idx], dim=1) - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the proxy. The input to the - proxy is simply the space group. - - Args - ---- - state : list - A state - - Returns - ---- - proxy_state : Tensor - """ - return self.states2proxy(states) - return self.statetorch2proxy( - torch.tensor(states, device=self.device, dtype=torch.long) - ) - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the proxy. The input to the - proxy is simply the space group. - - Args - ---- - state : list - A state - - Returns - ---- - proxy_state : Tensor - """ - return self.states2proxy(states) - return torch.unsqueeze(states[:, self.sg_idx], dim=1).to(torch.long) - def state2readable(self, state=None): """ Transforms the state, represented as a list of property indices, into a diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index 7eb08f513..cbb4ba2fe 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -94,9 +94,8 @@ def __init__( self.kappa = kappa # Conversions: only conversions to policy are implemented and the conversion to # proxy format is the same + self.states2proxy = self.states2policy self.state2proxy = self.state2policy - self.statebatch2proxy = self.statebatch2policy - self.statetorch2proxy = self.statetorch2policy # Base class init super().__init__( fixed_distr_params=fixed_distr_params, @@ -145,44 +144,6 @@ def states2policy( states = float(states, device=self.device, float_type=self.float) return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] = None - ) -> TensorType["batch", "policy_input_dim"]: - """ - Clips the states into [0, 1] and maps them to [-1.0, 1.0] - - Args - ---- - state : list - State - """ - return self.states2policy(states) - return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 - - def statebatch2policy( - self, states: List[List] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Clips the states into [0, 1] and maps them to [-1.0, 1.0] - - Args - ---- - state : list - State - """ - return self.states2policy(states) - return self.statetorch2policy( - tfloat(states, device=self.device, float_type=self.float) - ) - - def state2policy(self, state: List = None) -> List: - """ - Clips the state into [0, 1] and maps it to [-1.0, 1.0] - """ - if state is None: - state = self.state.copy() - return [2.0 * min(max(0.0, s), 1.0) - 1.0 for s in state] - def state2readable(self, state: List) -> str: """ Converts a state (a list of positions) into a human-readable string @@ -1360,7 +1321,7 @@ def sample_from_reward( samples_final = [] max_reward = self.proxy2reward(self.proxy.min) while len(samples_final) < n_samples: - samples_uniform = self.statebatch2proxy( + samples_uniform = self.states2proxy( self.get_uniform_terminating_states(n_samples) ) rewards = self.proxy2reward(self.proxy(samples_uniform)) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 72d54cbd3..4e6e90c87 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -81,7 +81,7 @@ def __init__( # Proxy format # TODO: assess if really needed if self.proxy_state_format == "ohe": - self.statebatch2proxy = self.statebatch2policy + self.states2proxy = self.states2policy def get_action_space(self): """ @@ -175,60 +175,6 @@ def states2proxy( * torch.tensor(self.cells[None, :]).to(states.device, self.float) ).sum(axis=2) - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: each state is - a vector of length n_dim with values in the range [cell_min, cell_max]. - - See: statetorch2proxy() - - Args - ---- - state : list - State - """ - return self.states2proxy(states) - return self.statetorch2proxy( - tfloat(states, device=self.device, float_type=self.float) - ) - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: each state is - a vector of length n_dim with values in the range [cell_min, cell_max]. - - See: statetorch2policy() - """ - return self.states2proxy(states) - return ( - self.statetorch2policy(states).reshape( - (len(states), self.n_dim, self.length) - ) - * torch.tensor(self.cells[None, :]).to(states.device, self.float) - ).sum(axis=2) - - def state2policy(self, state: List = None) -> List: - """ - Transforms the state given as argument (or self.state if None) into a - one-hot encoding. The output is a list of len length * n_dim, - where each n-th successive block of length elements is a one-hot encoding of - the position in the n-th dimension. - - Example: - - State, state: [0, 3, 1] (n_dim = 3) - - state2policy(state): [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] (length = 4) - | 0 | 3 | 1 | - """ - if state is None: - state = self.state.copy() - state_policy = np.zeros(self.length * self.n_dim, dtype=np.float32) - state_policy[(np.arange(len(state)) * self.length + state)] = 1 - return state_policy.tolist() - def states2policy( self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch", "policy_input_dim"]: @@ -265,43 +211,6 @@ def states2policy( states_policy[rows, cols.flatten()] = 1.0 return states_policy - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Transforms a batch of states into a one-hot encoding. The output is a numpy - array of shape [n_states, length * n_dim]. - - See state2policy(). - """ - return self.states2policy(states) - cols = np.array(states) + np.arange(self.n_dim) * self.length - rows = np.repeat(np.arange(len(states)), self.n_dim) - state_policy = np.zeros( - (len(states), self.length * self.n_dim), dtype=np.float32 - ) - state_policy[rows, cols.flatten()] = 1.0 - return state_policy - - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "policy_output_dim"]: - """ - Transforms a batch of states into a one-hot encoding. The output is a numpy - array of shape [n_states, length * n_dim]. - - See state2policy(). - """ - return self.states2policy(states) - device = states.device - cols = (states + torch.arange(self.n_dim).to(device) * self.length).to(int) - rows = torch.repeat_interleave( - torch.arange(states.shape[0]).to(device), self.n_dim - ) - state_policy = torch.zeros( - (states.shape[0], self.length * self.n_dim), dtype=states.dtype - ).to(device) - state_policy[rows, cols.flatten()] = 1.0 - return state_policy - def readable2state(self, readable, alphabet={}): """ Converts a human-readable string representing a state into a state as a list of diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 07ee53594..ff1dd5f69 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -199,36 +199,6 @@ def states2proxy( """ return tfloat(states, device=self.device, float_type=self.float)[:, :-1] - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where - each state is a row of length n_dim with an angle in radians. The n_actions - item is removed. - """ - return self.states2proxy(states) - return torch.tensor(states, device=self.device)[:, :-1] - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the proxy. - """ - return self.states2proxy(states) - return states[:, :-1] - - def state2policy(self, state: List = None) -> List: - """ - Returns the policy encoding of the state. - - See: statebatch2policy() - """ - if state is None: - state = self.state.copy() - return self.statebatch2policy([state]).tolist()[0] - def states2policy( self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch", "policy_input_dim"]: @@ -267,49 +237,6 @@ def states2policy( dim=1, ) - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "policy_input_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the policy. - - If policy_encoding_dim_per_angle >= 2, then the state (angles) is encoded using - trigonometric components. - """ - return self.states2policy(states) - if ( - self.policy_encoding_dim_per_angle is not None - and self.policy_encoding_dim_per_angle >= 2 - ): - step = states[:, -1] - code_half_size = self.policy_encoding_dim_per_angle // 2 - int_coeff = ( - torch.arange(1, code_half_size + 1) - .repeat(states.shape[-1] - 1) - .to(states) - ) - encoding = ( - torch.repeat_interleave(states[:, :-1], repeats=code_half_size, dim=1) - * int_coeff - ) - states = torch.cat( - [torch.cos(encoding), torch.sin(encoding), torch.unsqueeze(step, 1)], - dim=1, - ) - return states - - def statebatch2policy( - self, states: List[List] - ) -> TensorType["batch_size", "policy_input_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the policy. - - See: statetorch2policy() - """ - return self.states2policy(states) - states = tfloat(states, float_type=self.float, device=self.device) - return self.statetorch2policy(states) - def state2readable(self, state: List) -> str: """ Converts a state (a list of positions) into a human-readable string @@ -652,7 +579,7 @@ def plot_reward_samples( [samples_mesh, torch.ones(samples_mesh.shape[0], 1)], 1 ).to(self.device) rewards = torch2np( - self.proxy2reward(self.proxy(self.statetorch2proxy(states_mesh))) + self.proxy2reward(self.proxy(self.states2proxy(states_mesh))) ) # Init figure fig, ax = plt.subplots() diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index cbef1c945..744b10001 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -289,43 +289,6 @@ def states2proxy( states[states != 0] = 1 return states - def statebatch2proxy( - self, states: List[TensorType["height", "width"]] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: simply - converts non-zero (non-empty) cells into 1s. - - Args - ---- - state : list - """ - return self.states2proxy(states) - states = torch.stack(states) - states[states != 0] = 1 - return states - - def statetorch2proxy( - self, states: TensorType["height", "width", "batch"] - ) -> TensorType["height", "width", "batch"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: : simply - converts non-zero (non-empty) cells into 1s. - """ - return self.states2proxy(states) - states[states != 0] = 1 - return states - - def state2policy( - self, state: Optional[TensorType["height", "width"]] = None - ) -> TensorType["height", "width"]: - """ - Prepares a state in "GFlowNet format" for the policy model. - - See: state2proxy() - """ - return self.state2proxy(state).flatten() - def states2policy( self, states: Union[ @@ -335,7 +298,7 @@ def states2policy( """ Prepares a batch of states in "environment format" for the policy model. - See statetorch2proxy(). + See states2proxy(). Args ---- @@ -350,28 +313,6 @@ def states2policy( states = tint(states, device=self.device, int_type=self.int) return self.states2proxy(states).flatten(start_dim=1) - def statebatch2policy( - self, states: List[TensorType["height", "width"]] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the policy model. - - See statebatch2proxy(). - """ - return self.states2policy(states) - return self.statebatch2proxy(states).flatten(start_dim=1) - - def statetorch2policy( - self, states: TensorType["height", "width", "batch"] - ) -> TensorType["height", "width", "batch"]: - """ - Prepares a batch of states in "GFlowNet format" for the policy model. - - See statetorch2proxy(). - """ - return self.states2policy(states) - return self.statetorch2proxy(states).flatten(start_dim=1) - def state2readable(self, state: Optional[TensorType["height", "width"]] = None): """ Converts a state (board) into a human-friendly string. diff --git a/gflownet/envs/torus.py b/gflownet/envs/torus.py index e7ce84742..a298f3318 100644 --- a/gflownet/envs/torus.py +++ b/gflownet/envs/torus.py @@ -133,24 +133,6 @@ def states2proxy( * self.angle_rad ) - def statebatch2proxy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Prepares a batch of states in "GFlowNet format" for the proxy: an array where - each state is a row of length n_dim with an angle in radians. The n_actions - item is removed. - """ - return self.states2proxy(states) - return torch.tensor(states, device=self.device)[:, :-1] * self.angle_rad - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the proxy. - """ - return self.states2proxy(states) - return states[:, :-1] * self.angle_rad - # TODO: circular encoding as in htorus def state2policy(self, state=None) -> List: """ @@ -217,48 +199,6 @@ def states2policy( states_policy[:, -1] = states[:, -1] return states_policy - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Transforms a batch of states into the policy model format. The output is a numpy - array of shape [n_states, n_angles * n_dim + 1]. - - See state2policy(). - """ - return self.states2policy(states) - states = np.array(states) - cols = states[:, :-1] + np.arange(self.n_dim) * self.n_angles - rows = np.repeat(np.arange(states.shape[0]), self.n_dim) - state_policy = np.zeros( - (len(states), self.n_angles * self.n_dim + 1), dtype=np.float32 - ) - state_policy[rows, cols.flatten()] = 1.0 - state_policy[:, -1] = states[:, -1] - return state_policy - - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "policy_output_dim"]: - """ - Transforms a batch of torch states into the policy model format. The output is - a tensor of shape [n_states, n_angles * n_dim + 1]. - - See state2policy(). - """ - return self.states2policy(states) - device = states.device - cols = ( - states[:, :-1] + torch.arange(self.n_dim).to(device) * self.n_angles - ).to(int) - rows = torch.repeat_interleave( - torch.arange(states.shape[0]).to(device), self.n_dim - ) - state_policy = torch.zeros( - (states.shape[0], self.n_angles * self.n_dim + 1) - ).to(states) - state_policy[rows, cols.flatten()] = 1.0 - state_policy[:, -1] = states[:, -1] - return state_policy - def state2readable(self, state: Optional[List] = None) -> str: """ Converts a state (a list of positions) into a human-readable string From 9107d6308ea440cfd8b2ced37d4d21653eb29c5c Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:46:43 -0400 Subject: [PATCH 33/54] All tests: remove statebatch2proxy, statetorch2proxy, statebatch2policy, statetorch2policy, state2policy and state2proxy --- tests/gflownet/envs/test_crystal.py | 4 ++-- tests/gflownet/envs/test_grid.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/gflownet/envs/test_crystal.py b/tests/gflownet/envs/test_crystal.py index c11cac1ec..53811228a 100644 --- a/tests/gflownet/envs/test_crystal.py +++ b/tests/gflownet/envs/test_crystal.py @@ -112,8 +112,8 @@ def test__state2proxy__returns_expected_value(env, state, expected): ], ], ) -def test__statebatch2proxy__returns_expected_value(env, batch, expected): - assert torch.allclose(env.statebatch2proxy(batch), expected, atol=1e-4) +def test__states2proxy__returns_expected_value(env, batch, expected): + assert torch.allclose(env.states2proxy(batch), expected, atol=1e-4) @pytest.mark.parametrize("action", [(1, 1, -2, -2, -2, -2), (3, 4, -2, -2, -2, -2)]) diff --git a/tests/gflownet/envs/test_grid.py b/tests/gflownet/envs/test_grid.py index 747685554..27ebbadd5 100644 --- a/tests/gflownet/envs/test_grid.py +++ b/tests/gflownet/envs/test_grid.py @@ -70,7 +70,7 @@ def test__state2proxy__returns_expected(env, state, state2proxy): @pytest.mark.parametrize( - "states, statebatch2proxy", + "states, states2proxy", [ ( [[0, 0, 0], [4, 4, 4], [1, 2, 3], [4, 0, 1]], @@ -78,8 +78,8 @@ def test__state2proxy__returns_expected(env, state, state2proxy): ), ], ) -def test__statebatch2proxy__returns_expected(env, states, statebatch2proxy): - assert torch.equal(torch.Tensor(statebatch2proxy), env.statebatch2proxy(states)) +def test__states2proxy__returns_expected(env, states, states2proxy): + assert torch.equal(torch.Tensor(states2proxy), env.states2proxy(states)) @pytest.mark.parametrize( From ed0329428487dbf64bcea2f7ed4cf9436408b619 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 10:50:29 -0400 Subject: [PATCH 34/54] gflownet, buffer and batch: remove statebatch2proxy, statetorch2proxy, statebatch2policy, statetorch2policy, state2policy and state2proxy --- gflownet/gflownet.py | 10 +++++----- gflownet/utils/batch.py | 19 ++++--------------- gflownet/utils/buffer.py | 2 +- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index c9cac3146..6fd8fcaef 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -325,7 +325,7 @@ def sample_actions( # Check for at least one non-random action if idx_norandom.sum() > 0: states_policy = tfloat( - self.env.statebatch2policy( + self.env.states2policy( [s for s, do in zip(states, idx_norandom) if do] ), device=self.device, @@ -1040,8 +1040,8 @@ def test(self, **plot_kwargs): ) elif self.continuous: # TODO make it work with conditional env - x_sampled = torch2np(self.env.statebatch2proxy(x_sampled)) - x_tt = torch2np(self.env.statebatch2proxy(x_tt)) + x_sampled = torch2np(self.env.states2proxy(x_sampled)) + x_tt = torch2np(self.env.states2proxy(x_tt)) kde_pred = self.env.fit_kde( x_sampled, kernel=self.logger.test.kde.kernel, @@ -1055,7 +1055,7 @@ def test(self, **plot_kwargs): x_from_reward = self.env.sample_from_reward( n_samples=self.logger.test.n ) - x_from_reward = torch2np(self.env.statetorch2proxy(x_from_reward)) + x_from_reward = torch2np(self.env.states2proxy(x_from_reward)) # Fit KDE with samples from reward kde_true = self.env.fit_kde( x_from_reward, @@ -1325,7 +1325,7 @@ def logq(traj_list, actions_list, model, env): with torch.no_grad(): logits_traj = model( tfloat( - env.statebatch2policy(traj), + env.states2policy(traj), device=self.device, float_type=self.float, ) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index a35f01ddf..9ab3dc24d 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -408,12 +408,7 @@ def states2policy( self.get_states_of_trajectory(traj_idx, states, traj_indices) ) return states_policy - # TODO: do we need tfloat or is done in env.statebatch2policy? - return tfloat( - self.env.statebatch2policy(states), - device=self.device, - float_type=self.float, - ) + return self.env.states2policy(states) def states2proxy( self, @@ -461,7 +456,7 @@ def states2proxy( if traj_idx not in traj_indices: continue states_proxy.append( - self.envs[traj_idx].statebatch2proxy( + self.envs[traj_idx].states2proxy( self.get_states_of_trajectory(traj_idx, states, traj_indices) ) ) @@ -471,7 +466,7 @@ def states2proxy( index[perm_index] = index.clone() states_proxy = concat_items(states_proxy, index) return states_proxy - return self.env.statebatch2proxy(states) + return self.env.states2proxy(states) def get_actions(self) -> TensorType["n_states, action_dim"]: """ @@ -678,13 +673,7 @@ def _compute_parents_all(self): self.parents_all.extend(parents) self.parents_actions_all.extend(parents_a) self.parents_all_indices.extend([idx] * len(parents)) - self.parents_all_policy.append( - tfloat( - self.envs[traj_idx].statebatch2policy(parents), - device=self.device, - float_type=self.float, - ) - ) + self.parents_all_policy.append(self.envs[traj_idx].states2policy(parents)) # Convert to tensors self.parents_actions_all = tfloat( self.parents_actions_all, diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index c7fd61ade..5f852b7ec 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -230,7 +230,7 @@ def make_data_set(self, config): samples = self.env.get_random_terminating_states(config.n) else: return None, None - energies = self.env.proxy(self.env.statebatch2proxy(samples)).tolist() + energies = self.env.proxy(self.env.states2proxy(samples)).tolist() df = pd.DataFrame( { "samples": [self.env.state2readable(s) for s in samples], From c9eac7927a1153508beafc2b0049b615416634fe Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 11:12:29 -0400 Subject: [PATCH 35/54] Fix how policy_input_dim is computed --- gflownet/envs/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 9740018b8..7b98b7b15 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -90,7 +90,7 @@ def __init__( self.fixed_policy_output = self.get_policy_output(self.fixed_distr_params) self.random_policy_output = self.get_policy_output(self.random_distr_params) self.policy_output_dim = len(self.fixed_policy_output) - self.policy_input_dim = len(self.state2policy()) + self.policy_input_dim = self.state2policy().shape[1] @abstractmethod def get_action_space(self): From 828141c057c27911907ec184ed5cef524475d0a8 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 11:12:45 -0400 Subject: [PATCH 36/54] Tetris: policy output to float --- gflownet/envs/tetris.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index 744b10001..6338cc186 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -311,7 +311,7 @@ def states2policy( A tensor containing all the states in the batch. """ states = tint(states, device=self.device, int_type=self.int) - return self.states2proxy(states).flatten(start_dim=1) + return self.states2proxy(states).flatten(start_dim=1).to(self.float) def state2readable(self, state: Optional[TensorType["height", "width"]] = None): """ From 94c0c329dc2fd5c4c1aa0439ee62066705c1dfc7 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 11:14:27 -0400 Subject: [PATCH 37/54] Fix typo --- gflownet/envs/cube.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index cbb4ba2fe..51ebb1d6b 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -141,7 +141,7 @@ def states2policy( ------- A tensor containing all the states in the batch. """ - states = float(states, device=self.device, float_type=self.float) + states = tfloat(states, device=self.device, float_type=self.float) return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 def state2readable(self, state: List) -> str: From ad2c7f6bc762bd4669d43a1fa391a36083087363 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 11:33:02 -0400 Subject: [PATCH 38/54] squeeze output of state2policy and state2proxy and revert to previous way of getting policy_input_dim --- gflownet/envs/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 7b98b7b15..26e2a9668 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -90,7 +90,7 @@ def __init__( self.fixed_policy_output = self.get_policy_output(self.fixed_distr_params) self.random_policy_output = self.get_policy_output(self.random_distr_params) self.policy_output_dim = len(self.fixed_policy_output) - self.policy_input_dim = self.state2policy().shape[1] + self.policy_input_dim = len(self.state2policy()) @abstractmethod def get_action_space(self): @@ -694,7 +694,7 @@ def state2proxy(self, state: Union[List, TensorType["state_dim"]] = None): A state """ state = self._get_state(state) - return self.states2proxy([state]) + return torch.squeeze(self.states2proxy([state]), dim=0) def states2policy( self, states: Union[List, TensorType["batch", "state_dim"]] @@ -727,7 +727,7 @@ def state2policy(self, state: Union[List, TensorType["state_dim"]] = None): A state """ state = self._get_state(state) - return self.states2policy([state]) + return torch.squeeze(self.states2policy([state]), dim=0) def state2readable(self, state=None): """ From 17407b59dcdc5a46f08068f3181a07dd292de69f Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 11:37:18 -0400 Subject: [PATCH 39/54] Add missing import --- gflownet/envs/crystals/crystal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gflownet/envs/crystals/crystal.py b/gflownet/envs/crystals/crystal.py index 11186176d..189b7927c 100644 --- a/gflownet/envs/crystals/crystal.py +++ b/gflownet/envs/crystals/crystal.py @@ -10,6 +10,7 @@ from gflownet.envs.crystals.composition import Composition from gflownet.envs.crystals.lattice_parameters import LatticeParameters from gflownet.envs.crystals.spacegroup import SpaceGroup +from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import TRICLINIC From 1f817b1029c03d7ced30f5bc80af008ab5cae268 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 11:37:34 -0400 Subject: [PATCH 40/54] Update composition and crystal tests --- tests/gflownet/envs/test_composition.py | 5 +++-- tests/gflownet/envs/test_crystal.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/gflownet/envs/test_composition.py b/tests/gflownet/envs/test_composition.py index 888e7221a..70250e587 100644 --- a/tests/gflownet/envs/test_composition.py +++ b/tests/gflownet/envs/test_composition.py @@ -4,6 +4,7 @@ import torch from gflownet.envs.crystals.composition import Composition +from gflownet.utils.common import tlong @pytest.fixture @@ -50,8 +51,8 @@ def test__environment__initializes_properly(elements): ), ], ) -def test__state2oracle__returns_expected_tensor(env, state, exp_tensor): - assert torch.equal(env.state2oracle(state), torch.Tensor(exp_tensor)) +def test__state2proxy__returns_expected_tensor(env, state, exp_tensor): + assert torch.equal(env.state2proxy(state), tlong(exp_tensor, device=env.device)) def test__state2readable(env): diff --git a/tests/gflownet/envs/test_crystal.py b/tests/gflownet/envs/test_crystal.py index 53811228a..ffbd50baa 100644 --- a/tests/gflownet/envs/test_crystal.py +++ b/tests/gflownet/envs/test_crystal.py @@ -74,8 +74,8 @@ def test__pad_depad_action(env): ], ], ) -def test__state2oracle__returns_expected_value(env, state, expected): - assert torch.allclose(env.state2oracle(state), expected, atol=1e-4) +def test__state2proxy__returns_expected_value(env, state, expected): + assert torch.allclose(env.state2proxy(state), expected, atol=1e-4) @pytest.mark.parametrize( From ea16fc204881d55c4f1f1c76a7b8ae3cda3ba54a Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 11:44:39 -0400 Subject: [PATCH 41/54] Discrete torus: Remove state2policy and output of policy to float --- gflownet/envs/torus.py | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/gflownet/envs/torus.py b/gflownet/envs/torus.py index a298f3318..54b1183a3 100644 --- a/gflownet/envs/torus.py +++ b/gflownet/envs/torus.py @@ -133,33 +133,6 @@ def states2proxy( * self.angle_rad ) - # TODO: circular encoding as in htorus - def state2policy(self, state=None) -> List: - """ - Transforms the angles part of the state given as argument (or self.state if - None) into a one-hot encoding. The output is a list of len n_angles * n_dim + - 1, where each n-th successive block of length elements is a one-hot encoding of - the position in the n-th dimension. - - Example, n_dim = 2, n_angles = 4: - - State, state: [1, 3, 4] - | a | n | (a = angles, n = n_actions) - - state2policy(state): [0, 1, 0, 0, 0, 0, 0, 1, 4] - | 1 | 3 | 4 | - """ - if state is None: - state = self.state.copy() - # TODO: do we need float32? - # TODO: do we need one-hot? - state_policy = np.zeros(self.n_angles * self.n_dim + 1, dtype=np.float32) - # Angles - state_policy[: self.n_dim * self.n_angles][ - (np.arange(self.n_dim) * self.n_angles + state[: self.n_dim]) - ] = 1 - # Number of actions - state_policy[-1] = state[-1] - return state_policy - # TODO: circular encoding as in htorus def states2policy( self, states: Union[List, TensorType["batch", "state_dim"]] @@ -197,7 +170,7 @@ def states2policy( ).to(states) states_policy[rows, cols.flatten()] = 1.0 states_policy[:, -1] = states[:, -1] - return states_policy + return states_policy.to(self.float) def state2readable(self, state: Optional[List] = None) -> str: """ From 55b8ed3573058744f61d5749d3656a6c60896181 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:16:17 -0400 Subject: [PATCH 42/54] Update env.reward() in base --- gflownet/envs/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 26e2a9668..f6da1fc57 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -757,7 +757,9 @@ def reward(self, state=None, done=None): done = self._get_done(done) if done is False: return tfloat(0.0, float_type=self.float, device=self.device) - return self.proxy2reward(self.proxy(self.state2proxy(state))[0]) + return self.proxy2reward( + self.proxy(torch.unsqueeze(self.state2proxy(state), dim=0))[0] + ) def reward_batch(self, states: List[List], done=None): """ From 1e71a6ded28c12233701886811028b4543f66cc6 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:16:43 -0400 Subject: [PATCH 43/54] Test batch: statebatch2policy -> states2policy --- tests/gflownet/utils/test_batch.py | 171 ++++------------------------- 1 file changed, 19 insertions(+), 152 deletions(-) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 338dfd061..776bba324 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -123,14 +123,7 @@ def test__get_states__single_env_returns_expected(env, batch, request): assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) @pytest.mark.repeat(N_REPETITIONS) @@ -155,14 +148,7 @@ def test__get_parents__single_env_returns_expected(env, batch, request): assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) @pytest.mark.repeat(N_REPETITIONS) @@ -197,14 +183,7 @@ def test__get_parents_all__single_env_returns_expected(env, batch, request): float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) @pytest.mark.repeat(N_REPETITIONS) @@ -365,14 +344,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -399,14 +371,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -423,14 +388,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -447,14 +405,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) @@ -551,14 +502,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -585,14 +529,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -609,14 +546,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -633,14 +563,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) @@ -794,14 +717,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -828,14 +744,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -852,14 +761,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -876,14 +778,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) @@ -1043,14 +938,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -1077,14 +965,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -1101,14 +982,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -1125,14 +999,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) From a50cd7dcdb0213db79fe958fa729f6e5733d5aee Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:17:00 -0400 Subject: [PATCH 44/54] Envs: remove state2proxy --- gflownet/envs/crystals/crystal.py | 28 -------------------- gflownet/envs/crystals/lattice_parameters.py | 22 --------------- gflownet/envs/crystals/spacegroup.py | 22 --------------- gflownet/envs/grid.py | 23 ---------------- gflownet/envs/tetris.py | 18 ------------- 5 files changed, 113 deletions(-) diff --git a/gflownet/envs/crystals/crystal.py b/gflownet/envs/crystals/crystal.py index 189b7927c..0acf0a8fa 100644 --- a/gflownet/envs/crystals/crystal.py +++ b/gflownet/envs/crystals/crystal.py @@ -462,34 +462,6 @@ def get_parents( return parents, actions - def state2proxy(self, state: Optional[List[int]] = None) -> Tensor: - """ - Prepares a list of states in "GFlowNet format" for the proxy. Simply - a concatenation of all crystal components. - """ - if state is None: - state = self.state.copy() - - composition_proxy_state = self.composition.state2proxy( - state=self._get_composition_state(state) - ).to(self.device) - space_group_proxy_state = ( - self.space_group.state2proxy(state=self._get_space_group_state(state)) - .unsqueeze(-1) # StateGroup proxy state is a single number - .to(self.device) - ) - lattice_parameters_proxy_state = self.lattice_parameters.state2proxy( - state=self._get_lattice_parameters_state(state) - ).to(self.device) - - return torch.cat( - [ - composition_proxy_state, - space_group_proxy_state, - lattice_parameters_proxy_state, - ] - ) - def states2proxy( self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index 11090bf02..4901c6404 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -337,28 +337,6 @@ def get_mask_invalid_actions_forward( return mask - def state2proxy(self, state: Optional[List[int]] = None) -> Tensor: - """ - Prepares a list of states in "GFlowNet format" for the proxy. - - Args - ---- - state : list - A state. - - Returns - ---- - proxy_state : Tensor - Tensor containing lengths and angles converted from the Grid format. - """ - if state is None: - state = self.state.copy() - - return Tensor( - [self.cell2length[s] for s in state[:3]] - + [self.cell2angle[s] for s in state[3:]] - ) - def states2proxy( self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 2f4052bcc..f5c8ee4cf 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -244,28 +244,6 @@ def get_mask_invalid_actions_forward( ] return mask - def state2proxy(self, state: List = None) -> Tensor: - """ - Prepares a list of states in "GFlowNet format" for the proxy. The input to the - proxy is simply the space group. - - Args - ---- - state : list - A state - - Returns - ---- - proxy_state : Tensor - """ - if state is None: - state = self.state - if state[self.sg_idx] == 0: - raise ValueError( - "The space group must have been set in order to call the proxy" - ) - return torch.tensor(state[self.sg_idx], device=self.device, dtype=torch.long) - def states2proxy( self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 4e6e90c87..8f6638eaa 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -125,29 +125,6 @@ def get_mask_invalid_actions_forward( mask[idx] = True return mask - def state2proxy(self, state: List = None) -> List: - """ - Prepares a state in "GFlowNet format" for the oracles: a list of length - n_dim with values in the range [cell_min, cell_max] for each state. - - See: state2policy() - - Args - ---- - state : list - State - """ - if state is None: - state = self.state.copy() - return ( - ( - np.array(self.state2policy(state)).reshape((self.n_dim, self.length)) - * self.cells[None, :] - ) - .sum(axis=1) - .tolist() - ) - def states2proxy( self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index 6338cc186..6c4e8bcb7 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -247,24 +247,6 @@ def get_mask_invalid_actions_forward( mask[-1] = True return mask - def state2proxy( - self, state: Optional[TensorType["height", "width"]] = None - ) -> TensorType["height", "width"]: - """ - Prepares a state in "environment format" for the oracles: simply converts - non-zero - (non-empty) cells into 1s. - - Args - ---- - state : tensor - """ - if state is None: - state = self.state.clone().detach() - state_proxy = state.clone().detach() - state_proxy[state_proxy != 0] = 1 - return state_proxy - def states2proxy( self, states: Union[ From 3880e48babbf08b2f6e6186366a9cf832ac54121 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:18:48 -0400 Subject: [PATCH 45/54] Remove reward_torchbatch because it is unused --- gflownet/envs/base.py | 16 ---------------- gflownet/envs/htorus.py | 2 +- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index f6da1fc57..f4e7fb829 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -777,22 +777,6 @@ def reward_batch(self, states: List[List], done=None): rewards[list(done)] = self.proxy2reward(self.proxy(states_proxy)).tolist() return rewards - def reward_torchbatch( - self, - states: TensorType["batch_size", "state_dim"], - done: TensorType["batch_size"] = None, - ): - """ - Computes the rewards of a batch of states in "GFlownet format" - """ - if done is None: - done = torch.ones(states.shape[0], dtype=torch.bool, device=self.device) - states_proxy = self.states2proxy(states[done, :]) - reward = torch.zeros(done.shape[0], dtype=self.float, device=self.device) - if states[done, :].shape[0] > 0: - reward[done] = self.proxy2reward(self.proxy(states_proxy)) - return reward - def proxy2reward(self, proxy_vals): """ Prepares the output of a proxy for GFlowNet: the inputs proxy_vals is expected diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index ff1dd5f69..461df04e5 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -539,7 +539,7 @@ def sample_from_reward( ), axis=1, ) - rewards = self.reward_torchbatch(samples) + rewards = self.reward_batch(samples) mask = ( torch.rand(n_samples, dtype=self.float, device=self.device) * (max_reward + epsilon) From 03308a96ebe4f1434611ef33c145126b6a4c122b Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:19:09 -0400 Subject: [PATCH 46/54] Add TODO --- gflownet/envs/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index f4e7fb829..2620bb3b0 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -761,6 +761,7 @@ def reward(self, state=None, done=None): self.proxy(torch.unsqueeze(self.state2proxy(state), dim=0))[0] ) + # TODO: cleanup def reward_batch(self, states: List[List], done=None): """ Computes the rewards of a batch of states, given a list of states and 'dones' From 5e3c765a9556fc60e80454c3a6aa73a224c78e40 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:20:52 -0400 Subject: [PATCH 47/54] statetorch2 -> states2 --- tests/gflownet/envs/test_ccube.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/gflownet/envs/test_ccube.py b/tests/gflownet/envs/test_ccube.py index feda16bf3..90b8a4549 100644 --- a/tests/gflownet/envs/test_ccube.py +++ b/tests/gflownet/envs/test_ccube.py @@ -1097,10 +1097,8 @@ def test__state2policy_returns_expected(env, state, expected): ], ) @pytest.mark.skip(reason="skip while developping other tests") -def test__statetorch2policy_returns_expected(env, states, expected): - assert torch.equal( - env.statetorch2policy(torch.tensor(states)), torch.tensor(expected) - ) +def test__states2policy_returns_expected(env, states, expected): + assert torch.equal(env.states2policy(torch.tensor(states)), torch.tensor(expected)) @pytest.mark.parametrize( From 388816cc08ceeafc12305216f3d23f06e756829a Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:34:29 -0400 Subject: [PATCH 48/54] Fix --- gflownet/envs/htorus.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 461df04e5..df7199045 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -539,7 +539,9 @@ def sample_from_reward( ), axis=1, ) - rewards = self.reward_batch(samples) + rewards = tfloat( + self.reward_batch(samples), device=self.device, float_type=self.float + ) mask = ( torch.rand(n_samples, dtype=self.float, device=self.device) * (max_reward + epsilon) From 14e30c1d0ec39975d5648fe59ccd407928182315 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:40:16 -0400 Subject: [PATCH 49/54] Delete files not relevant to branch --- gflownet/envs/seqs/amp.py | 33 --- gflownet/envs/seqs/aptamers.py | 30 -- gflownet/envs/seqs/sequence.py | 476 -------------------------------- tests/gflownet/envs/test_amp.py | 168 ----------- 4 files changed, 707 deletions(-) delete mode 100644 gflownet/envs/seqs/amp.py delete mode 100644 gflownet/envs/seqs/aptamers.py delete mode 100644 gflownet/envs/seqs/sequence.py delete mode 100644 tests/gflownet/envs/test_amp.py diff --git a/gflownet/envs/seqs/amp.py b/gflownet/envs/seqs/amp.py deleted file mode 100644 index 69f5e5ab3..000000000 --- a/gflownet/envs/seqs/amp.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -Classes to represent aptamers environments -""" -from typing import List, Tuple -import itertools -import numpy as np -from gflownet.envs.base import GFlowNetEnv -import itertools -from polyleven import levenshtein -import numpy.typing as npt -from torchtyping import TensorType -import torch -import matplotlib.pyplot as plt -import torch.nn.functional as F -from gflownet.utils.sequence.amp import AMINO_ACIDS -from gflownet.envs.sequence import Sequence - - -class AMP(Sequence): - """ - Anti-microbial peptide sequence environment - """ - - def __init__( - self, - **kwargs, - ): - special_tokens = ["[PAD]", "[EOS]"] - self.vocab = AMINO_ACIDS + special_tokens - super().__init__( - **kwargs, - special_tokens=special_tokens, - ) diff --git a/gflownet/envs/seqs/aptamers.py b/gflownet/envs/seqs/aptamers.py deleted file mode 100644 index 5b50eedf1..000000000 --- a/gflownet/envs/seqs/aptamers.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -Classes to represent aptamers environments -""" -import itertools -import time -from typing import List - -import numpy as np -import numpy.typing as npt -import pandas as pd -import time -from gflownet.utils.sequence.aptamers import NUCLEOTIDES -from gflownet.envs.sequence import Sequence - - -class Aptamers(Sequence): - """ - Aptamer sequence environment - """ - - def __init__( - self, - **kwargs, - ): - special_tokens = ["[PAD]", "[EOS]"] - self.vocab = NUCLEOTIDES + special_tokens - super().__init__( - **kwargs, - special_tokens=special_tokens, - ) diff --git a/gflownet/envs/seqs/sequence.py b/gflownet/envs/seqs/sequence.py deleted file mode 100644 index 23c23e57c..000000000 --- a/gflownet/envs/seqs/sequence.py +++ /dev/null @@ -1,476 +0,0 @@ -""" -Parent class to represent sequence-like environments, such as AMP and DNA. Sequences -are constructed by adding tokens from a dictionary. An alternative to this kind of -sequence environment (not-implemented as of July 2023) would be a "mutation-based" -modification of the sequences, or a combination of mutations and additions. -""" -import itertools -from typing import Iterable, List, Tuple - -import matplotlib.pyplot as plt -import numpy as np -import numpy.typing as npt -import torch -import torch.nn.functional as F -from polyleven import levenshtein -from torchtyping import TensorType - -from gflownet.envs.base import GFlowNetEnv - - -class Sequence(GFlowNetEnv): - """ - Parent of sequence environments. By default, for illustration purposes, this parent - class is functional and represents binary sequences of 0s and 1s, that can be - padded with the special token [PAD] and are terminated by the special token [EOS]. - - Attributes - ---------- - tokens : iterable - An iterable containing the vocabulary of tokens that make the sequences. - - max_length : int - Maximum length of the sequences. - - min_length : int - Minimum length of the sequences - - max_word_length : int - Maximum number of tokens allowed per action. - - min_word_length : int - Minimum number of tokens allowed per action. - - eos_token : int, str - EOS token. Default: -1. - - pad_token : int, str - PAD token. Default: -2. - """ - - def __init__( - self, - tokens: Iterable = [0, 1], - max_length: int = 10, - min_length: int = 1, - max_word_length: int = 1, - min_word_length: int = 1, - eos_token: Union[int, str] = -1, - pad_token: Union[int, str] = -2, - **kwargs, - ): - assert min_length > 0 - assert max_length > 0 - assert max_length >= min_length - assert min_word_length > 0 - assert max_word_length > 0 - assert max_word_length >= min_word_length - # Main attributes - self.tokens = set(tokens) - self.min_length = min_length - self.max_length = max_length - self.min_word_length = min_word_length - self.max_word_length = max_word_length - self.eos_idx = -1 - self.pad_idx = -2 - # Dictionaries - self.idx2token = {idx: token for idx, token in enumerate(self.tokens)} - self.idx2token[self.eos_idx] = eos_token - self.idx2token[self.pad_idx] = pad_token - self.token2idx = {token: idx for idx, token in self.idx2token.items()} - # Source state: vector of length max_length filled with pad token - self.source = torch.full( - self.max_length, self.pad_idx, dtype=torch.long, device=self.device - ) - # End-of-sequence action - self.eos = (self.eos_idx,) + (self.pad_idx,) * (self.max_word_length - 1) - # Base class init - super().__init__(**kwargs) - - def get_action_space(self): - """ - Constructs list with all possible actions, including eos. - - An action is represented by a vector of length max_word_length where each - element indicates the idex of the token to add to the sequence. Actions with a - number of tokens smaller than max_word_length are padded with pad_idx. - - Examples: - If min_word_length = 1 and max_word_length = 1: - actions: [(0,), (1,), (-1,)] - If min_word_length = 2 and max_word_length = 2: - actions: [(0, 0,), (0, 1), (1, 0), (1, 1), (-1, -2)] - If min_word_length = 1 and max_word_length = 2: - actions: [(0, -2), (1, -2), (0, 0,), (0, 1), (1, 0), (1, 1), (-1, -2)] - """ - valid_wordlens = np.arange(self.min_word_length, self.max_word_length + 1) - alphabet = [a for a in range(self.n_alphabet)] - actions = [] - for r in valid_wordlens: - actions_r = [el for el in itertools.product(alphabet, repeat=r)] - actions += actions_r - # Add "eos" action - # eos != n_alphabet in the init because it would break if max_word_length >1 - actions = actions + [(len(actions),)] - self.eos = len(actions) - 1 - return actions - - def copy(self): - return self.__class__(**self.__dict__) - # return deepcopy(self) - - def get_mask_invalid_actions_forward(self, state=None, done=None): - """ - Returns a vector of length the action space (where action space includes eos): True if action is invalid - given the current state, False otherwise. - """ - if state is None: - state = self.state.clone().detach() - if done is None: - done = self.done - if done: - return [True for _ in range(len(self.action_space))] - mask = [False for _ in range(len(self.action_space))] - seq_length = ( - torch.where(state == self.padding_idx)[0][0] - if state[-1] == self.padding_idx - else len(state) - ) - if seq_length < self.min_length: - mask[self.eos] = True - for idx, a in enumerate(self.action_space[:-1]): - if seq_length + len(list(a)) > self.max_length: - mask[idx] = True - return mask - - def true_density(self, max_states=1e6): - """ - Computes the reward density (reward / sum(rewards)) of the whole space, if the - dimensionality is smaller than specified in the arguments. - - Returns - ------- - Tuple: - - normalized reward for each state - - states - - (un-normalized) reward) - """ - if self._true_density is not None: - return self._true_density - if self.n_alphabet**self.max_length > max_states: - return (None, None, None) - state_all = np.int32( - list(itertools.product(*[list(range(self.n_alphabet))] * self.max_length)) - ) - traj_rewards, state_end = zip( - *[ - (self.proxy(state), state) - for state in state_all - if len(self.get_parents(state, False)[0]) > 0 or sum(state) == 0 - ] - ) - traj_rewards = np.array(traj_rewards) - self._true_density = ( - traj_rewards / traj_rewards.sum(), - list(map(tuple, state_end)), - traj_rewards, - ) - return self._true_density - - # def state2oracle(self, state: List = None): - # return "".join(self.state2readable(state)) - - def get_max_traj_length( - self, - ): - return self.max_length / self.min_word_length + 1 - - def statebatch2oracle(self, states: List[TensorType["max_length"]]) -> List[str]: - state_oracle = [] - for state in states: - if state[-1] == self.padding_idx: - state = state[: torch.where(state == self.padding_idx)[0][0]] - if self.tokenizer is not None and state[0] == self.tokenizer.bos_idx: - state = state[1:-1] - state_numpy = state.detach().cpu().numpy() - state_oracle.append(self.state2oracle(state_numpy)) - return state_oracle - - def statetorch2oracle( - self, states: TensorType["batch_dim", "max_length"] - ) -> List[str]: - return self.statebatch2oracle(states) - - # TODO: Deprecate as never used. - def state2policy(self, state=None): - """ - Transforms the sequence (state) given as argument (or self.state if None) into a - one-hot encoding. The output is a list of length nalphabet * max_length, - where each n-th successive block of nalphabet elements is a one-hot encoding of - the letter in the n-th position. - - Example: - - Sequence: AATGC - - state: [0, 1, 3, 2] - A, T, G, C - - state2obs(state): [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] - | A | T | G | C | - - If max_length > len(state), the last (max_length - len(state)) blocks are all - 0s. - """ - if state is None: - state = self.state.clone().detach() - state = ( - state[: torch.where(state == self.padding_idx)[0][0]] - if state[-1] == self.padding_idx - else state - ) - state_policy = torch.zeros(1, self.max_length, self.n_alphabet) - if len(state) == 0: - return state_policy.reshape(1, -1) - state_onehot = F.one_hot(state, num_classes=self.n_alphabet + 1)[:, :, 1:].to( - self.float - ) - state_policy[:, : state_onehot.shape[1], :] = state_onehot - return state_policy.reshape(state.shape[0], -1) - - def statebatch2policy( - self, states: List[TensorType["1", "max_length"]] - ) -> TensorType["batch", "policy_input_dim"]: - """ - Transforms a batch of states into the policy model format. The output is a numpy - array of shape [n_states, n_alphabet * max_seq_len]. - - See state2policy() - """ - state_tensor = torch.vstack(states) - state_policy = self.statetorch2policy(state_tensor) - return state_policy - - def statetorch2policy( - self, states: TensorType["batch", "max_length"] - ) -> TensorType["batch", "policy_input_dim"]: - if states.dtype != torch.long: - states = states.to(torch.long) - state_onehot = ( - F.one_hot(states, self.n_alphabet + 2)[:, :, :-2] - .to(self.float) - .to(self.device) - ) - state_padding_mask = (states != self.padding_idx).to(self.float).to(self.device) - state_onehot_pad = state_onehot * state_padding_mask.unsqueeze(-1) - # Assertion works as long as [PAD] is last key in lookup table. - assert torch.eq(state_onehot_pad, state_onehot).all() - state_policy = torch.zeros( - states.shape[0], - self.max_length, - self.n_alphabet, - device=self.device, - dtype=self.float, - ) - state_policy[:, : state_onehot.shape[1], :] = state_onehot - return state_policy.reshape(states.shape[0], -1) - - def policytorch2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a sequence (state) given as argument - into a a sequence of letter indices. - - Example: - - Sequence: AATGC - - state_policy: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] - | A | A | T | G | C | - - policy2state(state_policy): [0, 0, 1, 3, 2] - A, A, T, G, C - """ - mat_state_policy = torch.reshape( - state_policy, (self.max_length, self.n_alphabet) - ) - state = torch.where(mat_state_policy)[1].tolist() - return state - - # TODO: Deprecate as never used. - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a sequence (state) given as argument - into a a sequence of letter indices. - - Example: - - Sequence: AATGC - - state_policy: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] - | A | A | T | G | C | - - policy2state(state_policy): [0, 0, 1, 3, 2] - A, A, T, G, C - """ - mat_state_policy = np.reshape(state_policy, (self.max_length, self.n_alphabet)) - state = np.where(mat_state_policy)[1].tolist() - return state - - def state2oracle(self, state: List = None): - return "".join(self.state2readable(state)) - - def statebatch2oracle(self, states: List[TensorType["max_length"]]) -> List[str]: - state_oracle = [] - for state in states: - if state[-1] == self.padding_idx: - state = state[: torch.where(state == self.padding_idx)[0][0]] - if self.tokenizer is not None and state[0] == self.tokenizer.bos_idx: - state = state[1:-1] - state_numpy = state.detach().cpu().numpy() - state_oracle.append(self.state2oracle(state_numpy)) - return state_oracle - - def statetorch2oracle( - self, states: TensorType["batch_dim", "max_length"] - ) -> List[str]: - return self.statebatch2oracle(states) - - def state2readable(self, state: List) -> str: - """ - Transforms a sequence given as a list of indices into a sequence of letters - according to an alphabet. - Used only in Buffer - """ - if isinstance(state, torch.Tensor) == False: - state = torch.tensor(state).long() - if state[-1] == self.padding_idx: - state = state[: torch.where(state == self.padding_idx)[0][0]] - state = state.tolist() - return "".join([self.inverse_lookup[el] for el in state]) - - def statetorch2readable(self, state: TensorType["1", "max_length"]) -> str: - if state[-1] == self.padding_idx: - state = state[: torch.where(state == self.padding_idx)[0][0]] - # TODO: neater way without having lookup as input arg - if ( - self.lookup is not None - and "[CLS]" in self.lookup.keys() - and state[0] == self.lookup["[CLS]"] - ): - state = state[1:-1] - state = state.tolist() - readable = [self.inverse_lookup[el] for el in state] - return "".join(readable) - - def readable2state(self, readable) -> TensorType["batch_dim", "max_length"]: - """ - Transforms a list or string of letters into a list of indices according to an alphabet. - """ - if isinstance(readable, str): - encoded_readable = [self.lookup[el] for el in readable] - state = torch.ones(self.max_length, dtype=torch.int64) * self.padding_idx - state[: len(encoded_readable)] = torch.tensor(encoded_readable) - else: - encoded_readable = [[self.lookup[el] for el in seq] for seq in readable] - state = ( - torch.ones((len(readable), self.max_length), dtype=torch.int64) - * self.padding_idx - ) - for i, seq in enumerate(encoded_readable): - state[i, : len(seq)] = torch.tensor(seq) - return state - - def get_parents(self, state=None, done=None, action=None): - """ - Determines all parents and actions that lead to sequence state - - Args - ---- - state : list - Representation of a sequence (state), as a list of length max_length - where each element is the index of a letter in the alphabet, from 0 to - (nalphabet - 1). - - action : int - Last action performed, only to determine if it was eos. - - Returns - ------- - parents : list - List of parents as state2obs(state) - - actions : list - List of actions that lead to state for each parent in parents - """ - # TODO: Adapt to tuple form actions - if state is None: - state = self.state.clone().detach() - if done is None: - done = self.done - if done: - return [state], [(self.eos,)] - elif torch.eq(state, self.source).all(): - return [], [] - else: - parents = [] - actions = [] - if state[-1] == self.padding_idx: - state_last_element = int(torch.where(state == self.padding_idx)[0][0]) - else: - state_last_element = len(state) - max_parent_action_length = self.max_word_length + 1 - self.min_word_length - for parent_action_length in range(1, max_parent_action_length + 1): - parent_action = tuple( - state[ - state_last_element - parent_action_length : state_last_element - ].numpy() - ) - if parent_action in self.action_space: - parent = state.clone().detach() - parent[ - state_last_element - parent_action_length : state_last_element - ] = self.padding_idx - parents.append(parent) - actions.append(parent_action) - return parents, actions - - def step(self, action: Tuple[int]) -> Tuple[List[int], Tuple[int, int], bool]: - """ - Executes step given an action index - - If action_idx is smaller than eos (no stop), add action to next - position. - - Args - ---- - action_idx : int - Index of action in the action space. a == eos indicates "stop action" - - Returns - ------- - self.state : list - The sequence after executing the action - - valid : bool - False, if the action is not allowed for the current state, e.g. stop at the - root state - """ - assert action in self.action_space - # If only possible action is eos, then force eos - if self.state[-1] != self.padding_idx: - self.done = True - self.n_actions += 1 - return self.state, (self.eos,), True - # If action is not eos, then perform action - state_last_element = int(torch.where(self.state == self.padding_idx)[0][0]) - if action[0] != self.eos: - state_next = self.state.clone().detach() - if state_last_element + len(action) > self.max_length: - valid = False - else: - state_next[ - state_last_element : state_last_element + len(action) - ] = torch.LongTensor(action) - self.state = state_next - valid = True - self.n_actions += 1 - return self.state, action, valid - else: - if state_last_element < self.min_length: - valid = False - else: - self.done = True - valid = True - self.n_actions += 1 - return self.state, (self.eos,), valid diff --git a/tests/gflownet/envs/test_amp.py b/tests/gflownet/envs/test_amp.py deleted file mode 100644 index ed05f2201..000000000 --- a/tests/gflownet/envs/test_amp.py +++ /dev/null @@ -1,168 +0,0 @@ -import pytest -import torch -import numpy as np - -from gflownet.envs.amp import AMP - - -@pytest.fixture -def env(): - return AMP(proxy_state_format="state") - - -def test__environment__initializes_properly(): - env = AMP(proxy_state_format="state") - assert torch.eq( - env.source, torch.ones(env.max_seq_length, dtype=torch.int64) * env.padding_idx - ).all() - assert torch.eq( - env.state, torch.ones(env.max_seq_length, dtype=torch.int64) * env.padding_idx - ).all() - - -def test__environment__action_space_has_eos(): - env = AMP(proxy_state_format="state") - assert (env.eos,) in env.action_space - - -@pytest.mark.parametrize( - "state, expected_state_policy", - [ - ( - torch.tensor([[3, 2, 21, 21, 21]]), - [ - [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - ), - ( - torch.tensor([[3, 2, 4, 2, 0]]), - [ - [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - ), - ( - torch.tensor([[21, 21, 21, 21, 21]]), - [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - ), - ], -) -def test_environment_policy_transformation(state, expected_state_policy): - env = AMP(proxy_state_format="state", max_seq_length=5) - expected_state_policy_tensor = torch.tensor( - expected_state_policy, dtype=env.float, device=env.device - ).reshape(state.shape[0], -1) - state_policy = env.statetorch2policy(state) - assert torch.eq(state_policy, expected_state_policy_tensor).all() - - -@pytest.mark.parametrize( - "state, done, expected_parent, expected_parent_action", - [ - ( - torch.tensor([3, 21, 21, 21, 21]), - False, - [torch.tensor([21, 21, 21, 21, 21])], - [(3,)], - ), - ( - torch.tensor([3, 2, 4, 2, 0]), - False, - [torch.tensor([3, 2, 4, 2, 21])], - [(0,)], - ), - ( - torch.tensor([3, 21, 21, 21, 21]), - False, - [torch.tensor([21, 21, 21, 21, 21])], - [(3,)], - ), - ( - torch.tensor([3, 21, 21, 21, 21]), - True, - [torch.tensor([3, 21, 21, 21, 21])], - [(20,)], - ), - ( - torch.tensor([21, 21, 21, 21, 21]), - False, - [], - [], - ), - ], -) -def test_environment_get_parents(state, done, expected_parent, expected_parent_action): - env = AMP(proxy_state_format="state", max_seq_length=5) - parent, parent_action = env.get_parents(state, done) - print(parent, parent_action) - if parent != []: - parent_tensor = torch.vstack(parent).to(env.device).to(env.float) - expected_parent_tensor = ( - torch.vstack(expected_parent).to(env.device).to(env.float) - ) - assert torch.eq(parent_tensor, expected_parent_tensor).all() - else: - assert parent == expected_parent - assert parent_action == expected_parent_action - - -@pytest.mark.parametrize( - "state, action, expected_next_state, expected_executed_action, expected_valid", - [ - ( - torch.tensor([3, 21, 21, 21, 21]), - (2,), - torch.tensor([3, 2, 21, 21, 21]), - (2,), - True, - ), - ( - torch.tensor([3, 2, 4, 2, 0]), - (2,), - torch.tensor([3, 2, 4, 2, 0]), - (20,), - True, - ), - ( - torch.tensor([21, 21, 21, 21, 21]), - (20,), - torch.tensor([21, 21, 21, 21, 21]), - (20,), - False, - ), - ( - torch.tensor([3, 21, 21, 21, 21]), - (20,), - torch.tensor([3, 21, 21, 21, 21]), - (20,), - True, - ), - ], -) -def test_environment_step( - state, action, expected_next_state, expected_executed_action, expected_valid -): - env = AMP(proxy_state_format="state", max_seq_length=5) - env.state = state - n_actions = env.n_actions - next_state, action_executed, valid = env.step(action) - if expected_executed_action == (20,) and expected_valid == True: - assert env.done == True - if expected_valid == True: - assert env.n_actions == n_actions + 1 - assert torch.eq(next_state, expected_next_state).all() - assert action_executed == expected_executed_action - assert valid == expected_valid From ce38e3fee66f095e7826f52d731500f5a42dbc8e Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:46:02 -0400 Subject: [PATCH 50/54] Delete files not relevant to branch --- config/env/seqs/amp.yaml | 18 ------------------ config/env/seqs/aptamers.yaml | 15 --------------- 2 files changed, 33 deletions(-) delete mode 100644 config/env/seqs/amp.yaml delete mode 100644 config/env/seqs/aptamers.yaml diff --git a/config/env/seqs/amp.yaml b/config/env/seqs/amp.yaml deleted file mode 100644 index b4b70bbb8..000000000 --- a/config/env/seqs/amp.yaml +++ /dev/null @@ -1,18 +0,0 @@ -defaults: - - base - -_target_: gflownet.envs.amp.AMP - -id: amp -# Minimum and maximum length for the sequences -min_seq_length: 1 -max_seq_length: 50 -# Number of letters in alphabet -n_alphabet: 20 -# Minimum and maximum number of steps in the action space -min_word_len: 1 -max_word_len: 1 -reward_func: power -reward_norm_std_mult: -1.0 -reward_norm: 0.1 -reward_beta: 8.0 \ No newline at end of file diff --git a/config/env/seqs/aptamers.yaml b/config/env/seqs/aptamers.yaml deleted file mode 100644 index 210df4f83..000000000 --- a/config/env/seqs/aptamers.yaml +++ /dev/null @@ -1,15 +0,0 @@ -defaults: - - base - -_target_: gflownet.envs.aptamers.Aptamers - -id: aptamers -# Minimum and maximum length for the sequences -min_seq_length: 30 -max_seq_length: 30 -# Number of letters in alphabet -n_alphabet: 4 -# Minimum and maximum number of steps in the action space -min_word_len: 1 -max_word_len: 1 -corr_type: None From 88d3fa6dea4eff8d3ab27e37f98727bee6a08305 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:46:30 -0400 Subject: [PATCH 51/54] Delete files not relevant to branch --- config/proxy/seqs/amp.yaml | 9 --------- config/proxy/seqs/aptamers.yaml | 6 ------ 2 files changed, 15 deletions(-) delete mode 100644 config/proxy/seqs/amp.yaml delete mode 100644 config/proxy/seqs/aptamers.yaml diff --git a/config/proxy/seqs/amp.yaml b/config/proxy/seqs/amp.yaml deleted file mode 100644 index ec4063f47..000000000 --- a/config/proxy/seqs/amp.yaml +++ /dev/null @@ -1,9 +0,0 @@ -_target_: gflownet.proxy.amp.AMPOracleWrapper - -oracle_split: "D2_target" -oracle_type: "MLP" -oracle_features: "AlBert" -dist_fn: "edit" -medoid_oracle_norm: 1 -maximize: True -cost: 1 diff --git a/config/proxy/seqs/aptamers.yaml b/config/proxy/seqs/aptamers.yaml deleted file mode 100644 index 5ebce2d19..000000000 --- a/config/proxy/seqs/aptamers.yaml +++ /dev/null @@ -1,6 +0,0 @@ -_target_: gflownet.proxy.aptamers.Aptamers - -oracle_id: "energy" -norm: False -cost: 4 -maximize: False \ No newline at end of file From f0e3feb625cdc9f3f534d686ef1d898d9d4e6ee3 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 31 Oct 2023 12:52:22 -0400 Subject: [PATCH 52/54] Delete files not relevant to branch --- gflownet/proxy/aptamers.py | 81 -------------------------------------- 1 file changed, 81 deletions(-) delete mode 100644 gflownet/proxy/aptamers.py diff --git a/gflownet/proxy/aptamers.py b/gflownet/proxy/aptamers.py deleted file mode 100644 index 34e172510..000000000 --- a/gflownet/proxy/aptamers.py +++ /dev/null @@ -1,81 +0,0 @@ -import numpy as np -import numpy.typing as npt -from nupack import * -import torch - -from gflownet.proxy.base import Proxy - - -class Aptamers(Proxy): - """ - DNA Aptamer oracles - """ - - def __init__(self, oracle_id, norm, cost, **kwargs): - super().__init__(**kwargs) - self.type = oracle_id - self.norm = norm - self.cost = cost - - def setup(self, env, norm=True): - self.max_seq_length = env.max_seq_length - - def _length(self, x): - if self.norm: - return -1.0 * np.sum(x, axis=1) / self.max_seq_length - else: - return -1.0 * np.sum(x, axis=1) - - def __call__(self, sequences): - if self.type == "length": - return self._length(sequences) - elif self.type == "pairs": - self.function = self._func_pairs - return self._nupack(sequences) - elif self.type == "energy": - self.function = self._func_energy - return self._nupack(sequences) - else: - raise NotImplementedError - - def _nupack(self, sequences): - """ - args: - inputs: list of arrays in desired format interpretable by oracle - returns: - array of scores - function: - creates the complex set and calls the desired nupack function - """ - temperature = 310.0 # Kelvin - ionicStrength = 1.0 # molar - strandList = [] - comps = [] - i = -1 - for sequence in sequences: - i += 1 - strandList.append(Strand(sequence, name="strand{}".format(i))) - comps.append(Complex([strandList[-1]], name="comp{}".format(i))) - - set = ComplexSet( - strands=strandList, complexes=SetSpec(max_size=1, include=comps) - ) - model1 = Model(material="dna", celsius=temperature - 273, sodium=ionicStrength) - results = complex_analysis(set, model=model1, compute=["mfe"]) - - energy = self.function(sequences, results, comps) - - return torch.tensor(energy, device=self.device, dtype=self.float) - - def _func_energy(self, sequences, results, comps): - energies = np.zeros(len(sequences)) - for i in range(len(energies)): - energies[i] = results[comps[i]].mfe[0].energy - return energies - - def _func_pairs(self, sequences, results, comps): - ssStrings = np.zeros(len(sequences), dtype=object) - for i in range(len(ssStrings)): - ssStrings[i] = str(results[comps[i]].mfe[0].structure) - nPairs = np.asarray([ssString.count("(") for ssString in ssStrings]).astype(int) - return -nPairs From cc634cf15135cf2decd51efac8b1e2a9eada7b1c Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 14 Nov 2023 12:54:08 -0500 Subject: [PATCH 53/54] Fix tensor comparison --- tests/gflownet/envs/test_grid.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/gflownet/envs/test_grid.py b/tests/gflownet/envs/test_grid.py index 27ebbadd5..afa375649 100644 --- a/tests/gflownet/envs/test_grid.py +++ b/tests/gflownet/envs/test_grid.py @@ -3,6 +3,7 @@ import torch from gflownet.envs.grid import Grid +from gflownet.utils.common import tfloat @pytest.fixture @@ -66,7 +67,10 @@ def config_path(): ], ) def test__state2proxy__returns_expected(env, state, state2proxy): - assert state2proxy == env.state2proxy(state) + assert torch.equal( + tfloat(state2proxy, device=env.device, float_type=env.float), + env.state2proxy(state), + ) @pytest.mark.parametrize( From 11d5fc4248eb780638756c53e7bbb4d49e0cac12 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Tue, 14 Nov 2023 12:54:34 -0500 Subject: [PATCH 54/54] Add typing in returns of state2policy and state2proxy of base env. --- gflownet/envs/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 2620bb3b0..3f8d535b3 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -683,7 +683,9 @@ def states2proxy( """ return tfloat(states, device=self.device, float_type=self.float) - def state2proxy(self, state: Union[List, TensorType["state_dim"]] = None): + def state2proxy( + self, state: Union[List, TensorType["state_dim"]] = None + ) -> TensorType["state_proxy_dim"]: """ Prepares a state in "GFlowNet format" for the proxy. By default, states2proxy is called, which by default will return the state as is. @@ -716,7 +718,9 @@ def states2policy( """ return tfloat(states, device=self.device, float_type=self.float) - def state2policy(self, state: Union[List, TensorType["state_dim"]] = None): + def state2policy( + self, state: Union[List, TensorType["state_dim"]] = None + ) -> TensorType["policy_input_dim"]: """ Prepares a state in "GFlowNet format" for the policy model. By default, states2policy is called, which by default will return the state as is.