MNIST Generative Adversarial Model in Keras

This post discusses and demonstrates the implementation of a generative adversarial network in Keras, using the MNIST dataset.

By Tim O'Shea, O'Shea Research.

Some of the generative work done in the past year or two using generative adversarial networks (GANs) has been pretty exciting and demonstrated some very impressive results.  The general idea is that you train two models, one (G) to generate some sort of output example given random noise as input, and one (A) to discern generated model examples from real examples.  Then, by training A to be an effective discriminator, we can stack G and A to form our GAN, freeze the weights in the adversarial part of the network, and train the generative network weights to push random noisy inputs towards the “real” example class output of the adversarial half.


High Level GAN Architecture

Building this style of network in the latest versions of Keras is actually quite straightforward and easy to do, I’ve wanted to try this out on a number of things so I put together a relatively simple version using the classic MNIST dataset to use a GAN approach to generating random handwritten digits.

Before going further I should mention all of this code is available on github here.

Generative Model

We set up a relatively straightforward generative model in keras using the functional API, taking 100 random inputs, and eventually mapping them down to a [1,28,28] pixel to match the MNIST data shape.  Be begin by generating a dense 14×14 set of values, and then run through a handful of filters of varying sizes and numbers of channels and ultimately train using and Adam optimizer for binary cross-entropy (although we really only use the generator model in the forwards direction, we don’t train directly on this model itself).  We use a sigmiod on the output layer to help saturate pixels into 0 or 1 states rather than a range of grays in between, and use batch normalization to help accelerate training and ensure that a wide range of activations are used within each layer.

# Build Generative model ...
nch = 200
g_input = Input(shape=[100])
H = Dense(nch*14*14, init='glorot_normal')(g_input)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Reshape( [nch, 14, 14] )(H)
H = UpSampling2D(size=(2, 2))(H)
H = Convolution2D(nch/2, 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Convolution2D(nch/4, 3, 3, border_mode='same', init='glorot_uniform')(H)
H = BatchNormalization(mode=2)(H)
H = Activation('relu')(H)
H = Convolution2D(1, 1, 1, border_mode='same', init='glorot_uniform')(H)
g_V = Activation('sigmoid')(H)
generator = Model(g_input,g_V)
generator.compile(loss='binary_crossentropy', optimizer=opt)

We now have a network which could in theory take in 100 random inputs and output digits, although the current weights are all random and this clearly isn’t happening just yet.

Sad images from an untrained generator
Sad images from an untrained generator

Adversarial Model

We build an adversarial discriminator network to take in [1,28,28] image vectors and decide if they are real or fake by using several convolutional layers, a dense layer, lots of dropout, and a two element softmax output layer encoding: [0,1] = fake, and [1,0] = real.  This is a relatively simple network, but the goal here is largely to get something that works passably and trains relatively quickly for experimentation.

# Build Discriminative model ...
d_input = Input(shape=shp)
H = Convolution2D(256, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(d_input)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
H = Convolution2D(512, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(H)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
H = Flatten()(H)
H = Dense(256)(H)
H = LeakyReLU(0.2)(H)
H = Dropout(dropout_rate)(H)
d_V = Dense(2,activation='softmax')(H)
discriminator = Model(d_input,d_V)
discriminator.compile(loss='categorical_crossentropy', optimizer=dopt)

We pre-train the discriminative model by generating a handful of random images using the untrained generative model, concatenating them with an equal number of real images of digits, labeling them appropriately, and then fitting until we reach a relatively stable loss value which takes 1 epoch over 20,000 examples.  This is an important step which should not be skipped — pre-training accelerates the GAN massively and I was not able to achieve convergence without it (possibly due to impatience).

Generative Adversarial Model

Now that we have both the generative and adversarial models, we can combine them to make a GAN quite easily in Keras.  Using the functional API, we can simply re-use the same network objects we have already instantiated and they will conveniently maintain the same shared weights with the previously compiled models.  Since we want to freeze the weights in the adversarial half of the network during back-propagation of the joint model, we first run through and set the keras trainable flag to False for each element in this part of the network.  For now, this seems to need to be applied at the primitive layer level rather than on the high level network so we introduce a simple function to do this.

# Freeze weights in the discriminator for stacked training
def make_trainable(net, val):
    net.trainable = val
    for l in net.layers:
       l.trainable = val
make_trainable(discriminator, False)

# Build stacked GAN model
gan_input = Input(shape=[100])
H = generator(gan_input)
gan_V = discriminator(H)
GAN = Model(gan_input, gan_V)
GAN.compile(loss='categorical_crossentropy', optimizer=opt)

At this point, we now have a randomly initialized generator, a (poorly) trained discriminator, and a GAN which can be trained across the stacked model of both networks.  The core of training routine for a GAN looks something like this.

  1. Generate images using G and random noise (forward pass only).
  2. Perform a Batch update of weights in A given generated images, real images, and labels.
  3. Perform a Batch update of weights in G given noise and forced “real” labels in the full GAN.
  4. Repeat…

Running this process for a number of epochs, we can plot the loss of the GAN and Adversarial loss functions over time to get our GAN loss plots during training.

GAN Training Loss

And finally, we can plot some samples from the trained generative model which look relatively like the original MNIST digits, and some examples from the original dataset for comparison.

GAN Generated Random Digits

Examples Digits from Real MNIST Set

Original. Reposted with permission.

Bio: Tim O'Shea is the principal researcher at O’Shea Research, a small engineering research and development LLC based in Arlington, VA specialized in high performance software radio, signal processing and machine learning systems.