基于类似代码导出的image_decoder,使用相同的features、coords和labels,onnxruntime可以正确推理计算,mnn推理结果不正确,好像prompt没有起作用。部分代码如下:
# prompt preprocess
points_list = [[[281, 405], [315, 602]]]
labels_list = [[1, 1]]
onnx_coords = numpy.array(points_list, dtype=numpy.float32)
onnx_labels = numpy.array(labels_list, dtype=numpy.int64)
mnn_coords = MNN.numpy.array(points_list, dtype=MNN.numpy.float32)
mnn_labels = MNN.numpy.array(labels_list, dtype=MNN.numpy.int32)
# onnx
dec_onnx = ort.InferenceSession("models/sam2.1_b_dec.onnx")
input_dict = {
"point_coords": onnx_coords ,
"point_labels": onnx_labels ,
"image_embed": image_embed_onnx,
"high_res_feats_0": high_feats0_onnx,
"high_res_feats_1": high_feats1_onnx,
}
onnx_mask = dec_onnx.run(["mask"], input_dict)
cv2.imwrite("onnx_mask.jpg", (numpy.squeeze(onnx_mask) + 30) * 6)
# mnn
image_embed_mnn = MNN.numpy.array(image_embed_onnx)
high_feats0_mnn = MNN.numpy.array(high_feats0_onnx)
high_feats1_mnn = MNN.numpy.array(high_feats1_onnx)
dec_mnn = MNN.nn.load_module_from_file(
"models/sam2.1_b_dec.mnn",
["point_coords", "point_labels", "image_embed", "high_res_feats_0", "high_res_feats_1"],
["mask"],
runtime_manager=rt,
)
dec_outputs = dec_mnn.onForward([mnn_coords, mnn_labels, image_embed_mnn, high_feats0_mnn, high_feats1_mnn])
mnn_mask = dec_outputs[0]
MNN.cv.imwrite("mnn_mask.jpg", (MNN.numpy.squeeze(mnn_mask) + 30) * 6)
onnx_mask.jpg

mnn_mask.jpg

补充Decoder定义和导出代码:
class Decoder(torch.nn.Module):
def __init__(self, model: SAM2Model):
super().__init__()
self.prompt_encoder = model.sam_prompt_encoder
self.mask_decoder = model.sam_mask_decoder
@torch.no_grad()
def forward(
self,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
image_embed: torch.Tensor,
high_res_feats_0: torch.Tensor,
high_res_feats_1: torch.Tensor,
):
sparse_embeddings, dense_embeddings = self.prompt_encoder(points=(point_coords, point_labels), boxes=None, masks=None)
pred_mask, _, _, _ = self.mask_decoder(
image_embeddings=image_embed,
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
repeat_image=False,
high_res_features=[high_res_feats_0, high_res_feats_1],
)
mask = F.interpolate(pred_mask, (1024, 1024), mode="bilinear", align_corners=False)
return mask
def export_decoder(model, abspath_stem, half: bool, int8: bool):
LOGGER.info(f"\n{colorstr("ONNX:")} starting export with onnx {onnx.__version__} opset {OPSET}...")
point_coords = torch.rand((1, 4, 2)) # [1, num_points, 2]
point_labels = torch.randint(-1, 4, (1, 4)) # [1, num_points]
image_embed = torch.randn((1, 256, 64, 64))
high_res_feats_0 = torch.randn((1, 32, 256, 256))
high_res_feats_1 = torch.randn((1, 64, 128, 128))
# _ = model(point_coords, point_labels, image_embed, high_res_feats_0, high_res_feats_1)
onnx_path = abspath_stem + "_dec.onnx"
t0 = time.time()
torch.onnx.export(
model,
(point_coords, point_labels, image_embed, high_res_feats_0, high_res_feats_1),
onnx_path,
opset_version=OPSET,
external_data=False,
input_names=["point_coords", "point_labels", "image_embed", "high_res_feats_0", "high_res_feats_1"],
output_names=["mask"],
dynamic_axes={"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}},
)
model_onnx = onnx.load(onnx_path)
LOGGER.info(f"{colorstr("ONNX:")} slimming with onnxslim {onnxslim.__version__}...")
model_onnx = onnxslim.slim(model_onnx)
onnx.checker.check_model(model_onnx)
onnx.save(model_onnx, onnx_path)
t1 = time.time()
mb = file_size(onnx_path)
assert mb > 0.0, "0.0 MB output model size"
LOGGER.info(f"{colorstr("ONNX:")} export success ✅ {(t1 - t0):.1f}s, saved as '{onnx_path}' ({mb:.1f} MB)")
mnn_path = abspath_stem + "_dec.mnn"
onnx2mnn(onnx_path, mnn_path, half, int8, "biz", colorstr("MNN:"))
t2 = time.time()
mb = file_size(mnn_path)
assert mb > 0.0, "0.0 MB output model size"
LOGGER.info(f"{colorstr("MNN:")} export success ✅ {(t2 - t1):.1f}s, saved as '{mnn_path}' ({mb:.1f} MB)")
return mnn_path
基于类似代码导出的image_decoder,使用相同的features、coords和labels,onnxruntime可以正确推理计算,mnn推理结果不正确,好像prompt没有起作用。部分代码如下:
onnx_mask.jpg

mnn_mask.jpg

补充Decoder定义和导出代码: