Skip to content

Commit bbfe5d6

Browse files
committed
Move Q4_K env-var dispatch into emit_linear/emit_embedding so patterns stays unchanged.
1 parent c23c9e4 commit bbfe5d6

4 files changed

Lines changed: 53 additions & 20 deletions

File tree

backends/mlx/custom_kernel_ops/gguf/patterns.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,9 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot:
114114
emit_linear,
115115
)
116116
else: # q4_k
117-
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k import (
118-
emit_direct_gguf,
117+
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import (
118+
emit_linear,
119119
)
120-
121-
if emit_direct_gguf():
122-
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import (
123-
emit_linear,
124-
)
125-
else:
126-
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear_mlx_native import (
127-
emit_linear,
128-
)
129120
return emit_linear(P, n, x_node, self.weight, bias_node)
130121

131122

@@ -177,8 +168,8 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot:
177168
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import (
178169
emit_embedding,
179170
)
180-
else:
181-
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding_mlx_native import (
182-
emit_embedding,
183-
)
171+
else: # q4_k
172+
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import (
173+
emit_embedding,
174+
)
184175
return emit_embedding(P, n, self.weight, indices_node, self.output_dtype)

backends/mlx/custom_kernel_ops/gguf/q4k/common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@
5151
_Q4K_D_BYTES + _Q4K_DMIN_BYTES + _Q4K_SCALES_BYTES + _Q4K_QS_BYTES
5252
) # 144
5353

54-
# Q4_K mat-mat uses NL = QK_K / 32 (8 sub-blocks of 32 elements).
55-
Q4K_NL = QK_K // 32 # 8
5654

5755
# ---------------------------------------------------------------------------
5856
# Shared Metal header

backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
"""
5757

5858

59-
def emit_embedding(
59+
def _emit_embedding_fused(
6060
P: MLXProgramBuilder,
6161
head: Node,
6262
weight_node: Node,
@@ -125,3 +125,28 @@ def emit_embedding(
125125
)
126126

127127
return out
128+
129+
130+
131+
def emit_embedding(
132+
P: MLXProgramBuilder,
133+
head: Node,
134+
weight_node: Node,
135+
indices_node: Node,
136+
output_dtype: torch.dtype,
137+
) -> Slot:
138+
"""Dispatch to fused Metal gather or the legacy MLX-native repack path."""
139+
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k import emit_direct_gguf
140+
141+
if emit_direct_gguf():
142+
return _emit_embedding_fused(
143+
P, head, weight_node, indices_node, output_dtype
144+
)
145+
146+
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding_mlx_native import (
147+
emit_embedding as emit_embedding_mlx_native,
148+
)
149+
150+
return emit_embedding_mlx_native(
151+
P, head, weight_node, indices_node, output_dtype
152+
)

backends/mlx/custom_kernel_ops/gguf/q4k/linear.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _q4k_matmul_source(has_bias: bool) -> str:
180180
short il0 = tid % NL0;
181181
short il = il0; // current dequant sub-block index within Q4_K block
182182
183-
const short offset1 = il0 / NL; // always 0 for NL=8, NL0=2
183+
const short offset1 = il0 / NL; // always 0 (il0 < NL0=2, NL=16)
184184
185185
// Pointer to weight block for this thread's assigned row.
186186
device const block_q4_K * wblk = (device const block_q4_K *) weight
@@ -417,7 +417,7 @@ def _emit_q4k_matmul(
417417
)
418418

419419

420-
def emit_linear(
420+
def _emit_linear_fused(
421421
P: MLXProgramBuilder,
422422
head: Node,
423423
x_node: Node,
@@ -513,3 +513,22 @@ def emit_linear(
513513
),
514514
)
515515
return out
516+
517+
def emit_linear(
518+
P: MLXProgramBuilder,
519+
head: Node,
520+
x_node: Node,
521+
weight_node: Node,
522+
bias_node: Optional[Node],
523+
) -> Slot:
524+
"""Dispatch to fused Metal kernels or the legacy MLX-native repack path."""
525+
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k import emit_direct_gguf
526+
527+
if emit_direct_gguf():
528+
return _emit_linear_fused(P, head, x_node, weight_node, bias_node)
529+
530+
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear_mlx_native import (
531+
emit_linear as emit_linear_mlx_native,
532+
)
533+
534+
return emit_linear_mlx_native(P, head, x_node, weight_node, bias_node)

0 commit comments

Comments
 (0)