0. Overview¶
Problem: Implement a complete Transformer encoder block from first principles and train it on a sequence classification task — predicting the majority token in a sequence.
Concepts used: multi-head attention (ch321), Transformer architecture (ch322), positional encoding (ch323), layer normalisation (ch310), GELU (ch309), Adam (ch312), cross-entropy (ch305).
Expected output: trained Transformer achieving >90% accuracy on sequence classification, plus attention weight visualisation.
Difficulty: ★★★★★ | Estimated time: 2 hours
1. Setup¶
import numpy as np
import matplotlib.pyplot as plt
def gelu(x):
return 0.5*x*(1+np.tanh(np.sqrt(2/np.pi)*(x+0.044715*x**3)))
def layer_norm(x, gamma, beta, eps=1e-6):
mean=x.mean(-1,keepdims=True); var=x.var(-1,keepdims=True)
return gamma*(x-mean)/np.sqrt(var+eps)+beta
def softmax_last(z):
z_s=z-z.max(-1,keepdims=True); e=np.exp(z_s); return e/e.sum(-1,keepdims=True)
def sinusoidal_pe(T, d):
PE=np.zeros((T,d))
pos=np.arange(T)[:,None]; i=np.arange(0,d,2)[None,:]
div=10000**(i/d)
PE[:,0::2]=np.sin(pos/div); PE[:,1::2]=np.cos(pos/div)
return PE
# Task: sequence majority classification
# Input: sequence of tokens from vocab {0,1,2}; label = most frequent token
rng=np.random.default_rng(42)
V=3; T=8; d_model=32; n_heads=4; d_ff=64; n_classes=V
def make_data(n=1000, T=8, V=3, seed=0):
rng2=np.random.default_rng(seed)
X=rng2.integers(0,V,(n,T))
y=np.array([np.bincount(row,minlength=V).argmax() for row in X])
return X,y
X,y=make_data(2000)
split=1600
X_tr,y_tr=X[:split],y[:split]
X_te,y_te=X[split:],y[split:]
print(f"Train: {X_tr.shape}, Test: {X_te.shape}")
print(f"Label distribution: {np.bincount(y)}")2. Stage 1 — Transformer Block Implementation¶
class TransformerClassifier:
"""Single Transformer encoder block + mean pooling + linear classifier."""
def __init__(self, vocab_size, d_model, n_heads, d_ff, n_classes, max_len=64, seed=0):
rng=np.random.default_rng(seed)
assert d_model%n_heads==0
self.d_model=d_model; self.H=n_heads; self.d_k=d_model//n_heads
s=np.sqrt(1./d_model)
# Token embedding
self.embed=rng.normal(0,0.02,(vocab_size,d_model))
# Attention projections
self.W_q=rng.normal(0,s,(d_model,d_model))
self.W_k=rng.normal(0,s,(d_model,d_model))
self.W_v=rng.normal(0,s,(d_model,d_model))
self.W_o=rng.normal(0,s,(d_model,d_model))
# Layer norms
self.g1=np.ones(d_model); self.b1=np.zeros(d_model)
self.g2=np.ones(d_model); self.b2=np.zeros(d_model)
# FFN
self.W1=rng.normal(0,s,(d_ff,d_model)); self.b1f=np.zeros(d_ff)
self.W2=rng.normal(0,s,(d_model,d_ff)); self.b2f=np.zeros(d_model)
# Classifier head
self.W_cls=rng.normal(0,s,(n_classes,d_model)); self.b_cls=np.zeros(n_classes)
# Positional encoding (fixed)
self.PE=sinusoidal_pe(max_len,d_model)
def _split_heads(self,X): T=X.shape[0]; return X.reshape(T,self.H,self.d_k).transpose(1,0,2)
def _merge_heads(self,X): return X.transpose(1,0,2).reshape(X.shape[1],self.d_model)
def forward(self, token_ids: np.ndarray) -> tuple:
"""token_ids: (T,). Returns (logits (n_classes,), attn_weights (H,T,T))."""
T=len(token_ids)
X=self.embed[token_ids]+self.PE[:T] # (T, d_model)
# ── Self-attention sublayer ──
X_ln=layer_norm(X,self.g1,self.b1)
Q=self._split_heads(X_ln@self.W_q) # (H,T,d_k)
K=self._split_heads(X_ln@self.W_k)
V_h=self._split_heads(X_ln@self.W_v)
scores=Q@K.transpose(0,2,1)/np.sqrt(self.d_k)
attn=softmax_last(scores) # (H,T,T)
ctx=self._merge_heads(attn@V_h)@self.W_o
X=X+ctx
# ── FFN sublayer ──
X_ln2=layer_norm(X,self.g2,self.b2)
ffn=gelu(X_ln2@self.W1.T+self.b1f)@self.W2.T+self.b2f
X=X+ffn
# Mean pool → classify
pooled=X.mean(0)
logits=self.W_cls@pooled+self.b_cls
return logits,attn
model=TransformerClassifier(V,d_model,n_heads,d_ff,n_classes,seed=0)
n_params=(model.embed.size+model.W_q.size+model.W_k.size+model.W_v.size+
model.W_o.size+model.W1.size+model.W2.size+model.W_cls.size)
print(f"Total parameters: {n_params:,}")
# Test forward pass
logits_test,attn_test=model.forward(X_tr[0])
print(f"Logits: {logits_test.round(3)}, shape {logits_test.shape}")
print(f"Attn weights shape: {attn_test.shape}")3. Stage 2 — Training¶
def cross_entropy(logits, y):
p=softmax_last(logits); return float(-np.log(p[y]+1e-10))
# Adam state
m_state={}; v_state={}; t_state=[0]
lr=5e-3; EPOCHS=300
all_params={
'embed':model.embed,'W_q':model.W_q,'W_k':model.W_k,'W_v':model.W_v,
'W_o':model.W_o,'g1':model.g1,'b1':model.b1,'g2':model.g2,'b2':model.b2,
'W1':model.W1,'b1f':model.b1f,'W2':model.W2,'b2f':model.b2f,
'W_cls':model.W_cls,'b_cls':model.b_cls
}
def adam_step(P, g, key, lr=lr, b1=0.9, b2=0.999, eps=1e-8):
t_state[0]+=1; t=t_state[0]
m_state.setdefault(key,np.zeros_like(P))
v_state.setdefault(key,np.zeros_like(P))
m_state[key]=b1*m_state[key]+(1-b1)*g
v_state[key]=b2*v_state[key]+(1-b2)*g**2
mh=m_state[key]/(1-b1**t); vh=v_state[key]/(1-b2**t)
P-=lr/(np.sqrt(vh)+eps)*mh
losses=[]; accs=[]
rng2=np.random.default_rng(0)
eps_fd=5e-4
for epoch in range(EPOCHS):
# Sample a mini-batch (1 sample for simplicity with numerical grads)
idx=rng2.integers(0,len(X_tr))
xi,yi=X_tr[idx],y_tr[idx]
logits,_=model.forward(xi); loss=cross_entropy(logits,yi)
losses.append(loss)
# Numerical gradients (sparse sampling)
t_state[0]-=1 # will be incremented in adam_step per param
for pname,P in all_params.items():
flat=P.ravel()
n_samp=max(1,len(flat)//30)
idxs=rng2.choice(len(flat),n_samp,replace=False)
for i in idxs:
flat[i]+=eps_fd; lp=cross_entropy(model.forward(xi)[0],yi)
flat[i]-=2*eps_fd; lm=cross_entropy(model.forward(xi)[0],yi)
flat[i]+=eps_fd
g_val=(lp-lm)/(2*eps_fd)
# Direct SGD for sparse updates (Adam would need full grad vector)
flat[i]-=lr*g_val
if (epoch+1)%60==0:
# Eval on test set
correct=sum(int(softmax_last(model.forward(X_te[i])[0]).argmax()==y_te[i])
for i in range(len(X_te)))
acc=correct/len(X_te)
accs.append(acc)
print(f"Epoch {epoch+1:3d}: loss={loss:.4f} test_acc={acc:.1%}")4. Stage 3 — Visualise Attention¶
# Final evaluation
correct=sum(int(softmax_last(model.forward(X_te[i])[0]).argmax()==y_te[i])
for i in range(len(X_te)))
final_acc=correct/len(X_te)
# Attention heatmap for one example
example_idx=0
xi_ex=X_te[example_idx]; yi_ex=y_te[example_idx]
logits_ex,attn_ex=model.forward(xi_ex)
pred=softmax_last(logits_ex).argmax()
fig,axes=plt.subplots(1,n_heads+1,figsize=(14,3))
axes[0].bar(range(n_classes),softmax_last(logits_ex),color=['#e74c3c','#3498db','#2ecc71'])
axes[0].set_title(f'Prediction: {pred} True: {yi_ex}')
axes[0].set_xlabel('Class'); axes[0].set_ylabel('Probability')
axes[0].set_xticks(range(n_classes))
tok_labels=[str(t) for t in xi_ex]
for h in range(n_heads):
im=axes[h+1].imshow(attn_ex[h],cmap='Blues',vmin=0,vmax=1)
axes[h+1].set_title(f'Head {h+1}')
axes[h+1].set_xticks(range(T)); axes[h+1].set_xticklabels(tok_labels,fontsize=8)
axes[h+1].set_yticks(range(T)); axes[h+1].set_yticklabels(tok_labels,fontsize=8)
plt.colorbar(im,ax=axes[h+1],fraction=0.04)
plt.suptitle(f'ch337: Transformer attention | Input: {xi_ex.tolist()} '
f'| Majority: {yi_ex} | Test acc: {final_acc:.1%}',fontsize=10)
plt.tight_layout(); plt.savefig('ch337_transformer.png',dpi=120); plt.show()5. Results & Reflection¶
What was built: A single Transformer encoder block with multi-head self-attention, sinusoidal positional encoding, GELU FFN, and mean-pooled classification head.
What math made it possible:
Scaled dot-product attention (ch321): → softmax → weighted sum of
Head splitting and merging (ch322): independent attention computations in parallel
Layer normalisation (ch310): stabilises training across depth
Sinusoidal PE (ch323): injects position without adding parameters
Extension challenges:
Add a second Transformer block (stack two) and measure whether accuracy improves.
Replace mean pooling with CLS token pooling (prepend a learnable CLS embedding).
Visualise which positions each head attends to most — do different heads specialise?