The softmax bottleneck is a special case of a more general phenomenon

by Emin Orhan

One of my favorite papers this year so far has been this ICLR oral paper by Zhilin Yang, Zihang Dai and their colleagues at CMU. The paper is titled “Breaking the softmax bottleneck: a high-rank RNN language model” and uncovers an important deficiency in neural language models. These models typically use a softmax layer at the top of the network, mapping a relatively low dimensional embedding of the current context to a relatively high dimensional probability distribution over the vocabulary (the distribution represents the probability of each possible word in the vocabulary given the current context). The relatively low dimensional nature of the embedding space causes a potential degeneracy in the model. Mathematically, we can express this degeneracy problem as follows:

\mathbf{H} \mathbf{W}^\top = \mathbf{A}

Here, \mathbf{A} is an N \times M matrix containing the log-probability of each word in the vocabulary given each context in the dataset: \mathbf{A}_{ij} = \log p(x_j|c_i), where N is the number of all distinct contexts in the dataset and M is the vocabulary size; \mathbf{H} is an N \times d matrix containing the embedding space representations of all distinct contexts in the dataset, where d is the dimensionality of the embedding space; and \mathbf{W} is an M\times d matrix containing the softmax weights.

In typical applications, d \sim O(10^{2-3}) and M \sim O(10^5), so d is a few order orders of magnitude smaller than M. This means that while the right-hand side of the above equation can be full-rank, the left-hand side is rank-deficient: i.e. \mathrm{rank}(\mathbf{H} \mathbf{W}^\top) = d \ll M =\mathrm{rank}(\mathbf{A}) (we’re assuming that N is larger than M and d, which is usually the case). This means that the model is not expressive enough to capture \mathbf{A}.

I think it is actually more appropriate to frame the preceding discussion in terms of the distribution of singular values rather than the ranks of matrices. This is because \mathbf{A} can be full-rank, but it can have a large number of small singular values, in which case the softmax bottleneck would presumably not be a serious problem (there would be a good low-rank approximation to \mathbf{A}). So, the real question is: what is the proportion of near-zero, or small, singular values of \mathbf{A}? Similarly, the proportion of near-zero singular values of the left-hand side, that is, the degeneracy of the model, is equally important. It could be argued that as long as the model has enough non-zero singular values, it shouldn’t matter how many near-zero singular values it has, because it can just adjust those non-zero singular values appropriately during training. But this is not really the case. From an optimization perspective, near-zero singular values are as bad as zero singular values: they both restrict the effective capacity of the model (see this previous post for a discussion of this point).

Framed in this way, we can see that the softmax bottleneck is really a special case of a more general degeneracy problem that can arise even when we’re not mapping from a relatively low dimensional space to a relatively high dimensional space and even when the nonlinearity is not softmax. I will now illustrate this with a simple example from a fully-connected (dense) layer with relu units.

In this example, we assume that the input and output spaces have the same dimensionality (i.e. d=M=128, using the notation above), so there is no “bottleneck” due to a dimensionality difference between the input and output spaces. Mathematically, the layer is described by the equation: \mathbf{y}= f(\mathbf{W}\mathbf{x}), where f(\cdot) is relu and we ignore the biases for simplicity. The weights, \mathbf{W}, are drawn according to the standard Glorot uniform initialization scheme. We assume that the inputs x are standard normal variables and we calculate the average singular values of the Jacobian of the layer, \partial \mathbf{y}/\partial \mathbf{x}, over a bunch of inputs. The result is shown by the blue line (labeled “Dense”) below:

svals_relu

The singular values decrease roughly linearly up to the middle singular value, but drop sharply to zero after that. This is caused by the saturation of approximately half of the output units. Of course, the degeneracy here is not as dramatic as in the softmax bottleneck case, where more than 99\% of the singular values would have been degenerate (as opposed to roughly half in this example), but this is just a single layer and degeneracy can increase sharply in deeper models (again see this previous post for a discussion of this point).

The remaining lines in this figure show the average singular values of the Jacobian for a mixture-of-experts layer that’s directly inspired by, and closely related to, the mixture-of-softmaxes model proposed by the authors to deal with the softmax bottleneck problem. This mixture-of-experts layer is defined mathematically as follows:

\mathbf{y} = \sum_{k=1}^K g(\mathbf{v}_k^\top \mathbf{x})f(\mathbf{W}_k\mathbf{x})

where K denotes the number of experts,  g(\mathbf{v}_k^\top \mathbf{x}) represents the gating model for the k-th expert and f(\mathbf{W}_k\mathbf{x}) is the k-th expert model (a similar mixture-of-experts layer was recently used in this paper from Google Brain). The mixture-of-softmaxes model proposed in the paper roughly corresponds to the case where both f(\cdot) and g(\cdot) are softmax functions (with the additional difference that in their model the input \mathbf{x} is first transformed through a linear combination plus tanh nonlinearity).

The figure above shows that this mixture-of-experts model effectively deals with the degeneracy problem (just like the mixture-of-softmaxes model effectively deals with the softmax bottleneck problem). Intuitively, this is because when we add a number of matrices that are each individually degenerate, the resulting matrix is less likely to be degenerate (assuming, of course, that the degeneracies of the different matrices are not “correlated” in some sense, e.g. caused by the same columns). Consistent with this intuition, we see in the above figure that using more experts (larger K) makes the model better conditioned.

However, it should be emphasized that the mixture-of-experts layer (hence the mixture-of-softmaxes layer) likely has additional benefits other than just breaking the degeneracy in the model. This can be seen by observing that setting the gates to be constant, e.g. g(\cdot)= 1 / K, already effectively breaks the degeneracy:

svals_fixedgate

I haven’t yet run detailed benchmarks, but it seems highly unlikely to me that this version with constant gates would perform as well as the full mixture-of-experts layer with input-dependent gating. This suggests that, in addition to breaking the degeneracy, the mixture-of-experts layer implements other useful inductive biases (e.g. input-dependent, differential weighting of the experts in different parts of the input space), so the success of this layer cannot be explained entirely in terms of degeneracy breaking. The same comment applies to the mixture-of-softmaxes model too.

Finally, I would like to point out that the recently introduced non-local neural networks can also be seen as a special case of the mixture-of-experts architecture. In that paper, a non-local layer was defined as follows:

\mathbf{y}_i = \frac{1}{C(\mathbf{x})} \sum_{j} g(\mathbf{x}_i,\mathbf{x}_j) f(\mathbf{x}_j)

with input-dependent gating function g(\cdot,\cdot) and expert model f(\cdot). The canonical applications for this model are image-based, so the layers are actually two-dimensional (a position dimension -represented by the indices i, j– and a feature dimension -implicitly represented by the vectors \mathbf{x}_j, \mathbf{y}_i etc.-), hence the inductive biases implemented by this model are not exactly the same as in the flat dense case above, but the analogy suggests that part of the success of this model may be explained by degeneracy breaking as well.

Note: I wrote a fairly general Keras layer implementing the mixture-of-experts layer discussed in this post. It is available here. Preliminary tests with small scale problems suggest that this layer actually works much better than a generic dense layer, so I am making the code available as a Keras layer in the hope that other people will be encouraged to explore its benefits (and its potential weaknesses). I am currently working on a convolutional version of the mixture-of-experts layer. The convolutional mixture of experts models have now been implemented and uploaded to the GithHub repository together with some examples illustrating how to use these dense and convolutional mixture of experts layers.

Advertisements