Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions benchmark/MNIST.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from pathlib import Path
from typing import List

import torch
import torch.backends.opt_einsum
import torch.nn as nn
import typer
from torch.nn import functional as F
from torchvision import datasets, transforms

from benchmark.utils import loss_win_condition, trial
from heavyball.utils import set_torch

app = typer.Typer(pretty_exceptions_enable=False)
set_torch()

app = typer.Typer()


class Model(nn.Module):
def __init__(self, hidden_size: int = 128):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, hidden_size)
# self.dropout1 = nn.Dropout(0.25)
self.fc2 = nn.Linear(hidden_size, hidden_size)
# self.dropout2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(hidden_size, 10)

def forward(self, x):
x = self.flatten(x)
x = F.relu(self.fc1(x))
# x = self.dropout1(x)
x = F.relu(self.fc2(x))
# x = self.dropout2(x)
x = self.fc3(x)
return F.log_softmax(x, dim=1)


def set_deterministic_weights(model, seed=42):
"""Initialize model with deterministic weights using a fixed seed"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Re-initialize all parameters
for module in model.modules():
if isinstance(module, nn.Linear):
# Use Xavier/Glorot uniform initialization with fixed seed
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)

return model


@app.command()
def main(
method: List[str] = typer.Option(["qr"], help="Eigenvector method to use (for SOAP)"),
dtype: List[str] = typer.Option(["float32"], help="Data type to use"),
hidden_size: int = 128,
batch: int = 128,
steps: int = 0,
weight_decay: float = 0,
opt: List[str] = typer.Option(["ForeachSOAP"], help="Optimizers to use"),
win_condition_multiplier: float = 1.0,
trials: int = 10,
test_loader: bool = None,
):
dtype = [getattr(torch, d) for d in dtype]

# Usage in your script:
model = Model(hidden_size).cuda()
# Load MNIST data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Download data to a data directory relative to the script
data_dir = Path(__file__).parent / "data"
data_dir.mkdir(exist_ok=True)

train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch, shuffle=False, num_workers=0, pin_memory=True
)

test_dataset = datasets.MNIST(data_dir, train=False, download=True, transform=transform)

test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch, shuffle=False, num_workers=0, pin_memory=True
)

data_iter = iter(train_loader)

def data():
nonlocal data_iter
try:
batch_data, batch_targets = next(data_iter)
except StopIteration:
# Reset iterator when exhausted
data_iter = iter(train_loader)
batch_data, batch_targets = next(data_iter)

return batch_data.cuda(), batch_targets.cuda()

# Custom loss function that matches the expected signature
def loss_fn(output, target):
return F.nll_loss(output, target)

trial(
model,
data,
loss_fn,
loss_win_condition(win_condition_multiplier * 0),
steps,
opt[0],
weight_decay,
failure_threshold=10,
trials=trials,
test_loader=test_loader,
)


if __name__ == "__main__":
app()
75 changes: 64 additions & 11 deletions benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import time
import warnings

from typing import Callable, Optional
import numpy as np
import optuna
import torch
Expand Down Expand Up @@ -259,6 +259,7 @@ def __init__(
win_condition,
weight_decay,
warmup_trials,
eval_callback,
ema_index: int = 0,
**kwargs,
):
Expand All @@ -277,6 +278,7 @@ def __init__(
self.warmup_trials = int(warmup_trials)
self.kwargs = kwargs
self.ema_index = ema_index
self.eval_callback = eval_callback

# up to 32768 consecutive times can the new loss be (1 - 1e-7)x larger than the preceding loss
self.validator = Validator(
Expand Down Expand Up @@ -333,6 +335,13 @@ def _inner(self, params):
if hasattr(o, "train"):
o.train()

if not hasattr(self, "test_accuracies"):
self.callback_results = []

if self.eval_callback is not None:
test_accuracy = self.eval_callback(self.m)
self.callback_results.append(test_accuracy)

for j in range(self.group):
inp, tgt = self.data()

Expand Down Expand Up @@ -362,15 +371,16 @@ def _closure():
return validator.ema_states.min().item(), self.m, loss_cpu
if validator(loss).item():
return validator.ema_states.min().item(), self.m, loss_cpu
return validator.ema_states.min().item(), self.m, loss.item()
return validator.ema_states.min().item(), self.m, loss.item(), self.callback_results

def objective(self, params):
self.attempt += 1
target, _, loss = self._inner(params)
target, _, loss, test_accuracies = self._inner(params)
if self.best_loss is None or loss < self.best_loss or not np.isfinite(self.best_loss):
self.best_loss = loss
self.best_at = self.attempt
self.avg = np.log(np.array(params))
self.callback_results = test_accuracies.copy()
if self.best_at * 8 < self.attempt and self.attempt - self.best_at > self.warmup_trials: # no improvements
raise Stop
if time.time() > self.end_time: # timeout
Expand Down Expand Up @@ -445,6 +455,7 @@ def trial(
return_best: bool = False,
warmup_trial_pct: float = 1,
random_trials: int = 10,
eval_callback: Optional[Callable] = None, # evaluate_test_accuracy(dataloader)
):
if data is None:
data = _none_data
Expand Down Expand Up @@ -474,9 +485,6 @@ def trial(
if opt.startswith("ortho-"):
opt = opt[len("ortho-") :]
kwargs["ortho_method"] = "newtonschulz-graft"
if opt == "adam":
opt = torch.optim.Adam
else:
opt = getattr(heavyball, opt)

heavyball.utils._ignore_warning("logei_candidates_func is experimental")
Expand Down Expand Up @@ -509,6 +517,7 @@ def _win_condition(*args):
_win_condition,
weight_decay,
max(trials * warmup_trial_pct, 1 + random_trials),
eval_callback,
**kwargs,
)

Expand Down Expand Up @@ -561,10 +570,54 @@ def _optuna_objective(trial):
print("Successfully found the minimum.")
else:
winning_params = {"lr": 1, "1mbeta1": 0.9, "1mbeta2": 0.999, "1mshampoo_beta": 0.999}
print(
f"Took: {end_time - start_time} | Attempt: {obj.attempt} | " #
f"{opt.__name__}(lr={winning_params['lr']:.5f}, betas=({1 - winning_params['1mbeta1']:.3f}, {1 - winning_params['1mbeta2']:.4f}), " #
f"shampoo_beta={1 - winning_params['1mshampoo_beta']:.3f}) | Best Loss: {obj.best_loss}"
)

if obj.callback_results == []:
print(
f"Took: {end_time - start_time} | Attempt: {obj.attempt} | " #
f"{opt.__name__}(lr={winning_params['lr']:.5f}, betas=({1 - winning_params['1mbeta1']:.3f}, {1 - winning_params['1mbeta2']:.4f}), " #
f"shampoo_beta={1 - winning_params['1mshampoo_beta']:.3f}) | Best Loss: {obj.best_loss}"
)
else:
print(
f"Took: {end_time - start_time} | Attempt: {obj.attempt} | " #
f"{opt.__name__}(lr={winning_params['lr']:.5f}, betas=({1 - winning_params['1mbeta1']:.3f}, {1 - winning_params['1mbeta2']:.4f}), " #
f"shampoo_beta={1 - winning_params['1mshampoo_beta']:.3f}) | Best Loss: {obj.best_loss} | Test Accuracies: {obj.callback_results}"
)

if return_best:
return obj.get_best()


def evaluate_test_accuracy(test_loader):
def _fn(model):
# Save the current training state
was_training = model.training

model.eval()
correct = 0
total = 0

with torch.no_grad():
for data, target in test_loader:
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()

output = model(data)

# Handle different output shapes
if output.dim() > 2: # Sequence modeling: [batch, seq_len, vocab_size]
pred = output.argmax(dim=-1) # [batch, seq_len]
pred_flat = pred.view(-1)
target_flat = target.view(-1)
correct += pred_flat.eq(target_flat).sum().item()
total += target_flat.numel()
else: # Regular classification: [batch, num_classes]
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.numel()

# Restore the original training state
model.train(was_training)
return correct / total

return _fn
Loading