Skip to content

Fix rotary transform deferred init and some other fixes#419

Merged
BlueCrescent merged 12 commits into
mainfrom
fix_rotary_transform_deferred_init
Dec 2, 2025
Merged

Fix rotary transform deferred init and some other fixes#419
BlueCrescent merged 12 commits into
mainfrom
fix_rotary_transform_deferred_init

Conversation

@BlueCrescent

@BlueCrescent BlueCrescent commented Nov 12, 2025

Copy link
Copy Markdown
Member

What does this PR do?

Mainly fixes a bug causing RotaryTransform to be initialized wrongly when using deferred initialization.
Previously, the inverse frequencies would be initialized to zero (or NaN when using deterministic algorithms).

General Changes

  • Fixed RotaryTransform bug.
  • Added a corresponding test for deferred initialization.
  • Removed some typos.
  • Added utility functions that were very helpful in debugging this.

Breaking Changes

  • None

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a critical bug in the RotaryTransform class that caused incorrect initialization when using deferred initialization (meta device), where inverse frequencies would be initialized to zero or NaN. The PR also corrects several typos and adds debugging utilities.

Key changes:

  • Fixed RotaryTransform to properly handle deferred initialization by creating a reset_parameters method that respects the device context
  • Added a test to verify deferred initialization produces the same weights as eager initialization
  • Corrected attribute name from attention_config.attention_config to attention_config.qk_norm_config in CausalSelfAttention
  • Fixed multiple spelling errors: "stoppping" → "stopping" and "savig" → "saving"

Reviewed Changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
src/modalities/models/gpt2/gpt2_model.py Fixed RotaryTransform deferred initialization bug and corrected CausalSelfAttention attribute name typo
tests/nn/model_initialization/test_deferred_initialization.py Added comprehensive test to verify deferred vs eager initialization produces identical results
tests/utility.py Added debug utilities: deterministic CUDA context manager and NaN detection hooks
src/modalities/gym.py Fixed typo in parameter name: early_stoppping → early_stopping
src/modalities/checkpointing/checkpoint_saving_strategies.py Fixed typo in parameter name across multiple methods
src/modalities/checkpointing/checkpoint_saving.py Fixed typo in parameter name and docstring

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread tests/utility.py Outdated
@BlueCrescent

Copy link
Copy Markdown
Member Author

Note: I checked all the other components going into the GPT2 model and nothing else should be impacted by this bug. However, it would be could to keep this in mind when adding/changing models in the future because neither our code nor PyTorch detects non-shape operations being performed on tensors using the meta device.
Testing with deterministic algorithms activated might be useful.

@therealdavidos therealdavidos self-requested a review November 27, 2025 14:17

@le1nux le1nux left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work and good catch of this bug!
I left a few small remarks and ideas.


self.reset_parameters()

def reset_parameters(self):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not init_weights() is non-pytorch function that they call in the train.py and then initialises all modules that contain weights in a recursive fashion.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our architecture, it might be interesting to add something like this but with some "initializer" parameter. So that our initialization component just calls this on the top-level model and gives the chosen initialization method as parameter.
Forcing this method to exist would hopefully help to prevent future modules or models to suffer from the same bug as the rotary transforms did.

self.reset_parameters()

def reset_parameters(self):
device = self.inv_freq.device if hasattr(self, "inv_freq") else None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment would be great why there are two cases where

  1. inv_freq exists and has a device attribute
  2. it does not exist in which case device is set to None

Also the impact of setting device to None below, should be documented.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

inv_freq = 1.0 / (base_freq ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.base_freq = base_freq

self.reset_parameters()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to call this explicitly in the constructor? Wouldn't our weight init routines call it?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still want the constructor to build a valid and complete instance of the module. In particular for unit testing or checkpoint loading where this initialization component is not used.

module_path: str | None,
target: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...],
target_name: str,
):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are just logging here. Since we have so much output in the tests and also runs, I would either rename it to has_nan() -> bool and return a flag indicating the presence of NAN or raise an exception additional to the logging

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a parameter to make this raise exceptions. Adding a return is not possible since it is a hook.

Comment thread src/modalities/utils/debug.py Outdated
_detect_nan(module, module_path, output, "output")


def register_nan_hooks(model: torch.nn.Module):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use this, debug_nan_hook and _detect_nan somewhere?

If not, I still think it can be helpful and we could keep it. I just would document it somewhere

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now turned these utility functions into components and also added them to one of the example configs.

weight_init_type=WeightInitTypes.SCALED,
mean=0.0,
std=0.02,
num_layers=2,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we read this from the model directly instead of hardcoding?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean with from the model?

gpt2_model_eager = _apply_initialization(gpt2_model_eager)
with torch.device("meta"):
gpt2_model_deferred = _build_gpt2_model()
gpt2_model_deferred = _apply_initialization(gpt2_model_deferred)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a check that all parameters and buffers are on meta device before init?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I added a separate test checking that deferred init params are on meta first then on cuda device.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we use these hooks? Could not find a reference in the code. Would be good to clarify how to use them.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now turned these utility functions into components and also added them to one of the example configs.

@le1nux le1nux left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :) Also great to have these debugging components now! 👍

@BlueCrescent BlueCrescent merged commit 9cbb02b into main Dec 2, 2025
3 checks passed
@BlueCrescent BlueCrescent deleted the fix_rotary_transform_deferred_init branch December 2, 2025 14:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants