Skip to content

Wrong GNNBenchmarkDataset MNIST Data #10592

@Akulen

Description

@Akulen

🐛 Describe the bug

The MNIST dataset provided by GNNBenchmarkDataset is different from the dataset in the source paper “Benchmarking Graph Neural Networks”. The node features seem to be the same, but the edges are not (The number of edges seems to be the same, but not their associated nodes, or edge features). This might be on purpose, but I couldn't find any documentation explaining this discrepancy. I used the following code run on a clone of the paper's repository to show those differences:

import torch

from torch_geometric.datasets import GNNBenchmarkDataset
from layers.gin_layer import GINLayer as GINLayerTorch, MLP, ApplyNodeFunc
from data.superpixels import SuperPixDataset

import torch_geometric
print(torch_geometric.__file__)

data_pyg = GNNBenchmarkDataset(root='data', name='MNIST', split='train')
data_bench = SuperPixDataset('MNIST').train.graph_lists

for i in range(10):
    g_pyg = data_pyg[i]
    g_bench = data_bench[i]

    print(g_pyg.num_nodes, g_pyg.num_edges, '|', g_bench.num_nodes(), g_bench.num_edges())

    def MSE(x, y):
        return (x - y).pow(2).mean()

    print('Node/Edge Features MSE:', MSE(torch.cat((g_pyg.x, g_pyg.pos), dim=1), g_bench.ndata['feat']), '/', MSE(g_pyg.edge_attr, g_bench.edata['feat'][:,0]))

    edges_pyg = set()
    edges_bench = set()
    for u, v in g_pyg.edge_index.T:
        edges_pyg.add((u.item(), v.item()))
    U, V = g_bench.edges()
    for u, v in zip(U, V):
        edges_bench.add((u.item(), v.item()))

    print('Edges diffs:', len(edges_pyg - edges_bench), len(edges_bench - edges_pyg))

With the following outputs:

Downloading https://data.pyg.org/datasets/benchmarking-gnns/MNIST_v2.zip
Extracting data/MNIST/raw/MNIST_v2.zip
Processing...
Done!
[I] Loading dataset MNIST...
train, test, val sizes : 55000 10000 5000
[I] Finished loading.
[I] Data load time: 30.1530s
69 552 | 69 552
Node/Edge Features MSE: tensor(0.) / tensor(0.0367)
Edges diffs: 57 57
71 568 | 71 568
Node/Edge Features MSE: tensor(0.) / tensor(0.0292)
Edges diffs: 56 56
73 584 | 73 584
Node/Edge Features MSE: tensor(0.) / tensor(0.0373)
Edges diffs: 63 63
75 600 | 75 600
Node/Edge Features MSE: tensor(0.) / tensor(0.0283)
Edges diffs: 55 55
75 600 | 75 600
Node/Edge Features MSE: tensor(0.) / tensor(0.0299)
Edges diffs: 58 58
72 576 | 72 576
Node/Edge Features MSE: tensor(0.) / tensor(0.0343)
Edges diffs: 56 56
74 592 | 74 592
Node/Edge Features MSE: tensor(0.) / tensor(0.0307)
Edges diffs: 53 53
66 528 | 66 528
Node/Edge Features MSE: tensor(0.) / tensor(0.0261)
Edges diffs: 49 49
73 584 | 73 584
Node/Edge Features MSE: tensor(0.) / tensor(0.0227)
Edges diffs: 57 57
74 592 | 74 592
Node/Edge Features MSE: tensor(0.) / tensor(0.0312)
Edges diffs: 48 48

I am pretty confident the order of the graphs is the same, given that the node features perfectly match. This is problematic as I was trying to reproduce their results using a jax implementation, and got consistently lower test accuracy with the pyg version of the dataset (~94%) compared to the accuracy produced by the paper's code (~96%).

Looking into the discrepancies, I believe both datasets graphs were generated using some knn algorithm to select the edges, but a different metric was used to produce the pyg dataset. If this is intended and not a bug, it would be great if it could be clearly mentioned in the documentation.

As I couldn't install the paper's environment, I made the following pixi config file to run the previous code:

[workspace]
authors = ["Akulen <tomas@rigaux.com>"]
channels = ["conda-forge", "dglteam/label/cu121", "https://conda.anaconda.org/dglteam/label/th21_cu121"]
name = "benchmarking"
platforms = ["linux-64"]
version = "0.1.0"

[system-requirements]
cuda = "12.0"

[dependencies]
pytorch-gpu = ">=2.10.0,<3"
dgl = { version = ">=2.4.0.th21.cu121,<3", channel = "dglteam/label/th21_cu121" }
setuptools = ">=80.10.1,<81"
packaging = ">=26.0,<27"
pandas = ">=3.0.0,<4"
tensorboardx = ">=2.6.2.2,<3"
scikit-learn = ">=1.8.0,<2"
ogb = ">=1.3.6,<2"
matplotlib = ">=3.10.8,<4"
ipython = ">=9.9.0,<10"
tensorboard = ">=2.20.0,<3"

[pypi-dependencies]
torch-geometric = ">=2.7.0, <3"

Versions

Collecting environment information...
PyTorch version: 2.10.0
Is debug build: False
CUDA used to build PyTorch: 12.9
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.12.12 | packaged by conda-forge | (main, Jan 26 2026, 23:51:32) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-124-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.9.86
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA RTX A5000
GPU 1: NVIDIA RTX A5000
GPU 2: NVIDIA RTX A5000
GPU 3: NVIDIA RTX A5000
GPU 4: NVIDIA RTX A5000
GPU 5: NVIDIA RTX A5000

Nvidia driver version: 560.35.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               96
On-line CPU(s) list:                  0-95
Vendor ID:                            AuthenticAMD
Model name:                           AMD EPYC 7413 24-Core Processor
CPU family:                           25
Model:                                1
Thread(s) per core:                   2
Core(s) per socket:                   24
Socket(s):                            2
Stepping:                             1
Frequency boost:                      enabled
CPU max MHz:                          3630.8101
CPU min MHz:                          1500.0000
BogoMIPS:                             5299.56
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization:                       AMD-V
L1d cache:                            1.5 MiB (48 instances)
L1i cache:                            1.5 MiB (48 instances)
L2 cache:                             24 MiB (48 instances)
L3 cache:                             256 MiB (8 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-23,48-71
NUMA node1 CPU(s):                    24-47,72-95
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; safe RET
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] Could not collect
[conda] cuda-cudart               12.9.79              h5888daf_0    conda-forge
[conda] cuda-cudart_linux-64      12.9.79              h3f2d84a_0    conda-forge
[conda] cuda-cupti                12.9.79              h676940d_1    conda-forge
[conda] cuda-nvrtc                12.9.86              hecca717_1    conda-forge
[conda] cuda-nvtx                 12.9.79              hecca717_1    conda-forge
[conda] cudnn                     9.10.2.21            hbcb9cd8_0    conda-forge
[conda] libblas                   3.11.0           5_h5875eb1_mkl    conda-forge
[conda] libcblas                  3.11.0           5_hfef963f_mkl    conda-forge
[conda] libcublas                 12.9.1.4             h676940d_1    conda-forge
[conda] libcudnn                  9.10.2.21            hf7e9902_0    conda-forge
[conda] libcudnn-dev              9.10.2.21            h58dd1b1_0    conda-forge
[conda] libcufft                  11.4.1.4             hecca717_1    conda-forge
[conda] libcurand                 10.3.10.19           h676940d_1    conda-forge
[conda] libcusolver               11.7.5.82            h676940d_2    conda-forge
[conda] libcusparse               12.5.10.65           hecca717_2    conda-forge
[conda] liblapack                 3.11.0           5_h5e43f62_mkl    conda-forge
[conda] libmagma                  2.9.0                ha7672b3_6    conda-forge
[conda] libnvjitlink              12.9.86              hecca717_2    conda-forge
[conda] libtorch                  2.10.0          cuda129_mkl_hf5f578d_300    conda-forge
[conda] mkl                       2025.3.0           h0e700b2_463    conda-forge
[conda] nccl                      2.29.2.1             h4d09622_1    conda-forge
[conda] numpy                     2.4.1           py312h33ff503_0    conda-forge
[conda] optree                    0.18.0          py312hd9148b4_0    conda-forge
[conda] pytorch                   2.10.0          cuda129_mkl_py312_h2ff76c1_300    conda-forge
[conda] pytorch-gpu               2.10.0          cuda129_mkl_h0d04637_300    conda-forge
[conda] tbb                       2022.3.0             hb700be7_2    conda-forge
[conda] torch-geometric           2.7.0                    pypi_0    pypi
[conda] triton                    3.5.1           cuda129py312h811769c_0    conda-forge

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions