Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

ch337 — Project: Transformer Block from Scratch

0. Overview

Problem: Implement a complete Transformer encoder block from first principles and train it on a sequence classification task — predicting the majority token in a sequence.

Concepts used: multi-head attention (ch321), Transformer architecture (ch322), positional encoding (ch323), layer normalisation (ch310), GELU (ch309), Adam (ch312), cross-entropy (ch305).

Expected output: trained Transformer achieving >90% accuracy on sequence classification, plus attention weight visualisation.

Difficulty: ★★★★★ | Estimated time: 2 hours

1. Setup

import numpy as np
import matplotlib.pyplot as plt

def gelu(x):
    return 0.5*x*(1+np.tanh(np.sqrt(2/np.pi)*(x+0.044715*x**3)))

def layer_norm(x, gamma, beta, eps=1e-6):
    mean=x.mean(-1,keepdims=True); var=x.var(-1,keepdims=True)
    return gamma*(x-mean)/np.sqrt(var+eps)+beta

def softmax_last(z):
    z_s=z-z.max(-1,keepdims=True); e=np.exp(z_s); return e/e.sum(-1,keepdims=True)

def sinusoidal_pe(T, d):
    PE=np.zeros((T,d))
    pos=np.arange(T)[:,None]; i=np.arange(0,d,2)[None,:]
    div=10000**(i/d)
    PE[:,0::2]=np.sin(pos/div); PE[:,1::2]=np.cos(pos/div)
    return PE

# Task: sequence majority classification
# Input: sequence of tokens from vocab {0,1,2}; label = most frequent token
rng=np.random.default_rng(42)
V=3; T=8; d_model=32; n_heads=4; d_ff=64; n_classes=V

def make_data(n=1000, T=8, V=3, seed=0):
    rng2=np.random.default_rng(seed)
    X=rng2.integers(0,V,(n,T))
    y=np.array([np.bincount(row,minlength=V).argmax() for row in X])
    return X,y

X,y=make_data(2000)
split=1600
X_tr,y_tr=X[:split],y[:split]
X_te,y_te=X[split:],y[split:]
print(f"Train: {X_tr.shape}, Test: {X_te.shape}")
print(f"Label distribution: {np.bincount(y)}")

2. Stage 1 — Transformer Block Implementation

class TransformerClassifier:
    """Single Transformer encoder block + mean pooling + linear classifier."""

    def __init__(self, vocab_size, d_model, n_heads, d_ff, n_classes, max_len=64, seed=0):
        rng=np.random.default_rng(seed)
        assert d_model%n_heads==0
        self.d_model=d_model; self.H=n_heads; self.d_k=d_model//n_heads

        s=np.sqrt(1./d_model)
        # Token embedding
        self.embed=rng.normal(0,0.02,(vocab_size,d_model))
        # Attention projections
        self.W_q=rng.normal(0,s,(d_model,d_model))
        self.W_k=rng.normal(0,s,(d_model,d_model))
        self.W_v=rng.normal(0,s,(d_model,d_model))
        self.W_o=rng.normal(0,s,(d_model,d_model))
        # Layer norms
        self.g1=np.ones(d_model); self.b1=np.zeros(d_model)
        self.g2=np.ones(d_model); self.b2=np.zeros(d_model)
        # FFN
        self.W1=rng.normal(0,s,(d_ff,d_model)); self.b1f=np.zeros(d_ff)
        self.W2=rng.normal(0,s,(d_model,d_ff)); self.b2f=np.zeros(d_model)
        # Classifier head
        self.W_cls=rng.normal(0,s,(n_classes,d_model)); self.b_cls=np.zeros(n_classes)
        # Positional encoding (fixed)
        self.PE=sinusoidal_pe(max_len,d_model)

    def _split_heads(self,X): T=X.shape[0]; return X.reshape(T,self.H,self.d_k).transpose(1,0,2)
    def _merge_heads(self,X): return X.transpose(1,0,2).reshape(X.shape[1],self.d_model)

    def forward(self, token_ids: np.ndarray) -> tuple:
        """token_ids: (T,). Returns (logits (n_classes,), attn_weights (H,T,T))."""
        T=len(token_ids)
        X=self.embed[token_ids]+self.PE[:T]  # (T, d_model)

        # ── Self-attention sublayer ──
        X_ln=layer_norm(X,self.g1,self.b1)
        Q=self._split_heads(X_ln@self.W_q)  # (H,T,d_k)
        K=self._split_heads(X_ln@self.W_k)
        V_h=self._split_heads(X_ln@self.W_v)
        scores=Q@K.transpose(0,2,1)/np.sqrt(self.d_k)
        attn=softmax_last(scores)           # (H,T,T)
        ctx=self._merge_heads(attn@V_h)@self.W_o
        X=X+ctx

        # ── FFN sublayer ──
        X_ln2=layer_norm(X,self.g2,self.b2)
        ffn=gelu(X_ln2@self.W1.T+self.b1f)@self.W2.T+self.b2f
        X=X+ffn

        # Mean pool → classify
        pooled=X.mean(0)
        logits=self.W_cls@pooled+self.b_cls
        return logits,attn

model=TransformerClassifier(V,d_model,n_heads,d_ff,n_classes,seed=0)
n_params=(model.embed.size+model.W_q.size+model.W_k.size+model.W_v.size+
          model.W_o.size+model.W1.size+model.W2.size+model.W_cls.size)
print(f"Total parameters: {n_params:,}")

# Test forward pass
logits_test,attn_test=model.forward(X_tr[0])
print(f"Logits: {logits_test.round(3)}, shape {logits_test.shape}")
print(f"Attn weights shape: {attn_test.shape}")

3. Stage 2 — Training

def cross_entropy(logits, y):
    p=softmax_last(logits); return float(-np.log(p[y]+1e-10))

# Adam state
m_state={}; v_state={}; t_state=[0]
lr=5e-3; EPOCHS=300

all_params={
    'embed':model.embed,'W_q':model.W_q,'W_k':model.W_k,'W_v':model.W_v,
    'W_o':model.W_o,'g1':model.g1,'b1':model.b1,'g2':model.g2,'b2':model.b2,
    'W1':model.W1,'b1f':model.b1f,'W2':model.W2,'b2f':model.b2f,
    'W_cls':model.W_cls,'b_cls':model.b_cls
}

def adam_step(P, g, key, lr=lr, b1=0.9, b2=0.999, eps=1e-8):
    t_state[0]+=1; t=t_state[0]
    m_state.setdefault(key,np.zeros_like(P))
    v_state.setdefault(key,np.zeros_like(P))
    m_state[key]=b1*m_state[key]+(1-b1)*g
    v_state[key]=b2*v_state[key]+(1-b2)*g**2
    mh=m_state[key]/(1-b1**t); vh=v_state[key]/(1-b2**t)
    P-=lr/(np.sqrt(vh)+eps)*mh

losses=[]; accs=[]
rng2=np.random.default_rng(0)
eps_fd=5e-4

for epoch in range(EPOCHS):
    # Sample a mini-batch (1 sample for simplicity with numerical grads)
    idx=rng2.integers(0,len(X_tr))
    xi,yi=X_tr[idx],y_tr[idx]
    logits,_=model.forward(xi); loss=cross_entropy(logits,yi)
    losses.append(loss)

    # Numerical gradients (sparse sampling)
    t_state[0]-=1  # will be incremented in adam_step per param
    for pname,P in all_params.items():
        flat=P.ravel()
        n_samp=max(1,len(flat)//30)
        idxs=rng2.choice(len(flat),n_samp,replace=False)
        for i in idxs:
            flat[i]+=eps_fd; lp=cross_entropy(model.forward(xi)[0],yi)
            flat[i]-=2*eps_fd; lm=cross_entropy(model.forward(xi)[0],yi)
            flat[i]+=eps_fd
            g_val=(lp-lm)/(2*eps_fd)
            # Direct SGD for sparse updates (Adam would need full grad vector)
            flat[i]-=lr*g_val

    if (epoch+1)%60==0:
        # Eval on test set
        correct=sum(int(softmax_last(model.forward(X_te[i])[0]).argmax()==y_te[i])
                    for i in range(len(X_te)))
        acc=correct/len(X_te)
        accs.append(acc)
        print(f"Epoch {epoch+1:3d}: loss={loss:.4f}  test_acc={acc:.1%}")

4. Stage 3 — Visualise Attention

# Final evaluation
correct=sum(int(softmax_last(model.forward(X_te[i])[0]).argmax()==y_te[i])
            for i in range(len(X_te)))
final_acc=correct/len(X_te)

# Attention heatmap for one example
example_idx=0
xi_ex=X_te[example_idx]; yi_ex=y_te[example_idx]
logits_ex,attn_ex=model.forward(xi_ex)
pred=softmax_last(logits_ex).argmax()

fig,axes=plt.subplots(1,n_heads+1,figsize=(14,3))
axes[0].bar(range(n_classes),softmax_last(logits_ex),color=['#e74c3c','#3498db','#2ecc71'])
axes[0].set_title(f'Prediction: {pred}  True: {yi_ex}')
axes[0].set_xlabel('Class'); axes[0].set_ylabel('Probability')
axes[0].set_xticks(range(n_classes))

tok_labels=[str(t) for t in xi_ex]
for h in range(n_heads):
    im=axes[h+1].imshow(attn_ex[h],cmap='Blues',vmin=0,vmax=1)
    axes[h+1].set_title(f'Head {h+1}')
    axes[h+1].set_xticks(range(T)); axes[h+1].set_xticklabels(tok_labels,fontsize=8)
    axes[h+1].set_yticks(range(T)); axes[h+1].set_yticklabels(tok_labels,fontsize=8)
    plt.colorbar(im,ax=axes[h+1],fraction=0.04)

plt.suptitle(f'ch337: Transformer attention  |  Input: {xi_ex.tolist()} '
             f'|  Majority: {yi_ex}  |  Test acc: {final_acc:.1%}',fontsize=10)
plt.tight_layout(); plt.savefig('ch337_transformer.png',dpi=120); plt.show()

5. Results & Reflection

What was built: A single Transformer encoder block with multi-head self-attention, sinusoidal positional encoding, GELU FFN, and mean-pooled classification head.

What math made it possible:

  • Scaled dot-product attention (ch321): QK/dkQK^\top/\sqrt{d_k} → softmax → weighted sum of VV

  • Head splitting and merging (ch322): HH independent attention computations in parallel

  • Layer normalisation (ch310): stabilises training across depth

  • Sinusoidal PE (ch323): injects position without adding parameters

Extension challenges:

  1. Add a second Transformer block (stack two) and measure whether accuracy improves.

  2. Replace mean pooling with CLS token pooling (prepend a learnable CLS embedding).

  3. Visualise which positions each head attends to most — do different heads specialise?