Skip to content

LAION-AI/open-clap-scaling

Repository files navigation

CLAP HTSAT Multi-Node Scaling Benchmark

Throughput and scaling efficiency measurements for CLAP contrastive audio-text models on the Jupiter HPC cluster, from 1 to 128 nodes (4 to 512 NVIDIA GH200 GPUs). Two model sizes benchmarked:

  • HTSAT-tiny-Roberta-base (111.20 GFLOPs/sample, BS=616)
  • HTSAT-base-Roberta-base (130.09 GFLOPs/sample, BS=416)

Results Summary

HTSAT-tiny-Roberta-base

Nodes GPUs Samples/s/GPU Total Samples/s MFU (%) Scaling Efficiency (%)
1 4 623.7 2,495 21.0 100.0
2 8 620.7 4,966 20.9 99.5
4 16 616.5 9,865 20.8 98.9
8 32 611.6 19,572 20.6 98.1
16 64 601.8 38,512 20.3 96.5
32 128 586.0 75,008 19.8 94.0
64 256 573.3 146,767 19.3 91.9
128 512 544.6 278,833 18.4 87.3

HTSAT-base-Roberta-base

Nodes GPUs Samples/s/GPU Total Samples/s MFU (%) Scaling Efficiency (%)
1 4 538.2 2,153 21.2 100.0
2 8 530.5 4,244 20.9 98.6
4 16 523.7 8,379 20.7 97.3
8 32 521.9 16,701 20.6 96.9
16 64 507.8 32,502 20.0 94.4
32 128 505.4 64,691 19.9 93.9
64 256 485.3 124,245 19.1 90.2
128 512 454.8 232,876 18.0 84.5

Key Takeaways

Both models demonstrate excellent multi-node scaling with standard DDP — no gradient accumulation, no torch.compile, no FSDP:

  • HTSAT-tiny: 87.3% scaling efficiency at 512 GPUs (112× speedup on 128× hardware)
  • HTSAT-base: 84.5% scaling efficiency at 512 GPUs (108× speedup on 128× hardware)
  • MFU stays above 18% for both models at all node counts
  • The larger model (HTSAT-base) scales slightly worse due to more gradient data in DDP allreduce

How to Reproduce

Prerequisites

# Activate the conda environment (Jupiter cluster)
source /e/project1/laionize/gijs/miniforge3/etc/profile.d/conda.sh
conda activate clapv2

# The training code repo
cd /e/project1/laionize/gijs/laionclapv2/clapv2

Quick Start: Run a Single Benchmark

# 1-node benchmark (4 GPUs, ~15 min)
MODEL=HTSAT-tiny-Roberta-base BATCH_SIZE=616 ACCUM_FREQ=1 FWD_GFLOPS=111.20 LABEL=test \
  sbatch --nodes=1 --time=00:15:00 \
    --output=/e/scratch/reformo/gijs/data/benchmarks/slurm_logs/htsat_tiny_test_%j.log \
    benchmark_htsat_scaling.sh

Full Scaling Sweep (1–128 Nodes)

# Submit all 8 node counts. Walltimes are tuned per scale:
# - 1–8 nodes: 15 min (fast init)
# - 16–32 nodes: 25 min (slower init with more ranks)
# - 64–128 nodes: 30 min (HuggingFace model loading across many ranks)

bash submit_all.sh

Monitor and Collect Results

# Check job status
squeue -u $USER | grep htsat

# Watch a running job's throughput
tail -f /e/scratch/reformo/gijs/data/benchmarks/slurm_logs/htsat_tiny_*.log | grep "Train Epoch"

# Collect results after all jobs complete
cd /e/project1/laionize/gijs/clapv2
python open_clap_scaling/helpers/compute_throughput_flops.py \
  /e/scratch/reformo/gijs/data/benchmarks/clap_scale_eff_htsat/htsat_tiny_bs616_af1_n1_*/

Configuration Details

Model

Component Details
Audio encoder HTSAT-tiny (embed_dim=96, depths=[2,2,6,2], heads=[4,8,16,32])
Text encoder RoBERTa-base (768-dim, 12 layers)
Loss Standard CLIP contrastive loss
Forward GFLOPs 111.20 per sample
Parameters ~30M (audio) + ~125M (text)

Training Configuration

Parameter Value Rationale
Batch size / GPU 616 Largest multiple of 8 that fits in GPU memory without compile
Gradient accumulation 1 HTSAT scales well without it; af>1 adds overhead
torch.compile OFF Avoids compile warmup time and memory overhead
Precision amp_bfloat16 Mixed precision for throughput on GH200
Data synthetic-audio Random waveforms — eliminates I/O variance for clean measurements
Workers 8 DataLoader worker threads
Optimizer AdamW (lr=3e-4, no warmup, no schedule) Benchmarking only — step throughput doesn't depend on LR

Hardware (Jupiter HPC)

Detail Value
GPU NVIDIA GH200 (95 GB HBM3)
Peak BF16 989 TFLOPS per GPU
GPUs per node 4
CPUs per task 72 (ARM)
Interconnect InfiniBand (mlx5)
Network Direct GPU RDMA via NCCL_NET_GDR_LEVEL=PHB

NCCL Tuning (Applied to All Runs)

# InfiniBand configuration
export NCCL_IB_HCA=mlx5
export NCCL_SOCKET_IFNAME=ib0
export NCCL_NET_GDR_LEVEL=PHB

# Large-scale tuning (helps at 16+ nodes)
export NCCL_MIN_NCHANNELS=16
export NCCL_MAX_NCHANNELS=32
export NCCL_BUFFSIZE=16777216        # 16 MB
export NCCL_ALGO=Tree,Ring
export NCCL_CROSS_NIC=1
export NCCL_IB_QPS_PER_CONNECTION=2

MFU Calculation

MFU = (fwd_GFLOPs × 3 × samples_per_sec_per_gpu) / (peak_TFLOPS × 1000) × 100
    = (111.20 × 3 × sps/gpu) / (989 × 1000) × 100

Scaling Efficiency = sps_per_gpu_at_N_nodes / sps_per_gpu_at_1_node × 100

The factor of 3 accounts for 1 forward + 2 backward passes per training step.

Files in This Directory

This directory is fully standalone — it contains all source code, configs, and scripts needed to reproduce the benchmark on any SLURM cluster with NVIDIA GPUs.

open_clap_scaling/
├── README.md                          # This file
├── results_htsat_tiny.csv             # HTSAT-tiny results (1–128 nodes)
├── results_htsat_base.csv             # HTSAT-base results (1–128 nodes)
├── benchmark_htsat_scaling.sh         # SLURM benchmark script (parameterized, commented)
├── submit_all.sh                      # One-command sweep: submits 1–128 node jobs
├── requirements.txt                   # Base Python dependencies
├── requirements-training.txt          # Training-specific dependencies
├── helpers/
│   └── compute_throughput_flops.py    # Log parser → throughput, TFLOPS, MFU
└── src/                               # Complete training source code
    ├── open_clip/                     # Model definitions
    │   ├── __init__.py
    │   ├── htsat.py                   # HTSAT audio encoder
    │   ├── model.py                   # CLAP model class
    │   ├── factory.py                 # Model creation
    │   ├── hf_model.py                # HuggingFace text encoder wrapper
    │   ├── hf_configs.py              # HF model configurations
    │   ├── audio.py                   # Audio processing pipeline
    │   ├── loss.py                    # Contrastive loss
    │   ├── tokenizer.py               # Text tokenization
    │   ├── model_configs/             # JSON model configs
    │   │   ├── HTSAT-tiny-Roberta-base.json
    │   │   ├── HTSAT-base-Roberta-base.json
    │   │   └── ...
    │   └── ...                        # Other model variants
    └── open_clip_train/               # Training loop
        ├── __init__.py
        ├── main.py                    # Entry point (python -m open_clip_train.main)
        ├── train.py                   # Training loop with DDP, grad accum, no_sync
        ├── data.py                    # Data loading (webdataset + synthetic)
        ├── distributed.py             # DDP/FSDP setup
        ├── params.py                  # CLI argument parsing
        ├── precision.py               # Mixed precision (amp_bfloat16)
        ├── scheduler.py               # LR scheduling
        ├── profiler.py                # FLOP profiling
        └── zero_shot.py               # Zero-shot evaluation

Adapting to Another Cluster

The benchmark script uses Jupiter-specific paths that need updating for other clusters:

  1. Conda path (line ~52): source /e/project1/laionize/gijs/miniforge3/etc/profile.d/conda.sh
  2. HF cache (line ~77): export HF_HOME="/e/scratch/reformo/gijs/cache"
  3. SLURM account/partition (lines 2-7): --account=reformo, --partition=booster
  4. Tar list (line ~97): Points to a cached tar file list (only needed for the train_data argument format; synthetic data doesn't actually read these files)
  5. Log directory (line ~90): Output path for experiment logs

The NCCL InfiniBand settings (NCCL_IB_HCA, NCCL_SOCKET_IFNAME) may need adjusting for different network hardware.

Installing Dependencies

conda create -n clapv2 python=3.12
conda activate clapv2
pip install -r requirements-training.txt

About

Multi-node scaling benchmarks for CLAP contrastive audio-language models on HPC clusters

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors