-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgan_fashion_mnist.py
112 lines (99 loc) · 5.04 KB
/
gan_fashion_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import tensorflow as tf
import numpy as np
import os
import argparse
from tensorflow.examples.tutorials.mnist import input_data
save_path = "./save_model/fashion_mnist/"
# F_MNIST image dimension: (28, 28)
class Machine:
def __init__(self, args):
self.image_dimension = args.image_dimension
self.n_class = args.n_class
self.n_noise = args.n_noise
self.network()
def network(self):
self.Y = tf.placeholder(tf.float32, [None, self.n_class])
self.G_input = tf.placeholder(tf.float32, [None, self.n_noise])
self.generated_image = self.generator(self.G_input, self.Y)
self.D_g = self.discriminator(self.generated_image, self.Y)
self.real_image = tf.placeholder(tf.float32, [None, self.image_dimension])
self.D_r = self.discriminator(self.real_image, self.Y, True)
self.loss_D = tf.reduce_mean(tf.log(tf.maximum(self.D_r,1e-10)) + tf.log(tf.maximum(1-self.D_g,1e-10)))
self.loss_G = tf.reduce_mean(tf.log(tf.maximum(self.D_g,1e-10)))
D_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')
G_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')
self.train_D = tf.train.AdamOptimizer(0.0001).minimize(-self.loss_D, var_list=D_weights)
self.train_G = tf.train.AdamOptimizer(0.0001).minimize(-self.loss_G, var_list=G_weights)
def generator(self, input, labels):
with tf.variable_scope('generator'):
inputs = tf.concat([input, labels], 1) #(?, 138)
hidden = tf.contrib.layers.fully_connected(inputs, 256, activation_fn = tf.nn.relu)
output = tf.contrib.layers.fully_connected(hidden, self.image_dimension, activation_fn = tf.nn.sigmoid)
return output
def discriminator(self, input, labels, reuse=None):
with tf.variable_scope('discriminator', reuse = reuse):
inputs = tf.concat([input, labels], 1)
hidden = tf.contrib.layers.fully_connected(inputs, 256, activation_fn = tf.nn.relu)
output = tf.contrib.layers.fully_connected(hidden, 1, activation_fn = tf.nn.sigmoid)
return output
def get_noise(batch_size, n_noise):
return np.random.normal(size=(batch_size, n_noise))
def train(args):
GAN = Machine(args)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=None)
mnist = input_data.read_data_sets("./fashion_mnist/data/", one_hot=True)
total_batch = int(mnist.train.num_examples / args.batch_size)
for epoch in range(args.total_epoch):
avg_loss_D = 0
avg_loss_G = 0
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(args.batch_size)
noise = get_noise(args.batch_size, args.n_noise)
_, loss_val_D = sess.run([GAN.train_D, GAN.loss_D], feed_dict={GAN.real_image: batch_xs, GAN.G_input: noise, GAN.Y: batch_ys})
_, loss_val_G = sess.run([GAN.train_G, GAN.loss_G], feed_dict={GAN.G_input: noise, GAN.Y: batch_ys})
avg_loss_D += loss_val_D
avg_loss_G += loss_val_G
print("Epoch: %3d | D_loss: %.8f | G_loss: %.8f" % (epoch, avg_loss_D/total_batch, avg_loss_G/total_batch)) # avg_loss_D/total_batch, avg_loss_G/total_batch))
if epoch == 0 or (epoch + 1) % 50 == 0:
if not os.path.exists(save_path):
os.makedirs(save_path)
if args.save == "True":
print("Saving model...")
saver.save(sess, save_path + str(epoch) + ".cptk")
def generate(args):
import matplotlib.pyplot as plt
GAN = Machine(args)
sess = tf.Session()
saver = tf.train.Saver(max_to_keep=None)
model_name = args.n_model + ".cptk"
saver.restore(sess, save_path +model_name)
n_samples = 10
noise = get_noise(n_samples, args.n_noise)
targets = np.array([[0,1,2,3,4,5,6,7,8,9]]).reshape(-1)
Y = np.eye(args.n_class)[targets]
samples = sess.run(GAN.generated_image, feed_dict={GAN.G_input: noise, GAN.Y: Y})
fig, ax = plt.subplots(1, n_samples, figsize=(n_samples, 1))
for i in range(n_samples):
ax[i].set_axis_off()
ax[i].imshow(np.reshape(samples[i], (28, 28)))
plt.show()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--todo", default = "train", type = str, help="train, generate")
parser.add_argument("--total_epoch", default = 200, type = int)
parser.add_argument("--batch_size", default = 100, type = int)
parser.add_argument("--learning_rate", default = 0.0002, type = float)
parser.add_argument("--n_noise", default = 128, type = int)
parser.add_argument("--n_class", default = 10, type = int)
parser.add_argument("--image_dimension", default = 28*28, type = int)
parser.add_argument("--n_model", default = "199", type = str, help="0, 49, 99")
parser.add_argument("--save", default = "False", type = str)
args = parser.parse_args()
if args.todo == "train":
train(args)
elif args.todo == "generate":
generate(args)
if __name__ == "__main__":
main()