Deep Java Library (DJL), is an open-source library created by Amazon to develop machine learning (ML) and deep learning (DL) models natively in Java while simplifying the use of deep learning frameworks.
I recently used DJL to develop a footwear classification model and found the toolkit super intuitive and easy to use; it’s obvious a lot of thought went into the design and how Java developers would use it. DJL APIs abstract commonly used functions to develop models and orchestrate infrastructure management. I found the high-level APIs used to train, test and run inference allowed me to use my knowledge of Java and the ML lifecycle to develop a model in less than an hour with minimal code.
Footwear classification model
The footwear classification model is a multiclass classification computer vision (CV) model, trained using supervised learning that classifies footwear in one of four class labels: boots, sandals, shoes, or slippers.
The most important part of developing an accurate ML model is to use data from a reputable source. The data source for the footwear classification model is the UTZappos50k dataset provided by The University of Texas at Austin and is freely available for academic, non-commercial use. The shoe dataset consists of 50,025 labeled catalog images collected from Zappos.com.
Train the footwear classification model
Training is the process to produce an ML model by giving a learning algorithm training data to study. The term model refers to the artifact produced during the training process; the model contains patterns found in the training data and can be used to make a prediction (or inference). Before I started the training process, I set up my local environment for development. You will need JDK 8 (or later), IntelliJ, an ML engine for training (like Apache MXNet), an environment variable pointed to your engine’s path and the build dependencies for DJL.
DJL stays true to Java’s motto, “write once, run anywhere (WORA)”, by being engine and deep learning framework-agnostic. Developers can write code once that runs on any engine. DJL currently provides an implementation for Apache MXNet, an ML engine that eases the development of deep neural networks. DJL APIs use JNA, Java Native Access, to call the corresponding Apache MXNet operations. From a hardware perspective, training occurred locally on my laptop using a CPU. However, for the best performance, the DJL team recommends using a machine with at least one GPU. If you don’t have a GPU available to you, there is always an option to use Apache MXNet on Amazon EC2. A nice feature of DJL is that it provides automatic CPU/GPU detection based on the hardware configuration to always ensure the best performance.
Load dataset from the source
The footwear data was saved locally and loaded using DJL ImageFolder dataset, which is a dataset that can retrieve images from a local folder. In DJL terms, a Dataset simply holds the training data. There are dataset implementations that can be used to download data (based on the URL you provide), extract data, and automatically separate data into training and validation sets. The automatic separation is a useful feature as it is important to never use the same data the model was trained with to validate the model’s performance. The training validation dataset is used to find patterns in the data; the validation dataset is used to estimate the footwear model’s accuracy during the training process.
When structuring the data locally, I didn’t go down to the most granular level identified by the UTZappos50k dataset, such as the ankle, knee-high, mid-calf, over the knee, etc. classification labels for boots. My local data are kept at the highest level of classification, which includes only boots, sandals, shoes, and slippers.
Train the model
Now that I have the footwear data separated into training and validation sets, I will use a neural network to train the model.
Training is started by feeding the training data as input to a Block. In DJL terms, a Block is a composable unit that forms a neural network. You can combine Blocks (just like Lego blocks) to form a complex network. At the end of the training process, a Block represents a fully-trained model. The first step is to get a model instance by calling Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH). The getModel() method creates an empty model, constructs the neural network, and sets the neural network to the model.
The next step is to set up and configure a Trainer by calling the model.newTrainer(config) method. The config object was initialized by calling the setupTrainingConfig(loss) method, which sets the training configuration (or hyperparameters) to determine how the network is trained.
There are multiple hyperparameters set for training:
- newHeightand newWidth — the shape of the image.
- batchSize— the batch size used for training; pick a proper size based on your model.
- numOfOutput— the number of labels; there are 4 labels for footwear classification.
- loss— loss functions evaluate model predictions against true labels measuring how good (or bad) a model is.
- Initializer— identifies an initialization method; in this case, Xavier initialization.
- MultiFactorTracker— configures the learning rate options.
- Optimizer: an optimization technique to minimize the value of the loss function; in this case, stochastic gradient descent (SGD).
The next step is to set Metrics, a training listener, and initialize the Trainer with the proper input shape. Metrics collect and report key performance indicators (KPIs) during training that can be used to analyze and monitor training performance and stability. Next, I kick off the training process by calling the fit(trainer, trainingDataset, validateDataset, “build/logs/training”) method, which iterates over the training data and stores the patterns found in the model.
At the end of the training, a well-performing validated model artifact is saved locally along with its properties using the model.save(Paths.get(modelParamsPath), modelParamsName)method. The metrics reported during the training process are shown below.
Now that I have a model, I can use it to perform inference (or prediction) on new data for which I do not know the classification (or target). After setting the necessary paths to the model and the image to be classified, I obtain an empty model instance using the Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) method and initialize it using the model.load(Paths.get(modelParamsPath), modelParamsName) method. This loads the model I trained in the previous step. Next, I’m initializing a Predictor, with a specified Translator, using the model.newPredictor(translator)method. You’ll notice that I’m passing a Translator to the Predictor. In DJL terms, a Translator provides model pre-processing and post-processing functionality. For example, with CV models, images need to be reshaped to grayscale; a Translator can do this for you. The Predictor allows me to perform inference on the loaded Model using the predictor.predict(img) method, passing in the image to classify. I’m doing a single prediction, but DJL also supports batch predictions. The inference is stored in predictResult, which contains the probability estimate per label. The model is automatically closed once inference completes making DJL memory efficient.
The inferences (per image) are shown below with their corresponding probability scores.
Takeaways & Next Steps
I’ve been developing Java-based applications since the late ’90s and started my machine learning journey in 2017. My journey would’ve been much easier had DJL been around back then. I highly recommend that Java developers, looking to transition to machine learning, give DJL a try. In my example, I developed the footwear classification model from scratch; however, DJL also allows developers to deploy pre-trained models with minimal effort. DJL also comes with popular datasets out of the box to allow developers to instantly get started with ML. Before starting with DJL, I would recommend that you have a firm understanding of the ML lifecycle and are familiar with common ML terms. Once you have a basic level understanding of ML, you can quickly come up to speed on DJL APIs.
Amazon has open-sourced DJL, where further detailed information about the toolkit can be found on the DJL website and Java Library API Specification page. The code for the footwear classification model can be found on GitLab. Good luck on your ML journey and please feel free to reach out to me if you have any questions.