-
Notifications
You must be signed in to change notification settings - Fork 16
Fix rotary transform deferred init and some other fixes #419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
86fa4b3
79ec823
65b596d
3aa07be
691f4d8
a041d20
b36f0c5
9131065
406ce8c
ae73bce
c8e6eea
af06349
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -133,9 +133,19 @@ def __init__(self, n_embd: int, n_head: int, seq_length_dim: int = -2, base_freq | |
| super().__init__() | ||
| # this also holds when using TP, since n_embd is the total embedding size and | ||
| # n_head is the number of heads globally | ||
| dim_model = n_embd // n_head | ||
| self.dim_model = n_embd // n_head | ||
| self.seq_length_dim = seq_length_dim | ||
| inv_freq = 1.0 / (base_freq ** (torch.arange(0, dim_model, 2).float() / dim_model)) | ||
| self.base_freq = base_freq | ||
|
|
||
| self.reset_parameters() | ||
|
|
||
| def reset_parameters(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For our architecture, it might be interesting to add something like this but with some "initializer" parameter. So that our initialization component just calls this on the top-level model and gives the chosen initialization method as parameter. |
||
| # If previously initialized on or moved to a device, reuse that device. | ||
| # Otherwise, use the default device of the current environment. | ||
| device = self.inv_freq.device if hasattr(self, "inv_freq") else None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A comment would be great why there are two cases where
Also the impact of setting device to None below, should be documented.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added. |
||
| inv_freq = 1.0 / ( | ||
| self.base_freq ** (torch.arange(0, self.dim_model, 2, device=device).float() / self.dim_model) | ||
| ) | ||
| self.register_buffer("inv_freq", inv_freq) | ||
|
|
||
| self._seq_len_cached = None | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where do we use these hooks? Could not find a reference in the code. Would be good to clarify how to use them.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I now turned these utility functions into components and also added them to one of the example configs. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| import logging | ||
| import os | ||
| from contextlib import contextmanager | ||
| from typing import Any | ||
|
|
||
| import torch | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @contextmanager | ||
| def enable_deterministic_cuda(): | ||
| """Context manager to enable deterministic CUDA operations and restore previous state.""" | ||
| prev_cudnn_deterministic = torch.backends.cudnn.deterministic | ||
| prev_cudnn_benchmark = torch.backends.cudnn.benchmark | ||
| prev_algos = torch.are_deterministic_algorithms_enabled() | ||
| prev_cublas_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG") | ||
|
|
||
| torch.backends.cudnn.deterministic = True | ||
| torch.backends.cudnn.benchmark = False | ||
| torch.use_deterministic_algorithms(True) | ||
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" | ||
|
|
||
| try: | ||
| yield | ||
| finally: | ||
| torch.backends.cudnn.deterministic = prev_cudnn_deterministic | ||
| torch.backends.cudnn.benchmark = prev_cudnn_benchmark | ||
| torch.use_deterministic_algorithms(prev_algos) | ||
| if prev_cublas_cfg is None: | ||
| os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) | ||
| else: | ||
| os.environ["CUBLAS_WORKSPACE_CONFIG"] = prev_cublas_cfg | ||
|
|
||
|
|
||
| def _detect_nan( | ||
| module: torch.nn.Module, | ||
| module_path: str | None, | ||
| target: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...], | ||
| target_name: str, | ||
| raise_exception: bool, | ||
| ): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we are just logging here. Since we have so much output in the tests and also runs, I would either rename it to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a parameter to make this raise exceptions. Adding a return is not possible since it is a hook. |
||
| if isinstance(target, (list, tuple)): | ||
| if any(torch.isnan(o).any() for o in target if isinstance(o, torch.Tensor)): | ||
| logger.error(f"NaN detected in {target_name} {module.__class__.__name__}") | ||
| if module_path: | ||
| logger.error(f"Module path: {module_path}") | ||
| if raise_exception: | ||
| raise ValueError(f"NaN detected in {target_name} of module {module.__class__.__name__}") | ||
| elif isinstance(target, torch.Tensor) and torch.isnan(target).any(): | ||
| logger.error(f"NaN detected in {target_name} {module.__class__.__name__}") | ||
| if module_path: | ||
| logger.error(f"Module path: {module_path}") | ||
| if raise_exception: | ||
| raise ValueError(f"NaN detected in {target_name} of module {module.__class__.__name__}") | ||
|
|
||
|
|
||
| def debug_nan_hook( | ||
| module: torch.nn.Module, | ||
| input: torch.Tensor | tuple[torch.Tensor, ...], | ||
| output: torch.Tensor | tuple[torch.Tensor, ...] | list[torch.Tensor], | ||
| module_path: str | None = None, | ||
| raise_exception: bool = False, | ||
| ): | ||
| """Hook to detect NaN in forward pass""" | ||
| _detect_nan(module, module_path, target=input, target_name="input", raise_exception=raise_exception) | ||
| _detect_nan(module, module_path, target=output, target_name="output", raise_exception=raise_exception) | ||
|
|
||
|
|
||
| def print_forward_hook( | ||
| module: torch.nn.Module, | ||
| input: torch.Tensor | tuple[torch.Tensor, ...] | list[torch.Tensor] | dict[str, Any], | ||
| output: torch.Tensor | tuple[torch.Tensor, ...] | list[torch.Tensor] | dict[str, Any], | ||
| module_path: str | None = None, | ||
| print_shape_only: bool = False, | ||
| ): | ||
| """Hook to print input and output shapes during forward pass""" | ||
| if isinstance(input, (list, tuple)): | ||
| input_shapes = [inp.shape for inp in input if isinstance(inp, torch.Tensor)] | ||
| elif isinstance(input, torch.Tensor): | ||
| input_shapes = [input.shape] | ||
| else: | ||
| input_shapes = [] | ||
|
|
||
| if isinstance(output, (list, tuple)): | ||
| output_shapes = [out.shape for out in output if isinstance(out, torch.Tensor)] | ||
| elif isinstance(output, torch.Tensor): | ||
| output_shapes = [output.shape] | ||
| else: | ||
| output_shapes = [] | ||
|
|
||
| print( | ||
| f"Module: {module.__class__.__name__}, " | ||
| f"Path: {module_path}, " | ||
| f"Input shapes: {input_shapes}, " | ||
| f"Output shapes: {output_shapes}" | ||
| ) | ||
| if not print_shape_only: | ||
| print(f">>> Input:\n{input}") | ||
| print(f">>> Output:\n{output}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to call this explicitly in the constructor? Wouldn't our weight init routines call it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still want the constructor to build a valid and complete instance of the module. In particular for unit testing or checkpoint loading where this initialization component is not used.