{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CS-109B Introduction to Data Science\n",
"## Lab 6: Convolutional Neural Networks 2\n",
"\n",
"**Harvard University** \n",
"**Spring 2020** \n",
"**Instructors:** Mark Glickman, Pavlos Protopapas, and Chris Tanner \n",
"**Lab Instructors:** Chris Tanner and Eleni Angelaki Kaxiras \n",
"**Content:** Eleni Angelaki Kaxiras, Cedric Flamant, Pavlos Protopapas\n",
"\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# RUN THIS CELL TO PROPERLY HIGHLIGHT THE EXERCISES\n",
"import requests\n",
"from IPython.core.display import HTML\n",
"styles = requests.get(\"https://raw.githubusercontent.com/Harvard-IACS/2019-CS109B/master/content/styles/cs109.css\").text\n",
"HTML(styles)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning Goals\n",
"\n",
"In this lab we will continue with Convolutional Neural Networks (CNNs), will look into the `tf.data` interface which enables us to build complex input pipelines for our data. We will also touch upon visualization techniques to peak into our CNN's hidden layers.\n",
"\n",
"By the end of this lab, you should be able to:\n",
"\n",
"- know how a CNN works from start to finish\n",
"- use `tf.data.Dataset` to import and, if needed, transform, your data for feeding into the network. Transformations might include normalization, scaling, tilting, resizing, or applying other data augmentation techniques.\n",
"- understand how `saliency maps` are implemented with code."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" \n",
"\n",
"## Table of Contents\n",
"\n",
"1. **Part 1**: [Beginning-to-end Convolutional Neural Networks](#part1).\n",
"2. **Part 2**: [Image Pipelines with `tf.data.Dataset`](#part2). \n",
"3. **Part 3**: [Hidden Layer Visualization, Saliency Maps](#part3)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from scipy.optimize import minimize\n",
"from sklearn.utils import shuffle\n",
"\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams[\"figure.figsize\"] = (5,5)\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras.models import Sequential, Model\n",
"from tensorflow.keras.layers import Dense, Conv2D, Conv1D, MaxPooling2D, MaxPooling1D,\\\n",
" Dropout, Flatten, Activation, Input\n",
"from tensorflow.keras.optimizers import Adam, SGD, RMSprop\n",
"from tensorflow.keras.utils import to_categorical\n",
"from tensorflow.keras.metrics import AUC, Precision, Recall, FalsePositives, \\\n",
" FalseNegatives, TruePositives, TrueNegatives\n",
"from tensorflow.keras.preprocessing import image\n",
"from tensorflow.keras.regularizers import l2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.1.0\n",
"0 GPUs\n"
]
}
],
"source": [
"from __future__ import absolute_import, division, print_function, unicode_literals\n",
"tf.keras.backend.clear_session() # For easy reset of notebook state.\n",
"print(tf.__version__) # You should see a > 2.0.0 here!\n",
"from tf_keras_vis.utils import print_gpus\n",
"print_gpus()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"## Additional Packages required if you don't already have them\n",
"# While in your conda environment,\n",
"\n",
"# imageio\n",
"# Install using \"conda install imageio\"\n",
"# pillow\n",
"# Install using \"conda install pillow\"\n",
"# tensorflow-datasets\n",
"# Install using \"conda install tensorflow-datasets\"\n",
"# tf-keras-vis\n",
"# Install using \"pip install tf-keras-vis\"\n",
"# tensorflow-addons\n",
"# Install using \"pip install tensorflow-addons\""
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from tf_keras_vis.saliency import Saliency\n",
"from tf_keras_vis.utils import normalize\n",
"import tf_keras_vis.utils as utils\n",
"from matplotlib import cm\n",
"from tf_keras_vis.gradcam import Gradcam"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(109)\n",
"tf.random.set_seed(109)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Part 0: Running on SEAS JupyterHub\n",
"\n",
"**PLEASE READ**: [Instructions for Using SEAS JupyterHub](https://canvas.harvard.edu/courses/65462/pages/instructions-for-using-seas-jupyterhub?module_item_id=638544)\n",
"\n",
"SEAS and FAS are providing you with a platform in AWS to use for the class (accessible from the 'Jupyter' menu link in Canvas). These are AWS p2 instances with a GPU, 10GB of disk space, and 61 GB of RAM, for faster training for your networks. Most of the libraries such as keras, tensorflow, pandas, etc. are pre-installed. If a library is missing you may install it via the Terminal.\n",
"\n",
"**NOTE: The AWS platform is funded by SEAS and FAS for the purposes of the class. It is FREE for you - not running against your personal AWS credit. For this reason you are only allowed to use it for purposes related to this course, and with prudence.**\n",
"\n",
"**Help us keep this service: Make sure you stop your instance as soon as you do not need it. Your instance will terminate after 30 min of inactivity.**\n",
"\n",
"\n",
"*source: CS231n Stanford, Google Cloud Tutorial*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
"## Part 1: Beginning-to-end Convolutional Neural Networks\n",
"\n",
"\n",
"\n",
"*image [source](http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/)*\n",
"
\n",
"We will go through the various steps of training a CNN, including:\n",
"- difference between cross-validation and validation\n",
"- specifying a loss, metrics, and an optimizer,\n",
"- performing validation,\n",
"- using callbacks, specifically `EarlyStopping`, which stops the training when training is no longer improving the validation metrics,\n",
"- learning rate significance\n",
"
\n",
"
Table Exercise: Use the whiteboard next to your table to draw a CNN from start to finish as per the instructions. We will then draw it together in class.
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" [Back to Table of Contents](#top)\n",
"\n",
"## Part 2: Image Preprocessing: Using `tf.data.Dataset`"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow_addons as tfa\n",
"import tensorflow_datasets as tfds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" `tf.data` API in `tensorflow` enables you to build complex **input pipelines** from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training. \n",
"\n",
"The pipeline for a text model might involve extracting symbols from raw text data, converting them to embedding identifiers with a lookup table, and batching together sequences of different lengths. The `tf.data API` makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations.\n",
"\n",
"The `tf.data API` introduces a `tf.data.Dataset` that represents a sequence of **elements**, consistινγ of one or more **components**. For example, in an image pipeline, an element might be a single training example, with a pair of tensor components representing the image and its label.\n",
"\n",
"To create an input pipeline, you must start with a data **source**. For example, to construct a Dataset from data in memory, you can use `tf.data.Dataset.from_tensors()` or `tf.data.Dataset.from_tensor_slices()`. Alternatively, if your input data is stored in a file in the recommended TFRecord format, you can use `tf.data.TFRecordDataset()`.\n",
"\n",
"The Dataset object is a Python iterable. You may view its elements using a for loop:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[4 3 1 9 7 4 8 9 4 6]\n",
"[9 6 2 2 6 4 7 2 9 8]\n",
"[5 7 5 4 8 5 6 4 8 4]\n",
"[6 2 2 2 6 6 4 2 2 2]\n"
]
}
],
"source": [
"dataset = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))\n",
"\n",
"for elem in dataset:\n",
" print(elem.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once you have a Dataset object, you can **transform** it into a new Dataset by chaining method calls on the `tf.data.Dataset` object. For example, you can apply per-element transformations such as `Dataset.map()`, and multi-element transformations such as `Dataset.batch()`. See the [documentation](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for `tf.data.Dataset` for a complete list of transformations.\n",
"\n",
"The `map` function takes a function and returns a new and augmented dataset. "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 8 6 2 18 14 8 16 18 8 12]\n",
"[18 12 4 4 12 8 14 4 18 16]\n",
"[10 14 10 8 16 10 12 8 16 8]\n",
"[12 4 4 4 12 12 8 4 4 4]\n"
]
}
],
"source": [
"dataset = dataset.map(lambda x: x*2) \n",
"for elem in dataset:\n",
" print(elem.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Datasets are powerful objects because they are effectively dictionaries that can store tensors and other data such as the response variable. We can also construct them by passing small sized `numpy` arrays, such as in the following example.\n",
"\n",
"Tensorflow has a plethora of them:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# uncomment to see available datasets\n",
"#tfds.list_builders()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### `mnist` dataset"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((60000, 28, 28), (60000,))"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# load mnist\n",
"(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
"x_train.shape, y_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# take only 10 images for simplicity\n",
"train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
"test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(28, 28)\n",
"(28, 28)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAN80lEQVR4nO3df6hcdXrH8c+ncf3DrBpTMYasNhuRWBWbLRqLSl2RrD9QNOqWDVgsBrN/GHChhEr6xyolEuqP0qAsuYu6sWyzLqgYZVkVo6ZFCF5j1JjU1YrdjV6SSozG+KtJnv5xT+Su3vnOzcyZOZP7vF9wmZnzzJnzcLife87Md879OiIEYPL7k6YbANAfhB1IgrADSRB2IAnCDiRxRD83ZpuP/oEeiwiPt7yrI7vtS22/aftt27d281oAesudjrPbniLpd5IWSNou6SVJiyJia2EdjuxAj/XiyD5f0tsR8U5EfCnpV5Ku6uL1APRQN2GfJekPYx5vr5b9EdtLbA/bHu5iWwC61M0HdOOdKnzjND0ihiQNSZzGA03q5si+XdJJYx5/R9L73bUDoFe6CftLkk61/V3bR0r6kaR19bQFoG4dn8ZHxD7bSyU9JWmKpAci4o3aOgNQq46H3jraGO/ZgZ7ryZdqABw+CDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii4ymbcXiYMmVKsX7sscf2dPtLly5tWTvqqKOK686dO7dYv/nmm4v1u+66q2Vt0aJFxXU///zzYn3lypXF+u23316sN6GrsNt+V9IeSfsl7YuIs+toCkD96jiyXxQRH9TwOgB6iPfsQBLdhj0kPW37ZdtLxnuC7SW2h20Pd7ktAF3o9jT+/Ih43/YJkp6x/V8RsWHsEyJiSNKQJNmOLrcHoENdHdkj4v3qdqekxyTNr6MpAPXrOOy2p9o++uB9ST+QtKWuxgDUq5vT+BmSHrN98HX+PSJ+W0tXk8zJJ59crB955JHF+nnnnVesX3DBBS1r06ZNK6577bXXFutN2r59e7G+atWqYn3hwoUta3v27Cmu++qrrxbrL7zwQrE+iDoOe0S8I+kvauwFQA8x9AYkQdiBJAg7kARhB5Ig7EASjujfl9om6zfo5s2bV6yvX7++WO/1ZaaD6sCBA8X6jTfeWKx/8sknHW97ZGSkWP/www+L9TfffLPjbfdaRHi85RzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlrMH369GJ948aNxfqcOXPqbKdW7XrfvXt3sX7RRRe1rH355ZfFdbN+/6BbjLMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJM2VyDXbt2FevLli0r1q+44opi/ZVXXinW2/1L5ZLNmzcX6wsWLCjW9+7dW6yfccYZLWu33HJLcV3UiyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTB9ewD4JhjjinW200vvHr16pa1xYsXF9e9/vrri/W1a9cW6xg8HV/PbvsB2zttbxmzbLrtZ2y/Vd0eV2ezAOo3kdP4X0i69GvLbpX0bEScKunZ6jGAAdY27BGxQdLXvw96laQ11f01kq6uuS8ANev0u/EzImJEkiJixPYJrZ5oe4mkJR1uB0BNen4hTEQMSRqS+IAOaFKnQ287bM+UpOp2Z30tAeiFTsO+TtIN1f0bJD1eTzsAeqXtabzttZK+L+l429sl/VTSSkm/tr1Y0u8l/bCXTU52H3/8cVfrf/TRRx2ve9NNNxXrDz/8cLHebo51DI62YY+IRS1KF9fcC4Ae4uuyQBKEHUiCsANJEHYgCcIOJMElrpPA1KlTW9aeeOKJ4roXXnhhsX7ZZZcV608//XSxjv5jymYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9knulFNOKdY3bdpUrO/evbtYf+6554r14eHhlrX77ruvuG4/fzcnE8bZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTW7hwYbH+4IMPFutHH310x9tevnx5sf7QQw8V6yMjIx1vezJjnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHUVnnnlmsX7PPfcU6xdf3Plkv6tXry7WV6xYUay/9957HW/7cNbxOLvtB2zvtL1lzLLbbL9ne3P1c3mdzQKo30RO438h6dJxlv9LRMyrfn5Tb1sA6tY27BGxQdKuPvQCoIe6+YBuqe3XqtP841o9yfYS28O2W/8zMgA912nYfybpFEnzJI1IurvVEyNiKCLOjoizO9wWgBp0FPaI2BER+yPigKSfS5pfb1sA6tZR2G3PHPNwoaQtrZ4LYDC0HWe3vVbS9yUdL2mHpJ9Wj+dJCknvSvpxRLS9uJhx9sln2rRpxfqVV17ZstbuWnl73OHir6xfv75YX7BgQbE+WbUaZz9iAisuGmfx/V13BKCv+LoskARhB5Ig7EAShB1IgrADSXCJKxrzxRdfFOtHHFEeLNq3b1+xfskll7SsPf/888V1D2f8K2kgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLtVW/I7ayzzirWr7vuumL9nHPOaVlrN47eztatW4v1DRs2dPX6kw1HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2SW7u3LnF+tKlS4v1a665plg/8cQTD7mnidq/f3+xPjJS/u/lBw4cqLOdwx5HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2w0C7sexFi8abaHdUu3H02bNnd9JSLYaHh4v1FStWFOvr1q2rs51Jr+2R3fZJtp+zvc32G7ZvqZZPt/2M7beq2+N63y6ATk3kNH6fpL+PiD+X9FeSbrZ9uqRbJT0bEadKerZ6DGBAtQ17RIxExKbq/h5J2yTNknSVpDXV09ZIurpXTQLo3iG9Z7c9W9L3JG2UNCMiRqTRPwi2T2ixzhJJS7prE0C3Jhx229+W9Iikn0TEx/a4c8d9Q0QMSRqqXoOJHYGGTGjozfa3NBr0X0bEo9XiHbZnVvWZknb2pkUAdWh7ZPfoIfx+Sdsi4p4xpXWSbpC0srp9vCcdTgIzZswo1k8//fRi/d577y3WTzvttEPuqS4bN24s1u+8886WtccfL//KcIlqvSZyGn++pL+V9LrtzdWy5RoN+a9tL5b0e0k/7E2LAOrQNuwR8Z+SWr1Bv7jedgD0Cl+XBZIg7EAShB1IgrADSRB2IAkucZ2g6dOnt6ytXr26uO68efOK9Tlz5nTUUx1efPHFYv3uu+8u1p966qli/bPPPjvkntAbHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IIk04+znnntusb5s2bJiff78+S1rs2bN6qinunz66acta6tWrSque8cddxTre/fu7agnDB6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRJpx9oULF3ZV78bWrVuL9SeffLJY37dvX7FeuuZ89+7dxXWRB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUjCEVF+gn2SpIcknSjpgKShiPhX27dJuknS/1ZPXR4Rv2nzWuWNAehaRIw76/JEwj5T0syI2GT7aEkvS7pa0t9I+iQi7ppoE4Qd6L1WYZ/I/Owjkkaq+3tsb5PU7L9mAXDIDuk9u+3Zkr4naWO1aKnt12w/YPu4FusssT1se7irTgF0pe1p/FdPtL8t6QVJKyLiUdszJH0gKST9k0ZP9W9s8xqcxgM91vF7dkmy/S1JT0p6KiLuGac+W9KTEXFmm9ch7ECPtQp729N425Z0v6RtY4NefXB30EJJW7ptEkDvTOTT+Ask/Yek1zU69CZJyyUtkjRPo6fx70r6cfVhXum1OLIDPdbVaXxdCDvQex2fxgOYHAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HvK5g8k/c+Yx8dXywbRoPY2qH1J9NapOnv7s1aFvl7P/o2N28MRcXZjDRQMam+D2pdEb53qV2+cxgNJEHYgiabDPtTw9ksGtbdB7Uuit071pbdG37MD6J+mj+wA+oSwA0k0Enbbl9p+0/bbtm9toodWbL9r+3Xbm5uen66aQ2+n7S1jlk23/Yztt6rbcefYa6i322y/V+27zbYvb6i3k2w/Z3ub7Tds31Itb3TfFfrqy37r+3t221Mk/U7SAknbJb0kaVFEbO1rIy3YflfS2RHR+BcwbP+1pE8kPXRwai3b/yxpV0SsrP5QHhcR/zAgvd2mQ5zGu0e9tZpm/O/U4L6rc/rzTjRxZJ8v6e2IeCcivpT0K0lXNdDHwIuIDZJ2fW3xVZLWVPfXaPSXpe9a9DYQImIkIjZV9/dIOjjNeKP7rtBXXzQR9lmS/jDm8XYN1nzvIelp2y/bXtJ0M+OYcXCarer2hIb7+bq203j309emGR+YfdfJ9OfdaiLs401NM0jjf+dHxF9KukzSzdXpKibmZ5JO0egcgCOS7m6ymWqa8Uck/SQiPm6yl7HG6asv+62JsG+XdNKYx9+R9H4DfYwrIt6vbndKekyjbzsGyY6DM+hWtzsb7ucrEbEjIvZHxAFJP1eD+66aZvwRSb+MiEerxY3vu/H66td+ayLsL0k61fZ3bR8p6UeS1jXQxzfYnlp9cCLbUyX9QIM3FfU6STdU92+Q9HiDvfyRQZnGu9U042p43zU+/XlE9P1H0uUa/UT+vyX9YxM9tOhrjqRXq583mu5N0lqNntb9n0bPiBZL+lNJz0p6q7qdPkC9/ZtGp/Z+TaPBmtlQbxdo9K3ha5I2Vz+XN73vCn31Zb/xdVkgCb5BByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/D+f1mbtgJ8kQQAAAABJRU5ErkJggg==\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# In case you want to retrieve the images/numpy arrays\n",
"for element in iter(train_dataset.take(1)):\n",
" image = element[0].numpy()\n",
" print(image.shape)\n",
" print(image.shape)\n",
" plt.figure()\n",
" plt.imshow(image, cmap='gray')\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once you have your Model, you may pass a Dataset instance directly to the methods `fit()`, `evaluate()`, and `predict()`. The difference with the way we have been previously using these methods is that we are not passing the images and labels separately. They are now both in the Dataset object.\n",
"\n",
"```\n",
"model.fit(train_dataset, epochs=3)\n",
"\n",
"model.evaluate(test_dataset)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Data Augmentation"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig, axes = plt.subplots(1,6, figsize=(10,3))\n",
"for i, (image, label) in enumerate(train_dataset.take(4)):\n",
" axes[i].imshow(image)\n",
" axes[i].set_title(f'{label:.2f}')\n",
"image_flip_up = tf.image.flip_up_down(np.expand_dims(image, axis=2)).numpy()\n",
"image_rot_90 = tf.image.rot90(np.expand_dims(image, axis=2), k=1).numpy()\n",
"axes[4].imshow(image_flip_up.reshape(28,-1))\n",
"axes[4].set_title(f'{label:.2f}-flip')\n",
"axes[5].imshow(image_rot_90.reshape(28,-1))\n",
"axes[5].set_title(f'{label:.2f}-rot90')\n",
"plt.show();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Note:\n",
"\n",
"The tf.data API is a set of utilities in TensorFlow 2.0 for loading and preprocessing data in a way that's fast and scalable. You also have the option to use the `keras` [`ImageDataGenerator`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator), that accepts `numpy` arrays, instead of the Dataset. We think it's good for you to learn to use Datasets.\n",
"\n",
"As a general rule, for input to NNs, Tensorflow recommends that you use `numpy` arrays if your data is small and fit in memory, and `tf.data.Datasets` otherwise.\n",
"\n",
"#### References:\n",
"1. `tf.data.Dataset` [Documentation](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).\n",
"2. Import [`numpy` arrays in Tensorflow](https://www.tensorflow.org/tutorials/load_data/numpy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### The Street View House Numbers (SVHN) Dataset\n",
"\n",
"We will play with the SVHN real-world image dataset. It can be seen as similar in flavor to MNIST (e.g., the images are of small cropped digits), but incorporates an order of magnitude more labeled data (over 600,000 digit images) and comes from a significantly harder, unsolved, real world problem (recognizing digits and numbers in natural scene images). SVHN is obtained from house numbers in Google Street View images. \n",
"\n",
"All digits have been resized to a fixed resolution of 32-by-32 pixels. The original character bounding boxes are extended in the appropriate dimension to become square windows, so that resizing them to 32-by-32 pixels does not introduce aspect ratio distortions. Nevertheless this preprocessing introduces some distracting digits to the sides of the digit of interest. Loading the .mat files creates 2 variables: X which is a 4-D matrix containing the images, and y which is a vector of class labels. To access the images, $X(:,:,:,i)$ gives the i-th 32-by-32 RGB image, with class label $y(i)$.\n",
"\n",
"\n",
"\n",
"*Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, Andrew Y. Ng Reading Digits in Natural Images with Unsupervised Feature Learning NIPS Workshop on Deep Learning and Unsupervised Feature Learning 2011.*"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# Will take some time but will only load once\n",
"train_svhn_cropped, test_svhn_cropped = tfds.load('svhn_cropped', split=['train', 'test'], shuffle_files=False)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"isinstance(train_svhn_cropped, tf.data.Dataset)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((32, 32, 3), ())"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# # convert to numpy if needed\n",
"features = next(iter(train_svhn_cropped))\n",
"images = features['image'].numpy()\n",
"labels = features['label'].numpy()\n",
"images.shape, labels.shape"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(5, shape=(), dtype=int64)\n"
]
}
],
"source": [
"for i, element in enumerate(train_svhn_cropped):\n",
" if i==1: break;\n",
" image = element['image'] \n",
" label = element['label']\n",
" print(label)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# batch_size indicates that the dataset should be divided in batches \n",
"# each consisting of 4 elements (a.k.a images and their labels)\n",
"# take_size chooses a number of these batches, e.g. 3 of them for display\n",
"\n",
"batch_size = 4\n",
"take_size = 3\n",
"\n",
"# Plot\n",
"fig, axes = plt.subplots(take_size,batch_size, figsize=(10,10))\n",
"for i, element in enumerate(train_svhn_cropped.batch(batch_size).take(take_size)):\n",
" for j in range(4):\n",
" image = element['image'][j]\n",
" label = element['label'][j]\n",
" axes[i][j].imshow(image)\n",
" axes[i][j].set_title(f'true label={label:d}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we convert from a collection of dictionaries to a collection of tuples. We will still have a `tf.data.Dataset`"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"def normalize_image(img):\n",
" return tf.cast(img, tf.float32)/255.\n",
"\n",
"def normalize_dataset(element):\n",
" img = element['image']\n",
" lbl = element['label']\n",
" return normalize_image(img), lbl"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"train_svhn = train_svhn_cropped.map(normalize_dataset)\n",
"test_svhn = test_svhn_cropped.map(normalize_dataset)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"isinstance(train_svhn, tf.data.Dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Define our CNN model "
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_5\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_15 (Conv2D) (None, 30, 30, 16) 448 \n",
"_________________________________________________________________\n",
"max_pooling2d_10 (MaxPooling (None, 15, 15, 16) 0 \n",
"_________________________________________________________________\n",
"conv2d_16 (Conv2D) (None, 13, 13, 32) 4640 \n",
"_________________________________________________________________\n",
"max_pooling2d_11 (MaxPooling (None, 6, 6, 32) 0 \n",
"_________________________________________________________________\n",
"conv2d_17 (Conv2D) (None, 4, 4, 64) 18496 \n",
"_________________________________________________________________\n",
"flatten_5 (Flatten) (None, 1024) 0 \n",
"_________________________________________________________________\n",
"dense_10 (Dense) (None, 32) 32800 \n",
"_________________________________________________________________\n",
"dense_11 (Dense) (None, 10) 330 \n",
"=================================================================\n",
"Total params: 56,714\n",
"Trainable params: 56,714\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"n_filters = 16\n",
"input_shape = (32, 32, 3)\n",
"\n",
"svhn_model = Sequential() \n",
"svhn_model.add(Conv2D(n_filters, (3, 3), activation='relu', input_shape=input_shape))\n",
"svhn_model.add(MaxPooling2D((2, 2)))\n",
"svhn_model.add(Conv2D(n_filters*2, (3, 3), activation='relu')) \n",
"svhn_model.add(MaxPooling2D((2, 2)))\n",
"svhn_model.add(Conv2D(n_filters*4, (3, 3), activation='relu'))\n",
"svhn_model.add(Flatten())\n",
"svhn_model.add(Dense(n_filters*2, activation='relu'))\n",
"svhn_model.add(Dense(10, activation='softmax'))\n",
"svhn_model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"loss = keras.losses.sparse_categorical_crossentropy # we use this because we did not 1-hot encode the labels\n",
"optimizer = Adam(lr=0.001)\n",
"metrics = ['accuracy'] \n",
"\n",
"# Compile model\n",
"svhn_model.compile(optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=metrics)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### With Early Stopping"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/15\n",
"1145/1145 [==============================] - 30s 26ms/step - loss: 1.0362 - accuracy: 0.6684 - val_loss: 0.6124 - val_accuracy: 0.8285\n",
"Epoch 2/15\n",
"1145/1145 [==============================] - 30s 26ms/step - loss: 0.5177 - accuracy: 0.8515 - val_loss: 0.5254 - val_accuracy: 0.8519\n",
"Epoch 3/15\n",
"1145/1145 [==============================] - 30s 26ms/step - loss: 0.4393 - accuracy: 0.8739 - val_loss: 0.4789 - val_accuracy: 0.8639\n",
"Epoch 4/15\n",
"1145/1145 [==============================] - 30s 26ms/step - loss: 0.3956 - accuracy: 0.8865 - val_loss: 0.4440 - val_accuracy: 0.8750\n",
"Epoch 5/15\n",
"1145/1145 [==============================] - 30s 26ms/step - loss: 0.3654 - accuracy: 0.8951 - val_loss: 0.4233 - val_accuracy: 0.8816\n",
"Epoch 6/15\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 0.3412 - accuracy: 0.9014 - val_loss: 0.4168 - val_accuracy: 0.8846\n",
"Epoch 7/15\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 0.3215 - accuracy: 0.9072 - val_loss: 0.4084 - val_accuracy: 0.8871\n",
"Epoch 8/15\n",
"1145/1145 [==============================] - 30s 26ms/step - loss: 0.3055 - accuracy: 0.9124 - val_loss: 0.4026 - val_accuracy: 0.8888\n",
"Epoch 9/15\n",
"1145/1145 [==============================] - 31s 27ms/step - loss: 0.2916 - accuracy: 0.9163 - val_loss: 0.4100 - val_accuracy: 0.8887\n",
"Epoch 10/15\n",
"1145/1145 [==============================] - 30s 26ms/step - loss: 0.2780 - accuracy: 0.9200 - val_loss: 0.4217 - val_accuracy: 0.8861\n",
"Epoch 00010: early stopping\n",
"CPU times: user 20min 1s, sys: 8min 31s, total: 28min 33s\n",
"Wall time: 4min 58s\n"
]
}
],
"source": [
"%%time\n",
"batch_size = 64\n",
"epochs=15\n",
"\n",
"callbacks = [ \n",
" keras.callbacks.EarlyStopping(\n",
" # Stop training when `val_accuracy` is no longer improving\n",
" monitor='val_accuracy',\n",
" # \"no longer improving\" being further defined as \"for at least 2 epochs\"\n",
" patience=2,\n",
" verbose=1)\n",
" ]\n",
"\n",
"history = svhn_model.fit(train_svhn.batch(batch_size), #.take(50), # change 50 only\n",
" epochs=epochs,\n",
" callbacks=callbacks,\n",
" validation_data=test_svhn.batch(batch_size)) #.take(50))"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def print_history(history):\n",
" fig, ax = plt.subplots(1, 1, figsize=(8,4))\n",
" ax.plot((history.history['accuracy']), 'b', label='train')\n",
" ax.plot((history.history['val_accuracy']), 'g' ,label='val')\n",
" ax.set_xlabel(r'Epoch', fontsize=20)\n",
" ax.set_ylabel(r'Accuracy', fontsize=20)\n",
" ax.legend()\n",
" ax.tick_params(labelsize=20)\n",
" fig, ax = plt.subplots(1, 1, figsize=(8,4))\n",
" ax.plot((history.history['loss']), 'b', label='train')\n",
" ax.plot((history.history['val_loss']), 'g' ,label='val')\n",
" ax.set_xlabel(r'Epoch', fontsize=20)\n",
" ax.set_ylabel(r'Loss', fontsize=20)\n",
" ax.legend()\n",
" ax.tick_params(labelsize=20)\n",
" plt.show();\n",
" \n",
"print_history(history)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"svhn_model.save('svhn_good.h5')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Too High Learning Rate"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"loss = keras.losses.sparse_categorical_crossentropy \n",
"optimizer = Adam(lr=0.5) # really big learning rate\n",
"metrics = ['accuracy'] \n",
"\n",
"# Compile model\n",
"svhn_model.compile(optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 1518.9293 - accuracy: 0.1763 - val_loss: 2.2455 - val_accuracy: 0.1594\n",
"Epoch 2/10\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 2.2719 - accuracy: 0.1741 - val_loss: 2.2437 - val_accuracy: 0.1594\n",
"Epoch 3/10\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 2.2734 - accuracy: 0.1745 - val_loss: 2.2431 - val_accuracy: 0.1959\n",
"Epoch 4/10\n",
"1145/1145 [==============================] - 29s 26ms/step - loss: 2.2737 - accuracy: 0.1743 - val_loss: 2.2429 - val_accuracy: 0.1959\n",
"Epoch 5/10\n",
"1145/1145 [==============================] - 29s 26ms/step - loss: 2.2738 - accuracy: 0.1743 - val_loss: 2.2428 - val_accuracy: 0.1959\n",
"Epoch 6/10\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 2.2738 - accuracy: 0.1743 - val_loss: 2.2428 - val_accuracy: 0.1959\n",
"Epoch 7/10\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 2.2738 - accuracy: 0.1743 - val_loss: 2.2428 - val_accuracy: 0.1959\n",
"Epoch 8/10\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 2.2738 - accuracy: 0.1743 - val_loss: 2.2428 - val_accuracy: 0.1959\n",
"Epoch 9/10\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 2.2738 - accuracy: 0.1743 - val_loss: 2.2428 - val_accuracy: 0.1959\n",
"Epoch 10/10\n",
"1145/1145 [==============================] - 29s 25ms/step - loss: 2.2738 - accuracy: 0.1743 - val_loss: 2.2428 - val_accuracy: 0.1959\n",
"CPU times: user 19min 22s, sys: 8min 17s, total: 27min 40s\n",
"Wall time: 4min 50s\n"
]
}
],
"source": [
"%%time\n",
"batch_size = 64\n",
"epochs=10\n",
"\n",
"history = svhn_model.fit(train_svhn.batch(batch_size), #.take(50), # change 50 to see the difference\n",
" epochs=epochs,\n",
" validation_data=test_svhn.batch(batch_size)) #.take(50))"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"print_history(history)\n",
"fig.savefig('../images/train_high_lr.png')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Too Low Learning Rate\n",
"\n",
"Experiment with the learning rate using a small sample of the training set by using .take(num) which takes only `num` number of samples.\n",
"```\n",
"history = svhn_model.fit(train_svhn.batch(batch_size).take(50))\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"#loss = keras.losses.categorical_crossentropy\n",
"loss = keras.losses.sparse_categorical_crossentropy # we use this because we did not 1-hot encode the labels\n",
"optimizer = Adam(lr=1e-5) # very low learning rate\n",
"metrics = ['accuracy'] \n",
"\n",
"# Compile model\n",
"svhn_model.compile(optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"2290/2290 [==============================] - 37s 16ms/step - loss: 2.2603 - accuracy: 0.1707 - val_loss: 2.2314 - val_accuracy: 0.1957\n",
"Epoch 2/10\n",
"2290/2290 [==============================] - 34s 15ms/step - loss: 2.2295 - accuracy: 0.1894 - val_loss: 2.2119 - val_accuracy: 0.1970\n",
"Epoch 3/10\n",
"2290/2290 [==============================] - 35s 15ms/step - loss: 2.2046 - accuracy: 0.2012 - val_loss: 2.1738 - val_accuracy: 0.2342\n",
"Epoch 4/10\n",
"2290/2290 [==============================] - 35s 15ms/step - loss: 2.1504 - accuracy: 0.2458 - val_loss: 2.0987 - val_accuracy: 0.2948\n",
"Epoch 5/10\n",
"2290/2290 [==============================] - 36s 16ms/step - loss: 2.0492 - accuracy: 0.3008 - val_loss: 1.9756 - val_accuracy: 0.3434\n",
"Epoch 6/10\n",
"2290/2290 [==============================] - 37s 16ms/step - loss: 1.9201 - accuracy: 0.3507 - val_loss: 1.8509 - val_accuracy: 0.3832\n",
"Epoch 7/10\n",
"2290/2290 [==============================] - 38s 16ms/step - loss: 1.7967 - accuracy: 0.3975 - val_loss: 1.7373 - val_accuracy: 0.4274\n",
"Epoch 8/10\n",
"2290/2290 [==============================] - 35s 15ms/step - loss: 1.6818 - accuracy: 0.4490 - val_loss: 1.6338 - val_accuracy: 0.4714\n",
"Epoch 9/10\n",
"2290/2290 [==============================] - 34s 15ms/step - loss: 1.5778 - accuracy: 0.4939 - val_loss: 1.5412 - val_accuracy: 0.5111\n",
"Epoch 10/10\n",
"2290/2290 [==============================] - 35s 15ms/step - loss: 1.4837 - accuracy: 0.5307 - val_loss: 1.4577 - val_accuracy: 0.5436\n",
"CPU times: user 20min 26s, sys: 9min 2s, total: 29min 28s\n",
"Wall time: 5min 56s\n"
]
}
],
"source": [
"%%time\n",
"batch_size = 32\n",
"epochs=10\n",
"\n",
"history = svhn_model.fit(train_svhn.batch(batch_size).take(50),\n",
" epochs=epochs,\n",
" validation_data=test_svhn.batch(batch_size)) #.take(50))"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"print_history(history)\n",
"fig.savefig('../images/train_50.png')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Changing the batch size"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"#loss = keras.losses.categorical_crossentropy\n",
"loss = keras.losses.sparse_categorical_crossentropy # we use this because we did not 1-hot encode the labels\n",
"optimizer = Adam(lr=0.001)\n",
"metrics = ['accuracy'] \n",
"\n",
"# Compile model\n",
"svhn_model.compile(optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"36629/36629 [==============================] - 175s 5ms/step - loss: 0.8544 - accuracy: 0.7295 - val_loss: 0.5765 - val_accuracy: 0.8363\n",
"Epoch 2/5\n",
"36629/36629 [==============================] - 135s 4ms/step - loss: 0.5045 - accuracy: 0.8494 - val_loss: 0.5326 - val_accuracy: 0.8511\n",
"Epoch 3/5\n",
"36629/36629 [==============================] - 134s 4ms/step - loss: 0.4520 - accuracy: 0.8649 - val_loss: 0.5270 - val_accuracy: 0.8584\n",
"Epoch 4/5\n",
"36629/36629 [==============================] - 141s 4ms/step - loss: 0.4209 - accuracy: 0.8744 - val_loss: 0.5106 - val_accuracy: 0.8614\n",
"Epoch 5/5\n",
"36629/36629 [==============================] - 126s 3ms/step - loss: 0.4007 - accuracy: 0.8811 - val_loss: 0.5079 - val_accuracy: 0.8617\n",
"CPU times: user 19min 36s, sys: 10min 1s, total: 29min 37s\n",
"Wall time: 11min 50s\n"
]
}
],
"source": [
"%%time\n",
"batch_size = 2\n",
"epochs=5\n",
"\n",
"history = svhn_model.fit(train_svhn.batch(batch_size), \n",
" epochs=epochs,\n",
" validation_data=test_svhn.batch(batch_size)) "
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"print_history(history)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" [Back to Table of Contents](#top)\n",
"## Part 3: Hidden Layer Visualization, Saliency Maps\n",
"\n",
"[Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps](https://arxiv.org/pdf/1312.6034.pdf)\n",
"\n",
"It is often said that Deep Learning Models are black boxes. But we can peak into these boxes. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Let's train a small model on MNIST"
]
},
{
"cell_type": "code",
"execution_count": 408,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.datasets import mnist\n",
"# load MNIST data\n",
"(x_train, y_train), (x_test, y_test) = mnist.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 409,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 255)"
]
},
"execution_count": 409,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train.min(), x_train.max()"
]
},
{
"cell_type": "code",
"execution_count": 410,
"metadata": {},
"outputs": [],
"source": [
"x_train = x_train.reshape((60000, 28, 28, 1)) # Reshape to get third dimension\n",
"x_test = x_test.reshape((10000, 28, 28, 1)) \n",
"\n",
"x_train = x_train.astype('float32') / 255 # Normalize between 0 and 1\n",
"x_test = x_test.astype('float32') / 255 \n",
"\n",
"# Convert labels to categorical data \n",
"y_train = to_categorical(y_train)\n",
"y_test = to_categorical(y_test)"
]
},
{
"cell_type": "code",
"execution_count": 411,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.0, 1.0)"
]
},
"execution_count": 411,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train.min(), x_train.max()"
]
},
{
"cell_type": "code",
"execution_count": 412,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(60000, 28, 28, 1)"
]
},
"execution_count": 412,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(\n",
"# path='mnist.npz')\n",
"x_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 413,
"metadata": {},
"outputs": [],
"source": [
"class_idx = 0\n",
"indices = np.where(y_test[:, class_idx] == 1.)[0]\n",
"\n",
"# pick some random input from here.\n",
"idx = indices[0]\n",
"img = x_test[idx]"
]
},
{
"cell_type": "code",
"execution_count": 414,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWgAAAFlCAYAAADGe3ILAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAQ1ElEQVR4nO3db4hd9Z3H8c9nEvsk+iDiKNG6SddIXCmsWWJcsETXxZIIwfjAGhHJYmGiqImwD1YSsPHBgmjN7gNDwhRDs9BaCqlrkLBVJJAKQUzEPzGzrX9I0uiQ+AfUEKQx890Hc8JOk5nM+c3cM/d7732/IMy9Z75z7vfkTD75ze+e8xtHhAAA+fS1uwEAwPgIaABIioAGgKQIaABIioAGgKQIaABIavZMvphtrukDgHNEhMfbzggaAJKaVkDbXm77j7Y/tP14q5oCAEie6p2EtmdJ+pOk2yUdk/SmpHsj4tAFvoYpDgA4RxNTHEslfRgRH0fEXyT9RtKd09gfAGCM6QT0VZL+POb5sWrbX7E9YHu/7f3TeC0A6DnTuYpjvCH5eVMYETEoaVBiigMASkxnBH1M0tVjnn9f0qfTawcAcNZ0AvpNSdfa/oHt70laLWlXa9oCAEx5iiMivrP9iKTfS5olaXtEvN+yzgCgx035MrspvRhz0ABwHu4kBIAOQ0ADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkNbvdDQATmTNnTlH9M888U7t27dq1Rfs+cOBAUf3dd99du/bIkSNF+0bvYAQNAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQlCNi5l7MnrkXQ8dbuHBhUf3Q0FBDnUh9fWVjmXXr1tWu3bJlS2k76DIR4fG2M4IGgKQIaABIalqr2dk+LOkbSWckfRcRS1rRFACgNcuN/lNEfN6C/QAAxmCKAwCSmm5Ah6RXbB+wPdCKhgAAo6Y7xXFzRHxq+3JJr9r+34jYO7agCm7CGwAKTWsEHRGfVh9PSHpR0tJxagYjYglvIAJAmSkHtO05ti85+1jSjyUdbFVjANDrpjPFcYWkF22f3c+vI+J/WtIVAGDqAR0RH0v6+xb2AgAYoxXXQQO19Pf3F9Xv2LGjoU6AzsB10ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFGtxYFrWrVtXu3bVqlVF+1669LzlxTvGsmXLatf29ZWNk955552i+r17905ehJQYQQNAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACTliJi5F7Nn7sUwI86cOVO7dmRkpMFOmlV6O3aTx3rkyJGi+nvuuaeo/sCBA0X1mL6I8HjbGUEDQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFKsxYG/snv37qL6FStW1K7t5LU4vvjii6L6kydP1q6dP39+aTuNmjVrVrtb6DmsxQEAHYaABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASGp2uxtAs2655Zai+kWLFhXVl6yvkWktjm3bthXVv/LKK0X1X331Ve3a2267rWjfGzduLKov9dBDD9Wu3bp1a4OdgBE0ACRFQANAUpMGtO3ttk/YPjhm26W2X7X9QfVxbrNtAkDvqTOC/qWk5edse1zSaxFxraTXqucAgBaaNKAjYq+kL8/ZfKekHdXjHZJWtbgvAOh5U72K44qIGJakiBi2fflEhbYHJA1M8XUAoGc1fpldRAxKGpT4lVcAUGKqV3Ectz1PkqqPJ1rXEgBAmnpA75K0pnq8RtJLrWkHAHBWncvsXpC0T9Ii28ds/1TSU5Jut/2BpNur5wCAFnLEzE0LMwfdGgsWLKhdu2/fvqJ9X3bZZUX1fX31fwgrvdX7yJEjRfU7d+6sXfvkk08W7fvUqVNF9SXmz59fVF96Tvv7+4vqv/3229q1TzzxRNG+n3vuuaL606dPF9V3qojweNu5kxAAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkmItjg60cOHC2rVDQ0MNdlK2FseePXuK9r169eqi+s8//7yovlM9+uijRfWbN28uqm9yfZXrrruuqP6jjz4qqu9UrMUBAB2GgAaApAhoAEiKgAaApAhoAEiKgAaApAhoAEiKgAaApAhoAEiKgAaApAhoAEhqdrsbQGfbv39/7doHHnigaN+9srZGqV27dhXV33fffUX1N954Y1E9msMIGgCSIqABICkCGgCSIqABICkCGgCSIqABICkCGgCSIqABICkCGgCSIqABIClu9e5yfX3N/h980003Nbp/nM92UX3p90CT3zObNm0qqr///vubaaRDMIIGgKQIaABIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIioAGgKRYi6MDPfjgg7VrR0ZGGuwE7bBy5cqi+sWLFxfVl3zPlH5/la7F0esYQQNAUgQ0ACQ1aUDb3m77hO2DY7Ztsv2J7berP3c02yYA9J46I+hfSlo+zvb/iIgbqj+7W9sWAGDSgI6IvZK+nIFeAABjTGcO+hHb71ZTIHMnKrI9YHu/7f3TeC0A6DlTDeitkq6RdIOkYUnPTlQYEYMRsSQilkzxtQCgJ00poCPieESciYgRSb+QtLS1bQEAphTQtueNeXqXpIMT1QIApmbSOwltvyDpVkmX2T4m6WeSbrV9g6SQdFjS2gZ7BICeNGlAR8S942x+voFeAABjsBZHBypdiwEzr7+/v3bt9ddfX7TvDRs2lLbTmM8++6yo/vTp0w110p241RsAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkmItDqABGzdurF378MMPN9hJucOHD9euXbNmTdG+jx49WthNb2MEDQBJEdAAkBQBDQBJEdAAkBQBDQBJEdAAkBQBDQBJEdAAkBQBDQBJEdAAkBS3egM17N69u6h+0aJFDXXSvEOHDtWuff311xvsBIygASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASAp1uLoQLZr1/b1Nft/8IoVKxrb9+DgYFH9lVde2VAn5X+PIyMjDXXSvJUrV7a7BVQYQQNAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUqzF0YG2bt1au/bpp59usBPp5Zdfrl3b9PoUmda/yNTLtm3b2t0CpogRNAAkNWlA277a9h7bQ7bft72+2n6p7Vdtf1B9nNt8uwDQO+qMoL+T9K8R8XeS/lHSw7avl/S4pNci4lpJr1XPAQAtMmlAR8RwRLxVPf5G0pCkqyTdKWlHVbZD0qqmmgSAXlT0JqHtBZIWS3pD0hURMSyNhrjtyyf4mgFJA9NrEwB6T+2Atn2xpJ2SHouIr+v+Vo+IGJQ0WO0jptIkAPSiWldx2L5Io+H8q4j4XbX5uO151efnSTrRTIsA0JvqXMVhSc9LGoqIzWM+tUvSmurxGkkvtb49AOhddaY4bpZ0v6T3bL9dbdsg6SlJv7X9U0lHJd3dTIsA0JsmDeiIeF3SRBPO/9zadgAAZzli5t63403C1pg/f37t2n379hXtu7+/v6i+r6/+zaiZbn8uVXKcknT8+PHatUNDQ0X7HhgouyhqeHi4qP7UqVNF9Zi+iBh3EMyt3gCQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFGtxdLlly5YV1a9aVfaby9avX1+7tpfW4li3bl3t2i1btpS2gy7DWhwA0GEIaABIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIioAGgKRYiwPTsnz58tq1AwMDRfteuXJlUf2uXbtq1w4ODhbt2x53qYQJHTp0qHbt0aNHi/aN7sNaHADQYQhoAEiKgAaApAhoAEiKgAaApAhoAEiKgAaApAhoAEiKgAaApAhoAEiKW70BoM241RsAOgwBDQBJEdAAkBQBDQBJEdAAkBQBDQBJEdAAkBQBDQBJEdAAkBQBDQBJEdAAkBQBDQBJEdAAkNSkAW37att7bA/Zft/2+mr7Jtuf2H67+nNH8+0CQO+YdLlR2/MkzYuIt2xfIumApFWSfiLpZET8vPaLsdwoAJxnouVGZ9f4wmFJw9Xjb2wPSbqqte0BAM5VNAdte4GkxZLeqDY9Yvtd29ttz21xbwDQ02oHtO2LJe2U9FhEfC1pq6RrJN2g0RH2sxN83YDt/bb3t6BfAOgZtX7lle2LJL0s6fcRsXmczy+Q9HJE/HCS/TAHDQDnmPKvvLJtSc9LGhobztWbh2fdJengdJsEAPy/Oldx/EjSHyS9J2mk2rxB0r0and4ISYclra3eULzQvhhBA8A5JhpB81u9AaDN+K3eANBhCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASGr2DL/e55KOjLP9supz3Y7j7D69cqwcZ3PmT/QJR8RMNjJ+E/b+iFjS7j6axnF2n145Vo6zPZjiAICkCGgASCpLQA+2u4EZwnF2n145Vo6zDVLMQQMAzpdlBA0AOEdbA9r2ctt/tP2h7cfb2UvTbB+2/Z7tt23vb3c/rWJ7u+0Ttg+O2Xap7Vdtf1B9nNvOHlthguPcZPuT6py+bfuOdvbYCravtr3H9pDt922vr7Z31Tm9wHGmOqdtm+KwPUvSnyTdLumYpDcl3RsRh9rSUMNsH5a0JCK66lpS28sknZT0XxHxw2rb05K+jIinqv9450bEv7Wzz+ma4Dg3SToZET9vZ2+tZHuepHkR8ZbtSyQdkLRK0r+oi87pBY7zJ0p0Tts5gl4q6cOI+Dgi/iLpN5LubGM/mIKI2Cvpy3M23ylpR/V4h0a/8TvaBMfZdSJiOCLeqh5/I2lI0lXqsnN6geNMpZ0BfZWkP495fkwJ/4JaKCS9YvuA7YF2N9OwKyJiWBr9hyDp8jb306RHbL9bTYF09I/957K9QNJiSW+oi8/pOccpJTqn7Qxoj7Otmy8puTki/kHSCkkPVz8yo7NtlXSNpBskDUt6tr3ttI7tiyXtlPRYRHzd7n6aMs5xpjqn7QzoY5KuHvP8+5I+bVMvjYuIT6uPJyS9qNEpnm51vJrjOzvXd6LN/TQiIo5HxJmIGJH0C3XJObV9kUZD61cR8btqc9ed0/GOM9s5bWdAvynpWts/sP09Sasl7WpjP42xPad6I0K250j6saSDF/6qjrZL0prq8RpJL7Wxl8acDazKXeqCc2rbkp6XNBQRm8d8qqvO6UTHme2ctvVGleoSlv+UNEvS9oj497Y10yDbf6vRUbM0uoLgr7vlWG2/IOlWja4CdlzSzyT9t6TfSvobSUcl3R0RHf0G2wTHeatGfxQOSYclrT07T9upbP9I0h8kvSdppNq8QaPzs11zTi9wnPcq0TnlTkIASIo7CQEgKQIaAJIioAEgKQIaAJIioAEgKQIaAJIioAEgKQIaAJL6P1WfDJeAVvxdAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# pick some random input from here.\n",
"idx = indices[0]\n",
"\n",
"# Lets sanity check the picked image.\n",
"from matplotlib import pyplot as plt\n",
"%matplotlib inline\n",
"plt.rcParams['figure.figsize'] = (18, 6)\n",
"\n",
"#plt.imshow(test_images[idx][..., 0])\n",
"img = x_test[idx] * 255 \n",
"img = img.astype('float32')\n",
"img = np.squeeze(img) # trick to reduce img from (28,28,1) to (28,28)\n",
"plt.imshow(img, cmap='gray');"
]
},
{
"cell_type": "code",
"execution_count": 415,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_10\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv2d_29 (Conv2D) (None, 26, 26, 32) 320 \n",
"_________________________________________________________________\n",
"conv2d_30 (Conv2D) (None, 24, 24, 64) 18496 \n",
"_________________________________________________________________\n",
"max_pooling2d_19 (MaxPooling (None, 12, 12, 64) 0 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 12, 12, 64) 0 \n",
"_________________________________________________________________\n",
"flatten_10 (Flatten) (None, 9216) 0 \n",
"_________________________________________________________________\n",
"dense_19 (Dense) (None, 128) 1179776 \n",
"_________________________________________________________________\n",
"dropout_3 (Dropout) (None, 128) 0 \n",
"_________________________________________________________________\n",
"preds (Dense) (None, 10) 1290 \n",
"=================================================================\n",
"Total params: 1,199,882\n",
"Trainable params: 1,199,882\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"input_shape=(28, 28, 1)\n",
"num_classes = 10\n",
"\n",
"model = Sequential()\n",
"model.add(Conv2D(32, kernel_size=(3, 3),\n",
" activation='relu',\n",
" input_shape=input_shape))\n",
"model.add(Conv2D(64, (3, 3), activation='relu'))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Dropout(0.25))\n",
"model.add(Flatten())\n",
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(num_classes, activation='softmax', name='preds'))\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 416,
"metadata": {},
"outputs": [],
"source": [
"model.compile(loss=keras.losses.categorical_crossentropy,\n",
" optimizer=keras.optimizers.Adam(),\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "code",
"execution_count": 417,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"60000"
]
},
"execution_count": 417,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"num_samples = x_train.shape[0]\n",
"num_samples"
]
},
{
"cell_type": "code",
"execution_count": 418,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 48000 samples, validate on 12000 samples\n",
"Epoch 1/10\n",
"48000/48000 [==============================] - 60s 1ms/sample - loss: 0.2007 - accuracy: 0.9381 - val_loss: 0.0620 - val_accuracy: 0.9823\n",
"Epoch 2/10\n",
"48000/48000 [==============================] - 62s 1ms/sample - loss: 0.0851 - accuracy: 0.9741 - val_loss: 0.0476 - val_accuracy: 0.9871\n",
"Epoch 3/10\n",
"48000/48000 [==============================] - 62s 1ms/sample - loss: 0.0625 - accuracy: 0.9806 - val_loss: 0.0414 - val_accuracy: 0.9890\n",
"Epoch 4/10\n",
"48000/48000 [==============================] - 62s 1ms/sample - loss: 0.0527 - accuracy: 0.9839 - val_loss: 0.0438 - val_accuracy: 0.9875\n",
"Epoch 5/10\n",
"48000/48000 [==============================] - 62s 1ms/sample - loss: 0.0442 - accuracy: 0.9864 - val_loss: 0.0335 - val_accuracy: 0.9902\n",
"Epoch 6/10\n",
"48000/48000 [==============================] - 63s 1ms/sample - loss: 0.0380 - accuracy: 0.9875 - val_loss: 0.0359 - val_accuracy: 0.9907\n",
"Epoch 7/10\n",
"48000/48000 [==============================] - 65s 1ms/sample - loss: 0.0329 - accuracy: 0.9894 - val_loss: 0.0385 - val_accuracy: 0.9903\n",
"Epoch 8/10\n",
"48000/48000 [==============================] - 65s 1ms/sample - loss: 0.0286 - accuracy: 0.9910 - val_loss: 0.0396 - val_accuracy: 0.9904\n",
"Epoch 9/10\n",
"48000/48000 [==============================] - 65s 1ms/sample - loss: 0.0294 - accuracy: 0.9905 - val_loss: 0.0427 - val_accuracy: 0.9901\n",
"Epoch 10/10\n",
"48000/48000 [==============================] - 67s 1ms/sample - loss: 0.0254 - accuracy: 0.9915 - val_loss: 0.0456 - val_accuracy: 0.9894\n",
"CPU times: user 33min 28s, sys: 24min 14s, total: 57min 42s\n",
"Wall time: 10min 33s\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 418,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"batch_size = 32\n",
"epochs = 10\n",
"\n",
"model.fit(x_train, y_train,\n",
" batch_size=batch_size,\n",
" epochs=epochs,\n",
" verbose=1,\n",
" validation_split=0.2,\n",
" shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 419,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test loss: 0.03391535646364073\n",
"Test accuracy: 0.9909\n"
]
}
],
"source": [
"score = model.evaluate(x_test, y_test, verbose=0)\n",
"print('Test loss:', score[0])\n",
"print('Test accuracy:', score[1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Let's look at the layers with `tf.keras.viz` \n",
"\n",
"https://pypi.org/project/tf-keras-vis/\n",
"\n",
"And an example: https://github.com/keisen/tf-keras-vis/blob/master/examples/visualize_conv_filters.ipynb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can identify layers by their layer id:"
]
},
{
"cell_type": "code",
"execution_count": 638,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"('conv2d_29', 'dropout_3')"
]
},
"execution_count": 638,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Alternatively we can specify layer_id as -1 since it corresponds to the last layer.\n",
"layer_id = 0\n",
"model.layers[layer_id].name, model.layers[-2].name"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or you may look at their output"
]
},
{
"cell_type": "code",
"execution_count": 639,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 639,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output = [model.layers[layer_id].output]\n",
"output"
]
},
{
"cell_type": "code",
"execution_count": 640,
"metadata": {},
"outputs": [],
"source": [
"# # You may also replace part of your NN with other parts,\n",
"# # e.g. replace the activation function of the last layer\n",
"# # with a linear one\n",
"\n",
"# model.layers[-1].activation = tf.keras.activations.linear"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generate Feature Maps"
]
},
{
"cell_type": "code",
"execution_count": 641,
"metadata": {},
"outputs": [],
"source": [
"def get_feature_maps(model, layer_id, input_image):\n",
" \"\"\"Returns intermediate output (activation map) from passing an image to the model\n",
" \n",
" Parameters:\n",
" model (tf.keras.Model): Model to examine\n",
" layer_id (int): Which layer's (from zero) output to return\n",
" input_image (ndarray): The input image\n",
" Returns:\n",
" maps (List[ndarray]): Feature map stack output by the specified layer\n",
" \"\"\"\n",
" model_ = Model(inputs=[model.input], outputs=[model.layers[layer_id].output]) \n",
" return model_.predict(np.expand_dims(input_image, axis=0))[0,:,:,:].transpose((2,0,1))"
]
},
{
"cell_type": "code",
"execution_count": 664,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(28, 28, 1)"
]
},
"execution_count": 664,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Choose an arbitrary image\n",
"image_id = 67\n",
"img = x_test[image_id,:,:,:]\n",
"img.shape"
]
},
{
"cell_type": "code",
"execution_count": 665,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 665,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWgAAAFlCAYAAADGe3ILAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAQ00lEQVR4nO3db2hddZ7H8c9na0dKW7QyVuq/dVZEdhEb11IXWhd1MXat0vrAwT6QDg5WcYQKgusf0IJdsDK61SdCB4tdcByF6lrHxbEUobMgpVVCjUbHKh2t1sShgg2i1vrdBzmdzaZJc37Jvbnf3Pt+QUly8u25v+PtvOf05N5TR4QAAPn8TasXAAAYHYEGgKQINAAkRaABICkCDQBJEWgASOqkqXww27ymDwBGiAiPtp0zaABIalKBtr3M9ge299m+t1GLAgBInug7CW3PkPQnSVdLOiBpt6RVEfHeCX4PlzgAYIRmXOJYLGlfRHwcEd9L+p2kFZPYHwBgmMkE+ixJnw77+kC17f+xvcb2Htt7JvFYANBxJvMqjtFOyY+7hBERmyRtkrjEAQAlJnMGfUDSOcO+PlvS55NbDgDgmMkEerekC2z/zPZPJN0kaVtjlgUAmPAljoj4wfadkv4gaYakzRHxbsNWBgAdbsIvs5vQg3ENGgCOwzsJAWCaIdAAkBSBBoCkCDQAJEWgASApAg0ASRFoAEiKQANAUgQaAJIi0ACQFIEGgKQINAAkRaABICkCDQBJEWgASIpAA0BSBBoAkiLQAJAUgQaApAg0ACRFoAEgKQINAEkRaABIikADQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASRFoAEiKQANAUgQaAJIi0ACQFIEGgKQINAAkRaABICkCDQBJEWgASIpAA0BSBBoAkiLQAJAUgQaApAg0ACRFoAEgKQINAEmd1OoFAMhl7ty5tWdfeOGFon0vW7asaH7hwoW1Z/fu3Vu07+mAM2gASIpAA0BSk7rEYXu/pMOSjkr6ISIWNWJRAIDGXIO+MiL+0oD9AACG4RIHACQ12UCHpNdtv2V7TSMWBAAYMtlLHEsi4nPb8yVtt/1+ROwcPlCFm3gDQKFJnUFHxOfVxwFJL0laPMrMpohYxA8QAaDMhANte7btucc+l9QtqbdRCwOATjeZSxxnSHrJ9rH9/DYiXmvIqgAAEw90RHwsqf77MAEARbgXB9DmVq1aVTRfcv+L7u7uon339/cXzX/11VdF8+2G10EDQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASRFoAEiKQANAUgQaAJIi0ACQFPfiAKaZW2+9tWh+w4YNRfOnnHJK7dkjR44U7bv03h2ffvpp0Xy74QwaAJIi0ACQFIEGgKQINAAkRaABICkCDQBJEWgASIpAA0BSBBoAkiLQAJAUb/Vuc7NmzSqaX758edH8+++/X3u2t7e3aN+dYvbs2UXzt99+e9F8yVu3JWlgYKBpa9m7d2/RfKfjDBoAkiLQAJAUgQaApAg0ACRFoAEgKQINAEkRaABIikADQFIEGgCSItAAkBSBBoCkHBFT92D21D1YG1u8eHHt2Ycffrho35dffnnR/E033VR7dtu2bUX77hQ33nhj0bztovmNGzcWzZfcM6W7u7to3xhdRIz6pHIGDQBJEWgASIpAA0BSBBoAkiLQAJAUgQaApAg0ACRFoAEgKQINAEkRaABIikADQFIntXoBkM4888yi+VdeeaX27Ouvv1607yuvvLJofteuXUXzON7AwEDR/M0331w0X3rvjgcffLBoHs3DGTQAJEWgASCpcQNte7PtAdu9w7adZnu77Q+rj/Oau0wA6Dx1zqCfkbRsxLZ7Je2IiAsk7ai+BgA00LiBjoidkg6N2LxC0pbq8y2SVjZ4XQDQ8Sb6Ko4zIuKgJEXEQdvzxxq0vUbSmgk+DgB0rKa/zC4iNknaJPFPXgFAiYm+iqPf9gJJqj6WvZATADCuiQZ6m6TV1eerJb3cmOUAAI6p8zK75yS9KelC2wds/1LSI5Kutv2hpKurrwEADeSIqbssPJ2vQc+YMaP27Pr164v2ffLJJxfNX3XVVbVnr7nmmqJ99/f3F81j8t57772i+QsvvLBofuPGjUXzd999d9E8Ji8iRn0/Pu8kBICkCDQAJEWgASApAg0ASRFoAEiKQANAUgQaAJIi0ACQFIEGgKQINAAkRaABIKmm3w+6XVx00UW1Z++5554mrkTq7u6uPcu9NVpjyZIltWdnzZrVxJVIg4ODTd0/moczaABIikADQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASRFoAEiKQANAUgQaAJLiXhxAE3z77be1Zz/66KOifZ977rlF8729vUXzyIMzaABIikADQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASRFoAEiKQANAUrzVu6b58+e3egl/tXLlytqzO3fuLNr3kSNHSpfTERYuXFg0f/jw4dqzl156adG+e3p6iuZfffXVonnkwRk0ACRFoAEgKQINAEkRaABIikADQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASXEvjpoiotVL+Ks77rij9mxXV1fRvp955pmi+eeff7727ODgYNG+Sy1durT2bOnzuW7duqL5s88+u/bsnDlzivb9ySefFM1/8803RfPIgzNoAEiKQANAUuMG2vZm2wO2e4dtW2f7M9s91a9rm7tMAOg8dc6gn5G0bJTt/xERXdWv/27ssgAA4wY6InZKOjQFawEADDOZa9B32t5bXQKZN9aQ7TW299jeM4nHAoCOM9FAPyXpfEldkg5KemyswYjYFBGLImLRBB8LADrShAIdEf0RcTQifpT0G0mLG7ssAMCEAm17wbAvb5DUO9YsAGBixn0noe3nJF0h6ae2D0h6SNIVtrskhaT9km5r4hoBoCONG+iIWDXK5qebsBYAwDCeyntM2M5zQ4tCtmvPPvDAA0X7Lr3PQ8lamu3o0aO1Z5v9Z23GjBlN23em/+Zbtmwpmr/llluatBI0SkSM+geMt3oDQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASRFoAEiKQANAUgQaAJIi0ACQ1Lg3S8KQkvtIrF+/vmjf3333XdH89ddfX3t2yZIlRfsu1cz7X5Q6dKj+v8y2b9++on0vXty8W55v3bq1aP6JJ55o0kqQDWfQAJAUgQaApAg0ACRFoAEgKQINAEkRaABIikADQFIEGgCSItAAkBSBBoCkXPIW5kk/mD11D9bGZs6cWXv29NNPL9r3fffdVzQ/MDBQe/bJJ58s2nepo0eP1p79/vvvi/bd1dVVNP/mm2/Wnt2+fXvRvteuXVs0/8EHHxTNY+pFhEfbzhk0ACRFoAEgKQINAEkRaABIikADQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASXEvDqCGOXPmFM0/++yztWevu+66on0vX768aP61114rmsfU414cADDNEGgASIpAA0BSBBoAkiLQAJAUgQaApAg0ACRFoAEgKQINAEkRaABIikADQFIntXoBwHTw0EMPFc0vXbq0SStBJ+EMGgCSGjfQts+x/YbtPtvv2l5bbT/N9nbbH1Yf5zV/uQDQOeqcQf8g6e6I+HtJ/yTpV7b/QdK9knZExAWSdlRfAwAaZNxAR8TBiHi7+vywpD5JZ0laIWlLNbZF0spmLRIAOlHRDwltnyfpEkm7JJ0REQeloYjbnj/G71kjac3klgkAnad2oG3PkbRV0l0R8bU96j8AcJyI2CRpU7UP/kUVAKip1qs4bM/UUJyfjYgXq839thdU318gaaA5SwSAzlTnVRyW9LSkvoh4fNi3tklaXX2+WtLLjV8eAHSuOpc4lki6WdI7tnuqbfdLekTSC7Z/KekTSTc2Z4kA0JnGDXRE/I+ksS44/0tjlwMAOIa3egM1XHbZZUXzp556au3ZL7/8smjfPT094w+hLfBWbwBIikADQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASRFoAEiKQANAUgQaAJLiXhxADbNmzSqa37FjR+3Z3bt3F+37iy++KJrH9MUZNAAkRaABICkCDQBJEWgASIpAA0BSBBoAkiLQAJAUgQaApAg0ACRFoAEgKQINAElxLw6ghosvvrho/sUXX6w929fXV7ocdAjOoAEgKQINAEkRaABIikADQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASfFWb6CGDRs2FM0/+uijtWcHBwdLl4MOwRk0ACRFoAEgKQINAEkRaABIikADQFIEGgCSItAAkBSBBoCkCDQAJEWgASApAg0ASTkipu7B7Kl7MACYJiLCo23nDBoAkho30LbPsf2G7T7b79peW21fZ/sz2z3Vr2ubv1wA6BzjXuKwvUDSgoh42/ZcSW9JWinp55IGI+LXtR+MSxwAcJyxLnGMez/oiDgo6WD1+WHbfZLOauzyAAAjFV2Dtn2epEsk7ao23Wl7r+3Ntuc1eG0A0NFqB9r2HElbJd0VEV9LekrS+ZK6NHSG/dgYv2+N7T229zRgvQDQMWq9zM72TEm/l/SHiHh8lO+fJ+n3EXHROPvhGjQAjDDhl9nZtqSnJfUNj3P1w8NjbpDUO9lFAgD+T51XcSyV9EdJ70j6sdp8v6RVGrq8EZL2S7qt+oHiifbFGTQAjDDWGTTvJASAFuOdhAAwzRBoAEiKQANAUgQaAJIi0ACQFIEGgKQINAAkRaABICkCDQBJEWgASIpAA0BSBBoAkiLQAJAUgQaApAg0ACRFoAEgKQINAEkRaABIikADQFIEGgCSItAAkBSBBoCkCDQAJEWgASCpk6b48f4i6c+jbP9p9b12x3G2n045Vo6zef52rG84IqZyIaMvwt4TEYtavY5m4zjbT6ccK8fZGlziAICkCDQAJJUl0JtavYApwnG2n045Vo6zBVJcgwYAHC/LGTQAYISWBtr2Mtsf2N5n+95WrqXZbO+3/Y7tHtt7Wr2eRrG92faA7d5h206zvd32h9XHea1cYyOMcZzrbH9WPac9tq9t5RobwfY5tt+w3Wf7Xdtrq+1t9Zye4DhTPactu8Rhe4akP0m6WtIBSbslrYqI91qyoCazvV/Soohoq9eS2v5nSYOS/jMiLqq2PSrpUEQ8Uv0f77yI+LdWrnOyxjjOdZIGI+LXrVxbI9leIGlBRLxte66ktyStlPQLtdFzeoLj/LkSPaetPINeLGlfRHwcEd9L+p2kFS1cDyYgInZKOjRi8wpJW6rPt2joD/60NsZxtp2IOBgRb1efH5bUJ+kstdlzeoLjTKWVgT5L0qfDvj6ghP+BGigkvW77LdtrWr2YJjsjIg5KQ/9DkDS/xetppjtt760ugUzrv/aPZPs8SZdI2qU2fk5HHKeU6DltZaA9yrZ2fknJkoj4R0n/KulX1V+ZMb09Jel8SV2SDkp6rLXLaRzbcyRtlXRXRHzd6vU0yyjHmeo5bWWgD0g6Z9jXZ0v6vEVrabqI+Lz6OCDpJQ1d4mlX/dU1vmPX+gZavJ6miIj+iDgaET9K+o3a5Dm1PVND0Xo2Il6sNrfdczracWZ7TlsZ6N2SLrD9M9s/kXSTpG0tXE/T2J5d/SBCtmdL6pbUe+LfNa1tk7S6+ny1pJdbuJamORasyg1qg+fUtiU9LakvIh4f9q22ek7HOs5sz2lL36hSvYRlo6QZkjZHxL+3bDFNZPvvNHTWLA3dQfC37XKstp+TdIWG7gLWL+khSf8l6QVJ50r6RNKNETGtf8A2xnFeoaG/Coek/ZJuO3addrqyvVTSHyW9I+nHavP9Gro+2zbP6QmOc5USPae8kxAAkuKdhACQFIEGgKQINAAkRaABICkCDQBJEWgASIpAA0BSBBoAkvpf+4znI+zDUfAAAAAASUVORK5CYII=\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"img_to_show = np.squeeze(img)\n",
"plt.imshow(img_to_show, cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": 666,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 28, 28, 1)\n",
"Prediction is: 4\n"
]
}
],
"source": [
"# Was this successfully predicted?\n",
"img_batch = (np.expand_dims(img,0))\n",
"print(img_batch.shape)\n",
"predictions_single = model.predict(img_batch)\n",
"print(f'Prediction is: {np.argmax(predictions_single[0])}') "
]
},
{
"cell_type": "code",
"execution_count": 667,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(32, 26, 26)"
]
},
"execution_count": 667,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# layer id should be for a Conv layer, a Flatten will not do\n",
"maps = get_feature_maps(model, layer_id, img)# [0:10]\n",
"maps.shape"
]
},
{
"cell_type": "code",
"execution_count": 668,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWgAAAF1CAYAAADFrXCQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAUgklEQVR4nO3dfZBddX3H8c8nySYZw1MiD6YhPBhCC1UbnBWtVER5VGsDHVHTFkLHTuggY2ixgpQp0RGKVtTiAxoKNbUqYIGCDioMOiIUUxYmQsgSnowQsiYgKEkKm2Tz7R/3rN7EfThn99693937fs3s7N1zv/d3vidn89lfzt7ziyNCAIB8JrW6AQDAwAhoAEiKgAaApAhoAEiKgAaApAhoAEiKgMa4Z/s42+tL1p5l++4R7mfErwVGgoBGw9leZ/uEVvcx3ti+xHbwZ4d+BDTGnO0pre4hG9vzJL1HUk+re0EeBDQayvbXJB0k6du2t9j+iO1DipnhB2w/JekHA12WqJ95255k+0LbT9j+pe0bbM8q2UP/6zbbXmP7tN8t8edt/9r2I7aPr3tib9vX2O6x/YztT9iePLo/lVK+IOkCSdvGYF8YJwhoNFREnCHpKUnvjog9IuJTdU+/VdIRkk4uMdSHJJ1avOb3JL0g6Ysl23hC0lsk7S3pY5L+0/bsuuffKOlJSftKukTSTXXhv0LSDkmHSTpK0kmS/qbMTm3/aoiPC4d43emStkXEbSWPD22Cf2piLC2LiK2SZHu42rMlnRsR64v6ZZKesn1GROwY6oUR8a26L6+3/VFJR0u6pdi2SdLnorYQzfW2z5f0Ltu3S3qHpH0i4iVJW21/VtISSV8ZruGI2Ge4mt3Z3kPSZar9IAB2QUBjLD1dofZgSTfb3lm3rU/SAZKeGeqFts+U9PeSDik27aHabLnfM7HrKmE/V22WfrCkDkk9dT9AJlXsu6qPSfpaRPysifvAOMUlDjTDYEsk1m/fKukV/V8U13n3q3v+aUnviIh96j6mR8Rw4XywpKslnSvplcWsdrWk+in7HO86hT9I0oZin72S9q3b514R8YdDHu1v971liI+LBnnZ8ZI+ZPsXtn8haa6kG2xfUGafmNgIaDTDRkmvHqbmUUnTbb/LdoekiyVNq3v+y5IuLQJXtvezvbDEvmeo9oPg2eJ1fy3pNbvV7K9aKHYU13+PkHRbRPRIul3SFbb3Kn5ROc/2W0vsV8U198E+LhvkZccX/S0oPjaodnmn7PV2TGAENJrhnyVdXPxy7MMDFUTEryWdI+nfVLtksVVS/bs6/lXSrZJut71Z0k9U++XekCJijaQrJN2r2g+K10q6Z7eylZLmS3pO0qWS3hMRvyyeO1PSVElrVPvF5H9Jmq0miYhfRsQv+j9Uu4zzQkRsadY+MX6YBfsBICdm0ACQFAENAEkR0ACQFAENAEkR0ACQ1JjeSTjV02K6ZozlLgEgtZe1Vduid8C1D8Y0oKdrht7424XDAKDtrYw7B31uVJc4bJ9ie63tx4darQsAUN2IA7pYO+GLqq3+daSkRbaPbFRjANDuRjODPlrS4xHxZERsk3SdpDJrJQAAShhNQM/Rrsswri+27cL2Ettdtru2q3cUuwOA9jKagB7ot46/s7BHRCyPiM6I6OzYZbEyAMBQRhPQ61Vbu7bfgaotlQgAaIDRBPR9kubbPtT2VEnvV215SABAA4z4fdARscP2uZK+L2mypGsj4uGGdQYAbW5UN6oU/wsx/xMxADQBa3EAQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkRUADQFIENAAkNaXVDQDIZfIrZ5WunXLT1Epj3zr/e5XqT3zvWaVrJ929qtLY4wEzaABIioAGgKRGdYnD9jpJmyX1SdoREZ2NaAoA0Jhr0G+LiOcaMA4AoA6XOAAgqdEGdEi63fb9tpc0oiEAQM1oL3EcExEbbO8v6Q7bj0TEXfUFRXAvkaTpesUodwcA7WNUM+iI2FB83iTpZklHD1CzPCI6I6KzQ9NGszsAaCsjDmjbM2zv2f9Y0kmSVjeqMQBod6O5xHGApJtt94/zjYiodpsQAGBQIw7oiHhS0h81sBcAQB3W4gAmuE3nvLlS/a9eu6N07aOHXVVp7HtedqX6jme3lK7tqzTy+MD7oAEgKQIaAJIioAEgKQIaAJIioAEgKQIaAJIioAEgKQIaAJIioAEgKQIaAJIioAEgKdbiAMaZDR+ptrbGdedcUan+8I6ppWu37OytNPayM8+pVD9p7apK9RMNM2gASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkuNV7gpu0556V6p99/2sq1c9a81LpWt/T3rftDmbyPntXqv/zv/xRpfoqt25L0r29k0vXfvjj51cae+bd91aqb3fMoAEgKQIaAJIioAEgKQIaAJIioAEgKQIaAJIioAEgKQIaAJIioAEgKQIaAJIioAEgKdbiGIe2n9RZunafi39eaezbDv10pfo/vfjDpWtn3lNp6Lax8b1HVqr/zpeqjf+2j3RXqv/ChreXrp35VdbWaCZm0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFGtxJDD5sEMr1X/sy18pXbvkwb+qNPb7Ll1aqX7m91mLYbRmbOqrVL9x0cuV6id5Z6X65y4r//04Tc9VGhvVMIMGgKQIaABIatiAtn2t7U22V9dtm2X7DtuPFZ9nNrdNAGg/ZWbQX5V0ym7bLpR0Z0TMl3Rn8TUAoIGGDeiIuEvS87ttXihpRfF4haRTG9wXALS9kV6DPiAieiSp+Lz/YIW2l9just21Xb0j3B0AtJ+m/5IwIpZHRGdEdHZoWrN3BwATxkgDeqPt2ZJUfN7UuJYAANLIA/pWSYuLx4sl3dKYdgAA/cq8ze6bku6V9Pu219v+gKTLJZ1o+zFJJxZfAwAaaNhbvSNi0SBPHd/gXlLzlPJ3xa+98vXVBp9a7VbcH2w5snTt3PO2Vhp7x7ruSvUYvfdd9t1K9Uv2Xlep/vDb/rZa/Xfvq1SP5uFOQgBIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIqvwCE22u749fW7r20YVXNbET6W1LzyldO2PdyiZ2gsH0vusNpWtfNeX6JnYiTdo6uanjo3mYQQNAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUqzFATTB5Jd2lq791rOdlcb+sxl3VKrfZ40r1SMPZtAAkBQBDQBJEdAAkBQBDQBJEdAAkBQBDQBJEdAAkBQBDQBJEdAAkBQBDQBJcat3Sf83e1qrW/iNDe/eXrr28G9X6zt6e6u20xZ2vuWoSvUdv365dO0lc75TaexPPPeGSvUHXPdwpfq+StVoJmbQAJAUAQ0ASRHQAJAUAQ0ASRHQAJAUAQ0ASRHQAJAUAQ0ASRHQAJAUAQ0ASRHQAJAUa3GUFdHqDn5j7QlXl65devcxlcb+8fVvrlQ/95ru0rV9L7xQaeyqXn730aVrvbPa+dz/o09Wqn/7rEdK1x48pdpfwx9tnF+pftqL6yrVIw9m0ACQFAENAEkNG9C2r7W9yfbqum3LbD9je1Xx8c7mtgkA7afMDPqrkk4ZYPtnI2JB8XFbY9sCAAwb0BFxl6Tnx6AXAECd0VyDPtf2g8UlkJmDFdleYrvLdtd28b91AEBZIw3oqyTNk7RAUo+kKwYrjIjlEdEZEZ0dyvPfRgFAdiMK6IjYGBF9EbFT0tWSyr8BFQBQyogC2vbsui9Pk7R6sFoAwMgMewuT7W9KOk7SvrbXS7pE0nG2F0gKSesknd3EHgGgLQ0b0BGxaIDN1zShFwBAHccYrjGxl2fFG338mO2voezSpU986k2Vhn74L66sVD8p0Q2gvbG9dG2fmvu9Nt3NW1om05/56/7nrEr1B53+UHMaQUOsjDv1Yjw/YMDk+a4DAOyCgAaApAhoAEiKgAaApAhoAEiKgAaApAhoAEiKgAaApAhoAEiKgAaApAhoAEiqeYsXTDQV1iyZ9w/3Vhr6dds+VKn+2BMeLF37pQPvqjR2VdPc0dTxq3hwW1/p2q89/+ZKY//Lq1ZWbae0t/z0fZXqD7ySv7btghk0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUtwzmsAh/1jt1vCnPz6tdO3Cue+pNPYj/zSzUv2kTVNL1x7+yScqjV3Zjh2lS+OllyoNffKxZ1eq//6/f6V07atmbK409ssvTa9UX36RAmTDDBoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkmItjnEoentL1/Y9/rNKY88/s1p9FX1NG7n5pt9XbR2RU7pPK137vSNurjT2sfM/WKl+r65K5UiEGTQAJEVAA0BSBDQAJEVAA0BSBDQAJEVAA0BSBDQAJEVAA0BSBDQAJEVAA0BSBDQAJMVaHEAJ3ZcfVqn+lsM+X6G6o1ozaBvMoAEgqWED2vZc2z+03W37YdtLi+2zbN9h+7Hi88zmtwsA7aPMDHqHpPMj4ghJb5L0QdtHSrpQ0p0RMV/SncXXAIAGGTagI6InIh4oHm+W1C1pjqSFklYUZSskndqsJgGgHVW6Bm37EElHSVop6YCI6JFqIS5p/0Fes8R2l+2u7Sq/0DwAtLvSAW17D0k3SjovIl4s+7qIWB4RnRHR2aFpI+kRANpSqYC23aFaOH89Im4qNm+0Pbt4frakTc1pEQDaU5l3cVjSNZK6I+IzdU/dKmlx8XixpFsa3x4AtK8yN6ocI+kMSQ/ZXlVsu0jS5ZJusP0BSU9JOr05LQJAexo2oCPibkke5OnjG9sOAKAft3oDJZywYE2l+iM6yt++/b+9g81/BjbrJz2V6ndUqkYm3OoNAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEkR0ACQFAENAEmxFgdQwn5TN1eqX/TkyaVrH/jpvEpjz//Zykr1GL+YQQNAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUgQ0ACRFQANAUqzFAZRwwX7V1r94+/rFpWv3Wju5ajtoE8ygASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkiKgASApAhoAkuJWb6CEo278u0r1f7BsbenavhcerdoO2gQzaABIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIioAGgKQIaABIirU4gBLmL/1Jpfq+JvWB9sIMGgCSGjagbc+1/UPb3bYftr202L7M9jO2VxUf72x+uwDQPspc4tgh6fyIeMD2npLut31H8dxnI+LTzWsPANrXsAEdET2SeorHm213S5rT7MYAoN1VugZt+xBJR0laWWw61/aDtq+1PbPBvQFAWysd0Lb3kHSjpPMi4kVJV0maJ2mBajPsKwZ53RLbXba7tqu3AS0DQHsoFdC2O1QL569HxE2SFBEbI6IvInZKulrS0QO9NiKWR0RnRHR2aFqj+gaACa/Muzgs6RpJ3RHxmbrts+vKTpO0uvHtAUD7KvMujmMknSHpIdurim0XSVpke4GkkLRO0tlN6RAA2lSZd3HcLckDPHVb49sBAPTjTkIASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkHBFjtzP7WUk/H+CpfSU9N2aNtA7HOfG0y7FynM1zcETsN9ATYxrQg7HdFRGdre6j2TjOiaddjpXjbA0ucQBAUgQ0ACSVJaCXt7qBMcJxTjztcqwcZwukuAYNAPhdWWbQAIDdtDSgbZ9ie63tx21f2Mpems32OtsP2V5lu6vV/TSK7Wttb7K9um7bLNt32H6s+DyzlT02wiDHucz2M8U5XWX7na3ssRFsz7X9Q9vdth+2vbTYPqHO6RDHmeqctuwSh+3Jkh6VdKKk9ZLuk7QoIta0pKEms71OUmdETKj3kto+VtIWSf8REa8ptn1K0vMRcXnxg3dmRFzQyj5Ha5DjXCZpS0R8upW9NZLt2ZJmR8QDtveUdL+kUyWdpQl0Toc4zvcq0Tlt5Qz6aEmPR8STEbFN0nWSFrawH4xARNwl6fndNi+UtKJ4vEK1b/xxbZDjnHAioiciHigeb5bULWmOJtg5HeI4U2llQM+R9HTd1+uV8A+ogULS7bbvt72k1c002QER0SPV/iJI2r/F/TTTubYfLC6BjOt/9u/O9iGSjpK0UhP4nO52nFKic9rKgPYA2ybyW0qOiYjXS3qHpA8W/2TG+HaVpHmSFkjqkXRFa9tpHNt7SLpR0nkR8WKr+2mWAY4z1TltZUCvlzS37usDJW1oUS9NFxEbis+bJN2s2iWeiWpjcY2v/1rfphb30xQRsTEi+iJip6SrNUHOqe0O1ULr6xFxU7F5wp3TgY4z2zltZUDfJ2m+7UNtT5X0fkm3trCfprE9o/hFhGzPkHSSpNVDv2pcu1XS4uLxYkm3tLCXpukPrMJpmgDn1LYlXSOpOyI+U/fUhDqngx1ntnPa0htVirewfE7SZEnXRsSlLWumiWy/WrVZsyRNkfSNiXKstr8p6TjVVgHbKOkSSf8t6QZJB0l6StLpETGuf8E2yHEep9o/hUPSOkln91+nHa9s/4mkH0t6SNLOYvNFql2fnTDndIjjXKRE55Q7CQEgKe4kBICkCGgASIqABoCkCGgASIqABoCkCGgASIqABoCkCGgASOr/AeCx68Pwssv6AAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Plot just a subset\n",
"maps = get_feature_maps(model, layer_id, img)[0:10]\n",
"\n",
"fig, ax = plt.subplots()\n",
"img = np.squeeze(img)\n",
"ax.imshow(img + 0.5)\n",
"label = y_test[image_id,:]\n",
"label = int(np.where(label == 1.)[0])\n",
"\n",
"ax.set_title(f'true label = {label}')\n",
"\n",
"f, ax = plt.subplots(3,3, figsize=(8,8))\n",
"for i, axis in enumerate(ax.ravel()):\n",
" axis.imshow(maps[i], cmap='gray')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### `tf_keras_vis.gradcam.Gradcam`\n",
"\n",
"[Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization](https://arxiv.org/pdf/1610.02391.pdf)"
]
},
{
"cell_type": "code",
"execution_count": 669,
"metadata": {},
"outputs": [],
"source": [
"#from tensorflow.keras import backend as K\n",
"# Define modifier to replace a softmax function of the last layer to a linear function.\n",
"def model_modifier(m):\n",
" m.layers[-1].activation = tf.keras.activations.linear"
]
},
{
"cell_type": "code",
"execution_count": 670,
"metadata": {},
"outputs": [],
"source": [
"#img_batch = (np.expand_dims(img,0))\n",
"# Define modifier to replace a softmax function of the last layer to a linear function.\n",
"def model_modifier(m):\n",
" m.layers[-1].activation = tf.keras.activations.linear\n",
"\n",
"# Create Saliency object\n",
"saliency = Saliency(model, model_modifier)\n",
"\n",
"# Define loss function. Pass it the correct class label.\n",
"loss = lambda output: tf.keras.backend.mean(output[:, tf.argmax(y_test[image_id])])"
]
},
{
"cell_type": "code",
"execution_count": 671,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 28, 28, 1)\n"
]
}
],
"source": [
"# Generate saliency map\n",
"print(img_batch.shape)"
]
},
{
"cell_type": "code",
"execution_count": 679,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"saliency_map = saliency(loss, img_batch)\n",
"\n",
"saliency_map = normalize(saliency_map)\n",
"\n",
"f, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5)) #, subplot_kw={'xticks': [], 'yticks': []})\n",
"ax[0].imshow(saliency_map[i], cmap='jet')\n",
"ax[1].imshow(img);"
]
},
{
"cell_type": "code",
"execution_count": 686,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASEAAAEhCAYAAAAwHRYbAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAL4klEQVR4nO3cW4xdZRkG4H+3QztTylA7AwUplFYFDWlATVCj4iEEY4x4iIZg0Bhj4gXBxFtvveBeTfSCaDzFRE2IJh4IMeI5oAIiolRap4XSQoexFNtph7bbC65Qkv+dssevM/M8l83bb63Zu/Ou1eTLPxgOhw2gyprqGwBWNyUElFJCQCklBJRSQkApJQSUGltMeO1gw3CsbVqqewFWsIV2YHY4HF7w33++qBIaa5va1vbp0d0VsGrsaV/Y+1J/7r9jQCklBJRSQkApJQSUUkJAKSUElFJCQKlF7QnB8nU6zI3yuVxxzeVndf/0QDklBJRSQkApJQSUUkJAKSUElFJCQCklBJSyrMhZatSLfqN83qazJsLcXJA5Gc4a9a/0+hHP+1/ehIBSSggopYSAUkoIKKWEgFJKCCilhIBSSggopYSAUjamOUuN+vmYbhwnW86XhrPOyWIbdvQzC+El0x8zlmxzH3hZV/AmBJRSQkApJQSUUkJAKSUElFJCQCklBJRSQkApJQSUsjH9f5Gel5xYLc+NE2FuS5i75ExvZOlNB5n0n1D6G70uzO2aDEJPhcNe2mr5Fw2cpZQQUEoJAaWUEFBKCQGllBBQSgkBpZQQUEoJAaVW4cb0KLeXU8nHnN7XqO9/lM+hUa/1Bq4IN6HfG97b3f3PYzCTzVpz09Eo97GpP3cz3/zee6JZg6eza7ZjWezlnh+d8CYElFJCQCklBJRSQkApJQSUUkJAKSUElFJCQKllsKyYLsClfZrk0lnpvS2EucSov7LkZxjlZ9ta9nlcGU2avvlglHtdeyTK/ez1n+xmju84FM26bfPvo9y+k/3NwTXH1kezhsfCZcX2WJibD3NnzpsQUEoJAaWUEFBKCQGllBBQSgkBpZQQUEoJAaWUEFBqiTamR7mFm24lbx5hLt06fT7MzQaZ9PO4OMyln9vSH9/5v070I5edG006rz0Z5X7+5o9EuQ8f+XE3M3XRU9GsheEFUe6en/XvbRB/T+m/3SNhLtvUfjm8CQGllBBQSgkBpZQQUEoJAaWUEFBKCQGllBBQSgkBpZZoY3qU3ZZtzrbJS7LcZUHm4Yls1rosNrju8m7m6PYLo1kTvziZXfOxv0a5bCM23QxPXdqPXJtNmtm7M8rtiLbWW5uanu5mnjiUnRl+1z3XRbmJ08E2dLrg3PaFuaXfhE55EwJKKSGglBICSikhoJQSAkopIaCUEgJKKSGg1BItKyYLdeFCYOsvj7XW4sXB+NTTwKlL/x3lxoOjSm8573fRrO9s3B7lJtpclGvtVJBJP7Qwtyk4onZtNupoeITqrv6pra211ra/9VA3c+9zl0ezJg4/m100+dgWsmXL1ubDXLqsOMqjmpfibwO8TEoIKKWEgFJKCCilhIBSSggopYSAUkoIKKWEgFKL3JgetmiDcvLqfub68JLpaZV/D3OP9CODzceiUdd/9Lko96O9/c3Z7//qeDRr4rE/RbncKJfmk+3r1toV/ciW1+6PRj2xJTvW9+h4dm+DwUw3c+zX745mjV1+NMq1+5JQf5P7BeeEudTSv6d4EwJKKSGglBICSikhoJQSAkopIaCUEgJKKSGglBICSi1yXXYQ/ZXJz/XPN/5QuzO64p07PxTljtyfnTX8zPuCzdOJndGsPTv3RrnzvzrsZk4f/ks0q7WNYe5EmEueQ8mZ4bmxd/W3w69pD0SzrvhItml+6foNUe5LP3x/NzN16OloVtuSxVq0oJ9+B8vvvWL53TGwoighoJQSAkopIaCUEgJKKSGglBICSikhoJQSAkotcmN6TWutv3l6bbu3m9mQrYm27W1PlHvw81dFudve9EQ3M/amX0ez7rjxuii37vDdQWptNKu158NcKjgzPH5WTUSpNWv617zvNTdGs3Y8+kyUa9PZZvXglev7ocuyS7aFMBd9ByuXNyGglBICSikhoJQSAkopIaCUEgJKKSGglBICSi3J8a4H28XdzO726uiKN7S7otyNd90f5W65qZ/72vgnoln5Mlpi1M+DdN7/f1lx4fZ13czx110UzZr/7HiUu3Lr7ig3Ph+ENkejWnsyzLUjaXBF8iYElFJCQCklBJRSQkApJQSUUkJAKSUElFJCQCklBJRa5Mb06dZaf6X04T9c0x/VPyW2tdbalvHvRrm7j22Ncjdv6Wce/+v2aNZz2w5HuamHgq4/WfU8GOV1s83f4bbzupk1z+6NZr1z1yNR7leP7Ixy5+451M0Mz4lGtXYwzLW5IJMeATvKTfnFzDtz3oSAUkoIKKWEgFJKCCilhIBSSggopYSAUkoIKKWEgFKL3JgettZO9GM/CUZNZ1f88q2fiXLX3/hQlPvi+LX90A/6W7OttXbbjr9FuZ9+6pJuZt/D/UxrrU0+kGzXtjacTw5Lzpy8Mru3NjwVxc59+2w3s31iXzRr05rsu5p5Nvs8xnYHz+Ud0ahFHB0d/E6VnUG+9M6eOwFWJSUElFJCQCklBJRSQkApJQSUUkJAKSUElDqDZcXkWMhd/chsdum527PNsO9ddEM2744Luplbd94RzRq09VHuvec93g+9ZSaadeotUSw+vLO1/oLhWHsynJU90wZtGKSyxcfWjmap49m9nZ+EslN9W2sLaXBV8yYElFJCQCklBJRSQkApJQSUUkJAKSUElFJCQCklBJRa5Mb0oGW9lRxXGZ59ubA2y+17Popt3vebbuYrp94dzdq240CUe9/kniCV/ZzhpxHnkuRTp05Gk/48vzXK3bBxf5DaGM362rqPRbnJC9dFuTYdfKcz2ajWsqNns/32lfu+sHJ/MmBZUEJAKSUElFJCQCklBJRSQkApJQSUUkJAKSUElFrkxnQq6baJcFa4WR3b3E1s+vk/o0nP/jLbwv3u5M5uZvYd49GswdFsF3rqt/+KcpHT4YnV2dJ6+9a2q7uZj38wO8h5403zUW7HvY9GuV2/m+6HjmQb5On51zamAQopIaCUEgJKKSGglBICSikhoJQSAkopIaCUEgJKLdHGdCLcwh15TybbruE1T2abs8O5/vby1J3ZJVPDOJl8D+l3lZwt3trYk/3ct2ffEM265Z+7o9zXn7k2yq1feDxIpb82p8LcKFX9Xp25s+dOgFVJCQGllBBQSgkBpZQQUEoJAaWUEFBKCQGlCpcVq/pP775Y8nmk/0yyI3tnr9/ezdy8+eHskg9tyHL3Z7Fo169/QvAL5sJ7G/kRxsuL30iglBICSikhoJQSAkopIaCUEgJKKSGglBICSikhoFThxjQvNupjObOjZ1s7tx+55tXRpKs/8McoNzPev+b07Npo1v6jUaxN3PdclDt9fhBKF6Hn/HolvAkBpZQQUEoJAaWUEFBKCQGllBBQSgkBpZQQUEoJAaWsdC476XNjPsxt6yZe8YFD0aRXtd1R7untb+xmfvCP10azDhzMzrWeOvCvKNemgszhbFRrJ9LgquZNCCilhIBSSggopYSAUkoIKKWEgFJKCCilhIBSlhWXnezY09w53cSJ8Fn12/a2KHfdo/u7mW8cvySatf6Z8KzV9HH77yQTnilbsqy4/N4rlt8dAyuKEgJKKSGglBICSikhoJQSAkopIaCUEgJKKSGglI3pZedUmEu/2iPdxLEfXhxNOjY3GeW+etXGbmb6nuyo2OH8wSjX2qYsNpMcFzsXXvN0mFvdvAkBpZQQUEoJAaWUEFBKCQGllBBQSgkBpZQQUEoJAaVsTC876RZu/+zoFwTnJT+YbgjPRKmpfVu7mWE7Hl7zUJibDXPJxrRN6FHyJgSUUkJAKSUElFJCQCklBJRSQkApJQSUUkJAKSUElLIxfdYY9fMg/Wrng8z+cNb6MPdUmEskG86LkXweqfQ7Xd3vAqv7pwfKKSGglBICSikhoJQSAkopIaCUEgJKKSGglGXFVa/iOXQ2P/vO5ntbmXziQCklBJRSQkApJQSUUkJAKSUElFJCQCklBJRSQkCpwXA4zMODwaHW2t6lux1gBds2HA4v+O8/XFQJAYya/44BpZQQUEoJAaWUEFBKCQGllBBQSgkBpZQQUEoJAaX+A6lXvVJbtehkAAAAAElFTkSuQmCC\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# from matplotlib import cm\n",
"# from tf_keras_vis.gradcam import Gradcam\n",
"\n",
"# Create Gradcam object\n",
"gradcam = Gradcam(model, model_modifier)\n",
"\n",
"# Generate heatmap with GradCAM\n",
"cam = gradcam(loss, img_batch)\n",
"cam = normalize(cam)\n",
"\n",
"f, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5),\n",
" subplot_kw={'xticks': [], 'yticks': []})\n",
"for i in range(len(cam)):\n",
" heatmap = np.uint8(cm.jet(cam[i])[..., :3] * 255)\n",
" ax.imshow(img)\n",
" ax.imshow(heatmap, cmap='jet', alpha=0.5)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}