Semi-supervised Training Using GAN on SVHN Dataset

Posted 2018-12-12

Goal:

  • Using GAN's to implement semi-supervised learning
  • Semi-supervised learnings allows us to learn not just from labeled data but also unlabeled data (of which there is a lot more of and widely available)
    • Discriminator learns (optimize losses) from the following sources:
      • Real data with label of corresponding digit labels (~1k samples)
      • Real data with a label of 'real' (~72k samples)
      • Generated data with a label of 'fake' (~72k samples)
    • Generator learns (optimizes losses) from 'featuring matching' to discriminator learned weights
  • This is a replication of the Udacity AIND semi-supervised project using Stanford Street View House Number dataset.
  • Performance
In [ ]:
import time

import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
import tensorflow as tf

%matplotlib inline

Download Data

In [2]:
from urllib.request import urlretrieve
from os.path import isfile, isdir
from tqdm import tqdm

!mkdir data
data_dir = 'data/'

if not isdir(data_dir):
    raise Exception("Data directory does not exist!")

class DLProgress(tqdm):
    """Extends tqdm class. https://github.com/tqdm/tqdm 
    
    Attributes
    ----------
    total : int
        Total size (in tqdm units)
    """
    last_block = 0
    
    def hook(self, block_num=1, block_size=1, total_size=None):
        """Hook to customize and manually update the progress bar
        
        Parameters
        ----------
        block_num : int, optional
            Number of blocks transfered so far [default: 1]
        block_size : int, optional
            Size of each block (in tqdm units) [default: 1]
        total_size : int, optional
            Total size (in tqdm units) If [default: None] remains 
            unchanged
        """
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num

# Get training and test data
if not isfile(data_dir + 'train_32x32.mat'):
    with DLProgress(unit='B', unit_scale=True, miniters=1, 
                    desc='SVHN Training Set') as pbar:
        urlretrieve(
            'http://ufldl.stanford.edu/housenumbers/train_32x32.mat',
            data_dir + 'train_32x32.mat',
            pbar.hook)
    
if not isfile(data_dir + 'test_32x32.mat'):
    with DLProgress(unit='B', unit_scale=True, miniters=1, 
                    desc='SVHN Test Set') as pbar:
        urlretrieve(
            'http://ufldl.stanford.edu/housenumbers/test_32x32.mat',
            data_dir + 'test_32x32.mat',
            pbar.hook) 
mkdir: data: File exists

Load Data

Get a feel for the dataset

In [3]:
trainset = loadmat(data_dir + 'train_32x32.mat')
testset = loadmat(data_dir + 'test_32x32.mat')
trainset.keys(), testset.keys()
Out[3]:
(dict_keys(['X', '__globals__', '__version__', '__header__', 'y']),
 dict_keys(['X', '__globals__', '__version__', '__header__', 'y']))
In [4]:
trainset['__globals__'], trainset['__version__'], trainset['__header__']
Out[4]:
([],
 '1.0',
 b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Mon Dec  5 21:09:26 2011')
In [5]:
trainset['X'].shape, trainset['y'].shape
Out[5]:
((32, 32, 3, 73257), (73257, 1))
In [6]:
testset['X'].shape, testset['y'].shape
Out[6]:
((32, 32, 3, 26032), (26032, 1))
In [7]:
trainset['y'][:5]
Out[7]:
array([[1],
       [9],
       [2],
       [3],
       [2]], dtype=uint8)
In [8]:
# A particular RGB image
testset['X'][:,:,:,0].shape
Out[8]:
(32, 32, 3)
In [9]:
# Min, max values
testset['X'][:,:,:,0].min(), testset['X'][:,:,:,0].max()
Out[9]:
(1, 125)
In [10]:
# Note it's some images have 2 digits but still labeled by most central one.
# This is greyscale of red in RGB values of the first sample.
plt.imshow(trainset['X'][:,:,0,0], cmap='Greys')
Out[10]:
<matplotlib.image.AxesImage at 0xb3198c5c0>

Show Image Examples

In [11]:
idx = np.random.randint(0, trainset['X'].shape[3], size=36)
fig, axes = plt.subplots(6, 6, sharex=True, sharey=True, figsize=(7, 7))
for ax, i in zip(axes.flatten(), idx):
    ax.imshow(trainset['X'][:,:,:,i])
    ax.axis('off')
plt.subplots_adjust(wspace=0, hspace=0)

SVHN Dataset Class Utility

In [12]:
def scale(x, feature_range=(-1, 1)):
    """ Normalize values to features range based on min of all values in 
    ndarray and 255.
    
    Parameters
    ----------
    x : ndarray
        Input to normalize
    feature_range : tuple
        Min and max of desired range [default: (-1, 1)]
    
    Returns
    -------
    out : ndarray
        Normaled output
    """
    # Scale to (0, 1)
    x = (x - x.min()) / (255 - x.min())
    
    # Scale to dsired range
    min_, max_ = feature_range
    x = x * (max_ - min_) + min_
    return x
In [13]:
class Dataset:
    """SVHN dataset class utility - mask to show only 1000 labels.  The rest
    will be treated as unlabeled.
    
    Parameters
    ----------
    train : ndarray
        Training dataset
    test : ndarray
        Test dataset
    val_fract : float
        Fraction of test data leave for validation
    shuffle : bool
        Shuffle data on call to batches method if True
    scale_fn : pointer
        Scaling function
    
    Attributes
    ----------
    train_x, test_x, valid_x : 4D ndarray
        Training observation data.  
    train_y, test_y, valid_y : 2D ndarray
        Labeled ground truth data for training observations. 
    label_mask : array
        0, 1 mask to suppress labels
    scaler : pointer
        Pointer to scaler function
    shuffle : bool
        Shuffle if True
    """
    def __init__(self, train, test, val_frac=0.5, shuffle=True, scale_fn=None):
        # Split dataset
        split_idx = int(len(test['y']) * (1 - val_frac))
        self.train_x, self.train_y = train['X'], train['y']
        self.test_x, self.valid_x = test['X'][:,:,:,:split_idx], test['X'][:,:,:,split_idx:]
        self.test_y, self.valid_y = test['y'][:split_idx], test['y'][split_idx:]
        
        # Use the first 1000 labels, mask the rest
        self.label_mask = np.zeros_like(self.train_y)
        self.label_mask[:1000] = 1
        
        # Change axis of data to (batch, height, width, channel)
        self.train_x = np.rollaxis(self.train_x, 3)
        self.test_x = np.rollaxis(self.test_x, 3)
        self.valid_x = np.rollaxis(self.valid_x, 3)
        
        # Normalize data
        if scale_fn is None:
            self.scaler = scale
        else:
            self.scaler = scale_fn
        self.train_x = self.scaler(self.train_x)
        self.test_x = self.scaler(self.test_x)
        self.valid_x = self.scaler(self.valid_x)
        self.shuffle = shuffle
    
    def batches(self, batch_size, which_set='train'):
        """Generator to yield batches of data at a time.
        
        Parameters
        ----------
        batch_size : int
            Batch size for each yield
        which_set : {'train', 'test', 'valid'}
            Specifies which dataset
            
        Yields
        ------
        x : 4D ndarray
            Next batch of dataset
        y : 2D array
            Next batch of labels
        label_mask : array, optional
            Next batch of label masks for training set
        
        """
        x_name = which_set + '_x'
        y_name = which_set + '_y'
        
        # Introspection to get / set different datasets during runtime
        num_examples = len(getattr(dataset, y_name))
        if self.shuffle:
            shuffled_idx = np.random.permutation(num_examples)
            setattr(dataset, x_name, getattr(dataset, x_name)[shuffled_idx])
            setattr(dataset, y_name, getattr(dataset, y_name)[shuffled_idx])
            if which_set == 'train':
                dataset.label_mask = dataset.label_mask[shuffled_idx]
        
        dataset_x = getattr(dataset, x_name)
        dataset_y = getattr(dataset, y_name)
        for idx in range(0, num_examples, batch_size):
            x = dataset_x[idx:idx+batch_size]
            y = dataset_y[idx:idx+batch_size]
            
            if which_set == 'train':
                # When training, also provide label masks
                yield x, y, self.label_mask[idx:idx+batch_size]
            else:
                yield x, y   

TensorFlow Graph

Inputs to Generator and Discriminator

In [14]:
def model_inputs(real_shape, z_shape):
    """Inputs to Generator and Discriminator
    
    Parameters
    ----------
    real_shape : tuple of ints
        Shape of input 
    z_shape : int
        Length of random noise as seed for fake data
    
    Returns
    -------
    inputs_real : placeholder tensor
        Input for real data 
    inputs_z : placeholder tensor
        Input for fake data
    y : placeholder 1D tensor
        Ground truth labels for real data
    label_mask : 1D tensor
        Label mask to indicate show / mask labels
    """
    inputs_real = tf.placeholder(tf.float32, (None, *real_shape), 
                                 name='input_real')
    inputs_z = tf.placeholder(tf.float32, (None, z_shape), name='input_z')
    y = tf.placeholder(tf.int32, (None), name='y')
    label_mask = tf.placeholder(tf.int32, (None), name='label_mask')
    
    return inputs_real, inputs_z, y, label_mask

Generator Architecture

In [15]:
def generator(z, output_dim, reuse=False, alpha=0.2, training=True, 
              size_mult=128):
    """NN architecture for generator.
    Fairly standard leaky relu deep generator takes in random noise and 
    upsamples using transpose convolutional layers and strides. Applies batch
    normalization.  
    
    Parameters
    ----------
    z : placeholder tensor
        Generated random noise 
    output_dim : int
        Number of channels corresponding to input to for discriminator
    reuse : bool
        Reuse weights if True [default: False]
    alpha : float
        Left half leaky relu alpha multiplier [default: 0.2]
    training : bool
        If true, return output during training. False for inference mode.
        [default: True]
    size_mult : int
        Base layer unit size multiplier for all layers [default: 128]
    
    Returns
    -------
    out : tensor
        Output layer of generator which matches input shape for discriminator
        Shape (batch, 32, 32, output_dim)
    """
    with tf.variable_scope('generator', reuse=reuse):
        # layer 1 - input (batch, scalar)
        # Reshape to (batch, 4, 4, size_mult * 4) for CNN
        # Fully connected, batch norm, leaky relu
        x1 = tf.layers.dense(z, 4 * 4 * size_mult * 4)
        x1 = tf.reshape(x1, (-1, 4, 4, size_mult * 4))
        x1 = tf.layers.batch_normalization(x1, training=training)
        x1 = tf.maximum(alpha * x1, x1)
        
        # layer 2 - input (batch, 4, 4, size_mult * 4)
        # Batch norm, leaky relu, kernel 5, stride 2, same padding
        x2 = tf.layers.conv2d_transpose(x1, size_mult * 2, 5, strides=2, 
                                        padding='same')
        x2 = tf.layers.batch_normalization(x2, training=training)
        x2 = tf.maximum(alpha * x2, x2)
        
        # layer 3 - input (batch, 8, 8, size_mult * 2)
        # Batch norm, leaky relu, kernel 5, stride 2, same padding
        x3 = tf.layers.conv2d_transpose(x2, size_mult, 5, strides=2, 
                                        padding='same')
        x3 = tf.layers.batch_normalization(x3, training=training)
        x3 = tf.maximum(alpha * x3, x3)
        
        # Final layer - input (batch, 16, 16, size_mult)
        # kernel 5, stride 2, same padding, tanh
        # Final layer shape - (batch, 32, 32, output_dim)
        logits = tf.layers.conv2d_transpose(x3, output_dim, 5, strides=2, 
                                           padding='same')
        out = tf.tanh(logits)
        
        return out

Discriminator Architecture

In [16]:
def discriminator(x, reuse=False, alpha=0.2, drop_rate=0., num_classes=10, 
                  size_mult=64, training=True):
    """NN architecture for discriminator.
    Fairly standard discriminator using strides and no pooling. Due to the small
    number of labeled data, use dropouts to regularize. 
    
    Apply feature matching (global average pooling) for height x width prior 
    to final output layer so that generator can optimize against this.
    
    Calculate real / fake logits based on logits on classes. The higher the 
    sum(classes logits), the higher the likelihood the data is real.
    
    Parameters
    ----------
    x : tensor
        Input to discriminator
    reuse : bool
        Reuse model weights if True [default: False]
    alpha : float
        Left half leaky relu multiplier parameter [default: 0.2]
    drop_rate : float
        Dropout rate parameter for layers [default: 0.]
    num_classes : int
        Number of real digits to predict [default: 10]
    size_mult : int
        Base layer unit size multiplier for all layers [default: 64]
    
    Returns
    -------
    out : tensor
        Output layer on real data on digit classes prediction
        Shape (batch, num_classes)
    class_logits : tensor
        Logits layer on real data on digit classes prediction
        Shape (batch, num_classes)
    gan_logits : tensor
        Calculated logits on real/fake prediction. Calculated based on logits
        of the classes. 
        Shape (batch, 1)
    features : tensor
        Features layer before output layer, which does global average pooling
        on height x width
        Shape (batch, 1, size_mult * 2)
    """
    with tf.variable_scope('discriminator', reuse=reuse):
        # Apply dropout to input
        x = tf.layers.dropout(x, rate=drop_rate/2.5)
        
        # layer 1 - input (batch, 32, 32, 3)
        # No batch norm to make sure input comes thru ok
        # Kernel 3, stride 2, leaky relu, dropout, same padding
        x1 = tf.layers.conv2d(x, size_mult, 3, strides=2, padding='same')
        relu1 = tf.maximum(alpha * x1, x1)
        relu1 = tf.layers.dropout(relu1, rate=drop_rate)
        
        # layer 2 - input (batch, 16, 16, size_mult)
        # Batch norm, kernel 3, stride 2, leaky relu, same padding
        x2 = tf.layers.conv2d(relu1, size_mult, 3, strides=2, padding='same')
        bn2 = tf.layers.batch_normalization(x2, training=training)
        relu2 = tf.maximum(alpha * bn2, bn2)
        
        # layer 3 - input (batch, 8, 8, size_mult)
        # Batch norm, kernel 3, stride 2, leaky relu, dropout, same padding
        x3 = tf.layers.conv2d(relu2, size_mult, 3, strides=2, padding='same')
        bn3 = tf.layers.batch_normalization(x3, training=training)
        relu3 = tf.maximum(alpha * bn3, bn3)
        relu3 = tf.layers.dropout(relu3, rate=drop_rate)
        
        # layer 4 - input (batch, 4, 4, size_mult)
        # Batch norm, kernel 3, stride 1, leaky relu, same padding
        x4 = tf.layers.conv2d(relu3, size_mult * 2, 3, strides=1, padding='same')
        bn4 = tf.layers.batch_normalization(x4, training=training)
        relu4 = tf.maximum(alpha * bn4, bn4)
        
        # layer 5 - input (batch, 4, 4, size_mult * 2)
        # Batch norm, kernel 3, stride 1, leaky relu, same padding
        x5 = tf.layers.conv2d(relu4, size_mult * 2, 3, strides=1, padding='same')
        bn5 = tf.layers.batch_normalization(x5, training=training)
        relu5 = tf.maximum(alpha * bn5, bn5)
        
        # layer 6 - input (batch, 4, 4, size_mult * 2)
        # Batch norm, kernel 3, stride 2, leaky relu, dropout, same padding
        x6 = tf.layers.conv2d(relu5, size_mult * 2, 3, strides=2, padding='same')
        bn6 = tf.layers.batch_normalization(x6, training=training)
        relu6 = tf.maximum(alpha * bn6, bn6)
        relu6 = tf.layers.dropout(relu6, rate=drop_rate)
        
        # layer 7 - input (batch, 2, 2, size_mult * 2)
        # No batch norm so we can using feature matching on it 
        # Kernel 2, stride 1, leaky relu, valid padding
        x7 = tf.layers.conv2d(relu6, size_mult * 2, 2, strides=1, padding='valid')
        relu7 = tf.maximum(alpha * x7, x7)
        
        # features (global average pooling) which flattens array on axis 1 
        # (height) and 2 (width) down into a scalar. Resulting dimension is
        # (batch, feature_avg, size_mult * 2)
        features = tf.reduce_mean(relu7, axis=[1, 2])
        
        # output layer - input (batch, 1, size_mult * 2)
        # output layer shape (batch, num_classes)
        class_logits = tf.layers.dense(features, num_classes)
        out = tf.nn.softmax(class_logits)
        
        # Computing GAN logits
        # 1 method is to set num_classes to be 11 instead of 10, if so, would 
        # have to split it up.
        # 2nd method takes advantage of the fact that 
        # sum(softmax(class_logts)) = 1, so can leave out 1 unknown
        # (fake_class_logits) and still have softmax work. 
        real_class_logits = class_logits
        fake_class_logits = 0
        
        # Numerical stability trick for log softmax due to possibility of 
        #     - one value being really large
        #     - all the values are very negative
        # Get max of the class logits (axis 1). Note axis 0 is batches
        mx = tf.reduce_max(real_class_logits, axis=1, keepdims=True)
        # Subtract max from each logit element wise
        stable_real_class_logits = real_class_logits - mx  
        #    - Calc log sum exp then add back the max (the trick above), thus 
        #      giving us the logits as they should be, but stable.
        #    - log(exp(sum logits across all classes)) + scalar max
        #    - gan_logits shape -> (batch, scalar sum of logits across all classes)
        gan_logits = tf.log(tf.reduce_sum(tf.exp(stable_real_class_logits), axis=1))\
                   + tf.squeeze(mx) + fake_class_logits
        
        return out, class_logits, gan_logits, features

D & G Losses

In [17]:
def model_loss(input_real, input_z, output_dim, y, num_classes, label_mask, 
               alpha=0.2, drop_rate=0.):
    """Get outputs from discriminator and generator and calculate losses.
    
    
    Parameters
    ----------
    input_real : placeholder tensor
        Input for real data
    input_z : placeholder tensor
        Input for fake data
    output_dim : int
        Number of channels for input of discriminator
    y : placeholder 1D tensor
        Labels for real data
    num_classes : int
        Number of real digits to predict
    label_mask : placeholder 1D tensor
        0, 1 mask. 1 means use the label.
    alpha : float
        Left half of leaky relu multiplier parameter [default: 0.2]
    drop_rate : scalar tensor
        Dropout rate parameter in layers [default: 0.]
    
    
    Returns
    -------
    d_loss : 1D tensor
        Losses for discriminator for each observation
    g_loss : 1D tensor
        Losses for generator for each observation
    correct : float
        Number of correct predictions
    masked_correct : float
        Number of correct predictions taking into account masking
    g_model : tensor
        Generator output layer 
    """
    g_size_mult = 32
    d_size_mult = 64
    
    # Get generator output
    g_model = generator(input_z, output_dim, alpha=alpha, 
                        size_mult=g_size_mult)
    
    # Get discriminator output for real data 
    _ = discriminator(input_real, alpha=alpha, drop_rate=drop_rate, 
                      size_mult=d_size_mult)
    d_model_real, class_logits_real, gan_logits_real, features_real = _
    
    # Get discriminator output for fake data while reusing same weights.
    # Note that we use batch norm on real data and then doing it again
    # when using it on fake data. This is a bit weird.
    _ = discriminator(g_model, reuse=True, alpha=alpha, drop_rate=drop_rate, 
                      size_mult=d_size_mult)
    d_model_fake, class_logits_fake, gan_logits_fake, features_fake = _

    # LOSSES
    # Discriminator loss = unsupervised + supervised
    # Unsupervised logits loss - match feedback of real vs fake
    # Logit shape (batch, scalar)
    d_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=gan_logits_real, 
                                                          labels=tf.ones_like(gan_logits_real))
    d_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=gan_logits_fake,
                                                          labels=tf.zeros_like(gan_logits_fake))
    
    # Supervised classes logits loss - match on classes of digits. We only 
    # calculate this for ones with unmasked labels
    # Class logit shape (batch, num_classes)
    y_one_hot = tf.one_hot(tf.squeeze(y), num_classes, dtype=tf.float32)
    class_cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits_v2(logits=class_logits_real,
                                                                          labels=y_one_hot)
    class_cross_entropy_loss = tf.squeeze(class_cross_entropy_loss)
    label_mask = tf.squeeze(tf.to_float(label_mask))
    # Make sure not to divide by 0
    # d_loss shape (batch)
    d_loss_class = tf.reduce_sum(label_mask * class_cross_entropy_loss) / tf.maximum(1., tf.reduce_sum(label_mask))
    d_loss = d_loss_class + d_loss_real + d_loss_fake
    
    # Generator loss
    # Calculate how closely matches features of discriminator pre-logits layer
    # ie. this is trying to generate images as close to the trained weights / 
    # activation filters for discriminator, thereby generating a higher 
    # probability of a 1 in the binary decision of 1/0 where 1 is 
    # discriminator's prediction to be real.
    # Note the chain of averagings:
    #    - feature_average out of discriminator, shape (batch, feature_avg, 
    #      filters)
    #    - moments (observation avg), shape (batches)
    #    - loss is scalar number (differences between moments
    #      for fake vs real for each observation then averaged over batch)
    moments_real = tf.reduce_mean(features_real, axis=0)
    moments_fake = tf.reduce_mean(features_fake, axis=0)
    g_loss = tf.reduce_mean(tf.abs(moments_real - moments_fake))
    
    # ACCURACY
    pred_class = tf.cast(tf.argmax(class_logits_real, axis=1), tf.int32)
    eq = tf.to_float(tf.equal(tf.squeeze(y), pred_class))
    correct = tf.reduce_sum(eq)
    masked_correct = tf.reduce_sum(label_mask * eq)

    return d_loss, g_loss, correct, masked_correct, g_model
    

D & G Optimizer

In [18]:
def model_opt(d_loss, g_loss, learning_rate, beta1):
    """Optimize separately for discriminator and generator. Shrinkable learning
    rate. 
    
    Parameters
    ---------
    d_loss : 1D tensor
        Discriminator loss
    g_loss : 1D tensor
        Generator loss
    learning_rate : non-trainable variable scalar tensor
        learning rate for optimizer
    beta1 : float
        beta1 for AdamOptimizer
    
    Returns
    -------
    d_train_opt : Optimizer
        Discrminator optimizer training operation
    g_train_opt : Optimizer
        Generator optimizer training operation 
    shrink_lr : non-trainable variable scalar tensor
        Manually updated learning rate for 
    """
    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if var.name.startswith('discriminator')]
    g_vars = [var for var in t_vars if var.name.startswith('generator')]
    
    # Check that there are no straggling vars
    for t in t_vars:
        assert t in d_vars or t in g_vars
    
    # Minimize both simultaneously
    d_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(d_loss, var_list=d_vars)
    g_train_opt = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(g_loss, var_list=g_vars)
    
    # Decrease learning rate over time
    shrink_lr = tf.assign(learning_rate, learning_rate * 0.9)
    
    return d_train_opt, g_train_opt, shrink_lr

Set up GAN

In [19]:
class GAN:
    """Set up the GAN model.
    
    Parameters
    ---------
    real_shape : tuple
        Shape of 1 particular training data
    z_shape : int
        Shape of random noise for fake image
    learning_rate : float
        Learning rate of optimizers
    num_classes : int
        Number of classes to predict [default: 10]
    alpha : float
        Multiplier for left half of leaky relu [default: 0.2]
    beta1 : float
        beta1 for AdamOptimizer [default: 0.5]
    
    Attributes
    ----------
    learning_rate : variable scalar tensor
        Learning rate for optimizer
    inputs_real : placeholder tensor
        Input for the real data
    inputs_z : placeholder tensor
        Input for the fake data
    y : placeholder 1D tensor
        Labels of digits for dataset
    label_mask : placeholder 1D tensor
        0, 1 masks, 1 means use the label
    drop_rate : placeholder scalar tensor
        Drop rate of layers
    d_loss : 1D tensor
        Discriminator loss for each observation
    g_loss : 1D tensor
        Generator loss for each observation
    correct : float
        Number of correct predictions
    masked_correct : float
        Number of correct prediction taking into account label masks
    g_model : tensor
        Generator output layer
    d_opt : Optimizer
        Discriminator optimizer
    g_opt : Optimizer
        Generator optimizer
    shrink_lr : non-trainable variable scalar tensor
        Mutable learning rate
    """
    def __init__(self, real_shape, z_shape, learning_rate, num_classes=10, 
                 alpha=0.2, beta1=0.5):
        tf.reset_default_graph()
        
        self.learning_rate = tf.Variable(learning_rate, trainable=False)
        # Setup inputs
        self.input_real, self.input_z, self.y, self.label_mask = model_inputs(real_shape, z_shape)
        self.drop_rate = tf.placeholder_with_default(0.5, shape=(), name='drop_rate')
        
        # Calculate model losses 
        _ = model_loss(self.input_real, self.input_z, real_shape[2], self.y, 
                       num_classes, self.label_mask, alpha=0.2, 
                       drop_rate=self.drop_rate)
        self.d_loss, self.g_loss, self.correct, self.masked_correct, self.g_model = _
        
        # Setup optimizers
        self.d_opt, self.g_opt, self.shrink_lr = model_opt(self.d_loss, self.g_loss, self.learning_rate, beta1)
        

Training loop

In [20]:
def view_samples(epoch_samples, nrows, ncols, figsize=(5, 5)):
    """View generated samples during training"""
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, sharey=True, 
                            sharex=True)
    for ax, img in zip(axes.flatten(), epoch_samples):
        img = ((img - img.min()) * 255 / (img.max() - img.min())).astype(np.uint8)
        ax.axis('off')
        ax.set_adjustable('box-forced')
        ax.imshow(img)
    plt.subplots_adjust(wspace=0, hspace=0)
    
    return fig, axes
In [21]:
def train(gan_net, dataset, epochs, batch_size, z_shape, figsize=(5, 5)):
    """Training loop.
    
    Parameters
    ----------
    gan_net : GAN
        GAN network 
    dataset : Dataset
        Dataset 
    epoches : int
        Number of epoches to train for
    batch_size : int
        Batch size for training
    z_shape : int
        Length of random noise
    figsize : tuple
        [default: (5, 5)]
        
    
    Returns
    -------
    train_accs : list
        Training accuracies for each epoch
    test_accs : list
        Test accuracies for each epoch
    samples : list
    """
    saver = tf.train.Saver()
    
    # Generate a set of noise so that it can be used to visualize the training
    # progress over time.
    sample_z = np.random.normal(0, 1, size=(50, z_shape))
    
    samples, train_accs, test_accs = [], [], []
    steps = 0
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for e in range(epochs):
            print('Epoch:', e)
            t_epoch_start = time.time()
            
            # Train data
            num_examples = 0
            num_correct = 0
            for x, y, label_mask in dataset.batches(batch_size, which_set='train'):
                steps += 1
                num_examples += label_mask.sum()
                batch_z = np.random.normal(0, 1, size=(len(x), z_shape))

                t1 = time.time()
                _, _, correct = sess.run([gan_net.d_opt, 
                                          gan_net.g_opt, 
                                          gan_net.masked_correct], 
                                         feed_dict={gan_net.input_real: x, 
                                                    gan_net.input_z: batch_z, 
                                                    gan_net.y: y,
                                                    gan_net.label_mask: label_mask})
                t2 = time.time()
                num_correct += correct
            sess.run([gan_net.shrink_lr])
            
            # Training accuracy calculation only based on ones with labels
            train_acc = num_correct / float(num_examples)
            print('\t\tClassifier train accuracy (subet w/labels):', train_acc)
            
            # Test set accuracy based on all data have labels
            num_examples = 0
            num_correct = 0
            for x, y in dataset.batches(batch_size, which_set='test'):
                num_examples += x.shape[0]
                # Droprate set to 0 given inference so all input features
                # should make it to the NN.
                # Note this does not calculate losses, it just feeds in x and
                # y and use the discriminator to preduct output and calculate
                # number correct over total number of examples.
                # TODO - verify if this is updating weights of discriminator
                # during inference.
                correct, = sess.run([gan_net.correct], 
                                    feed_dict={gan_net.input_real: x,
                                               gan_net.y: y,
                                               gan_net.drop_rate: 0.})
                num_correct += correct
            test_acc = num_correct / float(num_examples)
            print('\t\tClassifier test accuracy:', test_acc)
            print('\t\tStep time for 1 batch:', t2 - t1)
            t_epoch_end = time.time()
            print('\t\tEpoch time:', t_epoch_end - t_epoch_start)
            
            # Save accuracies
            train_accs.append(train_acc)
            test_accs.append(test_acc)
            
            # Visualize generated samples using generator trained weights 
            gen_samples = sess.run(gan_net.g_model, 
                                   feed_dict={gan_net.input_z: sample_z})
            samples.append(gen_samples)
            _ = view_samples(samples[-1], 5, 10, figsize=figsize)
            plt.show()
        
        saver.save(sess, './checkpoints/generator.ckpt')
    
    return train_accs, test_accs, samples    

Main

In [22]:
real_shape = (32, 32, 3)
z_shape = 100
learning_rate = 0.0003
epochs = 25
batch_size = 128

gan_net = GAN(real_shape, z_shape, learning_rate)
dataset = Dataset(trainset, testset)
In [23]:
train_acc, test_acc, samples = train(gan_net, dataset, epochs, batch_size, z_shape, figsize=(10, 5))
Epoch: 0
		Classifier train accuracy (subet w/labels): 0.181
		Classifier test accuracy: 0.18000921942224954
		Step time for 1 batch: 0.20109796524047852
		Epoch time: 340.14670300483704
Epoch: 1
		Classifier train accuracy (subet w/labels): 0.314
		Classifier test accuracy: 0.34964658881376764
		Step time for 1 batch: 0.18351078033447266
		Epoch time: 318.1429190635681
Epoch: 2
		Classifier train accuracy (subet w/labels): 0.544
		Classifier test accuracy: 0.5192839582052858
		Step time for 1 batch: 0.18638205528259277
		Epoch time: 309.2164921760559