[OFT] Rework and Matrix Exponential Mode#1556
Draft
Koratahiu wants to merge 120 commits into
Draft
Conversation
- Dynamic steps based on dtype
- revert scaled oft change
- Remove 0.999 - Default 0.95
…into scaled_optm
part of @BitcrushedHeart referenced commit
…into scaled_optm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
-1).Refactoring
_cayley_batchto_compute_orthogonal_matrix.Matrix Exponential & Reworking CANS
I found a few issues with the Cayley transform in general (even when using CANS or exact math):
After searching for alternatives to Cayley Neumann, I found the Matrix Exponential, which seems to solve all the mentioned issues:
One way to implement the Matrix Exponential is by using the exact math via
torch.linalg.matrix_exp. However,matrix_expis expensive and very unstable for BF16 (it exploded in my tests).To resolve this, I applied a highly effective approximation pipeline:
This might look complex and compute-heavy, but in reality, it only requires 12 matmuls. Thanks to the 4th-order Taylor expansion, the matrix is already near orthogonality, and CANS only requires 3 iterations to converge.
The relative error compared to exact math (
torch.linalg.matrix_exp) is very small (~1e-6 to ~1e-3 for FP32 and ~1e-3 for BF16):Auto clipping mode
When setting spectral norm clipping to
-1, it now automatically applies the recommended spectral norm clipping for each technique:Usage:
Matrix Exponential CANS.-1(auto).