Skip to content

Why are mask tokens zero‑initialized in V‑JEPA while they are randomly initialized (trunc_normal) in I‑JEPA? #92

@k007ke

Description

@k007ke

While digging through the predictors I noticed a small but interesting difference in the way the mask tokens are initialized:

# I‑JEPA (vision_transformer_predictor.py)
self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
trunc_normal_(self.mask_token, std=init_std)   # ≈ N(0, 0.02²)

 #V‑JEPA (vision_transformer_predictor.py)
self.mask_tokens = nn.ParameterList([
                nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
                for i in range(num_mask_tokens)
            ])
  # zero_init_mask_tokens=True, this code doesn't work.
  if self.predictor_pos_embed is not None:
            self._init_pos_embed(self.predictor_pos_embed.data)
        self.init_std = init_std
        if not zero_init_mask_tokens:
            for mt in self.mask_tokens:
                trunc_normal_(mt, std=init_std)
# zero_init_mask_tokens=True is the default, so they remain exactly zero

Could you share the motivation behind switching to zero initialization for the video version?

  • Did zero‑init improve training stability for long spatio‑temporal sequences or multi‑mask‑token setups?

  • Have you compared convergence speed or final performance between zero‑init and trunc‑normal on the same video benchmarks?

Thanks again for your time and for the great work!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions