Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ jobs:
python -m pip install --upgrade pip
pip install -e ".[all]"

# Hugging Face Hub rate-limits unauthenticated requests from shared runner IPs (HTTP 429).
# With a warm cache, huggingface_hub falls back to cached files when a HEAD request fails.
- name: Cache Hugging Face models
uses: actions/cache@v4
with:
path: ~/.cache/huggingface
key: huggingface-${{ runner.os }}-${{ hashFiles('tests/**/*.py') }}
restore-keys: |
huggingface-${{ runner.os }}-

- name: Run tests with pytest (except "slow" tests)
run: |
pytest -m "not slow"
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
/models/
/outputs/

# Ignore the uv.lock generated when running `uv run` / `uv sync`.
uv.lock

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Added

- Add ColQwen3.5 and BiQwen3.5 support (model + processor). Pretrained checkpoint: [athrael-soju/colqwen3.5-4.5B-v3](https://huggingface.co/athrael-soju/colqwen3.5-4.5B-v3).
- Add optional `[lik]` extra (`late-interaction-kernels>=0.4.1,<0.5.0`) that routes `score_multi_vector` and the five ColBERT losses through the fused Triton MaxSim kernel on CUDA Ampere+ / Apple Silicon, with a transparent torch fallback elsewhere. `COLPALI_SCORES_BACKEND` selects the backend (mirrors PyLate's `PYLATE_SCORES_BACKEND`): `auto` (default), `torch` (force the fallback), or `lik` (strict; raises when the kernel cannot run).

### Changed

Expand All @@ -18,6 +19,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
### Fixed

- Fix ModernVBERT wrappers to rely on the upstream Hugging Face implementation and keep checkpoint key conversion mapping working with current Transformers v5 loading.
- Fix `ContrastiveTrainer._get_train_sampler` to accept the dataset argument that Transformers v5 passes positionally (single-dataset training crashed with a `TypeError` at dataloader build).
- Fix `ContrastiveTrainer` to prime `query_prefix`/`pos_prefix`/`neg_prefix` from the collator in `__init__` (single-dataset training crashed with an `AttributeError` in `compute_loss`, as only the multi-dataset path set them).

## [0.3.14] - 2026-02-24

Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ Mac users using MPS with the ColQwen models have reported errors with torch 2.6.
> [!WARNING]
> For ColPali versions above v1.0, make sure to install the `colpali-engine` package from source or with a version above v0.2.0.

### Fused MaxSim kernels (optional)

The optional `[lik]` extra installs [`late-interaction-kernels`](https://github.com/hcompai/late-interaction-kernels), a fused Triton MaxSim kernel used automatically on CUDA Ampere+ / Apple Silicon for scoring and the ColBERT losses. It avoids materializing the `[B, B, Lq, Ld]` score tensor, whose memory cost grows quadratically with the batch size and can become the allocation that caps it: in our ColQwen2 + LoRA benchmark on an 80 GB H100, this raised the largest trainable batch size from 64 to 128, with unchanged end-to-end throughput. Full benchmark results in [illuin-tech/colpali#412](https://github.com/illuin-tech/colpali/pull/412):

```bash
pip install "colpali-engine[lik]"
```

The `COLPALI_SCORES_BACKEND` environment variable selects the backend (mirrors PyLate's `PYLATE_SCORES_BACKEND`): `auto` (default) uses the kernel when eligible and silently falls back to torch, `torch` forces the pure-torch reference, and `lik` requires the kernel and raises if it cannot run.

## Development docs

- [Adding a new model family](docs/add_model_family.md)
Expand Down
55 changes: 34 additions & 21 deletions colpali_engine/loss/late_interaction_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.nn.functional as F # noqa: N812
from torch.nn import CrossEntropyLoss

from colpali_engine.utils.maxsim import maxsim_inbatch, maxsim_kd


class ColbertModule(torch.nn.Module):
"""
Expand Down Expand Up @@ -90,6 +92,23 @@ def _aggregate(
return self._smooth_max(scores_raw, dim=dim_max).sum(dim=dim_sum)
return scores_raw.amax(dim=dim_max).sum(dim=dim_sum)

def _inbatch_scores(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor) -> torch.Tensor:
"""
Compute the in-batch MaxSim score matrix and apply optional length normalization.

Routes through the fused late-interaction kernel when ``use_smooth_max`` is False;
smooth-max keeps the logsumexp path since the kernel only exposes hard max.
"""
if self.use_smooth_max:
raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
scores = self._aggregate(raw, True, dim_max=3, dim_sum=2)
else:
scores = maxsim_inbatch(query_embeddings, doc_embeddings)
if self.normalize_scores:
lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
scores = self._apply_normalization(scores, lengths)
return scores

def _filter_high_negatives(self, scores: torch.Tensor, pos_idx: torch.Tensor) -> None:
"""
Down-weight negatives whose score exceeds a fraction of the positive score.
Expand Down Expand Up @@ -149,11 +168,7 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor,
Returns:
Tensor: Scalar loss value.
"""
lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2)
if self.normalize_scores:
scores = self._apply_normalization(scores, lengths)
scores = self._inbatch_scores(query_embeddings, doc_embeddings)

batch_size = scores.size(0)
idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
Expand Down Expand Up @@ -235,9 +250,13 @@ def forward(
pos_raw = torch.einsum(
"bnd,bsd->bns", query_embeddings, doc_embeddings[offset : offset + neg_doc_embeddings.size(0)]
)
neg_raw = torch.einsum("bnd,blsd->blns", query_embeddings, neg_doc_embeddings)
pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1)
neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2)
if self.use_smooth_max:
# Smooth-max keeps the logsumexp path since the kernel only exposes hard max.
neg_raw = torch.einsum("bnd,blsd->blns", query_embeddings, neg_doc_embeddings)
neg_scores = self._aggregate(neg_raw, True, dim_max=3, dim_sum=2)
else:
neg_scores = maxsim_kd(query_embeddings, neg_doc_embeddings)

if self.normalize_scores:
pos_scores = self._apply_normalization(pos_scores, lengths)
Expand Down Expand Up @@ -293,12 +312,7 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor,
Returns:
Tensor: Scalar loss value.
"""
lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2)

if self.normalize_scores:
scores = self._apply_normalization(scores, lengths)
scores = self._inbatch_scores(query_embeddings, doc_embeddings)

batch_size = scores.size(0)
idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
Expand Down Expand Up @@ -381,9 +395,13 @@ def forward(
pos_raw = torch.einsum(
"bnd,bld->bnl", query_embeddings, doc_embeddings[offset : offset + query_embeddings.size(0)]
)
neg_raw = torch.einsum("bnd,bsld->bsnl", query_embeddings, neg_doc_embeddings) # B x Nneg x Nq x Lneg
pos_scores = self._aggregate(pos_raw, self.use_smooth_max, dim_max=2, dim_sum=1)
neg_scores = self._aggregate(neg_raw, self.use_smooth_max, dim_max=3, dim_sum=2)
if self.use_smooth_max:
# Smooth-max keeps the logsumexp path since the kernel only exposes hard max.
neg_raw = torch.einsum("bnd,bsld->bsnl", query_embeddings, neg_doc_embeddings) # B x Nneg x Nq x Lneg
neg_scores = self._aggregate(neg_raw, True, dim_max=3, dim_sum=2)
else:
neg_scores = maxsim_kd(query_embeddings, neg_doc_embeddings)

if self.normalize_scores:
pos_scores = self._apply_normalization(pos_scores, lengths)
Expand Down Expand Up @@ -440,12 +458,7 @@ def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor,
Tensor: Scalar loss value.
"""

lengths = (query_embeddings[:, :, 0] != 0).sum(dim=1)
raw = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings)
scores = self._aggregate(raw, self.use_smooth_max, dim_max=3, dim_sum=2)

if self.normalize_scores:
scores = self._apply_normalization(scores, lengths)
scores = self._inbatch_scores(query_embeddings, doc_embeddings)

batch_size = scores.size(0)
idx, pos_idx = self._get_idx(batch_size, offset, scores.device)
Expand Down
12 changes: 11 additions & 1 deletion colpali_engine/trainer/contrastive_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def __init__(self, loss_func, is_vision_model, compute_symetric_loss=False, *arg
self.train_dataset_list = train_dataset_list
self.eval_dataset_list = eval_dataset_list
self.compute_symetric_loss = compute_symetric_loss
# Prime the prefixes from the collator. The multi-dataset path also
# sets these inside get_train_dataloader, but the single-dataset path
# never does, and compute_loss reads them on every step.
collator = kwargs.get("data_collator")
self.query_prefix = getattr(collator, "query_prefix", "query_")
self.pos_prefix = getattr(collator, "pos_doc_prefix", "doc_")
self.neg_prefix = getattr(collator, "neg_doc_prefix", "neg_doc_")

def get_train_dataloader(self) -> DataLoader:
"""
Expand Down Expand Up @@ -116,8 +123,11 @@ def get_train_dataloader(self) -> DataLoader:

return self.accelerator.prepare(dataloader)

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
def _get_train_sampler(self, dataset=None) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset_list is None:
# transformers 5.x passes the dataset positionally; older versions do not.
if dataset is not None:
return super()._get_train_sampler(dataset)
return super()._get_train_sampler()

# Use SingleDatasetBatchSampler to ensure that each dataset in the list is sampled independently
Expand Down
75 changes: 75 additions & 0 deletions colpali_engine/utils/_lik_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""``late-interaction-kernels`` (LIK) backend for MaxSim scoring.

Lazily imported by ``colpali_engine.utils.maxsim`` so the dependency stays optional.
Mirrors PyLate's ``_lik_backend`` (lightonai/pylate#222): each entry point validates its
inputs and raises ``LIKUnsupportedError`` when the kernel cannot run, so the dispatcher can
fall back silently in ``auto`` mode and re-raise in strict ``lik`` mode.
"""

import torch

_IMPORT_OK: bool | None = None


class LIKUnsupportedError(Exception):
"""LIK cannot handle this call. The only exception ``auto`` mode swallows: real kernel errors propagate."""


def _lik_device(query: torch.Tensor, doc: torch.Tensor) -> str:
"""Validate the inputs against the kernel's constraints; return ``"cuda"`` or ``"mps"``."""
if not is_available():
raise LIKUnsupportedError(
"late-interaction-kernels is not installed (pip install 'colpali-engine[lik]') "
"or no CUDA/MPS accelerator is present."
)
if query.device != doc.device:
raise LIKUnsupportedError(f"query and doc are on different devices ({query.device} vs {doc.device}).")
# Below the Triton tile/MMA floor on the embedding dim the kernel can't beat einsum.
if query.shape[-1] < 8:
raise LIKUnsupportedError(f"embedding dim {query.shape[-1]} is below the kernel's tile floor (8).")
if query.is_cuda:
if torch.cuda.get_device_capability(query.device)[0] < 8: # bf16 tensor cores need Ampere+
raise LIKUnsupportedError("the CUDA kernel requires compute capability >= 8 (Ampere or newer).")
return "cuda"
if query.device.type == "mps":
return "mps"
raise LIKUnsupportedError(f"unsupported device type {query.device.type!r} (CUDA or MPS required).")


def is_available() -> bool:
"""Memoized: ``late_interaction_kernels`` imports and a CUDA/MPS accelerator is present."""
global _IMPORT_OK
if _IMPORT_OK is not None:
return _IMPORT_OK
try:
import late_interaction_kernels # noqa: F401
except ImportError:
_IMPORT_OK = False
return _IMPORT_OK
_IMPORT_OK = torch.cuda.is_available() or torch.backends.mps.is_available()
return _IMPORT_OK


def maxsim_inbatch_lik(query: torch.Tensor, doc: torch.Tensor) -> torch.Tensor:
"""In-batch MaxSim through the fused kernel. Raises ``LIKUnsupportedError`` when ineligible."""
device = _lik_device(query, doc)
if device == "cuda":
from late_interaction_kernels.autograd import maxsim

return maxsim(query, doc)
from late_interaction_kernels.mps import maxsim_mps

return maxsim_mps(query, doc, normalize=False)


def maxsim_kd_lik(query: torch.Tensor, doc: torch.Tensor) -> torch.Tensor:
"""Per-query candidate MaxSim through the fused kernel. Raises ``LIKUnsupportedError`` when ineligible."""
if _lik_device(query, doc) != "cuda":
raise LIKUnsupportedError("the KD (candidate-list) layout has no MPS kernel.")
if doc.dim() != 4:
raise LIKUnsupportedError(f"KD layout expects doc [B, N, L, d], got {doc.dim()}-D.")

from late_interaction_kernels.autograd import maxsim

# A 4-D doc triggers the kernel's kd_layout path (one fused launch, no packing).
return maxsim(query, doc)
64 changes: 64 additions & 0 deletions colpali_engine/utils/maxsim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""MaxSim dispatch: route late-interaction scoring through ``late-interaction-kernels`` (LIK)
or the pure-torch einsum reference, selected by ``COLPALI_SCORES_BACKEND``.

``maxsim_inbatch`` scores the in-batch ``[B, Lq, d] x [B, Ld, d]`` grid; ``maxsim_kd`` scores
the per-query candidate layout ``[B, N, Ld, d]`` used by the negative-doc losses.

``COLPALI_SCORES_BACKEND`` (read per call) mirrors PyLate's ``PYLATE_SCORES_BACKEND``:
``auto`` (default) uses LIK when eligible and silently falls back to torch, ``torch`` forces
the reference, ``lik`` requires the kernel and raises ``LIKUnsupportedError`` when it cannot run.

Pad tokens must be exactly zero: both paths rely on zero-padding, not an explicit mask.
"""

import os

import torch

_VALID_BACKENDS = ("auto", "torch", "lik")


def _resolve_backend() -> str:
"""Read ``COLPALI_SCORES_BACKEND`` at call time so it can be flipped at runtime."""
backend = os.environ.get("COLPALI_SCORES_BACKEND", "auto").lower()
if backend not in _VALID_BACKENDS:
raise ValueError(f"COLPALI_SCORES_BACKEND must be one of {_VALID_BACKENDS}, got {backend!r}.")
return backend


def _torch_maxsim(query: torch.Tensor, doc: torch.Tensor) -> torch.Tensor:
"""Reference in-batch MaxSim: ``einsum("bnd,csd->bcns").amax(-1).sum(-1)``."""
return torch.einsum("bnd,csd->bcns", query, doc).amax(dim=3).sum(dim=2)


def _torch_maxsim_kd(query: torch.Tensor, doc: torch.Tensor) -> torch.Tensor:
"""Reference KD MaxSim over ``doc [B, N, L, d]`` (each query has its own N candidates)."""
return torch.einsum("bnd,bksd->bkns", query, doc).amax(dim=3).sum(dim=2)


def maxsim_inbatch(query: torch.Tensor, doc: torch.Tensor) -> torch.Tensor:
"""In-batch MaxSim: ``[B_q, B_d]`` scores from zero-padded ``query [B_q, L_q, d]`` and ``doc [B_d, L_d, d]``."""
backend = _resolve_backend()
if backend != "torch":
from colpali_engine.utils import _lik_backend

try:
return _lik_backend.maxsim_inbatch_lik(query, doc)
except _lik_backend.LIKUnsupportedError:
if backend == "lik":
raise
return _torch_maxsim(query, doc)


def maxsim_kd(query: torch.Tensor, doc: torch.Tensor) -> torch.Tensor:
"""Per-query candidate MaxSim: ``[B, N]`` from zero-padded ``query [B, L_q, d]`` and ``doc [B, N, L_d, d]``."""
backend = _resolve_backend()
if backend != "torch":
from colpali_engine.utils import _lik_backend

try:
return _lik_backend.maxsim_kd_lik(query, doc)
except _lik_backend.LIKUnsupportedError:
if backend == "lik":
raise
return _torch_maxsim_kd(query, doc)
3 changes: 2 additions & 1 deletion colpali_engine/utils/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"FastPlaid is not installed.If you want to use it:Instal with `pip install --no-deps fast-plaid fastkmeans`"
)

from colpali_engine.utils.maxsim import maxsim_inbatch
from colpali_engine.utils.torch_utils import get_torch_device


Expand Down Expand Up @@ -176,7 +177,7 @@ def score_multi_vector(
ps_batch = torch.nn.utils.rnn.pad_sequence(
ps[j : j + batch_size], batch_first=True, padding_value=0
).to(device)
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2))
scores_batch.append(maxsim_inbatch(qs_batch, ps_batch))
scores_batch = torch.cat(scores_batch, dim=1).cpu()
scores_list.append(scores_batch)

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,14 @@ interpretability = [
"seaborn>=0.13.2,<1.0.0",
]

lik = ["late-interaction-kernels>=0.4.1,<0.5.0"]

dev = ["pytest>=8.0.0", "ruff>=0.4.0"]

all = [
"colpali-engine[dev]",
"colpali-engine[interpretability]",
"colpali-engine[lik]",
"colpali-engine[train]",
]

Expand Down
Loading
Loading