Skip to content

Commit

Permalink
add control over the number of torch threads
Browse files Browse the repository at this point in the history
  • Loading branch information
minaskar committed Apr 24, 2024
1 parent 935ba54 commit 0783f26
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pocomc/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .flow import Flow
from .particles import Particles
from .geometry import Geometry
from .threading import configure_threads

class Sampler:
r"""Preconditioned Monte Carlo class.
Expand Down Expand Up @@ -39,6 +40,9 @@ class Sampler:
pool : pool
Provided ``MPI`` or ``multiprocessing`` pool for
parallelisation (default is ``pool=None``).
pytorch_threads : int
Maximum number of threads to use for torch. If ``None`` torch uses all
available threads while training the normalizing flow (default is ``pytorch_threads=1``).
flow : ``torch.nn.Module`` or ``None``
Normalizing flow (default is ``None``). The default is a Masked Autoregressive Flow
(MAF) with 6 blocks of 3x64 layers and residual connections.
Expand Down Expand Up @@ -87,6 +91,7 @@ def __init__(self,
likelihood_kwargs: dict = None,
vectorize: bool = False,
pool=None,
pytorch_threads=1,
flow=None,
train_config: dict = None,
precondition: bool = True,
Expand All @@ -105,6 +110,9 @@ def __init__(self,
torch.manual_seed(random_state)
self.random_state = random_state

# Configure threads
configure_threads(pytorch_threads=pytorch_threads)

# Prior
self.prior = prior
self.log_prior = self.prior.logpdf
Expand Down
21 changes: 21 additions & 0 deletions pocomc/threading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch

def configure_threads(pytorch_threads=None):
"""Configure the number of threads available.
This is necessary when using PyTorch on the CPU as by default it will use
all available threads.
Notes
-----
Uses ``torch.set_num_threads``. If pytorch threads is None but other
arguments are specified then the value is inferred from them.
Parameters
----------
pytorch_threads: int, optional
Maximum number of threads for PyTorch on CPU. If None, pytorch will
use all available threads.
"""
if pytorch_threads:
torch.set_num_threads(pytorch_threads)

0 comments on commit 0783f26

Please sign in to comment.