Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

ch333 — Training Infrastructure

1. Why infrastructure matters

A perfectly designed model trained inefficiently wastes compute. At large scale:

  • GPT-3: 3.14 × 10^23 FLOPs, ~1000 A100-GPU-days.

  • LLaMA 3 405B: ~10^25 FLOPs.

Infrastructure determines whether a given compute budget is achievable in practice.


2. Data parallelism

Split the mini-batch across KK GPUs. Each GPU:

  1. Holds a full copy of the model.

  2. Processes B/KB/K samples.

  3. Computes local gradients.

  4. All-reduce: averages gradients across GPUs (ring-all-reduce algorithm).

  5. Updates local model copy.

Communication cost: O(Nparams)O(N_{\text{params}}) per step. Efficient for moderate model sizes.


3. Model parallelism

When the model doesn’t fit in one GPU’s memory:

Tensor parallelism (Megatron-LM): split weight matrices across GPUs. For a linear layer Y=XWY = XW: split WW column-wise across GPUs, compute partial YY, all-reduce.

Pipeline parallelism: assign different layers to different GPUs. GPU 1 processes layers 1–12, GPU 2 processes layers 13–24, etc. Challenge: GPUs are idle while waiting for upstream results → schedule micro-batches.

ZeRO (Zero Redundancy Optimizer): shard optimizer states, gradients, and parameters across GPUs. ZeRO-3 eliminates all redundancy — each GPU holds only 1/K1/K of everything.

(Linear algebra operations: ch151. Distributed matrix multiply: ch153.)

import numpy as np
import matplotlib.pyplot as plt


# ── Simulate memory and compute requirements at scale ──

def model_memory_GB(params: float, precision_bytes: int = 2) -> float:
    """Memory for model weights alone (FP16 = 2 bytes per param)."""
    return params * precision_bytes / 1e9

def training_memory_GB(params: float, precision_bytes: int = 2) -> float:
    """Approximate training memory: weights + gradients + Adam states (8 bytes each)."""
    weight_mem = params * precision_bytes   # FP16 weights
    grad_mem   = params * precision_bytes   # FP16 gradients
    adam_mem   = params * 4 * 2            # two FP32 Adam states (m, v)
    return (weight_mem + grad_mem + adam_mem) / 1e9

def activation_memory_GB(params: float, batch_size: int = 8, seq_len: int = 2048) -> float:
    """Rough activation memory estimate."""
    # Scales with batch and sequence length, roughly proportional to sqrt(params)
    d_model = int(np.sqrt(params / 12))  # heuristic from param count formula
    return batch_size * seq_len * d_model * 2 / 1e9


# Model sizes: 100M to 1T parameters
sizes = np.logspace(8, 12, 100)
gpu_80GB = 80.0

fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# Memory requirements
mem_weights  = [model_memory_GB(N)    for N in sizes]
mem_training = [training_memory_GB(N) for N in sizes]
mem_act      = [activation_memory_GB(N) for N in sizes]

axes[0].loglog(sizes/1e9, mem_weights,  label='Inference (FP16 weights only)', lw=2, color='#2ecc71')
axes[0].loglog(sizes/1e9, mem_training, label='Training (weights+grad+Adam)',  lw=2, color='#e74c3c')
axes[0].loglog(sizes/1e9, [m+a for m,a in zip(mem_training, mem_act)],
               label='Training + activations (B=8)', lw=2, color='#9b59b6', linestyle='--')
axes[0].axhline(gpu_80GB, color='black', linestyle=':', lw=2, label='80GB GPU memory')
axes[0].axhline(gpu_80GB*8, color='gray', linestyle=':', lw=2, label='8×80GB node')
axes[0].set_xlabel('Parameters (billions)'); axes[0].set_ylabel('Memory (GB)')
axes[0].set_title('Memory requirements vs model size')
axes[0].legend(fontsize=8); axes[0].grid(True, alpha=0.3)

# GPUs needed with ZeRO-3 (distribute all states)
gpus_needed = [int(np.ceil(training_memory_GB(N) / gpu_80GB)) for N in sizes]
axes[1].loglog(sizes/1e9, gpus_needed, color='#3498db', lw=2)
axes[1].set_xlabel('Parameters (billions)'); axes[1].set_ylabel('Min GPUs (ZeRO-3 training)')
axes[1].set_title('Minimum GPUs for training\n(ZeRO-3, 80GB each)')
axes[1].grid(True, alpha=0.3)

# Annotate key models
key_models = [('GPT-3', 175e9, '#e74c3c'), ('LLaMA-3 70B', 70e9, '#f39c12'),
              ('LLaMA-3 405B', 405e9, '#9b59b6')]
for name, N, color in key_models:
    gpus = int(np.ceil(training_memory_GB(N) / gpu_80GB))
    axes[1].scatter([N/1e9], [gpus], color=color, s=80, zorder=5)
    axes[1].annotate(f'{name}\n({gpus} GPUs)', (N/1e9, gpus), fontsize=7,
                     xytext=(5, 5), textcoords='offset points')

plt.tight_layout()
plt.savefig('ch333_infrastructure.png', dpi=120)
plt.show()

print("Memory estimates for key models (FP16, B=8, seq=2048):")
for name, N, _ in key_models:
    print(f"  {name} ({N/1e9:.0f}B): training={training_memory_GB(N):.0f}GB, "
          f"min GPUs={int(np.ceil(training_memory_GB(N)/gpu_80GB))}")

3. Mixed precision training

Training in FP32 is 2× slower and uses 2× memory vs FP16/BF16. Mixed precision: use FP16 for forward/backward (fast), FP32 for weight updates (numerically stable).

BF16 (bfloat16): same range as FP32, lower precision. Preferred for LLMs because it avoids overflow issues common with FP16.

Gradient scaling: multiply loss by a large constant before backward to prevent FP16 underflow; divide gradients by the same constant before the optimizer step.


4. Gradient checkpointing

During backprop, intermediate activations must be stored for gradient computation. Memory = O(LBTd)O(L \cdot B \cdot T \cdot d) where LL = layers.

Gradient checkpointing: discard intermediate activations during the forward pass; recompute them on-the-fly during the backward pass. Trades compute for memory: costs one extra forward pass but reduces activation memory from O(L)O(L) to O(L)O(\sqrt{L}).


5. Summary

  • Data parallelism: split batch across GPUs; all-reduce gradients. Linear speedup.

  • Tensor/pipeline parallelism: split model across GPUs for models that don’t fit.

  • ZeRO: shard optimizer states/gradients/params; eliminate redundant copies.

  • Mixed precision: FP16/BF16 forward, FP32 optimizer states.

  • Gradient checkpointing: trade compute for memory; enables larger batches.


6. Forward and backward references

Used here: matrix multiplication (ch153), gradient computation (ch306), Adam optimizer (ch312), scaling laws (ch332).

This will reappear in ch340 — Capstone II, where training infrastructure decisions are discussed in the context of the end-to-end system.