Skip to content

Day 3: Self-Attention and Multi-Head Attention

Welcome to Day 3. So far, we've used "Attention" as a singular concept. A Query maps to a Key and spits out a Value.

But sentences are multidimensional. "The incredibly wealthy bank... "

If you run a single Attention block on the word "bank", what does it focus on? Does it focus on the grammar? (Bank is a noun). Does it focus on the context? (Bank relates to money). Does it focus on the emotion?

A Single-Head Attention block will mush all of these concepts together in a blurry average.

Multi-Head Attention

Instead of using 1 massive Attention block, we use Multi-Head Attention. We mathematically split the dimensions into \(8\) (or \(12\), or \(96\)) completely standalone Attention blocks running perfectly parallel!

  • Head 1: Strictly learns grammatical structure (Nouns, Verbs).
  • Head 2: Strictly learns geographical relationships.
  • Head 3: Strictly learns emotional tone.
  • Head 4: Learns pronouns and who they map to.

Because they operate independently, the algorithm learns an unbelievably robust, nuanced understanding of language without destroying distinct contextual threads!

Hands-On: Building Parallel Heads in PyTorch!

Look closely at day3_ex.py. We initialize the PyTorch logic to physically split the dimension calculations into independent pieces!

# day3_ex.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads

        # We mathematically split the Embed Dim into independent "Heads"!
        self.head_dim = embed_dim // num_heads
        assert embed_dim % num_heads == 0 # Must be perfectly divisible!

        self.query = nn.Linear(embed_dim, embed_dim)
        # ... [Keys and Values setup]

The Forward Split

Inside the forward() pass, the magic happens. We use .view() and .transpose() to dynamically chop the \(64\)-dimension vector into eight strict \(8\)-dimension chunks, completely isolating their calculations!

    def forward(self, x):
        batch_size = x.size(0)

        # The Math Split! We force PyTorch to partition the Matrix!
        q = self.query(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # We calculate the Attention Dot Products!
        scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attention_weights = F.softmax(scores, dim=-1)

        # We stitch the 8 Independent Heads back together into a single Output Vector!
        context = torch.matmul(attention_weights, v).transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        return self.out(context), attention_weights

Wrapping Up Day 3

Multi-Head Attention is the absolute architectural bedrock of GPT-4, Claude, and Gemini. By deploying \(96+\) independent Attention heads, models can seamlessly track sarcasm, syntax, math, and code all in the same breath.

But we have officially sidestepped a glaring problem. If all of these words are processed at the exact same time... how does the Transformer know what order the words were in?

Tomorrow, on Day 4: Positional Encoding, we inject Trigonometry into our algorithm to solve the sequence paradox.