Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 34 additions & 39 deletions test/test_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,90 +2,85 @@

os.environ["TORCH_LOGS"] = "+recompiles"

import pytest
import math

import pytest
import torch
from torch import nn
from torch import linalg
from torch import linalg, nn
from torch._dynamo import config

import heavyball
from benchmark.utils import get_optim
from heavyball import utils
from heavyball.utils import (
_compilable_global_l2norm_clip_,
_compilable_global_rmsnorm_clip_,
_compilable_l2_clip_,
_compilable_rmsnorm_clip_,
_compilable_global_rmsnorm_clip_,
_compilable_global_l2norm_clip_,
clean,
set_torch,
)
from benchmark.utils import get_optim
from heavyball.utils import clean, set_torch
from heavyball import utils

config.cache_size_limit = 128


def _make_tensors_with_grad(x):
out = [torch.zeros_like(y, requires_grad=True) for y in x]
for y, grad in zip(out, x):
y.grad = grad.clone()
return out


def _in_assertions(x):
if utils.compile_mode == None:
if utils.compile_mode is None:
assert all([not torch.isnan(y).any() for y in x]), "nan before clipping"
assert all([not torch.isinf(y).any() for y in x]), "inf before clipping"


def _out_assertions(torch_clipped, heavyball_clipped):
if utils.compile_mode == None:
if utils.compile_mode is None:
assert all([torch.allclose(a, b) for a, b in zip(torch_clipped, heavyball_clipped)])
assert all([not (torch.isnan(y).any() or torch.isinf(y).any()) for y in heavyball_clipped])


def _clip_non_global(x, norm, clip_at):
scalar = [max(clip_at / n.item(), 1.0) for n in norm]
scalar = [min(clip_at / max(n.item(), 1e-8), 1.0) for n in norm]
return [s * y for s, y in zip(scalar, x)]


def _clip_global(x, norm, clip_at):
scalar = min(clip_at / max(norm.item(), 1e-8), 1.0)
return [scalar * y for y in x]


def _test_rmsnorm(x, clip_at):
_in_assertions(x)
norm = [torch.sqrt(torch.mean(y**2)) + 1e-6 for y in x]
norm = [torch.sqrt(torch.mean(y**2)) for y in x]
torch_clipped = _clip_non_global(x, norm, clip_at)
heavyball_clipped = _compilable_rmsnorm_clip_(x, clip_at)
_out_assertions(torch_clipped, heavyball_clipped)
return heavyball_clipped
_compilable_rmsnorm_clip_(x, clip_at)
_out_assertions(torch_clipped, x)
return x


def _test_l2(x, clip_at):
_in_assertions(x)
norm = [linalg.vector_norm(y) + 1e-6 for y in x]
norm = [linalg.vector_norm(y) for y in x]
torch_clipped = _clip_non_global(x, norm, clip_at)
heavyball_clipped = _compilable_l2_clip_(x, clip_at)
_out_assertions(torch_clipped, heavyball_clipped)
return heavyball_clipped
_compilable_l2_clip_(x, clip_at)
_out_assertions(torch_clipped, x)
return x


def _test_global_l2norm(x, clip_at):
_in_assertions(x)
parameters = _make_tensors_with_grad(x)
nn.utils.clip_grad_norm_(parameters, clip_at)
torch_clipped = [y.grad for y in parameters]
heavyball_clipped = _compilable_global_l2norm_clip_(x, clip_at)
_out_assertions(torch_clipped, heavyball_clipped)
return heavyball_clipped
l2_norm = nn.utils.get_total_norm(x)
torch_clipped = _clip_global(x, l2_norm, clip_at)
_compilable_global_l2norm_clip_(x, clip_at)
_out_assertions(torch_clipped, x)
return x


def _test_global_rmsnorm(x, clip_at):
_in_assertions(x)
l2_norm = nn.utils.get_total_norm(x)
rms_norm = l2_norm / math.sqrt(sum(y.numel() for y in x))
parameters = _make_tensors_with_grad(x)
nn.utils.clip_grads_with_norm_(parameters, clip_at, rms_norm)
torch_clipped = [y.grad for y in parameters]
heavyball_clipped = _compilable_global_rmsnorm_clip_(x, clip_at)
_out_assertions(torch_clipped, heavyball_clipped)
return heavyball_clipped
torch_clipped = _clip_global(x, rms_norm, clip_at)
_compilable_global_rmsnorm_clip_(x, clip_at)
_out_assertions(torch_clipped, x)
return x


@pytest.mark.parametrize("opt", heavyball.__all__)
Expand Down Expand Up @@ -113,7 +108,7 @@ def test_clip(opt, size, depth: int, iterations: int = 16, outer_iterations: int
loss = model(torch.randn((1024, size), device="cuda")).square().mean()
loss.backward()
o.step()
if utils.compile_mode != None:
if utils.compile_mode is not None:
assert all([not (torch.isnan(y).any() or torch.isinf(y).any()) for y in model.parameters()])
o.zero_grad()

Expand Down