Skip to content

SAM2.1在onnxruntime和MNN上的推理结果不同 #4551

Description

@mltloveyy

基于类似代码导出的image_decoder,使用相同的features、coords和labels,onnxruntime可以正确推理计算,mnn推理结果不正确,好像prompt没有起作用。部分代码如下:

# prompt preprocess
points_list = [[[281, 405], [315, 602]]]
labels_list = [[1, 1]]
onnx_coords = numpy.array(points_list, dtype=numpy.float32)
onnx_labels = numpy.array(labels_list, dtype=numpy.int64)
mnn_coords = MNN.numpy.array(points_list, dtype=MNN.numpy.float32)
mnn_labels = MNN.numpy.array(labels_list, dtype=MNN.numpy.int32)

# onnx
dec_onnx = ort.InferenceSession("models/sam2.1_b_dec.onnx")
input_dict = {
    "point_coords": onnx_coords ,
    "point_labels": onnx_labels ,
    "image_embed": image_embed_onnx,
    "high_res_feats_0": high_feats0_onnx,
    "high_res_feats_1": high_feats1_onnx,
}
onnx_mask = dec_onnx.run(["mask"], input_dict)
cv2.imwrite("onnx_mask.jpg", (numpy.squeeze(onnx_mask) + 30) * 6)

# mnn
image_embed_mnn = MNN.numpy.array(image_embed_onnx)
high_feats0_mnn = MNN.numpy.array(high_feats0_onnx)
high_feats1_mnn = MNN.numpy.array(high_feats1_onnx)
dec_mnn = MNN.nn.load_module_from_file(
    "models/sam2.1_b_dec.mnn",
    ["point_coords", "point_labels", "image_embed", "high_res_feats_0", "high_res_feats_1"],
    ["mask"],
    runtime_manager=rt,
)
dec_outputs = dec_mnn.onForward([mnn_coords, mnn_labels, image_embed_mnn, high_feats0_mnn, high_feats1_mnn])
mnn_mask = dec_outputs[0]
MNN.cv.imwrite("mnn_mask.jpg", (MNN.numpy.squeeze(mnn_mask) + 30) * 6)

onnx_mask.jpg
Image

mnn_mask.jpg
Image

补充Decoder定义和导出代码:

class Decoder(torch.nn.Module):
    def __init__(self, model: SAM2Model):
        super().__init__()
        self.prompt_encoder = model.sam_prompt_encoder
        self.mask_decoder = model.sam_mask_decoder

    @torch.no_grad()
    def forward(
        self,
        point_coords: torch.Tensor,
        point_labels: torch.Tensor,
        image_embed: torch.Tensor,
        high_res_feats_0: torch.Tensor,
        high_res_feats_1: torch.Tensor,
    ):
        sparse_embeddings, dense_embeddings = self.prompt_encoder(points=(point_coords, point_labels), boxes=None, masks=None)
        pred_mask, _, _, _ = self.mask_decoder(
            image_embeddings=image_embed,
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
            repeat_image=False,
            high_res_features=[high_res_feats_0, high_res_feats_1],
        )
        mask = F.interpolate(pred_mask, (1024, 1024), mode="bilinear", align_corners=False)
        return mask

def export_decoder(model, abspath_stem, half: bool, int8: bool):
    LOGGER.info(f"\n{colorstr("ONNX:")} starting export with onnx {onnx.__version__} opset {OPSET}...")
    point_coords = torch.rand((1, 4, 2))  # [1, num_points, 2]
    point_labels = torch.randint(-1, 4, (1, 4))  # [1, num_points]
    image_embed = torch.randn((1, 256, 64, 64))
    high_res_feats_0 = torch.randn((1, 32, 256, 256))
    high_res_feats_1 = torch.randn((1, 64, 128, 128))
    # _ = model(point_coords, point_labels, image_embed, high_res_feats_0, high_res_feats_1)

    onnx_path = abspath_stem + "_dec.onnx"
    t0 = time.time()
    torch.onnx.export(
        model,
        (point_coords, point_labels, image_embed, high_res_feats_0, high_res_feats_1),
        onnx_path,
        opset_version=OPSET,
        external_data=False,
        input_names=["point_coords", "point_labels", "image_embed", "high_res_feats_0", "high_res_feats_1"],
        output_names=["mask"],
        dynamic_axes={"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}},
    )
    model_onnx = onnx.load(onnx_path)
    LOGGER.info(f"{colorstr("ONNX:")} slimming with onnxslim {onnxslim.__version__}...")
    model_onnx = onnxslim.slim(model_onnx)
    onnx.checker.check_model(model_onnx)
    onnx.save(model_onnx, onnx_path)
    t1 = time.time()
    mb = file_size(onnx_path)
    assert mb > 0.0, "0.0 MB output model size"
    LOGGER.info(f"{colorstr("ONNX:")} export success ✅ {(t1 - t0):.1f}s, saved as '{onnx_path}' ({mb:.1f} MB)")

    mnn_path = abspath_stem + "_dec.mnn"
    onnx2mnn(onnx_path, mnn_path, half, int8, "biz", colorstr("MNN:"))
    t2 = time.time()
    mb = file_size(mnn_path)
    assert mb > 0.0, "0.0 MB output model size"
    LOGGER.info(f"{colorstr("MNN:")} export success ✅ {(t2 - t1):.1f}s, saved as '{mnn_path}' ({mb:.1f} MB)")

    return mnn_path

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions