Designing A Modular Neural Network Library

PyTorch inspired API, but using C++

There are a lot of good open source neural network libraries out there today. Some of the most popular include Tensorflow, PyTorch, Keras, Caffe, Theano, which are all quite robust and can accomplish pretty much anything you want in the realm of deep learning. So why would we write our own?

The first reason is simply to learn. I always find a deeper understanding of a topic if I know how it works from the bottom up, as well as top down. Most of these neural network libraries have already taken care of designing all the modules, implementing back propagation, and releasing example models for you to play with. Understanding how these functions really work under the hood will help you implement, debug, and test new ideas faster. You will have some intuition into why a network might not be learning, as well as be able to come up with new ideas for new architectures researcher may not have thought of yet. For more on why you should have a deep understanding of the details I recommend reading Andrej Karpathy's why you should understand backprop.

We are choosing C++ for our library for the simple reason that it is what is under the hood of most neural network libraries. Many of the popular ones have python front ends, but use C++ on the backend for performance reasons. We will eventually speed up our implementations with BLAS and CUDA to squeeze as much performance out as possible. Implementing this library in C++ from scratch also gives you the knowledge and ability to put these networks in embedded systems, or custom hardware if you so choose. At the end, our training loop will look very similar to Keras or PyTorch, just with a C++ instead of Python. Here is an example of what model creation will eventually look like:

By the end of this series of posts, you will be able to train a neural network to recognize handwritten digits from the MNIST dataset in about 40 lines of C++ code. This is about the same amount of code you would need in your main function for any of the other frameworks I mentioned above. In fact, many of the models we create will be easily transferable to these other frameworks, since you will know the terminology. You will be able to follow our models line by line and port them to PyTorch or the framework of your choice if you prefer Python in the future.

So without further ado, let's hop into how we would design a succinct and extensible neural network API. When designing libraries, I like to start at a high level with our main function, and imagine as a user of the library, how I would like to use it. Of course this may change as you get into the implementation, but it gives a good starting point for the high level functions you want to expose. Let's aim for our main function to look something like the following.

First we will have a dataloader that will know how to generate training examples from some data path on disk.

Then we construct a feedforward network built up from trainable modules that can be of arbitrary size. In this case we have a stack of 3 fully connected linear layers, with ReLU activation functions, and a SoftMax probability function at the end.

Next we decide on the loss function, and optimization technique that will be used. CrossEntropyLoss will train our network to output classes with probabilities. Stochastic gradient descent will update our model's weights with each training step to achieve a local minimum in error.

Finally we will loop over our dataset given our dataloader, and train the model on batches of the data. The model will keep track of stats about how well it is performing so that we can monitor the learning process.

There may be terms in this code you are unfamiliar with, but don't worry, we will go over what everything means in detail as we implement them. Hopefully this gives you a sense of the general structure we are going for.

We are going to start from the top, and work our way down to make this happen. The very first thing we define in this main function is a data loader which will load handwritten characters into a data structure called a tensor. Tensors are the building blocks of many deep learning libraries, so they are a great place to start. Follow me to the next post to learn more about them, and start designing a Tensor class that will be able to handle a lot of the data storage we will need to do.


No comments yet.

Add your feedback below

Login to comment