How to deploy Keras model to production using flask (part – 2)

Deploy Keras model to production, Deploy Keras model to production using flask

Hello everyone, this is part two of the two-part tutorial series on how to deploy Keras model to production. In part one of the tutorial series, we looked at how to use Convolutional Neural Network (CNN) to classify MNIST Handwritten digits using Keras. We also saved the model file obtained after training. In this part of the tutorial series, we are going to see how to deploy Keras model to production using Flask.

Flask is part of the categories of the micro-framework. Micro-frameworks are normally framework with little to no dependencies to external libraries. This means that the framework is light and there is little dependency to update and watch for security bugs. A very simple flask app for web rendering would something like:

from flask import Flask
app = Flask(__name__)

def index():
    return "Hello World!"

if __name__ == "__main__":'', port=5000)

Save it under the filename and run it using the command


Just as easy as that your ‘hello world’ flask web app is up and running. I won’t be able to cover complete flask tutorial on this post but you can refer to flask documentation if you want to learn more. It’s well documented and easy to understand.

The basic structure of a flask web application looks like this:

$ tree deploy_mnist_flask/
|-- static
|-- templates

The templates folder is the place where the templates will be put. The static folder is the place where any files (images, css, javascript) needed by the web application will be put.

Make a file named index.html inside the templates directory and copy/paste the code below in the file. This is the html file we will render using flask.

<!DOCTYPE html>
<html lang="en">
    <meta charset="utf-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <!-- The above 3 meta tags *must* come first in the head; any other head content must come *after* these tags -->
    <meta name="description" content="">
    <meta name="author" content="">

    <title>MNIST Handwritten text recognition using keras</title>

    <!-- Bootstrap core CSS -->
    <link rel="stylesheet" href="" integrity="sha384-BVYiiSIFeK1dGmJRAkycuHAHRg32OmUcww7on3RYdg4Va+PmSTsz/K68vbdEjh4u" crossorigin="anonymous">
    <link rel="stylesheet" href="{{ url_for('static',filename='style.css') }}">


    <div class="container">
      <div class="header clearfix">
          <ul class="nav nav-pills pull-right">
            <li role="presentation" class="active"><a href="#">Home</a></li>
            <li role="presentation"><a href="">About</a></li>
        <h3 class="text-muted">MNIST Handwritten CNN</h3>

      <div class="jumbotron">
        <h3 class="jumbotronHeading">Draw the digit inside this Box!</h3>
    <div class="slidecontainer">
      <p>Drag the slider to change the line width.</p>
      <input type="range" min="10" max="50" value="15" id="myRange">
      <p>Value: <span id="sliderValue"></span></p>
    <div class="canvasDiv">
          <canvas id="canvas" width="280" height="280"></canvas>
          <p style="text-align:center;">
            <a class="btn btn-success myButton" href="#" role="button">Predict</a>
            <a class="btn btn-primary" href="#" id="clearButton" role="button">Clear</a>

      <div class="jumbotron">
      	<p id="result">Get your prediction here!!!</p>

      <footer class="footer">
        <p>&copy; 2018,</p>

    </div> <!-- /container -->

  <script src=''></script>

    <script src="{{ url_for('static',filename='index.js') }}"></script>

    <script type="text/javascript">
      var $SCRIPT_ROOT = {{ request.script_root|tojson|safe }};
      var canvasObj = document.getElementById("canvas");
      var img = canvasObj.toDataURL();
        type: "POST",
        url: $SCRIPT_ROOT + "/predict/",
        data: img,
        success: function(data){
          $('#result').text(' Predicted Output: '+data);


This is the html for our landing page. It links to two external files, index.js and style.css

The code forindex.js is:

  var canvas = document.querySelector( "#canvas" );
  var context = canvas.getContext( "2d" );
  canvas.width = 280;
  canvas.height = 280;

  var Mouse = { x: 0, y: 0 };
  var lastMouse = { x: 0, y: 0 };
  context.color = "white";
  context.lineWidth = 15;
    context.lineJoin = context.lineCap = 'round';

  canvas.addEventListener( "mousemove", function( e )
    lastMouse.x = Mouse.x;
    lastMouse.y = Mouse.y;

    Mouse.x = e.pageX - this.offsetLeft;
    Mouse.y = e.pageY - this.offsetTop;

  }, false );

  canvas.addEventListener( "mousedown", function( e )
    canvas.addEventListener( "mousemove", onPaint, false );

  }, false );

  canvas.addEventListener( "mouseup", function()
    canvas.removeEventListener( "mousemove", onPaint, false );

  }, false );

  var onPaint = function()
    context.lineWidth = context.lineWidth;
    context.lineJoin = "round";
    context.lineCap = "round";
    context.strokeStyle = context.color;
    context.moveTo( lastMouse.x, lastMouse.y );
    context.lineTo( Mouse.x, Mouse.y );

  function debug()
    /* CLEAR BUTTON */
    var clearButton = $( "#clearButton" );
    clearButton.on( "click", function()
        context.clearRect( 0, 0, 280, 280 );
    /* LINE WIDTH */

    var slider = document.getElementById("myRange");
    var output = document.getElementById("sliderValue");
    output.innerHTML = slider.value;

    slider.oninput = function() {
      output.innerHTML = this.value;
      context.lineWidth = $( this ).val();
    $( "#lineWidth" ).change(function()
      context.lineWidth = $( this ).val();

And the code for style.css is

/* Space out content a bit */
body {
  padding-top: 20px;
  padding-bottom: 20px;

/* Everything but the jumbotron gets side spacing for mobile first views */
.footer {
  padding-right: 15px;
  padding-left: 15px;

/* Custom page header */
.header {
  padding-bottom: 20px;
  border-bottom: 1px solid #e5e5e5;
/* Make the masthead heading the same height as the navigation */
.header h3 {
  margin-top: 0;
  margin-bottom: 0;
  line-height: 40px;

/* Custom page footer */
.footer {
  padding-top: 19px;
  color: #777;
  border-top: 1px solid #e5e5e5;

/* Customize container */
@media (min-width: 768px) {
  .container {
    max-width: 730px;
.container-narrow > hr {
  margin: 30px 0;

/* Main marketing message and sign up button */
.jumbotron {
  text-align: center;
  border-bottom: 1px solid #e5e5e5;
  padding-top: 20px;
  padding-bottom: 20px;

  text-align: center;

@media screen and (min-width: 768px) {
  /* Remove the padding we set earlier */
  .footer {
    padding-right: 0;
    padding-left: 0;
  /* Space out the masthead */
  .header {
    margin-bottom: 30px;
  /* Remove the bottom border on the jumbotron for visual effect */
  .jumbotron {
    border-bottom: 0;

@media screen and (max-width: 500px) {
    display: none;


  float: left;
  width: 30%;

  margin-bottom: 7vh;

  display: flow-root;
  text-align: center;

Save these files inside the static directory. I am not going into an in-depth discussion about how to write javascript code or css styling in the post but if you are having problems understanding any of it, please feel free to mention below in the comments section below and we will help you.

Okay! We have all the files required to render our webpage ready. We use render_templatefunction from the flask module to render the html file. Import the module at the top of your page and use it to render the index.html page we just created by updating our file as below:

from flask import Flask, render_template
app = Flask(__name__)

def index():
    return render_template("index.html")

if __name__ == "__main__":'', port=5000)

If you haven’t modified the html/css file, you should see a webpage that looks like the one below:
Deploy keras module to production using flask

We will copy the model.h5″ and model.json” files we created in part 1 of this tutorial inside the model folder in the working directory. Create a directory named model and copy paste the files inside the folder. We will also create a file named load.pywhich loads the model structure and model weight. Copy/paste the code below in the file:

from keras.models import model_from_json
import tensorflow as tf

def init():
  json_file = open('model.json','r')
  loaded_model_json =
  loaded_model = model_from_json(loaded_model_json)
  #load weights into new model
  print("Loaded Model from disk")

  #compile and evaluate loaded model
  graph = tf.get_default_graph()

  return loaded_model,graph

We will later call this from

As of right now, the slider to change the line width and the clear button works but as we can see, the predict doesn’t do anything. Let’s fix that. First, let’s update our file with the code below. The code is well documented in itself so I don’t think there is much need to explain everything. However, we will go through some core concepts.

# requests are objects that flask handles (get set post, etc)
from flask import Flask, render_template, request
# scientific computing library for saving, reading, and resizing images
from scipy.misc import imread, imresize
# for matrix math
import numpy as np
# for regular expressions, saves time dealing with string data
import re
# system level operations (like loading files)
import sys
# for reading operating system data
import os

# tell our app where our saved model is

from load import *

# initalize our flask app
app = Flask(__name__)
# global vars for easy reusability
global model, graph
# initialize these variables
model, graph = init()

import base64

# decoding an image from base64 into raw representation
def convertImage(imgData1):
    imgstr ='base64,(.*)', str(imgData1)).group(1)
    with open('output.png', 'wb') as output:

def index():
    return render_template("index.html")

@app.route('/predict/', methods=['GET', 'POST'])
def predict():
    # whenever the predict method is called, we're going
    # to input the user drawn character as an image into the model
    # perform inference, and return the classification
    # get the raw data format of the image
    imgData = request.get_data()
    # encode it into a suitable format
    # read the image into memory
    x = imread('output.png', mode='L')
    # make it the right size
    x = imresize(x, (28, 28))
    # imsave('final_image.jpg', x)
    # convert to a 4D tensor to feed into our model
    x = x.reshape(1, 28, 28, 1)
    # in our computation graph
    with graph.as_default():
        # perform the prediction
        out = model.predict(x)
        print(np.argmax(out, axis=1))
        # convert the response to a string
        response = np.argmax(out, axis=1)
        return str(response[0])

if __name__ == "__main__":
    # run the app locally on the given port'', port=5000)
# optional if we want to run in debugging mode

At the top of our file, we import the necessary libraries. We already have the web server up and running. The file we created earlier is used to load the model weight and model structure so that we can make the prediction. What we do next is, whenever anyone clicks the Predict button, we read the image on the canvas. Such read image is passed to the predict function in the base64 format which we will convert to a .png file. We will then use the converted image and resize it to 28*28 pixels as used in the MNIST dataset. The image is then converted to a 4D tensor as used in the training set and we use the model.predict() function to predict the output and pass the response back to the calling AJAX method to update the html div display the result.

And presto!

MNIST digit classification in browser

If you’re having any trouble at any of the steps, please reach out to us in the comments section and we will solve it for you. The complete code used in this tutorial can be found in this github repo.

A huge shoutout to Siraj Rawal as this tutorial series is inspired by his video “How to deploy Keras model to production” 

12 Comments on How to deploy Keras model to production using flask (part – 2)

  1. Hi! I keep getting an empty output.png and the prediction doesn’t work, I even copy paste the entire code to see if maybe I’m doing something wrong but I have the same problem!

  2. Hey!
    When running, I am getting the error, “NameError: name ‘init’ is not defined”.
    Please help me to solve this problem.

      • once i run code just for 3 epochs ..the files model.h5 and model_json are not saved in my directory me plz

  3. Hi!
    I stumbled over your work since I was done with my model I needed to get in production. I wanted to know if you can allow me to steal some parts of your work to accomplish what I want to do, since you could solve it very well with javascript and flask. I wanted to do it with django, but I gave up at a point.
    Thanks again,



  4. Hello , please fo templates and satatic, should i create it or they should be downloaded with flask! and if it is the choice one, i create theses files in the folder flask? thanks

    • Hello. For the purpose of following along with this tutorial, you need to create the files in the respective folders as mentioned in the post.

  5. Hello, thanks for the tutorial. I have a question: I’m quite familiar with Python and Flask, have been working with Keras for about 2 months now. I have my where I set up Flask with 2 endpoints, and where I’ve created Predictor class which basically loads two models upon being initialized. The issue I’m experiencing is this: if I test the Predictor class from given the same image path it works fine. However, if I call the same method via Flask (on client request) given the same path (I’ve checked, it definitely loads the image properly)… once it comes to model.predict() line in the code it crashes and returns ValueError. I am not sure what the issue is here, especially since I’m seeing that combining Keras with Flask is a popular practice 🙂

    • Could you please upload the link to your code so that we can have a more detailed look at it. It seems like you’re doing everything fine. Must be some small issue.

      • the code can be seen here (I’m not quite certain it’s proper, so that might be why it fails) but it is important that I don’t have to clear_session() every time before I try to call any Keras code to predict anything seeing as I have to load rather large models.

1 Trackbacks & Pingbacks

  1. How to deploy django to production using DigitalOcean (Part-1) |

Leave a Reply

Your email address will not be published.

This site uses Akismet to reduce spam. Learn how your comment data is processed.