Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify state conversions (excludes Crystal environments) #247

Merged
merged 60 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
737cda8
checkout files from test_merge_cont_mf
nikita-0209 Apr 11, 2023
afdb762
remove comments
nikita-0209 Apr 11, 2023
e7cd7a5
Update to a slightly faster version of get_parents()
nikita-0209 Apr 11, 2023
b57c695
Merge branch 'main_seq' into sequence-envs
alexhernandezgarcia Aug 1, 2023
7567fe3
Move sequence config files into new dir seqs/
alexhernandezgarcia Aug 1, 2023
2ba9ae4
Move sequence proxy config files into new dir seqs/
alexhernandezgarcia Aug 1, 2023
e607efc
Move sequence env *.py files into new dir seqs/
alexhernandezgarcia Aug 1, 2023
690fde0
wip: refactor of parent sequence class
alexhernandezgarcia Aug 1, 2023
beed5df
Replace original README with info about the new private repository.
alexhernandezgarcia Oct 18, 2023
40c7f33
Merge remote-tracking branch 'public/main' into main-public
alexhernandezgarcia Oct 26, 2023
9a4829d
WIP: few changes to sequence base
alexhernandezgarcia Oct 27, 2023
02ab997
Merge remote-tracking branch 'public/sequence-envs' into main
alexhernandezgarcia Oct 27, 2023
b4919d6
Finished __init__ for now.
alexhernandezgarcia Oct 27, 2023
8ce60ab
Docstring of get_action_space.
alexhernandezgarcia Oct 27, 2023
d280222
Merge pull request #2 from alexhernandezgarcia/main-public
alexhernandezgarcia Oct 29, 2023
1c17674
Merge branch 'main' of github.com:alexhernandezgarcia/gflownet-dev in…
alexhernandezgarcia Oct 29, 2023
8b0b621
Remove env.oracle and leave proxy only
alexhernandezgarcia Oct 29, 2023
0088ee4
Grid: states2proxy, states2policy; temporary state because old code i…
alexhernandezgarcia Oct 30, 2023
c83d8b3
docstring
alexhernandezgarcia Oct 30, 2023
04bb59f
Tetris: states2proxy, states2policy; temporary state because old code…
alexhernandezgarcia Oct 30, 2023
41962f9
Refactoring of tfloat, tlogn, tint and tbool
alexhernandezgarcia Oct 30, 2023
e446262
Typo in typing and correct docstrings
alexhernandezgarcia Oct 30, 2023
be67b5c
Cube: states2policy; temporary state because old code is still there.
alexhernandezgarcia Oct 30, 2023
89ebdaf
Fix typos
alexhernandezgarcia Oct 31, 2023
43bf2ea
Continuous Tori: states2proxy, states2policy; temporary state because…
alexhernandezgarcia Oct 31, 2023
7593578
Alanine Dipeptide: states2proxy; temporary state because old code is …
alexhernandezgarcia Oct 31, 2023
72a31ea
Missing imports
alexhernandezgarcia Oct 31, 2023
f7a8d2b
Discrete Torus: states2proxy, states2policy; temporary state because …
alexhernandezgarcia Oct 31, 2023
7201991
Tree: remove mention to oracle
alexhernandezgarcia Oct 31, 2023
127a017
Composition: states2proxy; temporary state because old code is still …
alexhernandezgarcia Oct 31, 2023
650dfba
Lattice parameters: states2proxy; temporary state because old code is…
alexhernandezgarcia Oct 31, 2023
cdf4961
Space group: states2proxy; temporary state because old code is still …
alexhernandezgarcia Oct 31, 2023
e7ee4af
Crystal: states2proxy; temporary state because old code is still there.
alexhernandezgarcia Oct 31, 2023
6792ebd
Base: statebatch2proxy and statetorch2proxy unified into states2proxy…
alexhernandezgarcia Oct 31, 2023
c5c2fd9
Remove policy2state from all environments and tests because it is not…
alexhernandezgarcia Oct 31, 2023
7c2ba15
Tree: states2proxy
alexhernandezgarcia Oct 31, 2023
df5af32
All environments: remove statebatch2proxy, statetorch2proxy, statebat…
alexhernandezgarcia Oct 31, 2023
9107d63
All tests: remove statebatch2proxy, statetorch2proxy, statebatch2poli…
alexhernandezgarcia Oct 31, 2023
ed03294
gflownet, buffer and batch: remove statebatch2proxy, statetorch2proxy…
alexhernandezgarcia Oct 31, 2023
c9eac79
Fix how policy_input_dim is computed
alexhernandezgarcia Oct 31, 2023
828141c
Tetris: policy output to float
alexhernandezgarcia Oct 31, 2023
94c0c32
Fix typo
alexhernandezgarcia Oct 31, 2023
ad2c7f6
squeeze output of state2policy and state2proxy and revert to previous…
alexhernandezgarcia Oct 31, 2023
17407b5
Add missing import
alexhernandezgarcia Oct 31, 2023
1f817b1
Update composition and crystal tests
alexhernandezgarcia Oct 31, 2023
ea16fc2
Discrete torus: Remove state2policy and output of policy to float
alexhernandezgarcia Oct 31, 2023
55b8ed3
Update env.reward() in base
alexhernandezgarcia Oct 31, 2023
1e71a6d
Test batch: statebatch2policy -> states2policy
alexhernandezgarcia Oct 31, 2023
a50cd7d
Envs: remove state2proxy
alexhernandezgarcia Oct 31, 2023
3880e48
Remove reward_torchbatch because it is unused
alexhernandezgarcia Oct 31, 2023
03308a9
Add TODO
alexhernandezgarcia Oct 31, 2023
5e3c765
statetorch2 -> states2
alexhernandezgarcia Oct 31, 2023
388816c
Fix
alexhernandezgarcia Oct 31, 2023
14e30c1
Delete files not relevant to branch
alexhernandezgarcia Oct 31, 2023
ce38e3f
Delete files not relevant to branch
alexhernandezgarcia Oct 31, 2023
88d3fa6
Delete files not relevant to branch
alexhernandezgarcia Oct 31, 2023
f0e3feb
Delete files not relevant to branch
alexhernandezgarcia Oct 31, 2023
cc634cf
Fix tensor comparison
alexhernandezgarcia Nov 14, 2023
11d5fc4
Add typing in returns of state2policy and state2proxy of base env.
alexhernandezgarcia Nov 14, 2023
fe105d1
Resolve conflicts with main
alexhernandezgarcia Dec 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 7 additions & 35 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,39 +1,11 @@
# GFlowNet
# Private sister repository of gflownet

This repository implements GFlowNets, generative flow networks for probabilistic modelling, on PyTorch. A design guideline behind this implementation is the separation of the logic of the GFlowNet agent and the environments on which the agent can be trained on. In other words, this implementation should allow its extension with new environments without major or any changes to to the agent. Another design guideline is flexibility and modularity. The configuration is handled via the use of [Hydra](https://hydra.cc/docs/intro/).
This repository (`gflownet-dev`) is private. It is meant to be used to develop research ideas and projects before making them public in the original [alexhernandezgarcia/gflownet](https://github.com/alexhernandezgarcia/gflownet) repository (`gflownet`).

## Installation
As of October 2023, it is uncertain whether we will stick to this plan in the long term, but the idea is the following:

### pip
- Develop ideas and projects in `gflownet-dev`.
- Upon publication or whenever the authors feel comfortable, transfer the relevant code to `gflownet`.
- Relevant code improvements and development that does not compromise research projects should be transferred to `gflownet` as early as possible.

```bash
python -m pip install --upgrade https://github.com/alexhernandezgarcia/gflownet/archive/main.zip
```

## How to train a GFlowNet model

To train a GFlowNet model with the default configuration, simply run

```bash
python main.py user.logdir.root=<path/to/log/files/>
```

Alternatively, you can create a user configuration file in `config/user/<username>.yaml` specifying a `logdir.root` and run

```bash
python main.py user=<username>
```

Using Hydra, you can easily specify any variable of the configuration in the command line. For example, to train GFlowNet with the trajectory balance loss, on the continuous torus (`ctorus`) environment and the corresponding proxy:

```bash
python main.py gflownet=trajectorybalance env=ctorus proxy=torus
```

The above command will overwrite the `env` and `proxy` default configuration with the configuration files in `config/env/ctorus.yaml` and `config/proxy/torus.yaml` respectively.

Hydra configuration is hierarchical. For instance, a handy variable to change while debugging our code is to avoid logging to wandb. You can do this by setting `logger.do.online=False`.

## Logging to wandb

The repository supports logging of train and evaluation metrics to [wandb.ai](https://wandb.ai), but it is disabled by default. In order to enable it, set the configuration variable `logger.do.online` to `True`.
This involves extra complexity, so we will re-evaluate or refine this plan after a test period.
16 changes: 0 additions & 16 deletions config/env/aptamers.yaml

This file was deleted.

42 changes: 26 additions & 16 deletions gflownet/envs/alaninedipeptide.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import List, Tuple
from typing import List, Tuple, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -40,25 +40,34 @@ def sync_conformer_with_state(self, state: List = None):
self.conformer.set_torsion_angle(ta, state[idx])
return self.conformer

def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray:
# TODO: are the conversions to oracle relevant?
def states2proxy(
self, states: Union[List[List], TensorType["batch", "state_dim"]]
) -> npt.NDArray:
"""
Prepares a batch of states in torch "GFlowNet format" for the oracle.
"""
device = states.device
if device == torch.device("cpu"):
np_states = states.numpy()
else:
np_states = states.cpu().numpy()
return np_states[:, :-1]

def statebatch2proxy(self, states: List[List]) -> npt.NDArray:
"""
Prepares a batch of states in "GFlowNet format" for the proxy: a tensor where
each state is a row of length n_dim with an angle in radians. The n_actions
Prepares a batch of states in "environment format" for the proxy: each state is
a vector of length n_dim where each value is an angle in radians. The n_actions
item is removed.

Important: this method returns a numpy array, unlike in most other
environments.

Args
----
states : list or tensor
A batch of states in environment format, either as a list of states or as a
single tensor.

Returns
-------
A numpy array containing all the states in the batch.
"""
return np.array(states)[:, :-1]
if torch.is_tensor(states[0]):
return states.cpu().numpy()[:, :-1]
else:
return np.array(states)[:, :-1]

# TODO: need to keep?
def statetorch2oracle(
self, states: TensorType["batch", "state_dim"]
) -> List[Tuple[npt.NDArray, npt.NDArray]]:
Expand All @@ -73,6 +82,7 @@ def statetorch2oracle(
result = self.statebatch2oracle(np_states)
return result

# TODO: need to keep?
def statebatch2oracle(
self, states: List[List]
) -> List[Tuple[npt.NDArray, npt.NDArray]]:
Expand Down
Loading
Loading