Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

1. From value-based to policy-based

Q-learning learns a value function and derives the policy implicitly (π=argmaxaQ(s,a)\pi = \arg\max_a Q(s,a)). This requires a discrete action space (argmax is only defined over finite sets).

Policy gradient methods directly parameterise the policy πθ(as)\pi_\theta(a|s) and optimise expected return by gradient ascent:

J(θ)=Eτπθ[G(τ)]J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[G(\tau)]
θJ(θ)=Eτπθ[tθlogπθ(atst)Gt]\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}\left[\sum_t \nabla_\theta \log \pi_\theta(a_t|s_t) \cdot G_t\right]

This is the REINFORCE (Williams, 1992) estimator. It works because:

θE[f(x)]=E[f(x)θlogpθ(x)]\nabla_\theta \mathbb{E}[f(x)] = \mathbb{E}[f(x) \nabla_\theta \log p_\theta(x)]

— the log-derivative trick (score function estimator).

(Expected value: ch249. Log derivative: ch205. Gradient ascent: ch212.)

import numpy as np
import matplotlib.pyplot as plt


def softmax(z):
    z_s = z - z.max()
    e = np.exp(z_s)
    return e / e.sum()

def relu(z): return np.maximum(0, z)


class PolicyNetwork:
    """Simple MLP policy for discrete action spaces."""

    def __init__(self, state_dim: int, n_actions: int,
                 hidden: int = 32, seed: int = 0):
        rng = np.random.default_rng(seed)
        self.W1 = rng.normal(0, np.sqrt(2.0/state_dim), (hidden, state_dim))
        self.b1 = np.zeros(hidden)
        self.W2 = rng.normal(0, np.sqrt(2.0/hidden), (n_actions, hidden))
        self.b2 = np.zeros(n_actions)

    def forward(self, s: np.ndarray) -> np.ndarray:
        """Returns action probabilities."""
        h = relu(self.W1 @ s + self.b1)
        logits = self.W2 @ h + self.b2
        return softmax(logits)

    def log_prob(self, s: np.ndarray, a: int) -> float:
        probs = self.forward(s)
        return float(np.log(probs[a] + 1e-10))

    def update(self, grads: dict, lr: float):
        for k in grads:
            getattr(self, k).__iadd__(lr * grads[k])  # gradient ASCENT


class SimpleCartPole:
    """Simplified CartPole-like environment (1D linear system)."""
    def __init__(self): self.reset()

    def reset(self):
        self.x = np.zeros(4)
        self.x[:2] = np.random.normal(0, 0.05, 2)
        self.t = 0
        return self.x.copy()

    def step(self, a: int):
        # Simplified linearised dynamics
        force = 10.0 * (a - 0.5)  # -5 or +5
        self.x[0] += 0.02 * self.x[1]
        self.x[1] += 0.02 * (force - 0.1*self.x[1])
        self.x[2] += 0.02 * self.x[3]
        self.x[3] += 0.02 * (force*0.1 + 0.5*self.x[2] - 0.05*self.x[3])
        self.t += 1
        done = (abs(self.x[2]) > 0.3 or abs(self.x[0]) > 2.0 or self.t >= 200)
        reward = 1.0 if not done else 0.0
        return self.x.copy(), reward, done


def collect_trajectory(policy, env, rng):
    """Run one episode. Returns (states, actions, returns)."""
    s = env.reset(); done = False
    states, actions, rewards = [], [], []
    while not done:
        probs = policy.forward(s)
        a = rng.choice(len(probs), p=probs)
        s_next, r, done = env.step(a)
        states.append(s); actions.append(a); rewards.append(r)
        s = s_next

    # Compute discounted returns
    gamma = 0.99; G = 0.0; returns = []
    for r in reversed(rewards):
        G = r + gamma * G; returns.insert(0, G)

    return states, actions, np.array(returns)


def reinforce_update(policy, states, actions, returns, lr=0.01):
    """REINFORCE gradient update."""
    baseline = returns.mean()  # subtract mean as control variate
    dW1 = np.zeros_like(policy.W1)
    dW2 = np.zeros_like(policy.W2)
    db1 = np.zeros_like(policy.b1)
    db2 = np.zeros_like(policy.b2)

    for s, a, G in zip(states, actions, returns):
        advantage = G - baseline
        probs = policy.forward(s)
        h = relu(policy.W1 @ s + policy.b1)

        # Score function gradient: d/dθ log π(a|s) * advantage
        d_logits = -probs; d_logits[a] += 1.0  # (n_actions,)
        dW2 += advantage * np.outer(d_logits, h)
        db2 += advantage * d_logits
        d_h = policy.W2.T @ d_logits
        d_h *= (h > 0)  # relu gradient
        dW1 += advantage * np.outer(d_h, s)
        db1 += advantage * d_h

    n = len(states)
    policy.update({'W1': dW1/n, 'W2': dW2/n, 'b1': db1/n, 'b2': db2/n}, lr)


rng = np.random.default_rng(42)
env = SimpleCartPole()
policy = PolicyNetwork(state_dim=4, n_actions=2, hidden=32)

episode_lengths = []
for ep in range(600):
    states, actions, returns = collect_trajectory(policy, env, rng)
    reinforce_update(policy, states, actions, returns, lr=0.003)
    episode_lengths.append(len(states))

window = 30
smoothed = np.convolve(episode_lengths, np.ones(window)/window, mode='valid')

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(episode_lengths, alpha=0.3, color='#3498db', lw=1)
ax.plot(smoothed, color='#e74c3c', lw=2, label=f'{window}-ep moving average')
ax.set_xlabel('Episode'); ax.set_ylabel('Episode length (survival steps)')
ax.set_title('REINFORCE on simplified CartPole')
ax.legend()
plt.tight_layout()
plt.savefig('ch330_policy_gradients.png', dpi=120)
plt.show()
print(f"Early (ep 1-50):  avg length = {np.mean(episode_lengths[:50]):.1f}")
print(f"Late  (ep 550+):  avg length = {np.mean(episode_lengths[-50:]):.1f}")

2. Variance reduction: baselines

The REINFORCE estimator has high variance — returns can vary wildly across trajectories. Subtracting a baseline b(s)b(s) that does not depend on aa reduces variance without introducing bias:

θJ=E[θlogπθ(as)(Gb(s))]\nabla_\theta J = \mathbb{E}\left[\nabla_\theta \log \pi_\theta(a|s) \cdot (G - b(s))\right]

The advantage A(s,a)=Q(s,a)V(s)A(s,a) = Q(s,a) - V(s) is the canonical baseline: how much better is action aa than the average action from state ss?


3. PPO — Proximal Policy Optimisation

PPO (Schulman et al., 2017) clips the policy update ratio to prevent destructively large steps:

LCLIP(θ)=E[min(rt(θ)At,  clip(rt(θ),1ε,1+ε)At)]\mathcal{L}^{\text{CLIP}}(\theta) = \mathbb{E}\left[\min\left(r_t(\theta) A_t,\; \text{clip}(r_t(\theta), 1-\varepsilon, 1+\varepsilon) A_t\right)\right]

where rt(θ)=πθ(atst)/πθold(atst)r_t(\theta) = \pi_\theta(a_t|s_t) / \pi_{\theta_{\text{old}}}(a_t|s_t).

PPO is the workhorse algorithm for modern RL: RLHF (Reinforcement Learning from Human Feedback) used to align LLMs (e.g., InstructGPT, ChatGPT) uses PPO as the core training algorithm.


4. Summary

  • Policy gradient: directly optimise J(θ)J(\theta) by gradient ascent using the log-derivative trick.

  • REINFORCE: collect trajectories, weight log-probs by returns, update parameters.

  • Baseline subtraction: reduces variance without bias. Advantage = Q - V.

  • PPO: clips policy update ratio to keep updates safe. Standard for RLHF.


5. Forward and backward references

Used here: expected value (ch249), log derivatives (ch205), gradient ascent (ch212), Q-learning foundations (ch329).

This will reappear in ch339 — Project: RL CartPole, where a policy network is trained with REINFORCE on OpenAI Gym’s CartPole environment.