Skip to content
Merged

Ci #80

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ on:
pull_request:
workflow_dispatch:
push:
branches: [ main ]

concurrency:
group: ci-${{ github.ref }}
cancel-in-progress: true

jobs:
pre-commit:
Expand All @@ -30,9 +33,9 @@ jobs:
python -m pip install --upgrade pip
python -m pip install ruff
- name: Ruff lint
run: ruff check heavyball scripts test
run: ruff check .
- name: Ruff format check
run: ruff format --check heavyball scripts test
run: ruff format --check .
tests:
needs: [pre-commit, style]
runs-on: ubuntu-latest
Expand All @@ -47,5 +50,5 @@ jobs:
python -m pip install --index-url https://download.pytorch.org/whl/cpu "torch>=2.2"
python -m pip install -e .[dev] --check-build-dependencies --use-pep517 --upgrade --upgrade-strategy eager --use-deprecated=legacy-resolver
python -m pip install pytest
- name: Run curated test targets
run: scripts/run_ci_tests.sh
- name: Run targeted tests
run: scripts/run_ci_tests.sh ${{ github.event_name == 'push' && 'push' || '' }}
2 changes: 1 addition & 1 deletion examples/branched_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

import heavyball
import heavyball.chainable as C
import heavyball.utils

heavyball.utils.set_torch()

Expand Down
48 changes: 39 additions & 9 deletions heavyball/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import functools
import math
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union

import numpy
import numpy as np
import optuna
import optunahub
import pandas as pd
import torch
from botorch.utils.sampling import manual_seed
from hebo.design_space.design_space import DesignSpace
from hebo.optimizers.hebo import HEBO
from optuna._transform import _SearchSpaceTransform
Expand All @@ -21,13 +21,6 @@
from optuna.study import Study
from optuna.study._study_direction import StudyDirection
from optuna.trial import FrozenTrial, TrialState
from optuna_integration.botorch import (
ehvi_candidates_func,
logei_candidates_func,
qehvi_candidates_func,
qei_candidates_func,
qparego_candidates_func,
)
from torch import Tensor
from torch.nn import functional as F

Expand All @@ -37,6 +30,33 @@
_SAMPLER_KEY = "auto:sampler"


@contextmanager
def manual_seed(seed: int | None = None) -> Generator[None, None, None]:
r"""
Contextmanager for manual setting the torch.random seed.

Args:
seed: The seed to set the random number generator to.

Returns:
Generator

Example:
>>> with manual_seed(1234):
>>> X = torch.rand(3)

copied as-is from https://github.com/meta-pytorch/botorch/blob/a42cd65f9b704cdb6f2ee64db99a022eb15295d5/botorch/utils/sampling.py#L53C1-L75C50 under the MIT License
"""
old_state = torch.random.get_rng_state()
try:
if seed is not None:
torch.random.manual_seed(seed)
yield
finally:
if seed is not None:
torch.random.set_rng_state(old_state)


class SimpleAPIBaseSampler(BaseSampler):
def __init__(
self,
Expand Down Expand Up @@ -65,6 +85,16 @@ def _get_default_candidates_func(
"""
The original is available at https://github.com/optuna/optuna-integration/blob/156a8bc081322791015d2beefff9373ed7b24047/optuna_integration/botorch/botorch.py under the MIT License
"""

# lazy import
from optuna_integration.botorch import (
ehvi_candidates_func,
logei_candidates_func,
qehvi_candidates_func,
qei_candidates_func,
qparego_candidates_func,
)

if n_objectives > 3 and not has_constraint and not consider_running_trials:
return ehvi_candidates_func
elif n_objectives > 3:
Expand Down
11 changes: 6 additions & 5 deletions heavyball/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,7 +1590,7 @@ def _compilable_copy_stochastic_(target: Tensor, source: Tensor):

def copy_stochastic_(target: Tensor, source: Tensor):
if target.dtype == torch.bfloat16 and source.dtype in (torch.float16, torch.float32, torch.float64):
_compilable_copy_stochastic_(target, source.float())
source = stochastic_round_(target, source)
set_(target, source)


Expand Down Expand Up @@ -2417,10 +2417,11 @@ def bf16_matmul(x: Tensor, y: Tensor):
def if_iscompiling(fn):
base = getattr(torch, fn.__name__, None)

def _fn(x):
if torch.compiler.is_compiling() and hasattr(torch, fn.__name__):
return base(x)
return fn(x)
@functools.wraps(fn)
def _fn(*args, **kwargs):
if torch.compiler.is_compiling() and base is not None:
return base(*args, **kwargs)
return fn(*args, **kwargs)

return _fn

Expand Down
2 changes: 1 addition & 1 deletion interactive/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from plotly.subplots import make_subplots
from sklearn.decomposition import PCA

import heavyball
import heavyball.chainable as C
import heavyball.utils

# TensorFlow Playground inspired colors
COLORS = {
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "heavyball"
description = "Efficient Optimizers"
version = "2.1.1"
version = "2.1.2"
authors = [{ name = "HeavyBall Authors", email = "github.heavyball@nestler.sh" }]
classifiers = ["Intended Audience :: Developers",
"Intended Audience :: Science/Research",
Expand All @@ -28,7 +28,7 @@ readme = "README.md"
requires-python = ">=3.9"

[project.optional-dependencies]
dev = ["pre-commit", "pytest", "ruff", "matplotlib", "seaborn", "pandas", "typer", "optuna", "optunahub", "hebo", "lightbench"]
dev = ["pre-commit", "pytest", "hypothesis", "ruff", "matplotlib", "seaborn", "pandas", "typer", "optuna", "optunahub", "hebo", "lightbench"]

[project.urls]
source = "https://github.com/HomebrewML/HeavyBall"
Expand Down
35 changes: 32 additions & 3 deletions scripts/run_ci_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set -euo pipefail
export PYTEST_DISABLE_PLUGIN_AUTOLOAD="${PYTEST_DISABLE_PLUGIN_AUTOLOAD:-1}"
export PYTHONWARNINGS="${PYTHONWARNINGS:+$PYTHONWARNINGS,}ignore:pkg_resources is deprecated as an API:UserWarning,ignore:CUDA initialization:UserWarning,ignore:Can't initialize NVML:UserWarning"
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-}"
export TORCH_COMPILE_DISABLE="${TORCH_COMPILE_DISABLE:-1}"

PYTEST_FLAGS=(--maxfail=1 --disable-warnings -q --color=no --code-highlight=no)

Expand All @@ -14,6 +15,34 @@ run_pytest() {
echo "::endgroup::"
}

run_pytest test/test_toy_training.py
run_pytest test/test_migrate_cli.py
run_pytest test/test_psgd_precond_init_stability.py::test_stable_exp_scalar -k dtype1
run_list() {
while IFS= read -r line; do
[[ -z "${line}" ]] && continue
read -r -a args <<<"${line}"
run_pytest "${args[@]}"
done
}

run_list <<'EOF'
test/test_toy_training.py
test/test_migrate_cli.py
test/test_cpu_features.py
test/test_chainable_cpu.py
test/test_helpers_cpu.py
test/test_utils_cpu.py
test/test_optimizer_cpu_smoke.py
test/test_psgd_precond_init_stability.py::test_stable_exp_scalar -k dtype1
test/test_psgd_precond_init_stability.py::test_stable_exp_tensor -k dtype1
test/test_psgd_precond_init_stability.py::test_lse_mean -k dtype1
test/test_psgd_precond_init_stability.py::test_mean_root[dtype1-4-16]
test/test_psgd_precond_init_stability.py::test_mean_root[dtype2-10-512]
test/test_psgd_precond_init_stability.py::test_divided_root[dtype1-3-5-16]
test/test_psgd_precond_init_stability.py::test_divided_root[dtype2-9-4-64]
EOF

if [[ ${1:-} == push ]]; then
run_list <<'EOF'
test/test_toy_training.py
test/test_migrate_cli.py
EOF
fi
1 change: 0 additions & 1 deletion test/test_bf16_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch._dynamo import config

import heavyball
import heavyball.utils
from heavyball.utils import clean, set_torch

os.environ["TORCH_LOGS"] = "+recompiles"
Expand Down
1 change: 0 additions & 1 deletion test/test_bf16_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch._dynamo import config

import heavyball
import heavyball.utils
from heavyball.utils import clean, set_torch

config.cache_size_limit = 128
Expand Down
1 change: 0 additions & 1 deletion test/test_bf16_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch._dynamo import config

import heavyball
import heavyball.utils
from heavyball.utils import clean, set_torch

config.cache_size_limit = 128
Expand Down
1 change: 0 additions & 1 deletion test/test_caution.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch._dynamo import config

import heavyball
import heavyball.utils
from heavyball.utils import clean, set_torch

config.cache_size_limit = 128
Expand Down
65 changes: 65 additions & 0 deletions test/test_chainable_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os

import torch

import heavyball.chainable as C
import heavyball.utils

os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")
heavyball.utils.compile_mode = None


def _identity_update(state, group, update, grad, param):
return update


def test_chain_applies_update_on_cpu():
param = [torch.nn.Parameter(torch.zeros(2))]
grad = [torch.ones(2)]
group = {"lr": 0.1, "caution": False, "weight_decay": 0.0}

with torch.no_grad():
C.chain(lambda _: {}, group, grad, param, _identity_update)

assert torch.allclose(param[0].detach(), torch.full((2,), -0.1))


def test_branch_merges_multiple_paths():
def double(_, __, update, ___, ____):
return [u * 2 for u in update]

def negate(_, __, update, ___, ____):
return [u * -1 for u in update]

def merge_fn(outputs):
return [sum(vals) / len(vals) for vals in zip(*outputs)]

branch = C.Branch([[double], [negate]], merge_fn)

update = [torch.ones(2)]
grad = [torch.ones(2)]
param = [torch.nn.Parameter(torch.ones(2))]

result = branch(lambda _: {}, {}, update, grad, param)
expected = torch.full_like(update[0], 0.5)
assert torch.allclose(result[0], expected)


def test_set_indices_assigns_transform_ids():
def base(_, __, update, ___, ____, buffer):
assert buffer is not None
return update

zero_guard = C.ZeroGuard(base, ["buffer"])
assigned = C.set_indices([zero_guard], retain=False)[0]
assert assigned.transform_idx == 0

def state_fn(_x):
return {}

group = {"storage_dtype": "float32"}
update = [torch.ones(1)]
grad = [torch.ones(1)]
param = [torch.nn.Parameter(torch.ones(1))]

assigned(state_fn, group, update, grad, param)
1 change: 0 additions & 1 deletion test/test_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch._dynamo import config

import heavyball
import heavyball.utils
from heavyball.utils import clean, set_torch

heavyball.utils.zeroth_power_mode = "newtonschulz"
Expand Down
1 change: 0 additions & 1 deletion test/test_closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torch import nn

import heavyball
import heavyball.utils
from heavyball.utils import clean, set_torch


Expand Down
Loading