0. Overview¶
Problem: Train a VAE on a synthetic 2D dataset of four Gaussian clusters, learn a structured latent space, and sample new points from the prior.
Concepts used: VAE theory (ch325), encoder-decoder (ch319), ELBO loss, reparameterisation trick, KL divergence (ch295), Adam (ch312), PCA (ch178).
Expected output: trained VAE with structured 2D latent space, interpolation demo, and reconstruction quality evaluation.
Difficulty: ★★★★☆ | Estimated time: 90 minutes
1. Setup¶
import numpy as np
import matplotlib.pyplot as plt
def relu(z): return np.maximum(0,z)
def sigmoid(z): return 1/(1+np.exp(-np.clip(z,-500,500)))
rng=np.random.default_rng(42)
# Dataset: 4 clusters in 2D
centres=[(2,2),(-2,2),(-2,-2),(2,-2)]
n_per=150
X_all=np.vstack([rng.multivariate_normal(c,0.3*np.eye(2),n_per) for c in centres])
y_all=np.repeat(np.arange(4),n_per)
perm=rng.permutation(len(X_all))
X_all,y_all=X_all[perm],y_all[perm]
print(f"Dataset: {X_all.shape}")2. Stage 1 — VAE Implementation¶
class VAE2D:
"""VAE for 2D input with 2D latent space."""
def __init__(self, input_dim=2, latent_dim=2, hidden=64, seed=0):
rng=np.random.default_rng(seed); s=lambda fi,fo:np.sqrt(2./fi)
self.We1=rng.normal(0,s(input_dim,hidden),(hidden,input_dim)); self.be1=np.zeros(hidden)
self.We2=rng.normal(0,s(hidden,hidden),(hidden,hidden)); self.be2=np.zeros(hidden)
self.Wmu=rng.normal(0,0.01,(latent_dim,hidden)); self.bmu=np.zeros(latent_dim)
self.Wlv=rng.normal(0,0.01,(latent_dim,hidden)); self.blv=np.zeros(latent_dim)
self.Wd1=rng.normal(0,s(latent_dim,hidden),(hidden,latent_dim)); self.bd1=np.zeros(hidden)
self.Wd2=rng.normal(0,s(hidden,hidden),(hidden,hidden)); self.bd2=np.zeros(hidden)
self.Wd3=rng.normal(0,s(hidden,input_dim),(input_dim,hidden)); self.bd3=np.zeros(input_dim)
self.latent_dim=latent_dim
def encode(self, x):
h=relu(x@self.We1.T+self.be1); h=relu(h@self.We2.T+self.be2)
return h@self.Wmu.T+self.bmu, h@self.Wlv.T+self.blv
def decode(self, z):
h=relu(z@self.Wd1.T+self.bd1); h=relu(h@self.Wd2.T+self.bd2)
return h@self.Wd3.T+self.bd3
def forward(self, x, rng):
mu,lv=self.encode(x)
std=np.exp(0.5*lv); eps=rng.standard_normal(mu.shape)
z=mu+eps*std
x_hat=self.decode(z)
recon=np.mean((x-x_hat)**2)
kl=float(-0.5*np.mean(1+lv-mu**2-np.exp(lv)))
return x_hat,mu,lv,z,recon+kl,recon,kl
vae=VAE2D(latent_dim=2,hidden=32,seed=0)
print("VAE initialised")3. Stage 2 — Training¶
EPOCHS=500; lr=1e-3
all_params={'We1':vae.We1,'be1':vae.be1,'We2':vae.We2,'be2':vae.be2,
'Wmu':vae.Wmu,'bmu':vae.bmu,'Wlv':vae.Wlv,'blv':vae.blv,
'Wd1':vae.Wd1,'bd1':vae.bd1,'Wd2':vae.Wd2,'bd2':vae.bd2,
'Wd3':vae.Wd3,'bd3':vae.bd3}
elbo_losses=[]; recon_losses=[]; kl_losses=[]
for epoch in range(EPOCHS):
# Full batch (small dataset)
_,mu,lv,z,elbo,recon,kl=vae.forward(X_all,rng)
elbo_losses.append(elbo); recon_losses.append(recon); kl_losses.append(kl)
# Numerical gradient (sparse)
eps_fd=1e-4
for pname,P in all_params.items():
flat=P.ravel(); n_s=max(1,len(flat)//25)
idxs=rng.choice(len(flat),n_s,replace=False)
for i in idxs:
flat[i]+=eps_fd; _,_,_,_,lp,_,_=vae.forward(X_all,rng)
flat[i]-=2*eps_fd; _,_,_,_,lm,_,_=vae.forward(X_all,rng)
flat[i]+=eps_fd; flat[i]-=lr*(lp-lm)/(2*eps_fd)
if (epoch+1)%100==0:
print(f"Epoch {epoch+1:4d}: ELBO={elbo:.4f} recon={recon:.4f} kl={kl:.4f}")4. Stage 3 — Latent Space Analysis¶
_,mu_enc,lv_enc,_,_,_,_=vae.forward(X_all,rng)
fig,axes=plt.subplots(1,4,figsize=(16,4))
colors=['#e74c3c','#3498db','#2ecc71','#f39c12']
# ELBO components
axes[0].plot(elbo_losses,label='ELBO',color='#2c3e50',lw=2)
axes[0].plot(recon_losses,label='Recon',color='#e74c3c',lw=1.5,linestyle='--')
axes[0].plot(kl_losses,label='KL',color='#3498db',lw=1.5,linestyle='--')
axes[0].set_title('ELBO training'); axes[0].legend(); axes[0].set_xlabel('Epoch')
# Input space
for cls in range(4):
m=y_all==cls
axes[1].scatter(X_all[m,0],X_all[m,1],c=colors[cls],s=15,alpha=0.7)
axes[1].set_title('Input space (4 clusters)'); axes[1].set_aspect('equal')
# Latent space
for cls in range(4):
m=y_all==cls
axes[2].scatter(mu_enc[m,0],mu_enc[m,1],c=colors[cls],s=15,alpha=0.7)
circle=plt.Circle((0,0),1,fill=False,color='black',linestyle='--',lw=1.5)
axes[2].add_patch(circle); axes[2].set_title('Latent space (encoded means)')
axes[2].set_aspect('equal')
# Generated samples from prior
z_prior=rng.standard_normal((300,2))
x_gen=vae.decode(z_prior)
axes[3].scatter(x_gen[:,0],x_gen[:,1],c='#9b59b6',s=15,alpha=0.5,label='Generated')
axes[3].scatter(X_all[:,0],X_all[:,1],c='gray',s=5,alpha=0.2,label='Real')
axes[3].set_title('Generated vs Real data'); axes[3].legend(fontsize=8)
axes[3].set_aspect('equal')
plt.suptitle('ch338: VAE — Input→Latent→Generate',fontsize=11)
plt.tight_layout(); plt.savefig('ch338_vae.png',dpi=120); plt.show()
# Interpolation in latent space
z1=mu_enc[y_all==0][0]; z2=mu_enc[y_all==2][0]
interp_zs=np.array([(1-t)*z1+t*z2 for t in np.linspace(0,1,8)])
interp_x=vae.decode(interp_zs)
print("Latent interpolation from cluster 0 to cluster 2:")
for i,(z,x) in enumerate(zip(interp_zs,interp_x)):
print(f" t={i/7:.2f}: z={z.round(2)} → x={x.round(2)}")5. Results & Reflection¶
What was built: A VAE trained on 2D clustered data. The latent space organises clusters near ; new samples from the prior decode to plausible data.
What math made it possible:
ELBO = reconstruction + KL (ch325): jointly trains encoder and decoder
Reparameterisation trick (ch325): makes sampling differentiable
KL to (ch295): closed-form expression, no integration needed
Decode from prior at generation time: sampling from a known distribution (ch253)
Extension challenges:
Increase latent dimension to 4; use PCA (ch178) to visualise it in 2D.
Implement -VAE () and observe whether the latent space becomes more disentangled.
Train on a slightly larger dataset and measure reconstruction MSE vs KL weight tradeoff.