Speed up neural networks 3.3x with Sparse Weight Activation Training [Breakdowns]
Sparsity is one of the next frontiers in Deep Learning. Don’t sleep on it.
Hey, it’s Devansh 👋👋
In my series Breakdowns, I go through complicated literature on Machine Learning to extract the most valuable insights. Expect concise, jargon-free, but still useful analysis aimed at helping you understand the intricacies of Cutting Edge AI Research and the applications of Deep Learning at the highest level.
If you’d like to support my writing, please consider buying and rating my 1 Dollar Ebook on Amazon or becoming a premium subscriber to my sister publication Tech Made Simple using the button below.
p.s. you can learn more about the paid plan here.
A little bit ago, I covered Google AI’s pathways architecture, calling it a revolution in Machine Learning. One of the standouts in Google’s novel approach was the implementation of sparse activation in their training architecture. I liked this idea so much that I decided to explore this in a lot more depth. That’s where I came across Sparse Weight Activation Training (SWAT), by some researchers at the Department of Electrical And Computer Engineering, University of British Columbia. And the paper definitely has me excited.
For ResNet-50 on ImageNet SWAT reduces total floating-point operations (FLOPS) during training by 80% resulting in a 3.3× training speedup when run on a simulated sparse learning accelerator representative of emerging platforms while incurring only 1.63% reduction in validation accuracy. Moreover, SWAT reduces memory footprint during the backward pass by 23% to 50% for activations and 50% to 90% for weights.
In this paper, I will be sharing some key insights from this paper. As we see ML operations scaling up, algorithms manipulating sparsity will be at the forefront of the cutting edge. You definitely don’t want to miss out.
Understanding the need for Sparse Activation
To understand why Sparse Activation and SWAT are so cool, think back to how Neural Networks work. When we train them, input flows through all the neurons, both in the forward and backward passes. This is why adding more parameters to a Neural Network adds to the cost exponentially.
Adding more neurons to our network allows for our model to learn from more complex data (like data from multiple tasks and data from multiple senses). However, this adds a lot of computational overhead.
Sparse Activation allows for a best-of-both-worlds scenario. Adding a lot of parameters allows our model to learn more tasks effectively (and make deeper connections). Sparse Activation lets you This allows the network to learn and get good at multiple tasks, without being too costly. The following video is more dedicated to that idea.
The concept kind of reminds me of a more modern twist on the Mixture of Experts learning protocol. Instead of deciphering which expert can handle the task best, we are instead routing the task to the part of the neural network that handles it best. This is similar to our brain, where different parts of our brain are good at different things. It’s not a coincidence that MoE is itself making a small comeback in large-scale model training. Delegation of tasks to smaller experts or sub-networks is an amazing way to balance scale and cost.
Now that you’re sold on the amazing world of sparsity, let’s dive into SWAT and what it does differently.
Breaking down SWAT
As the name suggests, SWAT sparsifies the weights and activation of different neurons. The process is relatively intuitive. It assumes that the biggest magnitudes are the most important. By the 80–20 principle, we can use only these important values, and set the other, less influential values to 0, eliminating them.
This is not a very hard algorithm to conceptualize. However, there are a few design choices that you would need to implement when applying such sparsification. We can choose to drop either weights, activations, gradients (calculated during backpropagation), or some combination of them. The SWAT team conducted a sensitivity analysis, checking how convergence was affected by each of them.
Figure 2 is an interesting one. The difference b/w dropping gradients and dropping weights+activations is clear. The former wrecks your performance. The authors themselves point out this phenomenon- ‘The “sparse weight and activation” curve shows that convergence is relatively insensitive to applying Top-K sparsification. In contrast, the “sparse output gradient” curve shows that convergence is sensitive to applying Top-K sparsification to back-propagated error gradients (5al ). The latter observation indicates that meProp, which drops back-propagated error-gradients, will suffer convergence issues on larger networks.’
This gives us a good starting off point- ‘In the forward pass use sparse weights (but not activations) and in the backward pass use sparse weights and activations (but not gradients).’ And that my lovely reader, is the basis of SWAT, explained simply.
There is another idea that I found extremely important. And that is how this team resurrects the dead. Yes, I’m completely serious.
Using Zombies and reviving Dead Neurons
During both forward and back propagation, only the most important weights are used for calculations. Most people would just stop here, and ignore the dead neurons for training. This is how most network pruning happens. However, these researchers are also woke.
They update both the active and the dead weights with the dense gradient calculations (remember we have already established that gradients should not be dropped). This adds a comeback mechanism of sorts for the previously dead weights. Just because it is dead for one iteration, doesn’t mean it won’t show up another time. This allows the network to explore network topologies (structures) dynamically.
This allows the algorithm to perform a beautiful balancing act. Pruning and dropout can be used to stop overfitting, improve generalization, and reduce the costs of training. However, reducing connectivity is tricky and tends to increase training loss. Especially if done wrong. This approach has the same effect as removing layers/neurons but dynamically updates to find the best configuration. Below is a full description of the algorithm-
Now for the final bit, let’s evaluate the results of SWAT on a bunch of tasks. Speedups and memory efficiency are useless if not backed by a great performance.
SWAT on tasks
The first graphic to look at compares the drop in accuracy vs reduction in training time/cost. If SWAT can reduce costs while keeping a reasonable performance, it will get a pass. SWAT is compared to it’s competitors twice, using sparsities of both 80% and 90%.
The performance is quite impressive. Given that most of the baselines are already very good, the slight reduction in cost at 80% (or even 90%) sparsity is not a huge concern. The 8x reduction of 90% SWAT-U is also pretty exciting and makes a case for this algorithm to be explored further. Next, let’s look at some raw numbers. Take a look at the following analysis from the authors-
For those of you curious, these are the aforementioned tables-
These results are quite spectacular. However, there are a lot for explorations. I would’ve liked to see more comparisons, approaching multiple kinds of tasks and policies. Given the utility of Data Augmentation in vision tasks these days, it would be interesting to compare how Sparsity would play a role there. What about tasks like generation, segmentation, etc? I think there are many areas where we can test out SWAT.
That being said, this paper is an amazing first step. The authors have established a pretty exciting algorithm and I will definitely be looking into this further. If you have any experience with SWAT or other sparsity-oriented algorithms/procedures share them with me. I’m definitely looking to learn a lot more.
That is it for this piece. I appreciate your time. As always, if you’re interested in reaching out to me or checking out my other work, links will be at the end of this email/post. If you like my writing, I would really appreciate an anonymous testimonial. You can drop it here. And if you found value in this write-up, I would appreciate you sharing it with more people. It is word-of-mouth referrals like yours that help me grow.
Reach out to me
Use the links below to check out my other content, learn more about tutoring, reach out to me about projects, or just to say hi.
Small Snippets about Tech, AI and Machine Learning over here
AI Newsletter- https://artificialintelligencemadesimple.substack.com/
My grandma’s favorite Tech Newsletter- https://codinginterviewsmadesimple.substack.com/
Check out my other articles on Medium. : https://rb.gy/zn1aiu
My YouTube: https://rb.gy/88iwdd
Reach out to me on LinkedIn. Let’s connect: https://rb.gy/m5ok2y
My Instagram: https://rb.gy/gmvuy9
My Twitter: https://twitter.com/Machine01776819
"This is why adding more parameters to a Neural Network adds to the cost exponentially."
Nonsense. Training cost is linear -- not even quadratic -- in the parameter count.
I know: nontechnical people throw around the term "exponentially" as if it only meant "a lot". But it has a very specific technical meaning, and we computer scientists should use it only when that meaning is applicable.
In fact, parameter counts themselves have been growing exponentially. Training these large networks would not be possible if the cost were even quadratic, never mind exponential, in the parameter count.