Skip to content

Commit

Permalink
Merge pull request #8419 from OpenMined/fix-nested-job-fails
Browse files Browse the repository at this point in the history
Fix bug in jobs and launching of workers
  • Loading branch information
shubham3121 authored Jan 25, 2024
2 parents 895c0c4 + f823d8c commit 2feece6
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 18 deletions.
3 changes: 3 additions & 0 deletions packages/hagrid/hagrid/orchestra.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def deploy_to_python(
port=port,
reset=reset,
processes=processes,
queue_port=queue_port,
n_consumers=n_consumers,
create_producer=create_producer,
dev_mode=dev_mode,
tail=tail,
node_type=node_type_enum,
Expand Down
22 changes: 17 additions & 5 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,24 +529,36 @@ def create_project(
return project

def sync_code_from_request(self, request):
code = request.code
# relative
from ..service.code.user_code import UserCode
from ..store.linked_obj import LinkedObject

code: Union[UserCode, SyftError] = request.code
if isinstance(code, SyftError):
return code
elif code is None:
return SyftError(message="no code inside request")

code = deepcopy(code)
code.node_uid = self.id
code.user_verify_key = self.verify_key

def get_nested_codes(code):
def get_nested_codes(code: UserCode):
result = []
for __, (linked_code_obj, _) in code.nested_codes.items():
nested_code = linked_code_obj.resolve
nested_code = deepcopy(nested_code)
nested_code.node_uid = code.node_uid
nested_code.user_verify_key = code.user_verify_key
result.append(nested_code)
result += get_nested_codes(nested_code)

updated_code_links = {
nested_code.service_func_name: (LinkedObject.from_obj(nested_code), {})
for nested_code in result
}
code.nested_codes = updated_code_links
return result

nested_codes = get_nested_codes(request.code)
nested_codes = get_nested_codes(code)

for c in nested_codes + [code]:
res = self.code.submit(c)
Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ def named(
)
else:
queue_config = None

return cls(
name=name,
id=uid,
Expand Down
14 changes: 14 additions & 0 deletions packages/syft/src/syft/node/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple

# third party
Expand Down Expand Up @@ -77,6 +78,9 @@ def run_uvicorn(
node_side_type: str,
enable_warnings: bool,
in_memory_workers: bool,
queue_port: Optional[int],
create_producer: bool,
n_consumers: int,
):
async def _run_uvicorn(
name: str,
Expand Down Expand Up @@ -106,6 +110,9 @@ async def _run_uvicorn(
enable_warnings=enable_warnings,
migrate=True,
in_memory_workers=in_memory_workers,
queue_port=queue_port,
create_producer=create_producer,
n_consumers=n_consumers,
)
else:
worker = worker_class(
Expand All @@ -117,6 +124,7 @@ async def _run_uvicorn(
enable_warnings=enable_warnings,
migrate=True,
in_memory_workers=in_memory_workers,
queue_port=queue_port,
)
router = make_routes(worker=worker)
app = make_app(worker.name, router=router)
Expand Down Expand Up @@ -172,6 +180,9 @@ def serve_node(
tail: bool = False,
enable_warnings: bool = False,
in_memory_workers: bool = True,
queue_port: Optional[int] = None,
create_producer: bool = False,
n_consumers: int = 0,
) -> Tuple[Callable, Callable]:
server_process = multiprocessing.Process(
target=run_uvicorn,
Expand All @@ -186,6 +197,9 @@ def serve_node(
node_side_type,
enable_warnings,
in_memory_workers,
queue_port,
create_producer,
n_consumers,
),
)

Expand Down
41 changes: 29 additions & 12 deletions packages/syft/src/syft/service/queue/zmq_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from .queue_stash import Status

# Producer/Consumer heartbeat interval (in seconds)
HEARTBEAT_INTERVAL_SEC = 5
HEARTBEAT_INTERVAL_SEC = 2

# Thread join timeout (in seconds)
THREAD_TIMEOUT_SEC = 5
Expand Down Expand Up @@ -246,7 +246,16 @@ def unwrap_nested_actionobjects(self, data):
key: self.unwrap_nested_actionobjects(obj) for key, obj in data.items()
}
if isinstance(data, ActionObject):
return data.get()
res = self.action_service.get(self.auth_context, data.id)
res = res.ok() if res.is_ok() else res.err()
if not isinstance(res, ActionObject):
return SyftError(message=f"{res}")
else:
nested_res = res.syft_action_data
if isinstance(nested_res, ActionObject):
nested_res.syft_node_location = res.syft_node_location
nested_res.syft_client_verify_key = res.syft_client_verify_key
return nested_res
return data

def preprocess_action_arg(self, arg):
Expand Down Expand Up @@ -378,16 +387,21 @@ def update_consumer_state_for_worker(
)
return

res = self.worker_stash.update_consumer_state(
credentials=self.worker_stash.partition.root_verify_key,
worker_uid=syft_worker_id,
consumer_state=consumer_state,
)
if res.is_err():
try:
res = self.worker_stash.update_consumer_state(
credentials=self.worker_stash.partition.root_verify_key,
worker_uid=syft_worker_id,
consumer_state=consumer_state,
)
if res.is_err():
logger.error(
"Failed to update consumer state for worker id={} error={}",
syft_worker_id,
res.err(),
)
except Exception as e:
logger.error(
"Failed to update consumer state for worker id={} error={}",
syft_worker_id,
res.err(),
f"Failed to update consumer state for worker id: {syft_worker_id}. Error: {e}"
)

def worker_waiting(self, worker: Worker):
Expand Down Expand Up @@ -468,8 +482,8 @@ def _run(self):
else:
logger.error("Invalid message header: {}", header)

self.purge_workers()
self.send_heartbeats()
self.purge_workers()

def require_worker(self, address):
"""Finds the worker (creates if necessary)."""
Expand All @@ -492,6 +506,9 @@ def process_worker(self, address: bytes, msg: List[bytes]):
syft_worker_id = msg.pop(0).decode()
if worker_ready:
# Not first command in session or Reserved service name
# If worker was already present, then we disconnect it first
# and wait for it to re-register itself to the producer. This ensures that
# we always have a healthy worker in place that can talk to the producer.
self.delete_worker(worker, True)
else:
# Attach worker to service and mark as idle
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def run_workers_in_threads(
)
except Exception as e:
print(
f"Failed to start consumer for Pool Name: {pool_name}, Worker Name: {worker_name}"
f"Failed to start consumer for Pool Name: {pool_name}, Worker Name: {worker_name}. Error: {e}"
)
worker.status = WorkerStatus.STOPPED
error = str(e)
Expand Down

0 comments on commit 2feece6

Please sign in to comment.