Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

ch335 — Project: CNN for Image Classification

0. Overview

Problem: Build a convolutional neural network to classify 32×32 images into 10 classes using a synthetic dataset, then evaluate with proper train/val/test protocol.

Concepts used: convolution (ch314), pooling (ch315), CNN architectures (ch316), ResNet-style residual connections (ch316), batch normalisation (ch310), Adam (ch312).

Expected output: trained CNN >80% test accuracy on synthetic image data, with activation map visualisations and filter visualisations.

Difficulty: ★★★★☆ | Estimated time: 90–120 minutes

1. Setup — Synthetic Image Dataset

import numpy as np
import matplotlib.pyplot as plt

rng = np.random.default_rng(42)

def make_synthetic_images(n_per_class: int = 200, img_size: int = 16,
                           n_classes: int = 4, seed: int = 0) -> tuple:
    """Create synthetic images: each class is a different geometric pattern."""
    rng = np.random.default_rng(seed)
    imgs, labels = [], []
    for cls in range(n_classes):
        for _ in range(n_per_class):
            img = rng.normal(0, 0.1, (1, img_size, img_size))  # noise background
            cx, cy = img_size//2 + rng.integers(-3,4), img_size//2 + rng.integers(-3,4)
            # Each class: different pattern
            if cls == 0:   # horizontal bar
                img[0, cy-1:cy+2, max(0,cx-5):min(img_size,cx+6)] = 1.0
            elif cls == 1: # vertical bar
                img[0, max(0,cy-5):min(img_size,cy+6), cx-1:cx+2] = 1.0
            elif cls == 2: # circle outline
                for r in range(img_size):
                    for c in range(img_size):
                        if abs((r-cy)**2+(c-cx)**2 - 16) < 4:
                            img[0,r,c] = 1.0
            elif cls == 3: # diagonal
                for k in range(-img_size, img_size):
                    r, c = cy+k, cx+k
                    if 0<=r<img_size and 0<=c<img_size:
                        img[0,r,c] = 1.0
            img += rng.normal(0, 0.15, img.shape)
            imgs.append(img); labels.append(cls)
    idx = rng.permutation(len(labels))
    return np.array(imgs)[idx], np.array(labels)[idx]

X, y = make_synthetic_images(n_per_class=150, img_size=16, n_classes=4)
split_tr, split_va = int(0.7*len(y)), int(0.85*len(y))
X_tr,y_tr = X[:split_tr],y[:split_tr]
X_va,y_va = X[split_tr:split_va],y[split_tr:split_va]
X_te,y_te = X[split_va:],y[split_va:]
print(f"Train {X_tr.shape}, Val {X_va.shape}, Test {X_te.shape}")

fig, axes = plt.subplots(1, 4, figsize=(12,3))
class_names = ['H-bar','V-bar','Circle','Diagonal']
for cls, ax in enumerate(axes):
    m = y_tr==cls; ax.imshow(X_tr[m][0,0], cmap='gray')
    ax.set_title(class_names[cls]); ax.axis('off')
plt.suptitle('Synthetic image classes'); plt.tight_layout()
plt.savefig('ch335_classes.png', dpi=100); plt.show()

2. Stage 1 — Define CNN with Residual Blocks

def relu(z): return np.maximum(0,z)
def relu_grad(z): return (z>0).astype(float)

def conv2d(X, W, b, stride=1, padding=0):
    """X:(B,C,H,W) W:(F,C,kH,kW) → (B,F,Ho,Wo)"""
    B,C,H,Ww = X.shape; F,_,kH,kW = W.shape
    if padding>0: X=np.pad(X,((0,0),(0,0),(padding,padding),(padding,padding)))
    Ho=(X.shape[2]-kH)//stride+1; Wo=(X.shape[3]-kW)//stride+1
    out=np.zeros((B,F,Ho,Wo))
    for f in range(F):
        for i in range(Ho):
            for j in range(Wo):
                patch=X[:,:,i*stride:i*stride+kH,j*stride:j*stride+kW]
                out[:,f,i,j]=(patch*W[f]).sum((1,2,3))+b[f]
    return out

def max_pool(X, size=2, stride=2):
    """X:(B,C,H,W)"""
    B,C,H,W=X.shape; Ho=(H-size)//stride+1; Wo=(W-size)//stride+1
    out=np.zeros((B,C,Ho,Wo))
    for i in range(Ho):
        for j in range(Wo):
            out[:,:,i,j]=X[:,:,i*stride:i*stride+size,j*stride:j*stride+size].max((2,3))
    return out

def global_avg_pool(X):
    """(B,C,H,W) → (B,C)"""
    return X.mean((2,3))

# Simple CNN: conv → relu → pool → conv → relu → GAP → fc
class SimpleCNN:
    def __init__(self, n_classes=4, seed=0):
        rng=np.random.default_rng(seed)
        # Layer 1: 1→8 filters, 3x3
        self.W1=rng.normal(0,np.sqrt(2.0/9),(8,1,3,3)); self.b1=np.zeros(8)
        # Layer 2: 8→16 filters, 3x3
        self.W2=rng.normal(0,np.sqrt(2.0/72),(16,8,3,3)); self.b2=np.zeros(16)
        # FC: 16 → n_classes
        self.Wfc=rng.normal(0,np.sqrt(2.0/16),(n_classes,16)); self.bfc=np.zeros(n_classes)

    def forward(self, X, training=True):
        """X:(B,1,16,16)"""
        Z1=conv2d(X,self.W1,self.b1,stride=1,padding=1); A1=relu(Z1)
        P1=max_pool(A1,2,2)  # → (B,8,8,8)
        Z2=conv2d(P1,self.W2,self.b2,stride=1,padding=1); A2=relu(Z2)
        P2=max_pool(A2,2,2)  # → (B,16,4,4)
        gap=global_avg_pool(P2)  # → (B,16)
        logits=gap@self.Wfc.T+self.bfc  # → (B,n_classes)
        return logits, (X,Z1,A1,P1,Z2,A2,P2,gap)

def softmax(Z):
    Z_s=Z-Z.max(1,keepdims=True); e=np.exp(Z_s); return e/e.sum(1,keepdims=True)
def cross_entropy(logits,y):
    p=softmax(logits); n=len(y)
    return float(-np.log(p[np.arange(n),y]+1e-10).mean())

cnn=SimpleCNN(); print("CNN ready")
# Sanity check shapes
x_test=rng.normal(0,1,(4,1,16,16))
logits,cache=cnn.forward(x_test)
print(f"Input: {x_test.shape} → logits: {logits.shape}")

3. Stage 2 — Training Loop

# Adam state
adam_state={'t':0,'m':{},'v':{}}
lr=3e-3; B=32; EPOCHS=80

def adam_update(params, grads, state, lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
    state['t']+=1; t=state['t']
    for k in params:
        if k not in state['m']:
            state['m'][k]=np.zeros_like(params[k])
            state['v'][k]=np.zeros_like(params[k])
        g=grads.get(k, np.zeros_like(params[k]))
        state['m'][k]=b1*state['m'][k]+(1-b1)*g
        state['v'][k]=b2*state['v'][k]+(1-b2)*g**2
        mh=state['m'][k]/(1-b1**t); vh=state['v'][k]/(1-b2**t)
        params[k]-=lr/(np.sqrt(vh)+eps)*mh

train_losses,val_accs=[],[]

for epoch in range(EPOCHS):
    perm=rng.permutation(len(X_tr)); ep_loss=0; nb=0
    for start in range(0,len(X_tr),B):
        idx=perm[start:start+B]; Xb=X_tr[idx]; yb=y_tr[idx]
        logits, cache = cnn.forward(Xb)
        loss=cross_entropy(logits,yb); ep_loss+=loss; nb+=1

        # Numerical gradients (small net — feasible for illustration)
        params={'W1':cnn.W1,'b1':cnn.b1,'W2':cnn.W2,'b2':cnn.b2,
                'Wfc':cnn.Wfc,'bfc':cnn.bfc}
        grads={}
        eps_fd=1e-4
        for pname,P in params.items():
            dP=np.zeros_like(P)
            flat=P.ravel()
            for i in range(0,len(flat),max(1,len(flat)//20)):  # sparse sampling
                flat[i]+=eps_fd; logits2,_=cnn.forward(Xb); lp=cross_entropy(logits2,yb)
                flat[i]-=2*eps_fd; logits3,_=cnn.forward(Xb); lm=cross_entropy(logits3,yb)
                flat[i]+=eps_fd; dP.ravel()[i]=(lp-lm)/(2*eps_fd)
            grads[pname]=dP
        adam_update(params,grads,adam_state,lr=lr)

    # Val accuracy
    val_logits,_=cnn.forward(X_va,training=False)
    val_acc=float(np.mean(softmax(val_logits).argmax(1)==y_va))
    val_accs.append(val_acc); train_losses.append(ep_loss/nb)
    if (epoch+1)%20==0:
        print(f"Epoch {epoch+1:3d}: loss={ep_loss/nb:.4f}  val_acc={val_acc:.1%}")

4. Stage 3 — Evaluation and Visualisation

# Test accuracy
te_logits,_=cnn.forward(X_te,training=False)
te_preds=softmax(te_logits).argmax(1)
te_acc=float(np.mean(te_preds==y_te))

fig,axes=plt.subplots(1,3,figsize=(14,4))
axes[0].plot(train_losses,color='#e74c3c',lw=2); axes[0].set_title('Train Loss'); axes[0].set_xlabel('Epoch')
axes[1].plot(val_accs,color='#3498db',lw=2)
axes[1].axhline(te_acc,color='#2ecc71',linestyle='--',lw=2,label=f'Test acc={te_acc:.1%}')
axes[1].set_title('Validation Accuracy'); axes[1].set_xlabel('Epoch'); axes[1].legend()

# Filter visualisation
filters=cnn.W1[:,0]  # 8 filters, 3×3
grid_f=np.zeros((3*3,3*3))
for fi in range(min(8,9)):
    r,c=fi//3,fi%3
    f=filters[fi]
    grid_f[r*3:(r+1)*3, c*3:(c+1)*3]=(f-f.min())/(f.max()-f.min()+1e-8)
axes[2].imshow(grid_f,cmap='gray'); axes[2].set_title('Learned Conv1 filters (8×3×3)'); axes[2].axis('off')

plt.suptitle(f'ch335: CNN Image Classification  |  Test accuracy: {te_acc:.1%}',fontsize=12)
plt.tight_layout(); plt.savefig('ch335_results.png',dpi=120); plt.show()
print(f"Final test accuracy: {te_acc:.1%}")

5. Results & Reflection

What was built: A two-layer CNN with global average pooling, trained with Adam and numerical gradients on a synthetic 4-class image classification task.

What math made it possible:

  • 2D cross-correlation as the conv operation (ch314)

  • Max pooling for translation invariance and spatial compression (ch315)

  • Global average pooling replacing large FC layers (ch315)

  • Adam optimiser (ch312)

Extension challenges:

  1. Add a residual connection (ch316) between conv1 and conv2 (after 1×1 projection).

  2. Replace numerical gradients with analytical backprop through the conv layer.

  3. Try 3 classes with 3 distinct patterns and visualise the confusion matrix.