Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionality to move the posterior to GPU/CPU #1368

Open
michaeldeistler opened this issue Jan 15, 2025 · 1 comment
Open

Functionality to move the posterior to GPU/CPU #1368

michaeldeistler opened this issue Jan 15, 2025 · 1 comment
Labels
enhancement New feature or request

Comments

@michaeldeistler
Copy link
Contributor

In many cases, it is useful to train the network on GPU, but to then do inference or diagnostics on CPU. This is currently a bit hacky, see also here.

It would be nice to have something like:

posterior.to("cpu")
@michaeldeistler michaeldeistler added the enhancement New feature or request label Jan 15, 2025
@ali-akhavan89
Copy link

ali-akhavan89 commented Jan 22, 2025

I am on SBI 0.23.2 and realized that in NFlowsFlow.sample() near:

def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
    condition_batch_dim = condition.shape[0]
    num_samples = torch.Size(sample_shape).numel()

    samples = self.net.sample(num_samples, context=condition)
    ...

If I add the following, it would fix a lot of cpu/gpu handling:

def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
  
    net_device = next(self.net.parameters()).device # this

    condition = condition.to(net_device) # this

    condition_batch_dim = condition.shape[0]
    num_samples = torch.Size(sample_shape).numel()

    samples = self.net.sample(num_samples, context=condition)
    ...
    return samples

I faced this issue when I tried to move everything to CPU after the training was done on GPU. I was getting this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 2
      1 for _ in range(num_rounds):
----> 2     theta = proposal.sample((num_simulations,)).cpu()
      3     x = simulator(theta).cpu()
      4     print(theta.device)

File ~\.conda\envs\ptgpu\Lib\site-packages\sbi\inference\posteriors\direct_posterior.py:134, in DirectPosterior.sample(self, sample_shape, x, max_sampling_batch_size, sample_with, show_progress_bars)
    127 if sample_with is not None:
    128     raise ValueError(
    129         f"You set `sample_with={sample_with}`. As of sbi v0.18.0, setting "
    130         f"`sample_with` is no longer supported. You have to rerun "
    131         f"`.build_posterior(sample_with={sample_with}).`"
    132     )
--> 134 samples = rejection.accept_reject_sample(
    135     proposal=self.posterior_estimator,
    136     accept_reject_fn=lambda theta: within_support(self.prior, theta),
    137     num_samples=num_samples,
    138     show_progress_bars=show_progress_bars,
    139     max_sampling_batch_size=max_sampling_batch_size,
    140     proposal_sampling_kwargs={"condition": x},
    141     alternative_method="build_posterior(..., sample_with='mcmc')",
    142 )[0]  # [0] to return only samples, not acceptance probabilities.
    144 return samples[:, 0]

File ~\.conda\envs\ptgpu\Lib\site-packages\torch\utils\_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~\.conda\envs\ptgpu\Lib\site-packages\sbi\samplers\rejection\rejection.py:281, in accept_reject_sample(proposal, accept_reject_fn, num_samples, show_progress_bars, warn_acceptance, sample_for_correction_factor, max_sampling_batch_size, proposal_sampling_kwargs, alternative_method, **kwargs)
    278 num_samples_possible = 0
    279 while num_remaining > 0:
    280     # Sample and reject.
--> 281     candidates = proposal.sample(
    282         (sampling_batch_size,),  # type: ignore
    283         **proposal_sampling_kwargs,
    284     )
    285     # SNPE-style rejection-sampling when the proposal is the neural net.
    286     are_accepted = accept_reject_fn(candidates)

File ~\.conda\envs\ptgpu\Lib\site-packages\sbi\neural_nets\estimators\nflows_flow.py:142, in NFlowsFlow.sample(self, sample_shape, condition)
    139 condition_batch_dim = condition.shape[0]
    140 num_samples = torch.Size(sample_shape).numel()
--> 142 samples = self.net.sample(num_samples, context=condition)
    143 # Change from Nflows' convention of (batch_dim, sample_dim, *event_shape) to
    144 # (sample_dim, batch_dim, *event_shape) (PyTorch + SBI).
    145 samples = samples.transpose(0, 1)

File ~\.conda\envs\ptgpu\Lib\site-packages\nflows\distributions\base.py:65, in Distribution.sample(self, num_samples, context, batch_size)
     62     context = torch.as_tensor(context)
     64 if batch_size is None:
---> 65     return self._sample(num_samples, context)
     67 else:
     68     if not check.is_positive_int(batch_size):

File ~\.conda\envs\ptgpu\Lib\site-packages\nflows\flows\base.py:44, in Flow._sample(self, num_samples, context)
     43 def _sample(self, num_samples, context):
---> 44     embedded_context = self._embedding_net(context)
     45     noise = self._distribution.sample(num_samples, context=embedded_context)
     47     if embedded_context is not None:
     48         # Merge the context dimension with sample dimension in order to apply the transform.

File ~\.conda\envs\ptgpu\Lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~\.conda\envs\ptgpu\Lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\.conda\envs\ptgpu\Lib\site-packages\torch\nn\modules\container.py:250, in Sequential.forward(self, input)
    248 def forward(self, input):
    249     for module in self:
--> 250         input = module(input)
    251     return input

File ~\.conda\envs\ptgpu\Lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~\.conda\envs\ptgpu\Lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~\.conda\envs\ptgpu\Lib\site-packages\sbi\utils\sbiutils.py:252, in Standardize.forward(self, tensor)
    251 def forward(self, tensor):
--> 252     return (tensor - self._mean) / self._std

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Hope it helps,
Ali


Update: I wasn't using TSNPE in the example above. With TSNPE, now, I'm getting a similar error when using:
accept_reject_fn = get_density_thresholder(posterior, quantile=1e-4, num_samples_to_estimate_support=10_000)
I will try to troubleshoot and share my understanding here if it helps. Thanks again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants