Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions mojo_opset/backends/ixformer/operators/moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
import torch.distributed as dist
from typing import Optional
from typing import Union

from mojo_opset.core import MojoMoE
Expand Down Expand Up @@ -152,7 +153,11 @@ def __del__(self):
if hasattr(self, "gdr_buffer_ptr"):
ixf_f.delete_gdr_buffer(self.gdr_buffer_ptr)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
forced_expert_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
enable_cuda_graph = True
else:
Expand All @@ -170,6 +175,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
)
ixfd.all_gather_into_tensor(full, hidden_states.contiguous(), group=self.ep_group, async_op=True)
hidden_states = full
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
ixfd.all_gather_into_tensor(
full_forced_expert_ids,
forced_expert_ids.contiguous(),
group=self.ep_group,
async_op=True,
)
forced_expert_ids = full_forced_expert_ids

# triple_gemm uses 3 bf16 components of the fp32 gate weight to emulate fp32 matmul precision on bf16 HW.
gate_logits = ixf_f.triple_gemm_bf16_bf16_fp32(
Expand All @@ -179,6 +198,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
self.gating.gate_weight_bf16_tn_0,
)
top_k_gates, top_k_indices = ixf_f.moe_topk_softmax(gate_logits, self.gating.top_k, renormalize=True)
if forced_expert_ids is not None:
expected_shape = (hidden_states.shape[0], self.top_k)
if tuple(forced_expert_ids.shape) != expected_shape:
raise ValueError(
f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}."
)
forced_expert_ids = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64)
top_k_indices = top_k_indices.to(torch.int64)
forced_expert_mask = forced_expert_ids >= 0
top_k_indices = torch.where(forced_expert_mask, forced_expert_ids, top_k_indices)
gate_probs = torch.softmax(gate_logits, dim=-1)
top_k_gates = torch.gather(gate_probs, dim=-1, index=top_k_indices)
top_k_gates = top_k_gates / torch.sum(top_k_gates, dim=-1, keepdim=True)
top_k_indices = top_k_indices.to(torch.int32)

num_tokens, dim = hidden_states.shape

Expand Down Expand Up @@ -411,7 +444,11 @@ def __del__(self):
if hasattr(self, "gdr_buffer_ptr2"):
ixf_f.delete_gdr_buffer(self.gdr_buffer_ptr2)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
forced_expert_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
enable_cuda_graph = True
Expand All @@ -433,6 +470,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
)
ixfd.all_gather_into_tensor(full, hidden_states.contiguous(), group=self.ep_group, async_op=True)
hidden_states = full
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
ixfd.all_gather_into_tensor(
full_forced_expert_ids,
forced_expert_ids.contiguous(),
group=self.ep_group,
async_op=True,
)
forced_expert_ids = full_forced_expert_ids

# triple_gemm uses 3 bf16 components of the fp32 gate weight to emulate fp32 matmul precision on bf16 HW.
gate_logits = ixf_f.triple_gemm_bf16_bf16_fp32(
Expand All @@ -442,6 +493,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
self.gating.gate_weight_bf16_tn_0,
)
top_k_gates, top_k_indices = ixf_f.moe_topk_softmax(gate_logits, self.gating.top_k, renormalize=True)
if forced_expert_ids is not None:
expected_shape = (hidden_states.shape[0], self.top_k)
if tuple(forced_expert_ids.shape) != expected_shape:
raise ValueError(
f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}."
)
forced_expert_ids = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64)
top_k_indices = top_k_indices.to(torch.int64)
forced_expert_mask = forced_expert_ids >= 0
top_k_indices = torch.where(forced_expert_mask, forced_expert_ids, top_k_indices)
gate_probs = torch.softmax(gate_logits, dim=-1)
top_k_gates = torch.gather(gate_probs, dim=-1, index=top_k_indices)
top_k_gates = top_k_gates / torch.sum(top_k_gates, dim=-1, keepdim=True)
top_k_indices = top_k_indices.to(torch.int32)

num_tokens, dim = hidden_states.shape

Expand Down
5 changes: 5 additions & 0 deletions mojo_opset/backends/ttx/operators/moe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from typing import Optional
from typing import Tuple

from mojo_opset.backends.ttx.kernels import moe_combine
Expand All @@ -21,15 +22,19 @@ class TTXMoEGating(MojoMoEGating):
def forward(
self,
hidden_states: torch.Tensor, # (num_tokens, hidden_size)
forced_expert_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns (top_k_indices, top_k_gates).

Args:
hidden_states: (num_tokens, hidden_size), fp16/bf16/fp32.
forced_expert_ids: Optional (num_tokens, top_k) expert ids.
Returns:
top_k_indices: (num_tokens, top_k), int32.
top_k_gates: (num_tokens, top_k), fp32.
"""
if forced_expert_ids is not None:
return super().forward(hidden_states, forced_expert_ids=forced_expert_ids)
assert self.gate_weight.dtype == torch.float32
return moe_gating(hidden_states, self.gate_weight, self.top_k)

Expand Down
38 changes: 34 additions & 4 deletions mojo_opset/core/operators/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
**kwargs,
)

def forward(self, hidden_states):
def forward(self, hidden_states, forced_expert_ids: Optional[torch.Tensor] = None):
# hidden_states: [num_tokens, H]
# DP input: gather peer ranks' shards so gating/dispatch see the full token set.
if self.dp_input and self.ep_size > 1:
Expand All @@ -87,8 +87,17 @@ def forward(self, hidden_states):
)
dist.all_gather_into_tensor(full, hidden_states.contiguous(), group=self.ep_group)
hidden_states = full
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
Comment on lines +90 to +98

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If forced_expert_ids is on CPU (or a different device than hidden_states), creating full_forced_expert_ids on forced_expert_ids.device and calling dist.all_gather_into_tensor will result in a runtime error (especially when using the NCCL backend, which requires CUDA tensors).

To prevent this, ensure forced_expert_ids is moved to hidden_states.device before allocating the empty tensor and performing the collective operation.

Suggested change
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
if forced_expert_ids is not None:
forced_expert_ids = forced_expert_ids.to(device=hidden_states.device)
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=hidden_states.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids


top_k_indices, top_k_gates = self.gating(hidden_states)
top_k_indices, top_k_gates = self.gating(hidden_states, forced_expert_ids=forced_expert_ids)
# top_k_indices, top_k_gates: [num_tokens, top_k]
sorted_hidden_states, tokens_per_expert, sorted_gates, token_indices = self.dispatch(
hidden_states, top_k_gates, top_k_indices
Expand Down Expand Up @@ -227,7 +236,7 @@ def __init__(
**kwargs,
)

def forward(self, hidden_states):
def forward(self, hidden_states, forced_expert_ids: Optional[torch.Tensor] = None):
# DP input: gather peer ranks' shards so gating/dispatch see the full token set.
if self.dp_input and self.ep_size > 1:
local_tokens = hidden_states.shape[0]
Expand All @@ -237,8 +246,17 @@ def forward(self, hidden_states):
)
dist.all_gather_into_tensor(full, hidden_states.contiguous(), group=self.ep_group)
hidden_states = full
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
Comment on lines +249 to +257

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If forced_expert_ids is on CPU (or a different device than hidden_states), creating full_forced_expert_ids on forced_expert_ids.device and calling dist.all_gather_into_tensor will result in a runtime error (especially when using the NCCL backend, which requires CUDA tensors).

To prevent this, ensure forced_expert_ids is moved to hidden_states.device before allocating the empty tensor and performing the collective operation.

Suggested change
if forced_expert_ids is not None:
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=forced_expert_ids.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids
if forced_expert_ids is not None:
forced_expert_ids = forced_expert_ids.to(device=hidden_states.device)
full_forced_expert_ids = torch.empty(
local_tokens * self.ep_size,
*forced_expert_ids.shape[1:],
dtype=forced_expert_ids.dtype,
device=hidden_states.device,
)
dist.all_gather_into_tensor(full_forced_expert_ids, forced_expert_ids.contiguous(), group=self.ep_group)
forced_expert_ids = full_forced_expert_ids


top_k_indices, top_k_gates = self.gating(hidden_states)
top_k_indices, top_k_gates = self.gating(hidden_states, forced_expert_ids=forced_expert_ids)
sorted_hidden_states, tokens_per_expert, sorted_gates, token_indices = self.dispatch(
hidden_states,
top_k_gates,
Expand Down Expand Up @@ -299,12 +317,14 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
forced_expert_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for MoE Gating operator.

Input:
- hidden_states (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
- forced_expert_ids (Optional[torch.Tensor]): Expert ids of shape [num_tokens, top_k].

Output:
- top_k_indices (torch.Tensor): Output tensor of shape [num_tokens, top_k].
Expand All @@ -314,6 +334,16 @@ def forward(
gate_logits = torch.matmul(hidden_states.float(), self.gate_weight)
gate_logits = torch.softmax(gate_logits, dim=-1)
top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
if forced_expert_ids is not None:
expected_shape = (hidden_states.shape[0], self.top_k)
if tuple(forced_expert_ids.shape) != expected_shape:
raise ValueError(
f"forced_expert_ids must have shape {expected_shape}, got {tuple(forced_expert_ids.shape)}."
)
forced_expert_ids = forced_expert_ids.to(device=hidden_states.device, dtype=torch.int64)
forced_expert_mask = forced_expert_ids >= 0
top_k_indices = torch.where(forced_expert_mask, forced_expert_ids, top_k_indices)
top_k_logits = torch.gather(gate_logits, dim=-1, index=top_k_indices)
top_k_gates = top_k_logits / torch.sum(top_k_logits, dim=-1, keepdim=True)
return top_k_indices.to(torch.int32), top_k_gates

Expand Down
144 changes: 144 additions & 0 deletions mojo_opset/tests/accuracy/operators/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,54 @@ def test_moe(num_experts, top_k, hidden_size, intermediate_size, num_tokens, dty
moe.forward_diff_with(moe_ref, x, mixed_tol=True)


@pytest.mark.parametrize(
"num_experts, top_k, hidden_size, intermediate_size, num_tokens",
[
(16, 4, 1024, 2048, 64),
(32, 8, 1024, 4096, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@bypass_not_implemented
def test_moe_forced_expert_ids_matches_topk_route(
num_experts, top_k, hidden_size, intermediate_size, num_tokens, dtype
):
device = get_torch_device()
torch.manual_seed(0)

moe = MojoMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
)

moe_ref = MojoMoE._registry.get("torch")(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
)

for p in moe_ref.parameters():
nn.init.normal_(p, std=0.02)

moe = moe.to(dtype).to(device)
moe_ref = moe_ref.to(dtype).to(device)
moe.load_state_dict(moe_ref.state_dict())

# FIXME: moe.gating.gate_weight.data should not be casted to float32
moe.gating.gate_weight.data = moe.gating.gate_weight.data.float()
moe_ref.gating.gate_weight.data = moe_ref.gating.gate_weight.data.float()

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device=device)
forced_expert_ids, _ = moe_ref.gating(x)

out = moe(x, forced_expert_ids=forced_expert_ids)
out_ref = moe_ref(x)
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)


@pytest.mark.parametrize(
"num_experts, top_k, hidden_size, num_tokens",
[
Expand Down Expand Up @@ -97,3 +145,99 @@ def test_moe_gating(num_experts, top_k, hidden_size, num_tokens, dtype):
rtol=(0, 1e-2),
ptol=(0.999, 1.0),
)


@pytest.mark.parametrize(
"num_experts, top_k, hidden_size, num_tokens",
[
(16, 4, 1024, 64),
(32, 8, 1024, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@bypass_not_implemented
def test_moe_gating_forced_expert_ids_matches_topk_route(num_experts, top_k, hidden_size, num_tokens, dtype):
device = get_torch_device()
torch.manual_seed(0)

moe_gating = MojoMoEGating(
hidden_size=hidden_size,
num_experts=num_experts,
top_k=top_k,
)

for p in moe_gating.parameters():
nn.init.normal_(p, std=0.02)

moe_gating = moe_gating.to(device)
assert moe_gating.gate_weight.dtype == torch.float32

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device=device)
top_k_indices, top_k_gates = moe_gating(x)
forced_indices, forced_gates = moe_gating(x, forced_expert_ids=top_k_indices)

torch.testing.assert_close(forced_indices, top_k_indices, atol=0, rtol=0)
torch.testing.assert_close(forced_gates, top_k_gates, atol=1e-5, rtol=1e-5)


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@bypass_not_implemented
def test_moe_gating_forced_expert_ids_gathers_requested_routes(dtype):
device = get_torch_device()
torch.manual_seed(0)

moe_gating = MojoMoEGating(
hidden_size=8,
num_experts=4,
top_k=2,
)

for p in moe_gating.parameters():
nn.init.normal_(p, std=0.02)

moe_gating = moe_gating.to(device)
assert moe_gating.gate_weight.dtype == torch.float32

x = torch.rand(3, 8, dtype=dtype, device=device)
forced_expert_ids = torch.tensor([[0, 1], [2, 3], [3, 0]], dtype=torch.int64, device=device)
forced_indices, forced_gates = moe_gating(x, forced_expert_ids=forced_expert_ids)

gate_probs = torch.softmax(torch.matmul(x.float(), moe_gating.gate_weight), dim=-1)
expected_gates = torch.gather(gate_probs, dim=-1, index=forced_expert_ids)
expected_gates = expected_gates / torch.sum(expected_gates, dim=-1, keepdim=True)

torch.testing.assert_close(forced_indices, forced_expert_ids.to(torch.int32), atol=0, rtol=0)
torch.testing.assert_close(forced_gates, expected_gates, atol=1e-5, rtol=1e-5)


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@bypass_not_implemented
def test_moe_gating_forced_expert_ids_negative_one_keeps_topk_route(dtype):
device = get_torch_device()
torch.manual_seed(0)

moe_gating = MojoMoEGating(
hidden_size=8,
num_experts=4,
top_k=2,
)

for p in moe_gating.parameters():
nn.init.normal_(p, std=0.02)

moe_gating = moe_gating.to(device)
assert moe_gating.gate_weight.dtype == torch.float32

x = torch.rand(3, 8, dtype=dtype, device=device)
top_k_indices, _ = moe_gating(x)
forced_expert_ids = torch.tensor([[0, -1], [-1, 3], [2, -1]], dtype=torch.int64, device=device)
forced_indices, forced_gates = moe_gating(x, forced_expert_ids=forced_expert_ids)

forced_expert_mask = forced_expert_ids >= 0
expected_indices = torch.where(forced_expert_mask, forced_expert_ids, top_k_indices.to(torch.int64))
gate_probs = torch.softmax(torch.matmul(x.float(), moe_gating.gate_weight), dim=-1)
expected_gates = torch.gather(gate_probs, dim=-1, index=expected_indices)
expected_gates = expected_gates / torch.sum(expected_gates, dim=-1, keepdim=True)

torch.testing.assert_close(forced_indices, expected_indices.to(torch.int32), atol=0, rtol=0)
torch.testing.assert_close(forced_gates, expected_gates, atol=1e-5, rtol=1e-5)