Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4d5a8f7
init benchmark v2
ClashLuke Jun 24, 2025
5496e1a
init benchmark v2
ClashLuke Jun 24, 2025
cd6e227
update readme
ClashLuke Jun 24, 2025
fdc4aac
remove unused argumetns
ClashLuke Jun 24, 2025
3d5d082
add plasticiity/generalization benchmarks
ClashLuke Jun 25, 2025
8d5fe61
update configs
ClashLuke Jun 25, 2025
ac0469d
more tasks
ClashLuke Jun 25, 2025
2cee0c0
even more tasks
ClashLuke Jun 25, 2025
c8b5310
even more tasks
ClashLuke Jun 25, 2025
6cfdbf7
fix xor
ClashLuke Jun 25, 2025
0f811a7
fix xor
ClashLuke Jun 25, 2025
054e377
fix xor
ClashLuke Jun 25, 2025
704fb56
fix mask
ClashLuke Jun 26, 2025
11e22e4
explicitly compute loss
ClashLuke Jun 26, 2025
79a4734
simplify
ClashLuke Jun 26, 2025
cf08804
more fixes
ClashLuke Jun 26, 2025
ed888f5
disable early stopping
ClashLuke Jun 26, 2025
277afa6
new tensor key
ClashLuke Jun 27, 2025
90ce92b
don't recompile on weight decay
ClashLuke Jul 1, 2025
93c5b49
correct heatmap print
ClashLuke Jul 1, 2025
99bdeb4
potential fixes
ClashLuke Jul 1, 2025
9947f63
rework tests
ClashLuke Jul 1, 2025
1aa31d4
use old defaults
ClashLuke Jul 1, 2025
e6f208b
init visualization
ClashLuke Jul 5, 2025
6824910
faster + correct some things
ClashLuke Jul 5, 2025
540b163
rename to playground
ClashLuke Jul 5, 2025
070525f
add missing fns
ClashLuke Jul 5, 2025
993d8e3
log lr
ClashLuke Jul 5, 2025
ac37120
basic mlp
ClashLuke Jul 5, 2025
f4b3349
better MLP
ClashLuke Jul 5, 2025
24ab0e0
automatically add more hyperparams
ClashLuke Jul 5, 2025
65b2d96
Merge remote-tracking branch 'origin/benchmark2'
ClashLuke Jul 17, 2025
e9e55ad
fix clipping logic, add test for clipping functions
alexjwilliams Jul 18, 2025
a0f39c2
unify clipping logic, higher numerical stability
ClashLuke Jul 21, 2025
f7c4781
revise changes made to clipping functions in a0f39c23ed81
alexjwilliams Jul 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 69 additions & 35 deletions benchmark/README.md

Large diffs are not rendered by default.

64 changes: 64 additions & 0 deletions benchmark/absolute_varying_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import typer

from benchmark.utils import param_norm_win_condition, trial
from heavyball.utils import set_torch

app = typer.Typer(pretty_exceptions_enable=False)
set_torch()

configs = {
"trivial": {"size": 4},
"easy": {"size": 16},
"medium": {"size": 512},
"hard": {"size": 8192},
"extreme": {"size": 2**15},
"nightmare": {"size": 2**17},
}


class Model(nn.Module):
def __init__(self, size):
super().__init__()
self.param = nn.Parameter(torch.randn(size))
self.register_buffer("scale", F.normalize(torch.arange(1, 1 + size).float(), dim=0, p=1))

def forward(self):
return self.param.abs() @ self.scale


@app.command()
def main(
method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"),
dtype: List[str] = typer.Option(["float32"], help="Data type to use"),
size: int = 1024,
batch: int = 256,
steps: int = 100,
weight_decay: float = 0,
opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"),
trials: int = 10,
win_condition_multiplier: float = 1.0,
config: Optional[str] = None,
):
kwargs = configs[config or "trivial"]
model = Model(**kwargs).cuda()

trial(
model,
None,
None,
param_norm_win_condition(win_condition_multiplier * 1e-7, 0),
steps,
opt[0],
weight_decay=weight_decay,
failure_threshold=2,
trials=trials,
)


if __name__ == "__main__":
app()
16 changes: 2 additions & 14 deletions benchmark/adversarial_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,29 +49,17 @@ def main(
config: Optional[str] = None,
):
frequency = configs.get(config, {}).get("frequency", 10)
dtype = [getattr(torch, d) for d in dtype]
model = Model(frequency).cuda().double()
model = Model(frequency).cuda()

def data():
return None, None

# More lenient condition due to adversarial component
trial(
model,
data,
None,
None,
param_norm_win_condition(win_condition_multiplier * 1e-3, 0),
steps,
opt[0],
dtype[0],
1,
1,
weight_decay,
method[0],
1,
1,
failure_threshold=7,
base_lr=1e-3,
trials=trials,
) # More attempts for adversarial case

Expand Down
16 changes: 2 additions & 14 deletions benchmark/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,17 @@ def main(
config: Optional[str] = None,
):
max_batch = configs.get(config, {}).get("max_batch", 256)
dtype = [getattr(torch, d) for d in dtype]
model = Model(max_batch).cuda().double()
model = Model(max_batch).cuda()

def data():
return None, None

# Use a more lenient win condition since we have inherent noise
trial(
model,
data,
None,
None,
param_norm_win_condition(win_condition_multiplier * 1e-8, 0),
steps,
opt[0],
dtype[0],
1,
1,
weight_decay,
method[0],
1,
1,
failure_threshold=5,
base_lr=1e-3,
trials=trials,
)

Expand Down
12 changes: 1 addition & 11 deletions benchmark/beale.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,14 @@ def main(
model = Model(coords)
model.double()

def data():
return None, None

model = trial(
model,
data,
None,
None,
loss_win_condition(win_condition_multiplier * 1e-8 * (not show_image)),
steps,
opt[0],
dtype[0],
1,
1,
weight_decay,
method[0],
1,
1,
base_lr=1e-4,
trials=trials,
return_best=show_image,
)
Expand Down
93 changes: 0 additions & 93 deletions benchmark/char_rnn.py

This file was deleted.

146 changes: 146 additions & 0 deletions benchmark/class_imbalance_rare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Class Imbalance Rare Event Detection Benchmark

Tests an optimizer's ability to learn from severely imbalanced datasets where
rare positive events are critical to detect. This benchmark simulates real-world
scenarios like fraud detection, medical diagnosis, or anomaly detection where
the minority class is both rare and important.

The task uses a synthetic classification problem with configurable class
imbalance ratios. Success is measured by the optimizer's ability to achieve
good performance on the minority class despite the severe imbalance.
"""

from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import typer

from benchmark.utils import trial
from heavyball.utils import set_torch

app = typer.Typer(pretty_exceptions_enable=False)
set_torch()

configs = {
"trivial": {"n_samples": 1000, "input_dim": 20, "hidden_dim": 16, "imbalance_ratio": 0.2, "noise_level": 0.1},
"easy": {"n_samples": 2000, "input_dim": 32, "hidden_dim": 24, "imbalance_ratio": 0.1, "noise_level": 0.2},
"medium": {"n_samples": 5000, "input_dim": 48, "hidden_dim": 32, "imbalance_ratio": 0.05, "noise_level": 0.3},
"hard": {"n_samples": 10000, "input_dim": 64, "hidden_dim": 48, "imbalance_ratio": 0.02, "noise_level": 0.4},
"extreme": {"n_samples": 20000, "input_dim": 96, "hidden_dim": 64, "imbalance_ratio": 0.01, "noise_level": 0.5},
"nightmare": {"n_samples": 50000, "input_dim": 128, "hidden_dim": 96, "imbalance_ratio": 0.005, "noise_level": 0.6},
}


class ImbalancedClassifier(nn.Module):
def __init__(self, n_samples, input_dim, hidden_dim, imbalance_ratio, noise_level):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 2),
)
n_positive = int(n_samples * imbalance_ratio)
n_negative = n_samples - n_positive
positive_X = torch.ones(n_positive, input_dim) + noise_level * torch.randn(n_positive, input_dim)
positive_y = torch.ones(n_positive, dtype=torch.long)
negative_X = -torch.ones(n_negative, input_dim) + noise_level * torch.randn(n_negative, input_dim)
negative_y = torch.zeros(n_negative, dtype=torch.long)
X = torch.cat([positive_X, negative_X], dim=0)
y = torch.cat([positive_y, negative_y], dim=0)
perm = torch.randperm(n_samples)
self.register_buffer("X", X[perm])
self.register_buffer("y", y[perm])
self.register_buffer("pos_weight", torch.tensor(n_negative / n_positive))

def forward(self):
logits = self.classifier(self.X)
alpha = 0.25
gamma = 2.0
probs = F.softmax(logits, dim=1)
ce_loss = F.cross_entropy(logits, self.y, reduction="none")
p_t = probs.gather(1, self.y.unsqueeze(1)).squeeze(1)
focal_weight = alpha * (1 - p_t) ** gamma
focal_loss = focal_weight * ce_loss
return focal_loss.mean()


def imbalanced_win_condition(f1_threshold, recall_threshold):
def win(model, loss):
with torch.no_grad():
logits = model.classifier(model.X)
predictions = torch.argmax(logits, dim=1)
true_positives = ((predictions == 1) & (model.y == 1)).sum().float()
false_positives = ((predictions == 1) & (model.y == 0)).sum().float()
false_negatives = ((predictions == 0) & (model.y == 1)).sum().float()
precision = true_positives / (true_positives + false_positives + 1e-7)
recall = true_positives / (true_positives + false_negatives + 1e-7)
f1 = 2 * precision * recall / (precision + recall + 1e-7)
success = (f1 >= f1_threshold) and (recall >= recall_threshold)
return success.item(), {
"f1_score": f1.item(),
"precision": precision.item(),
"recall": recall.item(),
"true_positives": true_positives.item(),
"false_positives": false_positives.item(),
"false_negatives": false_negatives.item(),
}

return win


@app.command()
def main(
dtype: List[str] = typer.Option(["float32"], help="Data type to use"),
n_samples: int = 5000,
input_dim: int = 48,
hidden_dim: int = 32,
imbalance_ratio: float = 0.05,
noise_level: float = 0.3,
steps: int = 2000,
weight_decay: float = 0.01,
opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"),
trials: int = 25,
win_condition_multiplier: float = 1.0,
config: Optional[str] = None,
):
"""
Class imbalance rare event detection benchmark.

Tests optimizer's ability to learn from severely imbalanced datasets
where detecting rare positive events is critical.
"""
if config:
cfg = configs.get(config, {})
n_samples = cfg.get("n_samples", n_samples)
input_dim = cfg.get("input_dim", input_dim)
hidden_dim = cfg.get("hidden_dim", hidden_dim)
imbalance_ratio = cfg.get("imbalance_ratio", imbalance_ratio)
noise_level = cfg.get("noise_level", noise_level)

model = ImbalancedClassifier(n_samples, input_dim, hidden_dim, imbalance_ratio, noise_level).cuda()

base_f1 = 0.7
base_recall = 0.8
f1_threshold = min(1.0, win_condition_multiplier * base_f1 * imbalance_ratio * 10)
recall_threshold = min(1.0, win_condition_multiplier * base_recall * imbalance_ratio * 8)

trial(
model,
None,
None,
imbalanced_win_condition(f1_threshold, recall_threshold),
steps,
opt[0],
weight_decay,
trials=trials,
failure_threshold=4,
)


if __name__ == "__main__":
app()
Loading