🤗 Try WaveDiT in your browser: pick an age, generate a synthetic 3D brain MRI, and explore it interactively (triplane + 3D viewer with clip-plane slicing). No install needed → huggingface.co/spaces/danesed/WaveDiT-demo
Official PyTorch implementation of "WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis" (MICCAI 2026).
WaveDiT synthesises full resolution, high-fidelity, conditional 3D brain MRIs by performing flow matching in the 3D Haar wavelet domain with a slice-wise HDiT backbone, guided by Morpheus, a state-aware uncertainty scheduler that adaptively weights the loss and sampling across frequency bands.
Links: 🤗 Live demo · 🤗 Models · Project page · HF paper · arXiv
- Wavelet flow matching: operates on the 8-channel 3D Haar latent (1 LLL + 7 HF bands).
- Morpheus uncertainty scheduler: Bayesian heteroscedastic loss weighting + uncertainty-minimising sampling guidance.
- HDiT backbone: neighbourhood + spatio-depth factorised attention for efficient 3D modelling.
- Multiple flow formulations:
cfm,rectified,ot_fm. - Conditional synthesis: numeric and categorical metadata (e.g. age), with classifier-free guidance.
- Single-file configs: one YAML fully describes a run; checkpoints are self-contained for generation.
Wavelet subbands are not statistically equal: the low-frequency approximation stays close to Gaussian, while the high-frequency bands are sparse and heavy-tailed, and these statistics shift along the flow trajectory. Morpheus is a lightweight network that, at each step, reads the statistical signature of the current noisy state (per-band mean, standard deviation, max amplitude, L2 energy, skewness and kurtosis) and predicts a per-band log-variance. That prediction plays two roles:
- Weighting the loss: it forms a Bayesian heteroscedastic objective
(
0.5 * exp(-s) * ||v - v_target||^2 + 0.5 * s) that down-weights inherently unpredictable high-frequency content, while the0.5 * sterm prevents trivial variance inflation. The result is state-dependent precision instead of a uniform MSE. - Conditioning the backbone: the projected log-variances become a frequency hint, injected alongside the time, slice and age embeddings, so the transformer adapts its prediction to the current reliability of each band, during both training and sampling.
conda create -n wavedit_env python=3.11 && conda activate wavedit_env
pip install -r requirements.txt
# Optional but recommended: fused neighbourhood-attention CUDA kernels (match your build):
pip install natten -f https://whl.natten.org
# Optional, faster global attention:
# pip install -U xformersDeveloped for Python 3.11 and PyTorch 2.6 (CUDA recommended).
NATTEN is optional. It is the fastest, ground-truth implementation of the neighbourhood attention used in the default config, but WaveDiT provides an equivalent built-in pure-PyTorch fallback, so the model runs without NATTEN, including on CPU. The backend is chosen automatically; override with
WAVEDIT_NA_BACKEND=auto|natten|torch.
configs/ One YAML per experiment (cfm, rectified, ot_fm)
train.sh bash train.sh [config.yaml] -> launches training
generate.sh bash generate.sh <ckpt> [outdir] -> generates samples
scripts/
train.py config-driven training entry point
generate.py generation (specific condition sets or linear interpolation)
prepare_metadata.py build the metadata CSV from NIfTI folders
tools/
slim_checkpoint.py strip optimiser state for release/inference
wavedit/
config.py typed config loaded from YAML
data/ unified dataset (CSV / filename), augmentation, collation
wavelets/ differentiable 3D Haar DWT/IDWT
models/ WaveletFlowMatching, DiT3D backbone, Morpheus, sampling, hdit/
training/ Trainer + checkpoint I/O
generation/ sample generation
evaluation/ metrics + W&B visualisation
utils/ logging + seeding
See data/README.md. In short, build a catalog once:
python scripts/prepare_metadata.py --input-dirs /path/to/scans --output-csv ./data/dataset.csvthen point data.metadata_csv in your config at it. Raw scans and catalogs are
git-ignored and must be obtained from the original dataset providers.
Edit a config (data paths, architecture, hyper-parameters) and launch:
bash train.sh configs/cfm.yamlOr run the entry point directly:
PYTHONPATH=. python scripts/train.py configs/cfm.yamlEach run writes to <checkpoint_dir>/<run_name>/: best.pth, last.pth, a copy of the
resolved config.yaml, and logs. Set logging.wandb: true for W&B metrics and
visualisations. Switch the objective with model.flow (cfm | rectified | ot_fm).
A variant that differs from a trained one only in patch size does not need to train from
scratch. scripts/weight_inheritance.py hands almost all of a trained checkpoint's
weights to the new model: the HDiT body transfers 1:1, and only the two patch projections
are resized to the new token grid with a pseudo-inverse patch resize. Coarse-to-fine
starts already in distribution and converges far faster than a from-scratch run.
# Example: warm-start FinePatch (4x4) from a trained Base (8x8)
python scripts/weight_inheritance.py \
--donor checkpoints/WaveDiT_CFM_Base/best.pth \
--config configs/cfm_FinePatch.yaml \
--output checkpoints/WaveDiT_CFM_FinePatch/last.pth
bash train.sh configs/cfm_FinePatch.yaml # resumes at epoch 0 with the inherited weightsThe released finest variant, WaveDiT-FinePatch2 (patch 2×2), was produced this way,
warm-started from WaveDiT-FinePatch (patch 4×4), which cut its training time drastically
versus training from scratch while keeping very high sample quality.
All variants are published on the Hugging Face Hub at
danesed/WaveDiT. Each .pth is self-contained
(architecture and condition metadata embedded), so generation needs only the file.
| Model | Variant | Params | Full-res inference VRAM¹ | Download |
|---|---|---|---|---|
| Base | patch 8×8 (baseline) | 142M | ~3.1 GB (runs from 4 GB) | WaveDiT-Base.pth |
| FinePatch | patch 4×4 | 142M | ~8.4 GB (runs from 10 GB) | WaveDiT-FinePatch.pth |
| FinePatch2 | patch 2×2, warm-started | 142M | ~27 GB (runs from 32 GB) | WaveDiT-FinePatch2.pth |
| Deep | depth 4/4 | 262M | ~3.1 GB (runs from 4 GB) | WaveDiT-Deep.pth |
| Wide | width 2048, d_ff 8192 | 506M | ~5.6 GB (runs from 8 GB) | WaveDiT-Wide.pth |
¹ Peak VRAM for full-resolution (224³) generation, batch 1, bf16, 10-step Heun
(torch.cuda.max_memory_reserved). The HDiT backbone is highly scalable: because
patch size, width and depth are config knobs over a compact wavelet representation, WaveDiT fits
a wide range of hardware budgets: full-resolution inference runs on GPUs from 4 GB
upward (Base), and the same configs scale training down to modest GPUs by adjusting
batch size / variant. No high-end accelerator is required to use the models.
WaveDiT-FinePatch2 was trained by warm start (weight inheritance) from
WaveDiT-FinePatch, not from scratch (see Warm-start a new variant),
which cut its training time drastically while reaching the finest 2×2 token grid at very
high sample quality.
from huggingface_hub import hf_hub_download
ckpt = hf_hub_download("danesed/WaveDiT", "WaveDiT-FinePatch2.pth")Checkpoints are self-contained (they embed the config and condition metadata), so generation needs only the checkpoint and your sampling choices.
# Specific condition sets (N samples each)
# NOTE: global flags (--cfg-scale, --num-flow-steps, --sampler, --save-size, ...) go BEFORE the subcommand.
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
--cfg-scale 1.5 --num-flow-steps 10 --sampler heun --save-size 182 218 182 \
specific --conditions "age=45.0" "age=70.5" --num-samples 10
# Linearly interpolate one condition (one sample per step)
PYTHONPATH=. python scripts/generate.py checkpoints/WaveDiT_CFM/best.pth out/ \
linear --condition age --min 6 --max 95 --num 100Or use the launcher: bash generate.sh checkpoints/WaveDiT_CFM/best.pth.
| Argument | Meaning |
|---|---|
--cfg-scale |
Classifier-free guidance scale (1.0 = none). |
--num-flow-steps |
ODE integration steps (overrides the checkpoint default). |
--sampler |
heun (2nd order) or euler. |
--morpheus-scale |
Uncertainty-guidance scale (0 disables it). |
--save-size |
Center-crop saved volumes to D H W (default: full model output). |
@misc{danese2026waveditdistributionawarewaveletflow,
title={WaveDiT: Distribution-Aware Wavelet Flow Matching for Efficient 3D Brain MRI Synthesis},
author={Danilo Danese and Angela Lombardi and Giuseppe Fasano and Matteo Attimonelli and Tommaso Di Noia},
year={2026},
eprint={2606.08670},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2606.08670},
}WaveDiT builds on the wavelet-domain analysis and multi-level evaluation protocol of our previous work, FlowLet.
The HDiT backbone is adapted from k-diffusion.
The invertible 3D wavelet transform builds on the great work of WDM
See LICENSE.
