diff --git a/README.md b/README.md index 796f8dc3b..61d306145 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,11 @@ -# GFlowNet +# Private sister repository of gflownet -This repository implements GFlowNets, generative flow networks for probabilistic modelling, on PyTorch. A design guideline behind this implementation is the separation of the logic of the GFlowNet agent and the environments on which the agent can be trained on. In other words, this implementation should allow its extension with new environments without major or any changes to to the agent. Another design guideline is flexibility and modularity. The configuration is handled via the use of [Hydra](https://hydra.cc/docs/intro/). +This repository (`gflownet-dev`) is private. It is meant to be used to develop research ideas and projects before making them public in the original [alexhernandezgarcia/gflownet](https://github.com/alexhernandezgarcia/gflownet) repository (`gflownet`). -## Installation +As of October 2023, it is uncertain whether we will stick to this plan in the long term, but the idea is the following: -### pip +- Develop ideas and projects in `gflownet-dev`. +- Upon publication or whenever the authors feel comfortable, transfer the relevant code to `gflownet`. +- Relevant code improvements and development that does not compromise research projects should be transferred to `gflownet` as early as possible. -```bash -python -m pip install --upgrade https://github.com/alexhernandezgarcia/gflownet/archive/main.zip -``` - -## How to train a GFlowNet model - -To train a GFlowNet model with the default configuration, simply run - -```bash -python main.py user.logdir.root= -``` - -Alternatively, you can create a user configuration file in `config/user/.yaml` specifying a `logdir.root` and run - -```bash -python main.py user= -``` - -Using Hydra, you can easily specify any variable of the configuration in the command line. For example, to train GFlowNet with the trajectory balance loss, on the continuous torus (`ctorus`) environment and the corresponding proxy: - -```bash -python main.py gflownet=trajectorybalance env=ctorus proxy=torus -``` - -The above command will overwrite the `env` and `proxy` default configuration with the configuration files in `config/env/ctorus.yaml` and `config/proxy/torus.yaml` respectively. - -Hydra configuration is hierarchical. For instance, a handy variable to change while debugging our code is to avoid logging to wandb. You can do this by setting `logger.do.online=False`. - -## Logging to wandb - -The repository supports logging of train and evaluation metrics to [wandb.ai](https://wandb.ai), but it is disabled by default. In order to enable it, set the configuration variable `logger.do.online` to `True`. +This involves extra complexity, so we will re-evaluate or refine this plan after a test period. diff --git a/config/env/aptamers.yaml b/config/env/aptamers.yaml deleted file mode 100644 index 5de209fa3..000000000 --- a/config/env/aptamers.yaml +++ /dev/null @@ -1,16 +0,0 @@ -defaults: - - base - -_target_: gflownet.envs.aptamers.AptamerSeq - -id: aptamers -func: nupack energy -# Minimum and maximum length for the sequences -min_seq_length: 30 -max_seq_length: 30 -# Number of letters in alphabet -n_alphabet: 4 -# Minimum and maximum number of steps in the action space -min_word_len: 1 -max_word_len: 1 - diff --git a/gflownet/envs/alaninedipeptide.py b/gflownet/envs/alaninedipeptide.py index 76b725e3b..04b8e39b8 100644 --- a/gflownet/envs/alaninedipeptide.py +++ b/gflownet/envs/alaninedipeptide.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import List, Tuple +from typing import List, Tuple, Union import numpy as np import numpy.typing as npt @@ -40,25 +40,34 @@ def sync_conformer_with_state(self, state: List = None): self.conformer.set_torsion_angle(ta, state[idx]) return self.conformer - def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray: + # TODO: are the conversions to oracle relevant? + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> npt.NDArray: """ - Prepares a batch of states in torch "GFlowNet format" for the oracle. - """ - device = states.device - if device == torch.device("cpu"): - np_states = states.numpy() - else: - np_states = states.cpu().numpy() - return np_states[:, :-1] - - def statebatch2proxy(self, states: List[List]) -> npt.NDArray: - """ - Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where - each state is a row of length n_dim with an angle in radians. The n_actions + Prepares a batch of states in "environment format" for the proxy: each state is + a vector of length n_dim where each value is an angle in radians. The n_actions item is removed. + + Important: this method returns a numpy array, unlike in most other + environments. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A numpy array containing all the states in the batch. """ - return np.array(states)[:, :-1] + if torch.is_tensor(states[0]): + return states.cpu().numpy()[:, :-1] + else: + return np.array(states)[:, :-1] + # TODO: need to keep? def statetorch2oracle( self, states: TensorType["batch", "state_dim"] ) -> List[Tuple[npt.NDArray, npt.NDArray]]: @@ -73,6 +82,7 @@ def statetorch2oracle( result = self.statebatch2oracle(np_states) return result + # TODO: need to keep? def statebatch2oracle( self, states: List[List] ) -> List[Tuple[npt.NDArray, npt.NDArray]]: diff --git a/gflownet/envs/aptamers.py b/gflownet/envs/aptamers.py deleted file mode 100644 index 425a2eb91..000000000 --- a/gflownet/envs/aptamers.py +++ /dev/null @@ -1,431 +0,0 @@ -""" -Classes to represent aptamers environments -""" -import itertools -import time -from typing import List - -import numpy as np -import numpy.typing as npt -import pandas as pd - -from gflownet.envs.base import GFlowNetEnv - - -class AptamerSeq(GFlowNetEnv): - """ - Aptamer sequence environment - - Attributes - ---------- - max_seq_length : int - Maximum length of the sequences - - min_seq_length : int - Minimum length of the sequences - - n_alphabet : int - Number of letters in the alphabet - - state : list - Representation of a sequence (state), as a list of length max_seq_length where - each element is the index of a letter in the alphabet, from 0 to (n_alphabet - - 1). - - done : bool - True if the sequence has reached a terminal state (maximum length, or stop - action executed. - - func : str - Name of the reward function - - n_actions : int - Number of actions applied to the sequence - - proxy : lambda - Proxy model - """ - - def __init__( - self, - max_seq_length=42, - min_seq_length=1, - n_alphabet=4, - min_word_len=1, - max_word_len=1, - **kwargs, - ): - super().__init__() - self.source = [] - self.min_seq_length = min_seq_length - self.max_seq_length = max_seq_length - self.n_alphabet = n_alphabet - self.min_word_len = min_word_len - self.max_word_len = max_word_len - self.action_space = self.get_action_space() - self.eos = self.action_space_dim - self.reset() - self.fixed_policy_output = self.get_fixed_policy_output() - self.random_policy_output = self.get_fixed_policy_output() - self.policy_output_dim = len(self.fixed_policy_output) - self.policy_input_dim = len(self.state2policy()) - self.max_traj_len = self.get_max_traj_length() - # Set up proxy - self.setup_proxy() - - def get_action_space(self): - """ - Constructs list with all possible actions - """ - assert self.max_word_len >= self.min_word_len - valid_wordlens = np.arange(self.min_word_len, self.max_word_len + 1) - alphabet = [a for a in range(self.n_alphabet)] - actions = [] - for r in valid_wordlens: - actions_r = [el for el in itertools.product(alphabet, repeat=r)] - actions += actions_r - return actions - - def get_max_traj_length( - self, - ): - return self.max_seq_length / self.min_word_len + 1 - - def reward_arbitrary_i(self, state): - if len(state) > 0: - return (state[-1] + 1) * len(state) - else: - return 0 - - def statebatch2oracle(self, states: List[List]): - """ - Prepares a batch of sequence states for the oracles. - - Args - ---- - states : list of lists - List of sequences. - """ - queries = [s + [-1] * (self.max_seq_length - len(s)) for s in states] - queries = np.array(queries, dtype=int) - if queries.ndim == 1: - queries = queries[np.newaxis, ...] - queries += 1 - if queries.shape[1] == 1: - import ipdb - - ipdb.set_trace() - queries = np.column_stack((queries, np.zeros(queries.shape[0]))) - return queries - - def state2policy(self, state: List = None) -> List: - """ - Transforms the sequence (state) given as argument (or self.state if None) into a - one-hot encoding. The output is a list of length n_alphabet * max_seq_length, - where each n-th successive block of n_alphabet elements is a one-hot encoding of - the letter in the n-th position. - - Example: - - Sequence: AATGC - - state: [0, 1, 3, 2] - A, T, G, C - - state2policy(state): [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] - | A | T | G | C | - - If max_seq_length > len(s), the last (max_seq_length - len(s)) blocks are all - 0s. - """ - if state is None: - state = self.state.copy() - state_policy = np.zeros(self.n_alphabet * self.max_seq_length, dtype=np.float32) - if len(state) > 0: - state_policy[(np.arange(len(state)) * self.n_alphabet + state)] = 1 - return state_policy - - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Transforms a batch of states into the policy model format. The output is a numpy - array of shape [n_states, n_angles * n_dim + 1]. - - See state2policy(). - """ - cols, lengths = zip( - *[ - (state + np.arange(len(state)) * self.n_alphabet, len(state)) - for state in states - ] - ) - rows = np.repeat(np.arange(len(states)), lengths) - state_policy = np.zeros( - (len(states), self.n_alphabet * self.max_seq_length), dtype=np.float32 - ) - state_policy[rows, np.concatenate(cols)] = 1.0 - return state_policy - - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a sequence (state) given as argument - into a a sequence of letter indices. - - Example: - - Sequence: AATGC - - state_policy: [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0] - | A | A | T | G | C | - - state: [0, 0, 1, 3, 2] - A, A, T, G, C - """ - return np.where( - np.reshape(state_policy, (self.max_seq_length, self.n_alphabet)) - )[1].tolist() - - def state2readable(self, state, alphabet={0: "A", 1: "T", 2: "C", 3: "G"}): - """ - Transforms a sequence given as a list of indices into a sequence of letters - according to an alphabet. - """ - return [alphabet[el] for el in state] - - def readable2state(self, letters, alphabet={0: "A", 1: "T", 2: "C", 3: "G"}): - """ - Transforms a sequence given as a list of indices into a sequence of letters - according to an alphabet. - """ - alphabet = {v: k for k, v in alphabet.items()} - return [alphabet[el] for el in letters] - - def reset(self, env_id=None): - """ - Resets the environment. - """ - self.state = [] - self.n_actions = 0 - self.done = False - self.id = env_id - return self - - def get_parents(self, state=None, done=None, action=None): - """ - Determines all parents and actions that lead to sequence state - - Args - ---- - state : list - Representation of a sequence (state), as a list of length max_seq_length - where each element is the index of a letter in the alphabet, from 0 to - (n_alphabet - 1). - - done : bool - Whether the trajectory is done. If None, done is taken from instance. - - action : None - Ignored - - Returns - ------- - parents : list - List of parents in state format - - actions : list - List of actions that lead to state for each parent in parents - """ - if state is None: - state = self.state.copy() - if done is None: - done = self.done - if done: - return [state], [self.eos] - else: - parents = [] - actions = [] - for idx, a in enumerate(self.action_space): - is_parent = state[-len(a) :] == list(a) - if not isinstance(is_parent, bool): - is_parent = all(is_parent) - if is_parent: - parents.append(state[: -len(a)]) - actions.append(idx) - return parents, actions - - def step(self, action_idx): - """ - Executes step given an action index - - If action_idx is smaller than eos (no stop), add action to next - position. - - See: step_daug() - See: step_chain() - - Args - ---- - action_idx : int - Index of action in the action space. a == eos indicates "stop action" - - Returns - ------- - self.state : list - The sequence after executing the action - - valid : bool - False, if the action is not allowed for the current state, e.g. stop at the - root state - """ - # If only possible action is eos, then force eos - if len(self.state) == self.max_seq_length: - self.done = True - self.n_actions += 1 - return self.state, [self.eos], True - # If action is not eos, then perform action - if action_idx != self.eos: - action = self.action_space[action_idx] - state_next = self.state + list(action) - if len(state_next) > self.max_seq_length: - valid = False - else: - self.state = state_next - valid = True - self.n_actions += 1 - return self.state, action_idx, valid - # If action is eos, then perform eos - else: - if len(self.state) < self.min_seq_length: - valid = False - else: - self.done = True - valid = True - self.n_actions += 1 - return self.state, self.eos, valid - - def get_mask_invalid_actions_forward(self, state=None, done=None): - """ - Returns a vector of length the action space + 1: True if action is invalid - given the current state, False otherwise. - """ - if state is None: - state = self.state.copy() - if done is None: - done = self.done - if done: - return [True for _ in range(self.action_space_dim + 1)] - mask = [False for _ in range(self.action_space_dim + 1)] - seq_length = len(state) - if seq_length < self.min_seq_length: - mask[self.eos] = True - for idx, a in enumerate(self.action_space): - if seq_length + len(a) > self.max_seq_length: - mask[idx] = True - return mask - - def make_train_set( - self, - ntrain, - oracle=None, - seed=168, - output_csv=None, - ): - """ - Constructs a randomly sampled train set. - - Args - ---- - ntest : int - Number of test samples. - - seed : int - Random seed. - - output_csv: str - Optional path to store the test set as CSV. - """ - samples_dict = oracle.initializeDataset( - save=False, returnData=True, customSize=ntrain, custom_seed=seed - ) - energies = samples_dict["energies"] - samples_mat = samples_dict["samples"] - state_letters = oracle.numbers2letters(samples_mat) - state_ints = [ - "".join([str(el) for el in state if el > 0]) for state in samples_mat - ] - if isinstance(energies, dict): - energies.update({"samples": state_letters, "indices": state_ints}) - df_train = pd.DataFrame(energies) - else: - df_train = pd.DataFrame( - {"samples": state_letters, "indices": state_ints, "energies": energies} - ) - if output_csv: - df_train.to_csv(output_csv) - return df_train - - # TODO: improve approximation of uniform - def make_test_set( - self, - path_base_dataset, - ntest, - min_length=0, - max_length=np.inf, - seed=167, - output_csv=None, - ): - """ - Constructs an approximately uniformly distributed (on the score) set, by - selecting samples from a larger base set. - - Args - ---- - path_base_dataset : str - Path to a CSV file containing the base data set. - - ntest : int - Number of test samples. - - seed : int - Random seed. - - dask : bool - If True, use dask to efficiently read a large base file. - - output_csv: str - Optional path to store the test set as CSV. - """ - if path_base_dataset is None: - return None, None - times = { - "all": 0.0, - "indices": 0.0, - } - t0_all = time.time() - if seed: - np.random.seed(seed) - df_base = pd.read_csv(path_base_dataset, index_col=0) - df_base = df_base.loc[ - (df_base["samples"].map(len) >= min_length) - & (df_base["samples"].map(len) <= max_length) - ] - energies_base = df_base["energies"].values - min_base = energies_base.min() - max_base = energies_base.max() - distr_unif = np.random.uniform(low=min_base, high=max_base, size=ntest) - # Get minimum distance samples without duplicates - t0_indices = time.time() - idx_samples = [] - for idx in tqdm(range(ntest)): - dist = np.abs(energies_base - distr_unif[idx]) - idx_min = np.argmin(dist) - if idx_min in idx_samples: - idx_sort = np.argsort(dist) - for idx_next in idx_sort: - if idx_next not in idx_samples: - idx_samples.append(idx_next) - break - else: - idx_samples.append(idx_min) - t1_indices = time.time() - times["indices"] += t1_indices - t0_indices - # Make test set - df_test = df_base.iloc[idx_samples] - if output_csv: - df_test.to_csv(output_csv) - t1_all = time.time() - times["all"] += t1_all - t0_all - return df_test, times diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 278490adc..ab68825a3 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -38,7 +38,6 @@ def __init__( energies_stats: List[int] = None, denorm_proxy: bool = False, proxy=None, - oracle=None, proxy_state_format: str = "oracle", fixed_distr_params: Optional[dict] = None, random_distr_params: Optional[dict] = None, @@ -68,17 +67,10 @@ def __init__( self.reward_func = reward_func self.energies_stats = energies_stats self.denorm_proxy = denorm_proxy - # Proxy and oracle + # Proxy self.proxy = proxy self.setup_proxy() - if oracle is None: - self.oracle = self.proxy - else: - self.oracle = oracle - if self.oracle is None or self.oracle.higher_is_better: - self.proxy_factor = 1.0 - else: - self.proxy_factor = -1.0 + self.proxy_factor = -1.0 self.proxy_state_format = proxy_state_format # Flag to skip checking if action is valid (computing mask) before step self.skip_mask_check = skip_mask_check @@ -100,9 +92,6 @@ def __init__( self.random_policy_output = self.get_policy_output(self.random_distr_params) self.policy_output_dim = len(self.fixed_policy_output) self.policy_input_dim = len(self.state2policy()) - if proxy is not None and self.proxy == self.oracle: - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle @abstractmethod def get_action_space(self): @@ -683,91 +672,75 @@ def get_policy_output( """ return torch.ones(self.action_space_dim, dtype=self.float, device=self.device) - def state2proxy(self, state: List = None): + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a state in "GFlowNet format" for the proxy. + Prepares a batch of states in "environment format" for the proxy. By default, + the batch of states is converted into a tensor with float dtype and returned as + is. Args ---- - state : list - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state.copy() - return self.statebatch2proxy([state]) + return tfloat(states, device=self.device, float_type=self.float) - def statebatch2proxy(self, states: List[List]) -> npt.NDArray[np.float32]: + def state2proxy( + self, state: Union[List, TensorType["state_dim"]] = None + ) -> TensorType["state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the proxy. + Prepares a state in "GFlowNet format" for the proxy. By default, states2proxy + is called, which by default will return the state as is. Args ---- state : list A state """ - return np.array(states) - - def statetorch2proxy( - self, states: TensorType["batch_size", "state_dim"] - ) -> TensorType["batch_size", "state_proxy_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the proxy. - """ - return states + state = self._get_state(state) + return torch.squeeze(self.states2proxy([state]), dim=0) - def state2oracle(self, state: List = None): + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: """ - Prepares a state in "GFlowNet format" for the oracle. + Prepares a batch of states in "environment format" for the policy model: By + default, the batch of states is converted into a tensor with float dtype and + returned as is. Args ---- - state : list - A state - """ - if state is None: - state = self.state.copy() - return state - - def statebatch2oracle(self, states: List[List]): - """ - Prepares a batch of states in "GFlowNet format" for the oracles - """ - return states + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - def statetorch2policy( - self, states: TensorType["batch_size", "state_dim"] - ) -> TensorType["batch_size", "policy_input_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the policy - """ - return states - - def state2policy(self, state=None): - """ - Converts a state into a format suitable for a machine learning model, such as a - one-hot encoding. + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state - return tfloat(state, float_type=self.float, device=self.device) + return tfloat(states, device=self.device, float_type=self.float) - def statebatch2policy( - self, states: List[List] - ) -> TensorType["batch_size", "policy_input_dim"]: + def state2policy( + self, state: Union[List, TensorType["state_dim"]] = None + ) -> TensorType["policy_input_dim"]: """ - Converts a batch of states into a format suitable for a machine learning model, - such as a one-hot encoding. Returns a numpy array. - """ - return self.statetorch2policy( - tfloat(states, float_type=self.float, device=self.device) - ) + Prepares a state in "GFlowNet format" for the policy model. By default, + states2policy is called, which by default will return the state as is. - def policy2state(self, state_policy: List) -> List: - """ - Converts the model (e.g. one-hot encoding) version of a state given as - argument into a state. + Args + ---- + state : list + A state """ - return state_policy + state = self._get_state(state) + return torch.squeeze(self.states2policy([state]), dim=0) def state2readable(self, state=None): """ @@ -797,15 +770,18 @@ def reward(self, state=None, done=None): done = self._get_done(done) if done is False: return tfloat(0.0, float_type=self.float, device=self.device) - return self.proxy2reward(self.proxy(self.state2proxy(state))[0]) + return self.proxy2reward( + self.proxy(torch.unsqueeze(self.state2proxy(state), dim=0))[0] + ) + # TODO: cleanup def reward_batch(self, states: List[List], done=None): """ Computes the rewards of a batch of states, given a list of states and 'dones' """ if done is None: done = np.ones(len(states), dtype=bool) - states_proxy = self.statebatch2proxy(states) + states_proxy = self.states2proxy(states) if isinstance(states_proxy, torch.Tensor): states_proxy = states_proxy[list(done), :] elif isinstance(states_proxy, list): @@ -815,27 +791,11 @@ def reward_batch(self, states: List[List], done=None): rewards[list(done)] = self.proxy2reward(self.proxy(states_proxy)).tolist() return rewards - def reward_torchbatch( - self, - states: TensorType["batch_size", "state_dim"], - done: TensorType["batch_size"] = None, - ): - """ - Computes the rewards of a batch of states in "GFlownet format" - """ - if done is None: - done = torch.ones(states.shape[0], dtype=torch.bool, device=self.device) - states_proxy = self.statetorch2proxy(states[done, :]) - reward = torch.zeros(done.shape[0], dtype=self.float, device=self.device) - if states[done, :].shape[0] > 0: - reward[done] = self.proxy2reward(self.proxy(states_proxy)) - return reward - def proxy2reward(self, proxy_vals): """ - Prepares the output of an oracle for GFlowNet: the inputs proxy_vals is - expected to be a negative value (energy), unless self.denorm_proxy is True. If - the latter, the proxy values are first de-normalized according to the mean and + Prepares the output of a proxy for GFlowNet: the inputs proxy_vals is expected + to be a negative value (energy), unless self.denorm_proxy is True. If the + latter, the proxy values are first de-normalized according to the mean and standard deviation in self.energies_stats. The output of the function is a strictly positive reward - provided self.reward_norm and self.reward_beta are positive - and larger than self.min_reward. @@ -878,7 +838,7 @@ def proxy2reward(self, proxy_vals): def reward2proxy(self, reward): """ Converts a "GFlowNet reward" into a (negative) energy or values as returned by - an oracle. + a proxy. """ if self.reward_func == "power": return self.proxy_factor * torch.exp( @@ -1359,7 +1319,7 @@ def top_k_metrics_and_plots( return metrics, figs, fig_names def plot_reward_distribution( - self, states=None, scores=None, ax=None, title=None, oracle=None, **kwargs + self, states=None, scores=None, ax=None, title=None, proxy=None, **kwargs ): if ax is None: fig, ax = plt.subplots() @@ -1368,15 +1328,15 @@ def plot_reward_distribution( standalone = False if title == None: title = "Scores of Sampled States" - if oracle is None: - oracle = self.oracle + if proxy is None: + proxy = self.proxy if scores is None: if isinstance(states[0], torch.Tensor): states = torch.vstack(states).to(self.device, self.float) if isinstance(states, torch.Tensor) == False: states = torch.tensor(states, device=self.device, dtype=self.float) - oracle_states = self.statetorch2oracle(states) - scores = oracle(oracle_states) + states_proxy = self.states2proxy(states) + scores = self.proxy(states_proxy) if isinstance(scores, TensorType): scores = scores.cpu().detach().numpy() ax.hist(scores) diff --git a/gflownet/envs/crystals/composition.py b/gflownet/envs/crystals/composition.py index 55b8bb186..371c002b0 100644 --- a/gflownet/envs/crystals/composition.py +++ b/gflownet/envs/crystals/composition.py @@ -134,10 +134,6 @@ def __init__( self.source = [0 for _ in self.elements] # End-of-sequence action self.eos = (-1, -1) - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle super().__init__(**kwargs) def get_action_space(self): @@ -406,67 +402,24 @@ def get_element_mask(min_atoms, max_atoms): return mask - def state2oracle(self, state: List = None) -> Tensor: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a state in "GFlowNet format" for the oracle. The output is a tensor of - length N_ELEMENTS_ORACLE + 1, where the positions of self.elements are filled with - the number of atoms of each element in the state. + Prepares a batch of states in "environment format" for the proxy: simply + returns the states as are with dtype long. Args ---- - state : list - A state - - Returns - ---- - oracle_state : Tensor - Tensor containing counts of individual elements - """ - if state is None: - state = self.state - return self.statetorch2oracle( - torch.unsqueeze(tfloat(state, device=self.device, float_type=self.float), 0) - )[0] - - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracle. The output is - a tensor with N_ELEMENTS_ORACLE + 1 columns, where the positions of - self.elements are filled with the number of atoms of each element in the state. - - Args - ---- - states : Tensor - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. Returns - ---- - oracle_states : Tensor - """ - states_float = states.to(self.float) - - states_oracle = torch.zeros( - (states.shape[0], N_ELEMENTS_ORACLE + 1), - device=self.device, - dtype=self.float, - ) - states_oracle[:, tlong(self.elements, device=self.device)] = states_float - return states_oracle - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles. In this case, - it simply converts the states into a torch tensor, with dtype torch.long. - - Args - ---- - state : list + ------- + A tensor containing all the states in the batch. """ - return self.statetorch2oracle(tlong(states, device=self.device)) + return tlong(states, device=self.device) def state2readable(self, state=None): """ diff --git a/gflownet/envs/crystals/crystal.py b/gflownet/envs/crystals/crystal.py index 0e914ce9a..0acf0a8fa 100644 --- a/gflownet/envs/crystals/crystal.py +++ b/gflownet/envs/crystals/crystal.py @@ -10,6 +10,7 @@ from gflownet.envs.crystals.composition import Composition from gflownet.envs.crystals.lattice_parameters import LatticeParameters from gflownet.envs.crystals.spacegroup import SpaceGroup +from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import TRICLINIC @@ -128,11 +129,6 @@ def __init__( self.lattice_parameters.eos, Stage.LATTICE_PARAMETERS ) - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle - super().__init__(**kwargs) def _set_lattice_parameters(self): @@ -247,7 +243,7 @@ def _get_composition_state(self, state: Optional[List[int]] = None) -> List[int] def _get_composition_tensor_states( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: return states[:, self.composition_state_start : self.composition_state_end] def _get_space_group_state(self, state: Optional[List[int]] = None) -> List[int]: @@ -258,7 +254,7 @@ def _get_space_group_state(self, state: Optional[List[int]] = None) -> List[int] def _get_space_group_tensor_states( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: return states[:, self.space_group_state_start : self.space_group_state_end] def _get_lattice_parameters_state( @@ -273,7 +269,7 @@ def _get_lattice_parameters_state( def _get_lattice_parameters_tensor_states( self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: + ) -> TensorType["batch", "state_proxy_dim"]: return states[ :, self.lattice_parameters_state_start : self.lattice_parameters_state_end ] @@ -466,58 +462,38 @@ def get_parents( return parents, actions - def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a list of states in "GFlowNet format" for the oracle. Simply - a concatenation of all crystal components. + Prepares a batch of states in "environment format" for the proxy: simply a + concatenation of all crystal components. + + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state.copy() - - composition_oracle_state = self.composition.state2oracle( - state=self._get_composition_state(state) - ).to(self.device) - space_group_oracle_state = ( - self.space_group.state2oracle(state=self._get_space_group_state(state)) - .unsqueeze(-1) # StateGroup oracle state is a single number - .to(self.device) - ) - lattice_parameters_oracle_state = self.lattice_parameters.state2oracle( - state=self._get_lattice_parameters_state(state) - ).to(self.device) - - return torch.cat( - [ - composition_oracle_state, - space_group_oracle_state, - lattice_parameters_oracle_state, - ] - ) - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - return self.statetorch2oracle( - torch.tensor(states, device=self.device, dtype=torch.long) - ) - - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - composition_oracle_states = self.composition.statetorch2oracle( + states = tlong(states, device=self.device) + composition_proxy_states = self.composition.states2proxy( self._get_composition_tensor_states(states) ).to(self.device) - space_group_oracle_states = self.space_group.statetorch2oracle( + space_group_proxy_states = self.space_group.states2proxy( self._get_space_group_tensor_states(states) ).to(self.device) - lattice_parameters_oracle_states = self.lattice_parameters.statetorch2oracle( + lattice_parameters_proxy_states = self.lattice_parameters.states2proxy( self._get_lattice_parameters_tensor_states(states) ).to(self.device) return torch.cat( [ - composition_oracle_states, - space_group_oracle_states, - lattice_parameters_oracle_states, + composition_proxy_states, + space_group_proxy_states, + lattice_parameters_proxy_states, ], dim=1, ) diff --git a/gflownet/envs/crystals/lattice_parameters.py b/gflownet/envs/crystals/lattice_parameters.py index 957a7e229..4901c6404 100644 --- a/gflownet/envs/crystals/lattice_parameters.py +++ b/gflownet/envs/crystals/lattice_parameters.py @@ -1,7 +1,7 @@ """ Classes to represent crystal environments """ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -9,6 +9,7 @@ from torchtyping import TensorType from gflownet.envs.grid import Grid +from gflownet.utils.common import tlong from gflownet.utils.crystals.constants import ( CUBIC, HEXAGONAL, @@ -336,48 +337,28 @@ def get_mask_invalid_actions_forward( return mask - def state2oracle(self, state: Optional[List[int]] = None) -> Tensor: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a list of states in "GFlowNet format" for the oracle. + Prepares a batch of states in "environment format" for the proxy: the + concatenation of the lengths and angles. Args ---- - state : list - A state. - - Returns - ---- - oracle_state : Tensor - Tensor containing lengths and angles converted from the Grid format. - """ - if state is None: - state = self.state.copy() - - return Tensor( - [self.cell2length[s] for s in state[:3]] - + [self.cell2angle[s] for s in state[3:]] - ) - - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is the lengths and angles. - - Args - ---- - states : Tensor - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. Returns - ---- - oracle_states : Tensor + ------- + A tensor containing all the states in the batch. """ + states = tlong(states, device=self.device) return torch.cat( [ - self.lengths_tensor[states[:, :3].long()], - self.angles_tensor[states[:, 3:].long()], + self.lengths_tensor[states[:, :3]], + self.angles_tensor[states[:, 3:]], ], dim=1, ) diff --git a/gflownet/envs/crystals/spacegroup.py b/gflownet/envs/crystals/spacegroup.py index 0faa4ffc3..011bc1e3e 100644 --- a/gflownet/envs/crystals/spacegroup.py +++ b/gflownet/envs/crystals/spacegroup.py @@ -14,6 +14,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import tlong from gflownet.utils.crystals.pyxtal_cache import space_group_check_compatible CRYSTAL_LATTICE_SYSTEMS = None @@ -130,10 +131,6 @@ def __init__( # Source state: index 0 (empty) for all three properties (crystal-lattice # system index, point symmetry index, space group) self.source = [0 for _ in range(3)] - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle # Base class init super().__init__(**kwargs) @@ -247,65 +244,25 @@ def get_mask_invalid_actions_forward( ] return mask - def state2oracle(self, state: List = None) -> Tensor: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a list of states in "GFlowNet format" for the oracle. The input to the - oracle is simply the space group. + Prepares a batch of states in "environment format" for the proxy: the proxy + format is simply the space group. Args ---- - state : list - A state - - Returns - ---- - oracle_state : Tensor - """ - if state is None: - state = self.state - if state[self.sg_idx] == 0: - raise ValueError( - "The space group must have been set in order to call the oracle" - ) - return torch.tensor(state[self.sg_idx], device=self.device, dtype=torch.long) - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is simply the space group. - - Args - ---- - state : list - A state - - Returns - ---- - oracle_state : Tensor - """ - return self.statetorch2oracle( - torch.tensor(states, device=self.device, dtype=torch.long) - ) - - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracle. The input to the - oracle is simply the space group. - - Args - ---- - state : list - A state + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. Returns - ---- - oracle_state : Tensor + ------- + A tensor containing all the states in the batch. """ - return torch.unsqueeze(states[:, self.sg_idx], dim=1).to(torch.long) + states = tlong(states, device=self.device) + return torch.unsqueeze(states[:, self.sg_idx], dim=1) def state2readable(self, state=None): """ diff --git a/gflownet/envs/cube.py b/gflownet/envs/cube.py index ca929ea6d..ee5c52622 100644 --- a/gflownet/envs/cube.py +++ b/gflownet/envs/cube.py @@ -4,7 +4,7 @@ import itertools import warnings from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -102,6 +102,9 @@ def __init__( self.epsilon = epsilon # Small constant to restrict the interval of (test) sets self.kappa = kappa + # Conversions: only conversions to policy are implemented and the conversion to + # proxy format is the same + self.states2proxy = self.states2policy # Base class init super().__init__( fixed_distr_params=fixed_distr_params, @@ -130,105 +133,25 @@ def get_mask_invalid_actions_forward( def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None): pass - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] = None - ) -> TensorType["batch", "oracle_input_dim"]: - """ - Clips the states into [0, 1] and maps them to [-1.0, 1.0] - - Args - ---- - state : list - State - """ - return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Clips the states into [0, 1] and maps them to [-1.0, 1.0] - - Args - ---- - state : list - State - """ - return self.statetorch2oracle( - tfloat(states, device=self.device, float_type=self.float) - ) - - def state2oracle(self, state: List = None) -> List: - """ - Clips the state into [0, 1] and maps it to [-1.0, 1.0] - """ - if state is None: - state = self.state.copy() - return [2.0 * min(max(0.0, s), 1.0) - 1.0 for s in state] - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] = None - ) -> TensorType["batch", "oracle_input_dim"]: - """ - Returns statetorch2oracle(states), that is states mapped to [-1.0, 1.0]. - - Args - ---- - state : list - State - """ - return self.statetorch2oracle(states) - - def statebatch2proxy( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Returns statebatch2oracle(states), that is states mapped to [-1.0, 1.0]. - - Args - ---- - state : list - State - """ - return self.statebatch2oracle(states) - - def state2proxy(self, state: List = None) -> List: - """ - Returns state2oracle(state), that is the state mapped to [-1.0, 1.0]. - """ - return self.state2oracle(state) - - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] = None + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch", "policy_input_dim"]: """ - Returns statetorch2proxy(states), that is states mapped to [-1.0, 1.0]. - - Args - ---- - state : list - State - """ - return self.statetorch2proxy(states) - - def statebatch2policy( - self, states: List[List] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Returns statebatch2proxy(states), that is states mapped to [-1.0, 1.0]. + Prepares a batch of states in "environment format" for the policy model: clips + the states into [0, 1] and maps them to [-1.0, 1.0] Args ---- - state : list - State - """ - return self.statebatch2proxy(states) + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - def state2policy(self, state: List = None) -> List: - """ - Returns state2proxy(state), that is the state mapped to [-1.0, 1.0]. + Returns + ------- + A tensor containing all the states in the batch. """ - return self.state2proxy(state) + states = tfloat(states, device=self.device, float_type=self.float) + return 2.0 * torch.clip(states, min=0.0, max=1.0) - 1.0 def state2readable(self, state: List) -> str: """ @@ -1489,7 +1412,7 @@ def sample_from_reward( samples_final = [] max_reward = self.proxy2reward(self.proxy.min) while len(samples_final) < n_samples: - samples_uniform = self.statebatch2proxy( + samples_uniform = self.states2proxy( self.get_uniform_terminating_states(n_samples) ) rewards = self.proxy2reward(self.proxy(samples_uniform)) diff --git a/gflownet/envs/grid.py b/gflownet/envs/grid.py index 46e3639ef..8f6638eaa 100644 --- a/gflownet/envs/grid.py +++ b/gflownet/envs/grid.py @@ -2,7 +2,7 @@ Classes to represent a hyper-grid environments """ import itertools -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -12,6 +12,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import tfloat, tlong class Grid(GFlowNetEnv): @@ -80,10 +81,7 @@ def __init__( # Proxy format # TODO: assess if really needed if self.proxy_state_format == "ohe": - self.statebatch2proxy = self.statebatch2policy - elif self.proxy_state_format == "oracle": - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle + self.states2proxy = self.states2policy def get_action_space(self): """ @@ -127,127 +125,68 @@ def get_mask_invalid_actions_forward( mask[idx] = True return mask - def state2oracle(self, state: List = None) -> List: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a state in "GFlowNet format" for the oracles: a list of length - n_dim with values in the range [cell_min, cell_max] for each state. - - See: state2policy() - - Args - ---- - state : list - State - """ - if state is None: - state = self.state.copy() - return ( - ( - np.array(self.state2policy(state)).reshape((self.n_dim, self.length)) - * self.cells[None, :] - ) - .sum(axis=1) - .tolist() - ) - - def statebatch2oracle( - self, states: List[List] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: each state is + Prepares a batch of states in "environment format" for the proxy: each state is a vector of length n_dim with values in the range [cell_min, cell_max]. - See: statetorch2oracle() + See: states2policy() Args ---- - state : list - State - """ - return self.statetorch2oracle( - torch.tensor(states, device=self.device, dtype=self.float) - ) + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - def statetorch2oracle( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: each state is - a vector of length n_dim with values in the range [cell_min, cell_max]. - - See: statetorch2policy() + Returns + ------- + A tensor containing all the states in the batch. """ + states = tfloat(states, device=self.device, float_type=self.float) return ( - self.statetorch2policy(states).reshape( - (len(states), self.n_dim, self.length) + self.states2policy(states).reshape( + (states.shape[0], self.n_dim, self.length) ) * torch.tensor(self.cells[None, :]).to(states.device, self.float) ).sum(axis=2) - def state2policy(self, state: List = None) -> List: + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: """ - Transforms the state given as argument (or self.state if None) into a - one-hot encoding. The output is a list of len length * n_dim, + Prepares a batch of states in "environment format" for the policy model: states + are one-hot encoded. + + The output is a 2D tensor, with the second dimension of size length * n_dim, where each n-th successive block of length elements is a one-hot encoding of the position in the n-th dimension. - Example: - - State, state: [0, 3, 1] (n_dim = 3) - - state2policy(state): [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] (length = 4) - | 0 | 3 | 1 | - """ - if state is None: - state = self.state.copy() - state_policy = np.zeros(self.length * self.n_dim, dtype=np.float32) - state_policy[(np.arange(len(state)) * self.length + state)] = 1 - return state_policy.tolist() - - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Transforms a batch of states into a one-hot encoding. The output is a numpy - array of shape [n_states, length * n_dim]. - - See state2policy(). - """ - cols = np.array(states) + np.arange(self.n_dim) * self.length - rows = np.repeat(np.arange(len(states)), self.n_dim) - state_policy = np.zeros( - (len(states), self.length * self.n_dim), dtype=np.float32 - ) - state_policy[rows, cols.flatten()] = 1.0 - return state_policy + Example (n_dim = 3, length = 4): + - state: [0, 3, 1] + - policy format: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] + | 0 | 3 | 1 | - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "policy_output_dim"]: - """ - Transforms a batch of states into a one-hot encoding. The output is a numpy - array of shape [n_states, length * n_dim]. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See state2policy(). + Returns + ------- + A tensor containing all the states in the batch. """ - device = states.device - cols = (states + torch.arange(self.n_dim).to(device) * self.length).to(int) - rows = torch.repeat_interleave( - torch.arange(states.shape[0]).to(device), self.n_dim + states = tlong(states, device=self.device) + n_states = states.shape[0] + cols = states + torch.arange(self.n_dim) * self.length + rows = torch.repeat_interleave(torch.arange(n_states), self.n_dim) + states_policy = torch.zeros( + (n_states, self.length * self.n_dim), dtype=self.float, device=self.device ) - state_policy = torch.zeros( - (states.shape[0], self.length * self.n_dim), dtype=states.dtype - ).to(device) - state_policy[rows, cols.flatten()] = 1.0 - return state_policy - - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a state given as argument - into a state (list of the position at each dimension). - - Example: - - state_policy: [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0] (length = 4, n_dim = 3) - | 0 | 3 | 1 | - - policy2state(state_policy): [0, 3, 1] - """ - return np.where(np.reshape(state_policy, (self.n_dim, self.length)))[1].tolist() + states_policy[rows, cols.flatten()] = 1.0 + return states_policy def readable2state(self, readable, alphabet={}): """ diff --git a/gflownet/envs/htorus.py b/gflownet/envs/htorus.py index 011a74c51..0f7146be4 100644 --- a/gflownet/envs/htorus.py +++ b/gflownet/envs/htorus.py @@ -4,7 +4,7 @@ import itertools import re from copy import deepcopy -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -73,10 +73,6 @@ def __init__( self.source = self.source_angles + [0] # End-of-sequence action: (n_dim, 0) self.eos = (self.n_dim, 0) - # TODO: assess if really needed - self.state2oracle = self.state2proxy - self.statebatch2oracle = self.statebatch2proxy - self.statetorch2oracle = self.statetorch2proxy # Base class init super().__init__( fixed_distr_params=fixed_distr_params, @@ -185,84 +181,63 @@ def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=Non ] + [mask[-1]] return mask - def statebatch2proxy( - self, states: List[List] + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where - each state is a row of length n_dim with an angle in radians. The n_actions + Prepares a batch of states in "environment format" for the proxy: each state is + a vector of length n_dim where each value is an angle in radians. The n_actions item is removed. - """ - return torch.tensor(states, device=self.device)[:, :-1] - - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the proxy. - """ - return states[:, :-1] - def state2policy(self, state: List = None) -> List: - """ - Returns the policy encoding of the state. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See: statebatch2policy() + Returns + ------- + A tensor containing all the states in the batch. """ - if state is None: - state = self.state.copy() - return self.statebatch2policy([state]).tolist()[0] + return tfloat(states, device=self.device, float_type=self.float)[:, :-1] - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] ) -> TensorType["batch", "policy_input_dim"]: """ - Prepares a batch of states in torch "GFlowNet format" for the policy. - - If policy_encoding_dim_per_angle >= 2, then the state (angles) is encoded using + Prepares a batch of states in "environment format" for the policy model: if + policy_encoding_dim_per_angle >= 2, then the state (angles) is encoded using trigonometric components. - """ - if ( - self.policy_encoding_dim_per_angle is not None - and self.policy_encoding_dim_per_angle >= 2 - ): - step = states[:, -1] - code_half_size = self.policy_encoding_dim_per_angle // 2 - int_coeff = ( - torch.arange(1, code_half_size + 1) - .repeat(states.shape[-1] - 1) - .to(states) - ) - encoding = ( - torch.repeat_interleave(states[:, :-1], repeats=code_half_size, dim=1) - * int_coeff - ) - states = torch.cat( - [torch.cos(encoding), torch.sin(encoding), torch.unsqueeze(step, 1)], - dim=1, - ) - return states - def statebatch2policy( - self, states: List[List] - ) -> TensorType["batch_size", "policy_input_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the policy. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See: statetorch2policy() + Returns + ------- + A tensor containing all the states in the batch. """ states = tfloat(states, float_type=self.float, device=self.device) - return self.statetorch2policy(states) - - def policy2state(self, state_policy: List) -> List: - """ - Returns the input as is. - """ - if self.policy_encoding_dim_per_angle is not None: - raise NotImplementedError( - "Convertion from encoded policy_state to state is not impemented" - ) - return state_policy + if ( + self.policy_encoding_dim_per_angle is None + or self.policy_encoding_dim_per_angle < 2 + ): + return states + step = states[:, -1] + code_half_size = self.policy_encoding_dim_per_angle // 2 + int_coeff = ( + torch.arange(1, code_half_size + 1).repeat(states.shape[-1] - 1).to(states) + ) + encoding = ( + torch.repeat_interleave(states[:, :-1], repeats=code_half_size, dim=1) + * int_coeff + ) + return torch.cat( + [torch.cos(encoding), torch.sin(encoding), torch.unsqueeze(step, 1)], + dim=1, + ) def state2readable(self, state: List) -> str: """ @@ -566,7 +541,9 @@ def sample_from_reward( ), axis=1, ) - rewards = self.reward_torchbatch(samples) + rewards = tfloat( + self.reward_batch(samples), device=self.device, float_type=self.float + ) mask = ( torch.rand(n_samples, dtype=self.float, device=self.device) * (max_reward + epsilon) @@ -606,7 +583,7 @@ def plot_reward_samples( [samples_mesh, torch.ones(samples_mesh.shape[0], 1)], 1 ).to(self.device) rewards = torch2np( - self.proxy2reward(self.proxy(self.statetorch2proxy(states_mesh))) + self.proxy2reward(self.proxy(self.states2proxy(states_mesh))) ) # Init figure fig, ax = plt.subplots() diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index a80d447ff..6c4e8bcb7 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -4,7 +4,7 @@ import itertools import re import warnings -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -99,10 +99,6 @@ def __init__( ) # End-of-sequence action: all -1 self.eos = (-1, -1, -1) - # Conversions - self.state2proxy = self.state2oracle - self.statebatch2proxy = self.statebatch2oracle - self.statetorch2proxy = self.statetorch2oracle # Precompute all possible rotations of each piece and the corresponding binary # mask @@ -251,87 +247,53 @@ def get_mask_invalid_actions_forward( mask[-1] = True return mask - def state2oracle( - self, state: Optional[TensorType["height", "width"]] = None - ) -> TensorType["height", "width"]: - """ - Prepares a state in "GFlowNet format" for the oracles: simply converts non-zero - (non-empty) cells into 1s. - - Args - ---- - state : tensor - """ - if state is None: - state = self.state.clone().detach() - state_oracle = state.clone().detach() - state_oracle[state_oracle != 0] = 1 - return state_oracle - - def statebatch2oracle( - self, states: List[TensorType["height", "width"]] - ) -> TensorType["batch", "state_oracle_dim"]: + def states2proxy( + self, + states: Union[ + List[TensorType["height", "width"]], TensorType["height", "width", "batch"] + ], + ) -> TensorType["height", "width", "batch"]: """ - Prepares a batch of states in "GFlowNet format" for the oracles: simply + Prepares a batch of states in "environment format" for a proxy: : simply converts non-zero (non-empty) cells into 1s. Args ---- - state : list - """ - states = torch.stack(states) - states[states != 0] = 1 - return states + states : list of 2D tensors or 3D tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - def statetorch2oracle( - self, states: TensorType["height", "width", "batch"] - ) -> TensorType["height", "width", "batch"]: - """ - Prepares a batch of states in "GFlowNet format" for the oracles: : simply - converts non-zero (non-empty) cells into 1s. + Returns + ------- + A tensor containing all the states in the batch. """ + states = tint(states, device=self.device, int_type=self.int) states[states != 0] = 1 return states - def state2policy( - self, state: Optional[TensorType["height", "width"]] = None - ) -> TensorType["height", "width"]: - """ - Prepares a state in "GFlowNet format" for the policy model. - - See: state2oracle() - """ - return self.state2oracle(state).flatten() - - def statebatch2policy( - self, states: List[TensorType["height", "width"]] - ) -> TensorType["batch", "state_oracle_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the policy model. - - See statebatch2oracle(). - """ - return self.statebatch2oracle(states).flatten(start_dim=1) - - def statetorch2policy( - self, states: TensorType["height", "width", "batch"] + def states2policy( + self, + states: Union[ + List[TensorType["height", "width"]], TensorType["height", "width", "batch"] + ], ) -> TensorType["height", "width", "batch"]: """ - Prepares a batch of states in "GFlowNet format" for the policy model. + Prepares a batch of states in "environment format" for the policy model. - See statetorch2oracle(). - """ - return self.statetorch2oracle(states).flatten(start_dim=1) + See states2proxy(). - def policy2state( - self, policy: Optional[TensorType["height", "width"]] = None - ) -> TensorType["height", "width"]: - """ - Returns None to signal that the conversion is not reversible. + Args + ---- + states : list of 2D tensors or 3D tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See: state2oracle() + Returns + ------- + A tensor containing all the states in the batch. """ - return None + states = tint(states, device=self.device, int_type=self.int) + return self.states2proxy(states).flatten(start_dim=1).to(self.float) def state2readable(self, state: Optional[TensorType["height", "width"]] = None): """ diff --git a/gflownet/envs/torus.py b/gflownet/envs/torus.py index 8c0ce712d..54b1183a3 100644 --- a/gflownet/envs/torus.py +++ b/gflownet/envs/torus.py @@ -2,7 +2,7 @@ Classes to represent hyper-torus environments """ import itertools -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import numpy as np import numpy.typing as npt @@ -11,6 +11,7 @@ from torchtyping import TensorType from gflownet.envs.base import GFlowNetEnv +from gflownet.utils.common import tfloat, tlong class Torus(GFlowNetEnv): @@ -64,9 +65,6 @@ def __init__( self.eos = tuple([self.max_increment + 1 for _ in range(self.n_dim)]) # Angle increments in radians self.angle_rad = 2 * np.pi / self.n_angles - # TODO: assess if really needed - self.state2oracle = self.state2proxy - self.statebatch2oracle = self.statebatch2proxy # Base class init super().__init__(**kwargs) @@ -112,105 +110,67 @@ def get_mask_invalid_actions_forward( mask[-1] = True return mask - def statebatch2proxy(self, states: List[List]) -> npt.NDArray[np.float32]: + def states2proxy( + self, states: Union[List[List], TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "state_proxy_dim"]: """ - Prepares a batch of states in "GFlowNet format" for the proxy: an array where - each state is a row of length n_dim with an angle in radians. The n_actions + Prepares a batch of states in "environment format" for the proxy: each state is + a vector of length n_dim where each value is an angle in radians. The n_actions item is removed. - """ - return torch.tensor(states, device=self.device)[:, :-1] * self.angle_rad - def statetorch2proxy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in torch "GFlowNet format" for the proxy. + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. + + Returns + ------- + A tensor containing all the states in the batch. """ - return states[:, :-1] * self.angle_rad + return ( + tfloat(states, device=self.device, float_type=self.float)[:, :-1] + * self.angle_rad + ) # TODO: circular encoding as in htorus - def state2policy(self, state=None) -> List: + def states2policy( + self, states: Union[List, TensorType["batch", "state_dim"]] + ) -> TensorType["batch", "policy_input_dim"]: """ - Transforms the angles part of the state given as argument (or self.state if - None) into a one-hot encoding. The output is a list of len n_angles * n_dim + - 1, where each n-th successive block of length elements is a one-hot encoding of - the position in the n-th dimension. + Prepares a batch of states in "environment format" for the policy model: the + policy format is a one-hot encoding of the states. + + Each row is a vector of length n_angles * n_dim + 1, where each n-th successive + block of length elements is a one-hot encoding of the position in the n-th + dimension. Example, n_dim = 2, n_angles = 4: - - State, state: [1, 3, 4] + - state: [1, 3, 4] | a | n | (a = angles, n = n_actions) - - state2policy(state): [0, 1, 0, 0, 0, 0, 0, 1, 4] - | 1 | 3 | 4 | - """ - if state is None: - state = self.state.copy() - # TODO: do we need float32? - # TODO: do we need one-hot? - state_policy = np.zeros(self.n_angles * self.n_dim + 1, dtype=np.float32) - # Angles - state_policy[: self.n_dim * self.n_angles][ - (np.arange(self.n_dim) * self.n_angles + state[: self.n_dim]) - ] = 1 - # Number of actions - state_policy[-1] = state[-1] - return state_policy - - def statebatch2policy(self, states: List[List]) -> npt.NDArray[np.float32]: - """ - Transforms a batch of states into the policy model format. The output is a numpy - array of shape [n_states, n_angles * n_dim + 1]. - - See state2policy(). - """ - states = np.array(states) - cols = states[:, :-1] + np.arange(self.n_dim) * self.n_angles - rows = np.repeat(np.arange(states.shape[0]), self.n_dim) - state_policy = np.zeros( - (len(states), self.n_angles * self.n_dim + 1), dtype=np.float32 - ) - state_policy[rows, cols.flatten()] = 1.0 - state_policy[:, -1] = states[:, -1] - return state_policy - - def statetorch2policy( - self, states: TensorType["batch", "state_dim"] - ) -> TensorType["batch", "policy_output_dim"]: - """ - Transforms a batch of torch states into the policy model format. The output is - a tensor of shape [n_states, n_angles * n_dim + 1]. + - policy format: [0, 1, 0, 0, 0, 0, 0, 1, 4] + | 1 | 3 | 4 | + Args + ---- + states : list or tensor + A batch of states in environment format, either as a list of states or as a + single tensor. - See state2policy(). + Returns + ------- + A tensor containing all the states in the batch. """ - device = states.device - cols = ( - states[:, :-1] + torch.arange(self.n_dim).to(device) * self.n_angles - ).to(int) + states = tlong(states, device=self.device) + cols = states[:, :-1] + torch.arange(self.n_dim).to(self.device) * self.n_angles rows = torch.repeat_interleave( - torch.arange(states.shape[0]).to(device), self.n_dim + torch.arange(states.shape[0]).to(self.device), self.n_dim ) - state_policy = torch.zeros( + states_policy = torch.zeros( (states.shape[0], self.n_angles * self.n_dim + 1) ).to(states) - state_policy[rows, cols.flatten()] = 1.0 - state_policy[:, -1] = states[:, -1] - return state_policy - - def policy2state(self, state_policy: List) -> List: - """ - Transforms the one-hot encoding version of a state given as argument - into a state (list of the position at each dimension). - - Example, n_dim = 2, n_angles = 4: - - state_policy: [0, 1, 0, 0, 0, 0, 0, 1, 4] - | 0 | 3 | 4 | - - policy2state(state_policy): [1, 3, 4] - | a | n | (a = angles, n = n_actions) - """ - mat_angles_policy = np.reshape( - state_policy[: self.n_dim * self.n_angles], (self.n_dim, self.n_angles) - ) - angles = np.where(mat_angles_policy)[1].tolist() - return angles + [int(state_policy[-1])] + states_policy[rows, cols.flatten()] = 1.0 + states_policy[:, -1] = states[:, -1] + return states_policy.to(self.float) def state2readable(self, state: Optional[List] = None) -> str: """ diff --git a/gflownet/envs/tree.py b/gflownet/envs/tree.py index bc92cda13..13684f4a3 100644 --- a/gflownet/envs/tree.py +++ b/gflownet/envs/tree.py @@ -285,13 +285,11 @@ def __init__( # Conversions policy_format = policy_format.lower() if policy_format == "mlp": - self.state2policy = self.state2policy_mlp - self.statetorch2policy = self.statetorch2policy_mlp + self.states2policy = self.states2policy_mlp elif policy_format != "gnn": raise ValueError( f"Unrecognized policy_format = {policy_format}, expected either 'mlp' or 'gnn'." ) - self.statetorch2oracle = self.statetorch2policy super().__init__( fixed_distr_params=fixed_distr_params, @@ -830,24 +828,19 @@ def get_logprobs( is_backward, ) - def state2policy_mlp( - self, state: Optional[TensorType["state_dim"]] = None - ) -> TensorType["policy_input_dim"]: - """ - Prepares a state in "GFlowNet format" for the policy model. - """ - if state is None: - state = self.state.clone().detach() - return self.statetorch2policy_mlp(state.unsqueeze(0))[0] - - def statetorch2policy_mlp( - self, states: TensorType["batch_size", "state_dim"] + def states2policy_mlp( + self, + states: Union[ + List[TensorType["state_dim"]], TensorType["batch_size", "state_dim"] + ], ) -> TensorType["batch_size", "policy_input_dim"]: """ Prepares a batch of states in torch "GFlowNet format" for an MLP policy model. It replaces the NaNs by -2s, removes the activity attribute, and explicitly appends the attribute vector of the active node (if present). """ + if isinstance(states, list): + states = torch.stack(states) rows, cols = torch.where(states[:, :-1, Attribute.ACTIVE] == Status.ACTIVE) active_features = torch.full((states.shape[0], 1, 4), -2.0) active_features[rows] = states[rows, cols, : Attribute.ACTIVE].unsqueeze(1) @@ -855,28 +848,6 @@ def statetorch2policy_mlp( states = torch.cat([states[:, :, : Attribute.ACTIVE], active_features], dim=1) return states.flatten(start_dim=1) - def policy2state( - self, policy: Optional[TensorType["policy_input_dim"]] = None - ) -> None: - """ - Returns None to signal that the conversion is not reversible. - """ - return None - - def statebatch2proxy( - self, states: List[TensorType["state_dim"]] - ) -> TensorType["batch", "state_proxy_dim"]: - """ - Prepares a batch of states in "GFlowNet format" for the proxy: simply - stacks the list of tensors and calls self.statetorch2proxy. - - Args - ---- - state : list - """ - states = torch.stack(states) - return self.statetorch2proxy(states) - def _attributes_to_readable(self, attributes: List) -> str: # Node type if attributes[Attribute.TYPE] == NodeType.CONDITION: diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 06777a851..fdbea6e76 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -325,7 +325,7 @@ def sample_actions( # Check for at least one non-random action if idx_norandom.sum() > 0: states_policy = tfloat( - self.env.statebatch2policy( + self.env.states2policy( [s for s, do in zip(states, idx_norandom) if do] ), device=self.device, @@ -1036,8 +1036,8 @@ def test(self, **plot_kwargs): assert batch.is_valid() x_sampled = batch.get_terminating_states() # TODO make it work with conditional env - x_sampled = torch2np(self.env.statebatch2proxy(x_sampled)) - x_tt = torch2np(self.env.statebatch2proxy(x_tt)) + x_sampled = torch2np(self.env.states2proxy(x_sampled)) + x_tt = torch2np(self.env.states2proxy(x_tt)) kde_pred = self.env.fit_kde( x_sampled, kernel=self.logger.test.kde.kernel, @@ -1051,7 +1051,7 @@ def test(self, **plot_kwargs): x_from_reward = self.env.sample_from_reward( n_samples=self.logger.test.n ) - x_from_reward = torch2np(self.env.statetorch2proxy(x_from_reward)) + x_from_reward = torch2np(self.env.states2proxy(x_from_reward)) # Fit KDE with samples from reward kde_true = self.env.fit_kde( x_from_reward, @@ -1332,7 +1332,7 @@ def logq(traj_list, actions_list, model, env): with torch.no_grad(): logits_traj = model( tfloat( - env.statebatch2policy(traj), + env.states2policy(traj), device=self.device, float_type=self.float, ) diff --git a/gflownet/proxy/aptamers.py b/gflownet/proxy/aptamers.py deleted file mode 100644 index 338c72347..000000000 --- a/gflownet/proxy/aptamers.py +++ /dev/null @@ -1,35 +0,0 @@ -import numpy as np -import numpy.typing as npt - -from gflownet.proxy.base import Proxy - - -class Aptamers(Proxy): - """ - DNA Aptamer oracles - """ - - def __init__(self, oracle_id, norm): - super().__init__() - self.type = oracle_id - self.norm = norm - - def setup(self, env=None): - self.max_seq_length = env.max_seq_length - - def __call__(self, states: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: - """ - args: - states : ndarray - """ - - def _length(x): - if self.norm: - return -1.0 * np.sum(x, axis=1) / self.max_seq_length - else: - return -1.0 * np.sum(x, axis=1) - - if self.type == "length": - return _length(states) - else: - raise NotImplementedError("self.type must be length") diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index a35f01ddf..9ab3dc24d 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -408,12 +408,7 @@ def states2policy( self.get_states_of_trajectory(traj_idx, states, traj_indices) ) return states_policy - # TODO: do we need tfloat or is done in env.statebatch2policy? - return tfloat( - self.env.statebatch2policy(states), - device=self.device, - float_type=self.float, - ) + return self.env.states2policy(states) def states2proxy( self, @@ -461,7 +456,7 @@ def states2proxy( if traj_idx not in traj_indices: continue states_proxy.append( - self.envs[traj_idx].statebatch2proxy( + self.envs[traj_idx].states2proxy( self.get_states_of_trajectory(traj_idx, states, traj_indices) ) ) @@ -471,7 +466,7 @@ def states2proxy( index[perm_index] = index.clone() states_proxy = concat_items(states_proxy, index) return states_proxy - return self.env.statebatch2proxy(states) + return self.env.states2proxy(states) def get_actions(self) -> TensorType["n_states, action_dim"]: """ @@ -678,13 +673,7 @@ def _compute_parents_all(self): self.parents_all.extend(parents) self.parents_actions_all.extend(parents_a) self.parents_all_indices.extend([idx] * len(parents)) - self.parents_all_policy.append( - tfloat( - self.envs[traj_idx].statebatch2policy(parents), - device=self.device, - float_type=self.float, - ) - ) + self.parents_all_policy.append(self.envs[traj_idx].states2policy(parents)) # Convert to tensors self.parents_actions_all = tfloat( self.parents_actions_all, diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index b5d3d3e42..17965bd93 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -250,7 +250,7 @@ def make_data_set(self, config): samples = self.env.get_random_terminating_states(config.n) else: return None, None - energies = self.env.oracle(self.env.statebatch2oracle(samples)).tolist() + energies = self.env.proxy(self.env.states2proxy(samples)).tolist() df = pd.DataFrame( { "samples": [self.env.state2readable(s) for s in samples], diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index e7e7f8afd..cfe81f40d 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -202,36 +202,36 @@ def batch_with_rest(start, stop, step, tensor=False): def tfloat(x, device, float_type): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(float_type).to(device) + return torch.stack(x).to(device=device, dtype=float_type) if torch.is_tensor(x): - return x.type(float_type).to(device) + return x.to(device=device, dtype=float_type) else: return torch.tensor(x, dtype=float_type, device=device) def tlong(x, device): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(torch.long).to(device) + return torch.stack(x).to(device=device, dtype=torch.long) if torch.is_tensor(x): - return x.type(torch.long).to(device) + return x.to(device=device, dtype=torch.long) else: return torch.tensor(x, dtype=torch.long, device=device) def tint(x, device, int_type): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(int_type).to(device) + return torch.stack(x).to(device=device, dtype=int_type) if torch.is_tensor(x): - return x.type(int_type).to(device) + return x.to(device=device, dtype=int_type) else: return torch.tensor(x, dtype=int_type, device=device) def tbool(x, device): if isinstance(x, list) and torch.is_tensor(x[0]): - return torch.stack(x).type(torch.bool).to(device) + return torch.stack(x).to(device=device, dtype=torch.bool) if torch.is_tensor(x): - return x.type(torch.bool).to(device) + return x.to(device=device, dtype=torch.bool) else: return torch.tensor(x, dtype=torch.bool, device=device) diff --git a/main.py b/main.py index 127e59001..62a5d064a 100644 --- a/main.py +++ b/main.py @@ -77,7 +77,7 @@ def main(config): if config.n_samples > 0 and config.n_samples <= 1e5: batch, times = gflownet.sample_batch(n_forward=config.n_samples, train=False) x_sampled = batch.get_terminating_states(proxy=True) - energies = env.oracle(x_sampled) + energies = env.proxy(x_sampled) x_sampled = batch.get_terminating_states() df = pd.DataFrame( { diff --git a/tests/gflownet/envs/common.py b/tests/gflownet/envs/common.py index 862c8bdcd..fd8c296bf 100644 --- a/tests/gflownet/envs/common.py +++ b/tests/gflownet/envs/common.py @@ -213,16 +213,6 @@ def test__sample_backwards_reaches_source(env, n=100): assert n_actions <= env.max_traj_length -@pytest.mark.repeat(100) -def test__state2policy__is_reversible(env): - env = env.reset() - while not env.done: - state_recovered = env.policy2state(env.state2policy()) - if state_recovered is not None: - assert env.equal(env.state, state_recovered) - env.step_random() - - @pytest.mark.repeat(100) def test__state2readable__is_reversible(env): env = env.reset() diff --git a/tests/gflownet/envs/test_ccube.py b/tests/gflownet/envs/test_ccube.py index 25181ed4e..462eaf2f9 100644 --- a/tests/gflownet/envs/test_ccube.py +++ b/tests/gflownet/envs/test_ccube.py @@ -1097,10 +1097,8 @@ def test__state2policy_returns_expected(env, state, expected): ], ) @pytest.mark.skip(reason="skip while developping other tests") -def test__statetorch2policy_returns_expected(env, states, expected): - assert torch.equal( - env.statetorch2policy(torch.tensor(states)), torch.tensor(expected) - ) +def test__states2policy_returns_expected(env, states, expected): + assert torch.equal(env.states2policy(torch.tensor(states)), torch.tensor(expected)) @pytest.mark.parametrize( diff --git a/tests/gflownet/envs/test_composition.py b/tests/gflownet/envs/test_composition.py index 31fcd470e..eb6add01e 100644 --- a/tests/gflownet/envs/test_composition.py +++ b/tests/gflownet/envs/test_composition.py @@ -4,6 +4,7 @@ import torch from gflownet.envs.crystals.composition import Composition +from gflownet.utils.common import tlong @pytest.fixture @@ -74,8 +75,8 @@ def test__environment__initializes_properly(elements): ), ], ) -def test__state2oracle__returns_expected_tensor(env, state, exp_tensor): - assert torch.equal(env.state2oracle(state), torch.Tensor(exp_tensor)) +def test__state2proxy__returns_expected_tensor(env, state, exp_tensor): + assert torch.equal(env.state2proxy(state), tlong(exp_tensor, device=env.device)) def test__state2readable(env): diff --git a/tests/gflownet/envs/test_crystal.py b/tests/gflownet/envs/test_crystal.py index 5d4c9cbed..75dae0167 100644 --- a/tests/gflownet/envs/test_crystal.py +++ b/tests/gflownet/envs/test_crystal.py @@ -110,8 +110,8 @@ def test__pad_depad_action(env): ], ], ) -def test__state2oracle__returns_expected_value(env, state, expected): - assert torch.allclose(env.state2oracle(state), expected, atol=1e-4) +def test__state2proxy__returns_expected_value(env, state, expected): + assert torch.allclose(env.state2proxy(state), expected, atol=1e-4) @pytest.mark.parametrize( @@ -216,8 +216,8 @@ def test__state2proxy__returns_expected_value(env, state, expected): ], ], ) -def test__statebatch2proxy__returns_expected_value(env, batch, expected): - assert torch.allclose(env.statebatch2proxy(batch), expected, atol=1e-4) +def test__states2proxy__returns_expected_value(env, batch, expected): + assert torch.allclose(env.states2proxy(batch), expected, atol=1e-4) @pytest.mark.parametrize("action", [(1, 1, -2, -2, -2, -2), (3, 4, -2, -2, -2, -2)]) diff --git a/tests/gflownet/envs/test_grid.py b/tests/gflownet/envs/test_grid.py index 2bf3ea4bf..afa375649 100644 --- a/tests/gflownet/envs/test_grid.py +++ b/tests/gflownet/envs/test_grid.py @@ -3,6 +3,7 @@ import torch from gflownet.envs.grid import Grid +from gflownet.utils.common import tfloat @pytest.fixture @@ -45,7 +46,7 @@ def config_path(): @pytest.mark.parametrize( - "state, state2oracle", + "state, state2proxy", [ ( [0, 0, 0], @@ -65,12 +66,15 @@ def config_path(): ), ], ) -def test__state2oracle__returns_expected(env, state, state2oracle): - assert state2oracle == env.state2oracle(state) +def test__state2proxy__returns_expected(env, state, state2proxy): + assert torch.equal( + tfloat(state2proxy, device=env.device, float_type=env.float), + env.state2proxy(state), + ) @pytest.mark.parametrize( - "states, statebatch2oracle", + "states, states2proxy", [ ( [[0, 0, 0], [4, 4, 4], [1, 2, 3], [4, 0, 1]], @@ -78,8 +82,8 @@ def test__state2oracle__returns_expected(env, state, state2oracle): ), ], ) -def test__statebatch2oracle__returns_expected(env, states, statebatch2oracle): - assert torch.equal(torch.Tensor(statebatch2oracle), env.statebatch2oracle(states)) +def test__states2proxy__returns_expected(env, states, states2proxy): + assert torch.equal(torch.Tensor(states2proxy), env.states2proxy(states)) @pytest.mark.parametrize( diff --git a/tests/gflownet/envs/test_lattice_parameters.py b/tests/gflownet/envs/test_lattice_parameters.py index 16aea2814..281d30783 100644 --- a/tests/gflownet/envs/test_lattice_parameters.py +++ b/tests/gflownet/envs/test_lattice_parameters.py @@ -282,8 +282,8 @@ def test__step__changes_state_as_expected(env, lattice_system, actions, exp_stat ), ], ) -def test__state2oracle__returns_expected_tensor(env, lattice_system, state, exp_tensor): - assert torch.equal(env.state2oracle(state), torch.Tensor(exp_tensor)) +def test__state2proxy__returns_expected_tensor(env, lattice_system, state, exp_tensor): + assert torch.equal(env.state2proxy(state), torch.Tensor(exp_tensor)) @pytest.mark.parametrize("lattice_system", [TRICLINIC]) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 338dfd061..776bba324 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -123,14 +123,7 @@ def test__get_states__single_env_returns_expected(env, batch, request): assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) @pytest.mark.repeat(N_REPETITIONS) @@ -155,14 +148,7 @@ def test__get_parents__single_env_returns_expected(env, batch, request): assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) @pytest.mark.repeat(N_REPETITIONS) @@ -197,14 +183,7 @@ def test__get_parents_all__single_env_returns_expected(env, batch, request): float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) @pytest.mark.repeat(N_REPETITIONS) @@ -365,14 +344,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -399,14 +371,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -423,14 +388,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -447,14 +405,7 @@ def test__forward_sampling_multiple_envs_all_as_expected(env, proxy, batch, requ ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) @@ -551,14 +502,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -585,14 +529,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -609,14 +546,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -633,14 +563,7 @@ def test__backward_sampling_multiple_envs_all_as_expected(env, proxy, batch, req ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) @@ -794,14 +717,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -828,14 +744,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -852,14 +761,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -876,14 +778,7 @@ def test__mixed_sampling_multiple_envs_all_as_expected(env, proxy, batch, reques ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS) @@ -1043,14 +938,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): assert torch.equal(torch.stack(states_batch), torch.stack(states)) else: assert states_batch == states - assert torch.equal( - states_policy_batch, - tfloat( - env.statebatch2policy(states), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_policy_batch, env.states2policy(states)) # Check actions actions_batch = batch.get_actions() assert torch.equal( @@ -1077,14 +965,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): assert torch.equal(torch.stack(parents_batch), torch.stack(parents)) else: assert parents_batch == parents - assert torch.equal( - parents_policy_batch, - tfloat( - env.statebatch2policy(parents), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_policy_batch, env.states2policy(parents)) # Check parents_all if not env.continuous: parents_all_batch, parents_all_a_batch, _ = batch.get_parents_all() @@ -1101,14 +982,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): float_type=batch.float, ), ) - assert torch.equal( - parents_all_policy_batch, - tfloat( - env.statebatch2policy(parents_all), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(parents_all_policy_batch, env.states2policy(parents_all)) # Check rewards rewards_batch = batch.get_rewards() rewards = torch.stack(rewards) @@ -1125,14 +999,7 @@ def test__mixed_sampling_merged_all_as_expected(env, proxy, request): ) else: assert states_term_batch == states_term_sorted - assert torch.equal( - states_term_policy_batch, - tfloat( - env.statebatch2policy(states_term_sorted), - device=batch.device, - float_type=batch.float, - ), - ) + assert torch.equal(states_term_policy_batch, env.states2policy(states_term_sorted)) @pytest.mark.repeat(N_REPETITIONS)