While digging through the predictors I noticed a small but interesting difference in the way the mask tokens are initialized:
# I‑JEPA (vision_transformer_predictor.py)
self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
trunc_normal_(self.mask_token, std=init_std) # ≈ N(0, 0.02²)
#V‑JEPA (vision_transformer_predictor.py)
self.mask_tokens = nn.ParameterList([
nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
for i in range(num_mask_tokens)
])
# zero_init_mask_tokens=True, this code doesn't work.
if self.predictor_pos_embed is not None:
self._init_pos_embed(self.predictor_pos_embed.data)
self.init_std = init_std
if not zero_init_mask_tokens:
for mt in self.mask_tokens:
trunc_normal_(mt, std=init_std)
# zero_init_mask_tokens=True is the default, so they remain exactly zero
Could you share the motivation behind switching to zero initialization for the video version?
-
Did zero‑init improve training stability for long spatio‑temporal sequences or multi‑mask‑token setups?
-
Have you compared convergence speed or final performance between zero‑init and trunc‑normal on the same video benchmarks?
Thanks again for your time and for the great work!
While digging through the predictors I noticed a small but interesting difference in the way the mask tokens are initialized:
Could you share the motivation behind switching to zero initialization for the video version?
Did zero‑init improve training stability for long spatio‑temporal sequences or multi‑mask‑token setups?
Have you compared convergence speed or final performance between zero‑init and trunc‑normal on the same video benchmarks?
Thanks again for your time and for the great work!