Skip to content

Commit

Permalink
[Feature] Randint on device for buffers
Browse files Browse the repository at this point in the history
ghstack-source-id: b055d47928161b6a081705872a91434d65a8b92a
Pull Request resolved: #2470
  • Loading branch information
vmoens committed Oct 9, 2024
1 parent 011ce2c commit c8b508e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,7 @@ def sample(self, storage, batch_size):
len(self._samplers),
(self.num_buffer_sampled,),
generator=self._rng,
device=getattr(storage, "device", None),
)
else:
buffer_ids = torch.multinomial(self.p, self.num_buffer_sampled, True)
Expand Down
11 changes: 9 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,13 @@ def _empty(self):
def _rand_given_ndim(self, batch_size):
# a method to return random indices given the storage ndim
if self.ndim == 1:
return torch.randint(0, len(self), (batch_size,), generator=self._rng)
return torch.randint(
0,
len(self),
(batch_size,),
generator=self._rng,
device=getattr(self, "device", None),
)
raise RuntimeError(
f"Random number generation is not implemented for storage of type {type(self)} with ndim {self.ndim}. "
f"Please report this exception as well as the use case (incl. buffer construction) on github."
Expand Down Expand Up @@ -507,7 +513,8 @@ def _rand_given_ndim(self, batch_size):
return super()._rand_given_ndim(batch_size)
shape = self.shape
return tuple(
torch.randint(_dim, (batch_size,), generator=self._rng) for _dim in shape
torch.randint(_dim, (batch_size,), generator=self._rng, device=self.device)
for _dim in shape
)

def flatten(self):
Expand Down

0 comments on commit c8b508e

Please sign in to comment.