Plot Keras Model in Colaboratory

Google Colaboratory or in short, Colab provides sufficiently powerful platform to run machine learning projects on Jupyter Notebook. The best thing is we can use NVIDIA Tesla K80 GPU for free! However there are some caveats.



Update (2019 April 22): NVIDIA Tesla T4 GPU is now available for free! It has much better performance than K80.


Just upload a notebook to Google Drive and open it with Colaboratory, yes, as simple as that. We can also create new notebook at the main page, which will be saved inside Google Drive Colab Notebooks folder.

As of the date of this post, one cannot plot deep learning model via Keras directly even though Keras is installed by default on Colab.

ImportError: Failed to import pydot. Please install pydot. For example with pip install pydot.

NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
“Open Examples” button below.

The errors encountered and steps taken towards solution can be seen in Plot Keras Model in Colab Error Reproduced.ipynb notebook.

Open In Colab

Looking at just the solution:

Plot Keras Model in Colab.ipynblink
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Install dependencies
!apt install graphviz
!pip install pydot pydot-ng
!echo "Double check with Python 3"
!python -c "import pydot"

# Restart runtime to allow Jupyter to know the changes above
import os
os._exit(0)

from keras.models import Model
from keras.layers import Input, Dense
from keras.utils import plot_model

# Multi-layer neural networks
inputs = Input(shape=(10,))
hidden1 = Dense(10, activation='relu')(inputs)
hidden2 = Dense(30, activation='relu')(hidden1)
hidden3 = Dense(10, activation='relu')(hidden2)
output = Dense(1, activation='sigmoid')(hidden3)
model = Model(inputs=inputs, outputs=output)

# Model summary
print(model.summary())

# Plot model graph
plot_model(model, show_shapes=True, show_layer_names=True, to_file='model.png')
from IPython.display import Image
Image(retina=True, filename='model.png')

Keras model graph plot sample