Skip to content

sisinflab/WaveDiT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis

Try the demo on Hugging Face Spaces WaveDiT Studio for macOS Models on Hugging Face Paper on Hugging Face arXiv

🤗 Try WaveDiT in your browser: pick an age, generate a synthetic 3D brain MRI, and explore it interactively (triplane + 3D viewer with clip-plane slicing). No install needed → huggingface.co/spaces/danesed/WaveDiT-demo

Official PyTorch implementation of "WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis" (MICCAI 2026).

WaveDiT synthesises full resolution, high-fidelity, conditional 3D brain MRIs by performing flow matching in the 3D Haar wavelet domain with a slice-wise HDiT backbone, guided by Morpheus, a state-aware uncertainty scheduler that adaptively weights the loss and sampling across frequency bands.

WaveDiT architecture

Links: 🤗 Live demo · 🤗 Models · Project page · HF paper · arXiv

Key features

  • Wavelet flow matching: operates on the 8-channel 3D Haar latent (1 LLL + 7 HF bands).
  • Morpheus uncertainty scheduler: Bayesian heteroscedastic loss weighting + uncertainty-minimising sampling guidance.
  • HDiT backbone: neighbourhood + spatio-depth factorised attention for efficient 3D modelling.
  • Multiple flow formulations: cfm, rectified, ot_fm.
  • Conditional synthesis: numeric and categorical metadata (e.g. age), with classifier-free guidance.
  • Single-file configs: one YAML fully describes a run; checkpoints are self-contained for generation.

Morpheus: state-aware uncertainty

Wavelet subbands are not statistically equal: the low-frequency approximation stays close to Gaussian, while the high-frequency bands are sparse and heavy-tailed, and these statistics shift along the flow trajectory. Morpheus is a lightweight network that, at each step, reads the statistical signature of the current noisy state (per-band mean, standard deviation, max amplitude, L2 energy, skewness and kurtosis) and predicts a per-band log-variance. That prediction plays two roles:

  • Weighting the loss: it forms a Bayesian heteroscedastic objective (0.5 * exp(-s) * ||v - v_target||^2 + 0.5 * s) that down-weights inherently unpredictable high-frequency content, while the 0.5 * s term prevents trivial variance inflation. The result is state-dependent precision instead of a uniform MSE.
  • Conditioning the backbone: the projected log-variances become a frequency hint, injected alongside the time, slice and age embeddings, so the transformer adapts its prediction to the current reliability of each band, during both training and sampling.

Installation

conda create -n wavedit_env python=3.11 && conda activate wavedit_env
pip install -r requirements.txt
# Optional but recommended: fused neighbourhood-attention CUDA kernels (match your build):
pip install natten -f https://whl.natten.org
# Optional, faster global attention:
# pip install -U xformers

Developed for Python 3.11 and PyTorch 2.6 (CUDA recommended).

NATTEN is optional. It is the fastest, ground-truth implementation of the neighbourhood attention used in the default config, but WaveDiT provides an equivalent built-in pure-PyTorch fallback, so the model runs without NATTEN, including on CPU. The backend is chosen automatically; override with WAVEDIT_NA_BACKEND=auto|natten|torch.

Repository layout

configs/            One YAML per experiment (cfm, rectified, ot_fm)
train.sh            bash train.sh [config.yaml]      -> launches training
generate.sh         bash generate.sh <ckpt> [outdir] -> generates samples
scripts/
  train.py          config-driven training entry point
  generate.py       generation (specific condition sets or linear interpolation)
  prepare_metadata.py  build the metadata CSV from NIfTI folders
tools/
  slim_checkpoint.py   strip optimiser state for release/inference
wavedit/
  config.py         typed config loaded from YAML
  data/             unified dataset (CSV / filename), augmentation, collation
  wavelets/         differentiable 3D Haar DWT/IDWT
  models/           WaveletFlowMatching, DiT3D backbone, Morpheus, sampling, hdit/
  training/         Trainer + checkpoint I/O
  generation/       sample generation
  evaluation/       metrics + W&B visualisation
  utils/            logging + seeding

Data

See data/README.md. In short, build a catalog once:

python scripts/prepare_metadata.py --input-dirs /path/to/scans --output-csv ./data/dataset.csv

then point data.metadata_csv in your config at it. Raw scans and catalogs are git-ignored and must be obtained from the original dataset providers.

Training

Edit a config (data paths, architecture, hyper-parameters) and launch:

bash train.sh configs/cfm.yaml

Or run the entry point directly:

PYTHONPATH=. python scripts/train.py configs/cfm.yaml

Each run writes to <checkpoint_dir>/<run_name>/: best.pth, last.pth, a copy of the resolved config.yaml, and logs. Set logging.wandb: true for W&B metrics and visualisations. Switch the objective with model.flow (cfm | rectified | ot_fm).

Warm-start a new variant

A variant that differs from a trained one only in patch size does not need to train from scratch. scripts/weight_inheritance.py hands almost all of a trained checkpoint's weights to the new model: the HDiT body transfers 1:1, and only the two patch projections are resized to the new token grid with a pseudo-inverse patch resize. Coarse-to-fine starts already in distribution and converges far faster than a from-scratch run.

# Example: warm-start FinePatch (4x4) from a trained Base (8x8)
python scripts/weight_inheritance.py \
    --donor  checkpoints/WaveDiT_CFM_Base/best.pth \
    --config configs/cfm_FinePatch.yaml \
    --output checkpoints/WaveDiT_CFM_FinePatch/last.pth
bash train.sh configs/cfm_FinePatch.yaml   # resumes at epoch 0 with the inherited weights

The released finest variant, WaveDiT-FinePatch2 (patch 2×2), was produced this way, warm-started from WaveDiT-FinePatch (patch 4×4), which cut its training time drastically versus training from scratch while keeping very high sample quality.

Pretrained weights

All variants are published on the Hugging Face Hub at danesed/WaveDiT. Each .pth is self-contained (architecture and condition metadata embedded), so generation needs only the file.

Model Variant Params Full-res inference VRAM¹ Download
Base patch 8×8 (baseline) 142M ~3.1 GB (runs from 4 GB) WaveDiT-Base.pth
FinePatch patch 4×4 142M ~8.4 GB (runs from 10 GB) WaveDiT-FinePatch.pth
FinePatch2 patch 2×2, warm-started 142M ~27 GB (runs from 32 GB) WaveDiT-FinePatch2.pth
Deep depth 4/4 262M ~3.1 GB (runs from 4 GB) WaveDiT-Deep.pth
Wide width 2048, d_ff 8192 506M ~5.6 GB (runs from 8 GB) WaveDiT-Wide.pth

¹ Peak VRAM for full-resolution (224³) generation, batch 1, bf16, 10-step Heun (torch.cuda.max_memory_reserved). The HDiT backbone is highly scalable: because patch size, width and depth are config knobs over a compact wavelet representation, WaveDiT fits a wide range of hardware budgets: full-resolution inference runs on GPUs from 4 GB upward (Base), and the same configs scale training down to modest GPUs by adjusting batch size / variant. No high-end accelerator is required to use the models.

WaveDiT-FinePatch2 was trained by warm start (weight inheritance) from WaveDiT-FinePatch, not from scratch (see Warm-start a new variant), which cut its training time drastically while reaching the finest 2×2 token grid at very high sample quality.

from huggingface_hub import hf_hub_download
ckpt = hf_hub_download("danesed/WaveDiT", "WaveDiT-FinePatch2.pth")

Generation

Checkpoints are self-contained (they embed the config and condition metadata), so generation needs only the checkpoint and your sampling choices.

# Specific condition sets (N samples each)
# NOTE: global flags (--cfg-scale, --num-flow-steps, --sampler, --save-size, ...) go BEFORE the subcommand.
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
    --cfg-scale 1.5 --num-flow-steps 10 --sampler heun --save-size 182 218 182 \
    specific --conditions "age=45.0" "age=70.5" --num-samples 10

# Linearly interpolate one condition (one sample per step)
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
    linear --condition age --min 6 --max 95 --num 100

Or use the launcher: bash generate.sh checkpoints/WaveDiT_CFM/best.pth.

Argument Meaning
--cfg-scale Classifier-free guidance scale (1.0 = none).
--num-flow-steps ODE integration steps (overrides the checkpoint default).
--sampler heun (2nd order) or euler.
--morpheus-scale Uncertainty-guidance scale (0 disables it).
--save-size Center-crop saved volumes to D H W (default: full model output).

Citation

@misc{danese2026waveditdistributionawarewaveletflow,
      title={WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis},
      author={Danilo Danese and Angela Lombardi and Giuseppe Fasano and Matteo Attimonelli and Tommaso Di Noia},
      year={2026},
      eprint={2606.08670},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2606.08670},
}

Acknowledgements

WaveDiT builds on the wavelet-domain analysis and multi-level evaluation protocol of our previous work, FlowLet.

The HDiT backbone is adapted from k-diffusion.

The invertible 3D wavelet transform builds on the great work of WDM

See LICENSE.

Packages

 
 
 

Contributors