Compact precise definition of a transformer function

Although I’ve been repeatedly advised it’s not a good social strategy, a glorious way to start a research paper is with specific, righteous criticism of your anonymous colleagues:For read-ability, I have dropped the citations and section references from these quotes without marking the ellipses.a  

Transformers are deep feed-forward artificial neural networks with a (self)attention mechanism. They have been tremendously successful in natural language processing tasks and other domains. Since their inception 5 years ago, many variants have been suggested. Descriptions are usually graphical, verbal, partial, or incremental. Despite their popularity, it seems no pseudocode has ever been published for any variant. Contrast this to other fields of computer science, even to “cousin” discipline reinforcement learning.

So begin Phuong & Hutter in a great, rant-filled paper that “covers what Transformers are, how they are trained, what they’re used for, their key architectural components, tokenization, and a preview of practical considerations, and the most prominent models.” As an exercise, in this post I’m dig into the first item by writing down an even more compact definition of a transformer than theirs, in the form of a mathematical function rather than pseudocode, while avoiding the ambiguities rampant in the rest of the literature. I will consider only what a single forward-pass of a transformer does, considered as a map from token sequences to probability distributions over the token vocabulary. I do not try to explain the transformer, nor do I address other important aspects like motivation, training, and computational.

(This post also draws on a nice introduction by Turner. If you are interested in understanding and interpretation, you might check out — in descending order of sophistication — Elhage et al., Molina, and Lee & Trott. Corrections to this post are welcome. EDIT: See also the recent pseudocode from Carpenter.)

As detailed well by Phuong & Hutter, the basic transformer structure can be used in several overall “configurations” (e.g., sequence classification, or sequence-to-sequence prediction for translation). We will present transformers as used for next-token prediction. There are many other smaller ways of varying the architecture (see a sample at the end of this post), so I have tried to pick a “standard” version, but I don’t have enough experience in the literature or industry to know confidently which are standard/promising vs speculative/unpromising.

Preliminaries

First, our basic notation: The matrix transpose is denoted “{}^\intercal” and the positive integers are \mathbb{Z}_{+} = \{1,2,3,\dots\}. The set of positive integers up to N is [N]=\{1,2,\dots,N\} and we use \mathbf{z} = [\mathbf{z}_n]_{n=1}^{N}= [\mathbf{z}_n]_{n} to express a vector \mathbf{z} of length N with elements \mathbf{z}_n.

Our token vocabulary (e.g., ascii characters, or English words) we denote by \mathcal{V}\cong[V]=\{1,2,\dots,V\}. A sequence is a vector of tokens, s\in\mathcal{S} = \mathcal{V}^{T}, and a probability distribution over the vocabulary is p\in\mathcal{P}_{\mathcal{V}} = \left\{ p \in [0,1]^{\left|\mathcal{V}\right|} \, \middle|\, \|p\|_1=1 \right\}. Our goal is to construct a next-token predictor: G:\mathcal{S}\to\mathcal{P}_{\mathcal{V}}.

Finally, we will use these common nonlinear functions:

  • Rectified linear unit, \mathrm{ReLU}:\mathbb{R}\to\mathbb{R}

    (1)   \begin{align*} \mathrm{ReLU}(x) := \max(x,0). \end{align*}

    When \mathrm{ReLU} is applied to vectors, we take it to act element-wise.

  • Layer normalization (aka z-scoring), \mathrm{LayerNorm}:\mathbb{R}^{N}\to\mathbb{R}^{N}

    (2)   \begin{align*} \mathrm{LayerNorm}(\mathbf{z}):= \frac{\mathbf{z}-\mathrm{mean}(\mathbf{z})}{\sqrt{\mathrm{var}(\mathbf{z})}},  \end{align*}

    with \mathrm{mean}(\mathbf{z}) := N^{-1}\sum_{n=1}^{N} \mathbf{z}_n and \mathrm{var}(\mathbf{z}) := N^{-1}\sum_{n=1}^{N} (\mathbf{z}_n-\mathrm{mean}(\mathbf{z}))^2, as usual.

  • Softmax (aka Boltzmann distribution), \mathrm{softmax}:\mathbb{R}^N\to\mathbb{R}^N

    (3)   \begin{align*} \mathrm{softmax}(\mathbf{z}) := \frac{\left[\exp(\mathbf{z}_n)\right]_{n=1}^N}{\sum_{n=1}^N \exp(\mathbf{z}_n)}. \end{align*}

Hyperparameters, parameters, and activations

Model hyperparameters (i.e., architectural choices)

  • Number of layers: L \in \mathbb{Z}_{+}
  • Number of attention heads: H \in \mathbb{Z}_{+}
  • Embedding, query-key, value-output, and feed-forward dimensions: D_{\mathrm{E}}, D_{\mathrm{QK}}, D_{\mathrm{VO}}, D_{\mathrm{FF}} \in \mathbb{Z}_{+}

Model parameters (i.e., trainable weights)

  • Token embedding matrix: \mathbf{W}^{\mathrm{emb}} \in \mathbb{R}^{V \times D_{\mathrm{E}}}, which is composed of the row vectors \mathbf{W}^{\mathrm{emb}}_v \in \mathbb{R}^{D_{\mathrm{E}}} indexed by token v\in \mathcal{V}\cong[V].
  • Position embedding matrix: \mathbf{W}^{\mathrm{pos}} \in \mathbb{R}^{T\times D_{\mathrm{E}}}, with \mathbf{W}^{\mathrm{pos}}_t \in \mathbb{R}^{D_{\mathrm{E}}} for t\in[T].
  • Token unembedding matrix: \mathbf{W}^{\mathrm{une}} \in \mathbb{R}^{D_{\mathrm{E}}\times V}
  • For each transformer layer \ell \in [L]:

    • Two feed-forward weight matrices: \mathbf{W}_{\mathrm{FF1}}^{(\ell)} \in \mathbb{R}^{D_{\mathrm{FF}}\times D_{\mathrm{E}}}, \mathbf{W}_{\mathrm{FF2}}^{(\ell)} \in \mathbb{R}^{D_{\mathrm{E}}\times D_{\mathrm{FF}}}
    • Two feed-forward bias vectors: \mathbf{b}_{\mathrm{FF1}}^{(\ell)} \in \mathbb{R}^{D_{\mathrm{FF}}}, \mathbf{b}_{\mathrm{FF2}}^{(\ell)} \in \mathbb{R}^{D_{\mathrm{E}}}
    • For each head h \in [H]:

      • Query and key matrices: \mathbf{W}_{\mathrm{K}}^{(\ell,h)}, \mathbf{W}_{\mathrm{Q}}^{(\ell,h)} \in \mathbb{R}^{D_{\mathrm{E}}\times D_{\mathrm{QK}}}
      • Value and output matrices: \mathbf{W}_{\mathrm{V}}^{(\ell,h)}, \mathbf{W}_{\mathrm{O}}^{(\ell,h)} \in \mathbb{R}^{D_{\mathrm{E}}\times D_{\mathrm{VO}}}

Activations (i.e., neuron outputs)

  • For each transformer layer \ell \in [L]:

    • Pre- and post-attention hidden state: \mathbf{X}^{(\ell)}, \mathbf{Y}^{(\ell)} \in \mathbb{R}^{T \times D_{\mathrm{E}}}, which are composed of the row vectors \mathbf{X}_t^{(\ell)}, \mathbf{Y}_t^{(\ell)} \in \mathbb{R}^{D_{\mathrm{E}}} indexed by position t\in[T].

Definition of a transformer function

Our definition makes use of two workhorses

  • Multi-head attention, \mathrm{attn}^{(\ell)}: \mathbb{R}^{T \times D_{\mathrm{E}}} \times \mathbb{R}^{D_{\mathrm{E}}} \to \mathbb{R}^{D_{\mathrm{E}}}

    (4)   \begin{align*} \mathrm{attn}^{(\ell)}(\mathbf{X}, \mathbf{z}) &:=  \sum_{h=1}^H       \mathrm{softmax}\left( \frac{\mathbf{z} \cdot \mathbf{W}_{\mathrm{Q}}^{(\ell,h)} \cdot {\mathbf{W}_{\mathrm{K}}^{(\ell,h)}}^\intercal \cdot \mathbf{X}^\intercal}{\sqrt{D_{\mathrm{QK}}}} \right) \cdot \mathbf{X} \cdot   {\mathbf{W}_{\mathrm{V}}^{(\ell,h)}} \cdot {\mathbf{W}_{\mathrm{O}}^{(\ell,h)}}^\intercal \end{align*}

  • Feed-forward network (aka multilayer perceptron, MLP), \mathrm{ffn}^{(\ell)}: \mathbb{R}^{D_{\mathrm{E}}} \to \mathbb{R}^{D_{\mathrm{E}}}

    (5)   \begin{align*} \mathrm{ffn}^{(\ell)}(\mathbf{z}) &:= \mathbf{W}_{\mathrm{FF2}}^{(\ell)} \cdot \mathrm{ReLU}\left( \mathbf{W}_{\mathrm{FF1}}^{(\ell)} \cdot \mathbf{z} + \mathbf{b}_{\mathrm{FF1}}^{(\ell)} \right) + \mathbf{b}_{\mathrm{FF2}}^{(\ell)} \end{align*}

    which has just one hidden layer.

We can then define our transformer, G:\mathcal{S}\to\mathcal{P}_{\mathcal{V}}, recursively with

(6)   \begin{align*} \mathbf{X}_t^{(\ell=0)} &= \mathbf{W}^{\mathrm{emb}}_{\mathbf{s}_{t}}  + \mathbf{W}^{\mathrm{pos}}_t \\ \mathbf{Y}_t^{(\ell\ge 1)} &= \mathrm{LayerNorm}\left(\mathbf{X}_t^{(\ell-1)} +  \mathrm{attn}^{(\ell)}\left(\mathbf{X}^{(\ell-1)},\mathbf{X}_{t}^{(\ell-1)}\right)\right) \\ \mathbf{X}_t^{(\ell\ge 1)} &= \mathrm{LayerNorm}\left(\mathbf{Y}_t^{(\ell)}+\mathrm{ffn}^{(\ell)}\left(\mathbf{Y}_t^{(\ell)}\right) \right) \\ G &= \mathrm{softmax}\left(\mathbf{X}_T^{(L)} \cdot \mathbf{W}^{\mathrm{une}}  \right) \end{align*}

Given a token sequence \mathbf{s}=[\mathbf{s}_t]_{t=1}^T, this defines the transformer next-token predictor, \mathbf{p}=G(\mathbf{s}). We can of course probabilistically extend any sequence \mathbf{s} to arbitrary length by computing \mathbf{p}=G(\mathbf{s}), sampling an element w_*\in\mathcal{V} according to the probability distribution \mathbf{p}, concatenating w_* to the end of the sequence \mathbf{s}, and iterating.

(Fin. The rest of this post is optional.)

Some comments

  • Using \mathbf{Q}=\mathbf{X}\cdot\mathbf{W}_\mathrm{Q}^{(\ell,h)} and \mathbf{K}=\mathbf{Y}\cdot\mathbf{W}_\mathrm{K}^{(\ell,h)}, one often sees attention presented with the highly symmetric looking softmax expression

    (7)   \begin{align*} \mathrm{softmax}\left(\frac{\mathbf{X} \cdot \mathbf{W}_\mathrm{Q}^{(\ell,h)} \cdot{\mathbf{W}_\mathrm{K}^{(\ell,h)}}^\intercal \cdot \mathbf{Y}^\intercal}{\sqrt{D_{\mathrm{QK}}}}\right) = \mathrm{softmax}\left(\frac{ {\mathbf{Q}} \cdot {\mathbf{K}}^\intercal}{\sqrt{D_{\mathrm{QK}}}}\right), \end{align*}

    but as far as I can tell the apparent elegance is misleading. The softmax function naturally acts on a vector. When applied to a matrix it is by convention taken to act column-wise or row-rise, i.e., collectively on the one index but independently on the other index, breaking the symmetry. We have kept the broken symmetry explicit in our above expression (4) for the attention function.

  • As the D_{\mathrm{VO}} \times D_{\mathrm{E}} weight matrices \mathbf{W}_{\mathrm{O}}^{(\ell,h)} and \mathbf{W}_{\mathrm{V}}^{(\ell,h)} only appear together in the product {\mathbf{W}_{\mathrm{O}}^{(\ell,h)}}^\intercal \cdot \mathbf{W}_{\mathrm{V}}^{(\ell,h)}, it is tempting to combine them into a single D_{\mathrm{E}}\times D_{\mathrm{E}} matrix. However, that matrix could have rank D_{\mathrm{E}}, whereas the original product \mathbf{W}_{\mathrm{O}}^\intercal \mathbf{W}_{\mathrm{V}} has maximum rank \max(D_{\mathrm{E}},D_{\mathrm{VO}}). Typically one chooses D_{\mathrm{VO}} = D_{\mathrm{E}}/H \ll D_{\mathrm{E}} for reasons of computational complexity. This all applies similarly to the pair \mathbf{W}_{\mathrm{K}}^{(\ell,h)} and \mathbf{W}_{\mathrm{Q}}^{(\ell,h)}. See here and here for more.

Variations

This list of a few variations may help you match up this definition with others you’ve seen. (There are many more variations not listed here.) Below I will drop the superscript indices “{}^{(\ell)}” and “{}^{(\ell,h)}“.

  • Our layer normalization follows Phuong & Hutter and the original work by Vaswani et al., where \mathrm{LayerNorm} is applied to the complete hidden state after being modified by multi-headed attention \mathrm{attn}^{(\ell)} and again after being modified by the feed-forward net \mathrm{ffn}^{(\ell)}. In contrast, Xioang et al. argue for the superiority of LayerNorm-ing only the changes (“residuals”) from these two mechanisms before they are added to the hidden layer, and this choice is used by Turner.
  • The layer normalization is also often generalized from \mathrm{LayerNorm}(\mathbf{z}) to \mathrm{LayerNorm}(\mathbf{z})\odot \boldsymbol\gamma+\boldsymbol\beta, where “\odot” is element-wise multiplication and where \boldsymbol{\gamma}, \boldsymbol{\beta}\in\mathbb{R}^{D_\mathrm{E}} are learned vectors in the embedding space for multiplicative and additive scaling, respectively.
  • The use of attention \mathrm{attn}\left(\mathbf{X},\mathbf{X}_{t'}\right)\right in Eq. (6) is known as bidirectional unmasked self-attention, but unidirectional, masked, and non-self versions are defined and explored elsewhere.
  • The linear transformations \mathbf{W}_{\mathrm{K}}, \mathbf{W}_{\mathrm{Q}}, \mathbf{W}_{\mathrm{V}}, and \mathbf{W}_{\mathrm{O}} can be generalized to non-homogeneous linear function by including additive biases, e.g., replacing \mathbf{W}_{\mathrm{Q}} \cdot \mathbf{Y}_t with \mathbf{W}_{\mathrm{Q}} \cdot \mathbf{Y}_t + \mathbf{c}_{\mathrm{Q}} where the bias \mathbf{c}_{\mathrm{Q}} is now an additional (trainable) vector of weights. (This makes them more closely resemble feed-forward networks.) However, some of these biases can be non-functional.
  • Rather than learned, the unembedding matrix is often taken to be simply the transpose of the token embedding matrix: \mathbf{W}_{\mathrm{une}} = \mathbf{W}_{\mathrm{emb}}^\intercal.
  • Rather than learned, \mathbf{W}_{\mathrm{pos}} is often taken to simply be fixed as a discrete Fourier transform of position: \mathbf{W}_{\mathrm{pos}}[2i,t] = \cos\left(t/T^{2i/D}\right) and \mathbf{W}_{\mathrm{pos}}[2i-1,t] = \sin\left(t/T^{2i/D}\right) for i=1,\ldots,D/2. This allows the model to handle arbitrary length sequences even if the number of long sequences in the training set is small.
  • One can modify the feed-forward network to use a larger number of layers, or a different nonlinearity than \mathrm{ReLU}.
[I thank Dan Ranard and Steve Howard for suggestions and feedback.]

Footnotes

(↵ returns to text)

  1. For read-ability, I have dropped the citations and section references from these quotes without marking the ellipses.
Bookmark the permalink.

Leave a Reply

Required fields are marked with a *. Your email address will not be published.

Contact me if the spam filter gives you trouble.

Basic HTML tags like ❮em❯ work. Type [latexpage] somewhere to render LaTeX in $'s. (Details.)