Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/patch/test_moe_wna16_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

import torch
from vllm.model_executor.layers.fused_moe.activation import MoEActivation

import vllm_metax.quant_config.moe_wna16 as moe_wna16


def test_moe_wna16_apply_tolerates_missing_disable_inplace(monkeypatch):
captured = {}

def fake_fused_experts(*args, **kwargs):
captured.update(kwargs)
return torch.empty(1)

monkeypatch.setattr(
moe_wna16, "get_fused_experts_fn", lambda: fake_fused_experts
)

method = object.__new__(moe_wna16.MoeWNA16Method)
method.moe = SimpleNamespace()
method.moe_quant_config = object()

layer = SimpleNamespace(
activation=MoEActivation.SILU,
w13_qweight=torch.empty(1),
w2_qweight=torch.empty(1),
apply_router_weight_on_input=False,
global_num_experts=1,
expert_map=None,
)

result = method.apply(
layer,
torch.empty(1),
torch.empty(1),
torch.empty(1, dtype=torch.int64),
None,
None,
)

assert result.shape == (1,)
assert captured["inplace"] is True
31 changes: 31 additions & 0 deletions tests/patch/test_triton_custom_op_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


def test_custom_op_schemas_allow_act_quant_fusion_import():
import torch

import vllm_metax.compat # noqa: F401

assert hasattr(torch.ops._C.scaled_fp4_quant, "out")
assert hasattr(torch.ops._C, "silu_and_mul_per_block_quant")
if hasattr(torch, "accelerator"):
assert hasattr(torch.accelerator, "empty_cache")

import vllm.compilation.passes.fusion.act_quant_fusion as act_quant_fusion

assert act_quant_fusion.SILU_MUL_OP is not None


def test_compat_import_tolerates_missing_torch_cuda(monkeypatch):
import importlib
import torch

import vllm_metax.compat as compat

with monkeypatch.context() as context:
context.delattr(torch, "cuda", raising=False)
importlib.reload(compat)

importlib.reload(compat)
1 change: 1 addition & 0 deletions vllm_metax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.

from . import compat as _compat # noqa: F401
from .version import __version__, __version_tuple__ # noqa: F401


Expand Down
67 changes: 67 additions & 0 deletions vllm_metax/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Early compatibility hooks for vLLM and MetaX runtime mismatches."""

import torch
from torch.library import Library

_FRAGMENT_LIBS: list[Library] = []


def _has_op_overload(name: str, overload_name: str | None = None) -> bool:
if not hasattr(torch.ops, "_C") or not hasattr(torch.ops._C, name):
return False
if overload_name is None:
return True
return hasattr(getattr(torch.ops._C, name), overload_name)


def _define_fragment(schema: str) -> None:
try:
lib = Library("_C", "FRAGMENT")
lib.define(schema)
_FRAGMENT_LIBS.append(lib)
except Exception:
# Another package or prior import may have registered it already.
pass


if not _has_op_overload("scaled_fp4_quant", "out"):
_define_fragment(
"scaled_fp4_quant.out("
"Tensor input, Tensor input_scale, bool is_sf_swizzled_layout, "
"*, Tensor! output, Tensor! output_scale) -> ()"
)

if not _has_op_overload("silu_and_mul_per_block_quant"):
_define_fragment(
"silu_and_mul_per_block_quant("
"Tensor! output, Tensor input, Tensor! scales, int group_size, "
"Tensor? scale_ub, bool is_scale_transposed) -> ()"
)

if hasattr(torch, "accelerator"):
cuda_module = getattr(torch, "cuda", None)
if cuda_module is not None:
for _name in (
"current_device",
"device_count",
"empty_cache",
"is_available",
"max_memory_allocated",
"mem_get_info",
"memory_allocated",
"memory_reserved",
"memory_stats",
"reset_peak_memory_stats",
"set_device",
"synchronize",
):
if not hasattr(torch.accelerator, _name) and hasattr(cuda_module, _name):
setattr(torch.accelerator, _name, getattr(cuda_module, _name))
if (
not hasattr(torch.accelerator, "current_device_index")
and hasattr(cuda_module, "current_device")
):
torch.accelerator.current_device_index = cuda_module.current_device
1 change: 1 addition & 0 deletions vllm_metax/patch/bugfix/triton_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# Affected versions: v0.21.0
# -----------------------------------------------
from . import custom_op_schemas
from . import kda
from . import lora
from . import chunk_delta_h
Expand Down
6 changes: 6 additions & 0 deletions vllm_metax/patch/bugfix/triton_support/custom_op_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# 2026 - Modified by MetaX Integrated Circuits (Shanghai) Co., Ltd. All Rights Reserved.
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for legacy triton_support import paths."""

from vllm_metax.compat import * # noqa: F401,F403
3 changes: 2 additions & 1 deletion vllm_metax/quant_config/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,15 @@ def apply(
)

fused_experts = get_fused_experts_fn()
disable_inplace = getattr(self.moe, "disable_inplace", False)

return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
inplace=not disable_inplace,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
Expand Down