Severely Theoretical

Computational neuroscience and machine learning

Month: January, 2018

Optimal synaptic memory consolidation

One of the fundamental challenges facing the brain is a trade-off between what we might loosely call learning and memory. On the one hand, we want our synapses to be highly plastic so that we can learn from our experiences quickly. On the other hand, we want our memories to be stable so that they are not easily overwritten by the constant barrage of experiences that we face every day and this seems to require rigid synapses. So, what is the optimal way to strike the balance between these two conflicting requirements?

Building on a series of previous papers, Benna & Fusi (2016) address this question in the context of a highly simplified model of synaptic modifications. In this model, the current value of a synapse is assumed to be a linear superposition of all the past modifications made to that synapse, weighted by a function of the time since that modification:

w_a(t) = \sum_{t^\prime} \Delta w_a (t^\prime) r(t-t^\prime)

Here a indexes the synapse, \Delta w_a (t^\prime) is the synaptic modification made at time t^\prime and r(t-t^\prime) is the weighting function that determines the effect of a modification made at time t^\prime on the current value of the synapse. Another important assumption is that the modifications are unit-magnitude, equal-probability depression or potentiation events, i.e. \Delta w_a(t^\prime)=\pm 1, that arrive at the synapse at a constant speed and are uncorrelated across time. It is assumed that there are N such synapses in total, uncorrelated with each other.

To formalize the plasticity-rigidity trade-off, they define a signal-to-noise ratio that quantifies how well we can decode a past synaptic modification event from the current values of the synapses. The signal is defined as the mean overlap between the current values of the synapses and the past synaptic modification pattern:

\mathcal{S}_{t^\prime}(t) \equiv \frac{1}{N}  \langle \sum_{a=1}^N w_a(t) \Delta w_a(t^\prime)  \rangle

and the noise is defined as the standard deviation of this overlap:

\mathcal{N}_{t^\prime}(t) \equiv \sqrt{ \frac{1}{N^2} \langle (\sum_{a=1}^N w_a(t) \Delta w_a(t^\prime) )^2 \rangle - \mathcal{S}^{2}_{t^\prime}(t) }

where the angle brackets denote averages over the stochastic modifications. With the assumptions made, it is then straightforward to derive that the signal-to-noise ratio for a memory inducted at time t^\prime is proportional to:

SNR_{t^\prime} (t) \propto r(t-t^\prime) / \sqrt{\sum_{l: t_l<t} r(t-t_l)^2}

Heuristically, we can see from this equation that the slowest decaying r(t) we can afford before the variance term diverges is r(t) \approx 1 / \sqrt{t} (to see this, note that plugging this in the denominator gives the harmonic series). One can make this argument more formal by writing down an objective (e.g. the area under the SNR(t) curve above an arbitrary threshold), optimizing with respect to r(t) and finding out that r(t) \approx 1 / \sqrt{t} indeed gives the correct answer, but I won’t do it here. This is the first main result of the paper (and to me the more important one). One can show that a system that displays the 1/\sqrt{t} decay can achieve the almost extensive N / \log N memory capacity (capacity is defined as the time at which the SNR(t) drops to an arbitrary fixed value, which they take to be 1 in the paper).

The rest of the paper is devoted to coming up with a hand-crafted, tractable dynamical system (which can be interpreted as describing the internal dynamics of a complex synapse model) that would display the desired 1/\sqrt{t} decay. The solution they come up with is based on the heat equation:

\frac{\partial u}{\partial t} = D \frac{\partial^2 u}{\partial x^2}

where D \propto g/C is the diffusion coefficient (where C is known as the heat capacity and g is the conductivity). Recall that the solution to the heat equation at x=0 already displays the required 1/\sqrt{t} decay. However, this naive solution is not good enough, because any discretized version of it would require too many variables (\sim \sqrt{N} variables) to achieve the 1/\sqrt{t} decay over a sufficiently long time. The “patch” they offer for this problem is to use an inhomogeneous heat equation with an exponentially increasing heat capacity: i.e. C(x) = \exp(\beta x) and an exponentially decreasing conductivity g(x) = \exp(-\beta x). This has the effect of slowing down the diffusion along the x direction and requires only \sim \log N variables to implement the desired 1\sqrt{t} decay in a discretized version.

Although I find this model informative for elucidating the kinds of mechanisms required for achieving extensive memory capacity in an efficient way under the studied scenario (e.g. multiple variables with exponentially differing time-scales interacting with each other), I have two basic issues with the paper. First, the assumptions and simplifications they make to render the problem tractable also make it somewhat uninteresting from a practical viewpoint: uncorrelated events arriving at uncorrelated synapses is not a very practically relevant scenario. In the studied scenario, there’s also no accounting of the importance of a synapse for prior experiences (cf. Kirkpatrick et al., 2017). Incidentally, Kirkpatrick et al. (2017) show that any synapse that optimizes their elastic weight consolidation objective (which combines the loss in the current episode and deviation from the current value of the synapse weighted by the importance of that synapse for prior episodes) automatically satisfies the optimal 1/\sqrt{t} decay under a scenario similar to that studied in Benna & Fusi (2016). This result suggests that this decay might naturally fall out of other objectives that the synapse would have to optimize.

Secondly, the desire to come up with a tractable model also restricts the authors to a linear model (i.e the heat equation). But, there is no reason synapses should be linear. Nonlinear models might, in fact, perform better. There’s not much room for improvement over the linear model proposed in the paper, since both the 1/\sqrt{t} decay and the nearly extensive memory capacity are optimal under the studied scenario, but perhaps the \log N discrete variables required in the linear model could be improved upon in a nonlinear model, or perhaps the robustness of the model could be improved with a nonlinear model. So, an alternative approach would be to model the synapse as a nonlinear dynamical system (i.e. an RNN), and optimize its parameters (one can think of more sophisticated variations on this idea to learn more interpretable solutions). I find this approach very useful, because (i) the hand-crafted solutions we humans come up with usually tend to be highly special (e.g. an attractor with all eigenvalues equal to 1), whereas nature prefers more generic solutions, (ii) with a hand-crafted model, we can rarely beat or even match the performance of an optimized nonlinear model even in relatively simple problems, so such models are useful for delineating the contours of what is achievable and for giving us valuable hints about more general principles.

Advertisements

Why is it hard to train deep neural networks? Degeneracy, not vanishing gradients, is the key

In this post, I will try to address a common misunderstanding about the difficulty of training deep neural networks. It seems to be a widely held belief that this difficulty is mainly, if not completely, due to the vanishing (and/or exploding) gradients problem. “Vanishing gradients” refers to the gradient norms becoming exponentially smaller for parameters deeper in the network. Smaller gradients mean parameters changing ever so slowly, and so learning gets stuck until the gradients become large enough, which could take exponentially long. This idea goes back, at least, to Bengio et al. (1994) and still seems to be everybody’s favorite explanation for why it is hard to train deep neural networks.

Let’s first consider a simple scenario: a deep linear network being trained to learn a linear mapping. Of course, deep linear networks aren’t interesting from a computational perspective, but Saxe et al. (2013) showed that learning dynamics in such networks can still be informative about the learning dynamics in nonlinear networks. So, let’s start  with this simple scenario. Here’s the learning curve and the initial gradient norms (before any training) for a 30-layer network (errorbars are standard errors over 10 independent runs).

plain_30_fold

I will shortly explain why these results are labeled “Fold 0” in the figure. The gradients here are with respect to layer activations (gradients with respect to parameters behave similarly). The network weights are initialized with the standard 1/\sqrt{n} initialization. The training loss decreases rapidly at first, but then quickly asymptotes at a suboptimal value. The gradients certainly don’t vanish (or explode), at least initially! They do become smaller as training progresses, but that is to be expected and it’s not clear at all that the gradients here are “too small” in any sense:

fold0_gradnorms

To show that convergence to a suboptimal solution here doesn’t have anything to do with the size of the gradient norms per se, I will now introduce a manipulation that will increase the gradient norms, but will worsen the performance. Here it is (in blue):

plain_30_fold1

So, what did I do? I simply changed the initialization in a pretty minimal way. Each initial weight matrix in the original networks is a 64 \times 64 matrix (initialized with the standard 1/\sqrt{n} initialization). In the networks shown in blue here, I just copied the first half of each initial weight matrix to the second half (the initial weight matrix is “folded” once, so that’s why I denote these as “Fold 1” networks). This reduces the rank of the initial weight matrices, and makes them more degenerate. Note that this manipulation is done only on the initial weight matrices, so no other constraint is imposed on learning once the training gets started. Here’s how the gradient norms look after the first few epochs:

fold1_gradnorms

So, I introduced a manipulation that increased the size of the gradient norms overall, yet the performance got significantly worse. Conversely, I will now introduce another manipulation that will shrink the size of the gradient norms, yet will improve performance substantially. Here it is (in green):

plain_30_ortho

As you may have guessed from the label, this new manipulation initializes the weight matrices to be orthogonal. Let’s recall that orthogonal matrices are the least degenerate matrices among matrices of a fixed (Frobenius) norm, where degeneracy can be measured in different ways, for example, as the fraction of singular values smaller than a given constant. Here’s how the gradients look after the first few epochs in this case:

ortho_gradnorms

So, what’s going on? If the size of the gradient norms per se isn’t responsible for training difficulties, then what is? The answer is that it’s the degeneracy of the model that, by and large, determines the training performance. Why is degeneracy harmful for training performance? Intuitively, the reason is that learning slows down substantially along degenerate directions in the parameter space, hence degeneracies reduce the effective dimensionality of the model. So, you might think that you’re fitting a model with \sim 130\mathrm{K} parameters, as in the above examples, but in reality you’re effectively fitting a model with substantially fewer degrees of freedom because of the degeneracies. So, the problem with the “Fold 0” and “Fold 1” networks above was that although the gradient norms are fine, the network’s available degrees of freedom contribute extremely unevenly to those norms: while a handful of degrees of freedom (non-degenerate ones) explain almost all of it, the vast majority (degenerate ones) don’t contribute anything at all (a conceptually helpful, but not strictly accurate, way of thinking about this may be to imagine only a few of the hidden units in each layer changing their activity in response to different inputs, while the vast majority just responding the same way independently of the input).

This reduction in the effective degrees of freedom can be quite substantial, as in the 1/\sqrt{n}-initialized random matrices above. As shown by Saxe et al. (2013), the product of such matrices becomes increasingly degenerate as the number of matrices multiplied (i.e. network depth) increases. Here’s an example with 1-layer, 10-layer and 100-layer networks respectively (adapted from Saxe et al. (2013)):

saxe_plotAs depth increases, the singular values of the product matrix become increasingly concentrated around 0, except for a vanishingly small fraction of singular values that become arbitrarily large. This result isn’t just relevant for linear networks. A similar thing happens in nonlinear networks too: as depth increases, the responses of the hidden units in a given layer become increasingly lower-dimensional, i.e. increasingly degenerate. In fact, this “degeneration process” can proceed much more quickly with depth in nonlinear networks with hard-saturating boundaries as in ReLU networks.

A nice visualization of this degeneration process is presented in a paper by Duvenaud et al. (2014):

drawing-1As depth increases, the input space (shown on the top left) gets distorted into increasingly thin filaments with only a single direction (orthogonal to the filament) at each point in the input space affecting the network’s response (it may be a bit hard to extrapolate from this figure how input spaces of more than two dimensions will behave under a similar mapping, but it turns out they become “hyper-pancakey” locally, i.e. there is a single direction at each point, orthogonal to the pancake surface, the network is sensitive to). Along that sensitive direction, the network in fact becomes hyper-sensitive to variations.

Finally, I can’t resist mentioning my own paper (with Xaq Pitkow) at this point. In this paper, through a series of experiments, we argue that the degeneracy problem I discussed in this post severely afflicts training in deep nonlinear networks, and that one of the ways (and quite possibly the single most important way) in which skip connections help training in deep networks is through breaking such degeneracies. I suspect that other methods like batch normalization or layer normalization that also help training in deep networks also work at least partly through a similar degeneracy-breaking mechanism, in addition to any other potentially independent mechanisms such as reducing the internal covariate shift as originally proposed. It is well-known, for example, that divisive normalization is an effective way of decorrelating the responses of hidden units, which in turn can be seen as a degeneracy-breaking mechanism.

Update (1/5/18): I should also mention this important recent paper by Pennington, Schoenholz & Ganguli. Orthogonal initialization completely eliminates degeneracies in linear networks, but not in nonlinear networks. In this paper, they provide a method to calculate the entire singular value distribution of the Jacobian of a nonlinear network and show that a depth-independent, non-degenerate singular value distribution can be achieved, with careful initialization, in networks with hard-tanh nonlinearities, but not in ReLU networks. The empirical results show that networks with depth-independent, non-degenerate singular value distributions train orders of magnitude faster than networks whose singular value distributions become wider (higher variance) and more degenerate with depth. This is a powerful demonstration of the importance of eliminating degeneracies and controlling the entire singular value distribution of a network, not just its mean.