← All posts

Understanding Transformers — Part 2: Multi-Head Attention

machine-learning deep-learning nlp transformers

In Part 1, we built scaled dot-product attention from scratch. A single attention head works — but it can only look at one “type” of relationship at a time.

Multi-head attention runs $h$ independent attention operations in parallel, each with its own learned projections, then concatenates and re-projects the results.

Motivation: one head isn’t enough

Consider the sentence: “John said that he hurt himself.”

A single attention head must simultaneously track:

  • heJohn (coreference)
  • himselfJohn (reflexive)
  • hurthimself (predicate-argument)

These are structurally different relationships. If you force a single head to represent all of them with one weight matrix, it’s forced to compromise.

Multiple heads let the model specialise: empirically, different heads in trained models learn to track different syntactic and semantic patterns (coreference, positional proximity, subject-verb agreement, etc.).

The mechanics

For each head $i$, we learn three projection matrices:

\[W^Q_i \in \mathbb{R}^{d_{model} \times d_k}, \quad W^K_i \in \mathbb{R}^{d_{model} \times d_k}, \quad W^V_i \in \mathbb{R}^{d_{model} \times d_v}\]

Each head then computes:

\[\text{head}_i = \text{Attention}(X W^Q_i,\ X W^K_i,\ X W^V_i)\]

Outputs are concatenated and projected through $W^O$:

\[\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\ W^O\]

where $W^O \in \mathbb{R}^{h \cdot d_v \times d_{model}}$.

Dimension arithmetic

The original paper uses $d_{model} = 512$, $h = 8$, which gives $d_k = d_v = 512/8 = 64$.

So each head operates on a 64-dim subspace. The total computation is no more expensive than a single head on the full 512 dimensions — we just distribute it.

PyTorch implementation

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model    = d_model
        self.num_heads  = num_heads
        self.d_k        = d_model // num_heads

        # Projections for Q, K, V and the output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        # x: (batch, seq, d_model) → (batch, heads, seq, d_k)
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def forward(self, query, key, value, mask=None):
        B = query.size(0)

        # Project and split
        Q = self.split_heads(self.W_q(query), B)   # (B, h, seq_q, d_k)
        K = self.split_heads(self.W_k(key),   B)   # (B, h, seq_k, d_k)
        V = self.split_heads(self.W_v(value), B)   # (B, h, seq_k, d_k)

        # Scaled dot-product attention per head
        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scores, dim=-1)  # (B, h, seq_q, seq_k)
        context = attn_weights @ V                    # (B, h, seq_q, d_k)

        # Concatenate heads and project
        context = context.transpose(1, 2).contiguous().view(B, -1, self.d_model)
        return self.W_o(context), attn_weights

Three variants of attention in a Transformer

The same MultiHeadAttention module is used in three different roles:

Location Q comes from K, V come from Purpose
Encoder self-attention encoder input encoder input Each token attends to all others in the source
Decoder self-attention decoder input decoder input Each output token attends to prior outputs (masked)
Cross-attention decoder encoder output Decoder attends to the full encoded source

The masking in decoder self-attention is crucial: at inference time, we can’t attend to future tokens (they don’t exist yet), so we apply a causal mask — a lower-triangular matrix of 1s.

What do heads actually learn?

Clark et al. (2019) probed BERT’s attention heads and found that different heads consistently specialise:

  • Some heads track direct syntactic objects (verb → direct object)
  • Some track coreferents (pronoun → antecedent)
  • A few “broad” heads attend somewhat uniformly, possibly acting as no-ops or residual paths

This specialisation emerges from gradient descent alone — it’s not hardcoded.

Key takeaways

  • Multi-head attention = $h$ independent attention heads, each with learned projections, concatenated and projected.
  • Dimensionality per head = $d_{model} / h$ — total cost is the same as one full-width head.
  • Three different configurations (self, masked self, cross) power the full Transformer.
  • Heads empirically specialise on different linguistic relationships.

Next: positional encoding — how the Transformer knows where each token is, since attention itself is position-blind.