Menu

Build an API for a machine learning model in 5 minutes using Flask

Tim Elfrink
Jan 16, 2019 12:00:00 AM

Flask is a micro web framework written in Python. It can create a REST API that allows you to send data, and receive a prediction as a response.

As a data scientist consultant, I want to make impact with my machine learning models. However, this is easier said than done. When starting a new project, it starts with playing around with the data in a Jupyter notebook. Once you’ve got a full understanding of what data you’re dealing with and have aligned with the client on what steps to take, one of the outcomes can be to create a predictive model.

You get excited and go back to your notebook to make the best model possible. The model and the results are presented and everyone is happy. The client wants to run the model in their infrastructure to test if they can really create the expected impact. Also, when people can use the model, you get the input necessary to improve it step by step. But how can we quickly do this, given that the client has some complicated infrastructure that you might not be familiar with?

For this purpose you need a tool that can fit in their complicated infrastructure, preferably in a language that you’re familiar with. This is where you can use Flask. Flask is a micro web framework written in Python. It can create a REST API that allows you to send data, and receive a prediction as a response.

Create your model

Let me show you how this works. For the purpose of demonstration, I will train a simple DecisionTreeClassifier model on an example dataset which can be loaded from the scikit-learn package.

Once the client is happy with the model you have created, you can save it as pickle file. You can then open this pickle file later and call the function predict to get a prediction for new input data. This is exactly what we will do in Flask.

Run Flask

Flask runs on a server. This can be in the environment of the client or a different server depending on the client’s requirements. When running python app.py it first loads the created pickle file. Once this is loaded you can start making predictions.

Request predictions

Predictions are made by passing a POST JSON request to the created Flask web server which is on port 5000 by default. In app.py this request is received and a prediction is based on the already loaded prediction function of our model. It returns the prediction in JSON format.

Now, all you need to do is call the web server with the correct syntax of data points. This corresponds with the format of the original dataset to get this JSON response of your predictions. For example:

python request.py -> <Response[200]> “[1.]"

For the data we sent we got a prediction of class 1 as output of our model. Actually all you are doing is sending data in an array to an endpoint, which is transformed to JSON format. The endpoint reads the JSON post and transforms it back to the original array.

With these simple steps you can easily let other people use your machine learning model and quickly make a big impact.

Notes

In this article, I didn’t account for any errors in the data or other exceptions. This article shows how to simply start and learn from the models output, but needs a lot of improvements before it is ready to be put into production. This solution can be made scalable when creating a docker file with the API and hosting it on to Kubernetes so you can balance the load across different machines. But these are all steps to take when going from a proof of concept to a production environment.