您好,感谢您将 LongCat-Flash / SGLang 相关实现开源,对社区帮助非常大。
在阅读并调试 SGLang 中 LongCat-Flash 的 MTP 实现时,我注意到 MTP 解码层中关于 Tensor Parallel 通信的一处细节,想请您帮忙确认其设计是否符合预期。
现象描述
在 longcat_flash_nextn.py 中,
LongcatFlashDenseDecoderLayer.forward()(约第 204 行)里:
self.mlp 为一个 LongcatFlashMLP,其内部结构由:
MergedColumnParallelLinear
- 接
RowParallelLinear
组成。
按照常见的 TP 设计方式,在这种 ColumnParallel → RowParallel 的 MLP 结构中,
MLP 计算完成后通常需要一次 all-reduce,以聚合各 TP rank 上的部分输出。
但在当前 MTP 路径中,我没有看到在 MLP 之后进行显式的 all_reduce() 操作。
对比主干模型实现
作为对比,在主模型路径中(longcat_flash.py 约第 483 行):
- 在 MLP 前向计算之后,存在明确的 TP 输出聚合(all-reduce)逻辑。
因此这里看起来 MTP 分支与主干 decoder 的通信行为并不一致。
想请教的问题
- 这里是否是 MTP 层有意省略了 MLP 后的 all-reduce(例如依赖后续算子或特殊假设)?
- 还是说这里 可能遗漏了一次 TP 通信操作?
如果这是一个已知设计或有特殊考虑,非常感谢您能帮忙说明;
如果确实是遗漏,也希望能对该问题有所帮助。
再次感谢您开源和维护这套代码 🙏
您好,感谢您将 LongCat-Flash / SGLang 相关实现开源,对社区帮助非常大。
在阅读并调试 SGLang 中 LongCat-Flash 的 MTP 实现时,我注意到 MTP 解码层中关于 Tensor Parallel 通信的一处细节,想请您帮忙确认其设计是否符合预期。
现象描述
在
longcat_flash_nextn.py中,LongcatFlashDenseDecoderLayer.forward()(约第 204 行)里:self.mlp为一个LongcatFlashMLP,其内部结构由:MergedColumnParallelLinearRowParallelLinear组成。
按照常见的 TP 设计方式,在这种 ColumnParallel → RowParallel 的 MLP 结构中,
MLP 计算完成后通常需要一次 all-reduce,以聚合各 TP rank 上的部分输出。
但在当前 MTP 路径中,我没有看到在 MLP 之后进行显式的
all_reduce()操作。对比主干模型实现
作为对比,在主模型路径中(
longcat_flash.py约第 483 行):因此这里看起来 MTP 分支与主干 decoder 的通信行为并不一致。
想请教的问题
如果这是一个已知设计或有特殊考虑,非常感谢您能帮忙说明;
如果确实是遗漏,也希望能对该问题有所帮助。
再次感谢您开源和维护这套代码 🙏