Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

1. The bottleneck problem, revisited

In seq2seq models (ch319), all information about the input is compressed into one vector. Bahdanau attention (2015) proposed that the decoder should have access to all encoder hidden states, weighted by their relevance to the current decoding step.

General attention: given a query qq and a set of key-value pairs (ki,vi)(k_i, v_i), compute a weighted sum of values where weights measure query-key similarity:

Attention(q,K,V)=iαivi,αi=softmax(score(q,ki))\text{Attention}(q, K, V) = \sum_i \alpha_i v_i, \quad \alpha_i = \text{softmax}(\text{score}(q, k_i))

(Softmax: ch305. Weighted averages: ch248. Dot products: ch133.)


2. Scaled Dot-Product Attention

Vaswani et al. (2017) use dot-product as the score function, scaled by dk\sqrt{d_k}:

score(q,k)=qkdk\text{score}(q, k) = \frac{q \cdot k}{\sqrt{d_k}}

The scaling prevents the dot products from growing large when dkd_k is large (which would push softmax into regions of near-zero gradient).

In matrix form for a batch of queries:

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V
import numpy as np
import matplotlib.pyplot as plt


def softmax_2d(Z: np.ndarray) -> np.ndarray:
    """Softmax over last axis."""
    Z_s = Z - Z.max(axis=-1, keepdims=True)
    e = np.exp(Z_s)
    return e / e.sum(axis=-1, keepdims=True)


def scaled_dot_product_attention(Q: np.ndarray, K: np.ndarray,
                                  V: np.ndarray,
                                  mask: np.ndarray = None) -> tuple:
    """
    Q: (T_q, d_k)  K: (T_k, d_k)  V: (T_k, d_v)
    Returns: (output (T_q, d_v), attn_weights (T_q, T_k))
    """
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)     # (T_q, T_k)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)
    weights = softmax_2d(scores)         # (T_q, T_k)
    output = weights @ V                 # (T_q, d_v)
    return output, weights


# ── Demo: retrieve information based on query relevance ──
rng = np.random.default_rng(42)
d_k = 8; d_v = 8; T_k = 6

# Keys represent different "memories"
K = rng.normal(0, 1, (T_k, d_k))
V = rng.normal(0, 1, (T_k, d_v))

# Queries: test two queries with different relevance profiles
T_q = 3
Q = rng.normal(0, 1, (T_q, d_k))

# Make query 0 similar to key 2
Q[0] = K[2] + rng.normal(0, 0.1, d_k)
# Make query 1 similar to key 4
Q[1] = K[4] + rng.normal(0, 0.1, d_k)

output, weights = scaled_dot_product_attention(Q, K, V)

fig, axes = plt.subplots(1, 3, figsize=(13, 4))
im0 = axes[0].imshow(weights, cmap='Blues', vmin=0, vmax=1)
axes[0].set_title('Attention weights (3 queries × 6 keys)')
axes[0].set_xlabel('Key index'); axes[0].set_ylabel('Query index')
plt.colorbar(im0, ax=axes[0], fraction=0.046)
for i in range(T_q):
    for j in range(T_k):
        axes[0].text(j, i, f'{weights[i,j]:.2f}', ha='center', va='center', fontsize=7)

# Effect of temperature (scaling factor)
temps = [0.1, 1.0, np.sqrt(d_k), 10.0]
raw_score = Q[0] @ K.T
for ax, temp, title_extra in zip([axes[1]], [np.sqrt(d_k)], ['√d_k scaling']):
    all_weights = [softmax_2d((raw_score / t)[None, :])[0] for t in temps]
    for w, t in zip(all_weights, temps):
        axes[1].plot(w, label=f'temp={t:.1f}', lw=1.5)
    axes[1].set_title('Effect of temperature on attention sharpness(query 0)')
    axes[1].set_xlabel('Key index'); axes[1].set_ylabel('Attention weight')
    axes[1].legend(fontsize=8)

# Multi-Head Attention structure (diagram as plot)
axes[2].axis('off')
text = (
    "Multi-Head Attention\n\n"
    "For h=1..H heads:\n"
    "  Q_h = Q @ W_q_h   (d_model → d_k)\n"
    "  K_h = K @ W_k_h   (d_model → d_k)\n"
    "  V_h = V @ W_v_h   (d_model → d_v)\n"
    "  head_h = Attention(Q_h, K_h, V_h)\n\n"
    "Output = Concat(head_1,...,head_H) @ W_O\n\n"
    "Each head attends to different\n"
    "parts of the input independently."
)
axes[2].text(0.05, 0.5, text, transform=axes[2].transAxes,
             va='center', fontsize=9, family='monospace',
             bbox=dict(boxstyle='round', facecolor='#ecf0f1', alpha=0.8))
axes[2].set_title('Multi-Head Attention structure')

plt.tight_layout()
plt.savefig('ch321_attention.png', dpi=120)
plt.show()
# Multi-Head Attention implementation
class MultiHeadAttention:
    """Multi-head scaled dot-product attention."""

    def __init__(self, d_model: int, n_heads: int, seed: int = 0):
        assert d_model % n_heads == 0
        rng = np.random.default_rng(seed)
        self.H = n_heads; self.d_k = d_model // n_heads; self.d_model = d_model
        s = np.sqrt(1.0 / d_model)
        self.W_q = rng.normal(0, s, (d_model, d_model))
        self.W_k = rng.normal(0, s, (d_model, d_model))
        self.W_v = rng.normal(0, s, (d_model, d_model))
        self.W_o = rng.normal(0, s, (d_model, d_model))

    def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray,
                mask: np.ndarray = None) -> tuple:
        """Q, K, V: (T, d_model). Returns (T, d_model), attn_weights (H, T_q, T_k)."""
        T_q = Q.shape[0]; T_k = K.shape[0]
        # Project
        Q_p = Q @ self.W_q; K_p = K @ self.W_k; V_p = V @ self.W_v
        # Split into heads: (H, T, d_k)
        def split(X, T): return X.reshape(T, self.H, self.d_k).transpose(1, 0, 2)
        Q_h = split(Q_p, T_q); K_h = split(K_p, T_k); V_h = split(V_p, T_k)

        # Attend per head
        all_out = np.zeros((self.H, T_q, self.d_k)); all_w = np.zeros((self.H, T_q, T_k))
        for h in range(self.H):
            scores = Q_h[h] @ K_h[h].T / np.sqrt(self.d_k)
            if mask is not None: scores = np.where(mask, -1e9, scores)
            w = softmax_2d(scores)
            all_out[h] = w @ V_h[h]; all_w[h] = w

        # Merge heads and project
        merged = all_out.transpose(1, 0, 2).reshape(T_q, self.d_model)
        output = merged @ self.W_o
        return output, all_w


rng = np.random.default_rng(0)
T = 8; d_model = 32; n_heads = 4
X = rng.normal(0, 1, (T, d_model))

mha = MultiHeadAttention(d_model, n_heads, seed=0)
out, weights = mha.forward(X, X, X)  # self-attention

print(f"Input shape:  {X.shape}")
print(f"Output shape: {out.shape}")
print(f"Attn weights: {weights.shape}  (n_heads, T_q, T_k)")
print(f"Output std:   {out.std():.4f}")

3. Self-attention vs cross-attention

  • Self-attention: Q,K,VQ, K, V all come from the same sequence. Used in Transformer encoders.

  • Cross-attention: QQ from decoder, K,VK, V from encoder. Used in encoder-decoder models.

  • Causal (masked) self-attention: mask the upper triangle so position tt can only attend to positions t\leq t. Used in decoder-only Transformers (GPT-style).

4. Summary

  • Attention: compute a query-key similarity; use it to weight a sum of values.

  • Scaled dot-product: divide scores by dk\sqrt{d_k} to prevent softmax saturation.

  • Multi-head: run HH attention functions in parallel with different learned projections.

  • Self-attention has O(T2)O(T^2) complexity — the key computational bottleneck for long sequences.

5. Forward and backward references

Used here: dot product (ch133), softmax (ch305), embeddings (ch320), linear projections (ch154).

This will reappear in ch322 — Transformers, where self-attention is combined with FFN layers and residual connections to form the complete Transformer block.