Skip to content

Commit

Permalink
[Feature] inline hold_out_net
Browse files Browse the repository at this point in the history
ghstack-source-id: c315202c8af55f0852195fe488ae855966386c4c
Pull Request resolved: #2499
  • Loading branch information
vmoens committed Oct 17, 2024
1 parent d894358 commit 815eece
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,11 +408,18 @@ def __init__(self, network: nn.Module) -> None:

def __enter__(self) -> None:
if self.mode:
self.network.requires_grad_(False)
if is_dynamo_compiling():
self._params = TensorDict.from_module(self.network)
self._params.data.to_module(self.network)
else:
self.network.requires_grad_(False)

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if self.mode:
self.network.requires_grad_()
if is_dynamo_compiling():
self._params.to_module(self.network)
else:
self.network.requires_grad_()


class hold_out_params(_context_manager):
Expand Down

1 comment on commit 815eece

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 815eece Previous: d894358 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 36.6420406162568 iter/sec (stddev: 0.1617383042882511) 238.30340065427748 iter/sec (stddev: 0.000733430823282556) 6.50

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.