Implementation Guide
Preliminary: the Original Paper of DeepSeekV3
Auxiliary-Loss-Free Load Balancing. For MoE models, an unbalanced expert load will lead to routing collapse (Shazeer et al., 2017) and diminish computational efficiency in scenarios with expert parallelism. Conventional solutions usually rely on the auxiliary loss (Fedus et al., 2021; Lepikhin et al., 2021) to avoid unbalanced load. However, too large an auxiliary loss will impair the model performance (Wang et al., 2024a). To achieve a better trade-off between load balance and model performance, we pioneer an auxiliary-loss-free load balancing strategy (Wang et al., 2024a) to ensure load balance. To be specific, we introduce a bias term $b_i$ for each expert and add it to the corresponding affinity scores $s_{i, t}$ to determine the top-K routing:
$$ g'_{i,t} = \begin{cases} s_{i,t}, & \text{if } s_{i,t} + b_i \in \mathrm{TopK}\left(\{ s_{j,t} + b_j \mid 1 \leq j \leq N_r \}, K_r \right) \\ 0, & \text{otherwise} \end{cases} $$Note that the bias term is only used for routing. The gating value, which will be multiplied with the FFN output, is still derived from the original affinity score $s_{i, t}$. During training, we keep monitoring the expert load on the whole batch of each training step. At the end of each step, we will decrease the bias term by $\gamma$ if its corresponding expert is overloaded, and increase it by $\gamma$ if its corresponding expert is underloaded, where $\gamma$ is a hyper-parameter called bias update speed. Through the dynamic adjustment, DeepSeek-V3 keeps balanced expert load during training, and achieves better performance than models that encourage load balance through pure auxiliary losses.
Let’s break down how to implement this sophisticated mechanism into our GROP PyTorch Lightning model.
1. The Core Concept: Auxiliary-Loss-Free Load Balancing
First, let’s internalize the logic you provided:
- Problem: In a Mixture of Experts (MoE) model, a “router” network sends each input token to one or more “expert” sub-networks. If the router consistently sends too many tokens to a few popular experts (“routing collapse”), we lose the computational benefit of parallelism and the model’s capacity is underutilized.
- Traditional Solution: Add an auxiliary loss term that penalizes imbalance. This forces the router to learn to distribute the load. However, this loss can interfere with the primary task’s loss (e.g., our GROP loss), potentially harming performance.
- DeepSeek’s Solution: Decouple the balancing mechanism from the model’s learning gradients.
- Biased Routing: Introduce a non-trainable bias term
b_ifor each experti. The routing decision (which experts to pick) is made usinglogits + bias. - Original Gating: The actual weight applied to an expert’s output is still calculated from the original, unbiased
logits. This is crucial: the balancing mechanism affects which expert is chosen, but not how much its output is valued. - Manual Bias Update: At the end of each training step, manually adjust the biases.
- If an expert was overloaded, decrease its bias. This makes it slightly less likely to be chosen in the next step.
- If an expert was underloaded, increase its bias. This makes it slightly more likely to be chosen.
- Result: A dynamic, self-correcting system that balances load without adding a conflicting loss term to the main optimization objective.
- Biased Routing: Introduce a non-trainable bias term
2. Implementation Strategy within PyTorch Lightning
We will modify our previous GROP implementation by replacing a standard Feed-Forward Network (FFN) layer with a custom MoE layer that incorporates this balancing logic.
Here’s the plan:
- Create an
ExpertModule: A simple MLP that will serve as our expert network. - Create the
MoELayerModule: This is the heart of the new logic. It will contain:- A gating network (the router).
- A
ModuleListofExperts. - The load-balancing
biastensor, registered as a buffer. - The
forwardmethod implementing the biased Top-K routing, original-logit gating, and token dispatching.
- Create an
ActorCriticMoENetwork: This will be a new version of ourActorCriticnetwork, which uses theMoELayerinstead of a standardnn.Linear. - Update the
GROPLightningModule:- It will now use the
ActorCriticMoEnetwork. - The
training_stepwill collect expert load statistics throughout an episode. - We will use the PyTorch Lightning hook
on_train_batch_endto perform the manual bias update after each optimizer step. This is the perfect place for logic that shouldn’t be part of the autograd graph.
- It will now use the
3. The “Best Version” Code with MoE Load Balancing
Here is the complete, modified code. I’ve heavily commented the new sections to explain what’s happening.
1import gym
2import numpy as np
3import torch
4import torch.nn as nn
5import torch.optim as optim
6import torch.nn.functional as F
7from torch.distributions import Categorical
8
9import pytorch_lightning as pl
10from pytorch_lightning.callbacks import ModelCheckpoint
11from pytorch_lightning.loggers import TensorBoardLogger
12
13# --- MoE Components ---
14
15class Expert(nn.Module):
16 """A simple two-layer MLP to serve as an expert."""
17 def __init__(self, input_dim, hidden_dim, output_dim):
18 super().__init__()
19 self.net = nn.Sequential(
20 nn.Linear(input_dim, hidden_dim),
21 nn.ReLU(),
22 nn.Linear(hidden_dim, output_dim)
23 )
24 def forward(self, x):
25 return self.net(x)
26
27class MoELayer(nn.Module):
28 """
29 The core Mixture of Experts layer with Auxiliary-Loss-Free Load Balancing.
30 """
31 def __init__(self, input_dim, output_dim, num_experts, top_k):
32 super().__init__()
33 self.input_dim = input_dim
34 self.output_dim = output_dim
35 self.num_experts = num_experts
36 self.top_k = top_k
37
38 # The router/gating network that produces affinity scores (logits)
39 self.gating_network = nn.Linear(input_dim, num_experts)
40
41 # The list of expert networks
42 self.experts = nn.ModuleList([Expert(input_dim, input_dim * 2, output_dim) for _ in range(num_experts)])
43
44 # The load-balancing bias term. Registered as a buffer so it's part of the
45 # module's state (saved with checkpoints) but not considered a trainable parameter by optimizers.
46 self.register_buffer('bias', torch.zeros(1, num_experts))
47 # We also need to track the load for the update rule
48 self.register_buffer('expert_load', torch.zeros(num_experts))
49
50 def forward(self, x):
51 """
52 x: input tensor of shape (batch_size, input_dim)
53 """
54 batch_size, _ = x.shape
55
56 # 1. Calculate affinity scores from the router
57 original_logits = self.gating_network(x) # shape: (batch_size, num_experts)
58
59 # 2. Add the balancing bias for routing decisions ONLY
60 biased_logits = original_logits + self.bias
61
62 # 3. Perform Top-K routing on biased logits
63 # `top_k_weights` are not used directly, `top_k_indices` are what we need
64 _, top_k_indices = torch.topk(biased_logits, self.top_k, dim=1) # shape: (batch_size, top_k)
65
66 # 4. Calculate gating values from ORIGINAL logits for the selected experts
67 # We use `gather` to select the original logits corresponding to the chosen experts
68 selected_original_logits = torch.gather(original_logits, 1, top_k_indices)
69 gating_weights = F.softmax(selected_original_logits, dim=1) # shape: (batch_size, top_k)
70
71 # 5. Dispatch tokens to experts and combine outputs
72 # This is a dispatch-scatter operation. We'll implement it efficiently.
73 final_output = torch.zeros(batch_size, self.output_dim, device=x.device)
74
75 # Create a flat list of tokens and their assigned expert indices
76 flat_top_k_indices = top_k_indices.flatten() # (batch_size * top_k)
77
78 # Track the load for the bias update rule
79 # `bincount` is a highly efficient way to count occurrences
80 self.expert_load = torch.bincount(flat_top_k_indices, minlength=self.num_experts)
81
82 # Create a one-hot mask for dispatching
83 # shape: (batch_size * top_k, num_experts)
84 one_hot_mask = F.one_hot(flat_top_k_indices, num_classes=self.num_experts)
85
86 # Loop over experts to process tokens in batches
87 for i in range(self.num_experts):
88 # Find which tokens (from the repeated batch) are routed to this expert
89 # shape: (batch_size * top_k)
90 expert_mask = one_hot_mask[:, i].bool()
91
92 if expert_mask.sum() == 0:
93 continue
94
95 # `x.repeat_interleave` creates the (batch_size * top_k, dim) tensor we need
96 expert_inputs = x.repeat_interleave(self.top_k, dim=0)[expert_mask]
97
98 # Process inputs through the expert
99 expert_outputs = self.experts[i](expert_inputs)
100
101 # Get the gating weights for these specific tokens
102 # shape: (num_tokens_for_this_expert, 1)
103 weights_for_expert = gating_weights.flatten()[expert_mask].unsqueeze(1)
104
105 # Weight the expert outputs
106 weighted_outputs = expert_outputs * weights_for_expert
107
108 # Scatter the results back to the correct original batch positions
109 # `index_add_` is a memory-efficient scatter-add operation.
110 # We need to map the flat indices back to the original batch dimension.
111 original_batch_indices = (torch.arange(batch_size, device=x.device)
112 .repeat_interleave(self.top_k))[expert_mask]
113
114 final_output.index_add_(0, original_batch_indices, weighted_outputs)
115
116 return final_output, self.expert_load
117
118# --- Updated Actor-Critic Network with MoE ---
119
120class ActorCriticMoE(nn.Module):
121 def __init__(self, state_dim, action_dim, hidden_dim=128, num_experts=8, top_k=2):
122 super(ActorCriticMoE, self).__init__()
123 self.shared_base = nn.Sequential(
124 nn.Linear(state_dim, hidden_dim),
125 nn.ReLU()
126 )
127
128 self.moe_layer = MoELayer(
129 input_dim=hidden_dim,
130 output_dim=hidden_dim,
131 num_experts=num_experts,
132 top_k=top_k
133 )
134
135 self.actor_head = nn.Linear(hidden_dim, action_dim)
136 self.critic_head = nn.Linear(hidden_dim, 1)
137
138 def forward(self, state):
139 base_features = self.shared_base(state)
140 moe_output, expert_load = self.moe_layer(base_features)
141
142 # Add a residual connection (common practice)
143 x = F.relu(moe_output + base_features)
144
145 action_logits = self.actor_head(x)
146 state_value = self.critic_head(x)
147
148 # We must return the expert_load to be used by the LightningModule
149 return action_logits, state_value, expert_load
150
151# --- The GROP LightningModule with MoE Load Balancing ---
152
153class GROPLightningMoEModule(pl.LightningModule):
154 def __init__(
155 self,
156 env_name: str = "CartPole-v1",
157 lr: float = 3e-4,
158 gamma: float = 0.99,
159 gae_lambda: float = 0.95,
160 critic_loss_coeff: float = 0.5,
161 entropy_coeff: float = 0.01,
162 # MoE specific hparams
163 num_experts: int = 8,
164 top_k: int = 2,
165 bias_update_speed: float = 1e-2,
166 ):
167 super().__init__()
168 self.save_hyperparameters()
169
170 self.env = gym.make(self.hparams.env_name)
171 state_dim = self.env.observation_space.shape[0]
172 action_dim = self.env.action_space.n
173
174 self.total_reward = 0
175 self.episode_count = 0
176
177 self.network = ActorCriticMoE(
178 state_dim, action_dim,
179 num_experts=self.hparams.num_experts,
180 top_k=self.hparams.top_k
181 )
182
183 # `forward` needs to handle the multiple return values from the network
184 def forward(self, state):
185 return self.network(state)
186
187 def get_action(self, state):
188 state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
189 logits, _, _ = self.forward(state_tensor)
190 dist = Categorical(logits=logits)
191 action = dist.sample()
192 return action.item()
193
194 def _calculate_advantages_gae(self, rewards, values, dones):
195 advantages = []
196 gae = 0
197 next_value = values[-1]
198 for i in reversed(range(len(rewards) - 1)):
199 reward, done_mask, value = rewards[i], 1.0 - dones[i], values[i]
200 delta = reward + self.hparams.gamma * next_value * done_mask - value
201 gae = delta + self.hparams.gamma * self.hparams.gae_lambda * done_mask * gae
202 advantages.insert(0, gae)
203 next_value = value
204 return torch.tensor(advantages, dtype=torch.float32, device=self.device)
205
206 def training_step(self, batch, batch_idx):
207 # 1. Generate Experience
208 state = self.env.reset()
209 done = False
210 states, actions, rewards, dones, logits_list, values = [], [], [], [], [], []
211
212 # We need to accumulate expert load over the whole episode
213 total_expert_load_in_episode = torch.zeros(self.hparams.num_experts, device=self.device)
214
215 while not done:
216 action = self.get_action(state)
217
218 state_tensor = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
219 # Capture all outputs from the forward pass
220 logits, value, expert_load = self.forward(state_tensor)
221
222 next_state, reward, done, _ = self.env.step(action)
223
224 states.append(state)
225 actions.append(action)
226 rewards.append(reward)
227 dones.append(done)
228 logits_list.append(logits)
229 values.append(value)
230 total_expert_load_in_episode += expert_load
231
232 state = next_state
233
234 self.total_reward = sum(rewards)
235 self.episode_count += 1
236 self.log("episode_reward", self.total_reward, on_step=True, on_epoch=True, prog_bar=True)
237
238 # 2. Process Experience and Calculate Losses
239 actions_v = torch.tensor(actions, dtype=torch.int64, device=self.device).view(-1, 1)
240 values_v = torch.cat(values).squeeze(-1)
241 all_logits = torch.cat(logits_list)
242
243 advantages_v = self._calculate_advantages_gae(rewards, values_v.cpu().detach().numpy(), dones)
244 returns_v = advantages_v + values_v.detach()
245 advantages_v = (advantages_v - advantages_v.mean()) / (advantages_v.std() + 1e-8)
246
247 # 3. GROP Actor Loss (unchanged)
248 chosen_action_logits = all_logits.gather(1, actions_v).squeeze(-1)
249 log_sum_exp_term = torch.logsumexp(all_logits - chosen_action_logits.unsqueeze(-1), dim=1)
250 actor_loss = (advantages_v * log_sum_exp_term).mean()
251
252 # 4. Critic Loss (unchanged)
253 critic_loss = F.mse_loss(values_v, returns_v)
254
255 # 5. Entropy Bonus (unchanged)
256 dist = Categorical(logits=all_logits)
257 entropy_loss = dist.entropy().mean()
258
259 # 6. Total Loss
260 total_loss = (
261 actor_loss
262 + self.hparams.critic_loss_coeff * critic_loss
263 - self.hparams.entropy_coeff * entropy_loss
264 )
265
266 # Log everything, including the expert load for this step
267 self.log_dict({
268 "total_loss": total_loss,
269 "actor_loss": actor_loss,
270 "critic_loss": critic_loss
271 }, on_step=True, on_epoch=True)
272
273 # CRUCIAL: Return the accumulated expert load to the hook
274 return {"loss": total_loss, "expert_load": total_expert_load_in_episode}
275
276 def on_train_batch_end(self, outputs, batch, batch_idx):
277 """
278 This hook is called after the training step and the optimizer step.
279 This is where we implement the auxiliary-loss-free update rule.
280 """
281 expert_load = outputs['expert_load'] # This comes from the training_step's return dict
282 total_tokens_routed = expert_load.sum()
283
284 if total_tokens_routed == 0:
285 return # Avoid division by zero if an episode was empty
286
287 # Define ideal load: total tokens * k / num_experts
288 ideal_load_per_expert = total_tokens_routed * self.hparams.top_k / self.hparams.num_experts
289
290 # Find overloaded and underloaded experts
291 overloaded_mask = expert_load > ideal_load_per_expert
292 underloaded_mask = expert_load < ideal_load_per_expert
293
294 # Update the bias tensor. This happens outside the autograd graph.
295 with torch.no_grad():
296 self.network.moe_layer.bias[0, overloaded_mask] -= self.hparams.bias_update_speed
297 self.network.moe_layer.bias[0, underloaded_mask] += self.hparams.bias_update_speed
298
299 # Log for monitoring
300 load_percentage = expert_load / total_tokens_routed if total_tokens_routed > 0 else expert_load
301 for i in range(self.hparams.num_experts):
302 self.log(f"expert_{i}/load_pct", load_percentage[i], on_step=True, on_epoch=False)
303 self.log(f"expert_{i}/bias", self.network.moe_layer.bias[0, i], on_step=True, on_epoch=False)
304
305 # Log load variance as a single metric for balance
306 load_variance = torch.var(load_percentage)
307 self.log("expert_load_variance", load_variance, on_step=True, on_epoch=False)
308
309 def configure_optimizers(self):
310 optimizer = optim.Adam(self.network.parameters(), lr=self.hparams.lr)
311 return optimizer
312
313 def train_dataloader(self):
314 return [0]
315
316# --- Main Training Script ---
317if __name__ == '__main__':
318 pl.seed_everything(42)
319
320 # --- Configuration ---
321 ENV_NAME = "CartPole-v1"
322 MAX_EPOCHS = 1000 # MoE might take a bit longer to stabilize
323
324 # --- Setup ---
325 model = GROPLightningMoEModule(
326 env_name=ENV_NAME,
327 lr=3e-4,
328 gamma=0.99,
329 gae_lambda=0.95,
330 num_experts=8,
331 top_k=2,
332 bias_update_speed=0.01
333 )
334
335 logger = TensorBoardLogger("tb_logs", name="grop_moe_cartpole")
336 checkpoint_callback = ModelCheckpoint(monitor="episode_reward", dirpath="checkpoints_moe/", mode="max")
337
338 trainer = pl.Trainer(
339 max_epochs=MAX_EPOCHS,
340 logger=logger,
341 callbacks=[checkpoint_callback],
342 accelerator="auto",
343 gradient_clip_val=0.5 # Good practice for more complex models
344 )
345
346 trainer.fit(model)
4. How to Run and What to Monitor
- Run the script:
python your_script_name.py - Open TensorBoard:
tensorboard --logdir tb_logs/
Now, in TensorBoard, you can observe the magic of the load balancing mechanism:
episode_reward: This should still go up, showing the model is learning the task.expert_load_variance: This is the key metric for balancing. You should see this value decrease over time and stabilize at a very low number. This indicates that the load across experts is becoming more uniform.expert_i/load_pct: Check the individual expert load percentages. Initially, some might be very high and others zero. Over time, they should all converge towards the ideal load (top_k / num_experts, which is2/8 = 25%in our default case).expert_i/bias: Watch the bias values. You’ll see the biases of initially overloaded experts going down (becoming negative) and the biases of underloaded experts going up (becoming positive). This is the dynamic adjustment in action.
This implementation successfully integrates a cutting-edge MoE load-balancing technique into a strong RL algorithm, all within the organized and powerful framework of PyTorch Lightning. It’s a robust, complex, and highly illustrative example of modern deep learning engineering.