Skip to content
This repository has been archived by the owner on May 3, 2022. It is now read-only.

Commit

Permalink
Resnet patch (#7)
Browse files Browse the repository at this point in the history
* minor bug: GroupConvolution should be instantiated before called

* add training pipeline for testing

* fix minor data error with validation

* fix minor data error with validation
  • Loading branch information
DavidMChan authored Jan 22, 2019
1 parent f060998 commit 0574bdd
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 4 deletions.
91 changes: 91 additions & 0 deletions research/resnet_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import tensorflow as tf
import os
import sys
import multiprocessing
tf.enable_eager_execution()

from flux.datasets.vision.cifar import Cifar10
from rinokeras.models.resnet import ResNeXt50

# Import Cifar10 Data
cifar = Cifar10()
train_image = tf.convert_to_tensor(cifar.X_train, dtype=tf.float64)
train_labels = tf.convert_to_tensor(cifar.Y_train, dtype=tf.int64)
val_image = tf.convert_to_tensor(cifar.X_test, dtype=tf.float64)
val_labels = tf.convert_to_tensor(cifar.Y_test, dtype=tf.int64)

NUM_EPOCHS = 10000
TEST_INTERVAL = 100
BATCH_SIZE = 64


class PredictionNet(tf.keras.Model):
def __init__(self, use_layer_norm=True) -> None:
super(PredictionNet, self).__init__()
self.resnet = ResNeXt50(use_layer_norm=use_layer_norm)
self.prediction_module = tf.keras.layers.Dense(units=10)

def call(self, inputs, training=True):
result = self.resnet(inputs)
result = self.prediction_module(result)
# Compute the paddings
return result

def loss(logits, labels):
sparse_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)
return tf.reduce_mean(sparse_loss)

def compute_accuracy(logits, labels):
index = tf.argmax(logits, axis=1)
values = tf.cast(tf.equal(index, labels), tf.float64)
batch_size = int(logits.shape[0])

return tf.reduce_sum(values)/batch_size

resnet = PredictionNet(True)
checkpoint_prefix = os.path.join('./checkpoints/', 'ckpt')
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
step_counter = tf.train.get_or_create_global_step()
checkpoint = tf.train.Checkpoint(
model=resnet, optimizer=optimizer, step_counter=step_counter)
checkpoint.restore(tf.train.latest_checkpoint('./checkpoints/'))

def run():
for iteration in range(NUM_EPOCHS):
index = tf.range(0, BATCH_SIZE, delta=1)
index = tf.random_shuffle(index)

batch = tf.gather(train_image, index)
labels = tf.gather(train_labels, index)
one_hot = tf.one_hot(labels, depth=10, dtype=tf.float64)
with tf.GradientTape() as tape:
logits = resnet(batch)
loss_value = loss(logits, one_hot)
grads = tape.gradient(loss_value, resnet.variables)
optimizer.apply_gradients(
zip(grads, resnet.variables), global_step=step_counter)

if iteration % 5 == 0:
print('[Iteration {}] Loss: {}'.format(iteration, loss_value))
sys.stdout.flush()

if iteration % TEST_INTERVAL == 0:
total_accuracy = 0.
num_batches = 0
tloss = 0
index = tf.range(0, BATCH_SIZE, delta=1)
index = tf.random_shuffle(index)
batch = tf.gather(val_image, index)
labels = tf.gather(val_labels, index)
one_hot = tf.one_hot(labels, depth=10, dtype=tf.float64)
logits = resnet(batch)
tloss += loss(logits,one_hot)
total_accuracy += compute_accuracy(logits, labels)
num_batches += 1
print('[TEST ITERATION, Iteration {}] Validation Accuracy: {}, Validation Loss: {}'.format(
iteration, float(total_accuracy) / num_batches, float(tloss) / num_batches))
checkpoint.save(checkpoint_prefix)
sys.stdout.flush()

if __name__ == "__main__":
run()
8 changes: 4 additions & 4 deletions rinokeras/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class GroupedConvolution(tf.keras.Model):
def __init__(self, cardinality: int = 1, n_filters: int = 64, kernel_size: Tuple[int, int] = (3, 3), stride: Tuple[int, int] = (1,1)) -> None:

super(GroupedConvolution, self).__init__()
self.cardinality = cardinality

if self.cardinality == 1:
Expand All @@ -20,15 +20,15 @@ def __init__(self, cardinality: int = 1, n_filters: int = 64, kernel_size: Tuple
self._layer_list = tf.contrib.checkpoint.List()
for idx in range(self.cardinality):
group = tf.keras.layers.Lambda(lambda z: z[:,:,:, idx * self._dim: (idx + 1) * self._dim])
group = tf.keras.layers.Conv2D(filters=self._dim, kernel_size=kernel_size, stride=stride, padding='same')
group = tf.keras.layers.Conv2D(filters=self._dim, kernel_size=kernel_size, strides=stride, padding='same')
self._layer_list.append(group)

def call(self, inputs, *args, **kwargs):
if self.cardinality == 1:
return self.output_layer(inputs)
else:
layers = [layer(inputs) for layer in self._layer_list]
return tf.keras.layers.Concatenate(layers)
return tf.keras.layers.Concatenate()(layers)


class ResidualBlock(tf.keras.Model):
Expand Down Expand Up @@ -134,4 +134,4 @@ def __init__(self, use_layer_norm: bool = True) -> None:
self.add(ResidualBlock(self.cardinality, n_filters_in=1024, n_filters_out=2048, stride=strides))

self.add(tf.keras.layers.GlobalAveragePooling2D())
self.add(tf.keras.layers.Dense(1))
# self.add(tf.keras.layers.Dense(1))

0 comments on commit 0574bdd

Please sign in to comment.