Skip to content

Commit 7ff2e5e

Browse files
author
enpasos
committed
dinov3 -> 36 errors
1 parent d007aef commit 7ff2e5e

5 files changed

Lines changed: 81 additions & 25 deletions

File tree

jaxamples/mnist_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ class MnistDinoV3ModelConfig(ConfigMixin):
221221
head_hidden_dim: int = 192
222222
head_dropout_rate: float = 0.1
223223
pool_features: str = "cls_mean"
224+
use_conv_stem: bool = True
225+
stem_hidden_dim: int = 32
224226

225227
def validate(self) -> None:
226228
_require(self.img_size > 0, "img_size must be > 0.")
@@ -239,6 +241,7 @@ def validate(self) -> None:
239241
_require(self.num_classes >= 2, "num_classes must be >= 2.")
240242
_require(self.num_storage_tokens >= 0, "num_storage_tokens must be >= 0.")
241243
_require(self.head_hidden_dim > 0, "head_hidden_dim must be > 0.")
244+
_require(self.stem_hidden_dim > 0, "stem_hidden_dim must be > 0.")
242245
_validate_dropout(self.head_dropout_rate, "head_dropout_rate")
243246
_require(
244247
self.pool_features in {"cls", "cls_mean"},

jaxamples/mnist_dinov3.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,41 @@ def prepare_dinov3_inputs(
5656
return nchw_images
5757

5858

59+
class MnistDinoV3ConvStem(nnx.Module):
60+
"""Light conv stem that gives the DINO backbone a stronger MNIST front-end."""
61+
62+
def __init__(self, *, hidden_dim: int, rngs: nnx.Rngs):
63+
params_key = rngs.params()
64+
conv1_key, norm1_key, conv2_key, norm2_key = jax.random.split(params_key, 4)
65+
66+
self.conv1 = nnx.Conv(
67+
in_features=1,
68+
out_features=hidden_dim,
69+
kernel_size=(3, 3),
70+
padding="SAME",
71+
use_bias=False,
72+
rngs=nnx.Rngs(conv1_key),
73+
)
74+
self.norm1 = nnx.LayerNorm(hidden_dim, rngs=nnx.Rngs(norm1_key))
75+
self.conv2 = nnx.Conv(
76+
in_features=hidden_dim,
77+
out_features=3,
78+
kernel_size=(3, 3),
79+
padding="SAME",
80+
use_bias=False,
81+
rngs=nnx.Rngs(conv2_key),
82+
)
83+
self.norm2 = nnx.LayerNorm(3, rngs=nnx.Rngs(norm2_key))
84+
85+
def __call__(self, images: jax.Array) -> jax.Array:
86+
features = self.conv1(images)
87+
features = self.norm1(features)
88+
features = nnx.gelu(features, approximate=False)
89+
features = self.conv2(features)
90+
features = self.norm2(features)
91+
return nnx.gelu(features, approximate=False)
92+
93+
5994
class MnistDinoV3Classifier(nnx.Module):
6095
"""Small MNIST classifier using the DINOv3 ViT backbone from jax2onnx."""
6196

@@ -72,16 +107,23 @@ def __init__(
72107
head_hidden_dim: int = 192,
73108
head_dropout_rate: float = 0.1,
74109
pool_features: str = "cls_mean",
110+
use_conv_stem: bool = True,
111+
stem_hidden_dim: int = 32,
75112
rngs: nnx.Rngs,
76113
):
77114
params_key = rngs.params()
78-
backbone_key, head_key = jax.random.split(params_key)
115+
stem_key, backbone_key, head_key = jax.random.split(params_key, 3)
79116
head_norm_key, head_hidden_key, head_dropout_key, head_out_key = jax.random.split(
80117
head_key, 4
81118
)
82119

83120
self.img_size = int(img_size)
84121
self.pool_features = pool_features
122+
self.input_stem = (
123+
MnistDinoV3ConvStem(hidden_dim=stem_hidden_dim, rngs=nnx.Rngs(stem_key))
124+
if use_conv_stem
125+
else None
126+
)
85127
self.backbone = DinoVisionTransformer(
86128
img_size=img_size,
87129
patch_size=patch_size,
@@ -144,7 +186,10 @@ def _pool_head_features(self, tokens: jax.Array) -> jax.Array:
144186
def __call__(
145187
self, images: jax.Array, *, deterministic: bool = True
146188
) -> jax.Array:
147-
backbone_inputs = prepare_dinov3_inputs(images, expected_size=self.img_size)
189+
stemmed_images = self.input_stem(images) if self.input_stem is not None else images
190+
backbone_inputs = prepare_dinov3_inputs(
191+
stemmed_images, expected_size=self.img_size
192+
)
148193
tokens = self._encode_backbone(backbone_inputs)
149194
head_features = self._pool_head_features(tokens)
150195
head_features = self.head_norm(head_features)
@@ -160,30 +205,34 @@ def get_default_config() -> MnistExampleConfig:
160205
model_config = MnistDinoV3ModelConfig(
161206
img_size=28,
162207
patch_size=4,
163-
embed_dim=192,
164-
depth=4,
165-
num_heads=6,
208+
embed_dim=256,
209+
depth=6,
210+
num_heads=8,
166211
num_classes=10,
167212
num_storage_tokens=0,
168-
head_hidden_dim=192,
213+
head_hidden_dim=256,
169214
head_dropout_rate=0.1,
170215
pool_features="cls_mean",
216+
use_conv_stem=True,
217+
stem_hidden_dim=32,
171218
)
172219
checkpoint_name = (
173220
"dinov3_"
174221
f"p{model_config.patch_size}_"
175222
f"dim{model_config.embed_dim}_"
176223
f"d{model_config.depth}_"
177224
f"h{model_config.num_heads}_"
178-
f"{model_config.pool_features}_checkpoints"
225+
f"{model_config.pool_features}_"
226+
f"{'stem' + str(model_config.stem_hidden_dim) if model_config.use_conv_stem else 'nostem'}_checkpoints"
179227
)
180228
return MnistExampleConfig(
181229
seed=5678,
182230
training=shared_mnist_training_config(
183231
checkpoint_dir=os.path.abspath(os.path.join("./data", checkpoint_name)),
184232
output_dir=default_output_dir,
185233
),
186-
# Match the ViT example more closely on token count and parameter budget.
234+
# Borrow a stronger local front-end and a slightly larger backbone to close the
235+
# gap to the stronger MNIST ViT baseline.
187236
model=model_config,
188237
onnx=OnnxConfig(
189238
model_name="mnist_dinov3_model",

onnx/mnist_dinov3_model.onnx

11.6 MB
Binary file not shown.

onnx/mnist_dinov3_model_config.json

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
{
22
"model": {
3-
"depth": 4,
4-
"embed_dim": 192,
3+
"depth": 6,
4+
"embed_dim": 256,
55
"head_dropout_rate": 0.1,
6-
"head_hidden_dim": 192,
6+
"head_hidden_dim": 256,
77
"img_size": 28,
88
"num_classes": 10,
9-
"num_heads": 6,
9+
"num_heads": 8,
1010
"num_storage_tokens": 0,
1111
"patch_size": 4,
12-
"pool_features": "cls_mean"
12+
"pool_features": "cls_mean",
13+
"stem_hidden_dim": 32,
14+
"use_conv_stem": true
1315
},
1416
"onnx": {
1517
"input_params": {
@@ -30,15 +32,15 @@
3032
"training": {
3133
"augmentation": {
3234
"elastic_alpha": 1.2,
33-
"elastic_probability": 0.35,
34-
"elastic_sigma": 0.9,
35+
"elastic_probability": 0.3,
36+
"elastic_sigma": 1.0,
3537
"enable_elastic": true,
3638
"enable_rect_erasing": false,
3739
"enable_rotation": true,
3840
"enable_scaling": true,
3941
"enable_translation": true,
40-
"max_rotation": 10.0,
41-
"max_translation": 2.5,
42+
"max_rotation": 12.0,
43+
"max_translation": 4.0,
4244
"rect_erase_height": 2,
4345
"rect_erase_width": 20,
4446
"rect_erasing_probability": 0.0,
@@ -47,15 +49,15 @@
4749
"scale_max_y": 1.1,
4850
"scale_min_x": 0.9,
4951
"scale_min_y": 0.9,
50-
"scaling_probability": 0.7,
52+
"scaling_probability": 0.6,
5153
"translation_probability": 0.8
5254
},
5355
"base_learning_rate": 0.0001,
5456
"batch_size": 64,
55-
"checkpoint_dir": "/home/enpasos/projects/jaxamples/data/dinov3_p4_dim192_d4_h6_cls_mean_checkpoints",
57+
"checkpoint_dir": "/home/enpasos/projects/jaxamples/data/dinov3_p4_dim256_d6_h8_cls_mean_stem32_checkpoints",
5658
"data_dir": "./data",
5759
"enable_training": true,
58-
"num_epochs_to_train_now": 500,
60+
"num_epochs_to_train_now": 700,
5961
"output_dir": "/home/enpasos/projects/jaxamples/output",
6062
"start_epoch": 0,
6163
"warmup_epochs": 5,

tests/test_mnist_dinov3.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,17 @@ def test_mnist_dinov3_default_config_uses_fairer_budget():
104104
assert config.training.num_epochs_to_train_now == 700
105105
assert config.training.weight_decay == pytest.approx(1e-4)
106106
assert config.training.checkpoint_dir.endswith(
107-
"dinov3_p4_dim192_d4_h6_cls_mean_checkpoints"
107+
"dinov3_p4_dim256_d6_h8_cls_mean_stem32_checkpoints"
108108
)
109109
assert config.model.patch_size == 4
110-
assert config.model.embed_dim == 192
111-
assert config.model.depth == 4
112-
assert config.model.num_heads == 6
113-
assert config.model.head_hidden_dim == 192
110+
assert config.model.embed_dim == 256
111+
assert config.model.depth == 6
112+
assert config.model.num_heads == 8
113+
assert config.model.head_hidden_dim == 256
114114
assert config.model.head_dropout_rate == pytest.approx(0.1)
115115
assert config.model.pool_features == "cls_mean"
116+
assert config.model.use_conv_stem is True
117+
assert config.model.stem_hidden_dim == 32
116118

117119

118120
def test_lr_schedule_applies_warmup_before_cosine_decay():

0 commit comments

Comments
 (0)