Skip to content

sglang实现,疑似在 MTP 解码层的 MLP 后缺少 TP 通信(all-reduce) #43

@cavalier501

Description

@cavalier501

您好,感谢您将 LongCat-Flash / SGLang 相关实现开源,对社区帮助非常大。

在阅读并调试 SGLang 中 LongCat-Flash 的 MTP 实现时,我注意到 MTP 解码层中关于 Tensor Parallel 通信的一处细节,想请您帮忙确认其设计是否符合预期。

现象描述

longcat_flash_nextn.py 中,
LongcatFlashDenseDecoderLayer.forward()(约第 204 行)里:

  • self.mlp 为一个 LongcatFlashMLP,其内部结构由:
    • MergedColumnParallelLinear
    • RowParallelLinear
      组成。

按照常见的 TP 设计方式,在这种 ColumnParallel → RowParallel 的 MLP 结构中,
MLP 计算完成后通常需要一次 all-reduce,以聚合各 TP rank 上的部分输出。

但在当前 MTP 路径中,我没有看到在 MLP 之后进行显式的 all_reduce() 操作。

对比主干模型实现

作为对比,在主模型路径中(longcat_flash.py 约第 483 行):

  • 在 MLP 前向计算之后,存在明确的 TP 输出聚合(all-reduce)逻辑。

因此这里看起来 MTP 分支与主干 decoder 的通信行为并不一致

想请教的问题

  • 这里是否是 MTP 层有意省略了 MLP 后的 all-reduce(例如依赖后续算子或特殊假设)?
  • 还是说这里 可能遗漏了一次 TP 通信操作

如果这是一个已知设计或有特殊考虑,非常感谢您能帮忙说明;
如果确实是遗漏,也希望能对该问题有所帮助。

再次感谢您开源和维护这套代码 🙏

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions