Skip to content

Commit 1f69af3

Browse files
committed
Vision model onnx conversion working
1 parent b4ea7a3 commit 1f69af3

5 files changed

Lines changed: 247 additions & 91 deletions

File tree

examples/gemma3/qnn/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ Requirements:
88
* Python 3.10
99
* uv - Used throughout the setup scripts, please follow the [publically available installation instructions](https://docs.astral.sh/uv/getting-started/installation/#installation-methods)
1010

11-
This repository contains an automated setup script for Linux that can be used to help automate many of the steps listed in the tutorial above:
11+
This repository contains an automated setup script for Linux that can be used to help automate many of the steps listed in the Phi-3.5 tutorial above:
1212

1313
```bash
1414
source env_setup.sh
1515
```
1616

1717
## Optimization Process
1818

19-
Since Gemma-3-4B is a multi-modal model composed of both vision and text components, the strategy for optimizing it through Olive is to operate on the constituent models separately before configuring them to work in concert at the onnxruntime-genai stage.
19+
Since Gemma-3-4B is a multi-modal model composed of both vision and text components, the strategy for optimizing it through Olive is to operate on the constituent models before configuring them to work in concert at the onnxruntime-genai stage.
2020

2121
Thus, the following commands should be used to separately produce context binaries for the text and vision portions of the model, respectively.
2222

examples/gemma3/qnn/custom_gemma3_4b_it_vision.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,33 @@
44
# --------------------------------------------------------------------------
55

66

7+
import logging
8+
79
import torch
810
from transformers import AutoModel
911

12+
logger = logging.getLogger(__name__)
1013

11-
def load_gemma3_model(model_path):
12-
return AutoModel.from_pretrained("google/gemma-3-4b-it")
1314

15+
class Gemma3VisualEmbeddingGenerator(torch.nn.Module):
16+
def __init__(self, full_model):
17+
super().__init__()
18+
# Extract only the vision components
19+
self.vision_tower = full_model.vision_tower
20+
self.multi_modal_projector = full_model.multi_modal_projector
21+
22+
def forward(self, pixel_values):
23+
# Process images through vision tower
24+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
25+
selected_image_feature = image_outputs.last_hidden_state
26+
# Project to final embedding space
27+
return self.multi_modal_projector(selected_image_feature)
28+
29+
30+
def load_gemma3_model(model_path):
31+
full_model = AutoModel.from_pretrained("google/gemma-3-4b-it")
32+
logger.info("Loaded full model: %s", full_model)
1433

15-
def get_dummy_inputs(model_handler):
16-
return {
17-
"input_ids": torch.full((1, 256), 262144, dtype=torch.long), # Image token ID
18-
"pixel_values": torch.randn(1, 3, 896, 896, dtype=torch.float32),
19-
"attention_mask": torch.ones((1, 256), dtype=torch.long),
20-
}
34+
vision_model = Gemma3VisualEmbeddingGenerator(full_model)
35+
logger.info("Created vision-only model: %s", vision_model)
36+
return vision_model

examples/gemma3/qnn/env_setup.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
#!/bin/bash
2+
# -------------------------------------------------------------------------
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
# Licensed under the MIT License.
5+
# --------------------------------------------------------------------------
16

27
# Installing setuptools to build Olive from source
38
uv pip install setuptools

examples/gemma3/qnn/gemma3-4b-vision-qnn-config.json

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
"type": "PyTorchModel",
44
"model_script": "custom_gemma3_4b_it_vision.py",
55
"model_loader": "load_gemma3_model",
6-
"dummy_inputs_func": "get_dummy_inputs",
76
"io_config": {
8-
"input_names": [ "input_ids", "pixel_values", "attention_mask" ],
9-
"input_shapes": [ [ 1, 256 ], [ 1, 3, 896, 896 ], [ 1, 256 ] ],
10-
"input_types": [ "int64", "float32", "int64" ],
11-
"output_names": [ "last_hidden_state" ],
7+
"input_names": [ "pixel_values" ],
8+
"input_shapes": [ [ 1, 3, 896, 896 ] ],
9+
"input_types": [ "float32" ],
10+
"output_names": [ "image_features" ],
1211
"output_shapes": [ [ 1, 256, 2560 ] ]
1312
}
1413
},
@@ -27,16 +26,23 @@
2726
}
2827
],
2928
"passes": {
30-
"conversion": { "type": "OnnxConversion", "target_opset": 17 },
29+
"conversion": { "type": "OnnxConversion", "target_opset": 20 },
30+
"surgery": { "type": "GraphSurgeries", "surgeries": [ { "surgeon": "MatMulAddToGemm" } ] },
3131
"quantization": {
3232
"type": "OnnxStaticQuantization",
3333
"quant_preprocess": true,
3434
"data_config": "gemma_vision_data_config",
35-
"op_types_to_quantize": [ "MatMul", "LayerNormalization", "Gemm", "Sigmoid", "Gelu" ],
3635
"activation_type": "uint16",
3736
"precision": "uint8",
3837
"calibrate_method": "MinMax"
3938
},
39+
"cb": {
40+
"type": "EPContextBinaryGenerator",
41+
"provider_options": {
42+
"htp_graph_finalization_optimization_mode": "3",
43+
"offload_graph_io_quantization": "0"
44+
}
45+
},
4046
"add_metadata": { "type": "AddOliveMetadata", "graph_name": "gemma-3-4b-it-vision" }
4147
},
4248
"target": "qnn_system",

0 commit comments

Comments
 (0)