Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from multiprocessing import Process, Queue as MPQueue
from transformers import AutoConfig
import socket

from schemas import http
from schemas.config import MESConfig
from ultils.dtype_utils import get_torch_dtype, check_dtype_compatibility, dtype_to_string
from core.tokenizer_manager import TokenizerManager
from core.gpu_worker import GPUWorker

# 设置多进程启动方法为 spawn(CUDA 要求)
try:
Expand All @@ -20,17 +21,16 @@
# 如果已经设置过,忽略错误
pass

# 导入新的进程模块
from core.tokenizer_manager import TokenizerManager
from core.gpu_worker import GPUWorker



class Engine:
def __init__(self, model_name, attn_backend="flash_attn", tensor_parallel_size=1, dtype="auto"):
def __init__(self, model_name, attn_backend="flash_attn", tensor_parallel_size=1, dtype="auto", quantization=None):
self._model_name = model_name
self._attn_backend = attn_backend
self._tensor_parallel_size = tensor_parallel_size
self._dtype = dtype
self._quantization = quantization

# 多进程架构
self._prepare_process = None # Prepare 进程
Expand Down Expand Up @@ -79,13 +79,15 @@ def __init__(self, model_name, attn_backend="flash_attn", tensor_parallel_size=1

print(f"[Engine] Using dtype: {self._dtype}")

# 创建 MES 配置(内部会加载量化配置)
mes_config = MESConfig(
attn_backend=self._attn_backend,
model_name=self._model_name,
max_tokens_per_batch=self._max_tokens_per_batch,
enable_monitoring=self._enable_monitoring,
dtype=self._dtype,
model_config=config,
quantization=self._quantization, # 传入用户指定的量化方法
)
self._prepare_process = Process(
target=TokenizerManager,
Expand Down
2 changes: 1 addition & 1 deletion core/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def signal_handler(signum, frame):
)
self.model.eval()
model_path = snapshot_download(self.model_name)
load_model(self.model, model_path)
load_model(self.model, model_path, self.mes_config.quantization_config)
print(f"[GPUWorker] Model loaded successfully on {self.device} with dtype {dtype_to_string(torch_dtype)}")


Expand Down
106 changes: 87 additions & 19 deletions layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from typing import Optional
from layers.quantization.base_config import QuantizationConfig



def divide(numerator, denominator):
Expand All @@ -10,37 +13,61 @@ def divide(numerator, denominator):


class LinearBase(nn.Module):
"""Linear层基类,支持可选的量化"""

def __init__(
self,
input_size: int,
output_size: int,
tp_dim: int | None = None
tp_dim: int | None = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.tp_dim = tp_dim
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
self.quant_config = quant_config

# 根据配置初始化量化方法
if quant_config is not None:
self.quant_method = quant_config.get_quant_method(self)
else:
self.quant_method = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError


class ReplicatedLinear(LinearBase):
"""副本Linear层(无张量并行)"""

def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size)
super().__init__(input_size, output_size, quant_config=quant_config)
self.input_size = input_size
self.output_size = output_size
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
self.weight.weight_loader = self.weight_loader

# 根据量化配置创建参数
if self.quant_method is not None:
self.quant_method.create_weights(
layer=self,
input_size_per_partition=input_size,
output_partition_sizes=[output_size],
input_size=input_size,
output_size=output_size,
params_dtype=torch.float16,
)
else:
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
self.weight.weight_loader = self.weight_loader

Comment on lines +57 to +70

Copilot AI Feb 10, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

量化分支下调用 create_weights() 只注册了 qweight/qzeros/scales,但没有像非量化 weight 一样给这些 Parameter 绑定 weight_loader。这样在 load_model 里会走 default_weight_loader 直接 copy_(TP>1 时会因为缺少分片逻辑导致 shape mismatch),并且在 packed_modules_mapping 分支会因为缺少 weight_loader 直接异常。建议:为 qweight/qzeros/scales 分别实现并绑定专用 weight_loader(按 Column/Row 并行正确在 output/input 维度上分片),并确保 packed QKV / gate_up 的 shard_id 路径也能正确加载。

Copilot uses AI. Check for mistakes.
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size))
self.bias.weight_loader = self.weight_loader
Expand All @@ -51,24 +78,41 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
if self.quant_method is not None:
return self.quant_method.apply(self, x, self.bias)
else:
return F.linear(x, self.weight, self.bias)


class ColumnParallelLinear(LinearBase):
"""列并行Linear层,输出维度按TP切分"""

def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, 0)
super().__init__(input_size, output_size, 0, quant_config=quant_config)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.weight = nn.Parameter(
torch.empty(self.output_size_per_partition, self.input_size)
)
self.weight.weight_loader = self.weight_loader

if self.quant_method is not None:
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size_per_partition],
input_size=input_size,
output_size=output_size,
params_dtype=torch.float16,
)
else:
self.weight = nn.Parameter(
torch.empty(self.output_size_per_partition, self.input_size)
)
self.weight.weight_loader = self.weight_loader

if bias:
self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
self.bias.weight_loader = self.weight_loader
Expand All @@ -83,20 +127,25 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight, self.bias)
if self.quant_method is not None:
return self.quant_method.apply(self, x, self.bias)
else:
return F.linear(x, self.weight, self.bias)



class MergedColumnParallelLinear(ColumnParallelLinear):
"""合并的列并行Linear层(如gate_up_proj)"""

def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
):
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias=bias)
super().__init__(input_size, sum(output_sizes), bias=bias, quant_config=quant_config)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
param_data = param.data
Expand All @@ -108,6 +157,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded


class QKVParallelLinear(ColumnParallelLinear):
"""QKV并行Linear层,融合Q/K/V投影"""

def __init__(
self,
Expand All @@ -116,6 +166,7 @@ def __init__(
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
):
self.head_size = head_size
self.total_num_heads = total_num_heads
Expand All @@ -125,7 +176,7 @@ def __init__(
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
input_size = hidden_size
output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
super().__init__(input_size, output_size, bias)
super().__init__(input_size, output_size, bias, quant_config=quant_config)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
param_data = param.data
Expand All @@ -145,21 +196,34 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded


class RowParallelLinear(LinearBase):
"""行并行Linear层,输入维度按TP切分,需要all_reduce"""

def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, 1)
super().__init__(input_size, output_size, 1, quant_config=quant_config)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size

self.weight = nn.Parameter(
torch.empty(self.output_size, self.input_size_per_partition)
)
self.weight.weight_loader = self.weight_loader
if self.quant_method is not None:
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size_per_partition],
input_size=input_size,
output_size=output_size,
params_dtype=torch.float16,
)
else:
self.weight = nn.Parameter(
torch.empty(self.output_size, self.input_size_per_partition)
)
self.weight.weight_loader = self.weight_loader

if bias:
self.bias = nn.Parameter(torch.empty(self.output_size))
self.bias.weight_loader = self.weight_loader
Expand All @@ -174,7 +238,11 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data.copy_(loaded_weight)

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
if self.quant_method is not None:
y = self.quant_method.apply(self, x, self.bias if self.tp_rank == 0 else None)
else:
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)

if self.tp_size > 1:
dist.all_reduce(y)
return y
16 changes: 16 additions & 0 deletions layers/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Quantization layers package"""
from layers.quantization.base_config import (
QuantizationConfig,
LinearMethodBase,
)
from layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMarlinLinearMethod,
)

__all__ = [
'QuantizationConfig',
'LinearMethodBase',
'AWQMarlinConfig',
'AWQMarlinLinearMethod',
]
Loading
Loading