Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mojo_opset/backends/ttx/operators/moe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from typing import Optional
from typing import Tuple

from mojo_opset.backends.ttx.kernels import moe_combine
Expand All @@ -21,15 +22,19 @@ class TTXMoEGating(MojoMoEGating):
def forward(
self,
hidden_states: torch.Tensor, # (num_tokens, hidden_size)
forced_expert_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns (top_k_indices, top_k_gates).

Args:
hidden_states: (num_tokens, hidden_size), fp16/bf16/fp32.
forced_expert_ids: Optional (num_tokens, top_k) expert ids.
Returns:
top_k_indices: (num_tokens, top_k), int32.
top_k_gates: (num_tokens, top_k), fp32.
"""
if forced_expert_ids is not None:
return super().forward(hidden_states, forced_expert_ids=forced_expert_ids)
assert self.gate_weight.dtype == torch.float32
return moe_gating(hidden_states, self.gate_weight, self.top_k)

Expand Down
39 changes: 34 additions & 5 deletions mojo_opset/core/operators/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
**kwargs,
)

def forward(self, hidden_states):
def forward(self, hidden_states, forced_expert_ids: Optional[torch.Tensor] = None):
# hidden_states: [num_tokens, H]
# DP input: gather peer ranks' shards so gating/dispatch see the full token set.
if self.dp_input and self.ep_size > 1:
Expand All @@ -87,8 +87,17 @@ def forward(self, hidden_states):
)
dist.all_gather_into_tensor(full, hidden_states.contiguous(), group=self.ep_group)
hidden_states = full
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
Comment on lines +90 to +98

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If forced_expert_ids is on CPU (or a different device than hidden_states), creating full_forced_expert_ids on forced_expert_ids.device and calling dist.all_gather_into_tensor will result in a runtime error (especially when using the NCCL backend, which requires CUDA tensors).

To prevent this, ensure forced_expert_ids is moved to hidden_states.device before allocating the empty tensor and performing the collective operation.

Suggested change
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
if forced_expert_ids is not None:
forced_expert_ids = forced_expert_ids.to(device=hidden_states.device)
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=hidden_states.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids


top_k_indices, top_k_gates = self.gating(hidden_states)
top_k_indices, top_k_gates = self.gating(hidden_states, forced_expert_ids=forced_expert_ids)
# top_k_indices, top_k_gates: [num_tokens, top_k]
sorted_hidden_states, tokens_per_expert, sorted_gates, token_indices = self.dispatch(
hidden_states, top_k_gates, top_k_indices
Expand Down Expand Up @@ -227,7 +236,7 @@ def __init__(
**kwargs,
)

def forward(self, hidden_states):
def forward(self, hidden_states, forced_expert_ids: Optional[torch.Tensor] = None):
# DP input: gather peer ranks' shards so gating/dispatch see the full token set.
if self.dp_input and self.ep_size > 1:
local_tokens = hidden_states.shape[0]
Expand All @@ -237,8 +246,17 @@ def forward(self, hidden_states):
)
dist.all_gather_into_tensor(full, hidden_states.contiguous(), group=self.ep_group)
hidden_states = full
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
Comment on lines +249 to +257

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If forced_expert_ids is on CPU (or a different device than hidden_states), creating full_forced_expert_ids on forced_expert_ids.device and calling dist.all_gather_into_tensor will result in a runtime error (especially when using the NCCL backend, which requires CUDA tensors).

To prevent this, ensure forced_expert_ids is moved to hidden_states.device before allocating the empty tensor and performing the collective operation.

Suggested change
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
if forced_expert_ids is not None:
forced_expert_ids = forced_expert_ids.to(device=hidden_states.device)
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=hidden_states.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids


top_k_indices, top_k_gates = self.gating(hidden_states)
top_k_indices, top_k_gates = self.gating(hidden_states, forced_expert_ids=forced_expert_ids)
sorted_hidden_states, tokens_per_expert, sorted_gates, token_indices = self.dispatch(
hidden_states,
top_k_gates,
Expand Down Expand Up @@ -299,12 +317,14 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
forced_expert_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for MoE Gating operator.

Input:
- hidden_states (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
- forced_expert_ids (Optional[torch.Tensor]): Expert ids of shape [num_tokens, top_k].

Output:
- top_k_indices (torch.Tensor): Output tensor of shape [num_tokens, top_k].
Expand All @@ -313,7 +333,16 @@ def forward(
assert self.gate_weight.dtype == torch.float32
gate_logits = torch.matmul(hidden_states.float(), self.gate_weight)
gate_logits = torch.softmax(gate_logits, dim=-1)
top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
if forced_expert_ids is None:
top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
else:
expected_shape = (hidden_states.shape[0], self.top_k)
if tuple(forced_expert_ids.shape) != expected_shape:
raise ValueError(
f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}."
)
top_k_indices = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64)
top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is highly recommended to validate that the indices in forced_expert_ids are within the valid range of [0, num_experts). If an out-of-bounds index is provided, torch.gather will fail or trigger an illegal memory access on GPU, which can crash the entire process.

Suggested change
else:
expected_shape = (hidden_states.shape[0], self.top_k)
if tuple(forced_expert_ids.shape) != expected_shape:
raise ValueError(
f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}."
)
top_k_indices = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64)
top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices)
else:
expected_shape = (hidden_states.shape[0], self.top_k)
if tuple(forced_expert_ids.shape) != expected_shape:
raise ValueError(
f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}."
)
if (forced_expert_ids < 0).any() or (forced_expert_ids >= gate_logits.shape[-1]).any():
raise ValueError(
f"forced_expert_ids must contain values in range [0, {gate_logits.shape[-1] - 1}]."
)
top_k_indices = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64)
top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices)

top_k_gates = top_k_logits / torch.sum(top_k_logits, dim=-1, keepdim=True)
return top_k_indices.to(torch.int32), top_k_gates

Expand Down
111 changes: 111 additions & 0 deletions mojo_opset/tests/accuracy/operators/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,54 @@ def test_moe(num_experts, top_k, hidden_size, intermediate_size, num_tokens, dty
moe.forward_diff_with(moe_ref, x, mixed_tol=True)


@pytest.mark.parametrize(
"num_experts, top_k, hidden_size, intermediate_size, num_tokens",
[
(16, 4, 1024, 2048, 64),
(32, 8, 1024, 4096, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@bypass_not_implemented
def test_moe_forced_expert_ids_matches_topk_route(
num_experts, top_k, hidden_size, intermediate_size, num_tokens, dtype
):
device = get_torch_device()
torch.manual_seed(0)

moe = MojoMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
)

moe_ref = MojoMoE._registry.get("torch")(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
)

for p in moe_ref.parameters():
nn.init.normal_(p, std=0.02)

moe = moe.to(dtype).to(device)
moe_ref = moe_ref.to(dtype).to(device)
moe.load_state_dict(moe_ref.state_dict())

# FIXME: moe.gating.gate_weight.data should not be casted to float32
moe.gating.gate_weight.data = moe.gating.gate_weight.data.float()
moe_ref.gating.gate_weight.data = moe_ref.gating.gate_weight.data.float()

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device=device)
forced_expert_ids, _ = moe_ref.gating(x)

out = moe(x, forced_expert_ids=forced_expert_ids)
out_ref = moe_ref(x)
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)


@pytest.mark.parametrize(
"num_experts, top_k, hidden_size, num_tokens",
[
Expand Down Expand Up @@ -97,3 +145,66 @@ def test_moe_gating(num_experts, top_k, hidden_size, num_tokens, dtype):
rtol=(0, 1e-2),
ptol=(0.999, 1.0),
)


@pytest.mark.parametrize(
"num_experts, top_k, hidden_size, num_tokens",
[
(16, 4, 1024, 64),
(32, 8, 1024, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@bypass_not_implemented
def test_moe_gating_forced_expert_ids_matches_topk_route(num_experts, top_k, hidden_size, num_tokens, dtype):
device = get_torch_device()
torch.manual_seed(0)

moe_gating = MojoMoEGating(
hidden_size=hidden_size,
num_experts=num_experts,
top_k=top_k,
)

for p in moe_gating.parameters():
nn.init.normal_(p, std=0.02)

moe_gating = moe_gating.to(device)
assert moe_gating.gate_weight.dtype == torch.float32

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device=device)
top_k_indices, top_k_gates = moe_gating(x)
forced_indices, forced_gates = moe_gating(x, forced_expert_ids=top_k_indices)

torch.testing.assert_close(forced_indices, top_k_indices, atol=0, rtol=0)
torch.testing.assert_close(forced_gates, top_k_gates, atol=1e-5, rtol=1e-5)


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@bypass_not_implemented
def test_moe_gating_forced_expert_ids_gathers_requested_routes(dtype):
device = get_torch_device()
torch.manual_seed(0)

moe_gating = MojoMoEGating(
hidden_size=8,
num_experts=4,
top_k=2,
)

for p in moe_gating.parameters():
nn.init.normal_(p, std=0.02)

moe_gating = moe_gating.to(device)
assert moe_gating.gate_weight.dtype == torch.float32

x = torch.rand(3, 8, dtype=dtype, device=device)
forced_expert_ids = torch.tensor([[0, 1], [2, 3], [3, 0]], dtype=torch.int64, device=device)
forced_indices, forced_gates = moe_gating(x, forced_expert_ids=forced_expert_ids)

gate_probs = torch.softmax(torch.matmul(x.float(), moe_gating.gate_weight), dim=-1)
expected_gates = torch.gather(gate_probs, dim=-1, index=forced_expert_ids)
expected_gates = expected_gates / torch.sum(expected_gates, dim=-1, keepdim=True)

torch.testing.assert_close(forced_indices, forced_expert_ids.to(torch.int32), atol=0, rtol=0)
torch.testing.assert_close(forced_gates, expected_gates, atol=1e-5, rtol=1e-5)
Loading