Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions heavyball/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class OrthoScaleMode(enum.Enum):
)
_torch_compile_double_backward_pattern = re.compile(r"compile.*does not currently support double backward")
_fd_error = (
"You can accelerate startup by globally enabling finite_differences first " #
"You can accelerate startup by globally enabling finite_differences first "
"(via opt.finite_differences=True or by subclassing it)\n"
"Original Error: "
)
Expand Down Expand Up @@ -418,9 +418,13 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):


###### START
# Taken from https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
# Based on https://arxiv.org/pdf/2505.16932v3
# and https://github.com/NoahAmsel/PolarExpress/blob/5454910920ca8c65afda28820cdf9e49b9436ed0/polar_express.py#L69-L82
# and https://github.com/thinking-machines-lab/manifolds/blob/89dcae50f01af59f1e0570289474da3a2ecaa60b/src/msign.py#L47
#
# under the MIT License

# Coefficients are from https://arxiv.org/pdf/2505.16932v3
ABC_LIST: list[tuple[float, float, float]] = [
(8.28721201814563, -23.595886519098837, 17.300387312530933),
(4.107059111542203, -2.9478499167379106, 0.5448431082926601),
Expand All @@ -438,7 +442,7 @@ def zeropower_via_newtonschulz5(G, steps=5, eps=1e-7):
] + [ABC_LIST[-1]]


def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
def msign(G: torch.Tensor, steps: int = 10, eps: float = 1e-7) -> torch.Tensor:
"""
Polar Express algorithm for the matrix sign function:
https://arxiv.org/abs/2505.16932
Expand All @@ -450,7 +454,9 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:
if should_transpose:
x = x.mT

x /= x.norm(dim=(-2, -1), keepdim=True) * 1.01
# x = x / (x.norm(dim=(-2, -1), keepdim=True) * 1.01 + eps)
stochastic_divide_with_eps_(x, x.norm(dim=(-2, -1)) * 1.01, eps)

for step in range(steps):
a, b, c = ABC_LIST_STABLE[step] if step < len(ABC_LIST_STABLE) else ABC_LIST_STABLE[-1]
s = x @ x.mT
Expand All @@ -464,7 +470,6 @@ def msign(G: torch.Tensor, steps: int = 10) -> torch.Tensor:

if should_transpose:
x = x.mT
x = torch.nan_to_num(x)
return x.float()


Expand Down
Loading