← back to projects

5D Parallelism for Transformer Training

complete

Built the parallel training stack for a GPT-style model as a capstone for a hands-on distributed systems workshop, implementing data, tensor, pipeline, context, and expert parallelism from scratch, then benchmarked throughput, memory, and convergence across 1–8 GPU configurations.

PyTorchDistributed TrainingCUDAPythonML SystemsInfrastructureTransformers

Training Parallelism, From Single-GPU Baseline to 5D Distributed Runs

Capstone project for a 5D Parallelism Workshop — implementing and benchmarking five distributed training strategies for a GPT-style Transformer from scratch.

I implemented multiple parallelism strategies and benchmarked how each technique affects throughput, memory, and training dynamics across individual and combined configurations.


The Problem

Training large Transformers is limited by GPU memory, communication overhead, and poor utilization when work is split naively. This project explored how modern LLM training systems break the model, batch, sequence, and expert dimensions across multiple GPUs — implementing each technique from scratch rather than relying on a framework abstraction.


My Role

I built the full parallel training stack as the workshop capstone:

  • Data parallel scaling baselines across 1, 2, 4, and 8 GPUs
  • Tensor parallel column/row linear layers with all-gather/all-reduce communication
  • Pipeline parallel model partitioning and micro-batch scheduling
  • Context parallel ring attention for sequence sharding
  • Expert parallel MoE routing with top-k gating and load balancing
  • Combination runs: TP+PP+EP and TP+CP

Technical Implementation

TechniqueWhat It SplitsKey ImplementationWhat It Demonstrates
Data ParallelismBatchReplicated model + distributed batchesScaling throughput with more GPUs
Tensor ParallelismMatrix dimensionsColumnParallelLinear and RowParallelLinearSplitting Transformer matmuls across devices
Pipeline ParallelismLayersPipelineStage + micro-batchesReducing idle time across model stages
Context ParallelismSequence lengthRing attention over KV chunksLong-context attention without full-sequence memory
Expert ParallelismMoE expertsTop-k router + distributed expert executionConditional compute and expert sharding

What each parallelism strategy splits — DP splits the batch, TP/PP split the model, CP/EP split the sequence and experts
What each parallelism strategy splits — DP splits the batch, TP/PP split the model, CP/EP split the sequence and experts


Throughput Comparison

Combined parallelism throughput comparison across all runs
Combined parallelism throughput comparison across all runs


Results

Steady-state averages from steps 1–48, excluding initialization and final validation overhead.

RunParallelismAvg Tok/SecFinal Val LossNotes
gpu1_tp1_pp1_cp1_ep1DP=1 baseline~266K8.6988Single-GPU baseline
gpu2_tp1_pp1_cp1_ep1DP=2~531K8.6533Near-linear DP scaling
gpu4_tp1_pp1_cp1_ep1DP=4~1.06M8.6297Continued DP scaling
gpu8_tp1_pp1_cp1_ep1DP=8~2.10M8.6034Fastest pure throughput run
gpu8_tp2_pp1_cp1_ep1TP=2, DP=4~1.16M8.6929Tensor-parallel matmul split
gpu8_tp1_pp2_cp1_ep1PP=2, DP=4~531KPipeline schedule benchmark (logging artifact)
gpu8_tp1_pp1_cp2_ep1CP=2, DP=4~83K8.6118Ring attention has high communication cost
gpu8_tp1_pp1_cp1_ep4EP=4, DP=2~363K8.9291MoE routing and expert sharding
gpu8_tp2_pp2_cp1_ep2TP=2, PP=2, EP=2~150K8.9118Combined model/expert parallelism
gpu8_tp2_pp1_cp2_ep1TP=2, CP=2, DP=2~124K8.7077Combined tensor + context parallelism

Training Loss

Training loss convergence across parallelism strategies
Training loss convergence across parallelism strategies

Validation Loss

Validation loss across parallelism strategies
Validation loss across parallelism strategies


Key Takeaways

The most important result was not simply which run was fastest, but how each parallelism strategy introduced a different tradeoff:

  • Pure data parallelism achieved the highest token throughput because it minimized communication complexity — the replicated model pattern pays off when the model fits on a single device.
  • Tensor parallelism reduced per-rank model work but introduced collective communication at every matmul boundary.
  • Pipeline parallelism exposed the cost of pipeline bubbles and micro-batch scheduling overhead.
  • Context parallelism enabled sequence sharding but ring attention communication dominated at this model scale.
  • Expert parallelism demonstrated conditional compute and routing overhead, amplified when combined with other model-parallel strategies.

Technical Stack

LayerTool
FrameworkPyTorch Distributed (NCCL backend)
ParallelismCustom DP, TP, PP, CP, EP implementations
ModelGPT-style Transformer (slim)
HardwareUp to 8 GPUs
Visualizationmatplotlib