Skip to content

Support router replay#353

Open
zhanzy178 wants to merge 2 commits into
masterfrom
zzy/router_replay
Open

Support router replay#353
zhanzy178 wants to merge 2 commits into
masterfrom
zzy/router_replay

Conversation

@zhanzy178

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for forced_expert_ids in the Mixture of Experts (MoE) gating and forward pass implementations, allowing specific expert routing to be forced. It updates the core and backend MoE operators to accept, gather, and apply these forced expert IDs, and adds corresponding accuracy tests. The review feedback highlights critical improvements: ensuring forced_expert_ids is moved to the correct device before performing collective operations to prevent runtime errors, and validating that the forced expert indices are within the valid range to avoid illegal memory access during gathering.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +90 to +98
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If forced_expert_ids is on CPU (or a different device than hidden_states), creating full_forced_expert_ids on forced_expert_ids.device and calling dist.all_gather_into_tensor will result in a runtime error (especially when using the NCCL backend, which requires CUDA tensors).

To prevent this, ensure forced_expert_ids is moved to hidden_states.device before allocating the empty tensor and performing the collective operation.

Suggested change
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
if forced_expert_ids is not None:
forced_expert_ids = forced_expert_ids.to(device=hidden_states.device)
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=hidden_states.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids

Comment on lines +249 to +257
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If forced_expert_ids is on CPU (or a different device than hidden_states), creating full_forced_expert_ids on forced_expert_ids.device and calling dist.all_gather_into_tensor will result in a runtime error (especially when using the NCCL backend, which requires CUDA tensors).

To prevent this, ensure forced_expert_ids is moved to hidden_states.device before allocating the empty tensor and performing the collective operation.

Suggested change
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
if forced_expert_ids is not None:
forced_expert_ids = forced_expert_ids.to(device=hidden_states.device)
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=hidden_states.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids

Comment thread mojo_opset/core/operators/moe.py Outdated
Comment on lines +338 to +345
else:
expected_shape = (hidden_states.shape[0], self.top_k)
if tuple(forced_expert_ids.shape) != expected_shape:
raise ValueError(
f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}."
)
top_k_indices = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64)
top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is highly recommended to validate that the indices in forced_expert_ids are within the valid range of [0, num_experts). If an out-of-bounds index is provided, torch.gather will fail or trigger an illegal memory access on GPU, which can crash the entire process.

Suggested change
else:
expected_shape = (hidden_states.shape[0], self.top_k)
if tuple(forced_expert_ids.shape) != expected_shape:
raise ValueError(
f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}."
)
top_k_indices = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64)
top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices)
else:
expected_shape = (hidden_states.shape[0], self.top_k)
if tuple(forced_expert_ids.shape) != expected_shape:
raise ValueError(
f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}."
)
if (forced_expert_ids < 0).any() or (forced_expert_ids >= gate_logits.shape[-1]).any():
raise ValueError(
f"forced_expert_ids must contain values in range [0, {gate_logits.shape[-1] - 1}]."
)
top_k_indices = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64)
top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices)

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Comment -- Adds an optional forced_expert_ids routing override to MoE gating with appropriate DP all-gather, but the TTX subclass path is incomplete.

Summary

The PR threads an optional forced_expert_ids tensor through MojoMoE/MojoQuantMoE and MojoMoEGating.forward, including all-gather of the forced ids under DP. The TTX gating override falls back to the base class when forced ids are provided, otherwise it uses the existing fused kernel.

Must fix

  • [BLOCKER] TTX fallback bypasses TTX gate_weight dtype assumption -- mojo_opset/backends/ttx/operators/moe.py:37-38 -- When forced_expert_ids is provided, the call goes to super().forward, which asserts gate_weight.dtype == torch.float32 and computes in float; verify this matches how TTX stores gate_weight in production (the test forces .float() with a FIXME). If TTX normally keeps gate_weight in bf16/fp16, this fallback will assert at runtime.

Suggestions

Suggestions (3)
  • [MAJOR] Duplicated DP all-gather block -- mojo_opset/core/operators/moe.py:80-97, 239-255 -- Identical forced-ids gather logic is copy-pasted into MojoMoE.forward and MojoQuantMoE.forward; factor into a helper to avoid drift.
  • [MAJOR] No validation that forced_expert_ids ids are in range -- mojo_opset/core/operators/moe.py:338-345 -- Out-of-range expert ids will silently produce garbage via torch.gather / downstream dispatch; consider an assert (at least in debug) that ids are in [0, num_experts).
  • [MINOR] Renormalizing forced gates may not be desired -- mojo_opset/core/operators/moe.py:347 -- After gather, dividing by sum still renormalizes; if the caller wants to preserve the original softmax mass on chosen experts (e.g. for speculative routing), this silently changes semantics. Worth a docstring note.

Notes

  • [CHECK] mojo_opset/core/operators/moe.py:88-96 -- all_gather_into_tensor requires the output to be contiguous and shape [world_size * local, ...]; with *forced_expert_ids.shape[1:] unpacking into torch.empty, confirm the resulting layout matches what NCCL expects when top_k dim is the only trailing dim (should be fine, but worth a sanity test under EP > 1).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant