CS109B Data Science 2: Advanced Topics in Data Science

Lab 10 - Autoencoders and Variational Autoencoders

Harvard University
Spring 2019
Instructors: Mark Glickman and Pavlos Protopapas


In [1]:
# !pip install imgaug
In [2]:
## load the libraries 
import sys
import warnings
import os
import glob
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import cv2
from sklearn.model_selection import train_test_split

from keras.layers import *
from keras.callbacks import EarlyStopping
from keras.utils import to_categorical
from keras.models import Model, Sequential
from keras.metrics import *
from keras.optimizers import Adam, RMSprop
from scipy.stats import norm
from keras.preprocessing import image

from keras import backend as K

from imgaug import augmenters
import matplotlib.pyplot as plt
plt.gray()
Using TensorFlow backend.

Part 1: Data

Reading data

Download the data given at the following link: . Use pandas and numpy to read in the data as a matrix

In [3]:
### read dataset 
train = pd.read_csv("data/fashion-mnist_train.csv")
train_x = train[list(train.columns)[1:]].values
train_x, val_x = train_test_split(train_x, test_size=0.15)

## create train and validation datasets
train_x, val_x = train_test_split(train_x, test_size=0.15)
In [4]:
## normalize and reshape
train_x = train_x/255.
val_x = val_x/255.

train_x = train_x.reshape(-1, 28, 28, 1)
val_x = val_x.reshape(-1, 28, 28, 1)
In [5]:
train_x.shape
Out[5]:
(43350, 28, 28, 1)

Visualizing Samples

Visualize 10 images from dataset

In [6]:
f, ax = plt.subplots(1,5)
f.set_size_inches(80, 40)
for i in range(5,10):
    ax[i-5].imshow(train_x[i, :, :, 0].reshape(28, 28))

Part 2: Denoise Images using AEs

Understanding AEs

Autoencoders

Add Noise to Images

Check out imgaug docs for more info and other ways to add noise.

In [7]:
# Lets add sample noise - Salt and Pepper
noise = augmenters.SaltAndPepper(0.1)
seq_object = augmenters.Sequential([noise])

train_x_n = seq_object.augment_images(train_x * 255) / 255
val_x_n = seq_object.augment_images(val_x * 255) / 255
In [8]:
f, ax = plt.subplots(1,5)
f.set_size_inches(80, 40)
for i in range(5,10):
    ax[i-5].imshow(train_x_n[i, :, :, 0].reshape(28, 28))

Setup Encoder Neural Network

Try different number of hidden layers, nodes?

In [9]:
# input layer
input_layer = Input(shape=(28, 28, 1))

# encoding architecture
encoded_layer1 = Conv2D(64, (3, 3), activation='relu', padding='same')(input_layer)
encoded_layer1 = MaxPool2D( (2, 2), padding='same')(encoded_layer1)
encoded_layer2 = Conv2D(32, (3, 3), activation='relu', padding='same')(encoded_layer1)
encoded_layer2 = MaxPool2D( (2, 2), padding='same')(encoded_layer2)
encoded_layer3 = Conv2D(16, (3, 3), activation='relu', padding='same')(encoded_layer2)
latent_view    = MaxPool2D( (2, 2), padding='same')(encoded_layer3)

Setup Decoder Neural Network

Try different number of hidden layers, nodes?

In [10]:
# decoding architecture
decoded_layer1 = Conv2D(16, (3, 3), activation='relu', padding='same')(latent_view)
decoded_layer1 = UpSampling2D((2, 2))(decoded_layer1)
decoded_layer2 = Conv2D(32, (3, 3), activation='relu', padding='same')(decoded_layer1)
decoded_layer2 = UpSampling2D((2, 2))(decoded_layer2)
decoded_layer3 = Conv2D(64, (3, 3), activation='relu')(decoded_layer2)
decoded_layer3 = UpSampling2D((2, 2))(decoded_layer3)
output_layer   = Conv2D(1, (3, 3), padding='same')(decoded_layer3)

Train AE

In [11]:
# compile the model
model = Model(input_layer, output_layer)
model.compile(optimizer='adam', loss='mse')
In [12]:
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 28, 28, 64)        640       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 14, 14, 32)        18464     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 7, 7, 32)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 16)          4624      
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 4, 4, 16)          0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 4, 4, 16)          2320      
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 8, 8, 16)          0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 8, 8, 32)          4640      
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 28, 28, 64)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 28, 28, 1)         577       
=================================================================
Total params: 49,761
Trainable params: 49,761
Non-trainable params: 0
_________________________________________________________________
In [13]:
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=5, mode='auto')
history = model.fit(train_x_n, train_x, epochs=20, batch_size=2048, validation_data=(val_x_n, val_x), callbacks=[early_stopping])
Train on 43350 samples, validate on 7650 samples
Epoch 1/20
43350/43350 [==============================] - 7s 154us/step - loss: 0.1083 - val_loss: 0.0654
Epoch 2/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0480 - val_loss: 0.0390
Epoch 3/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0356 - val_loss: 0.0326
Epoch 4/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0303 - val_loss: 0.0281
Epoch 5/20
43350/43350 [==============================] - 2s 48us/step - loss: 0.0264 - val_loss: 0.0248
Epoch 6/20
43350/43350 [==============================] - 2s 46us/step - loss: 0.0238 - val_loss: 0.0230
Epoch 7/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0223 - val_loss: 0.0217
Epoch 8/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0213 - val_loss: 0.0208
Epoch 9/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0205 - val_loss: 0.0201
Epoch 10/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0199 - val_loss: 0.0197
Epoch 11/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0193 - val_loss: 0.0191
Epoch 12/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0189 - val_loss: 0.0187
Epoch 13/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0184 - val_loss: 0.0182
Epoch 14/20
43350/43350 [==============================] - 2s 48us/step - loss: 0.0182 - val_loss: 0.0184
Epoch 15/20
43350/43350 [==============================] - 2s 48us/step - loss: 0.0179 - val_loss: 0.0176
Epoch 16/20
43350/43350 [==============================] - 2s 48us/step - loss: 0.0175 - val_loss: 0.0174
Epoch 17/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0174 - val_loss: 0.0172
Epoch 18/20
43350/43350 [==============================] - 2s 48us/step - loss: 0.0171 - val_loss: 0.0170
Epoch 19/20
43350/43350 [==============================] - 2s 48us/step - loss: 0.0169 - val_loss: 0.0167
Epoch 20/20
43350/43350 [==============================] - 2s 47us/step - loss: 0.0167 - val_loss: 0.0166

Visualize Intermediate Layers of AE

Visualize intermediate layers

In [14]:
# compile the model
model_2 = Model(input_layer, latent_view)
model_2.compile(optimizer='adam', loss='mse')
In [15]:
n = np.random.randint(0,len(val_x)-5)
f, ax = plt.subplots(1,5)
f.set_size_inches(80, 40)
for i,a in enumerate(range(n,n+5)):
    ax[i].imshow(val_x_n[a, :, :, 0].reshape(28, 28))
plt.show()
In [16]:
preds = model_2.predict(val_x_n[n:n+5])
preds.shape
Out[16]:
(5, 4, 4, 16)
In [17]:
f, ax = plt.subplots(16,5)
ax = ax.ravel()
f.set_size_inches(20, 40)
for j in range(16):
    for i,a in enumerate(range(n,n+5)):
        ax[j*5 + i].imshow(preds[i, :, :, j])
plt.show()

Visualize Samples reconstructed by AE

In [18]:
n = np.random.randint(0,len(val_x)-5)
In [19]:
f, ax = plt.subplots(1,5)
f.set_size_inches(80, 40)
for i,a in enumerate(range(n,n+5)):
    ax[i].imshow(val_x[a, :, :, 0].reshape(28, 28))
In [20]:
f, ax = plt.subplots(1,5)
f.set_size_inches(80, 40)
for i,a in enumerate(range(n,n+5)):
    ax[i].imshow(val_x_n[a, :, :, 0].reshape(28, 28))
In [21]:
preds = model.predict(val_x_n[n:n+5])
f, ax = plt.subplots(1,5)
f.set_size_inches(80, 40)
for i,a in enumerate(range(n,n+5)):
    ax[i].imshow(preds[i].reshape(28, 28))
plt.show()

Part 3: Exercise: Denoising noisy documents

In [22]:
TRAIN_IMAGES = glob.glob('data/train/*.png')
CLEAN_IMAGES = glob.glob('data/train_cleaned/*.png')
TEST_IMAGES = glob.glob('data/test/*.png')
In [23]:
plt.figure(figsize=(20,8))
img = cv2.imread('data/train/101.png', 0)
plt.imshow(img, cmap='gray')
print(img.shape)
(420, 540)
In [24]:
def load_image(path):
    image_list = np.zeros((len(path), 258, 540, 1))
    for i, fig in enumerate(path):
        img = image.load_img(fig, grayscale=True, target_size=(258, 540))
        x = image.img_to_array(img).astype('float32')
        x = x / 255.0
        image_list[i] = x
    
    return image_list

x_train = load_image(TRAIN_IMAGES)
y_train = load_image(CLEAN_IMAGES)
x_test = load_image(TEST_IMAGES)

print(x_train.shape, x_test.shape)
(144, 258, 540, 1) (72, 258, 540, 1)
In [25]:
x_train.shape
Out[25]:
(144, 258, 540, 1)
In [26]:
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.15)
print(x_train.shape, x_val.shape)
(122, 258, 540, 1) (22, 258, 540, 1)
In [27]:
plt.imshow(x_train[0, :, :, 0])
Out[27]:
In [28]:
plt.imshow(y_train[0, :, :, 0])
Out[28]:
In [29]:
input_layer = Input(shape=(258, 540, 1))
        
# encoder
encoder = Conv2D(64, (3, 3), activation='relu', padding='same')(input_layer)
encoder = MaxPooling2D((2, 2), padding='same')(encoder)

# decoder
decoder = Conv2D(64, (3, 3), activation='relu', padding='same')(encoder)
decoder = UpSampling2D((2, 2))(decoder)
output_layer = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(decoder)

ae = Model(input_layer, output_layer)
In [30]:
ae.compile(loss='mse', optimizer=Adam(lr=0.001))
ae.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         (None, 258, 540, 1)       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 258, 540, 64)      640       
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 129, 270, 64)      0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 129, 270, 64)      36928     
_________________________________________________________________
up_sampling2d_4 (UpSampling2 (None, 258, 540, 64)      0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 258, 540, 1)       577       
=================================================================
Total params: 38,145
Trainable params: 38,145
Non-trainable params: 0
_________________________________________________________________
In [31]:
batch_size = 16
epochs = 200

early_stopping = EarlyStopping(monitor='val_loss',min_delta=0,patience=5,verbose=1, mode='auto')
history = ae.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_val, y_val), callbacks=[early_stopping])
Train on 122 samples, validate on 22 samples
Epoch 1/200
122/122 [==============================] - 2s 17ms/step - loss: 0.1448 - val_loss: 0.0700
Epoch 2/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0729 - val_loss: 0.0763
Epoch 3/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0772 - val_loss: 0.0742
Epoch 4/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0700 - val_loss: 0.0643
Epoch 5/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0624 - val_loss: 0.0576
Epoch 6/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0546 - val_loss: 0.0491
Epoch 7/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0453 - val_loss: 0.0395
Epoch 8/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0374 - val_loss: 0.0331
Epoch 9/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0323 - val_loss: 0.0290
Epoch 10/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0288 - val_loss: 0.0260
Epoch 11/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0262 - val_loss: 0.0238
Epoch 12/200
122/122 [==============================] - 1s 9ms/step - loss: 0.0241 - val_loss: 0.0218
Epoch 13/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0222 - val_loss: 0.0201
Epoch 14/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0205 - val_loss: 0.0184
Epoch 15/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0189 - val_loss: 0.0171
Epoch 16/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0178 - val_loss: 0.0158
Epoch 17/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0165 - val_loss: 0.0148
Epoch 18/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0156 - val_loss: 0.0140
Epoch 19/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0146 - val_loss: 0.0134
Epoch 20/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0139 - val_loss: 0.0125
Epoch 21/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0131 - val_loss: 0.0116
Epoch 22/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0124 - val_loss: 0.0111
Epoch 23/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0119 - val_loss: 0.0106
Epoch 24/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0114 - val_loss: 0.0101
Epoch 25/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0108 - val_loss: 0.0097
Epoch 26/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0104 - val_loss: 0.0094
Epoch 27/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0101 - val_loss: 0.0091
Epoch 28/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0099 - val_loss: 0.0086
Epoch 29/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0095 - val_loss: 0.0084
Epoch 30/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0092 - val_loss: 0.0081
Epoch 31/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0089 - val_loss: 0.0079
Epoch 32/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0087 - val_loss: 0.0077
Epoch 33/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0085 - val_loss: 0.0075
Epoch 34/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0082 - val_loss: 0.0073
Epoch 35/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0080 - val_loss: 0.0072
Epoch 36/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0079 - val_loss: 0.0071
Epoch 37/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0079 - val_loss: 0.0070
Epoch 38/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0077 - val_loss: 0.0070
Epoch 39/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0076 - val_loss: 0.0066
Epoch 40/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0073 - val_loss: 0.0064
Epoch 41/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0071 - val_loss: 0.0064
Epoch 42/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0070 - val_loss: 0.0064
Epoch 43/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0070 - val_loss: 0.0062
Epoch 44/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0068 - val_loss: 0.0060
Epoch 45/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0068 - val_loss: 0.0060
Epoch 46/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0067 - val_loss: 0.0059
Epoch 47/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0065 - val_loss: 0.0057
Epoch 48/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0064 - val_loss: 0.0057
Epoch 49/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0063 - val_loss: 0.0056
Epoch 50/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0063 - val_loss: 0.0055
Epoch 51/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0064 - val_loss: 0.0059
Epoch 52/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0062 - val_loss: 0.0055
Epoch 53/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0060 - val_loss: 0.0053
Epoch 54/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0059 - val_loss: 0.0052
Epoch 55/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0058 - val_loss: 0.0052
Epoch 56/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0057 - val_loss: 0.0051
Epoch 57/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0057 - val_loss: 0.0050
Epoch 58/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0056 - val_loss: 0.0050
Epoch 59/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0056 - val_loss: 0.0050
Epoch 60/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0055 - val_loss: 0.0050
Epoch 61/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0054 - val_loss: 0.0048
Epoch 62/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0053 - val_loss: 0.0047
Epoch 63/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0053 - val_loss: 0.0048
Epoch 64/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0052 - val_loss: 0.0046
Epoch 65/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0051 - val_loss: 0.0046
Epoch 66/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0051 - val_loss: 0.0045
Epoch 67/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0050 - val_loss: 0.0045
Epoch 68/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0050 - val_loss: 0.0044
Epoch 69/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0049 - val_loss: 0.0044
Epoch 70/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0049 - val_loss: 0.0043
Epoch 71/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0048 - val_loss: 0.0044
Epoch 72/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0048 - val_loss: 0.0044
Epoch 73/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0047 - val_loss: 0.0042
Epoch 74/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0046 - val_loss: 0.0041
Epoch 75/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0046 - val_loss: 0.0041
Epoch 76/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0045 - val_loss: 0.0041
Epoch 77/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0045 - val_loss: 0.0040
Epoch 78/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0045 - val_loss: 0.0041
Epoch 79/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0045 - val_loss: 0.0039
Epoch 80/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0044 - val_loss: 0.0041
Epoch 81/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0044 - val_loss: 0.0039
Epoch 82/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0043 - val_loss: 0.0038
Epoch 83/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0042 - val_loss: 0.0038
Epoch 84/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0042 - val_loss: 0.0038
Epoch 85/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0042 - val_loss: 0.0037
Epoch 86/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0041 - val_loss: 0.0037
Epoch 87/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0041 - val_loss: 0.0037
Epoch 88/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0041 - val_loss: 0.0037
Epoch 89/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0040 - val_loss: 0.0036
Epoch 90/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0040 - val_loss: 0.0036
Epoch 91/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0039 - val_loss: 0.0035
Epoch 92/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0039 - val_loss: 0.0035
Epoch 93/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0038 - val_loss: 0.0035
Epoch 94/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0038 - val_loss: 0.0034
Epoch 95/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0038 - val_loss: 0.0034
Epoch 96/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0037 - val_loss: 0.0034
Epoch 97/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0037 - val_loss: 0.0033
Epoch 98/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0037 - val_loss: 0.0033
Epoch 99/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0036 - val_loss: 0.0033
Epoch 100/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0036 - val_loss: 0.0033
Epoch 101/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0036 - val_loss: 0.0033
Epoch 102/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0036 - val_loss: 0.0032
Epoch 103/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0035 - val_loss: 0.0032
Epoch 104/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0035 - val_loss: 0.0031
Epoch 105/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0035 - val_loss: 0.0033
Epoch 106/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0035 - val_loss: 0.0031
Epoch 107/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0034 - val_loss: 0.0031
Epoch 108/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0034 - val_loss: 0.0032
Epoch 109/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0034 - val_loss: 0.0030
Epoch 110/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0033 - val_loss: 0.0031
Epoch 111/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0033 - val_loss: 0.0030
Epoch 112/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0033 - val_loss: 0.0030
Epoch 113/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0033 - val_loss: 0.0031
Epoch 114/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0033 - val_loss: 0.0029
Epoch 115/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0032 - val_loss: 0.0029
Epoch 116/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0032 - val_loss: 0.0029
Epoch 117/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0032 - val_loss: 0.0029
Epoch 118/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0032 - val_loss: 0.0029
Epoch 119/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0032 - val_loss: 0.0029
Epoch 120/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0032 - val_loss: 0.0029
Epoch 121/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0032 - val_loss: 0.0029
Epoch 122/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0031 - val_loss: 0.0028
Epoch 123/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0031 - val_loss: 0.0028
Epoch 124/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0031 - val_loss: 0.0028
Epoch 125/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0031 - val_loss: 0.0029
Epoch 126/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0030 - val_loss: 0.0028
Epoch 127/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0030 - val_loss: 0.0028
Epoch 128/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0030 - val_loss: 0.0027
Epoch 129/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0030 - val_loss: 0.0028
Epoch 130/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0030 - val_loss: 0.0027
Epoch 131/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0030 - val_loss: 0.0028
Epoch 132/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0030 - val_loss: 0.0029
Epoch 133/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0030 - val_loss: 0.0027
Epoch 134/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0029 - val_loss: 0.0028
Epoch 135/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0030 - val_loss: 0.0027
Epoch 136/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0029 - val_loss: 0.0026
Epoch 137/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0029 - val_loss: 0.0026
Epoch 138/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0029 - val_loss: 0.0026
Epoch 139/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0026
Epoch 140/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0026
Epoch 141/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0026
Epoch 142/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0026
Epoch 143/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0026
Epoch 144/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0025
Epoch 145/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0026
Epoch 146/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0025
Epoch 147/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0026
Epoch 148/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0028 - val_loss: 0.0025
Epoch 149/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0025
Epoch 150/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0025
Epoch 151/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0025
Epoch 152/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0025
Epoch 153/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0025
Epoch 154/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0025
Epoch 155/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0024
Epoch 156/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0025
Epoch 157/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0024
Epoch 158/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0026 - val_loss: 0.0025
Epoch 159/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0026 - val_loss: 0.0024
Epoch 160/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0027 - val_loss: 0.0024
Epoch 161/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0026 - val_loss: 0.0025
Epoch 162/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0026 - val_loss: 0.0024
Epoch 163/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0026 - val_loss: 0.0024
Epoch 164/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 165/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 166/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 167/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 168/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 169/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 170/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 171/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0024
Epoch 172/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 173/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 174/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 175/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0023
Epoch 176/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 177/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 178/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0024
Epoch 179/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 180/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 181/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 182/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 183/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 184/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 185/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 186/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 187/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 188/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 189/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0023 - val_loss: 0.0022
Epoch 190/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0023
Epoch 191/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0022
Epoch 192/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0025 - val_loss: 0.0021
Epoch 193/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0022
Epoch 194/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0023 - val_loss: 0.0022
Epoch 195/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0023 - val_loss: 0.0021
Epoch 196/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0023 - val_loss: 0.0023
Epoch 197/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0023
Epoch 198/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0024 - val_loss: 0.0021
Epoch 199/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0023 - val_loss: 0.0021
Epoch 200/200
122/122 [==============================] - 1s 8ms/step - loss: 0.0023 - val_loss: 0.0021
Epoch 00200: early stopping
In [32]:
preds = ae.predict(x_test)
In [33]:
n = 25
preds_0 = preds[n] * 255.0
preds_0 = preds_0.reshape(258, 540)
x_test_0 = x_test[n] * 255.0
x_test_0 = x_test_0.reshape(258, 540)
plt.imshow(x_test_0, cmap='gray')
Out[33]:
In [34]:
plt.imshow(preds_0, cmap='gray')
Out[34]:

Part 4: Generating New Fashion using VAEs

Understanding VAEs

"VAE"

Reset data

In [35]:
### read dataset 
train = pd.read_csv("data/fashion-mnist_train.csv")
train_x = train[list(train.columns)[1:]].values
train_x, val_x = train_test_split(train_x, test_size=0.2)

## create train and validation datasets
train_x, val_x = train_test_split(train_x, test_size=0.2)
In [36]:
## normalize and reshape
train_x = train_x/255.
val_x = val_x/255.

train_x = train_x.reshape(-1, 28, 28, 1)
val_x = val_x.reshape(-1, 28, 28, 1)

Setup Encoder Neural Network

Try different number of hidden layers, nodes?

In [37]:
import keras.backend as K
In [38]:
batch_size = 16
latent_dim = 2  # Number of latent dimension parameters

input_img = Input(shape=(784,), name="input")
x = Dense(512, activation='relu', name="intermediate_encoder")(input_img)
x = Dense(2, activation='relu', name="latent_encoder")(x)

z_mu = Dense(latent_dim)(x)
z_log_sigma = Dense(latent_dim)(x)
In [39]:
# sampling function
def sampling(args):
    z_mu, z_log_sigma = args
    epsilon = K.random_normal(shape=(K.shape(z_mu)[0], latent_dim),
                              mean=0., stddev=1.)
    return z_mu + K.exp(z_log_sigma) * epsilon

# sample vector from the latent distribution
z = Lambda(sampling)([z_mu, z_log_sigma])
In [40]:
# decoder takes the latent distribution sample as input
decoder_input = Input((2,), name="input_decoder")

x = Dense(512, activation='relu', name="intermediate_decoder", input_shape=(2,))(decoder_input)

# Expand to 784 total pixels
x = Dense(784, activation='sigmoid', name="original_decoder")(x)

# decoder model statement
decoder = Model(decoder_input, x)

# apply the decoder to the sample from the latent distribution
z_decoded = decoder(z)
In [41]:
decoder.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_decoder (InputLayer)   (None, 2)                 0         
_________________________________________________________________
intermediate_decoder (Dense) (None, 512)               1536      
_________________________________________________________________
original_decoder (Dense)     (None, 784)               402192    
=================================================================
Total params: 403,728
Trainable params: 403,728
Non-trainable params: 0
_________________________________________________________________
In [42]:
# construct a custom layer to calculate the loss
class CustomVariationalLayer(Layer):

    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        # Reconstruction loss
        xent_loss = binary_crossentropy(x, z_decoded)
        return xent_loss

    # adds the custom loss to the class
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x

# apply the custom loss to the input images and the decoded latent distribution sample
y = CustomVariationalLayer()([input_img, z_decoded])
In [43]:
z_decoded
Out[43]:
In [44]:
# VAE model statement
vae = Model(input_img, y)
vae.compile(optimizer='rmsprop', loss=None)
In [45]:
vae.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input (InputLayer)              (None, 784)          0                                            
__________________________________________________________________________________________________
intermediate_encoder (Dense)    (None, 512)          401920      input[0][0]                      
__________________________________________________________________________________________________
latent_encoder (Dense)          (None, 2)            1026        intermediate_encoder[0][0]       
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 2)            6           latent_encoder[0][0]             
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 2)            6           latent_encoder[0][0]             
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 2)            0           dense_1[0][0]                    
                                                                 dense_2[0][0]                    
__________________________________________________________________________________________________
model_4 (Model)                 (None, 784)          403728      lambda_1[0][0]                   
__________________________________________________________________________________________________
custom_variational_layer_1 (Cus [(None, 784), (None, 0           input[0][0]                      
                                                                 model_4[1][0]                    
==================================================================================================
Total params: 806,686
Trainable params: 806,686
Non-trainable params: 0
__________________________________________________________________________________________________
In [46]:
train_x.shape
Out[46]:
(38400, 28, 28, 1)
In [47]:
train_x = train_x.reshape(-1, 784)
val_x = val_x.reshape(-1, 784)
In [48]:
vae.fit(x=train_x, y=None,
        shuffle=True,
        epochs=20,
        batch_size=batch_size,
        validation_data=(val_x, None))
Train on 38400 samples, validate on 9600 samples
Epoch 1/20
38400/38400 [==============================] - 10s 258us/step - loss: 0.3728 - val_loss: 0.3446
Epoch 2/20
38400/38400 [==============================] - 9s 244us/step - loss: 0.3423 - val_loss: 0.3378
Epoch 3/20
38400/38400 [==============================] - 9s 245us/step - loss: 0.3369 - val_loss: 0.3380
Epoch 4/20
38400/38400 [==============================] - 9s 244us/step - loss: 0.3345 - val_loss: 0.3418
Epoch 5/20
38400/38400 [==============================] - 9s 243us/step - loss: 0.3329 - val_loss: 0.3353
Epoch 6/20
38400/38400 [==============================] - 9s 244us/step - loss: 0.3318 - val_loss: 0.3318
Epoch 7/20
38400/38400 [==============================] - 9s 242us/step - loss: 0.3309 - val_loss: 0.3307
Epoch 8/20
38400/38400 [==============================] - 9s 245us/step - loss: 0.3302 - val_loss: 0.3301
Epoch 9/20
38400/38400 [==============================] - 9s 241us/step - loss: 0.3297 - val_loss: 0.3297
Epoch 10/20
38400/38400 [==============================] - 9s 241us/step - loss: 0.3293 - val_loss: 0.3298
Epoch 11/20
38400/38400 [==============================] - 9s 242us/step - loss: 0.3289 - val_loss: 0.3312
Epoch 12/20
38400/38400 [==============================] - 9s 246us/step - loss: 0.3285 - val_loss: 0.3290
Epoch 13/20
38400/38400 [==============================] - 9s 245us/step - loss: 0.3283 - val_loss: 0.3284
Epoch 14/20
38400/38400 [==============================] - 9s 244us/step - loss: 0.3279 - val_loss: 0.3311
Epoch 15/20
38400/38400 [==============================] - 9s 245us/step - loss: 0.3277 - val_loss: 0.3298
Epoch 16/20
38400/38400 [==============================] - 9s 245us/step - loss: 0.3276 - val_loss: 0.3276
Epoch 17/20
38400/38400 [==============================] - 9s 243us/step - loss: 0.3274 - val_loss: 0.3285
Epoch 18/20
38400/38400 [==============================] - 9s 245us/step - loss: 0.3273 - val_loss: 0.3287
Epoch 19/20
38400/38400 [==============================] - 9s 245us/step - loss: 0.3271 - val_loss: 0.3292
Epoch 20/20
38400/38400 [==============================] - 9s 244us/step - loss: 0.3269 - val_loss: 0.3278
Out[48]:
In [49]:
# Display a 2D manifold of the samples
n = 20  # figure with 20x20 samples
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))

# Construct grid of latent variable values - can change values here to generate different things
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

# decode for each square in the grid
for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        
        x_decoded = decoder.predict(z_sample, batch_size=batch_size)
        
        digit = x_decoded[0].reshape(digit_size, digit_size)
        
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(20, 20))
plt.imshow(figure)
plt.show()  
In [50]:
### read dataset 
train = pd.read_csv("data/fashion-mnist_train.csv")
train_x = train[list(train.columns)[1:]].values
train_y = train[list(train.columns)[0]].values

train_x = train_x/255.
# train_x = train_x.reshape(-1, 28, 28, 1)

# Translate into the latent space
encoder = Model(input_img, z_mu)
x_valid_noTest_encoded = encoder.predict(train_x, batch_size=batch_size)
plt.figure(figsize=(10, 10))
plt.scatter(x_valid_noTest_encoded[:, 0], x_valid_noTest_encoded[:, 1], c=train_y, cmap='brg')
plt.colorbar()
plt.show()

Part 5: Exercise: Generating New Fashion using VAEs: Adding CNNs and KL Divergence Loss

In [51]:
batch_size = 16
latent_dim = 2  # Number of latent dimension parameters

# Encoder architecture: Input -> Conv2D*4 -> Flatten -> Dense
input_img = Input(shape=(28, 28, 1))

x = Conv2D(32, 3,
                  padding='same', 
                  activation='relu')(input_img)
x = Conv2D(64, 3,
                  padding='same', 
                  activation='relu',
                  strides=(2, 2))(x)
x = Conv2D(64, 3,
                  padding='same', 
                  activation='relu')(x)
x = Conv2D(64, 3,
                  padding='same', 
                  activation='relu')(x)

# need to know the shape of the network here for the decoder
shape_before_flattening = K.int_shape(x)

x = Flatten()(x)
x = Dense(32, activation='relu')(x)

# Two outputs, latent mean and (log)variance
z_mu = Dense(latent_dim)(x)
z_log_sigma = Dense(latent_dim)(x)

Set up sampling function

In [52]:
# sampling function
def sampling(args):
    z_mu, z_log_sigma = args
    epsilon = K.random_normal(shape=(K.shape(z_mu)[0], latent_dim),
                              mean=0., stddev=1.)
    return z_mu + K.exp(z_log_sigma) * epsilon

# sample vector from the latent distribution
z = Lambda(sampling)([z_mu, z_log_sigma])

Setup Decoder Neural Network

Try different number of hidden layers, nodes?

In [53]:
# decoder takes the latent distribution sample as input
decoder_input = Input(K.int_shape(z)[1:])

# Expand to 784 total pixels
x = Dense(np.prod(shape_before_flattening[1:]),
                 activation='relu')(decoder_input)

# reshape
x = Reshape(shape_before_flattening[1:])(x)

# use Conv2DTranspose to reverse the conv layers from the encoder
x = Conv2DTranspose(32, 3,
                           padding='same', 
                           activation='relu',
                           strides=(2, 2))(x)
x = Conv2D(1, 3,
                  padding='same', 
                  activation='sigmoid')(x)

# decoder model statement
decoder = Model(decoder_input, x)

# apply the decoder to the sample from the latent distribution
z_decoded = decoder(z)

Set up loss functions

In [54]:
# construct a custom layer to calculate the loss
class CustomVariationalLayer(Layer):

    def vae_loss(self, x, z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        # Reconstruction loss
        xent_loss = binary_crossentropy(x, z_decoded)
        # KL divergence
        kl_loss = -5e-4 * K.mean(1 + z_log_sigma - K.square(z_mu) - K.exp(z_log_sigma), axis=-1)
        return K.mean(xent_loss + kl_loss)

    # adds the custom loss to the class
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x, z_decoded)
        self.add_loss(loss, inputs=inputs)
        return x

# apply the custom loss to the input images and the decoded latent distribution sample
y = CustomVariationalLayer()([input_img, z_decoded])

Train VAE

In [55]:
# VAE model statement
vae = Model(input_img, y)
vae.compile(optimizer='rmsprop', loss=None)
In [56]:
vae.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 28, 28, 32)   320         input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 14, 14, 64)   18496       conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 14, 14, 64)   36928       conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 14, 14, 64)   36928       conv2d_13[0][0]                  
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 12544)        0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 32)           401440      flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 2)            66          dense_3[0][0]                    
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 2)            66          dense_3[0][0]                    
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 2)            0           dense_4[0][0]                    
                                                                 dense_5[0][0]                    
__________________________________________________________________________________________________
model_7 (Model)                 (None, 28, 28, 1)    56385       lambda_2[0][0]                   
__________________________________________________________________________________________________
custom_variational_layer_2 (Cus [(None, 28, 28, 1),  0           input_3[0][0]                    
                                                                 model_7[1][0]                    
==================================================================================================
Total params: 550,629
Trainable params: 550,629
Non-trainable params: 0
__________________________________________________________________________________________________
In [57]:
train_x = train_x.reshape(-1, 28, 28, 1)
val_x = val_x.reshape(-1, 28, 28, 1)
In [58]:
vae.fit(x=train_x, y=None,
        shuffle=True,
        epochs=20,
        batch_size=batch_size,
        validation_data=(val_x, None))
Train on 60000 samples, validate on 9600 samples
Epoch 1/20
60000/60000 [==============================] - 30s 500us/step - loss: 0.3938 - val_loss: 0.3383
Epoch 2/20
60000/60000 [==============================] - 29s 486us/step - loss: 0.3371 - val_loss: 0.3379
Epoch 3/20
60000/60000 [==============================] - 29s 490us/step - loss: 0.3337 - val_loss: 0.3315
Epoch 4/20
60000/60000 [==============================] - 29s 485us/step - loss: 0.3315 - val_loss: 0.3306
Epoch 5/20
60000/60000 [==============================] - 29s 489us/step - loss: 0.3305 - val_loss: 0.3335
Epoch 6/20
60000/60000 [==============================] - 29s 484us/step - loss: 0.3289 - val_loss: 0.3291
Epoch 7/20
60000/60000 [==============================] - 29s 491us/step - loss: 0.3283 - val_loss: 0.3299
Epoch 8/20
60000/60000 [==============================] - 29s 489us/step - loss: 0.3282 - val_loss: 0.3322
Epoch 9/20
60000/60000 [==============================] - 29s 488us/step - loss: 0.3286 - val_loss: 0.3276
Epoch 10/20
60000/60000 [==============================] - 29s 489us/step - loss: 0.3271 - val_loss: 0.3295
Epoch 11/20
60000/60000 [==============================] - 29s 487us/step - loss: 0.3261 - val_loss: 0.3267
Epoch 12/20
60000/60000 [==============================] - 29s 484us/step - loss: 0.3254 - val_loss: 0.3260
Epoch 13/20
60000/60000 [==============================] - 29s 486us/step - loss: 0.3249 - val_loss: 0.3311
Epoch 14/20
60000/60000 [==============================] - 29s 489us/step - loss: 0.3244 - val_loss: 0.3244
Epoch 15/20
60000/60000 [==============================] - 29s 491us/step - loss: 0.3240 - val_loss: 0.3251
Epoch 16/20
60000/60000 [==============================] - 29s 487us/step - loss: 0.3237 - val_loss: 0.3240
Epoch 17/20
60000/60000 [==============================] - 29s 491us/step - loss: 0.3234 - val_loss: 0.3231
Epoch 18/20
60000/60000 [==============================] - 29s 489us/step - loss: 0.3232 - val_loss: 0.3233
Epoch 19/20
60000/60000 [==============================] - 29s 488us/step - loss: 0.3230 - val_loss: 0.3236
Epoch 20/20
60000/60000 [==============================] - 29s 489us/step - loss: 0.3227 - val_loss: 0.3231
Out[58]:

Visualize Samples reconstructed by VAE

In [59]:
# Display a 2D manifold of the samples
n = 20  # figure with 20x20 samples
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))

# Construct grid of latent variable values - can change values here to generate different things
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

# decode for each square in the grid
for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        z_sample = np.tile(z_sample, batch_size).reshape(batch_size, 2)
        x_decoded = decoder.predict(z_sample, batch_size=batch_size)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(20, 20))
plt.imshow(figure)
plt.show()  

TODO:

VAE: Visualize latent space

In [60]:
train = pd.read_csv("data/fashion-mnist_train.csv")
In [61]:
### read dataset 
train = pd.read_csv("data/fashion-mnist_train.csv")
train_x = train[list(train.columns)[1:]].values
train_y = train[list(train.columns)[0]].values

train_x = train_x/255.
train_x = train_x.reshape(-1, 28, 28, 1)
In [62]:
# Translate into the latent space
encoder = Model(input_img, z_mu)
x_valid_noTest_encoded = encoder.predict(train_x, batch_size=batch_size)
plt.figure(figsize=(10, 10))
plt.scatter(x_valid_noTest_encoded[:, 0], x_valid_noTest_encoded[:, 1], c=train_y, cmap='brg')
plt.colorbar()
plt.show()