November 24, 2021

The beauty of efficient neural networks

No items found.
Subscribe to newsletter
Share this post

SotA machine learning models are growing so large in size, that it’s impossible for hardware to keep up. What are the symptoms of this growing problem and is there an antidote?
In this blog post, we’ll start by discussing the semantics of efficient models and why you would want to use them. Then we’ll take a deep dive into what techniques you can use to end up with smaller, faster and more sustainable machine learning models.


Efficiency and machine learning models

A machine learning model is more efficient if we can achieve one or more of the following:

  • reduce the memory footprint
  • decrease the amount of compute needed

without resulting in a significant drop in performance. The reduction in memory footprint or compute can be at training time, inference time, or both. It’s important to note that these improvements don’t always go hand in hand with each other. It is for example possible that a model needs more compute to train, but takes up less memory and is faster at inference time.

Why do we need it?

A first reason follows from an observation of the growth in the size of SotA machine learning models. When comparing GPT with DeepSpeed, we see a 10.000-fold increase in the number of parameters of a model in only 2 years' time.

Size of SotA NLP models through time (visuals from EMNLP tutorial on high-performance NLP)

Thus, models are growing so fast in size, that it’s impossible for hardware to keep up. This has a number of drawbacks. For one, this imposes a financial barrier on the models which in turn makes them less accessible to people with regular hardware. A recent poll on Twitter suggests that even most Ph.D. researchers only have basic hardware at their disposal.

Furthermore, when you train your model you can expect to have very powerful GPUs available. However, in production you’ll often be constrained by the limitations of the target hardware. This is especially important when the goal is to deploy the machine learning model on-premise or on edge.

There are multiple reasons to invest time into optimizing your model for efficiency. For example when your model is too large for the target hardware and it doesn’t fit in memory, you can use model compression techniques to reduce the size of your weights file. This is especially useful for on-premise and edge setups because if you can successfully compress your model, you can keep on using the state of the art without any hardware upgrades. Another use case is to speed up the inference time of your model. Most efficient model techniques reduce the number of computations needed to make a prediction or make them less expensive. Lastly, it can significantly reduce the energy consumed by a model which can decrease the costs for long-term running models in the cloud or notable battery life improvements for edge devices.


Pruning is a family of methods that focuses on removing redundant stuff while only minimally impacting performance.

Pruning is a pretty old concept and was already used in the ‘80s to try and reduce the size of decision trees. So it might not come as a surprise that there exist many different types of pruning, and they can be differentiated by:

  • The pruning criteria, i.e. how do we decide what to prune?

The selection can be based on any easy or complex function of the weights, the activations, the gradients or combination of any of them. A lot of research has been done on this topic in recent years and a nice summary paper on the topic that was recently released by Uber can be found here. However, let’s not get lost in the details, in general the most popular pruning criteria is simply removing the smallest weights in terms of absolute value.

  • What is pruned, e.g. connections, channels or entire layers.

Connection, channel and layer pruning refer to the smallest unit that can be pruned in a single step, based on the pruning criteria. The choice of which pruning method to choose is a trade-off between size reduction and speed improvement. With connection pruning, you can remove more connections without a significant drop in accuracy as you only remove the connections that minimally contribute to making a prediction. With a more structured approach such as channel pruning or layer pruning you will respectively need to choose the least important channels or layers to prune which will inevitably also contain some more important connections and thus you’ll be able to prune fewer connections before experiencing a significant drop in accuracy.

On the other hand, connection pruning will only introduce a lot of 0’s in the matrix multiplications but as the multiplications still have to be done they will barely impact inference speed. With channel pruning entire matrix multiplications can be skipped, and with layer pruning even entire steps of the forward pass. Thus, the latter 2 methods will have a more significant impact on inference time.

  • When the pruning happens, e.g. before or after training.

The standard approach is to first train a large neural network and then prune away unnecessary components based on the results of the training process. The holy grail of pruning however is to find these smaller network structures before training, as this can significantly reduce the training costs. This is better known as the lottery ticket hypothesis and has been a hot research topic since its introduction in 2018. However, a technique that’s generalizable and easily applicable still has to be developed, and thus for the time being pruning after training is still the way to go.

  • How many times we prune.

Here the choices are once or iteratively. In general, the iterative approach with a fine-tuning step in between yields the best results as can be seen in the figure below. However, there’s no free lunch as this approach also takes up the most time and computation.

Pruning results from this paper, with vs without retraining and one-shot vs iterative.

Practical advice when pruning

So what if you want to try pruning yourself? Well, this depends on if you’re using TensorFlow or PyTorch.
At the moment, TensorFlow only supports one kind of pruning, connection pruning. They do give the option to only prune some types of layers, for example only the dense layers. Providing other pruning methods is on their roadmap, but for now there are no pruning methods available that will help you speed up your model.
PyTorch however, supports both structured and unstructured pruning methods and provides an easily extensible pruning class that gives you the power to implement your own pruning methods if you’d ever want to. On top of the standard pruning methods that come with PyTorch, you can find multiple open-source pruning toolkits that provide even more options. So with respect to pruning, PyTorch takes the cake.


A second technique that is well-known and really powerful is quantization. The basic idea is surprisingly simple: it’s a reduction in the precision of the numbers used to represent the weights of a model. So for example, assume we have a model where the weights are saved as 32-bit floats, then by converting them to 8-bit integers we already have a 4x reduction in size, and at least a 50% speed improvement. It’s quite surprising if you think about the fact that you can represent 2²⁴ more numbers with 32 bits vs 8 bits. But apparently ML models don’t care that much, since most of the time we can achieve similar performance.

A distinction can be made between uniform and
quantization. Uniform means that the quantization levels are equally spaced, non-uniform if they’re unequally spaced.
Uniform quantization is the easier of the two, you only have one degree of freedom and that is, how many bits am I going to use to represent the weights after quantization.
However, with non-uniform quantization you have the extra degree of freedom on how are you going to divide the quantization levels?
An interesting approach, that has been successful in research, is to do a k-nearest neighbor search and divide the quantization levels accordingly. This has the benefit that lots of connections will share the same weight which allows for a large compression of the original network. However, due to some optimization issues, there is still more research needed for this approach to become applicable in the wild.

Another smart trick we can use is quantization aware training, which simulates inference-time quantization during training in the forward passes. This induces some quantization error which is accumulated in the total loss of the model and they try to reduce it by adjusting the parameters accordingly. This makes our model more robust to quantization later on.
Both TensorFlow and PyTorch support quantization and quantization aware training.

Knowledge distillation

One of the central observations behind knowledge distillation is that there’s a difference in requirements during training and inference time. In many cases, the goal at training time is to learn and extract structure from large amounts of data. This is where large models, the “teacher”, with a lot of capacity are able to shine. At inference time however, these large models are cumbersome to work with (due to the typically more stringent latency and memory requirements) and smaller models, the “student”, would be much more suitable.

In knowledge distillation, the goal is to combine the best of both worlds. In the first step, a large model is trained on a large amount of data. In a second step, the knowledge captured by the large models is “distilled” into a much smaller model. But what do we mean by knowledge? And how can we “distill” this knowledge into a smaller model?

When talking about knowledge in the context of knowledge distillation, we’re talking about the learned relationship between input and output vectors. This relationship can tell us a lot about the way a model generalizes to unseen data (one of the reasons why large models typically perform better than smaller ones is because they are able to generalize better).
Take classification for example. Although the output probability for the target class will by far be the highest, the output probabilities for the other classes still contain a lot of information. It’s exactly these relative differences between the small probabilities of the incorrect classes that can provide a lot of information about how the model generalizes.

Finally, the knowledge learned by the teacher should somehow be transferred to the student. This is done by adding an extra loss term to the usual loss function called the distillation loss. Again taking the example of classification, we’re not only interested in the hard label (i.e. the one-hot encoded vector representing the class) but also in the soft label which is generated by the teacher (in order to magnify the small differences between probabilities of incorrect classes, a temperature parameter is added to the final softmax layer where the higher the temperature, the smoother the output distribution.

Knowledge distillation makes use of an extra distillation loss term in the loss function (no dogs were harmed during distillation).

Smart optimization tricks

The previous approaches all have in common that they change the network somehow (remove some parts, reduce the precision of the weights or even train a smaller one altogether). The two techniques discussed below don’t really change the network itself, but they do make the optimization procedure more efficient in terms of the amount of memory that is required at training time.

Gradient checkpointing

In gradient checkpointing, we are basically reducing the amount of memory needed at the cost of an increase in compute needed. This trade-off is especially useful since it’s typically much easier to wait a bit longer for your model to be trained than to get some more VRAM.
The way gradient checkpointing works is by not keeping all activations in memory all the time, but only keeping checkpoints in between. Whenever certain activations are needed again (e.g. when calculating gradients in the backward pass) you simply recalculate them again starting from the closest checkpoint.

Without gradient checkpointing, all activations are kept in memory (visualisation from this blogpost).
With gradient checkpointing, activations are only stored at certain checkpoints and are recalculated when necessary (visualisation from this blogpost).

Gradient accumulation

If you’ve ever tried to fine-tune a relatively large transformer-based model (e.g. BERT, T5, …) on a modest GPU, you have without a doubt encountered your fair share of out-of-memory exceptions. The issue is that in many cases you want to fine-tune these models using larger batch sizes than you can fit in memory. Enter gradient accumulation. As the name implies, it allows you to accumulate gradients for a number of training steps before doing any weight updates. This way you can simulate larger batch sizes without actually needing the memory requirements to fit them in memory.

Gradients are accumulated for and weight updates are only performed after a number of training steps, effectively simulating larger batch sizes using a smaller amount of memory.

Beyond general techniques

All techniques discussed in the previous sections are (at least in theory) generally applicable to any neural network architecture, but additional efficiency gains can be obtained by focusing on specific architectures.
One very relevant example is the transformer architecture which has dominated the state of the art in many fields (especially NLP) for a few years now and keeps on pushing the limits in terms of model sizes.

This is all good and well if you work for Google or OpenAI, but for many people the hardware requirements are very cumbersome. It should therefore not surprise you that a lot of research has been focused on trying to make the transformer architecture (and more specifically the self-attention mechanism) more efficient. An extensive overview of these techniques are out of scope for this blogpost, but the talk about high performance natural language processing at EMNLP 2020 is a great resource to dive into this topic.

Related posts

View all
No results found.
There are no results with this criteria. Try changing your search.
Large Language Model
Foundation Models
Structured Data
Chat GPT
Voice & Sound
Front-End Development
Data Protection & Security
Responsible/ Ethical AI
Hardware & sensors
Generative AI
Natural language processing
Computer vision