Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

ch338 — Project: Variational Autoencoder

0. Overview

Problem: Train a VAE on a synthetic 2D dataset of four Gaussian clusters, learn a structured latent space, and sample new points from the prior.

Concepts used: VAE theory (ch325), encoder-decoder (ch319), ELBO loss, reparameterisation trick, KL divergence (ch295), Adam (ch312), PCA (ch178).

Expected output: trained VAE with structured 2D latent space, interpolation demo, and reconstruction quality evaluation.

Difficulty: ★★★★☆ | Estimated time: 90 minutes

1. Setup

import numpy as np
import matplotlib.pyplot as plt

def relu(z): return np.maximum(0,z)
def sigmoid(z): return 1/(1+np.exp(-np.clip(z,-500,500)))

rng=np.random.default_rng(42)
# Dataset: 4 clusters in 2D
centres=[(2,2),(-2,2),(-2,-2),(2,-2)]
n_per=150
X_all=np.vstack([rng.multivariate_normal(c,0.3*np.eye(2),n_per) for c in centres])
y_all=np.repeat(np.arange(4),n_per)
perm=rng.permutation(len(X_all))
X_all,y_all=X_all[perm],y_all[perm]
print(f"Dataset: {X_all.shape}")

2. Stage 1 — VAE Implementation

class VAE2D:
    """VAE for 2D input with 2D latent space."""

    def __init__(self, input_dim=2, latent_dim=2, hidden=64, seed=0):
        rng=np.random.default_rng(seed); s=lambda fi,fo:np.sqrt(2./fi)
        self.We1=rng.normal(0,s(input_dim,hidden),(hidden,input_dim)); self.be1=np.zeros(hidden)
        self.We2=rng.normal(0,s(hidden,hidden),(hidden,hidden));       self.be2=np.zeros(hidden)
        self.Wmu=rng.normal(0,0.01,(latent_dim,hidden));               self.bmu=np.zeros(latent_dim)
        self.Wlv=rng.normal(0,0.01,(latent_dim,hidden));               self.blv=np.zeros(latent_dim)
        self.Wd1=rng.normal(0,s(latent_dim,hidden),(hidden,latent_dim)); self.bd1=np.zeros(hidden)
        self.Wd2=rng.normal(0,s(hidden,hidden),(hidden,hidden));         self.bd2=np.zeros(hidden)
        self.Wd3=rng.normal(0,s(hidden,input_dim),(input_dim,hidden));   self.bd3=np.zeros(input_dim)
        self.latent_dim=latent_dim

    def encode(self, x):
        h=relu(x@self.We1.T+self.be1); h=relu(h@self.We2.T+self.be2)
        return h@self.Wmu.T+self.bmu, h@self.Wlv.T+self.blv

    def decode(self, z):
        h=relu(z@self.Wd1.T+self.bd1); h=relu(h@self.Wd2.T+self.bd2)
        return h@self.Wd3.T+self.bd3

    def forward(self, x, rng):
        mu,lv=self.encode(x)
        std=np.exp(0.5*lv); eps=rng.standard_normal(mu.shape)
        z=mu+eps*std
        x_hat=self.decode(z)
        recon=np.mean((x-x_hat)**2)
        kl=float(-0.5*np.mean(1+lv-mu**2-np.exp(lv)))
        return x_hat,mu,lv,z,recon+kl,recon,kl

vae=VAE2D(latent_dim=2,hidden=32,seed=0)
print("VAE initialised")

3. Stage 2 — Training

EPOCHS=500; lr=1e-3
all_params={'We1':vae.We1,'be1':vae.be1,'We2':vae.We2,'be2':vae.be2,
            'Wmu':vae.Wmu,'bmu':vae.bmu,'Wlv':vae.Wlv,'blv':vae.blv,
            'Wd1':vae.Wd1,'bd1':vae.bd1,'Wd2':vae.Wd2,'bd2':vae.bd2,
            'Wd3':vae.Wd3,'bd3':vae.bd3}

elbo_losses=[]; recon_losses=[]; kl_losses=[]

for epoch in range(EPOCHS):
    # Full batch (small dataset)
    _,mu,lv,z,elbo,recon,kl=vae.forward(X_all,rng)
    elbo_losses.append(elbo); recon_losses.append(recon); kl_losses.append(kl)

    # Numerical gradient (sparse)
    eps_fd=1e-4
    for pname,P in all_params.items():
        flat=P.ravel(); n_s=max(1,len(flat)//25)
        idxs=rng.choice(len(flat),n_s,replace=False)
        for i in idxs:
            flat[i]+=eps_fd; _,_,_,_,lp,_,_=vae.forward(X_all,rng)
            flat[i]-=2*eps_fd; _,_,_,_,lm,_,_=vae.forward(X_all,rng)
            flat[i]+=eps_fd; flat[i]-=lr*(lp-lm)/(2*eps_fd)

    if (epoch+1)%100==0:
        print(f"Epoch {epoch+1:4d}: ELBO={elbo:.4f} recon={recon:.4f} kl={kl:.4f}")

4. Stage 3 — Latent Space Analysis

_,mu_enc,lv_enc,_,_,_,_=vae.forward(X_all,rng)

fig,axes=plt.subplots(1,4,figsize=(16,4))
colors=['#e74c3c','#3498db','#2ecc71','#f39c12']

# ELBO components
axes[0].plot(elbo_losses,label='ELBO',color='#2c3e50',lw=2)
axes[0].plot(recon_losses,label='Recon',color='#e74c3c',lw=1.5,linestyle='--')
axes[0].plot(kl_losses,label='KL',color='#3498db',lw=1.5,linestyle='--')
axes[0].set_title('ELBO training'); axes[0].legend(); axes[0].set_xlabel('Epoch')

# Input space
for cls in range(4):
    m=y_all==cls
    axes[1].scatter(X_all[m,0],X_all[m,1],c=colors[cls],s=15,alpha=0.7)
axes[1].set_title('Input space (4 clusters)'); axes[1].set_aspect('equal')

# Latent space
for cls in range(4):
    m=y_all==cls
    axes[2].scatter(mu_enc[m,0],mu_enc[m,1],c=colors[cls],s=15,alpha=0.7)
circle=plt.Circle((0,0),1,fill=False,color='black',linestyle='--',lw=1.5)
axes[2].add_patch(circle); axes[2].set_title('Latent space (encoded means)')
axes[2].set_aspect('equal')

# Generated samples from prior
z_prior=rng.standard_normal((300,2))
x_gen=vae.decode(z_prior)
axes[3].scatter(x_gen[:,0],x_gen[:,1],c='#9b59b6',s=15,alpha=0.5,label='Generated')
axes[3].scatter(X_all[:,0],X_all[:,1],c='gray',s=5,alpha=0.2,label='Real')
axes[3].set_title('Generated vs Real data'); axes[3].legend(fontsize=8)
axes[3].set_aspect('equal')

plt.suptitle('ch338: VAE — Input→Latent→Generate',fontsize=11)
plt.tight_layout(); plt.savefig('ch338_vae.png',dpi=120); plt.show()

# Interpolation in latent space
z1=mu_enc[y_all==0][0]; z2=mu_enc[y_all==2][0]
interp_zs=np.array([(1-t)*z1+t*z2 for t in np.linspace(0,1,8)])
interp_x=vae.decode(interp_zs)
print("Latent interpolation from cluster 0 to cluster 2:")
for i,(z,x) in enumerate(zip(interp_zs,interp_x)):
    print(f"  t={i/7:.2f}: z={z.round(2)} → x={x.round(2)}")

5. Results & Reflection

What was built: A VAE trained on 2D clustered data. The latent space organises clusters near N(0,I)\mathcal{N}(0,I); new samples from the prior decode to plausible data.

What math made it possible:

  • ELBO = reconstruction + KL (ch325): jointly trains encoder and decoder

  • Reparameterisation trick (ch325): makes sampling differentiable

  • KL to N(0,I)\mathcal{N}(0,I) (ch295): closed-form expression, no integration needed

  • Decode from prior at generation time: sampling from a known distribution (ch253)

Extension challenges:

  1. Increase latent dimension to 4; use PCA (ch178) to visualise it in 2D.

  2. Implement β\beta-VAE (β=4\beta=4) and observe whether the latent space becomes more disentangled.

  3. Train on a slightly larger dataset and measure reconstruction MSE vs KL weight tradeoff.