Skip to content

Commit

Permalink
Fix tqdm too early report for parallel pool (#1053)
Browse files Browse the repository at this point in the history
* Fix tqdm too early report for parallel pool

* fix typo

* simplifying parallelization warnings
  • Loading branch information
yannikschaelte authored Apr 29, 2023
1 parent efdae5b commit b0ad7f1
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 33 deletions.
6 changes: 4 additions & 2 deletions pypesto/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Abstract engine base class."""
import abc
from typing import List
from typing import Any, List

from .task import Task

Expand All @@ -12,7 +12,9 @@ def __init__(self):
pass

@abc.abstractmethod
def execute(self, tasks: List[Task], progress_bar: bool = True):
def execute(
self, tasks: List[Task], progress_bar: bool = True
) -> List[Any]:
"""Execute tasks.
Parameters
Expand Down
11 changes: 5 additions & 6 deletions pypesto/engine/mpi_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Engines with multi-node parallelization."""
import logging
from typing import List
from typing import Any, List

import cloudpickle as pickle
from mpi4py import MPI
Expand Down Expand Up @@ -31,7 +31,9 @@ class MPIPoolEngine(Engine):
def __init__(self):
super().__init__()

def execute(self, tasks: List[Task], progress_bar: bool = True):
def execute(
self, tasks: List[Task], progress_bar: bool = True
) -> List[Any]:
"""
Pickle tasks and distribute work to workers.
Expand All @@ -45,10 +47,7 @@ def execute(self, tasks: List[Task], progress_bar: bool = True):
pickled_tasks = [pickle.dumps(task) for task in tasks]

n_procs = MPI.COMM_WORLD.Get_size() # Size of communicator
logger.info(
f"Performing parallel task execution on {n_procs-1} "
f"workers with one manager."
)
logger.info(f"Parallelizing on {n_procs-1} workers with one manager.")

with MPIPoolExecutor() as executor:
results = executor.map(
Expand Down
24 changes: 13 additions & 11 deletions pypesto/engine/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import multiprocessing
import os
from typing import List
from typing import Any, List

import cloudpickle as pickle
from tqdm import tqdm
Expand Down Expand Up @@ -41,15 +41,15 @@ def __init__(self, n_procs: int = None, method: str = None):

if n_procs is None:
n_procs = os.cpu_count()
logger.warning(
f"Engine set up to use up to {n_procs} processes in total. "
f"The number was automatically determined and might not be "
f"appropriate on some systems."
logger.info(
f"Engine will use up to {n_procs} processes (= CPU count)."
)
self.n_procs: int = n_procs
self.method: str = method

def execute(self, tasks: List[Task], progress_bar: bool = True):
def execute(
self, tasks: List[Task], progress_bar: bool = True
) -> List[Any]:
"""Pickle tasks and distribute work over parallel processes.
Parameters
Expand All @@ -64,15 +64,17 @@ def execute(self, tasks: List[Task], progress_bar: bool = True):
pickled_tasks = [pickle.dumps(task) for task in tasks]

n_procs = min(self.n_procs, n_tasks)
logger.info(
f"Performing parallel task execution on {n_procs} " f"processes."
)
logger.debug(f"Parallelizing on {n_procs} processes.")

ctx = multiprocessing.get_context(method=self.method)

with ctx.Pool(processes=n_procs) as pool:
results = pool.map(
work, tqdm(pickled_tasks, disable=not progress_bar)
results = list(
tqdm(
pool.imap(work, pickled_tasks),
total=len(pickled_tasks),
disable=not progress_bar,
),
)

return results
24 changes: 13 additions & 11 deletions pypesto/engine/multi_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import List
from typing import Any, List

from tqdm import tqdm

Expand Down Expand Up @@ -37,14 +37,14 @@ def __init__(self, n_threads: int = None):

if n_threads is None:
n_threads = os.cpu_count()
logger.warning(
f"Engine set up to use up to {n_threads} processes in total. "
f"The number was automatically determined and might not be "
f"appropriate on some systems."
logger.info(
f"Engine will use up to {n_threads} threads (= CPU count)."
)
self.n_threads: int = n_threads

def execute(self, tasks: List[Task], progress_bar: bool = True):
def execute(
self, tasks: List[Task], progress_bar: bool = True
) -> List[Any]:
"""Deepcopy tasks and distribute work over parallel threads.
Parameters
Expand All @@ -59,13 +59,15 @@ def execute(self, tasks: List[Task], progress_bar: bool = True):
copied_tasks = [copy.deepcopy(task) for task in tasks]

n_threads = min(self.n_threads, n_tasks)
logger.info(
f"Performing parallel task execution on {n_threads} " f"threads."
)
logger.debug(f"Parallelizing on {n_threads} threads.")

with ThreadPoolExecutor(max_workers=n_threads) as pool:
results = pool.map(
work, tqdm(copied_tasks, disable=not progress_bar)
results = list(
tqdm(
pool.map(work, copied_tasks),
total=len(copied_tasks),
disable=not progress_bar,
),
)

return results
6 changes: 4 additions & 2 deletions pypesto/engine/single_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Engines without parallelization."""
from typing import List
from typing import Any, List

from tqdm import tqdm

Expand All @@ -17,7 +17,9 @@ class SingleCoreEngine(Engine):
def __init__(self):
super().__init__()

def execute(self, tasks: List[Task], progress_bar: bool = True):
def execute(
self, tasks: List[Task], progress_bar: bool = True
) -> List[Any]:
"""Execute all tasks in a simple for loop sequentially.
Parameters
Expand Down
3 changes: 2 additions & 1 deletion pypesto/engine/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Abstract Task class."""
import abc
from typing import Any


class Task(abc.ABC):
Expand All @@ -15,5 +16,5 @@ def __init__(self):
pass

@abc.abstractmethod
def execute(self):
def execute(self) -> Any:
"""Execute the task and return its results."""

0 comments on commit b0ad7f1

Please sign in to comment.