Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ runs:

- name: Set up Python ${{ inputs.python-version }}
run: |
uv pip install --upgrade pip setuptools
uv pip install --upgrade pip setuptools packaging
shell: bash

- name: Install numpy
Expand Down
19 changes: 9 additions & 10 deletions test/nn/norm/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@

from torch_geometric.nn import BatchNorm, HeteroBatchNorm
from torch_geometric.testing import is_full_test, withDevice
from torch_geometric.typing import WITH_PT212


@withDevice
@pytest.mark.parametrize('conf', [True, False])
def test_batch_norm(device, conf):
x = torch.randn(100, 16, device=device)

norm = BatchNorm(16, affine=conf, track_running_stats=conf, device=device)
norm.reset_running_stats()
norm.reset_parameters()
if WITH_PT212:
assert str(norm) == (f'BatchNorm(16, eps=1e-05, momentum=0.1, '
f'affine={conf}, bias={conf}, '
f'track_running_stats={conf})')
else:
assert str(norm) == (f'BatchNorm(16, eps=1e-05, momentum=0.1, '
f'affine={conf}, '
f'track_running_stats={conf})')

bn = getattr(norm, "module", norm)
assert bn.num_features == 16
assert bn.eps == 1e-5
assert bn.momentum == 0.1
assert bn.affine == conf
assert bn.track_running_stats == conf
assert (bn.weight is not None) == conf
assert (bn.bias is not None) == conf

if is_full_test():
torch.jit.script(norm)
Expand Down
Loading