This code implements a Mixture of Experts (MoE) model in PyTorch, which is a neural network architecture that uses multiple specialised "expert" networks and a routing mechanism to decide which experts should process each input.
-
Weight calculations for inference v full model
-
Introduce capacity limits per expert
-
Add noise (similar to epsilon-greedy in RL, encourages exploration)
-
GPT with MoE instead of the normal FFN
Expert: A simple feedforward network that transforms input through a hidden layer with GELU activation.
Router: The key component that decides which experts to use for each input.
MoE: The main model that combines routing and experts, plus an output layer for classification.
The Router implements a top-k gating mechanism:
- Gating:
gateis a linear layer that produces logits for each expert - Scoring: Softmax converts logits to probabilities (scores) - these represent how much each expert should contribute
- Selection:
topkselects the k highest-scoring experts for each input - Dispatch mask: A binary mask indicating which experts are active for each input
The key insight is that instead of using all experts for every input, we only use the top-k most relevant experts, making the model more efficient and specialised.
The auxiliary loss is designed to encourage load balancing among experts. Here's the mathematical formulation:
Where:
-
$E$ = number of experts -
$B$ = batch size -
$f_i$ = fraction of tokens routed to expert$i$ (load) -
$P_i$ = sum of routing probabilities for expert$i$ (importance)
Without load balancing, the model might:
- Overuse some experts while ignoring others
- Create training instability
- Reduce the benefits of having multiple experts
The auxiliary loss penalises scenarios where:
- An expert gets high routing probabilities (high importance) AND
- Gets assigned many tokens (high load)
This encourages the router to distribute work more evenly across experts while still allowing specialisation. The loss is typically added to the main task loss with a small coefficient (e.g., I've done 0.01 by default) during training.
The