Skip to content

onnx model export #17

Description

@dadaligoudan

Hi,I want to transform STTN pytorch model to onnx format to deploy. Following is my code:
if name == 'main':
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# 加载原始模型
net = importlib.import_module('model.' + args.model)
model = net.InpaintGenerator().to(device)
model.load_state_dict(torch.load(args.ckpt, map_location=device)['netG'])
model.eval()

# 创建包装器
# wrapped_model = STTN_Wrapper(model).to(device)
wrapped_model = model.to(device)

# 准备符合实际场景的输入样例
batch_size = 1
seq_len = 11  # neighbor_nums
height, width = 240, 432

# 重要:使用真实的输入数据格式
dummy_masked_frames = torch.randn(batch_size, seq_len, 3, height, width, device=device)
dummy_masks = torch.randint(0, 2, (batch_size, seq_len, 1, height, width),
                            dtype=torch.float32, device=device)
dummy_masks.requires_grad_(True)

# 3. 使用更详细的导出参数
torch.onnx.export(
    wrapped_model,
    (dummy_masked_frames, dummy_masks),
    args.ckpt.replace('.pth', '.onnx'),
    export_params=True,
    opset_version=12,
    do_constant_folding=False,
    input_names=['masked_frames', 'masks'],
    output_names=['output'],
    dynamic_axes=None,
    # dynamic_axes={
    #     'masked_frames': {0: 'batch_size', 1: 'sequence_length'},
    #     'masks': {0: 'batch_size', 1: 'sequence_length'},
    #     'output': {0: 'batch_size'}
    # },
    verbose=True
)

# 4. 验证导出结果
import onnx

onnx_model = onnx.load(args.ckpt.replace('.pth', '.onnx'))
print("Exported model inputs:")
for i, input in enumerate(onnx_model.graph.input):
    print(f"{i}. Name: {input.name}, Type: {input.type}")

The model has been transformed to onnx format, but it has only one input,

Image Could you help me with it, I feel so confused about the 'masks' input missing during onnx model export.Thanks a lot.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions