1. The bottleneck problem, revisited¶
In seq2seq models (ch319), all information about the input is compressed into one vector. Bahdanau attention (2015) proposed that the decoder should have access to all encoder hidden states, weighted by their relevance to the current decoding step.
General attention: given a query and a set of key-value pairs , compute a weighted sum of values where weights measure query-key similarity:
(Softmax: ch305. Weighted averages: ch248. Dot products: ch133.)
2. Scaled Dot-Product Attention¶
Vaswani et al. (2017) use dot-product as the score function, scaled by :
The scaling prevents the dot products from growing large when is large (which would push softmax into regions of near-zero gradient).
In matrix form for a batch of queries:
import numpy as np
import matplotlib.pyplot as plt
def softmax_2d(Z: np.ndarray) -> np.ndarray:
"""Softmax over last axis."""
Z_s = Z - Z.max(axis=-1, keepdims=True)
e = np.exp(Z_s)
return e / e.sum(axis=-1, keepdims=True)
def scaled_dot_product_attention(Q: np.ndarray, K: np.ndarray,
V: np.ndarray,
mask: np.ndarray = None) -> tuple:
"""
Q: (T_q, d_k) K: (T_k, d_k) V: (T_k, d_v)
Returns: (output (T_q, d_v), attn_weights (T_q, T_k))
"""
d_k = Q.shape[-1]
scores = Q @ K.T / np.sqrt(d_k) # (T_q, T_k)
if mask is not None:
scores = np.where(mask, -1e9, scores)
weights = softmax_2d(scores) # (T_q, T_k)
output = weights @ V # (T_q, d_v)
return output, weights
# ── Demo: retrieve information based on query relevance ──
rng = np.random.default_rng(42)
d_k = 8; d_v = 8; T_k = 6
# Keys represent different "memories"
K = rng.normal(0, 1, (T_k, d_k))
V = rng.normal(0, 1, (T_k, d_v))
# Queries: test two queries with different relevance profiles
T_q = 3
Q = rng.normal(0, 1, (T_q, d_k))
# Make query 0 similar to key 2
Q[0] = K[2] + rng.normal(0, 0.1, d_k)
# Make query 1 similar to key 4
Q[1] = K[4] + rng.normal(0, 0.1, d_k)
output, weights = scaled_dot_product_attention(Q, K, V)
fig, axes = plt.subplots(1, 3, figsize=(13, 4))
im0 = axes[0].imshow(weights, cmap='Blues', vmin=0, vmax=1)
axes[0].set_title('Attention weights (3 queries × 6 keys)')
axes[0].set_xlabel('Key index'); axes[0].set_ylabel('Query index')
plt.colorbar(im0, ax=axes[0], fraction=0.046)
for i in range(T_q):
for j in range(T_k):
axes[0].text(j, i, f'{weights[i,j]:.2f}', ha='center', va='center', fontsize=7)
# Effect of temperature (scaling factor)
temps = [0.1, 1.0, np.sqrt(d_k), 10.0]
raw_score = Q[0] @ K.T
for ax, temp, title_extra in zip([axes[1]], [np.sqrt(d_k)], ['√d_k scaling']):
all_weights = [softmax_2d((raw_score / t)[None, :])[0] for t in temps]
for w, t in zip(all_weights, temps):
axes[1].plot(w, label=f'temp={t:.1f}', lw=1.5)
axes[1].set_title('Effect of temperature on attention sharpness(query 0)')
axes[1].set_xlabel('Key index'); axes[1].set_ylabel('Attention weight')
axes[1].legend(fontsize=8)
# Multi-Head Attention structure (diagram as plot)
axes[2].axis('off')
text = (
"Multi-Head Attention\n\n"
"For h=1..H heads:\n"
" Q_h = Q @ W_q_h (d_model → d_k)\n"
" K_h = K @ W_k_h (d_model → d_k)\n"
" V_h = V @ W_v_h (d_model → d_v)\n"
" head_h = Attention(Q_h, K_h, V_h)\n\n"
"Output = Concat(head_1,...,head_H) @ W_O\n\n"
"Each head attends to different\n"
"parts of the input independently."
)
axes[2].text(0.05, 0.5, text, transform=axes[2].transAxes,
va='center', fontsize=9, family='monospace',
bbox=dict(boxstyle='round', facecolor='#ecf0f1', alpha=0.8))
axes[2].set_title('Multi-Head Attention structure')
plt.tight_layout()
plt.savefig('ch321_attention.png', dpi=120)
plt.show()# Multi-Head Attention implementation
class MultiHeadAttention:
"""Multi-head scaled dot-product attention."""
def __init__(self, d_model: int, n_heads: int, seed: int = 0):
assert d_model % n_heads == 0
rng = np.random.default_rng(seed)
self.H = n_heads; self.d_k = d_model // n_heads; self.d_model = d_model
s = np.sqrt(1.0 / d_model)
self.W_q = rng.normal(0, s, (d_model, d_model))
self.W_k = rng.normal(0, s, (d_model, d_model))
self.W_v = rng.normal(0, s, (d_model, d_model))
self.W_o = rng.normal(0, s, (d_model, d_model))
def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray,
mask: np.ndarray = None) -> tuple:
"""Q, K, V: (T, d_model). Returns (T, d_model), attn_weights (H, T_q, T_k)."""
T_q = Q.shape[0]; T_k = K.shape[0]
# Project
Q_p = Q @ self.W_q; K_p = K @ self.W_k; V_p = V @ self.W_v
# Split into heads: (H, T, d_k)
def split(X, T): return X.reshape(T, self.H, self.d_k).transpose(1, 0, 2)
Q_h = split(Q_p, T_q); K_h = split(K_p, T_k); V_h = split(V_p, T_k)
# Attend per head
all_out = np.zeros((self.H, T_q, self.d_k)); all_w = np.zeros((self.H, T_q, T_k))
for h in range(self.H):
scores = Q_h[h] @ K_h[h].T / np.sqrt(self.d_k)
if mask is not None: scores = np.where(mask, -1e9, scores)
w = softmax_2d(scores)
all_out[h] = w @ V_h[h]; all_w[h] = w
# Merge heads and project
merged = all_out.transpose(1, 0, 2).reshape(T_q, self.d_model)
output = merged @ self.W_o
return output, all_w
rng = np.random.default_rng(0)
T = 8; d_model = 32; n_heads = 4
X = rng.normal(0, 1, (T, d_model))
mha = MultiHeadAttention(d_model, n_heads, seed=0)
out, weights = mha.forward(X, X, X) # self-attention
print(f"Input shape: {X.shape}")
print(f"Output shape: {out.shape}")
print(f"Attn weights: {weights.shape} (n_heads, T_q, T_k)")
print(f"Output std: {out.std():.4f}")3. Self-attention vs cross-attention¶
Self-attention: all come from the same sequence. Used in Transformer encoders.
Cross-attention: from decoder, from encoder. Used in encoder-decoder models.
Causal (masked) self-attention: mask the upper triangle so position can only attend to positions . Used in decoder-only Transformers (GPT-style).
4. Summary¶
Attention: compute a query-key similarity; use it to weight a sum of values.
Scaled dot-product: divide scores by to prevent softmax saturation.
Multi-head: run attention functions in parallel with different learned projections.
Self-attention has complexity — the key computational bottleneck for long sequences.
5. Forward and backward references¶
Used here: dot product (ch133), softmax (ch305), embeddings (ch320), linear projections (ch154).
This will reappear in ch322 — Transformers, where self-attention is combined with FFN layers and residual connections to form the complete Transformer block.