Support router replay#353
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for forced_expert_ids in the Mixture of Experts (MoE) gating and forward pass implementations, allowing specific expert routing to be forced. It updates the core and backend MoE operators to accept, gather, and apply these forced expert IDs, and adds corresponding accuracy tests. The review feedback highlights critical improvements: ensuring forced_expert_ids is moved to the correct device before performing collective operations to prevent runtime errors, and validating that the forced expert indices are within the valid range to avoid illegal memory access during gathering.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if forced_expert_ids is not None: | ||
| full_forced_expert_ids = torch.empty( | ||
| local_tokens * self.ep_size, | ||
| *forced_expert_ids.shape[1:], | ||
| dtype=forced_expert_ids.dtype, | ||
| device=forced_expert_ids.device, | ||
| ) | ||
| dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group) | ||
| forced_expert_ids = full_forced_expert_ids |
There was a problem hiding this comment.
If forced_expert_ids is on CPU (or a different device than hidden_states), creating full_forced_expert_ids on forced_expert_ids.device and calling dist.all_gather_into_tensor will result in a runtime error (especially when using the NCCL backend, which requires CUDA tensors).
To prevent this, ensure forced_expert_ids is moved to hidden_states.device before allocating the empty tensor and performing the collective operation.
| if forced_expert_ids is not None: | |
| full_forced_expert_ids = torch.empty( | |
| local_tokens * self.ep_size, | |
| *forced_expert_ids.shape[1:], | |
| dtype=forced_expert_ids.dtype, | |
| device=forced_expert_ids.device, | |
| ) | |
| dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group) | |
| forced_expert_ids = full_forced_expert_ids | |
| if forced_expert_ids is not None: | |
| forced_expert_ids = forced_expert_ids.to(device=hidden_states.device) | |
| full_forced_expert_ids = torch.empty( | |
| local_tokens * self.ep_size, | |
| *forced_expert_ids.shape[1:], | |
| dtype=forced_expert_ids.dtype, | |
| device=hidden_states.device, | |
| ) | |
| dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group) | |
| forced_expert_ids = full_forced_expert_ids |
| if forced_expert_ids is not None: | ||
| full_forced_expert_ids = torch.empty( | ||
| local_tokens * self.ep_size, | ||
| *forced_expert_ids.shape[1:], | ||
| dtype=forced_expert_ids.dtype, | ||
| device=forced_expert_ids.device, | ||
| ) | ||
| dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group) | ||
| forced_expert_ids = full_forced_expert_ids |
There was a problem hiding this comment.
If forced_expert_ids is on CPU (or a different device than hidden_states), creating full_forced_expert_ids on forced_expert_ids.device and calling dist.all_gather_into_tensor will result in a runtime error (especially when using the NCCL backend, which requires CUDA tensors).
To prevent this, ensure forced_expert_ids is moved to hidden_states.device before allocating the empty tensor and performing the collective operation.
| if forced_expert_ids is not None: | |
| full_forced_expert_ids = torch.empty( | |
| local_tokens * self.ep_size, | |
| *forced_expert_ids.shape[1:], | |
| dtype=forced_expert_ids.dtype, | |
| device=forced_expert_ids.device, | |
| ) | |
| dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group) | |
| forced_expert_ids = full_forced_expert_ids | |
| if forced_expert_ids is not None: | |
| forced_expert_ids = forced_expert_ids.to(device=hidden_states.device) | |
| full_forced_expert_ids = torch.empty( | |
| local_tokens * self.ep_size, | |
| *forced_expert_ids.shape[1:], | |
| dtype=forced_expert_ids.dtype, | |
| device=hidden_states.device, | |
| ) | |
| dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group) | |
| forced_expert_ids = full_forced_expert_ids |
| else: | ||
| expected_shape = (hidden_states.shape[0], self.top_k) | ||
| if tuple(forced_expert_ids.shape) != expected_shape: | ||
| raise ValueError( | ||
| f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}." | ||
| ) | ||
| top_k_indices = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64) | ||
| top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices) |
There was a problem hiding this comment.
It is highly recommended to validate that the indices in forced_expert_ids are within the valid range of [0, num_experts). If an out-of-bounds index is provided, torch.gather will fail or trigger an illegal memory access on GPU, which can crash the entire process.
| else: | |
| expected_shape = (hidden_states.shape[0], self.top_k) | |
| if tuple(forced_expert_ids.shape) != expected_shape: | |
| raise ValueError( | |
| f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}." | |
| ) | |
| top_k_indices = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64) | |
| top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices) | |
| else: | |
| expected_shape = (hidden_states.shape[0], self.top_k) | |
| if tuple(forced_expert_ids.shape) != expected_shape: | |
| raise ValueError( | |
| f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}." | |
| ) | |
| if (forced_expert_ids < 0).any() or (forced_expert_ids >= gate_logits.shape[-1]).any(): | |
| raise ValueError( | |
| f"forced_expert_ids must contain values in range [0, {gate_logits.shape[-1] - 1}]." | |
| ) | |
| top_k_indices = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64) | |
| top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices) |
Claude Code ReviewVerdict: Comment -- Adds an optional SummaryThe PR threads an optional Must fix
SuggestionsSuggestions (3)
Notes
|
No description provided.