The curious case of Spike Function
At the core of neuromorphic computing are spiking neurons as the computational unit. It turns out that deep networks of spiking neurons can be trained using the workhorse of modern AI, deep learning’s backpropagation algorithm, to provide workloads that are natively suited to exploit the efficiency and latency advantages of neuromorphic chips.
In this article, I will describe one of the key challenges in training deep spiking neural networks (SNNs): the spike function and its derivative, how the ill-defined spike function derivative hinders backpropagation training, and the ways to circumvent this problem. Let's take a tour of what, why, and how of the spike function and its derivative.
This blog focuses on one of the technical fundamentals of training deep SNNs. It is useful to understand to get a better intuition of what is happening during deep SNN training, but not absolutely necessary to start training deep SNNs.
1. Brief introduction
What are spiking events?
Biological neurons communicate with a brief impulse of charge. We call these events spikes. Spikes are sparse events in time. The neurons only communicate information when needed and the brief impulse helps the spiking event stand out. This is how evolution has found a way to communicate effectively in an inherently noisy pool of ionic interaction. It is debatable whether these events carry magnitude or whether the precise timing of spikes is key. However, it is certain that sparse message passing is the currency of information exchange in the brain.
Formally a spiking event can be represented as
where t is the time of the spike and a (typically 1) is the magnitude or payload of the spike.
Now, what is exciting about these spikes? It is sparse message-passing paradigm makes the brain very efficient. The human brain uses a mere 20 W. Now that's nothing compared to KWs of power consumed by current-day AI systems. Yet, the brain is a far more versatile system. In neuromorphic computing, we try to take cues from the brain to build efficient yet versatile systems. Spike-based messaging is a key property here.
Spiking neurons and the spike function
Spiking neurons are defined as the formal model of a biological neuron. They are the non-linear component of SNNs similar to the activation functions like ReLU, tanh, sigmoid (σ), etc., in standard Artificial Neural Networks (ANNs). They have the following key properties:
They have internal dynamics in their states u(t).
is the neuron dynamics. The temporal dynamics and statefulness are distinct features of a spiking neuron.
They respond with spikes. A spiking event occurs when a certain criterion is met. We call it spike condition. Typically, a spike condition marks the point when the neuron’s internal state exceeds a threshold.
(⋅) is the spike function that we are curious about.
(⋅) is the Heaviside step function.
Neuron dynamics is, perhaps, a topic for a separate blog. Here we will focus on the spike function. Some of the typical examples of spike functions are:
Leaky Integrate and Fire neuron: the neuron spikes when its voltage, u, exceeds a threshold, ϑ.
Resonate and fire neuron: the neuron spikes when its complex state z = u+iv exceeds a threshold, ϑ, when its phase is 0.
2. Spiking neuron meets backpropagation
It is well known that error backpropagation is the key supervised learning algorithm that has enabled highly scalable training of modern deep ANNs and the explosion of progress in AI over the past decade. We would like to apply backpropagation to train networks of spiking neurons to achieve similar functional breakthroughs in a way that reaps the benefits of efficient sparse message passing in neuromorphic computing. Whether or not backpropagation exists in the brain, this training approach might help us take a step towards realizing the versatility of the natural brain with spiking neural networks.
The error backpropagation algorithm fundamentally involves calculating gradients in a network’s parameter space that point in the direction of lower task error. For a deep network, this means that the derivative of neuron output activities must be computed layer by later backwards, starting from the network’s output, using the chain rule. Each component in the chain needs to be differentiable, and the product of derivatives along the chain is the final gradient (change) of the network parameter to be updated.
So to apply backpropagation to spiking neural networks, we have to evaluate the derivative of the spike function, and we immediately face a problem.
3. Spike Function () and its derivative (')
The key challenge we face is that the derivative of the step function is infinite at the origin. Even worse, the derivative of the Dirac-Delta function is completely ill-defined – both positive and negative infinite at the same point!
Despite the issues, we know for a fact that the argument of the spike function, u(t), is continuous and differentiable and its output, s(t), is continuous. So perhaps we can go back to the first principles and try to evaluate, if not approximate, the spike function derivative.
The first attempts to backpropagate with spiking neurons [1] estimated the spike function derivative at the point of the spike event as
This formulation is in fact a consequence of the implicit function theorem and an exact derivative. The derivative is only defined at the time of the spike event and enables backpropagating errors only at those precise times.
Similarly, a straight-through estimation of the spike function was proposed in[2].
This formulation does enable gradient backpropagation even when the neuron does not spike. However, it does not distinguish between active membrane potentials that are close to spiking and the potentials that are far from spiking.
When we treat the spiking mechanism in the discrete time domain, the issues with the spike function derivative become more manageable. In discrete time the Dirac delta is 1 (not undefined) when the condition is met. The spike function simplifies to
and its derivative is
Note that the spike function is a point process, i.e., its output is dependent on u at that point in time only. The formulation of the spike function derivative above holds in the continuous time domain in a probabilistic sense.
Surrogate gradients
In non-spiking ANNs, the breakthrough for handling the step non-linearity of multilayer perceptrons was the use of the sigmoid function, which in the limit of time scaling becomes the Heaviside step function.
We cannot change the spiking mechanism, however, we can relax the step function for the purposes of gradient calculation and use that as a proxy. This proxy substitution of the step function gradient during backpropagation is the crux of surrogate gradient. Surrogate gradient methods have proven to be an effective way to train SNNs in recent years[3][4].
The same kind of surrogate gradient proxy can be derived by looking at the expectation of the spike state change [5] which turns out to be the spike escape rate of a probabilistic spiking neuron.
With surrogate gradients, what we are really doing is peeking around the spiking event portal and trying to capture as much information from the neighborhood points. It is like looking at the ripples when a raindrop falls on the lake to try and figure out the surface beneath.
Nascent Delta functions ϕ(⋅)
Different SNN training works have proposed and used different types of surrogate gradient functions. What kind of functions are potential surrogate functions? What are the necessary conditions? In the limit, the surrogate gradient functions need to converge to a Dirac delta function. Such a family of functions is called nascent delta functions ϕ(⋅).
The relaxation (τ) factor controls the smoothness of the surrogate gradient. The degree of relaxation depends on case to case. An illustration of the effect of the relaxation parameter is shown below.
It is typically beneficial to consider a wider reach when the neuron activity is sparse whereas when there is a burst of spiking activity, it can be beneficial to look at a sharper surrogate gradient formulation. Smaller τ has a selective response which is suitable for operating in the precise spike-time regime. On the other hand, larger τ has less discriminative power of spike-times, as a result, operates in spike rate regime.
The scale (α) of the surrogate gradient can be used to control the flow of the gradient to the input layer. The effect of different scales on surrogate gradient response is shown below.
The scale (α) controls the amount of gradient propagated to the previous layer. A small α results in a vanishing gradient whereas a large α results in exploding gradient. A proper α needs to be chosen for proper gradient flow to the previous layer. This is an extra ability that the surrogate gradient enables, although, it needs to be treated with care.
The form (ϕ) of the surrogate gradient that works best is typically a matter of hyperparameter tuning and network initialization. Just like we can choose different kinds of ANN activation functions in the specific network architecture and task, in principle, we can choose different surrogate gradient functions that suit the task.
Different surrogate gradients put emphasis on different properties for e.g., box function only focuses on voltages higher than a threshold, the d_sigmoid function squashes the high values close to threshold, the double exponential function puts emphasis on higher peaks and relaxes smaller values.
4. Key takeaways
In this blog, we looked at the spike function and the challenge it poses for training deep SNNs with backpropagation. I showed how an approximation of the spike function derivative using nascent delta functions successfully tackles the problem. Such surrogate gradients like these allow gradients to properly flow during backpropagation. The spike function, however, is not the only unique problem that needs to be tackled in training deep SNNs. We will delve into the other challenges in the next installment of this blog series.
Bohte, S. M.; Kok, J. N. & La Poutré, J. A., SpikeProp: backpropagation for networks of spiking neurons. ESANN, 2000, 48, 17-37 ↩︎
Lee, J. H.; Delbruck, T. & Pfeiffer, M., Training Deep Spiking Neural Networks Using Backpropagation. Frontiers in Neuroscience, 2016, 10, 508 ↩︎
Neftci, E. O.; Mostafa, H. & Zenke, F., Surrogate gradient learning in spiking neural networks. IEEE Signal Processing Magazine, 2019, 36, 61-63 ↩︎
Yin, B.; Corradi, F., Bohté & S. M., Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks, Nature Machine Intelligence, 2021, 905-913 ↩︎
Shrestha, S. B. & Orchard, G., SLAYER: Spike Layer Error Reassignment in Time. Advances in Neural Information Processing Systems 31, Curran Associates, Inc., 2018, 1412-1421 ↩︎
This blog is a part of the blog series on Deep learning with Spiking Neural Networks.