@@ -56,6 +56,41 @@ def prepare_dinov3_inputs(
5656 return nchw_images
5757
5858
59+ class MnistDinoV3ConvStem (nnx .Module ):
60+ """Light conv stem that gives the DINO backbone a stronger MNIST front-end."""
61+
62+ def __init__ (self , * , hidden_dim : int , rngs : nnx .Rngs ):
63+ params_key = rngs .params ()
64+ conv1_key , norm1_key , conv2_key , norm2_key = jax .random .split (params_key , 4 )
65+
66+ self .conv1 = nnx .Conv (
67+ in_features = 1 ,
68+ out_features = hidden_dim ,
69+ kernel_size = (3 , 3 ),
70+ padding = "SAME" ,
71+ use_bias = False ,
72+ rngs = nnx .Rngs (conv1_key ),
73+ )
74+ self .norm1 = nnx .LayerNorm (hidden_dim , rngs = nnx .Rngs (norm1_key ))
75+ self .conv2 = nnx .Conv (
76+ in_features = hidden_dim ,
77+ out_features = 3 ,
78+ kernel_size = (3 , 3 ),
79+ padding = "SAME" ,
80+ use_bias = False ,
81+ rngs = nnx .Rngs (conv2_key ),
82+ )
83+ self .norm2 = nnx .LayerNorm (3 , rngs = nnx .Rngs (norm2_key ))
84+
85+ def __call__ (self , images : jax .Array ) -> jax .Array :
86+ features = self .conv1 (images )
87+ features = self .norm1 (features )
88+ features = nnx .gelu (features , approximate = False )
89+ features = self .conv2 (features )
90+ features = self .norm2 (features )
91+ return nnx .gelu (features , approximate = False )
92+
93+
5994class MnistDinoV3Classifier (nnx .Module ):
6095 """Small MNIST classifier using the DINOv3 ViT backbone from jax2onnx."""
6196
@@ -72,16 +107,23 @@ def __init__(
72107 head_hidden_dim : int = 192 ,
73108 head_dropout_rate : float = 0.1 ,
74109 pool_features : str = "cls_mean" ,
110+ use_conv_stem : bool = True ,
111+ stem_hidden_dim : int = 32 ,
75112 rngs : nnx .Rngs ,
76113 ):
77114 params_key = rngs .params ()
78- backbone_key , head_key = jax .random .split (params_key )
115+ stem_key , backbone_key , head_key = jax .random .split (params_key , 3 )
79116 head_norm_key , head_hidden_key , head_dropout_key , head_out_key = jax .random .split (
80117 head_key , 4
81118 )
82119
83120 self .img_size = int (img_size )
84121 self .pool_features = pool_features
122+ self .input_stem = (
123+ MnistDinoV3ConvStem (hidden_dim = stem_hidden_dim , rngs = nnx .Rngs (stem_key ))
124+ if use_conv_stem
125+ else None
126+ )
85127 self .backbone = DinoVisionTransformer (
86128 img_size = img_size ,
87129 patch_size = patch_size ,
@@ -144,7 +186,10 @@ def _pool_head_features(self, tokens: jax.Array) -> jax.Array:
144186 def __call__ (
145187 self , images : jax .Array , * , deterministic : bool = True
146188 ) -> jax .Array :
147- backbone_inputs = prepare_dinov3_inputs (images , expected_size = self .img_size )
189+ stemmed_images = self .input_stem (images ) if self .input_stem is not None else images
190+ backbone_inputs = prepare_dinov3_inputs (
191+ stemmed_images , expected_size = self .img_size
192+ )
148193 tokens = self ._encode_backbone (backbone_inputs )
149194 head_features = self ._pool_head_features (tokens )
150195 head_features = self .head_norm (head_features )
@@ -160,30 +205,34 @@ def get_default_config() -> MnistExampleConfig:
160205 model_config = MnistDinoV3ModelConfig (
161206 img_size = 28 ,
162207 patch_size = 4 ,
163- embed_dim = 192 ,
164- depth = 4 ,
165- num_heads = 6 ,
208+ embed_dim = 256 ,
209+ depth = 6 ,
210+ num_heads = 8 ,
166211 num_classes = 10 ,
167212 num_storage_tokens = 0 ,
168- head_hidden_dim = 192 ,
213+ head_hidden_dim = 256 ,
169214 head_dropout_rate = 0.1 ,
170215 pool_features = "cls_mean" ,
216+ use_conv_stem = True ,
217+ stem_hidden_dim = 32 ,
171218 )
172219 checkpoint_name = (
173220 "dinov3_"
174221 f"p{ model_config .patch_size } _"
175222 f"dim{ model_config .embed_dim } _"
176223 f"d{ model_config .depth } _"
177224 f"h{ model_config .num_heads } _"
178- f"{ model_config .pool_features } _checkpoints"
225+ f"{ model_config .pool_features } _"
226+ f"{ 'stem' + str (model_config .stem_hidden_dim ) if model_config .use_conv_stem else 'nostem' } _checkpoints"
179227 )
180228 return MnistExampleConfig (
181229 seed = 5678 ,
182230 training = shared_mnist_training_config (
183231 checkpoint_dir = os .path .abspath (os .path .join ("./data" , checkpoint_name )),
184232 output_dir = default_output_dir ,
185233 ),
186- # Match the ViT example more closely on token count and parameter budget.
234+ # Borrow a stronger local front-end and a slightly larger backbone to close the
235+ # gap to the stronger MNIST ViT baseline.
187236 model = model_config ,
188237 onnx = OnnxConfig (
189238 model_name = "mnist_dinov3_model" ,
0 commit comments