Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,7 @@ dmypy.json

# Pyre type checker
.pyre/

# Editor / local working dirs
.idea/
transcripts/
98 changes: 98 additions & 0 deletions benchmarks/bench_klsoap_pinv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import argparse
import csv
from collections import defaultdict
from itertools import product

import torch

SHAPES = [(64, 64), (128, 32), (32, 128), (256, 16), (16, 256)]
WARMUP_K = [0, 1, 2, 4, 8, 16]
EPS_VALS = [1e-12, 1e-8, 1e-4]
DTYPES = {torch.float32: "fp32", torch.float64: "fp64"}
METHODS = ["clamp", "pinv"]


def make_L(shape, k, seed, dtype):
d_a, d_b = shape
g = torch.Generator().manual_seed(seed)
Gs = [torch.randn(d_a, d_b, generator=g, dtype=dtype) for _ in range(k + 1)]
L = sum(G @ G.T for G in Gs) / (d_b * (k + 1))
return L, Gs[-1]


def reciprocal(method, eig, eps):
if method == "clamp":
return eig.clamp_min(eps).reciprocal()
keep = eig > eps * eig.amax(dim=-1, keepdim=True)
return torch.where(keep, eig.reciprocal(), 0.0)


def apply_inv(Q, inv_eig, X):
return Q @ (inv_eig.unsqueeze(-1) * (Q.T @ X))


def run_case(shape, k, eps, dtype, seed):
L, G = make_L(shape, k, seed, dtype)
eig64, Q64 = torch.linalg.eigh(L.double())
eig64 = eig64.clamp_min(0)
eig, Q = eig64.to(dtype), Q64.to(dtype)
truth_eps = eig64.shape[-1] * torch.finfo(eig64.dtype).eps
truth = apply_inv(Q64, reciprocal("pinv", eig64, truth_eps), G.double()).to(dtype)
ref_norm = truth.norm().item()
rank = (eig64 > eig64.max() * 1e-12).sum().item()

row = {
"shape": f"{shape[0]}x{shape[1]}",
"dtype": DTYPES[dtype],
"k": k,
"eps": eps,
"seed": seed,
"rank": rank,
"d": shape[0],
}
for m in METHODS:
inv = reciprocal(m, eig, eps)
out = apply_inv(Q, inv, G)
row[f"maxinv_{m}"] = inv.max().item()
row[f"err_{m}"] = (out - truth).norm().item() / ref_norm
return row


def summarize(rows):
buckets = defaultdict(list)
for r in rows:
buckets[(r["shape"], r["dtype"], r["k"], r["eps"])].append(r)

cols = [(stat, m) for stat in ("maxinv", "err") for m in METHODS]
header = f"{'shape':<10} {'dtype':<5} {'k':>3} {'eps':>9} {'rank/d':>8} " + " ".join(
f"{stat + '_' + m:<13}" for stat, m in cols
)
print(f"\n{header}\n{'-' * len(header)}")
for (shape, dt, k, eps), items in sorted(buckets.items()):
rank, d = items[0]["rank"], items[0]["d"]
vals = " ".join(f"{max(r[f'{stat}_{m}'] for r in items):>13.3e}" for stat, m in cols)
print(f"{shape:<10} {dt:<5} {k:>3} {eps:>9.0e} {rank}/{d:<5} {vals}")


def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--csv", help="Write all rows to CSV file")
parser.add_argument("--seeds", type=int, default=3)
args = parser.parse_args()

rows = [
run_case(shape, k, eps, dtype, seed)
for shape, k, eps, dtype, seed in product(SHAPES, WARMUP_K, EPS_VALS, DTYPES, range(args.seeds))
]
summarize(rows)

if args.csv:
with open(args.csv, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=list(rows[0]))
w.writeheader()
w.writerows(rows)
print(f"\nWrote {len(rows)} rows to {args.csv}")


if __name__ == "__main__":
main()
20 changes: 1 addition & 19 deletions benchmarks/bench_singular_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch._dynamo import config as dynamo_config

from heavyball.utils import _max_singular_value_ndim, max_singular_value, min_singular_value
from heavyball.utils import max_singular_value, min_singular_value

dynamo_config.cache_size_limit = 2**20
dynamo_config.accumulated_cache_size_limit = 2**20
Expand Down Expand Up @@ -31,7 +31,6 @@ def make_matrix(shape, cond=10, dtype=torch.float32, symmetric=False, seed=0):

SHAPES_2D = [(4, 4), (32, 32), (128, 128), (10, 5), (5, 10)]
SHAPES_SYM = [(4, 4), (32, 32), (128, 128)]
SHAPES_NDIM = [(3, 4, 5), (16, 32, 64), (16, 16, 512)]
CONDS = [1, 10, 1e4, 1e10, 1e18, 1e30, 1e300]
DTYPES = [torch.bfloat16, torch.float32, torch.float64]
POWER_ITERS = [0, 5, 20]
Expand Down Expand Up @@ -78,22 +77,6 @@ def bench_min_sv(rows):
rows.append(("min_sv", _dtype_name(dtype), pi, shape, cond, rerr, status))


def bench_ndim(rows):
for shape in SHAPES_NDIM:
torch.manual_seed(0x172893)
A = torch.randn(shape).cuda()
exact = torch.linalg.svdvals(A.double()).max()
try:
approx = _max_singular_value_ndim(A, power_iter=2)
rerr = abs((approx.double() - exact) / exact).item()
is_upper = (approx.double() >= exact.double()).item()
status = "ok" if is_upper else "not_upper_bound"
except Exception as e:
rerr = float("nan")
status = type(e).__name__
rows.append(("ndim", "fp32", 2, shape, 0, rerr, status))


def print_pareto(rows):
from itertools import groupby

Expand Down Expand Up @@ -124,7 +107,6 @@ def main():
rows = []
bench_max_sv(rows)
bench_min_sv(rows)
bench_ndim(rows)

print_pareto(rows)

Expand Down
130 changes: 130 additions & 0 deletions benchmarks/bench_soap_variance_rotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Numerics of SOAP/KLSOAP second-moment transport under a Q rotation.

Pre-fix (none): _apply_soap_preconditioner did not pass exp_avg_sq into
get_orthogonal_matrix_QR. v stayed in the OLD eigenframe
while Q (and m) moved -- drifts with each rotation.
Strawman (linear): apply m's einsum to v. Not what was shipped; included to
show why a naive rotation does not work.
Post-fix (hadamard): v <- (R*R)^T v per side, R = Q_old^T Q_new. Equals
diag(R^T diag(v) R) -- diagonal of the rotated covariance.

Reports per method vs analytical truth (= hadamard by construction):
err ||method - truth||_inf (hadamard: 0; none/linear grow with theta)
min min(method) (linear can go negative)
dvar (sum(method) - sum(v))/sum(v) (hadamard preserves; linear does not)
"""

import argparse
import csv
import math
from collections import defaultdict
from itertools import product

import torch

ANGLES = [0.0, 1e-3, 1e-2, 0.1, 0.5, 1.0, math.pi / 2]
SHAPES = [(16,), (256,), (16, 16), (32, 8), (8, 32), (64, 64)]
V_KINDS = ["uniform", "exponential", "spike"]
DTYPES = {torch.float32: "fp32", torch.float64: "fp64"}
METHODS = ["none", "linear", "hadamard"]


def haar(d, seed, dtype):
g = torch.Generator().manual_seed(seed)
return torch.linalg.qr(torch.randn(d, d, generator=g, dtype=dtype))[0]


def rotate(Q, theta, seed):
d = Q.shape[-1]
g = torch.Generator().manual_seed(seed)
A = torch.randn(d, d, generator=g, dtype=Q.dtype)
S = (A - A.T) / 2
S = S / S.norm() * math.sqrt(d)
return Q @ torch.linalg.matrix_exp(theta * S)


def make_v(shape, kind, seed, dtype):
g = torch.Generator().manual_seed(seed)
if kind == "uniform":
return torch.rand(shape, generator=g, dtype=dtype) + 0.1
if kind == "exponential":
return -(torch.rand(shape, generator=g, dtype=dtype) + 1e-6).log()
v = torch.full(shape, 1e-3, dtype=dtype)
v.view(-1)[0] = 1.0
return v


def transport(method, v, Q_old, Q_new):
if method == "none":
return v
n = len(Q_old)
in_, out, mid = "abcd"[:n], "efgh"[:n], "ABCD"[:n]
if method == "linear":
from_ = ",".join(m + i for m, i in zip(mid, in_))
to_ = ",".join(m + o for m, o in zip(mid, out))
return torch.einsum(f"{in_},{from_},{to_}->{out}", v, *Q_old, *Q_new)
Rs_sq = [(Qo.T @ Qn).pow(2) for Qo, Qn in zip(Q_old, Q_new)]
sides = ",".join(i + o for i, o in zip(in_, out))
return torch.einsum(f"{in_},{sides}->{out}", v, *Rs_sq)


def measure(out, truth, total):
return {
"err": (out - truth).abs().max().item(),
"min": out.min().item(),
"dvar": (out.sum().item() - total) / total,
}


def run_case(shape, theta, kind, dtype, seed):
Q_old = [haar(d, seed + 100 * i, dtype) for i, d in enumerate(shape)]
Q_new = [rotate(Q, theta, seed + 1000 + 100 * i) for i, Q in enumerate(Q_old)]
v = make_v(shape, kind, seed + 2000, dtype)
results = {m: transport(m, v, Q_old, Q_new) for m in METHODS}
total = v.sum().item()
return {
"shape": "x".join(map(str, shape)),
"dtype": DTYPES[dtype],
"theta": theta,
"kind": kind,
"seed": seed,
**{f"{k}_{m}": val for m in METHODS for k, val in measure(results[m], results["hadamard"], total).items()},
}


def summarize(rows):
buckets = defaultdict(list)
for r in rows:
buckets[(len(r["shape"].split("x")), r["dtype"], r["theta"])].append(r)

cols = [(stat, m) for stat in ("err", "min", "dvar") for m in METHODS]
header = f"{'mode':<4} {'dtype':<5} {'theta':>8} " + " ".join(f"{stat + '_' + m:<13}" for stat, m in cols)
print(f"\n{header}\n{'-' * len(header)}")
agg = {"err": max, "min": min, "dvar": lambda xs: max(map(abs, xs))}
for (n, dt, theta), items in sorted(buckets.items()):
vals = " ".join(f"{agg[stat]([r[f'{stat}_{m}'] for r in items]):>13.3e}" for stat, m in cols)
print(f"{n}d {dt:<5} {theta:>8.4f} {vals}")


def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--csv", help="Write all rows to CSV file")
parser.add_argument("--seeds", type=int, default=3)
args = parser.parse_args()

rows = [
run_case(shape, theta, kind, dtype, seed)
for shape, theta, kind, dtype, seed in product(SHAPES, ANGLES, V_KINDS, DTYPES, range(args.seeds))
]
summarize(rows)

if args.csv:
with open(args.csv, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=list(rows[0]))
w.writeheader()
w.writerows(rows)
print(f"\nWrote {len(rows)} rows to {args.csv}")


if __name__ == "__main__":
main()
27 changes: 14 additions & 13 deletions ci/gpu_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def _detect_repo_and_branch():


def api(method, path, **kwargs):
kwargs.setdefault("params", {})
kwargs["params"]["api_key"] = API_KEY
headers = kwargs.setdefault("headers", {})
headers["Authorization"] = f"Bearer {API_KEY}"
for attempt in range(3):
r = requests.request(method, f"{API_BASE}{path}", **kwargs)
if r.status_code != 429:
Expand Down Expand Up @@ -80,17 +80,18 @@ def find_offers(n):

SELF_DESTRUCT_TIMEOUT = 1800

ONSTART_TEMPLATE = """#!/bin/bash
timeout {timeout} bash -c '
ONSTART_SCRIPT = f"""#!/bin/bash
timeout {SELF_DESTRUCT_TIMEOUT} bash -c '
export PIP_BREAK_SYSTEM_PACKAGES=1 &&
if ! command -v g++ &>/dev/null; then apt-get update -qq && apt-get install -y -qq --no-install-recommends g++; fi &&
cd / && git clone --depth 1 -b {branch} {repo} /w &&
cd / && git clone --depth 1 -b "$HB_BRANCH" "$HB_REPO" /w &&
cd /w && pip install -e LightBench -q --break-system-packages 2>&1 &&
pip install -e ".[dev]" -q --break-system-packages 2>&1 &&
python -m pytest {test} --tb=short -q 2>&1; echo HEAVYBALL_EXIT=$?
python -m pytest "$HB_TEST" --tb=short -q 2>&1; echo HEAVYBALL_EXIT=$?
'
sleep 3
curl -s -X PUT "https://console.vast.ai/api/v0/instances/${{CONTAINER_ID}}/?api_key=${{CONTAINER_API_KEY}}" \
curl -s -X PUT "https://console.vast.ai/api/v0/instances/$CONTAINER_ID/" \\
-H "Authorization: Bearer $CONTAINER_API_KEY" \\
-H "Content-Type: application/json" -d '{{"state": "stopped"}}' || true
sleep 2
kill 1 2>/dev/null || true
Expand All @@ -102,12 +103,12 @@ def create_instance(offer_id, test_file):
"client_id": "me",
"image": IMAGE,
"disk": 16,
"onstart": ONSTART_TEMPLATE.format(
timeout=SELF_DESTRUCT_TIMEOUT,
branch=BRANCH,
repo=REPO_URL,
test=test_file,
),
"onstart": ONSTART_SCRIPT,
"env": {
"HB_BRANCH": BRANCH,
"HB_REPO": REPO_URL,
"HB_TEST": test_file,
},
"runtype": "ssh_direc ssh_proxy",
}
r = api("PUT", f"/asks/{offer_id}/", json=payload)
Expand Down
Loading
Loading