Skip to content
Chapter 8. Multi-Head Attention, Parallel Understanding

Chapter 8. Multi-Head Attention, Parallel Understanding

In Chapter 7, you learned how a single attention head works: it computes queries, keys, and values, compares them with scaled dot products, applies causal masking, and produces context-aware output vectors. But a single attention head can only capture one type of relationship at a time. If a head learns to connect verbs with their subjects, it cannot simultaneously learn to connect pronouns with their antecedents, or track syntactic structure, or identify semantic similarity. Real language is full of overlapping, simultaneous relationships, and a single attention pattern is not enough to capture them all. Multi-head attention solves this by running many attention heads in parallel, each learning different types of relationships, and combining their outputs into a single rich representation.


Why One Attention Pattern Isn’t Enough

Consider the sentence: “The cat that I saw yesterday was sleeping on the mat.”

A single attention head processing the token “was” might learn to attend strongly to “cat” (the subject of the verb). That is useful for subject-verb agreement. But at the same time, the model also needs to understand:

  • That “that” introduces a relative clause modifying “cat”
  • That “I” is the subject of “saw,” not of “was”
  • That “yesterday” is a temporal modifier of “saw”
  • That “on the mat” is a prepositional phrase modifying “sleeping”
  • That “the” before “mat” is a determiner for “mat,” not for “cat”

Each of these is a different type of linguistic relationship. A single set of W_Q, W_K, W_V weight matrices produces a single query, key, and value per token, which means a single attention pattern per token. That pattern can focus on one type of relationship (say, subject-verb agreement), but it cannot simultaneously focus on a completely different type (say, determiner-noun attachment).

The solution is straightforward: run multiple attention heads in parallel, each with its own set of W_Q, W_K, W_V weight matrices. Each head learns its own query-key-value projections, producing its own attention pattern. One head might specialize in subject-verb relationships, another in coreference (connecting pronouns to nouns), another in local syntactic structure, and another in long-range semantic connections. Together, the heads capture a multi-faceted understanding of the relationships between tokens.

This is multi-head attention, introduced in the original Transformer paper by Vaswani et al. (2017). The paper describes it this way: “Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.”

Source: Vaswani et al., “Attention Is All You Need,” NeurIPS 2017, Section 3.2.2.


How Multi-Head Attention Works

The mechanics of multi-head attention are simple once you understand single-head attention from Chapter 7. Here is the process, step by step:

Step 1: Split the Hidden Dimension into Heads

Instead of computing a single set of Q, K, V vectors with the full hidden dimension, we split the computation into h separate heads. Each head operates on a smaller dimension.

In the original Transformer (Vaswani et al., 2017):

  • d_model = 512 (the full hidden dimension)
  • h = 8 heads
  • d_k = d_v = 512 / 8 = 64 dimensions per head

In LLaMA 4 Maverick (Meta, April 2025):

  • hidden_size = 5,120
  • 40 query attention heads
  • head_dim = 128 dimensions per head

Source: Vaswani et al., 2017, Section 3.2.2: d_model = 512, h = 8, d_k = d_v = 64. LLaMA 4 Maverick from Ollama model metadata and HuggingFace Transformers (v5.3.0): hidden_size = 5,120, num_attention_heads = 40, head_dim = 128.

Note that in the original Transformer, d_k = d_model / h, so the total computation across all heads is roughly the same as a single head with the full dimension. This is a key design choice: multi-head attention does not increase the total computation compared to single-head attention with the same total dimension. It redistributes the same computation across multiple parallel heads.

Step 2: Each Head Computes Its Own Q, K, V

Each head i has its own weight matrices: W_Q^i, W_K^i, W_V^i. For each token’s input vector x:

q_i = x * W_Q^i    shape: [head_dim]
k_i = x * W_K^i    shape: [head_dim]
v_i = x * W_V^i    shape: [head_dim]

In practice, the model does not literally have h separate weight matrices. Instead, it has one large W_Q matrix of shape [hidden_size x (h * head_dim)], one large W_K, and one large W_V. The output of each large matrix is then reshaped and split into h heads. This is computationally equivalent but more efficient on GPUs because it uses a single large matrix multiplication instead of h smaller ones.

For LLaMA 4 Maverick:

  • W_Q has shape [5,120 x (40 * 128)] = [5,120 x 5,120]
  • The output is reshaped to [sequence_length x 40 x 128], giving 40 separate query vectors of dimension 128

Step 3: Each Head Runs Independent Attention

Each head independently computes the full attention operation from Chapter 7:

head_i = Attention(Q_i, K_i, V_i) = softmax(Q_i * K_i^T / sqrt(d_k)) * V_i

Each head produces its own attention weight matrix (which tokens attend to which) and its own output vectors. The attention patterns are completely independent across heads, because each head has its own learned W_Q, W_K, W_V matrices.

Step 4: Concatenate All Heads

After all heads compute their outputs, the results are concatenated along the head dimension:

concat = [head_1 ; head_2 ; ... ; head_h]    shape: [sequence_length x (h * head_dim)]

For LLaMA 4 Maverick, this produces a vector of dimension 40 * 128 = 5,120 for each token, which matches the hidden_size.

Step 5: Apply the Output Projection

The concatenated output is multiplied by a final weight matrix W_O to produce the multi-head attention output:

MultiHead(Q, K, V) = concat * W_O

Where W_O has shape [(h * head_dim) x hidden_size]. For LLaMA 4 Maverick, W_O is [5,120 x 5,120].

The output projection serves two purposes. First, it mixes information across heads: the output for each token is a learned linear combination of all heads’ outputs, allowing the model to combine the different types of information captured by different heads. Second, it projects the concatenated output back to the model’s hidden dimension, so the output has the same shape as the input and can be passed to the next layer.

The Complete Formula

The full multi-head attention formula from Vaswani et al. (2017) is:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O

where head_i = Attention(X * W_Q^i, X * W_K^i, X * W_V^i)

Source: Vaswani et al., “Attention Is All You Need,” NeurIPS 2017, Equations 2-3.


What Different Heads Learn

One of the most fascinating findings in Transformer research is that different attention heads spontaneously specialize in different linguistic tasks during training. Nobody programs a head to track subject-verb agreement or coreference; these specializations emerge from the training objective of predicting the next token.

Research Evidence

Clark et al. (2019) conducted a detailed analysis of BERT’s attention heads and found remarkably specific patterns:

  • Some heads attend to direct objects of verbs with high accuracy. When processing a verb like “ate,” these heads attend strongly to the direct object (“pizza”).
  • Some heads attend to determiners of nouns. When processing a noun like “cat,” these heads attend to its determiner (“the” or “a”).
  • Some heads attend to objects of prepositions. When processing a preposition like “on,” these heads attend to its object (“table”).
  • Some heads attend to coreferent mentions with high accuracy. When processing a pronoun like “she,” these heads attend to the noun it refers to (“Alice”).

Source: Clark et al., “What Does BERT Look At? An Analysis of BERT’s Attention,” BlackboxNLP Workshop at ACL, 2019. Found heads attending to direct objects, determiners, prepositional objects, and coreferent mentions.

Voita et al. (2019) went further, identifying three main types of specialized heads in a Transformer trained on machine translation:

  1. Positional heads: Attend to tokens at specific relative positions (e.g., always attend to the immediately preceding token, or the token two positions back).
  2. Syntactic heads: Track syntactic relationships like subject-verb or modifier-noun connections.
  3. Rare token heads: Attend to infrequent or unusual tokens in the sequence, which often carry the most information.

Critically, Voita et al. found that when they pruned (removed) attention heads using a learned pruning method, the specialized heads were the last to be pruned. The model could lose many “generic” heads with minimal performance impact, but removing the specialized heads caused significant degradation. This suggests that a relatively small number of heads do the heavy lifting, while many others are partially redundant.

Source: Voita et al., “Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned,” ACL 2019.

Visualizing Head Specialization

To make this concrete, imagine a 4-head attention layer processing the sentence “The cat sat on the mat”:

Head 1 (Positional - "previous token"):
  "cat"  attends strongly to "The"   (position -1)
  "sat"  attends strongly to "cat"   (position -1)
  "on"   attends strongly to "sat"   (position -1)
  "the"  attends strongly to "on"    (position -1)
  "mat"  attends strongly to "the"   (position -1)

Head 2 (Syntactic - subject-verb):
  "sat"  attends strongly to "cat"   (subject of "sat")
  Other tokens have diffuse attention

Head 3 (Syntactic - determiner-noun):
  "cat"  attends strongly to "The"   (determiner of "cat")
  "mat"  attends strongly to "the"   (determiner of "mat")
  Other tokens have diffuse attention

Head 4 (Semantic - prepositional phrase):
  "mat"  attends strongly to "on"    (preposition governing "mat")
  "on"   attends strongly to "sat"   (verb modified by "on the mat")
  Other tokens have diffuse attention

Each head captures a different aspect of the sentence’s structure. When the outputs of all four heads are concatenated and projected through W_O, the resulting vector for each token contains information about its local context (Head 1), its role in subject-verb relationships (Head 2), its determiner (Head 3), and its role in prepositional phrases (Head 4). This multi-faceted representation is far richer than what any single head could produce.


The Dimension Split: How Heads Divide the Work

A critical detail of multi-head attention is how the hidden dimension is divided among heads. This division determines the capacity of each head and has important implications for model design.

The Math

In the original Transformer:

  • Total hidden dimension (d_model): 512
  • Number of heads (h): 8
  • Dimension per head (d_k = d_v): 512 / 8 = 64

Each head works with 64-dimensional query, key, and value vectors. The total computation across all 8 heads is 8 * 64 = 512 dimensions, which equals d_model. This means multi-head attention with h heads and d_k = d_model/h dimensions per head uses the same total number of parameters and roughly the same computation as single-head attention with d_model dimensions.

In LLaMA 4 Maverick:

  • hidden_size: 5,120
  • Number of query heads: 40
  • head_dim: 128
  • Total query dimension: 40 * 128 = 5,120 = hidden_size

Source: LLaMA 4 Maverick from Ollama model metadata and HuggingFace Transformers (v5.3.0): hidden_size = 5,120, num_attention_heads = 40, head_dim = 128, num_key_value_heads = 8, num_hidden_layers = 48.

Why Smaller Heads Work

You might wonder: why would 8 heads with 64 dimensions each be better than 1 head with 512 dimensions? After all, each individual head has less capacity (fewer dimensions to work with).

The answer is that different types of relationships require different types of comparisons. A head that learns to detect subject-verb agreement needs to compare different features of the query and key vectors than a head that learns to detect coreference. By giving each head its own projection matrices (W_Q^i, W_K^i, W_V^i), each head can learn to project the input into a subspace that is optimized for the specific type of relationship it captures.

A single large head would have to use the same projection to capture all types of relationships simultaneously, which is a harder optimization problem. Multiple smaller heads can each specialize, and the output projection W_O learns to combine their specialized outputs into a unified representation.

Vaswani et al. (2017) tested this empirically. They compared:

  • 1 head with d_k = 512 (single-head attention)
  • 8 heads with d_k = 64 (the default configuration)
  • 16 heads with d_k = 32
  • 32 heads with d_k = 16

The 8-head configuration performed best. Too few heads (1) could not capture diverse relationships. Too many heads (32) gave each head too few dimensions (16) to learn meaningful patterns. The sweet spot was in the middle.

Source: Vaswani et al., 2017, Table 3: varying the number of attention heads. h=8 with d_k=64 performed best on English-to-German translation.

Real Model Configurations

Here are the head configurations of several real models:

Modelhidden_sizeQuery HeadsKV Headshead_dimYear
Original Transformer51288642017
GPT-2 (small)7681212642019
LLaMA 2 7B4,09632321282023
LLaMA 2 70B8,1926481282023
Mistral 7B4,0963281282023
LLaMA 4 Maverick5,1204081282025
DeepSeek-V37,168128MLA1282024

Sources: Original Transformer from Vaswani et al. (2017); GPT-2 from OpenAI (2019); LLaMA 2 from Meta (2023), 7B uses MHA (32 KV heads = 32 query heads), 70B uses GQA with 8 KV heads; Mistral 7B from Mistral AI (September 27, 2023), 32 query heads, 8 KV heads; LLaMA 4 Maverick from Ollama model metadata and HuggingFace Transformers (v5.3.0), released April 5, 2025; DeepSeek-V3 from technical report (arXiv:2412.19437, December 2024), 128 heads, head_dim = 128, uses MLA instead of standard KV heads.

Notice two trends:

  1. head_dim has converged to 128 in modern models. Earlier models used 64, but 128 has become the standard. This gives each head enough capacity to learn complex patterns.
  2. The number of KV heads is often much smaller than the number of query heads. LLaMA 2 70B has 64 query heads but only 8 KV heads. LLaMA 4 Maverick has 40 query heads but only 8 KV heads. This is Grouped Query Attention (GQA), which we will explain in detail later in this chapter.

Hands-On: Implementing Multi-Head Attention

Let’s implement multi-head attention from scratch. This code follows the exact same structure as the single-head attention from Chapter 7, but runs multiple heads in parallel:

import numpy as np

def scaled_dot_product_attention(Q, K, V, mask=None):
    """Single-head scaled dot-product attention (from Chapter 7)."""
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)
    exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
    return weights @ V, weights

def multi_head_attention(X, W_Q, W_K, W_V, W_O, n_heads, mask=None):
    """Multi-head attention.

    X: input matrix, shape [seq_len, d_model]
    W_Q: query projection, shape [d_model, d_model]
    W_K: key projection, shape [d_model, d_model]
    W_V: value projection, shape [d_model, d_model]
    W_O: output projection, shape [d_model, d_model]
    n_heads: number of attention heads
    mask: causal mask, shape [seq_len, seq_len]
    """
    seq_len, d_model = X.shape
    d_k = d_model // n_heads

    # Project to Q, K, V
    Q = X @ W_Q  # [seq_len, d_model]
    K = X @ W_K
    V = X @ W_V

    # Reshape into heads: [seq_len, n_heads, d_k]
    Q = Q.reshape(seq_len, n_heads, d_k)
    K = K.reshape(seq_len, n_heads, d_k)
    V = V.reshape(seq_len, n_heads, d_k)

    # Run attention for each head independently
    head_outputs = []
    all_weights = []
    for h in range(n_heads):
        out, w = scaled_dot_product_attention(
            Q[:, h, :], K[:, h, :], V[:, h, :], mask
        )
        head_outputs.append(out)
        all_weights.append(w)

    # Concatenate heads: [seq_len, d_model]
    concat = np.concatenate(head_outputs, axis=-1)

    # Output projection
    output = concat @ W_O

    return output, all_weights


# Example: 6-token sequence, 4 heads, d_model=32
np.random.seed(42)
seq_len = 6
d_model = 32
n_heads = 4
d_k = d_model // n_heads  # 8 per head

X = np.random.randn(seq_len, d_model) * 0.5
W_Q = np.random.randn(d_model, d_model) * 0.2
W_K = np.random.randn(d_model, d_model) * 0.2
W_V = np.random.randn(d_model, d_model) * 0.2
W_O = np.random.randn(d_model, d_model) * 0.2

mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1)

output, weights = multi_head_attention(X, W_Q, W_K, W_V, W_O, n_heads, mask)

print(f"Input shape:  {X.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of heads: {n_heads}")
print(f"Dimensions per head: {d_k}")
print()

for h in range(n_heads):
    print(f"Head {h} attention weights:")
    for i in range(seq_len):
        row = "  Token {}: [".format(i)
        row += ", ".join("{:.3f}".format(w) for w in weights[h][i])
        row += "]"
        print(row)
    print()

When you run this, you will see that each head produces a different attention pattern. Some heads might focus attention on nearby tokens, while others spread attention more broadly. The output has the same shape as the input ([seq_len, d_model]), because the concatenation and output projection map back to the original dimension.


The KV Cache Problem: Why Multi-Head Attention Is Expensive

Multi-head attention has a significant memory cost during inference (text generation), and understanding this cost is essential for understanding why modern models use Grouped Query Attention instead of the original multi-head attention.

The Problem

During text generation, the model produces one token at a time. For each new token, it needs to compute attention against all previous tokens. To avoid recomputing the key and value vectors for every previous token at every step, the model stores them in what is called the KV cache. (We will cover the KV cache in full detail in Chapter 18; for now, the key point is that it stores one key vector and one value vector per token, per layer, per head.)

The KV cache stores, for every layer and every head:

  • One key vector per token (dimension: head_dim)
  • One value vector per token (dimension: head_dim)

The total KV cache size for standard multi-head attention (MHA) is:

KV cache = 2 * num_layers * num_heads * seq_len * head_dim * bytes_per_element

Let’s compute this for a hypothetical model using full MHA (where num_kv_heads = num_query_heads):

For a model with LLaMA 4 Maverick’s dimensions but using full MHA (40 KV heads instead of 8):

  • 2 (K and V) * 48 layers * 40 heads * 128 head_dim = 491,520 values per token
  • In float16 (2 bytes per value): 491,520 * 2 = 983,040 bytes per token, roughly 0.96 MB per token
  • For a 100,000-token sequence: 96 GB just for the KV cache
  • For a 1,000,000-token sequence: 960 GB, far exceeding the memory of any single GPU

This is the fundamental problem: with standard multi-head attention, the KV cache grows linearly with both the number of heads and the sequence length. For long-context models with many heads, the KV cache becomes the dominant memory cost, often exceeding the memory required for the model weights themselves.

The Solution: Share KV Heads

The key insight is that the query heads need to be diverse (each head should learn a different attention pattern), but the key and value heads do not need to be as diverse. Multiple query heads can share the same key and value vectors without significant loss in model quality. This is the core idea behind Multi-Query Attention and Grouped Query Attention.


Multi-Query Attention (MQA): The Extreme Approach

Multi-Query Attention (MQA) was proposed by Noam Shazeer in 2019 in the paper “Fast Transformer Decoding: One Write-Head is All You Need.” The idea is radical: instead of giving each query head its own key and value vectors, use a single shared key head and a single shared value head across all query heads.

Source: Shazeer, “Fast Transformer Decoding: One Write-Head is All You Need,” arXiv:1911.02150, November 2019.

In MQA:

  • Each of the h query heads still has its own W_Q^i projection, producing h different query vectors
  • There is only one W_K and one W_V projection, producing a single key vector and a single value vector per token
  • All h query heads compute attention using the same key and value vectors

This dramatically reduces the KV cache:

MHA KV cache per token: 2 * num_layers * num_heads * head_dim
MQA KV cache per token: 2 * num_layers * 1 * head_dim

For a model with 40 heads, MQA reduces the KV cache by a factor of 40. This is an enormous memory saving that directly translates to faster inference (less memory bandwidth consumed reading the KV cache) and the ability to handle longer sequences.

The Tradeoff

MQA’s weakness is that sharing a single key-value pair across all query heads limits the model’s representational capacity. Different query heads are forced to attend based on the same key vectors, which means they cannot learn completely independent notions of “what is relevant.” In practice, Shazeer (2019) found that MQA produced only a small quality degradation compared to full MHA, but subsequent work by Ainslie et al. (2023) showed that the degradation could be more significant for larger models and more demanding tasks.

Source: Ainslie et al., “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,” arXiv:2305.13245, May 2023. Published at EMNLP 2023.


Grouped Query Attention (GQA): The Best of Both Worlds

Grouped Query Attention (GQA) was proposed by Ainslie et al. (2023) as a middle ground between full Multi-Head Attention (MHA) and Multi-Query Attention (MQA). Instead of giving every query head its own KV pair (MHA) or sharing a single KV pair across all query heads (MQA), GQA groups the query heads and assigns one KV pair per group.

Source: Ainslie et al., “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,” arXiv:2305.13245, May 2023. Published at EMNLP 2023.

How GQA Works

In GQA:

  • There are h query heads (e.g., 40 in LLaMA 4 Maverick)
  • There are g KV heads, where g < h (e.g., 8 in LLaMA 4 Maverick)
  • Each KV head is shared by h/g query heads (e.g., 40/8 = 5 query heads per KV group)

The query heads within each group compute their own independent queries (using their own W_Q^i projections), but they all use the same key and value vectors (from the shared W_K^g and W_V^g projections for their group).

Here is a visual representation for LLaMA 4 Maverick (40 query heads, 8 KV heads):

KV Head 0:  shared by Query Heads  0,  1,  2,  3,  4
KV Head 1:  shared by Query Heads  5,  6,  7,  8,  9
KV Head 2:  shared by Query Heads 10, 11, 12, 13, 14
KV Head 3:  shared by Query Heads 15, 16, 17, 18, 19
KV Head 4:  shared by Query Heads 20, 21, 22, 23, 24
KV Head 5:  shared by Query Heads 25, 26, 27, 28, 29
KV Head 6:  shared by Query Heads 30, 31, 32, 33, 34
KV Head 7:  shared by Query Heads 35, 36, 37, 38, 39

Each group of 5 query heads shares one set of key and value vectors. The 5 query heads within a group can still produce different attention patterns (because they have different W_Q projections), but they are all comparing against the same keys and selecting from the same values.

GQA as a Spectrum

GQA is actually a generalization that includes both MHA and MQA as special cases:

  • MHA (Multi-Head Attention): g = h. Every query head has its own KV head. This is the original Transformer design.
  • GQA (Grouped Query Attention): 1 < g < h. Query heads are grouped, with each group sharing one KV head.
  • MQA (Multi-Query Attention): g = 1. All query heads share a single KV head.
MHA:  Q Q Q Q Q Q Q Q     (8 query heads)
      K K K K K K K K     (8 KV heads, one per query head)

GQA:  Q Q Q Q Q Q Q Q     (8 query heads)
      K K . . K K . .     (2 KV heads, shared by groups of 4)

MQA:  Q Q Q Q Q Q Q Q     (8 query heads)
      K . . . . . . .     (1 KV head, shared by all)

KV Cache Savings

The KV cache reduction from GQA is proportional to the ratio of KV heads to query heads:

Attention TypeKV HeadsKV Cache Size (relative)Example Model
MHA (full)401.0x (baseline)Hypothetical Maverick with full MHA
GQA (8 groups)80.2x (5x reduction)LLaMA 4 Maverick (actual)
GQA (8 groups)80.125x (8x reduction)LLaMA 2 70B (64 query heads)
MQA (1 group)10.025x (40x reduction)(theoretical)

For LLaMA 4 Maverick with 8 KV heads instead of 40:

  • KV cache per token: 2 * 48 layers * 8 heads * 128 dim * 2 bytes = 196,608 bytes = 192 KB per token
  • For 100,000 tokens: 19.2 GB (compared to 96 GB with full MHA)
  • For 1,000,000 tokens: 192 GB (compared to 960 GB with full MHA)

This 5x reduction in KV cache is what makes million-token context windows practical. Without GQA, the KV cache alone for a 1M-token sequence would require nearly a terabyte of memory.

Why GQA Works

The reason GQA works well despite sharing KV heads is that the key and value vectors primarily encode “what information is available at this position,” while the query vectors encode “what information am I looking for.” Different query heads within the same group can still learn to look for different things, even though they are searching through the same set of keys and values.

An analogy: imagine a library where the books (values) are organized on shelves with labels (keys). Multiple researchers (query heads) can search the same library using different search criteria. One researcher might be looking for books about syntax, another for books about semantics. They are searching the same shelves with the same labels, but their different search criteria (queries) lead them to select different books. The library does not need to be duplicated for each researcher.

Ainslie et al. (2023) showed that GQA with 8 KV heads achieved quality comparable to full MHA on a range of benchmarks, while providing significant inference speedups. They also demonstrated that existing MHA models could be “uptrained” to use GQA by mean-pooling the existing KV head weights within each group and then fine-tuning for a small fraction of the original training steps.

Source: Ainslie et al., 2023. GQA achieves quality close to MHA while matching the speed of MQA. Uptraining from MHA to GQA requires only a fraction (alpha) of the original pre-training steps.

Models That Use GQA

GQA has become the dominant attention mechanism in modern LLMs:

ModelQuery HeadsKV HeadsGroup SizeYear
LLaMA 2 70B64882023
Mistral 7B32842023
LLaMA 3 8B32842024
LLaMA 3 70B64882024
LLaMA 4 Maverick40852025
LLaMA 4 Scout40852025

Sources: LLaMA 2 70B from Meta (2023), 64 query heads, 8 KV heads; Mistral 7B from Mistral AI (September 2023), 32 query heads, 8 KV heads; LLaMA 3 from Meta (April 2024); LLaMA 4 from Ollama model metadata and HuggingFace Transformers (v5.3.0), released April 5, 2025.

The pattern is clear: 8 KV heads has become a common choice across model sizes, with the number of query heads varying based on the model’s hidden dimension. This suggests that 8 KV heads provide sufficient representational capacity for the key-value space, while the query space benefits from more heads to capture diverse attention patterns.


Multi-Head Latent Attention (MLA): DeepSeek’s Innovation

While GQA reduces the KV cache by sharing KV heads across groups of query heads, DeepSeek took a fundamentally different approach with Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2 (2024) and used in DeepSeek-V3 (December 2024).

Source: DeepSeek-V3 Technical Report, arXiv:2412.19437, December 2024, Section 2.1.1.

The Core Idea: Low-Rank Compression

Instead of storing separate key and value vectors for each head, MLA compresses all key and value information into a single low-dimensional latent vector per token. During attention computation, this latent vector is “decompressed” back into full key and value vectors using learned up-projection matrices.

The process works like this:

  1. For each token, compute a compressed latent vector c_KV by projecting the input through a down-projection matrix:

    c_KV = W_down * h_t    shape: [d_c]

    Where d_c is the KV compression dimension (512 in DeepSeek-V3), much smaller than the full KV dimension (128 heads * 128 dim = 16,384).

  2. During attention, decompress c_KV back into full keys and values using up-projection matrices:

    K = W_up_K * c_KV    shape: [num_heads * head_dim]
    V = W_up_V * c_KV    shape: [num_heads * head_dim]
  3. MLA also compresses the queries. The input is projected down to a query latent vector of dimension d_c’ (1,536 in DeepSeek-V3), then projected back up to the full query dimension. This reduces activation memory during training, though it does not affect the KV cache size during inference.

  4. For positional information, MLA uses a separate “decoupled” key component that carries RoPE. This decoupled key has a small dimension (d_h^R = 64 in DeepSeek-V3) and is concatenated with the content-based key from step 2. A matching decoupled query component also carries RoPE and is concatenated with the content-based query.

KV Cache in MLA

The critical advantage: during inference, the model only needs to cache the compressed latent vector c_KV (512 dimensions) and the small decoupled RoPE key (64 dimensions) per token, instead of the full key and value vectors for all heads. This is a dramatic compression:

  • Standard MHA (128 heads, head_dim 128): cache 2 * 128 * 128 = 32,768 values per token per layer
  • GQA (8 KV heads, head_dim 128): cache 2 * 8 * 128 = 2,048 values per token per layer
  • MLA (d_c = 512, d_rope = 64): cache 512 + 64 = 576 values per token per layer

MLA achieves a 57x compression compared to standard MHA and a 3.6x compression compared to GQA with 8 KV heads, while maintaining quality comparable to full MHA. This extreme efficiency is one reason DeepSeek-V3 can handle its 128K-token context window effectively.

Source: DeepSeek-V3 Technical Report, arXiv:2412.19437, December 2024, Section 2.1.1. d_c = 512 (KV compression dimension), d_c’ = 1,536 (query compression dimension), d_h^R = 64 (decoupled RoPE key dimension), n_h = 128 attention heads, d_h = 128 per-head dimension. DeepSeek-V3 has 61 layers, hidden dimension 7,168, and 128K-token context window.

The Tradeoff

MLA’s compression introduces additional matrix multiplications (the up-projections) during attention computation, which adds some compute cost. However, since inference is typically memory-bandwidth-bound (limited by how fast data can be read from GPU memory) rather than compute-bound, the reduced memory footprint more than compensates for the extra computation. The net result is faster inference despite the additional matrix operations.


Comparing MHA, GQA, MQA, and MLA

Let’s put all four attention variants side by side:

PropertyMHAGQAMQAMLA
Query headshhhh
KV headshg (where g < h)1h (via decompression)
KV cache per token per layer2 * h * d_k2 * g * d_k2 * d_kd_c + d_rope
QualityBaselineNear MHASlight degradationNear MHA
Inference speedSlowestFastFastestFast
Used byLLaMA 2 7B, GPT-2LLaMA 2 70B, LLaMA 4, MistralPaLM, Gemini 1.0DeepSeek-V2, DeepSeek-V3

Sources: MQA used by PaLM (Google, 2022) and early Gemini models. GQA used by LLaMA 2 70B, LLaMA 3, LLaMA 4, Mistral. MLA used by DeepSeek-V2 and DeepSeek-V3.

The trend in the field is clear: full MHA is being replaced by more memory-efficient variants. GQA has become the most widely adopted approach due to its simplicity and strong quality-efficiency tradeoff. MLA offers even greater compression but requires more architectural complexity. The choice between GQA and MLA depends on the specific deployment constraints and engineering preferences of each team.


Worked Example: Multi-Head Attention with GQA

Let’s trace through a complete multi-head attention computation using Grouped Query Attention, matching the configuration of LLaMA 4 Maverick (scaled down for readability).

Setup

We will use:

  • 4 query heads, 2 KV heads (group size = 2, so query heads 0-1 share KV head 0, query heads 2-3 share KV head 1)
  • head_dim = 2 (tiny, for illustration; real models use 128)
  • 3-token sequence: “She wrote code”

Input matrix X (after embedding + positional encoding), shape [3 x 8]:

X = [[ 0.5,  0.3, -0.1,  0.7,  0.2, -0.4,  0.6,  0.1],   # "She"
     [-0.2,  0.8,  0.4, -0.3,  0.5,  0.1, -0.1,  0.9],   # "wrote"
     [ 0.3, -0.5,  0.6,  0.2, -0.3,  0.7,  0.4, -0.2]]   # "code"

Step 1: Compute Queries (4 heads, each with head_dim=2)

W_Q has shape [8 x 8] (hidden_size x (4 heads * 2 dim)). After projection and reshape:

Q_head0 = [[ 0.42,  0.15],   # "She"
            [ 0.31, -0.22],   # "wrote"
            [-0.18,  0.47]]   # "code"

Q_head1 = [[ 0.28,  0.33],
            [-0.14,  0.51],
            [ 0.39, -0.11]]

Q_head2 = [[-0.05,  0.44],
            [ 0.52,  0.18],
            [ 0.11,  0.36]]

Q_head3 = [[ 0.37, -0.29],
            [ 0.08,  0.41],
            [-0.23,  0.55]]

Step 2: Compute Keys and Values (2 KV heads, each with head_dim=2)

W_K has shape [8 x 4] (hidden_size x (2 KV heads * 2 dim)). After projection and reshape:

K_kv0 = [[ 0.35,  0.21],   # "She"
          [-0.12,  0.48],   # "wrote"
          [ 0.27, -0.15]]   # "code"

K_kv1 = [[ 0.18, -0.33],
          [ 0.41,  0.26],
          [-0.09,  0.52]]

V_kv0 = [[ 0.55,  0.12],
          [-0.28,  0.67],
          [ 0.43, -0.31]]

V_kv1 = [[-0.15,  0.48],
          [ 0.62,  0.09],
          [ 0.21,  0.73]]

Step 3: Attention Computation

Query heads 0 and 1 both use K_kv0 and V_kv0. Query heads 2 and 3 both use K_kv1 and V_kv1.

For Head 0 (using KV head 0):

scores = Q_head0 @ K_kv0.T / sqrt(2)

"She" attending:
  to "She":   (0.42*0.35 + 0.15*0.21) / 1.414 = (0.147 + 0.032) / 1.414 = 0.126
  to "wrote": (0.42*(-0.12) + 0.15*0.48) / 1.414 = (-0.050 + 0.072) / 1.414 = 0.016
  to "code":  (0.42*0.27 + 0.15*(-0.15)) / 1.414 = (0.113 - 0.023) / 1.414 = 0.064

After causal masking (token 0 can only see itself):

"She":   [0.126, -inf, -inf]  --> softmax --> [1.000, 0.000, 0.000]
"wrote": [scores...]          --> softmax --> [weights...]
"code":  [scores...]          --> softmax --> [weights...]

For Head 1 (also using KV head 0, but different queries): The same K_kv0 and V_kv0 are used, but Q_head1 produces different scores, leading to different attention weights. This is the key insight of GQA: even though the keys and values are shared, different query projections produce different attention patterns.

For Heads 2 and 3 (using KV head 1): These heads use K_kv1 and V_kv1, which encode different aspects of the tokens. The attention patterns will be different from Heads 0 and 1, not only because the queries are different, but also because the keys and values are different.

Step 4: Concatenate and Project

Each head produces a [3 x 2] output. Concatenating all 4 heads gives [3 x 8], which is then multiplied by W_O [8 x 8] to produce the final [3 x 8] output.

The final output for each token is a rich combination of information gathered by all four heads, each attending to different aspects of the context through different query-key comparisons.


Hands-On: Implementing GQA

Let’s implement Grouped Query Attention and compare it to standard Multi-Head Attention:

import numpy as np

def attention(Q, K, V, mask=None):
    """Scaled dot-product attention."""
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)
    exp_s = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    weights = exp_s / np.sum(exp_s, axis=-1, keepdims=True)
    return weights @ V, weights

def grouped_query_attention(X, W_Q, W_K, W_V, W_O,
                            n_q_heads, n_kv_heads, head_dim, mask=None):
    """Grouped Query Attention (GQA).

    n_q_heads: number of query heads
    n_kv_heads: number of key/value heads
    head_dim: dimension per head
    """
    seq_len = X.shape[0]
    group_size = n_q_heads // n_kv_heads

    Q_all = (X @ W_Q).reshape(seq_len, n_q_heads, head_dim)
    K_all = (X @ W_K).reshape(seq_len, n_kv_heads, head_dim)
    V_all = (X @ W_V).reshape(seq_len, n_kv_heads, head_dim)

    head_outputs = []
    for q_head in range(n_q_heads):
        kv_head = q_head // group_size  # which KV head this query uses
        out, _ = attention(
            Q_all[:, q_head, :],
            K_all[:, kv_head, :],
            V_all[:, kv_head, :],
            mask
        )
        head_outputs.append(out)

    concat = np.concatenate(head_outputs, axis=-1)
    return concat @ W_O


# Compare MHA vs GQA
np.random.seed(42)
seq_len, d_model, head_dim = 8, 64, 16
n_q_heads = 4
mask = np.triu(np.ones((seq_len, seq_len), dtype=bool), k=1)
X = np.random.randn(seq_len, d_model) * 0.3

# MHA: 4 query heads, 4 KV heads
W_Q_mha = np.random.randn(d_model, n_q_heads * head_dim) * 0.1
W_K_mha = np.random.randn(d_model, n_q_heads * head_dim) * 0.1
W_V_mha = np.random.randn(d_model, n_q_heads * head_dim) * 0.1
W_O_mha = np.random.randn(n_q_heads * head_dim, d_model) * 0.1

out_mha = grouped_query_attention(
    X, W_Q_mha, W_K_mha, W_V_mha, W_O_mha,
    n_q_heads=4, n_kv_heads=4, head_dim=head_dim, mask=mask
)

# GQA: 4 query heads, 2 KV heads
n_kv_heads_gqa = 2
W_Q_gqa = np.random.randn(d_model, n_q_heads * head_dim) * 0.1
W_K_gqa = np.random.randn(d_model, n_kv_heads_gqa * head_dim) * 0.1
W_V_gqa = np.random.randn(d_model, n_kv_heads_gqa * head_dim) * 0.1
W_O_gqa = np.random.randn(n_q_heads * head_dim, d_model) * 0.1

out_gqa = grouped_query_attention(
    X, W_Q_gqa, W_K_gqa, W_V_gqa, W_O_gqa,
    n_q_heads=4, n_kv_heads=2, head_dim=head_dim, mask=mask
)

# MQA: 4 query heads, 1 KV head
W_K_mqa = np.random.randn(d_model, 1 * head_dim) * 0.1
W_V_mqa = np.random.randn(d_model, 1 * head_dim) * 0.1

out_mqa = grouped_query_attention(
    X, W_Q_gqa, W_K_mqa, W_V_mqa, W_O_gqa,
    n_q_heads=4, n_kv_heads=1, head_dim=head_dim, mask=mask
)

print("Output shapes (all should be [8, 64]):")
print(f"  MHA: {out_mha.shape}")
print(f"  GQA: {out_gqa.shape}")
print(f"  MQA: {out_mqa.shape}")

# Compare parameter counts
print("\nParameter counts:")
mha_kv_params = 2 * d_model * n_q_heads * head_dim
gqa_kv_params = 2 * d_model * n_kv_heads_gqa * head_dim
mqa_kv_params = 2 * d_model * 1 * head_dim
print(f"  MHA K+V params: {mha_kv_params:,}")
print(f"  GQA K+V params: {gqa_kv_params:,} ({gqa_kv_params/mha_kv_params:.0%} of MHA)")
print(f"  MQA K+V params: {mqa_kv_params:,} ({mqa_kv_params/mha_kv_params:.0%} of MHA)")

# KV cache comparison (per token, in values)
print("\nKV cache per token (number of values):")
print(f"  MHA: {2 * n_q_heads * head_dim}")
print(f"  GQA: {2 * n_kv_heads_gqa * head_dim}")
print(f"  MQA: {2 * 1 * head_dim}")

This code demonstrates that all three variants produce outputs of the same shape, but GQA and MQA use significantly fewer KV parameters and require less KV cache memory. The output values differ because the weight matrices are different (randomly initialized), but in a trained model, GQA achieves quality close to MHA despite the reduced KV capacity.


Visualizing What Different Heads Learn

Let’s build a visualization that shows how different attention heads in a multi-head attention layer produce different attention patterns for the same input:

import numpy as np
import matplotlib.pyplot as plt

def attention_weights(Q, K, mask=None):
    """Compute attention weights (no value multiplication)."""
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)
    exp_s = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    return exp_s / np.sum(exp_s, axis=-1, keepdims=True)

# Simulate a sentence with 8 tokens
tokens = ["The", "cat", "that", "I", "saw", "was", "on", "mat"]
n = len(tokens)
d_model = 64
n_heads = 4
head_dim = d_model // n_heads  # 16

np.random.seed(12)
X = np.random.randn(n, d_model) * 0.3

# Create separate Q, K projections for each head
mask = np.triu(np.ones((n, n), dtype=bool), k=1)

fig, axes = plt.subplots(1, n_heads, figsize=(20, 5))

for h in range(n_heads):
    W_Q_h = np.random.randn(d_model, head_dim) * 0.2
    W_K_h = np.random.randn(d_model, head_dim) * 0.2

    Q_h = X @ W_Q_h
    K_h = X @ W_K_h
    w = attention_weights(Q_h, K_h, mask)

    ax = axes[h]
    im = ax.imshow(w, cmap="Blues", vmin=0, vmax=1, aspect="auto")
    ax.set_xticks(range(n))
    ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=9)
    ax.set_yticks(range(n))
    ax.set_yticklabels(tokens, fontsize=9)
    ax.set_title(f"Head {h}", fontsize=12)
    ax.set_xlabel("Key (attending to)")
    if h == 0:
        ax.set_ylabel("Query (attending from)")

plt.suptitle("Multi-Head Attention: Each Head Learns Different Patterns",
             fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig("multi_head_patterns.png", dpi=150, bbox_inches="tight")
plt.show()
print("Plot saved to multi_head_patterns.png")

When you run this, you will see four different heatmaps, one per head. Even with random (untrained) weights, the heads produce visibly different attention patterns. In a trained model, these differences would be much more pronounced and linguistically meaningful: one head might show a strong diagonal pattern (attending to the previous token), another might show strong off-diagonal connections (attending to specific syntactic partners), and another might show broad, diffuse attention (gathering global context).


Multi-Head Attention in the Full Model Pipeline

Let’s update the model pipeline from Chapter 7 to show exactly where multi-head attention fits, using LLaMA 4 Maverick’s real configuration (note: Maverick is a Mixture-of-Experts model with 128 experts, which we will cover in Chapter 12; the attention mechanism described here applies to every layer regardless of whether the FFN uses dense or MoE architecture):

Step 1: Tokenization (Chapter 4)
  "She wrote elegant code"
  --> [2296, 6267, 26880, 2574]
  --> 4 tokens

Step 2: Embedding Lookup (Chapter 5)
  Each token ID --> row in embedding table (202,048 x 5,120)
  --> Matrix of shape [4 x 5,120]

Step 3: For each of the 48 Transformer layers:

  a) RMSNorm (Chapter 10)
     Normalize each token's vector

  b) Multi-Head Attention with GQA (THIS CHAPTER)
     - Project input to Q: [4 x 5,120] @ [5,120 x 5,120] = [4 x 5,120]
       Reshape to [4 x 40 x 128] (40 query heads, 128 dim each)
     - Project input to K: [4 x 5,120] @ [5,120 x 1,024] = [4 x 1,024]
       Reshape to [4 x 8 x 128] (8 KV heads, 128 dim each)
     - Project input to V: [4 x 5,120] @ [5,120 x 1,024] = [4 x 1,024]
       Reshape to [4 x 8 x 128] (8 KV heads, 128 dim each)
     - Apply RoPE to Q and K (Chapter 6), in RoPE layers only
     - For each of the 40 query heads:
         Determine which KV head to use (query_head // 5)
         Compute: softmax(Q_head @ K_kv^T / sqrt(128)) @ V_kv
     - Concatenate 40 head outputs: [4 x 5,120]
     - Output projection: [4 x 5,120] @ [5,120 x 5,120] = [4 x 5,120]

  c) Residual Connection (Chapter 10)
     Add attention output to input

  d) RMSNorm (Chapter 10)

  e) Feed-Forward Network (Chapter 9)
     Process each token independently

  f) Residual Connection (Chapter 10)

Step 4: Final RMSNorm + Output Projection
  --> 202,048 probabilities for the next token

Parameter Count for Attention in LLaMA 4 Maverick

Let’s compute the exact parameter count for the attention mechanism:

  • W_Q: [5,120 x 5,120] = 26,214,400 parameters (40 query heads * 128 dim)
  • W_K: [5,120 x 1,024] = 5,242,880 parameters (8 KV heads * 128 dim)
  • W_V: [5,120 x 1,024] = 5,242,880 parameters (8 KV heads * 128 dim)
  • W_O: [5,120 x 5,120] = 26,214,400 parameters

Total per layer: 62,914,560 parameters (approximately 62.9 million)

Compare this to what full MHA would require:

  • W_K (full MHA): [5,120 x 5,120] = 26,214,400 parameters (40 KV heads * 128 dim)
  • W_V (full MHA): [5,120 x 5,120] = 26,214,400 parameters

Full MHA total per layer: 26.2M + 26.2M + 26.2M + 26.2M = 104.9 million parameters

GQA saves 104.9M - 62.9M = 42.0 million parameters per layer, a 40% reduction in attention parameters. Across all 48 layers, that is 2.0 billion fewer parameters, which translates directly to less memory usage and faster computation.

Source: LLaMA 4 Maverick from Ollama model metadata and HuggingFace Transformers (v5.3.0): vocab_size = 202,048, hidden_size = 5,120, num_attention_heads = 40, num_key_value_heads = 8, head_dim = 128, num_hidden_layers = 48.


The Concatenation and Projection Step: Why It Matters

The output projection W_O is often overlooked, but it plays a crucial role. After concatenating the outputs of all heads, each position in the concatenated vector corresponds to a specific head. Without the output projection, the model would have to use the raw concatenation, where the first 128 values come from head 0, the next 128 from head 1, and so on. The heads would be isolated from each other.

The output projection W_O mixes information across heads. Each element of the output is a weighted combination of values from all heads. This allows the model to combine the different types of information captured by different heads into a unified representation. For example, if head 0 captured subject-verb information and head 3 captured prepositional phrase structure, the output projection can combine these into a single vector that encodes both types of information.

Mathematically, the output projection is a linear transformation:

output = concat @ W_O

Where concat has shape [seq_len x (n_heads * head_dim)] and W_O has shape [(n_heads * head_dim) x hidden_size]. For LLaMA 4 Maverick, both dimensions are 5,120, so W_O is a square matrix with 26.2 million parameters. This is the same number of parameters as W_Q, making the output projection a significant component of the attention mechanism.


Why 128 Dimensions Per Head?

Modern models have converged on 128 dimensions per head (head_dim = 128), up from 64 in the original Transformer. This is not arbitrary; it reflects a balance between several factors:

  1. Expressiveness: Each head needs enough dimensions to learn meaningful query-key comparisons. With too few dimensions (e.g., 16), the dot product between queries and keys cannot capture complex relationships. With 128 dimensions, each head can represent nuanced patterns.

  2. RoPE compatibility: Rotary Position Embeddings (Chapter 6) work by pairing dimensions and rotating them. With 128 dimensions, there are 64 rotation pairs, providing a rich set of frequencies for encoding relative position. This is important for long-context models.

  3. Hardware efficiency: GPU tensor cores operate most efficiently on matrix dimensions that are multiples of 8, 16, or 32. A head dimension of 128 aligns well with these hardware constraints, enabling efficient computation.

  4. Scaling factor: The attention scores are divided by sqrt(head_dim) = sqrt(128) = 11.31. This keeps the scores in a numerically stable range for softmax. With head_dim = 64, the scaling factor is 8.0; with head_dim = 128, it is 11.31. Both are reasonable, but 128 provides slightly more headroom for the softmax to produce well-distributed attention weights.

The convergence on head_dim = 128 across LLaMA, Mistral, DeepSeek, and other model families suggests that this value is close to optimal for current model scales and training regimes.


Key Takeaways

  • Multi-head attention runs multiple attention heads in parallel, each with its own learned W_Q, W_K, W_V projection matrices. Each head captures a different type of relationship between tokens (syntax, semantics, coreference, positional patterns). The outputs of all heads are concatenated and projected through W_O to produce the final output.

  • In the original Transformer (Vaswani et al., 2017), multi-head attention used 8 heads with 64 dimensions each, for a total of 512 dimensions. Modern models like LLaMA 4 Maverick use 40 query heads with 128 dimensions each, for a total of 5,120 dimensions.

  • Different heads spontaneously specialize during training. Research by Clark et al. (2019) and Voita et al. (2019) found that specific heads learn to track subject-verb agreement, determiner-noun relationships, coreference, and positional patterns, with specialized heads being the most important for model performance.

  • The KV cache (storing key and value vectors for all previous tokens during generation) is the main memory bottleneck for long-context inference. With standard Multi-Head Attention (MHA), the KV cache grows linearly with both the number of heads and the sequence length.

  • Multi-Query Attention (MQA), proposed by Shazeer (2019), shares a single KV head across all query heads, reducing the KV cache by a factor equal to the number of heads. This provides maximum memory savings but can reduce model quality.

  • Grouped Query Attention (GQA), proposed by Ainslie et al. (2023), groups query heads and assigns one KV head per group. LLaMA 4 Maverick uses 40 query heads with 8 KV heads (groups of 5), reducing the KV cache by 5x compared to full MHA while maintaining near-MHA quality. GQA has become the dominant attention mechanism in modern LLMs.

  • Multi-Head Latent Attention (MLA), used by DeepSeek-V3 (2024), compresses all key-value information into a low-dimensional latent vector (512 dimensions) per token, achieving even greater KV cache compression than GQA. MLA caches only 576 values per token per layer, compared to 2,048 for GQA with 8 heads and 32,768 for full MHA with 128 heads.

  • The output projection W_O mixes information across heads, allowing the model to combine the different types of information captured by different heads into a unified representation. It has the same number of parameters as W_Q (26.2 million in LLaMA 4 Maverick).

  • GQA reduces attention parameters by approximately 40% compared to full MHA in LLaMA 4 Maverick (62.9M vs. 104.9M per layer), saving about 2 billion parameters across all 48 layers.


What’s Next

You now understand how multi-head attention runs multiple attention heads in parallel, how Grouped Query Attention shares KV heads to reduce memory costs, and how different heads specialize in different linguistic relationships. But attention is only half of each Transformer layer. In Chapter 9, we will explore the feed-forward network (FFN), the other major component, which processes each token independently after attention has gathered contextual information, and which contains the majority of the model’s parameters.