Understanding Self-Attention and Multi-Head Attention in Transformers

Srinivas Rahul Sapireddy

--

In the journey of understanding modern NLP architectures like Transformers, two core concepts often take center stage: self-attention and multi-head attention. These mechanisms form the backbone of how Transformers efficiently process sequential data while capturing intricate relationships between input tokens. Let’s break down these concepts step-by-step using visuals and examples to grasp their significance and implementation

What is Self-Attention?

Self-attention allows a model to weigh the importance of each token in the input sequence relative to every other token, including itself. This mechanism dynamically computes contextual embeddings for each token.

Steps in Self-Attention:

  1. Input Representation: The input tokens (e.g., words like “money,” “bank,” “grows”) are embedded into vectors.

2. Query, Key, and Value:

  • Each input vector is transformed into three representations: Query (Q), Key (K), and Value (V) matrices using learned linear transformations.

3. Attention Score Calculation: The attention score between a pair of tokens is computed as: Here, is the dimensionality of the key vectors.

4. Softmax Normalization: The raw scores are passed through a softmax function to ensure they sum to 1, resulting in attention weights.

5. Weighted Summation: The final output for each token is the weighted sum of the value vectors.

Visualizing Self-Attention

Consider an example: tokens like “money” calculate their importance to themselves, and other tokens such as “bank” and “grows.” The resulting attention weights define how much context to gather from other tokens. For instance, if “bank” has multiple meanings (e.g., financial vs. riverbank), self-attention helps dynamically assign weights based on the surrounding context.

What is Multi-Head Attention?

While self-attention computes a single set of attention weights, multi-head attention extends this concept by creating multiple sets of weights in parallel. Each head learns to focus on different parts of the input sequence, capturing diverse relationships.

Steps in Multi-Head Attention:

  1. Multiple Attention Heads:
  • For each attention head, the input is transformed into separate Query, Key, and Value matrices.
  • Each head performs its attention computations independently.

2. Concatenation and Final Projection:

  • The outputs from all attention heads are concatenated and passed through a final linear layer to combine their insights.

3. Parallel Computation: This parallelism enables the model to efficiently handle complex relationships within the input data.

Visualizing Multi-Head Attention

In the example provided, tokens like “money,” “bank,” and “grows” are processed by multiple heads simultaneously. Each head focuses on distinct patterns — some might prioritize syntactic structures, while others emphasize semantic meanings.

Code Implementation of Multi-Head Attention: Here’s an example of how multi-head attention can be implemented in PyTorch:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model

assert d_model % num_heads == 0

self.depth = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.fc_out = nn.Linear(d_model, d_model)

def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth)."""
x = x.view(batch_size, -1, self.num_heads, self.depth)
return x.permute(0, 2, 1, 3)

def forward(self, queries, keys, values):
batch_size = queries.size(0)

Q = self.split_heads(self.W_q(queries), batch_size)
K = self.split_heads(self.W_k(keys), batch_size)
V = self.split_heads(self.W_v(values), batch_size)

# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.depth, dtype=torch.float32))
attention_weights = F.softmax(scores, dim=-1)
scaled_attention = torch.matmul(attention_weights, V)

# Concatenate heads
scaled_attention = scaled_attention.permute(0, 2, 1, 3).contiguous()
concat_attention = scaled_attention.view(batch_size, -1, self.d_model)

# Final linear layer
output = self.fc_out(concat_attention)
return output

# Example usage
queries = torch.rand(3, 10, 512) # (batch_size, seq_len, d_model)
keys = torch.rand(3, 10, 512)
values = torch.rand(3, 10, 512)

multi_head_attention = MultiHeadAttention(d_model=512, num_heads=8)
output = multi_head_attention(queries, keys, values)
print("Multi-Head Attention Output Shape:", output.shape)

This implementation demonstrates how multi-head attention operates by splitting the input into multiple heads, computing scaled dot-product attention for each, and combining the outputs.

Key Advantages of Self-Attention and Multi-Head Attention

  1. Parallel Operations: Unlike RNNs, these mechanisms process all tokens simultaneously, leading to significant speed improvements.
  2. Dynamic Context: The computed attention weights adapt dynamically based on the input, capturing nuanced relationships.
  3. No Fixed Parameters: Attention weights are computed on the fly, avoiding the need for static parameterization.

Applications in Transformers

The Transformer architecture leverages self-attention and multi-head attention layers to encode relationships in data efficiently. These mechanisms are pivotal in tasks ranging from machine translation (e.g., translating “bank” differently in “money bank” vs. “river bank”) to question answering systems.

Deeper Dive: Mathematical Foundations and Code Implementation

To fully appreciate the power of these mechanisms, let’s dive deeper into their mathematical underpinnings and explore code implementations.

Mathematical Foundations:

  • The dot product attention calculation involves a scaling factor to prevent the scores from growing too large, which could lead to vanishing gradients when passed through softmax.
  • Multi-head attention allows each head to specialize in capturing different patterns, such as long-term dependencies or local syntactic structures.

Code Implementation: Here’s a Python snippet demonstrating self-attention in PyTorch:

import torch
import torch.nn.functional as F

# Example inputs
queries = torch.rand(3, 5) # (batch_size, seq_len, d_model)
keys = torch.rand(3, 5)
values = torch.rand(3, 5)

d_k = queries.size(-1)

# Compute scaled dot-product attention
scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, values)

print("Attention Weights:", attention_weights)
print("Output:", output)

This example highlights the computation of scaled dot-product attention, a key component of both self-attention and multi-head attention.

Conclusion

By understanding the interplay between self-attention and multi-head attention, we unlock the core functionality of Transformers. These mechanisms — simple yet powerful — revolutionize how models learn and represent relationships in sequential data.

References:

[1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). Attention Is All You Need. ArXiv. https://arxiv.org/abs/1706.03762

[2] https://jalammar.github.io/illustrated-transformer/

--

--

No responses yet