Although ubiquitously used in large-scale language modeling , the necessity of the causal mask is seldom questioned in the literature. "The causal mask is needed to prevent information leakage from future tokens" is a commonly encountered, almost dogmatically repeated phrase. However, among researchers and practitioners alike, there exists a certain confusion around what the causal mask is and why we actually need it.

1

A primer on the causal mask

The confusion about the causal mask already starts with its name. The original Transformer paper does not mention the term causal mask at all. While we cannot definitively pin-point the origin of the term, its first well-known appearance is in the T5 paper , where it is used to describe the triangular mask that is applied to the attention weights in the self-attention mechanism (Figure 1, centre). The mask being triangular has the effect of only allowing information from previous tokens to be used in the computation of the current token, which already leads us to the first common misconception: For causal LMs at inference time, even if $$n$$ tokens have already been generated, token $$k$$ with $$k < n$$ cannot attend to token $$j$$ with $$k < j < n$$, even though token $$j$$ is already known. From an information-theoretic perspective, it is clear that this is suboptimal, and indeed recent work has investigated the algorithmic deficiency of causal language models .

Figure 1: A schematic of full attention, causal masking, and prefix masking. Figure from . Used under CC-BY 4.0 license.

When solely looking at inference time, what we ideally want is full attention, i.e. every (generated) token can attend to every (generated) token (Figure 1, left). This is the no-mask regime, where the attention weights are not modified at all. Yet, all LLMs also use the causal mask at inference time since omitting the mask would lead to a distribution shift between training and inference, thus impairing performance. Since the causal mask is needed at inference time because it was used during training, a natural question to ask is: Why do we need the causal mask during training?

2

On the necessity of the causal mask

For brevity's sake, we will focus on the GPT-style pre-training regime , where the model is trained to predict the next token given the previous tokens. One of the key advantages of the Transformer architecture over classical RNNs is that we can predict all tokens of a sequence in parallel. Specifically, this is achieved using a technique called teacher-forcing , where instead of using the tokens generated by the model, we use the ground-truth previous tokens to predict each 'next token'. Clearly, during the model prediction of token $$k$$, its computation should not use information from token $$k$$ and beyond, which brings us to the second common misconception: The causal mask is not needed to prevent information leakage from future tokens during training.

To illustrate my point, suppose we have a sequence of tokens $$x_1, x_2, x_3, x_4$$ and we want the model to use tokens $$x_1, x_2, x_3$$ to predict token $$x_4$$. We need to prohibit tokens from attending to token $$x_4$$ when predicting token $$x_4$$, but we do not need to prohibit token $$x_1$$ from attending to tokens $$\{x_2, x_3\}$$, or token $$x_2$$ from attending to token $$x_3$$, which is exactly what the causal mask does. Instead of a triangular mask, using a block-sparse mask would suffice to prevent information leakage from future tokens, while allowing for all tokens in the context to attend to each other.

The illustration above only depicts the case of a single token prediction during training, raising the question of whether the point holds for parallel training as well. In fact, the causal mask is neither needed for parallel training, nor for teacher-forcing, and a block-sparse mask suffices for both.

3

What about PrefixLMs?

We can motivate the use of block-sparse masks by looking at PrefixLMs , which are language models originally designed to solve sequence-to-sequence tasks. In PrefixLMs, the model is trained to predict the next token given the previous tokens and a so-called prefix which is input to the model. During supervised training of PrefixLMs, the prefix could be a prompt, a question, or any other kind of context that is given to the model. The model is then trained to predict the answer or continuation of the prompt. Since the prefix is fixed and not predicted by the model, the tokens in the prefix are not masked at all, while the rest of the sequence is causally masked (Figure 1, right).

This masking procedure is specifically tailored towards the regime of sequence-to-sequence tasks under explicit supervision. However, the great abundance of textual data available on the Internet is unlabeled and largely unstructured, in which case PrefixLM training involves randomly splitting text into prefixes and targets. In this case, causal language modeling leads to much denser supervision than prefix language modeling, since the prefix does not provide a direct supervisory signal to the model.

This motivates a procedure that benefits from both dense supervision as well as full attention: Taking each token as the sole target and using all previous tokens as the prefix. This is exactly the block-sparse mask and nothing inherently prohibits parallelization in this regime. However, the block-sparse mask depends on the position of the token in the sequence, which means that a naïve implementation would require $$n$$ times as much memory and $$n$$ times as much computation, since everything after the first attention map computation is token-position dependent (where $$n$$ is the sequence length). Thus, the main challenge in moving beyond the causal mask is mitigating the memory and compute overhead that comes with that.

Contributions

FS worked on all aspects of this post, including research, analysis and writing. This blog post has benefited from various discussions with senior colleagues, among others Preetum Nakkiran and Thomas Scialom.