Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions sonicmoe/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def forward(
is_varlen_K: bool,
activation_type: ActivationType,
is_inference_mode_enabled: bool,
num_sms: int | None = None,
) -> torch.Tensor:
T, H = x.shape
I, H, E = w1.shape
Expand Down Expand Up @@ -171,6 +172,7 @@ def forward(
activation_type=activation_type.value,
is_glu_activation=is_glu_activation,
is_inference_mode_enabled=is_inference_mode_enabled,
num_sms=num_sms,
)

ctx.T = T
Expand All @@ -182,6 +184,7 @@ def forward(
ctx.is_varlen_K = is_varlen_K
ctx.is_glu_activation = is_glu_activation
ctx.stream_id = stream_id
ctx.num_sms = num_sms

ctx.save_for_backward(
x,
Expand Down Expand Up @@ -214,6 +217,7 @@ def backward(ctx, _: None, dz: torch.Tensor):
is_glu_activation = ctx.is_glu_activation
is_varlen_K = ctx.is_varlen_K
stream_id = ctx.stream_id
num_sms = ctx.num_sms

(
x,
Expand Down Expand Up @@ -256,6 +260,7 @@ def backward(ctx, _: None, dz: torch.Tensor):
s_scatter_idx=s_scatter_idx,
is_glu_activation=is_glu_activation,
stream_id=stream_id,
num_sms=num_sms,
)

_up_projection_backward_weight(
Expand All @@ -267,6 +272,7 @@ def backward(ctx, _: None, dz: torch.Tensor):
x_gather_idx=x_gather_idx,
is_glu_activation=is_glu_activation,
stream_id=stream_id,
num_sms=num_sms,
)

dx_reduced = torch.empty(T, H, dtype=dz.dtype, device=dz.device)
Expand All @@ -281,7 +287,7 @@ def backward(ctx, _: None, dz: torch.Tensor):
is_varlen_K=is_varlen_K,
)

return dx_reduced, dw1, db1, *[None] * 12
return dx_reduced, dw1, db1, *[None] * 13


class _DownProjection(torch.autograd.Function):
Expand All @@ -303,6 +309,7 @@ def forward(
num_activated_expert_per_token_offset: torch.Tensor,
is_varlen_K: bool,
activation_type: ActivationType,
num_sms: int | None = None,
) -> torch.Tensor:
TK = y1.size(0)
H, I, E = w2.shape
Expand All @@ -323,6 +330,7 @@ def forward(
expert_schedule_order=None,
x_gather_idx=x_gather_idx,
stream_id=stream_id,
num_sms=num_sms,
)

o = torch.empty(T, H, device=z.device, dtype=z.dtype)
Expand All @@ -344,6 +352,7 @@ def forward(
ctx.is_varlen_K = is_varlen_K
ctx.activation_type = activation_type
ctx.stream_id = stream_id
ctx.num_sms = num_sms

ctx.save_for_backward(
z,
Expand All @@ -365,6 +374,7 @@ def backward(ctx, dout: torch.Tensor):
stream_id = ctx.stream_id
is_varlen_K = ctx.is_varlen_K
activation_type = ctx.activation_type
num_sms = ctx.num_sms

(
z,
Expand Down Expand Up @@ -435,6 +445,7 @@ def backward(ctx, dout: torch.Tensor):
is_glu_activation=is_glu_activation,
activation_type=activation_type.value,
stream_id=stream_id,
num_sms=num_sms,
)

_down_projection_backward_weight(
Expand All @@ -445,13 +456,14 @@ def backward(ctx, dout: torch.Tensor):
expert_schedule_order=None,
x_gather_idx=x_gather_idx,
stream_id=stream_id,
num_sms=num_sms,
)

# TC top-K routing
if not is_varlen_K:
ds = ds.view(T, K)

return None, dz, dw2, db2, ds, *[None] * 10
return None, dz, dw2, db2, ds, *[None] * 11


def moe_TC_softmax_topk_layer(
Expand All @@ -465,6 +477,7 @@ def moe_TC_softmax_topk_layer(
stream_id: int,
activation_type: ActivationType | str = ActivationType.SWIGLU,
is_inference_mode_enabled: bool = False,
num_sms: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert ((b1 is None) and (b2 is None)) or (
(b1 is not None) and (b2 is not None)
Expand Down Expand Up @@ -501,6 +514,7 @@ def moe_TC_softmax_topk_layer(
False, # is_varlen_K
activation_type,
is_inference_mode_enabled,
num_sms,
)

o = _DownProjection.apply(
Expand All @@ -519,6 +533,7 @@ def moe_TC_softmax_topk_layer(
num_activated_expert_per_token_offset,
False, # is_varlen_K
activation_type,
num_sms,
)

return o, router_logits, expert_frequency
Expand Down Expand Up @@ -546,6 +561,7 @@ def moe_general_routing_inputs(
stream_id: int,
activation_type: ActivationType,
is_inference_mode_enabled: bool = False,
num_sms: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert ((b1 is None) and (b2 is None)) or (
(b1 is not None) and (b2 is not None)
Expand Down Expand Up @@ -578,6 +594,7 @@ def moe_general_routing_inputs(
True, # is_varlen_K
activation_type,
is_inference_mode_enabled,
num_sms,
)

o = _DownProjection.apply(
Expand All @@ -596,6 +613,7 @@ def moe_general_routing_inputs(
num_activated_expert_per_token_offset,
True, # is_varlen_K
activation_type,
num_sms,
)

return o, expert_frequency
20 changes: 12 additions & 8 deletions sonicmoe/functional/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def _up_projection_backward_act(
s_scatter_idx: torch.Tensor,
is_glu_activation: bool,
stream_id: int,
num_sms: int | None = None,
) -> None:
I, H, E = w1.size()
if is_glu_activation:
Expand All @@ -226,9 +227,9 @@ def _up_projection_backward_act(
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
current_stream = cuda.CUstream(stream_id)

compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype)
compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype, num_sms)
if compile_dx_key not in _up_projection_backward_act.compile_cache:
dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation)
dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation, num_sms=num_sms)
tensormaps = [dx_module.module.generate_tensormap(None, None, None) for _ in range(2)]
_up_projection_backward_act.compile_cache[compile_dx_key] = cute.compile(
dx_module,
Expand Down Expand Up @@ -271,6 +272,7 @@ def _up_projection_backward_weight(
x_gather_idx: torch.Tensor,
is_glu_activation: bool,
stream_id: int,
num_sms: int | None = None,
) -> None:
I, H, E = dw1.size()
if is_glu_activation:
Expand All @@ -291,9 +293,9 @@ def _up_projection_backward_weight(
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
current_stream = cuda.CUstream(stream_id)

compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype)
compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype, num_sms)
if compile_dw1_key not in _up_projection_backward_weight.compile_cache:
dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation)
dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation, num_sms=num_sms)
tensormaps = [dw1_module.module.generate_tensormap(None, None, None) for _ in range(1)]
_up_projection_backward_weight.compile_cache[compile_dw1_key] = cute.compile(
dw1_module,
Expand Down Expand Up @@ -342,6 +344,7 @@ def _down_projection_backward_act(
is_glu_activation: bool,
activation_type: str,
stream_id: int,
num_sms: int | None = None,
) -> None:
H, I, E = w2.size()
TK = x_gather_idx.size(0)
Expand Down Expand Up @@ -376,11 +379,11 @@ def _down_projection_backward_act(
current_stream = cuda.CUstream(stream_id)
ds_partial = None

compile_dz_key = ("dz", E, H, I, z.dtype, activation_type)
compile_dz_key = ("dz", E, H, I, z.dtype, activation_type, num_sms)
if compile_dz_key not in _down_projection_backward_act.compile_cache:
# I don't know why but this sync appears to fix a mysterious initialization bug??
torch.cuda.synchronize()
dz_module = HopperWgmma_MoE_Down_proj_ActGrad_Bwd(E, H, I, ActivationType(activation_type))
dz_module = HopperWgmma_MoE_Down_proj_ActGrad_Bwd(E, H, I, ActivationType(activation_type), num_sms=num_sms)
tensormaps = [dz_module.module.generate_tensormap(None, None, None) for _ in range(3)]

ds_partial_N = max(ceil_divide(I, dz_module.module.tile_shape_mnk[1]), 1)
Expand Down Expand Up @@ -488,6 +491,7 @@ def _down_projection_backward_weight(
expert_schedule_order: torch.Tensor | None,
x_gather_idx: torch.Tensor,
stream_id: int,
num_sms: int | None = None,
) -> None:
H, I, E = dw2.size()

Expand All @@ -503,9 +507,9 @@ def _down_projection_backward_weight(
mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
current_stream = cuda.CUstream(stream_id)

compile_dw2_key = ("dw2", E, H, I, dw2.dtype)
compile_dw2_key = ("dw2", E, H, I, dw2.dtype, num_sms)
if compile_dw2_key not in _down_projection_backward_weight.compile_cache:
dw2_module = HopperWgmma_MoE_Down_proj_WeightGrad_Bwd(E, H, I)
dw2_module = HopperWgmma_MoE_Down_proj_WeightGrad_Bwd(E, H, I, num_sms=num_sms)
tensormaps = [dw2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
_down_projection_backward_weight.compile_cache[compile_dw2_key] = cute.compile(
dw2_module,
Expand Down
10 changes: 6 additions & 4 deletions sonicmoe/functional/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _up_projection_forward(
activation_type: str,
is_glu_activation: bool,
is_inference_mode_enabled: bool = False,
num_sms: int | None = None,
) -> None:
I, H, E = w1.size()
if is_glu_activation:
Expand All @@ -88,10 +89,10 @@ def _up_projection_forward(

current_stream = cuda.CUstream(stream_id)

compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled)
compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled, num_sms)
if compile_w1_key not in _up_projection_forward.compile_cache:
w1_module = HopperWgmma_MoE_Up_proj_Fwd(
E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled
E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled, num_sms=num_sms
)
tensormaps = [w1_module.module.generate_tensormap(None, None, None) for _ in range(2)]
_up_projection_forward.compile_cache[compile_w1_key] = cute.compile(
Expand Down Expand Up @@ -139,6 +140,7 @@ def _down_projection_forward(
expert_schedule_order: torch.Tensor,
x_gather_idx: torch.Tensor,
stream_id: int,
num_sms: int | None = None,
) -> None:
H, I, E = w2.size()

Expand All @@ -160,9 +162,9 @@ def _down_projection_forward(

current_stream = cuda.CUstream(stream_id)

compile_w2_key = (E, H, I, (b2 is None), w2.dtype)
compile_w2_key = (E, H, I, (b2 is None), w2.dtype, num_sms)
if compile_w2_key not in _down_projection_forward.compile_cache:
w2_module = HopperWgmma_MoE_Down_proj_Fwd(E, H, I)
w2_module = HopperWgmma_MoE_Down_proj_Fwd(E, H, I, num_sms=num_sms)
tensormaps = [w2_module.module.generate_tensormap(None, None, None) for _ in range(1)]
_down_projection_forward.compile_cache[compile_w2_key] = cute.compile(
w2_module, mY1, mW2, mY2, mB2, mE_offset, mX_gather, tensormaps[0], mE_permute_order, current_stream
Expand Down
Loading