Implementation Guide

Preliminary: the Original Paper of DeepSeekV3

Auxiliary-Loss-Free Load Balancing. For MoE models, an unbalanced expert load will lead to routing collapse (Shazeer et al., 2017) and diminish computational efficiency in scenarios with expert parallelism. Conventional solutions usually rely on the auxiliary loss (Fedus et al., 2021; Lepikhin et al., 2021) to avoid unbalanced load. However, too large an auxiliary loss will impair the model performance (Wang et al., 2024a). To achieve a better trade-off between load balance and model performance, we pioneer an auxiliary-loss-free load balancing strategy (Wang et al., 2024a) to ensure load balance. To be specific, we introduce a bias term $b_i$ for each expert and add it to the corresponding affinity scores $s_{i, t}$ to determine the top-K routing:

$$ g'_{i,t} = \begin{cases} s_{i,t}, & \text{if } s_{i,t} + b_i \in \mathrm{TopK}\left(\{ s_{j,t} + b_j \mid 1 \leq j \leq N_r \}, K_r \right) \\ 0, & \text{otherwise} \end{cases} $$

Note that the bias term is only used for routing. The gating value, which will be multiplied with the FFN output, is still derived from the original affinity score $s_{i, t}$. During training, we keep monitoring the expert load on the whole batch of each training step. At the end of each step, we will decrease the bias term by $\gamma$ if its corresponding expert is overloaded, and increase it by $\gamma$ if its corresponding expert is underloaded, where $\gamma$ is a hyper-parameter called bias update speed. Through the dynamic adjustment, DeepSeek-V3 keeps balanced expert load during training, and achieves better performance than models that encourage load balance through pure auxiliary losses.


Let’s break down how to implement this sophisticated mechanism into our GROP PyTorch Lightning model.

1. The Core Concept: Auxiliary-Loss-Free Load Balancing

First, let’s internalize the logic you provided:

  • Problem: In a Mixture of Experts (MoE) model, a “router” network sends each input token to one or more “expert” sub-networks. If the router consistently sends too many tokens to a few popular experts (“routing collapse”), we lose the computational benefit of parallelism and the model’s capacity is underutilized.
  • Traditional Solution: Add an auxiliary loss term that penalizes imbalance. This forces the router to learn to distribute the load. However, this loss can interfere with the primary task’s loss (e.g., our GROP loss), potentially harming performance.
  • DeepSeek’s Solution: Decouple the balancing mechanism from the model’s learning gradients.
    1. Biased Routing: Introduce a non-trainable bias term b_i for each expert i. The routing decision (which experts to pick) is made using logits + bias.
    2. Original Gating: The actual weight applied to an expert’s output is still calculated from the original, unbiased logits. This is crucial: the balancing mechanism affects which expert is chosen, but not how much its output is valued.
    3. Manual Bias Update: At the end of each training step, manually adjust the biases.
      • If an expert was overloaded, decrease its bias. This makes it slightly less likely to be chosen in the next step.
      • If an expert was underloaded, increase its bias. This makes it slightly more likely to be chosen.
    4. Result: A dynamic, self-correcting system that balances load without adding a conflicting loss term to the main optimization objective.

2. Implementation Strategy within PyTorch Lightning

We will modify our previous GROP implementation by replacing a standard Feed-Forward Network (FFN) layer with a custom MoE layer that incorporates this balancing logic.

Here’s the plan:

  1. Create an Expert Module: A simple MLP that will serve as our expert network.
  2. Create the MoELayer Module: This is the heart of the new logic. It will contain:
    • A gating network (the router).
    • A ModuleList of Experts.
    • The load-balancing bias tensor, registered as a buffer.
    • The forward method implementing the biased Top-K routing, original-logit gating, and token dispatching.
  3. Create an ActorCriticMoE Network: This will be a new version of our ActorCritic network, which uses the MoELayer instead of a standard nn.Linear.
  4. Update the GROPLightningModule:
    • It will now use the ActorCriticMoE network.
    • The training_step will collect expert load statistics throughout an episode.
    • We will use the PyTorch Lightning hook on_train_batch_end to perform the manual bias update after each optimizer step. This is the perfect place for logic that shouldn’t be part of the autograd graph.

3. The “Best Version” Code with MoE Load Balancing

Here is the complete, modified code. I’ve heavily commented the new sections to explain what’s happening.

  1import gym
  2import numpy as np
  3import torch
  4import torch.nn as nn
  5import torch.optim as optim
  6import torch.nn.functional as F
  7from torch.distributions import Categorical
  8
  9import pytorch_lightning as pl
 10from pytorch_lightning.callbacks import ModelCheckpoint
 11from pytorch_lightning.loggers import TensorBoardLogger
 12
 13# --- MoE Components ---
 14
 15class Expert(nn.Module):
 16    """A simple two-layer MLP to serve as an expert."""
 17    def __init__(self, input_dim, hidden_dim, output_dim):
 18        super().__init__()
 19        self.net = nn.Sequential(
 20            nn.Linear(input_dim, hidden_dim),
 21            nn.ReLU(),
 22            nn.Linear(hidden_dim, output_dim)
 23        )
 24    def forward(self, x):
 25        return self.net(x)
 26
 27class MoELayer(nn.Module):
 28    """
 29    The core Mixture of Experts layer with Auxiliary-Loss-Free Load Balancing.
 30    """
 31    def __init__(self, input_dim, output_dim, num_experts, top_k):
 32        super().__init__()
 33        self.input_dim = input_dim
 34        self.output_dim = output_dim
 35        self.num_experts = num_experts
 36        self.top_k = top_k
 37
 38        # The router/gating network that produces affinity scores (logits)
 39        self.gating_network = nn.Linear(input_dim, num_experts)
 40        
 41        # The list of expert networks
 42        self.experts = nn.ModuleList([Expert(input_dim, input_dim * 2, output_dim) for _ in range(num_experts)])
 43
 44        # The load-balancing bias term. Registered as a buffer so it's part of the
 45        # module's state (saved with checkpoints) but not considered a trainable parameter by optimizers.
 46        self.register_buffer('bias', torch.zeros(1, num_experts))
 47        # We also need to track the load for the update rule
 48        self.register_buffer('expert_load', torch.zeros(num_experts))
 49
 50    def forward(self, x):
 51        """
 52        x: input tensor of shape (batch_size, input_dim)
 53        """
 54        batch_size, _ = x.shape
 55        
 56        # 1. Calculate affinity scores from the router
 57        original_logits = self.gating_network(x) # shape: (batch_size, num_experts)
 58
 59        # 2. Add the balancing bias for routing decisions ONLY
 60        biased_logits = original_logits + self.bias
 61
 62        # 3. Perform Top-K routing on biased logits
 63        # `top_k_weights` are not used directly, `top_k_indices` are what we need
 64        _, top_k_indices = torch.topk(biased_logits, self.top_k, dim=1) # shape: (batch_size, top_k)
 65
 66        # 4. Calculate gating values from ORIGINAL logits for the selected experts
 67        # We use `gather` to select the original logits corresponding to the chosen experts
 68        selected_original_logits = torch.gather(original_logits, 1, top_k_indices)
 69        gating_weights = F.softmax(selected_original_logits, dim=1) # shape: (batch_size, top_k)
 70        
 71        # 5. Dispatch tokens to experts and combine outputs
 72        # This is a dispatch-scatter operation. We'll implement it efficiently.
 73        final_output = torch.zeros(batch_size, self.output_dim, device=x.device)
 74        
 75        # Create a flat list of tokens and their assigned expert indices
 76        flat_top_k_indices = top_k_indices.flatten() # (batch_size * top_k)
 77        
 78        # Track the load for the bias update rule
 79        # `bincount` is a highly efficient way to count occurrences
 80        self.expert_load = torch.bincount(flat_top_k_indices, minlength=self.num_experts)
 81
 82        # Create a one-hot mask for dispatching
 83        # shape: (batch_size * top_k, num_experts)
 84        one_hot_mask = F.one_hot(flat_top_k_indices, num_classes=self.num_experts)
 85
 86        # Loop over experts to process tokens in batches
 87        for i in range(self.num_experts):
 88            # Find which tokens (from the repeated batch) are routed to this expert
 89            # shape: (batch_size * top_k)
 90            expert_mask = one_hot_mask[:, i].bool()
 91            
 92            if expert_mask.sum() == 0:
 93                continue
 94
 95            # `x.repeat_interleave` creates the (batch_size * top_k, dim) tensor we need
 96            expert_inputs = x.repeat_interleave(self.top_k, dim=0)[expert_mask]
 97            
 98            # Process inputs through the expert
 99            expert_outputs = self.experts[i](expert_inputs)
100
101            # Get the gating weights for these specific tokens
102            # shape: (num_tokens_for_this_expert, 1)
103            weights_for_expert = gating_weights.flatten()[expert_mask].unsqueeze(1)
104            
105            # Weight the expert outputs
106            weighted_outputs = expert_outputs * weights_for_expert
107            
108            # Scatter the results back to the correct original batch positions
109            # `index_add_` is a memory-efficient scatter-add operation.
110            # We need to map the flat indices back to the original batch dimension.
111            original_batch_indices = (torch.arange(batch_size, device=x.device)
112                                      .repeat_interleave(self.top_k))[expert_mask]
113            
114            final_output.index_add_(0, original_batch_indices, weighted_outputs)
115            
116        return final_output, self.expert_load
117
118# --- Updated Actor-Critic Network with MoE ---
119
120class ActorCriticMoE(nn.Module):
121    def __init__(self, state_dim, action_dim, hidden_dim=128, num_experts=8, top_k=2):
122        super(ActorCriticMoE, self).__init__()
123        self.shared_base = nn.Sequential(
124            nn.Linear(state_dim, hidden_dim),
125            nn.ReLU()
126        )
127        
128        self.moe_layer = MoELayer(
129            input_dim=hidden_dim, 
130            output_dim=hidden_dim, 
131            num_experts=num_experts, 
132            top_k=top_k
133        )
134        
135        self.actor_head = nn.Linear(hidden_dim, action_dim)
136        self.critic_head = nn.Linear(hidden_dim, 1)
137
138    def forward(self, state):
139        base_features = self.shared_base(state)
140        moe_output, expert_load = self.moe_layer(base_features)
141        
142        # Add a residual connection (common practice)
143        x = F.relu(moe_output + base_features)
144        
145        action_logits = self.actor_head(x)
146        state_value = self.critic_head(x)
147        
148        # We must return the expert_load to be used by the LightningModule
149        return action_logits, state_value, expert_load
150
151# --- The GROP LightningModule with MoE Load Balancing ---
152
153class GROPLightningMoEModule(pl.LightningModule):
154    def __init__(
155        self,
156        env_name: str = "CartPole-v1",
157        lr: float = 3e-4,
158        gamma: float = 0.99,
159        gae_lambda: float = 0.95,
160        critic_loss_coeff: float = 0.5,
161        entropy_coeff: float = 0.01,
162        # MoE specific hparams
163        num_experts: int = 8,
164        top_k: int = 2,
165        bias_update_speed: float = 1e-2,
166    ):
167        super().__init__()
168        self.save_hyperparameters()
169
170        self.env = gym.make(self.hparams.env_name)
171        state_dim = self.env.observation_space.shape[0]
172        action_dim = self.env.action_space.n
173        
174        self.total_reward = 0
175        self.episode_count = 0
176
177        self.network = ActorCriticMoE(
178            state_dim, action_dim, 
179            num_experts=self.hparams.num_experts, 
180            top_k=self.hparams.top_k
181        )
182
183    # `forward` needs to handle the multiple return values from the network
184    def forward(self, state):
185        return self.network(state)
186
187    def get_action(self, state):
188        state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
189        logits, _, _ = self.forward(state_tensor)
190        dist = Categorical(logits=logits)
191        action = dist.sample()
192        return action.item()
193
194    def _calculate_advantages_gae(self, rewards, values, dones):
195        advantages = []
196        gae = 0
197        next_value = values[-1]
198        for i in reversed(range(len(rewards) - 1)):
199            reward, done_mask, value = rewards[i], 1.0 - dones[i], values[i]
200            delta = reward + self.hparams.gamma * next_value * done_mask - value
201            gae = delta + self.hparams.gamma * self.hparams.gae_lambda * done_mask * gae
202            advantages.insert(0, gae)
203            next_value = value
204        return torch.tensor(advantages, dtype=torch.float32, device=self.device)
205
206    def training_step(self, batch, batch_idx):
207        # 1. Generate Experience
208        state = self.env.reset()
209        done = False
210        states, actions, rewards, dones, logits_list, values = [], [], [], [], [], []
211        
212        # We need to accumulate expert load over the whole episode
213        total_expert_load_in_episode = torch.zeros(self.hparams.num_experts, device=self.device)
214
215        while not done:
216            action = self.get_action(state)
217            
218            state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
219            # Capture all outputs from the forward pass
220            logits, value, expert_load = self.forward(state_tensor)
221
222            next_state, reward, done, _ = self.env.step(action)
223            
224            states.append(state)
225            actions.append(action)
226            rewards.append(reward)
227            dones.append(done)
228            logits_list.append(logits)
229            values.append(value)
230            total_expert_load_in_episode += expert_load
231
232            state = next_state
233        
234        self.total_reward = sum(rewards)
235        self.episode_count += 1
236        self.log("episode_reward", self.total_reward, on_step=True, on_epoch=True, prog_bar=True)
237
238        # 2. Process Experience and Calculate Losses
239        actions_v = torch.tensor(actions, dtype=torch.int64, device=self.device).view(-1, 1)
240        values_v = torch.cat(values).squeeze(-1)
241        all_logits = torch.cat(logits_list)
242
243        advantages_v = self._calculate_advantages_gae(rewards, values_v.cpu().detach().numpy(), dones)
244        returns_v = advantages_v + values_v.detach()
245        advantages_v = (advantages_v - advantages_v.mean()) / (advantages_v.std() + 1e-8)
246
247        # 3. GROP Actor Loss (unchanged)
248        chosen_action_logits = all_logits.gather(1, actions_v).squeeze(-1)
249        log_sum_exp_term = torch.logsumexp(all_logits - chosen_action_logits.unsqueeze(-1), dim=1)
250        actor_loss = (advantages_v * log_sum_exp_term).mean()
251        
252        # 4. Critic Loss (unchanged)
253        critic_loss = F.mse_loss(values_v, returns_v)
254
255        # 5. Entropy Bonus (unchanged)
256        dist = Categorical(logits=all_logits)
257        entropy_loss = dist.entropy().mean()
258        
259        # 6. Total Loss
260        total_loss = (
261            actor_loss 
262            + self.hparams.critic_loss_coeff * critic_loss 
263            - self.hparams.entropy_coeff * entropy_loss
264        )
265        
266        # Log everything, including the expert load for this step
267        self.log_dict({
268            "total_loss": total_loss,
269            "actor_loss": actor_loss,
270            "critic_loss": critic_loss
271        }, on_step=True, on_epoch=True)
272        
273        # CRUCIAL: Return the accumulated expert load to the hook
274        return {"loss": total_loss, "expert_load": total_expert_load_in_episode}
275
276    def on_train_batch_end(self, outputs, batch, batch_idx):
277        """
278        This hook is called after the training step and the optimizer step.
279        This is where we implement the auxiliary-loss-free update rule.
280        """
281        expert_load = outputs['expert_load'] # This comes from the training_step's return dict
282        total_tokens_routed = expert_load.sum()
283        
284        if total_tokens_routed == 0:
285            return # Avoid division by zero if an episode was empty
286        
287        # Define ideal load: total tokens * k / num_experts
288        ideal_load_per_expert = total_tokens_routed * self.hparams.top_k / self.hparams.num_experts
289
290        # Find overloaded and underloaded experts
291        overloaded_mask = expert_load > ideal_load_per_expert
292        underloaded_mask = expert_load < ideal_load_per_expert
293        
294        # Update the bias tensor. This happens outside the autograd graph.
295        with torch.no_grad():
296            self.network.moe_layer.bias[0, overloaded_mask] -= self.hparams.bias_update_speed
297            self.network.moe_layer.bias[0, underloaded_mask] += self.hparams.bias_update_speed
298
299        # Log for monitoring
300        load_percentage = expert_load / total_tokens_routed if total_tokens_routed > 0 else expert_load
301        for i in range(self.hparams.num_experts):
302            self.log(f"expert_{i}/load_pct", load_percentage[i], on_step=True, on_epoch=False)
303            self.log(f"expert_{i}/bias", self.network.moe_layer.bias[0, i], on_step=True, on_epoch=False)
304        
305        # Log load variance as a single metric for balance
306        load_variance = torch.var(load_percentage)
307        self.log("expert_load_variance", load_variance, on_step=True, on_epoch=False)
308
309    def configure_optimizers(self):
310        optimizer = optim.Adam(self.network.parameters(), lr=self.hparams.lr)
311        return optimizer
312
313    def train_dataloader(self):
314        return [0]
315
316# --- Main Training Script ---
317if __name__ == '__main__':
318    pl.seed_everything(42)
319    
320    # --- Configuration ---
321    ENV_NAME = "CartPole-v1"
322    MAX_EPOCHS = 1000  # MoE might take a bit longer to stabilize
323    
324    # --- Setup ---
325    model = GROPLightningMoEModule(
326        env_name=ENV_NAME,
327        lr=3e-4,
328        gamma=0.99,
329        gae_lambda=0.95,
330        num_experts=8,
331        top_k=2,
332        bias_update_speed=0.01
333    )
334    
335    logger = TensorBoardLogger("tb_logs", name="grop_moe_cartpole")
336    checkpoint_callback = ModelCheckpoint(monitor="episode_reward", dirpath="checkpoints_moe/", mode="max")
337    
338    trainer = pl.Trainer(
339        max_epochs=MAX_EPOCHS,
340        logger=logger,
341        callbacks=[checkpoint_callback],
342        accelerator="auto",
343        gradient_clip_val=0.5 # Good practice for more complex models
344    )
345    
346    trainer.fit(model)

4. How to Run and What to Monitor

  1. Run the script: python your_script_name.py
  2. Open TensorBoard: tensorboard --logdir tb_logs/

Now, in TensorBoard, you can observe the magic of the load balancing mechanism:

  • episode_reward: This should still go up, showing the model is learning the task.
  • expert_load_variance: This is the key metric for balancing. You should see this value decrease over time and stabilize at a very low number. This indicates that the load across experts is becoming more uniform.
  • expert_i/load_pct: Check the individual expert load percentages. Initially, some might be very high and others zero. Over time, they should all converge towards the ideal load (top_k / num_experts, which is 2/8 = 25% in our default case).
  • expert_i/bias: Watch the bias values. You’ll see the biases of initially overloaded experts going down (becoming negative) and the biases of underloaded experts going up (becoming positive). This is the dynamic adjustment in action.

This implementation successfully integrates a cutting-edge MoE load-balancing technique into a strong RL algorithm, all within the organized and powerful framework of PyTorch Lightning. It’s a robust, complex, and highly illustrative example of modern deep learning engineering.