Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to add it as a submodule

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "external/WAVE"]
path = external/WAVE
url = https://github.com/TCL606/WAVE.git
706 changes: 706 additions & 0 deletions MTEB benchmark runbook.md

Large diffs are not rendered by default.

273 changes: 273 additions & 0 deletions MTEB-WAVE-7B.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions external/WAVE
Submodule WAVE added at a248b6
3 changes: 2 additions & 1 deletion mteb/abstasks/multilabel_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def _undersample_data_indices(
"""
sample_indices = []
if idxs is None:
idxs = list(np.arange(len(y)))
# plain ints: datasets>=4 lazy Columns reject numpy integer keys
idxs = list(range(len(y)))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have problems here with datasets v4

self.np_rng.shuffle(idxs)
label_counter: dict[int, int] = defaultdict(int)
for i in idxs:
Expand Down
522 changes: 522 additions & 0 deletions mteb/models/model_implementations/wave_models.py

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you keep changes only to wave implementation and pyproject?

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,31 @@ qwen-vl = ["transformers>=4.57.0", "qwen-vl-utils>=0.0.14"]
# omnivinci = ["transformers==4.46.0", "torch==2.8.*", "einops>=0.8.0", "openai-whisper>=20240930", "soundfile>=0.13.1", "opencv-python-headless>=4.0.0", "kaldiio>=2.18.0", "decord>=0.6.0", "librosa>=0.10.0", "beartype>=0.18.0", "accelerate>=1.0.0", "s2wrapper @ git+https://github.com/bfshi/scaling_on_scales", "deepspeed>=0.9.0", "torchvision>=0.15.0", "torchaudio>=2.0.0", "torchcodec==0.7.0", "av>=10.0.0", "numpy>=1.23.0,<2.0.0", "flash-attn>=2.6.3"]
qwen_omni_utils = ["qwen_omni_utils"]
embeddinggemma = ["transformers>=4.56.0"]
# WAVE-7B (tsinghua-ee/WAVE-7B). Versions mirror the upstream repo's requirements.txt
# (https://github.com/TCL606/WAVE), vendored as a submodule at external/WAVE. Also requires the
# `audio` and `video` extras, plus the external BEATs_iter3_plus.pt checkpoint at load time.
wave = [
# Upstream WAVE pins torch==2.6.0 (training repro); inference works on 2.7.1, which is
# required for MTEB's video stack: datasets>=4 (Video -> torchcodec VideoDecoder) needs a
# torchcodec whose torch pairing is >=2.7 (datasets 4.8+ even needs torch>=2.8).
"torch==2.7.1",
"torchvision==0.22.1",
"torchaudio==2.7.1",
"torchcodec==0.4.*", # torch 2.7.1 pairing
"datasets[audio]>=4,<4.8",
"transformers==4.51.3",
"sentence-transformers<5", # ST>=5 hard-imports torchcodec at module top
"liger_kernel==0.5.10",
"decord>=0.6.0",
"soundfile>=0.13.1",
"ffmpeg-python>=0.2.0",
"peft>=0.11.0",
"accelerate>=1.7.0",
"setuptools", # triton needs it at import
# flash-attn is intentionally NOT listed: cold installs would attempt a source build
# (needs torch + nvcc at build time). Install the matching prebuilt wheel afterwards —
# scripts/setup_wave_env.sh does this automatically.
]
# ebind = ["ebind @ git+https://github.com/encord-team/ebind@7909701e9372353ca678b9515f9f61cf87c83c71", "soundfile>=0.13.1", "transformers>=4.57.1,!=5.0.0"]


Expand Down Expand Up @@ -457,6 +482,7 @@ conflicts = [
{ extra = "muq" },
{ extra = "transformers-v4" },
{ extra = "qwen-vl" },
{ extra = "wave" }, # pins transformers==4.51.3
# { extra = "omnivinci" },
{ extra = "embeddinggemma" },
# { extra = "ebind" },
Expand Down
127 changes: 127 additions & 0 deletions scripts/setup_wave_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/bin/bash
# Set up everything needed to evaluate tsinghua-ee/WAVE-7B with MTEB on a new
# (internet-connected) cluster or login node. See MTEB-WAVE-7B.md for background.
#
# Usage:
# bash scripts/setup_wave_env.sh [WORK_DIR] [--prefetch-model] [--prefetch-data TASK1,TASK2,...]
#
# WORK_DIR fast scratch workspace (default: /expscratch/$USER/wave-mteb,
# else $SCRATCH/wave-mteb, else $HOME/wave-mteb)
# --prefetch-model download the 18 GB WAVE-7B snapshot now (otherwise on first get_model)
# --prefetch-data warm the HF cache for the given MTEB task names before GPU time
#
# After it finishes it prints the exports needed in job scripts.
set -euo pipefail

REPO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
WAVE_REVISION="6d42651d34bf1a7d83d5779397d6ce0316a4cf4f"
FLASH_ATTN_RELEASE="v2.7.4.post1"

# ---- args -------------------------------------------------------------------
WORK=""
PREFETCH_MODEL=0
PREFETCH_DATA=""
while [[ $# -gt 0 ]]; do
case "$1" in
--prefetch-model) PREFETCH_MODEL=1 ;;
--prefetch-data) PREFETCH_DATA="$2"; shift ;;
*) WORK="$1" ;;
esac
shift
done
if [[ -z "$WORK" ]]; then
if [[ -d "/expscratch/$USER" ]]; then WORK="/expscratch/$USER/wave-mteb"
elif [[ -n "${SCRATCH:-}" ]]; then WORK="$SCRATCH/wave-mteb"
else WORK="$HOME/wave-mteb"; fi
fi
echo "==> workspace: $WORK"
mkdir -p "$WORK/beats" "$WORK/logs" "$WORK/.cache"

# ---- caches off /home -------------------------------------------------------
export XDG_CACHE_HOME="$WORK/.cache"
export HF_HOME="$WORK/.cache/huggingface"
export UV_CACHE_DIR="$WORK/.cache/uv"
export PIP_CACHE_DIR="$WORK/.cache/pip"
mkdir -p "$HF_HOME"

# ---- code: submodule --------------------------------------------------------
cd "$REPO_DIR"
git submodule update --init --recursive external/WAVE
test -f external/WAVE/qwenvl/model/qwen2_5_omni/modeling_qwen2_5_omni.py

# ---- env --------------------------------------------------------------------
if ! command -v uv >/dev/null; then
curl -LsSf https://astral.sh/uv/install.sh | sh
export PATH="$HOME/.local/bin:$PATH"
fi
if [[ ! -x "$WORK/.venv/bin/python" ]]; then
uv venv "$WORK/.venv" --python 3.10
fi
# shellcheck disable=SC1091
source "$WORK/.venv/bin/activate"
uv pip install -e ".[wave,audio,video]"

# ---- flash-attn: matching prebuilt wheel (no source builds) ------------------
SPEC=$(python - <<'PY'
import sys, torch
py = f"cp{sys.version_info.major}{sys.version_info.minor}"
tch = ".".join(torch.__version__.split("+")[0].split(".")[:2])
cu = f"cu{torch.version.cuda.split('.')[0]}" if torch.version.cuda else "cpu"
abi = "TRUE" if torch._C._GLIBCXX_USE_CXX11_ABI else "FALSE"
print(f"{cu}torch{tch}cxx11abi{abi}-{py}")
PY
)
CU_TORCH_ABI="${SPEC%-*}"
PYTAG="${SPEC#*-}"
WHEEL="flash_attn-${FLASH_ATTN_RELEASE#v}+${CU_TORCH_ABI}-${PYTAG}-${PYTAG}-linux_x86_64.whl"
echo "==> flash-attn wheel: $WHEEL"
uv pip install --no-build-isolation \
"https://github.com/Dao-AILab/flash-attention/releases/download/${FLASH_ATTN_RELEASE}/${WHEEL}"

# ---- BEATs checkpoint (required at model load; NOT auto-downloaded) ----------
BEATS="$WORK/beats/BEATs_iter3_plus_AS2M.pt"
if [[ ! -s "$BEATS" ]]; then
curl -fL --retry 3 -o "$BEATS" \
"https://huggingface.co/datasets/Bencr/beats-checkpoints/resolve/main/BEATs_iter3_plus_AS2M.pt"
fi
export WAVE_BEATS_PATH="$BEATS"

# ---- optional prefetches ------------------------------------------------------
if [[ "$PREFETCH_MODEL" == 1 ]]; then
python - <<PY
from huggingface_hub import snapshot_download
p = snapshot_download("tsinghua-ee/WAVE-7B", revision="$WAVE_REVISION")
print("model snapshot:", p)
PY
fi
if [[ -n "$PREFETCH_DATA" ]]; then
python - <<PY
import mteb
for name in "$PREFETCH_DATA".split(","):
task = mteb.get_tasks(tasks=[name.strip()])[0]
print("prefetching:", name)
task.load_data()
PY
fi

# ---- preflight ----------------------------------------------------------------
echo "==> preflight"
python -c "import mteb; m = mteb.get_model_meta('tsinghua-ee/WAVE-7B'); print('registry OK:', m.name, m.embed_dim)"
python -c "from torchcodec.decoders import VideoDecoder" 2>/dev/null \
&& echo "torchcodec OK (ffmpeg libs found)" \
|| echo "WARNING: torchcodec cannot find FFmpeg libs - load an ffmpeg 4-7 module or install ffmpeg (needed for video tasks)"
command -v nvidia-smi >/dev/null \
&& nvidia-smi --query-gpu=name --format=csv,noheader | head -1 \
|| echo "NOTE: no GPU on this node - run evaluations via your scheduler (WAVE needs bf16: A100/L40S/H100, not V100)"

cat <<EOF

Done. In job scripts, set:
source $WORK/.venv/bin/activate
export HF_HOME=$HF_HOME
export WAVE_BEATS_PATH=$BEATS
# plus ffmpeg libs for video tasks, e.g. on HLTCOE: module load ffmpeg/6.0.1

Smoke test (GPU node):
python -c "import mteb; m = mteb.get_model('tsinghua-ee/WAVE-7B'); print(type(m).__name__)"
EOF
Loading