"""
Here is a mnist training example using a single-layer 
neural network and a softmax classifier.
To run this file,
1. please check '172.16.3.227:~/tensorflow/scripts'
   and execute 'exec_mnist_distributed.sh'.
2. execute 'python mnist_distributed.py \
            --job_name=worker \
            --task_index=${0|1|2}' on each worker
   and 'python mnist_distributed.py \
        --job_name=ps \
        --task_index=0' on parameter server
Check (https://www.tensorflow.org/versions/r0.10/how_tos/style_guide.html) for tensorflow styling guide.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import sys
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
flags = tf.app.flags
flags.DEFINE_string(
    "data_dir",
    "/root/tensorflow/MNIST_data",
    "Directory for storing mnist data")
flags.DEFINE_string(
    "log_dir",
    "/root/tensorflow/logs/mnist_log",
    "Directory for storing log")
flags.DEFINE_boolean(
    "download_only",
    False,
    "Only perform downloading of data")
flags.DEFINE_string(
    "job_name",
    None,
    "job name: worker or ps")
flags.DEFINE_integer(
    "task_index",
    None,
    "Worker task index, should be >= 0.")
flags.DEFINE_integer(
    "hidden_units",
    100,
    "Number of units in the hidden layer of the NN")
flags.DEFINE_integer(
    "training_steps",
    20000,
    "Number of (global) training steps to perform")
flags.DEFINE_integer(
    "batch_size",
    100,
    "Training batch size to be fetched each time")
flags.DEFINE_float(
    "learning_rate",
    0.01,
    "Learning rate in machine learning")
flags.DEFINE_string(
    "ps_hosts",
    "172.16.3.230:2222",
    "Comma-separated list of hostname:port pairs")
flags.DEFINE_string(
    "worker_hosts",
    "172.16.3.227:2222,172.16.3.228:2222,172.16.3.229:2222",
    "Comma-separated list of hostname:port pairs")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28
def main(_):
  
  if FLAGS.job_name is None or FLAGS.job_name == "":
    raise ValueError("Must specify an explicit `job_name`")
  if FLAGS.task_index is None or FLAGS.task_index =="":
    raise ValueError("Must specify an explicit `task_index`")
  print("job name = %s" % FLAGS.job_name)
  print("task index = %d" % FLAGS.task_index)
  
  ps_spec = FLAGS.ps_hosts.split(",")
  worker_spec = FLAGS.worker_hosts.split(",")
  num_workers = len(worker_spec)
  
  cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec})
  server = tf.train.Server(
      cluster,
      job_name=FLAGS.job_name,
      task_index=FLAGS.task_index)
  if FLAGS.job_name == "ps":
    server.join()
  elif FLAGS.job_name == "worker":
    
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):
      
      hid_w = tf.Variable(tf.truncated_normal(
        [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
        stddev=1.0 / IMAGE_PIXELS), name="hid_w")
      hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
      
      sm_w = tf.Variable(tf.truncated_normal([FLAGS.hidden_units, 10],
        stddev=1.0 / math.sqrt(FLAGS.hidden_units)), name="sm_w")
      sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
      
      x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
      y_ = tf.placeholder(tf.float32, [None, 10])
      
      hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
      hid = tf.nn.relu(hid_lin)
      
      y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
      cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
      
      global_step = tf.Variable(0, name="global_step", trainable=False)
      
      train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(
          cross_entropy, global_step=global_step)
      
      correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
      
      init_op = tf.initialize_all_variables()
      summary_op = tf.merge_all_summaries()
      
      
    
    sv = tf.train.Supervisor(
        is_chief=(FLAGS.task_index == 0),
        logdir=FLAGS.log_dir,
        init_op=init_op,
        summary_op=summary_op,
	
	
	
        saver=None,
        global_step=global_step,
        save_model_secs=600)
    
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    with sv.managed_session(server.target) as sess:
      time_begin = time.time()
      local_step = 0
      step = 0
      while not sv.should_stop():
        
        batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
        train_feed = {x: batch_xs, y_: batch_ys}
        
        _, step = sess.run([train_step, global_step], feed_dict=train_feed)
        local_step += 1
        now = time.time()
        print("%f: Worker %d: training step %d done (global step: %d)"
            % (now, FLAGS.task_index, local_step, step))
        if step >= FLAGS.training_steps:
          break
      time_end = time.time()
      training_time = time_end - time_begin
      print("Training elapsed time: %f s" % training_time)
      
      val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
      val_xent = sess.run(cross_entropy, feed_dict=val_feed)
      print("After %d training step(s), validation cross entropy = %g"
          % (FLAGS.training_steps, val_xent))
      
      pred_feed = {x: mnist.test.images, y_: mnist.test.labels}
      print("Accuracy is: %f" % sess.run(accuracy, feed_dict=pred_feed))
    
if __name__ == "__main__":
  tf.app.run()