Fix rotary transform deferred init and some other fixes#419
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR fixes a critical bug in the RotaryTransform class that caused incorrect initialization when using deferred initialization (meta device), where inverse frequencies would be initialized to zero or NaN. The PR also corrects several typos and adds debugging utilities.
Key changes:
- Fixed RotaryTransform to properly handle deferred initialization by creating a
reset_parametersmethod that respects the device context - Added a test to verify deferred initialization produces the same weights as eager initialization
- Corrected attribute name from
attention_config.attention_configtoattention_config.qk_norm_configin CausalSelfAttention - Fixed multiple spelling errors: "stoppping" → "stopping" and "savig" → "saving"
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| src/modalities/models/gpt2/gpt2_model.py | Fixed RotaryTransform deferred initialization bug and corrected CausalSelfAttention attribute name typo |
| tests/nn/model_initialization/test_deferred_initialization.py | Added comprehensive test to verify deferred vs eager initialization produces identical results |
| tests/utility.py | Added debug utilities: deterministic CUDA context manager and NaN detection hooks |
| src/modalities/gym.py | Fixed typo in parameter name: early_stoppping → early_stopping |
| src/modalities/checkpointing/checkpoint_saving_strategies.py | Fixed typo in parameter name across multiple methods |
| src/modalities/checkpointing/checkpoint_saving.py | Fixed typo in parameter name and docstring |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…nsform_deferred_init
…nsform_deferred_init
|
Note: I checked all the other components going into the GPT2 model and nothing else should be impacted by this bug. However, it would be could to keep this in mind when adding/changing models in the future because neither our code nor PyTorch detects non-shape operations being performed on tensors using the meta device. |
le1nux
left a comment
There was a problem hiding this comment.
Nice work and good catch of this bug!
I left a few small remarks and ideas.
|
|
||
| self.reset_parameters() | ||
|
|
||
| def reset_parameters(self): |
There was a problem hiding this comment.
Regarding reset_parameters(), I found this discussion interesting:
https://github.com/pytorch/torchtitan/blob/58fa181ed3543e19c1cff3014f1b61b919d38cd1/torchtitan/models/llama3/model/model.py#L414-L423
There was a problem hiding this comment.
not init_weights() is non-pytorch function that they call in the train.py and then initialises all modules that contain weights in a recursive fashion.
There was a problem hiding this comment.
For our architecture, it might be interesting to add something like this but with some "initializer" parameter. So that our initialization component just calls this on the top-level model and gives the chosen initialization method as parameter.
Forcing this method to exist would hopefully help to prevent future modules or models to suffer from the same bug as the rotary transforms did.
| self.reset_parameters() | ||
|
|
||
| def reset_parameters(self): | ||
| device = self.inv_freq.device if hasattr(self, "inv_freq") else None |
There was a problem hiding this comment.
A comment would be great why there are two cases where
- inv_freq exists and has a device attribute
- it does not exist in which case device is set to None
Also the impact of setting device to None below, should be documented.
| inv_freq = 1.0 / (base_freq ** (torch.arange(0, dim_model, 2).float() / dim_model)) | ||
| self.base_freq = base_freq | ||
|
|
||
| self.reset_parameters() |
There was a problem hiding this comment.
do we have to call this explicitly in the constructor? Wouldn't our weight init routines call it?
There was a problem hiding this comment.
We still want the constructor to build a valid and complete instance of the module. In particular for unit testing or checkpoint loading where this initialization component is not used.
| module_path: str | None, | ||
| target: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...], | ||
| target_name: str, | ||
| ): |
There was a problem hiding this comment.
we are just logging here. Since we have so much output in the tests and also runs, I would either rename it to has_nan() -> bool and return a flag indicating the presence of NAN or raise an exception additional to the logging
There was a problem hiding this comment.
I added a parameter to make this raise exceptions. Adding a return is not possible since it is a hook.
| _detect_nan(module, module_path, output, "output") | ||
|
|
||
|
|
||
| def register_nan_hooks(model: torch.nn.Module): |
There was a problem hiding this comment.
Do we use this, debug_nan_hook and _detect_nan somewhere?
If not, I still think it can be helpful and we could keep it. I just would document it somewhere
There was a problem hiding this comment.
I now turned these utility functions into components and also added them to one of the example configs.
| weight_init_type=WeightInitTypes.SCALED, | ||
| mean=0.0, | ||
| std=0.02, | ||
| num_layers=2, |
There was a problem hiding this comment.
should we read this from the model directly instead of hardcoding?
There was a problem hiding this comment.
what do you mean with from the model?
| gpt2_model_eager = _apply_initialization(gpt2_model_eager) | ||
| with torch.device("meta"): | ||
| gpt2_model_deferred = _build_gpt2_model() | ||
| gpt2_model_deferred = _apply_initialization(gpt2_model_deferred) |
There was a problem hiding this comment.
should we add a check that all parameters and buffers are on meta device before init?
There was a problem hiding this comment.
Good idea! I added a separate test checking that deferred init params are on meta first then on cuda device.
…d outputs. Also added option to raise an exception in debug_nan_hook.
There was a problem hiding this comment.
where do we use these hooks? Could not find a reference in the code. Would be good to clarify how to use them.
There was a problem hiding this comment.
I now turned these utility functions into components and also added them to one of the example configs.
…on meta then on cuda device.
Also used them in the PP example config.
le1nux
left a comment
There was a problem hiding this comment.
LGTM :) Also great to have these debugging components now! 👍
What does this PR do?
Mainly fixes a bug causing RotaryTransform to be initialized wrongly when using deferred initialization.
Previously, the inverse frequencies would be initialized to zero (or NaN when using deterministic algorithms).
General Changes
Breaking Changes
Checklist before submitting final PR
python tests/tests.py)CHANGELOG_DEV.md)