-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlaunch_cluster.py
executable file
·94 lines (73 loc) · 2.72 KB
/
launch_cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#! /usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import configs.distributed
import subprocess
import time
def main (FLAGS, CHILD_FLAGS):
cluster = configs.distributed.get_cluster(FLAGS)
processes = {'ps' : [], 'worker' : []}
# Launch ps tasks
for task_index, parameter_task in enumerate(cluster['ps']):
print(task_index, parameter_task)
parameter_server = subprocess.Popen (
[
'python', 'parameter_server.py',
'--task_index', '{}'.format(task_index),
'--ps_tasks', '{}'.format(FLAGS.ps_tasks),
'--worker_tasks', '{}'.format(FLAGS.worker_tasks),
] + CHILD_FLAGS
)
processes['ps'].append(parameter_server)
time.sleep(FLAGS.launch_delay)
# Launch worker tasks
for task_index, worker_task in enumerate(cluster['worker']):
print(task_index, worker_task)
worker = subprocess.Popen (
[
'python', 'worker.py',
'--task_index', '{}'.format(task_index),
'--ps_tasks', '{}'.format(FLAGS.ps_tasks),
'--worker_tasks', '{}'.format(FLAGS.worker_tasks),
] + CHILD_FLAGS
)
processes['worker'].append(worker)
time.sleep(FLAGS.launch_delay)
try:
while True:
# Poll worker processes
live_workers = filter(lambda p : p.poll() is None, processes['worker'])
# Quit if all workers are finished
if len(live_workers) == 0:
break
time.sleep(FLAGS.poll_delay)
print('All workers finished, cleaning up and exiting.')
except KeyboardInterrupt as e:
# Catch CTRL-C
print('Caught {} cleaning up and exiting.'.format(type(e).__name__))
finally:
# Send SIGTERM
for task in ['ps', 'worker']:
for process in processes[task]:
if process.poll() is None:
process.terminate()
# Wait for exit
for task in ['ps', 'worker']:
for process in processes[task]:
process.wait()
return 0
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser = configs.distributed.add_argparse_args(parser)
parser.add_argument (
'--launch_delay', type=float, default=0.0,
help='Wait time between launching tasks in seconds.'
)
parser.add_argument (
'--poll_delay', type=float, default=60.0,
help='Wait time between checking tasks in seconds.'
)
FLAGS, CHILD_FLAGS = parser.parse_known_args()
exit(main(FLAGS, CHILD_FLAGS))