Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

ch331 — Model Interpretability

1. Why interpretability matters

A trained neural network is a function fθ:xyf_\theta: x \to y. We often need to understand:

  • What input features drive predictions? (for debugging and trust)

  • What has the model learned? (for scientific discovery)

  • When will the model fail? (for safety)

Interpretability is not one problem — it is a family of questions with different tools.


2. Gradient-based attribution

Saliency maps: f(x)xi\frac{\partial f(x)}{\partial x_i} — how much does output change if input ii changes? This is just the gradient of the output w.r.t. the input — standard backprop.

Integrated Gradients (Sundararajan et al., 2017): integrate gradients along a straight line from a baseline xx' to the input xx:

IGi(x)=(xixi)01f(x+α(xx))xidα\text{IG}_i(x) = (x_i - x'_i) \cdot \int_0^1 \frac{\partial f(x' + \alpha(x-x'))}{\partial x_i} d\alpha

Satisfies desirable axioms (completeness, sensitivity). More reliable than raw gradients.

(Gradient computation: ch306. Integration: ch221.)

import numpy as np
import matplotlib.pyplot as plt


def relu(z): return np.maximum(0, z)
def relu_grad(z): return (z > 0).astype(float)
def sigmoid(z): return 1/(1+np.exp(-np.clip(z,-500,500)))


class SimpleMLPInterp:
    """MLP with forward + input-gradient computation."""

    def __init__(self, layer_sizes: list, seed: int = 0):
        rng = np.random.default_rng(seed)
        self.params = []
        for i in range(len(layer_sizes)-1):
            fi, fo = layer_sizes[i], layer_sizes[i+1]
            W = rng.normal(0, np.sqrt(2.0/fi), (fo, fi))
            b = np.zeros(fo)
            self.params.append((W, b))

    def forward(self, x: np.ndarray) -> tuple:
        a = x; cache = [x]
        for i, (W, b) in enumerate(self.params):
            z = W @ a + b
            a = sigmoid(z) if i==len(self.params)-1 else relu(z)
            cache.extend([z, a])
        return a, cache

    def input_gradient(self, x: np.ndarray) -> np.ndarray:
        """Compute dOutput/dInput using backprop."""
        out, cache = self.forward(x)
        # Backward from scalar output
        n_layers = len(self.params)
        dA = np.array([1.0])  # gradient of scalar output
        for l in range(n_layers-1, -1, -1):
            z = cache[1 + l*2]
            W, _ = self.params[l]
            if l == n_layers-1:
                dZ = dA * sigmoid(z) * (1-sigmoid(z))
            else:
                dZ = dA * relu_grad(z)
            dA = W.T @ dZ
        return dA  # gradient w.r.t. input


def integrated_gradients(model, x: np.ndarray, x_baseline: np.ndarray,
                          n_steps: int = 50) -> np.ndarray:
    """IG attribution: integrate gradient from baseline to input."""
    alphas = np.linspace(0, 1, n_steps)
    grads = []
    for alpha in alphas:
        x_interp = x_baseline + alpha * (x - x_baseline)
        g = model.input_gradient(x_interp)
        grads.append(g)
    avg_grad = np.mean(grads, axis=0)
    ig = (x - x_baseline) * avg_grad
    return ig


# Train a simple model on synthetic data where features 0 and 2 are informative
rng = np.random.default_rng(42)
n = 300; d = 6
X = rng.normal(0, 1, (n, d))
# True function: depends on features 0 and 2 only
y = (X[:,0] + 2*X[:,2] > 0).astype(float)

model = SimpleMLPInterp([d, 32, 16, 1], seed=0)

# Simple training loop
lr = 0.05
for epoch in range(500):
    for i in rng.permutation(n):
        xi = X[i]; yi = y[i]
        out, cache = model.forward(xi)
        # BCE gradient at output
        dout = np.array([(out[0] - yi)])
        # Numerical weight updates (clean-room for brevity)
        for l, (W, b) in enumerate(model.params):
            eps = 1e-4
            dW = np.zeros_like(W)
            for idx in np.ndindex(*W.shape):
                if rng.random() > 0.15: continue
                W[idx] += eps; op, _ = model.forward(xi); lp = (op[0]-yi)**2
                W[idx] -= 2*eps; om, _ = model.forward(xi); lm = (om[0]-yi)**2
                W[idx] += eps; dW[idx] = (lp-lm)/(2*eps)
            W -= lr * dW

# Compute attributions for a test sample
x_test = rng.normal(0, 1, d)
x_baseline = np.zeros(d)

raw_grad = model.input_gradient(x_test)
ig_attrs = integrated_gradients(model, x_test, x_baseline)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))
feature_names = [f'Feature {i}' for i in range(d)]
colors_raw = ['#e74c3c' if v > 0 else '#3498db' for v in raw_grad]
colors_ig  = ['#e74c3c' if v > 0 else '#3498db' for v in ig_attrs]

axes[0].bar(feature_names, raw_grad, color=colors_raw)
axes[0].set_title('Raw gradient attribution'); axes[0].tick_params(axis='x', rotation=45)
axes[0].axhline(0, color='black', lw=0.8)

axes[1].bar(feature_names, ig_attrs, color=colors_ig)
axes[1].set_title('Integrated Gradients attribution'); axes[1].tick_params(axis='x', rotation=45)
axes[1].axhline(0, color='black', lw=0.8)

plt.suptitle('Features 0 and 2 are truly informative — attributions should highlight them', fontsize=11)
plt.tight_layout()
plt.savefig('ch331_interpretability.png', dpi=120)
plt.show()

3. LIME and SHAP

LIME (Ribeiro et al., 2016): locally approximate a complex model with a linear model. Perturb the input, observe outputs, fit a linear model to the perturbations.

SHAP (Lundberg & Lee, 2017): use Shapley values from cooperative game theory. Each feature’s contribution is its average marginal effect over all possible orderings. SHAP values satisfy desirable axioms: efficiency, symmetry, dummy, linearity.


4. Mechanistic interpretability

Rather than explaining individual predictions, mechanistic interpretability asks: what algorithms do specific circuits in the network implement?

In Transformers, attention heads have been found to implement:

  • Previous token detection

  • Duplicate token detection

  • Induction heads (prefix matching for in-context learning)

This is an active research area with growing importance for AI safety.


5. Summary

  • Saliency maps: f/xi\partial f / \partial x_i — fast but noisy.

  • Integrated Gradients: axiomatically principled, more reliable attribution.

  • LIME: local linear approximation; model-agnostic.

  • SHAP: Shapley-value-based; theoretically grounded, computationally expensive.

  • Mechanistic interpretability: find the circuits and algorithms inside the network.


6. Forward and backward references

Used here: backpropagation (ch306), input gradients (ch306), integration (ch221).

This will reappear in ch340 — Capstone II, where attribution methods are applied to explain the end-to-end system’s predictions.