Skip to content

Commit

Permalink
[Feature] MCTSForest
Browse files Browse the repository at this point in the history
ghstack-source-id: efee967055c8dab87bbef442451ce20f21730b13
Pull Request resolved: #2307
  • Loading branch information
vmoens committed Jul 22, 2024
1 parent 8bbb86f commit dda9371
Show file tree
Hide file tree
Showing 4 changed files with 453 additions and 0 deletions.
2 changes: 2 additions & 0 deletions test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def test_map_rollout(self):
assert contains[: rollout.shape[-1]].all()
assert not contains[rollout.shape[-1] :].any()

class TestMCTSForest:
...

if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
3 changes: 3 additions & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from .map import (
BinaryToDecimal,
HashToInt,
MCTSChildren,
MCTSForest,
MCTSNode,
QueryModule,
RandomProjectionHash,
SipHash,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
from .query import HashToInt, QueryModule
from .tdstorage import TensorDictMap, TensorMap
from .tree import MCTSChildren, MCTSForest, MCTSNode
Loading

0 comments on commit dda9371

Please sign in to comment.