Expanding the Context Window

Introduction

In this lesson, we will discuss context windows in language models, their importance, and the limitations of the original Transformer architecture in handling large context lengths.

We explore various optimization techniques that have been developed to expand the context window, including ALiBi Positional Encoding, Sparse Attention, FlashAttention, Multi-Query Attention, and the use of large RAM GPUs.

We also introduce the latest advancements in this field, such as FlashAttention-2 and LongNet, which aim to push the context window to an unprecedented scale.

The Importance of The Context Length

The context window refers to the number of input tokens the model can process simultaneously. In current models like GPT-4, this context window is around 32K tokens. To put this into perspective, this roughly translates to the size of 50 pages. However, recent advancements have pushed this limit to an impressive 100K tokens (check Claude by Anthropic), equivalent to 156 pages.

The context length of an LLM is a critical factor for several reasons. Firstly, it allows the model to process larger amounts of data at once, providing a more comprehensive understanding of the context. This is particularly useful when you want to feed a large amount of custom data into an LLM and ask questions about this specific data.

For instance, you might want to input a large document related to a specific company or problem and ask the model questions about this document. With a larger context window, the LLM can scan and retain more of this custom information, leading to more accurate and personalized responses.

Limitations of the Original Transformer Architecture

The original Transformer architecture, however, has some limitations when it comes to handling large context lengths. The main issue lies in the computational complexity of the Transformer architecture. Specifically, the attention layer computations in the Transformer architecture have a quadratic time and space complexity with respect to the number of input tokens nn. This means that as the context length increases, the computational resources required for training and inference increase exponentially.

To understand this better, let's understand the computational complexity of the Transformer architecture. The complexity of the attention layer in the Transformer model is O(n2d+nd2)O(n²d + nd²), where nn is the context length (number of input tokens) and dd is the embedding size.

This complexity arises from two main operations in the attention layer: linear projections to get Query, Key, and Value matrices (complexity ~ O(nd2)O(nd²)) and multiplications of these matrices (complexity ~ O(n2d)O(n²d)). As the context length or embedding size increases, the computational complexity grows quadratically, making it increasingly challenging to process larger context lengths.

Optimization Techniques to Expand the Context Window

Despite these challenges, researchers have developed several optimization techniques to speed up the Transformer and increase the context length to 100K tokens. Let's explore some of these techniques:

  1. ALiBi Positional Encoding: The original Transformer uses Positional Sinusoidal Encoding, which lacks the ability to extrapolate to larger context lengths. ALiBi, or Attention with Linear Biases, is a positional encoding technique that can be used to train the model on a small context and then fine-tune it on a larger one.
  2. Sparse Attention: This technique reduces the number of computations by considering only some tokens when calculating the attention scores. This makes the computation linear with respect to n, significantly reducing the computational complexity.
  3. FlashAttention: This is an efficient implementation of the attention layer for GPU. It optimizes the memory utilization of the GPU by splitting the input matrices into blocks and computing the attention output with respect to these blocks.
  4. Multi-Query Attention (MQA): MQA optimizes the memory consumption of the key/value decoder cache by sharing weights across all attention heads when linearly projecting Key and Value matrices.
  5. Large RAM GPUs: You need a lot of RAM in the GPU to fit a large context. Therefore, models with larger context windows are often trained on GPUs with large RAM, such as 80GB A100 GPUs.

FlashAttention-2

Building on the success of FlashAttention, researchers have recently developed FlashAttention-2, a more efficient version of the algorithm that further optimizes the attention layer's speed and memory usage. This new version has been completely rewritten from scratch, leveraging the new primitives from Nvidia. The result is a version that is about 2x faster than its predecessor, reaching up to 230 TFLOPs/s on A100 GPUs.

FlashAttention-2 introduces several improvements over the original FlashAttention.

  • Firstly, it reduces the number of non-matmul FLOPs, which are 16x more expensive than matmul FLOPs, by tweaking the algorithm to spend more time on matmul FLOPs.
  • Secondly, it optimizes parallelism by parallelizing over batch size, number of heads, and the sequence length dimension. This results in significant speedup, especially for long sequences.
  • Lastly, it improves work partitioning within each thread block to reduce the amount of synchronization and communication between different warps, resulting in fewer shared memory reads/writes.
  • In addition to these improvements, FlashAttention-2 also introduces new features, such as support for head dimensions up to 256 and multi-query attention (MQA), further expanding the context window.

With these advancements, FlashAttention-2 is a step forward in expanding the context window (without overcoming the fundamental limitations of the original Transformer architecture).

LongNet: A Leap Towards Billion-Token Context Window

Building on the advancements in Transformer optimization, a recent innovation comes from the paper "LONGNET: Scaling Transformers to 1,000,000,000 Tokens". This paper introduces a novel approach to handling the computational complexity of the Transformer architecture, pushing the context window potentially to an unprecedented 1 billion tokens.

The core innovation in LongNet is the introduction of "dilated attention.” This novel attention mechanism expands the attentive field exponentially as the distance between tokens grows, thereby decreasing attention allocation exponentially as the distance increases. This design principle helps to balance the limited attention resources with the necessity to access every token in the sequence.

Image from the paper "
Image from the paper "LONGNET: Scaling Transformers to 1,000,000,000 Tokens". Building blocks of dilated attention used in LONGNET. It consists of a series of attention patterns for modeling short- and long-range dependency. The number of attention patterns can be extended according to the sequence length.

The dilated attention mechanism in LongNet achieves a linear computational complexity, a significant improvement over the quadratic complexity of the standard Transformer.

Image from the paper "
Image from the paper "LONGNET: Scaling Transformers to 1,000,000,000 Tokens". Comparison of computation complexity among different methods. N is the sequence length, and d is the hidden dimension.

Conclusion

In this lesson, we examined the limitations of the original Transformer architecture in handling large context lengths, primarily due to its quadratic computational complexity. We then explored various optimization techniques developed to overcome these limitations, including ALiBi Positional Encoding, Sparse Attention, FlashAttention, Multi-Query Attention, and the use of large RAM GPUs.

We also discussed the latest advancements in this field, such as FlashAttention-2, which further optimizes the speed and memory usage of the attention layer, and LongNet, a novel approach that introduces "dilated attention" to potentially expand the context window to an unprecedented 1 billion tokens.

These advancements are critical in pushing the boundaries of language models, enabling them to process larger amounts of data at once and providing a more comprehensive understanding of the context, leading to more accurate and personalized responses.