Serve Keras Model with Flask REST API
This tutorial will briefly discuss the benefits of serving a trained Machine Learning Model with an API. Then we will take a look at a precise example using a Keras Model and Python Flask to serve the model. You’ll learn how to Serve Keras Model with Flask REST API.
Why serve ML Model with API?
Flexibility
Encapsulation the execution and manipulation of your machine learning model with an API has a few benefits. Of the benefit is the abstraction layer that you create with an (REST) API. This abstraction layer enables you to
- test your application more easily (with tools that can send API Requests but cannot import your Tensorflow/PyTorch model directly
- develop your application (you can initiate an execution with a REST Plugin, with your Browser or in CLI with curl
- share functionality as a service (by deploying and making accessible via HTTPS; deploying it as a micro-service)
Mobility
Since your Model now be tweaked with pure HTTP Requests, you can deploy your Model and access/manage it via Requests. No need to login into the SSH to change a cronjob, to change a limit of SQL Query or initiate a new build to deploy a newer version.
Also, with REST API you can deploy your model easily to services like AWS ElasticBeanstalk, Google AppEngine, etc. They all need a working Server in order to deploy your app. And now you can call your Model an Application, because in fact, it is.
Serving Keras Model with Flask
The following Application Structure and Code are just one of the many possibilities how tackle this idea. If you don’t like it, you can check out the CookieCutter Template for more structure and MetaFlow for a whole complete framework.
Folder Structure
We need following folders to encapsulate the scripts, classes etc.
In the screenshot below you can see that we have a folder with different models (model_x.py). We need this separation because you could have multiple Machine Learning Models that need to be served by the same Flask Server. Optionally you can create a ML Model Loader Class that will create Machine Learning Model based a configuration file (e.g saved in yaml, json or database).
In the queries folder you store your (in most cases very long) SQL Queries. You replace certain options with {parameter_x}, e.g: LIMIT {limit}. This way you’ll be able to dynamically generate parameterized SQL Queries reading the .sql File:
sql_file.read().format(limit=10000)
Tests/ are for tests. We are going to skip this due to scope.
Config.py is for Configurations (SQL Creds, Server Envs, etc.). We are going to skip this due to scope.
Server.py is for Flask serving our Services.
Our Services do certain actions with our model. For example: „train_service“ would initiate a training process for a certain model. „prediction_service“ would initiate a prediction process for a model and so on.
Machine Learning Model Class
In the code section below you can see a simple DynamicModel class with only one method which return the the compiled Keras Model. This Model does not have to be static and can be outsourced into a „build_model()“ method or similar. Also, all the parameters in the layers like the input_shape should be set via the method parameters from model(). Since this is only an introductory tutorial, many useful methods are missing in this class.
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
import keras
"""
Author: Andrey Bulezyuk @ German IT Academy (git-academy.com)
Date: 18.01.2020
"""
class DynamicModel():
def __init__(self, model_name = None):
self.model_name = model_name
def model(self):
model = Sequential()
model.add(Conv2D(32, (5, 5), input_shape=(28, 28, 1), activation='relu'))
model.add(MaxPooling2D())
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
return model
Service Layer
Why do we need a service layer between Flask API (server.py) and the Machine Learning Model (dynamic_model.py)? Simple. By having this extra layer (service.py) you can execute the services (in our case Class Methods) not only via REST API, but also from within other python modules.
Our service layer is responsible for importing the DynamicModel Class, loading and saving the trained model for prediction or training respectively.
import sys, os, datetime
sys.path.insert(1, os.path.join(os.getcwd(), "src/models"))
from dynamic_model import DynamicModel
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import load_model
"""
Author: Andrey Bulezyuk @ German IT Academy (https://git-academy.com)
Date: 18.01.2020
"""
class Service():
# model_name must be supplied.
# otherwise no configuration cad be loaded.
def __init__(self, model_name=None):
self.model_name = model_name
self.dynamic_model = DynamicModel(self.model_name)
def _get_train_data(self):
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# reshape to be [samples][width][height][channels]
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1)).astype('float32')
x_test = x_test.reshape((x_test.shape[0], 28, 28, 1)).astype('float32')
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
self.x_train = x_train
self.x_test = x_test
self.y_train = y_train
self.y_test = y_test
def train(self):
# Load data
self._get_train_data()
# This return the compiled Keras Model from dynamic_model->model()
print(self.y_train)
model = self.dynamic_model.model()
model.fit(self.x_train, self.y_train,
batch_size=1000,
epochs=4,
verbose=1)
# Save trained model
now = datetime.datetime.now()
model.save(f"src/models/{self.model_name}_{now.year}{now.month}{now.day}_{now.hour}{now.minute}.h5")
return True
def predict(self, X):
# Load model
model = self._load_model()
# Execute
results = model.predict(X)
if results is not None and results != False:
return results
return False
The train method works perfectly fine. You can see this in the section below when we execute it via Flask REST API with curl. The predict service method is not functional yet. The code and explanation for this is outside of the scope of this tutorial. Keep checking our IT Course Shop for similar courses with more in-depth material.
Flask API Server
Our server part is pretty simple. We import flask and our Service class. We create a route called ’service‘ with two parameters: service_name (which can be train, predict, stop, status, history, …) and model_name. Based on the parameters we execute the specified service.
import sys, os, json
sys.path.insert(1, os.getcwd())
sys.path.insert(1, os.path.join(os.getcwd(), "src"))
from flask import Flask, request
from service import Service
"""
Author: Andrey Bulezyuk @ German IT Academy (https://git-academy.com)
Date: 18.01.2020
"""
application = Flask(__name__)
@application.route("/")
def hello():
return "Hello World!"
@application.route("/<string:service_name>/<string:model_name>", methods=["GET", "POST"])
def service(service_name=None, model_name=None):
service = Service(model_name=model_name)
# GET Request is enough to trigger a training process
if service_name == 'train':
service.train()
# POST Request is required to get the X data for prediction process
elif service_name == 'predict':
service.predict()
return f"Service: {service_name}. Model: {model_name}. Success."
if __name__ == "__main__":
application.run(debug=True)
Example CLI & GET REquest
C:\Users\andre\code\servekeraswithapi>curl localhost:5000/train/ModelA
Service: train. Model: ModelA. Success.
Epoch 1/4
60000/60000 [==============================] - 8s 141us/step - loss: 6.5852 - accuracy: 0.7280
Epoch 2/4
60000/60000 [==============================] - 8s 140us/step - loss: 0.3276 - accuracy: 0.9141
Epoch 3/4
60000/60000 [==============================] - 8s 140us/step - loss: 0.1897 - accuracy: 0.9495
Epoch 4/4
60000/60000 [==============================] - 8s 140us/step - loss: 0.1256 - accuracy: 0.9645
127.0.0.1 - - [18/Jan/2020 20:54:53] "GET /train/ModelA HTTP/1.1" 200 -
That’s it with our short tutorial. If you liked it, subscribe to our Newsletter for more Tutorials. If you have any Questions feel free to contact us or leave a comment.