The process of Gradient Descent for training the router layer (gating network) of a Mixture-of-Experts (MoE) model involves calculating the gradient of the total loss with respect to the router’s parameters and using this to update the parameters.
The primary challenge in MoE training is the non-differentiability introduced by the Top-$K$ routing mechanism, which discretely selects experts. Standard backpropagation struggles with this non-smooth operation.
1. MoE Layer Output and Loss Function
MoE Layer Output ($y$)
For a given input vector $\mathbf{x}$, the router layer, typically a linear projection followed by a Softmax, produces unnormalized scores (logits) $\mathbf{h}(\mathbf{x}) = \mathbf{W}_g \mathbf{x}$, where $\mathbf{W}_g$ are the router’s parameters. These logits are often passed through a Softmax function to get the expert weights (or “gates”) $\mathbf{G}(\mathbf{x})$:
$$G_i(\mathbf{x}) = \frac{\exp(h_i(\mathbf{x}))}{\sum_{j=1}^{N} \exp(h_j(\mathbf{x}))}$$where $N$ is the total number of experts.
The MoE layer output is a weighted sum of the outputs of the $K$ selected experts, $\mathbf{E}_i(\mathbf{x})$, weighted by their gate values:
$$ \mathbf{y}(\mathbf{x}) = \sum_{i=1}^{N} \mathbf{M}_i(\mathbf{x}) \cdot G_i(\mathbf{x}) \cdot \mathbf{E}_i(\mathbf{x}) \tag{Eq. MoE} $$where $\mathbf{M}_i(\mathbf{x})$ is a Top-$K$ mask that is $1$ if expert $i$ is one of the top $K$ chosen experts for input $\mathbf{x}$, and $0$ otherwise.
Total Loss Function ($\mathcal{L}$)
The total loss function typically consists of two main components:
- Main Task Loss ($\mathcal{L}_{\text{main}}$): A standard differentiable loss (e.g., cross-entropy or mean squared error) computed on the model’s final output $\mathbf{y}$ against the true target $\mathbf{t}$.
- Auxiliary Load Balancing Loss ($\mathcal{L}_{\text{aux}}$): An optional but common term to encourage an even distribution of tokens across experts, preventing a “router collapse” where only a few experts are used. A common formulation is: $$\mathcal{L}_{\text{aux}} = \alpha \cdot \sum_{i=1}^{N} f_i \cdot P_i$$ where $\alpha$ is a scaling coefficient, $f_i$ is the fraction of tokens dispatched to expert $i$ in the batch, and $P_i$ is the average routing probability to expert $i$ across the batch. This is a differentiable term that penalizes large values of both $f_i$ (uneven load) and $P_i$ (over-confidence).
Yet we don’t discuss too much here. Readers could refer to the Appendices for further illustration, typically for the DeepSeek Auxiliary-Loss-Free Load Balancing.
Whatsoever, the total loss is:
$$\mathcal{L} = \mathcal{L}_{\text{main}} + \mathcal{L}_{\text{aux}}$$2. Gradient Calculation for the Router Parameters
Gradient Descent requires calculating the gradient of the total loss $\mathcal{L}$ with respect to the router parameters, $\mathbf{W}_g$:
$$ \nabla_{\mathbf{W}_g} \mathcal{L} = \nabla_{\mathbf{W}_g} \mathcal{L}_{\text{main}} + \nabla_{\mathbf{W}_g} \mathcal{L}_{\text{aux}} $$A. Gradient from Auxiliary Loss ($\nabla_{\mathbf{W}_g} \mathcal{L}_{\text{aux}}$)
This is a standard backpropagation step since $\mathcal{L}_{\text{aux}}$ is designed to be fully differentiable with respect to $\mathbf{W}_g$. The exact calculation depends on the specific form of the balancing loss. And we here begin with the DeepSeek one as an example.
The DeepSeek MoE architecture, specifically models like DeepSeek-V3, utilizes an Auxiliary-Loss-Free Load Balancing Strategy (Loss-Free Balancing). The goal of this strategy is to maintain balanced expert utilization without introducing a differentiable auxiliary loss term $\mathcal{L}_{\text{aux}}$ into the total loss $\mathcal{L}$.
$$\mathcal{L} = \mathcal{L}_{\text{main}} \quad \text{(Auxiliary loss term } \mathcal{L}_{\text{aux}} \text{ is } \mathbf{0})$$Consequently, the gradient contribution from a dedicated load balancing loss to the router parameters $\mathbf{W}_g$ is zero:
$$\nabla_{\mathbf{W}_g} \mathcal{L}_{\text{aux}} = \mathbf{0}$$B. Gradient from Main Loss ($\nabla_{\mathbf{W}_g} \mathcal{L}_{\text{main}}$)
This is the challenging part due to the Top-$K$ mask $\mathbf{M}$. The mask is discrete and thus has a zero gradient almost everywhere.
The application of the Chain Rule yields the gradient via the gates $\mathbf{G}$:
$$ \nabla_{\mathbf{W}_g} \mathcal{L}_{\text{main}} = \frac{\partial \mathcal{L}_{\text{main}}}{\partial \mathbf{y}} \cdot \frac{\partial \mathbf{y}}{\partial \mathbf{G}} \cdot \frac{\partial \mathbf{G}}{\partial \mathbf{h}} \cdot \frac{\partial \mathbf{h}}{\partial \mathbf{W}_g} $$The term $\frac{\partial \mathbf{y}}{\partial \mathbf{G}}$ involves the derivative of the output with respect to the gates:
$$ \begin{align*} \frac{\partial \mathbf{y}}{\partial G_i} &=\frac{\partial}{\partial G_i}[\sum_{j=1}^{N} M_j \cdot G_j \cdot E_j]\\ &= \frac{\partial G_i}{\partial G_i} \cdot M_i \cdot E_i + G_i \cdot \frac{\partial}{\partial G_i}[M_i \cdot E_i]\\ &= \mathbf{M}_i(\mathbf{x}) \cdot \mathbf{E}_i(\mathbf{x}) + G_i(\mathbf{x}) \cdot \frac{\partial \mathbf{M}_i(\mathbf{x})}{\partial G_i(\mathbf{x})} \cdot \mathbf{E}_i(\mathbf{x}) \end{align*} $$The non-differentiable term is $\frac{\partial \mathbf{M}_i(\mathbf{x})}{\partial G_i(\mathbf{x})}$.
The Conventional Solution: Straight-Through Estimator (STE)
In practice, the Top-$K$ operation is commonly handled by ignoring its non-differentiable component during the backward pass. This is a form of a Straight-Through Estimator (STE) or a surrogate gradient method.
- Forward Pass: The Top-$K$ function is used to discretely select experts, $\mathbf{M}_i(\mathbf{x})$ is computed.
- Backward Pass: The gradient calculation bypasses the non-differentiable selection and is approximated by considering the gates $G_i(\mathbf{x})$ as only multiplying the experts’ outputs without the selection dependence.
The approximate gradient for $\mathcal{L}_{\text{main}}$ is thus simplified to:
$$ \begin{align*} \nabla_{\mathbf{W}_g} \mathcal{L}_{\text{main}} &= \sum \frac{\partial \mathcal{L}_{\text{main}}}{\partial \mathbf{y}} \cdot \frac{\partial \mathbf{y}}{\partial \mathbf{G}} \cdot \frac{\partial \mathbf{G}}{\partial \mathbf{h}} \cdot \frac{\partial \mathbf{h}}{\partial \mathbf{W}_g}\\ &\approx \sum_{i=1}^{N} \frac{\partial \mathcal{L}_{\text{main}}}{\partial \mathbf{y}} \cdot \mathbf{M}_i \cdot \mathbf{E}_i \cdot \frac{\partial G_i}{\partial \mathbf{h}} \cdot \frac{\partial \mathbf{h}}{\partial \mathbf{W}_g} \tag{Eq. Loss} \end{align*} $$This approximation ensures that the router receives a gradient signal only from the selected experts ($M_i(\mathbf{x})=1$) and updates its parameters $\mathbf{W}_g$ in the direction that makes the selected experts’ gates (probabilities) higher.
3. Parameter Update (Gradient Descent Step)
Finally, the router parameters $\mathbf{W}_g$ are updated using the calculated total gradient and the learning rate $\eta$:
$$ \mathbf{W}_g^{\text{new}} = \mathbf{W}_g^{\text{old}} - \eta \cdot \nabla_{\mathbf{W}_g} \mathcal{L} $$This update simultaneously aims to:
- Minimize the main task loss by promoting routing to the experts that contribute most to a correct prediction (via $\nabla_{\mathbf{W}_g} \mathcal{L}_{\text{main}}$).
- Encourage balanced expert usage by adjusting the routing probabilities to distribute the load more evenly (via $\nabla_{\mathbf{W}_g} \mathcal{L}_{\text{aux}}$).
4. The Optimization Logic
From the equation(Eq. Loss), we notice that the main loss, $\nabla_{W_g} \mathcal{L}_{main}$ is calculated along with the training of $\mathbf{h}$ (Router) regardless of $\mathbf{M}(\mathbf{x})$, which is to say, the Router’s weights are trained on the Experts chosen, whatever the chosen ones are correct or not – through the whole training process, the Router has never actively chosen an Expert, but gradually tries to find the best Expert under different scenarios, under the main data body provided and trained on. Thereby the unnatural Load Balancing mechanism is required to ensure its functionality, which we will discuss further on.
Despite its betrayal to human intuition, such mechanism still works fundamentally in a mathematically correct way. We start with a dataset:
$$ \mathcal{D} = \{(x_i, y_i)\}_{i=1}^N $$and hold the MoE layer as still:
$$ \mathbf{y}_i(\mathbf{x}_i) = \sum_{j=1}^{N} \mathbf{M}_j(\mathbf{x}_i) \cdot G_j(\mathbf{x}_i) \cdot \mathbf{E}_j(\mathbf{x}_i) \tag{Eq. MoE} $$where $\mathbf{x}_i$ is the output of the layers before MoE w.r.t. $x_i$, and $\mathbf{y}_i(\mathbf{x}_i)$ is the output of the MoE layer given $\mathbf{x}_i$.
Our goal is to minimize the empirical risk:
$$ \mathcal{J}(\theta) = \frac{1}{N} \sum_{i=1}^{N} \ell(f_{\theta}(x_i), y_i) $$where $\ell(\cdot,\cdot)$ is the loss function (e.g., MSE, cross-entropy).
A. Full-Batch Gradient Descent
In vanilla gradient descent, we update parameters using the gradient computed over all samples:
$$ \nabla_\theta \mathcal{J}(\theta) = \frac{1}{N} \sum_{i=1}^{N} \nabla_\theta \ell(f_\theta(x_i), y_i) $$and the update rule is:
$$ \theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{J}(\theta_t) $$where $\eta$ is the learning rate.
This gives an exact gradient of the empirical risk but is computationally expensive — $O(N)$ per iteration.
B. Mini-Batch Gradient Descent
To save computation, we approximate the full gradient by averaging over a subset (mini-batch) of size $m \ll N$:
Let $B_t \subset {1,\dots,N}$ be a randomly sampled batch at iteration $t$. Then the mini-batch gradient is:
$$ g_t = \frac{1}{m} \sum_{i \in B_t} \nabla_\theta \ell(f_\theta(x_i), y_i) $$and the update rule becomes:
$$ \theta_{t+1} = \theta_t - \eta g_t $$then, since each sample is drawn uniformly from the dataset,
$$ \mathbb{E}_{B_t}\left[ \frac{1}{m} \sum_{i \in B_t} \nabla_\theta \ell(f_\theta(x_i), y_i) \right] = \frac{1}{N} \sum_{i=1}^{N} \nabla_\theta \ell(f_\theta(x_i), y_i)\\ \mathbb{E}_{B_t}[g_t] = \nabla_\theta \mathcal{J}(\theta_t) $$That is, the mini-batch gradient $g_t$ is an unbiased estimator of the true gradient. Thus, even though each batch only uses part of the data, in expectation it’s moving in the same direction as the true gradient.
This saves the computation requirement to $O(m)$ per iteration.
C. MoE Gradient Descent
Given the aforementioned statements and equation (Eq. Loss and Eq. MoE), we need to minimize:
$$ \begin{align*} \mathcal{J}(\theta) &= \frac{1}{N} \sum_{i=1}^{N} \ell(f_{\theta}(x_i), y_i)\\ &= \frac{1}{N} \sum_{i=1}^{N} \ell(\mathcal{F}(\mathbf{y}_i(\mathbf{x}_i)), y_i)\\ &= \frac{1}{N} \sum_{i=1}^{N} \ell(\mathcal{F}(\mathbf{M}_j(\mathbf{x}_i) \cdot G_j(\mathbf{x}_i) \cdot \mathbf{E}_j(\mathbf{x}_i)), y_i) \end{align*} $$where the $\mathcal{F}$ represents layers after the MoE.
Focusing one the MoE part, initially, randomly this layer would choose some Experts, then the gradient descent ensues whereafter the calculation. Without loss of generality, we denote the chosen ones with indices $\{1, 2, 3, 4\}$, where the first 2 ones are Shared Experts, while others Routed. Due to the definition of the Router, the descent dedicated to the Routed ones can thus be represented as:
$$ \nabla_{\theta} \mathcal{L}_{main} \approx \frac{\partial \mathcal{L}_{\text{main}}}{\partial \mathbf{y}} \cdot \mathbf{M}_i \cdot \mathbf{E}_i \cdot \frac{\partial \mathbf{G}_i}{\partial \mathbf{h}} \cdot \frac{\partial \mathbf{h}}{\partial \theta}, \text{where } i \in \{3, 4\}. $$And thereafter, during the backward propagation:
$$ \theta^{\text{new}} = \theta^{\text{old}} - \eta \cdot \nabla_{\theta} \mathcal{L}_{\mathbf{E}_{3, 4} \text{ dedicated}} $$Without any implementation of Balancing mechanisms, the likelihood of the Routed Experts $\{3, 4\}$ continues to grow. Ultimately with almost no doubt, multi-Experts mechanism shall collapse to limited ones, something non-sense.
Therefore, mechanisms for balancing are needed to enhance this feature, leading to the introduction of the auxiliary balancing algorithms. Though diverse in variety, such algorithms hold the identical expectation, which is no more than the following equation:
$$ \mathbb{E}_{B_t} [\frac{1}{mn} \sum_{i \in B_t} \sum_{j=1}^{N} \mathbf{M}_j(\mathbf{x}_i) \cdot \mathbf{E}_j(\mathbf{x}_i) \cdot \frac{\partial \mathbf{G}_j(\mathbf{x}_i)}{\partial \mathbf{h}_j(\mathbf{x}_i)} \cdot \frac{\partial \mathbf{h}_j(\mathbf{x}_i)}{\partial \theta}] = \frac{1}{N} \sum_{i=1}^{N} \nabla_{\theta} \mathbf{E}_{*}(\mathbf{x}_i) $$where $\mathbf{E}_{*}$ is the aggregated Expert, sharing the same architecture with a single Expert. And the equation is under the assumption that the given Batch is large enough.
Appendices
A. Softmax Differentiation
Given
$$ \begin{equation*} \text{softmax}(h_i) = \frac{\exp(h_i)}{\sum_{j=1}^{N}\exp(h_j)}, \end{equation*} $$we have
$$ \begin{align*} \frac{\partial \text{softmax}(h_i)}{\partial h_i} &= \frac{\exp(h_i) \cdot \sum_{j=1}^{N} \exp(h_j) - \exp(h_i) \cdot \exp(h_i)}{[\sum_{j=1}^{N}\exp(h_j)]^2}\\ &= \frac{\exp(h_i) \cdot \sum_{j \neq i}^{N}\exp(h_j)}{[\sum_{j=1}^{N}\exp(h_j)]^2}. \end{align*} $$B. DeepSeek Auxiliary Loss Load Balancing
The DeepSeek Router Update Mechanism
Instead of relying on a gradient from an auxiliary loss, the DeepSeek strategy achieves load balancing by directly manipulating the routing scores (logits) via a non-gradient-based, dynamic expert-wise bias. This operation happens before the Top-$K$ routing decision in the forward pass and does not contribute to the backpropagated gradient.
1. Routing Score Calculation with Bias:
The original routing score (logit) $h_i(\mathbf{x})$ is modified by an expert-specific bias $b_i$ to produce a biased score $h'_i(\mathbf{x})$:
$$h'_i(\mathbf{x}) = h_i(\mathbf{x}) + b_i$$where $h_i(\mathbf{x}) = \mathbf{w}_i \cdot \mathbf{x}$ (for the $i$-th expert).
The final gates $G_i(\mathbf{x})$ are then computed using Softmax on the biased scores $h'_i(\mathbf{x})$.
2. Dynamic Bias Update (Non-Gradient Step):
After each forward/backward pass (or epoch), the expert-wise bias $b_i$ for each expert $i$ is dynamically updated based on the expert’s recent load.
- Let $c_i$ be the load (count of tokens routed to expert $i$) in the current batch.
- Let $\bar{c}$ be the average desired load per expert ($\bar{c} = \text{Total Tokens} / N$).
- The load violation error $e_i$ is calculated: $$e_i = c_i - \bar{c}$$
- The bias $b_i$ is updated using a simple sign-based update rule: $$b_i^{\text{new}} = b_i^{\text{old}} - \mu \cdot \text{sign}(e_i)$$ Where $\mu$ is a small, positive, non-trainable step size (or update rate).
This rule enforces balance:
- If expert $i$ is overloaded ($e_i > 0$), $\text{sign}(e_i) = +1$, and the bias $b_i$ is decreased. A lower $b_i$ reduces the routing score $h'_i(\mathbf{x})$, making the expert less likely to be selected in future batches.
- If expert $i$ is underloaded ($e_i < 0$), $\text{sign}(e_i) = -1$, and the bias $b_i$ is increased, making the expert more likely to be selected.
Conclusion: The load balancing is achieved through a non-gradient, heuristic update of the bias terms $b_i$, which are not part of the trainable router weights $\mathbf{W}_g$ optimized by the main gradient descent step. Thus, the gradient contribution $\nabla_{\mathbf{W}_g} \mathcal{L}_{\text{aux}}$ is conceptually and practically eliminated.
The overall router parameter update simplifies to:
$$\mathbf{W}_g^{\text{new}} = \mathbf{W}_g^{\text{old}} - \eta \cdot \nabla_{\mathbf{W}_g} \mathcal{L}_{\text{main}}$$where the load-balancing effect is applied externally via the bias $b_i$.