Skip to content

Commit

Permalink
[Feature] MCTSForest
Browse files Browse the repository at this point in the history
ghstack-source-id: 761863077202685046b093070361970c88081e36
Pull Request resolved: #2307
  • Loading branch information
vmoens committed Jul 23, 2024
1 parent b2536e3 commit d1bc98a
Show file tree
Hide file tree
Showing 4 changed files with 457 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ def test_map_rollout(self):
assert contains[: rollout.shape[-1]].all()
assert not contains[rollout.shape[-1] :].any()

class TestMCTSForest:
def test_forest_build(self):
...
def test_forest_extend_and_get(self):
...


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
3 changes: 3 additions & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from .map import (
BinaryToDecimal,
HashToInt,
MCTSChildren,
MCTSForest,
MCTSNode,
QueryModule,
RandomProjectionHash,
SipHash,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
from .query import HashToInt, QueryModule
from .tdstorage import TensorDictMap, TensorMap
from .tree import MCTSChildren, MCTSForest, MCTSNode
Loading

0 comments on commit d1bc98a

Please sign in to comment.