forked from ischlag/distributed-tensorflow-example
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
80 lines (67 loc) · 2.53 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Distributed Tensorflow example
The original code was in @ischlag, but the distributed architecture is quite
different.
The code runs on TF 1.1.
Trains a simple sigmoid neural network on mnist for 20 epochs on three machines using one parameter server.
The code requires 'tmux'.
The code runs on the local server only.
Run like this:
$ bash run.sh
Then, by using ctrl+b+(window number, e.g., 0, 1, 2),
you can change the terminal.
"""
from __future__ import print_function
import tensorflow as tf
import numpy as np
import os
import time
import signal, sys
from worker import Worker
from utils import *
flags = tf.app.flags
flags.DEFINE_string('job_name', 'ps', "Either 'ps' or 'worker'")
flags.DEFINE_integer('task_index', 0, "Index of task within the job")
flags.DEFINE_integer('batch_size', 100, "Batch size")
flags.DEFINE_float('learning_rate', 0.001, "Learning rate")
flags.DEFINE_integer('training_steps', 10**7,
"Training steps (1step = 1batch update")
flags.DEFINE_string('logdir', './tmp/mnist/1', "Log directory")
flags.DEFINE_integer('num_workers', 2, "Number of workers")
flags.DEFINE_integer('num_gpus', 1,
"Number of gpus, less than or equal to num_workers")
FLAGS = flags.FLAGS
def main():
# Load MNIST dataset.
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# Cluster specification
spec = cluster_spec(FLAGS.num_workers, 1)
cluster = tf.train.ClusterSpec(spec)
# Signal
def shutdown(signal, frame):
sys.exit(128+signal)
signal.signal(signal.SIGHUP, shutdown)
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
# Set GPU memory fraction.
process_per_memory =\
np.ceil(float(FLAGS.num_workers)/float(FLAGS.num_gpus))
fraction = 0.9 / process_per_memory
print('-'*100)
print("Per-process GPU memory fraction: {}".format(fraction))
print('-'*100)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=fraction)
if FLAGS.job_name == 'ps':
server = tf.train.Server(cluster, job_name='ps',
task_index=FLAGS.task_index)
while True:
time.sleep(1000)
elif FLAGS.job_name == 'worker':
config = tf.ConfigProto(gpu_options=gpu_options)
server = tf.train.Server(cluster, job_name='worker',
task_index=FLAGS.task_index, config=config)
worker = Worker(FLAGS.job_name, FLAGS.task_index, server)
worker.learn(mnist)
if __name__ == '__main__':
main()