1. Why infrastructure matters¶
A perfectly designed model trained inefficiently wastes compute. At large scale:
GPT-3: 3.14 × 10^23 FLOPs, ~1000 A100-GPU-days.
LLaMA 3 405B: ~10^25 FLOPs.
Infrastructure determines whether a given compute budget is achievable in practice.
2. Data parallelism¶
Split the mini-batch across GPUs. Each GPU:
Holds a full copy of the model.
Processes samples.
Computes local gradients.
All-reduce: averages gradients across GPUs (ring-all-reduce algorithm).
Updates local model copy.
Communication cost: per step. Efficient for moderate model sizes.
3. Model parallelism¶
When the model doesn’t fit in one GPU’s memory:
Tensor parallelism (Megatron-LM): split weight matrices across GPUs. For a linear layer : split column-wise across GPUs, compute partial , all-reduce.
Pipeline parallelism: assign different layers to different GPUs. GPU 1 processes layers 1–12, GPU 2 processes layers 13–24, etc. Challenge: GPUs are idle while waiting for upstream results → schedule micro-batches.
ZeRO (Zero Redundancy Optimizer): shard optimizer states, gradients, and parameters across GPUs. ZeRO-3 eliminates all redundancy — each GPU holds only of everything.
(Linear algebra operations: ch151. Distributed matrix multiply: ch153.)
import numpy as np
import matplotlib.pyplot as plt
# ── Simulate memory and compute requirements at scale ──
def model_memory_GB(params: float, precision_bytes: int = 2) -> float:
"""Memory for model weights alone (FP16 = 2 bytes per param)."""
return params * precision_bytes / 1e9
def training_memory_GB(params: float, precision_bytes: int = 2) -> float:
"""Approximate training memory: weights + gradients + Adam states (8 bytes each)."""
weight_mem = params * precision_bytes # FP16 weights
grad_mem = params * precision_bytes # FP16 gradients
adam_mem = params * 4 * 2 # two FP32 Adam states (m, v)
return (weight_mem + grad_mem + adam_mem) / 1e9
def activation_memory_GB(params: float, batch_size: int = 8, seq_len: int = 2048) -> float:
"""Rough activation memory estimate."""
# Scales with batch and sequence length, roughly proportional to sqrt(params)
d_model = int(np.sqrt(params / 12)) # heuristic from param count formula
return batch_size * seq_len * d_model * 2 / 1e9
# Model sizes: 100M to 1T parameters
sizes = np.logspace(8, 12, 100)
gpu_80GB = 80.0
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
# Memory requirements
mem_weights = [model_memory_GB(N) for N in sizes]
mem_training = [training_memory_GB(N) for N in sizes]
mem_act = [activation_memory_GB(N) for N in sizes]
axes[0].loglog(sizes/1e9, mem_weights, label='Inference (FP16 weights only)', lw=2, color='#2ecc71')
axes[0].loglog(sizes/1e9, mem_training, label='Training (weights+grad+Adam)', lw=2, color='#e74c3c')
axes[0].loglog(sizes/1e9, [m+a for m,a in zip(mem_training, mem_act)],
label='Training + activations (B=8)', lw=2, color='#9b59b6', linestyle='--')
axes[0].axhline(gpu_80GB, color='black', linestyle=':', lw=2, label='80GB GPU memory')
axes[0].axhline(gpu_80GB*8, color='gray', linestyle=':', lw=2, label='8×80GB node')
axes[0].set_xlabel('Parameters (billions)'); axes[0].set_ylabel('Memory (GB)')
axes[0].set_title('Memory requirements vs model size')
axes[0].legend(fontsize=8); axes[0].grid(True, alpha=0.3)
# GPUs needed with ZeRO-3 (distribute all states)
gpus_needed = [int(np.ceil(training_memory_GB(N) / gpu_80GB)) for N in sizes]
axes[1].loglog(sizes/1e9, gpus_needed, color='#3498db', lw=2)
axes[1].set_xlabel('Parameters (billions)'); axes[1].set_ylabel('Min GPUs (ZeRO-3 training)')
axes[1].set_title('Minimum GPUs for training\n(ZeRO-3, 80GB each)')
axes[1].grid(True, alpha=0.3)
# Annotate key models
key_models = [('GPT-3', 175e9, '#e74c3c'), ('LLaMA-3 70B', 70e9, '#f39c12'),
('LLaMA-3 405B', 405e9, '#9b59b6')]
for name, N, color in key_models:
gpus = int(np.ceil(training_memory_GB(N) / gpu_80GB))
axes[1].scatter([N/1e9], [gpus], color=color, s=80, zorder=5)
axes[1].annotate(f'{name}\n({gpus} GPUs)', (N/1e9, gpus), fontsize=7,
xytext=(5, 5), textcoords='offset points')
plt.tight_layout()
plt.savefig('ch333_infrastructure.png', dpi=120)
plt.show()
print("Memory estimates for key models (FP16, B=8, seq=2048):")
for name, N, _ in key_models:
print(f" {name} ({N/1e9:.0f}B): training={training_memory_GB(N):.0f}GB, "
f"min GPUs={int(np.ceil(training_memory_GB(N)/gpu_80GB))}")3. Mixed precision training¶
Training in FP32 is 2× slower and uses 2× memory vs FP16/BF16. Mixed precision: use FP16 for forward/backward (fast), FP32 for weight updates (numerically stable).
BF16 (bfloat16): same range as FP32, lower precision. Preferred for LLMs because it avoids overflow issues common with FP16.
Gradient scaling: multiply loss by a large constant before backward to prevent FP16 underflow; divide gradients by the same constant before the optimizer step.
4. Gradient checkpointing¶
During backprop, intermediate activations must be stored for gradient computation. Memory = where = layers.
Gradient checkpointing: discard intermediate activations during the forward pass; recompute them on-the-fly during the backward pass. Trades compute for memory: costs one extra forward pass but reduces activation memory from to .
5. Summary¶
Data parallelism: split batch across GPUs; all-reduce gradients. Linear speedup.
Tensor/pipeline parallelism: split model across GPUs for models that don’t fit.
ZeRO: shard optimizer states/gradients/params; eliminate redundant copies.
Mixed precision: FP16/BF16 forward, FP32 optimizer states.
Gradient checkpointing: trade compute for memory; enables larger batches.
6. Forward and backward references¶
Used here: matrix multiplication (ch153), gradient computation (ch306), Adam optimizer (ch312), scaling laws (ch332).
This will reappear in ch340 — Capstone II, where training infrastructure decisions are discussed in the context of the end-to-end system.