Skip to content

Commit

Permalink
Merge pull request #8309 from OpenMined/enhance/hb_898
Browse files Browse the repository at this point in the history
ADD Thread mode execution for jobs and subjobs.
  • Loading branch information
koenvanderveen authored Dec 12, 2023
2 parents a9359a6 + ebb59cd commit 8a12f6b
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 25 deletions.
5 changes: 5 additions & 0 deletions packages/hagrid/hagrid/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def deploy_to_python(
node_side_type: NodeSideType,
enable_warnings: bool,
n_consumers: int,
thread_workers: bool,
create_producer: bool = False,
queue_port: Optional[int] = None,
) -> Optional[NodeHandle]:
Expand Down Expand Up @@ -307,6 +308,7 @@ def deploy_to_python(
node_side_type=node_side_type,
enable_warnings=enable_warnings,
n_consumers=n_consumers,
thread_workers=thread_workers,
create_producer=create_producer,
queue_port=queue_port,
migrate=True,
Expand Down Expand Up @@ -490,11 +492,13 @@ def launch(
render: bool = False,
enable_warnings: bool = False,
n_consumers: int = 0,
thread_workers: bool = False,
create_producer: bool = False,
queue_port: Optional[int] = None,
) -> Optional[NodeHandle]:
if dev_mode is True:
os.environ["DEV_MODE"] = "True"
thread_workers = True

# syft 0.8.1
if node_type == "python":
Expand Down Expand Up @@ -535,6 +539,7 @@ def launch(
node_side_type=node_side_type_enum,
enable_warnings=enable_warnings,
n_consumers=n_consumers,
thread_workers=thread_workers,
create_producer=create_producer,
queue_port=queue_port,
)
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def named(
node_side_type: Union[str, NodeSideType] = NodeSideType.HIGH_SIDE,
enable_warnings: bool = False,
n_consumers: int = 0,
thread_workers: bool = False,
create_producer: bool = False,
queue_port: Optional[int] = None,
dev_mode: bool = False,
Expand Down Expand Up @@ -502,11 +503,11 @@ def named(
create_producer=create_producer,
queue_port=queue_port,
n_consumers=n_consumers,
)
),
thread_workers=thread_workers,
)
else:
queue_config = None

return cls(
name=name,
id=uid,
Expand Down
11 changes: 7 additions & 4 deletions packages/syft/src/syft/service/job/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,17 @@ def kill(
return SyftError(message=res.err())

job = res.ok()
if job.job_pid is not None:
if job.job_pid is not None and job.status == JobStatus.PROCESSING:
job.status = JobStatus.INTERRUPTED
res = self.stash.update(context.credentials, obj=job)
if res.is_err():
return SyftError(message=res.err())

res = res.ok()
return SyftSuccess(message="Great Success!")
return SyftSuccess(message="Job killed successfully!")
else:
return SyftError(
message="Job is not running or isn't running in multiprocessing mode."
"Killing threads is currently not supported"
)

@service_method(
path="job.get_subjobs",
Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def restart(self, kill=False) -> None:
"Job is running or scheduled, if you want to kill it use job.kill() first"
)

def kill(self) -> None:
def kill(self) -> Union[None, SyftError]:
if self.job_pid is not None:
api = APIRegistry.api_for(
node_uid=self.node_uid,
Expand All @@ -215,6 +215,10 @@ def kill(self) -> None:
blocking=True,
)
api.make_call(call)
else:
return SyftError(
message="Job is not running or isn't running in multiprocessing mode."
)

def fetch(self) -> None:
api = APIRegistry.api_for(
Expand Down
35 changes: 18 additions & 17 deletions packages/syft/src/syft/service/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,22 +250,23 @@ def handle_message(message: bytes):
if isinstance(worker_result, SyftError):
raise Exception(message=f"{worker_result.err()}")

# from threading import Thread
# p = Thread(
# target=handle_message_multiprocessing,
# args=(worker_settings, queue_item, credentials),
# )
# p.start()

# handle_message_multiprocessing(worker_settings, queue_item, credentials)

p = multiprocessing.Process(
target=handle_message_multiprocessing,
args=(worker_settings, queue_item, credentials),
)
p.start()

job_item.job_pid = p.pid
worker.job_stash.set_result(credentials, job_item)
if queue_config.thread_workers:
# stdlib
from threading import Thread

p = Thread(
target=handle_message_multiprocessing,
args=(worker_settings, queue_item, credentials),
)
p.start()
else:
p = multiprocessing.Process(
target=handle_message_multiprocessing,
args=(worker_settings, queue_item, credentials),
)
p.start()

job_item.job_pid = p.pid
worker.job_stash.set_result(credentials, job_item)

p.join()
5 changes: 4 additions & 1 deletion packages/syft/src/syft/service/queue/zmq_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,9 @@ def purge_all(self) -> Union[SyftError, SyftSuccess]:

@serializable()
class ZMQQueueConfig(QueueConfig):
def __init__(self, client_type=None, client_config=None):
def __init__(
self, client_type=None, client_config=None, thread_workers: bool = False
):
self.client_type = client_type or ZMQClient
self.client_config: ZMQClientConfig = client_config or ZMQClientConfig()
self.thread_workers = thread_workers

0 comments on commit 8a12f6b

Please sign in to comment.