BrandPost: Detecting Pneumonia From Chest X-Ray Images in Java
In this blog post, we demonstrate how deep learning (DL) can be used to detect pneumonia from chest X-ray images. This work is inspired by the Chest X-ray Images Challenge on Kaggle and a related paper. We will also illustrate how artificial intelligence can assist clinical decision making with a focus on enterprise deployment. This work leverages a model trained using Keras and TensorFlow with this Kaggle kernel. And, the blog will look at generating predictions with this model using Deep Java Library (DJL), an open-source library to build and deploy DL in Java.
We chose Keras and DJL, two user-friendly DL tools to implement our image classifier. Keras is an easy to use, high-level API for deep learning that enables fast prototyping. DJL meanwhile offers Keras users simple APIs to deploy DL models natively in Java. Now, let’s dive into the tutorial.
Train and save your model using Keras
The first step is to train the image classification model. You can follow the instructions referenced in this kernel for a step-by-step guide. This model attempts to identify pneumonia by visually inspecting signatures in chest X-ray images. As a reference point, the images below compare the differences between 3 candidates that have normal lungs (left), bacterial pneumonia (middle), and viral pneumonia(right).
Fig.1 According to Fig.S6 of the paper, normal chest X-ray (left) shows clear lungs, bacterial pneumonia (middle) shows a focal lobar consolidation, and viral pneumonia (right) shows a more diffuse “interstitial” pattern in both lungs.
The training process consists of 3 steps: preparing the data, constructing the model, and training the model. You can download the dataset used to train this model using this link. This model is comprised of Depthwise Separable Convolution layers with partial pre-trained weights on ImageNet. Depthwise Separable Convolution layers have fewer parameters and are more efficient than comparable DL models. We also used transfer learning, a popular DL technique that adapts a model trained on one problem to a second related problem. Transfer learning leverages features already learned on a similar problem instating of developing a model from scratch and produces a more robust model quickly. For the first 2 layers in our model, we used the weights of a VGG network that is pre-trained on ImageNet, a much larger dataset.
You can just download the kernel notebook and run it locally to produce the model. Note that we need to save the model in the TensorFlow Saved Model format. You can just add the following one line at the end of the notebook. For more information about working with Keras models in DJL, see How to import Keras models in DJL.
If you want to run predictions directly with the pre-trained model, start off by downloading this model.
Load and Run Prediction using the Deep Java Library
Once you have the trained model, you can generate predictions using DJL. For the full code, see Pneumonia Detection. You can run predictions from the command line with the following command. Use -Dai.djl.repository.zoo.location to specify the location of your model.
The following is an example output:
The following sections walk you through the code in detail.
Import DJL library and TensorFlow engine
To run prediction on Keras models, you need the DJL high-level API library and the underlying TensorFlow engine. They can be imported using either Gradle or Maven. For more details, see Pneumonia Detection README. The following example uses Gradle to set up the dependencies:
Load model and run prediction
Next, we need to load our trained model. DJL provides simple and easy to use APIs to load models. You can load models using our model zoo or use your own models from your local drive. The following sample uses a local model zoo to load the model and run prediction. This can be done in just a few lines.
In this code, we first use a Criteria builder to tell model zoo what kind of model we want to load. We specify here that we want to load a model to take a BufferedImage as the input and predict a Classifications as the result. Then we can use ModelZoo.loadModel to find the matching model in the model repository. By default, DJL will look for models in our built-in repositories. We need to tell DJL to look in a custom path that contains the TensorFlow SavedModel format we obtained in the training section. We can do that by specifying `-Dai.djl.repository.zoo.location=models/saved_model`. After that, we create a new predictor to run predictions and print out the classification result. It’s pretty simple and straightforward.
Define your Translator
When we load the model, we also want to define how to pre-process input data and post-process output data from the model. DJL uses the Translator class for this function. Here is the implementation:
The translator converts input data format from BufferedImage to NDArray to align with the requirements of the model. It also resizes the image to 224×224 and normalizes the image by dividing by 255 before feeding it into the model. When running inference, you need to follow the same pre-processing procedure that was used during training. In this case, we need to match the Keras training code. After running prediction, the model outputs the probabilities of each class as an NDArray. We then translate these predictions back to our desired classes, namely “Pneumonia” or “Normal”.
That’s it! We’ve finished running prediction on X-ray images. You can now try to build more complicated models and try learning with larger datasets. Follow our GitHub, demo repository, and twitter for more documentation and examples of DJL!
Wei Lai is a software development engineer @AWS AI, working on deep learning.
Disclaimer: This blog post is intended for educational purposes only. The application was developed using experimental code. The result should not be used for any medical diagnosis of pneumonia. This content has not been reviewed or approved by any scientists or medical professionals.
Copyright © 2020 IDG Communications, Inc.