Skip to content

Commit a3c8a7c

Browse files
Merge pull request #114 from DataScienceUIBK/copilot/test-reranker-implementation
feat: Add DuoT5, RankLLaMA, and DeAR rerankers
2 parents c1ac072 + 68fac06 commit a3c8a7c

7 files changed

Lines changed: 1398 additions & 3 deletions

File tree

rankify/models/dear_reranker.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import copy
2+
from typing import List, Optional
3+
4+
import torch
5+
from tqdm import tqdm
6+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
7+
8+
from rankify.dataset.dataset import Context, Document
9+
from rankify.models.base import BaseRanking
10+
11+
12+
class DeARReranker(BaseRanking):
13+
"""
14+
Implements **DeAR (Dual-Stage Document Reranking)**, a family of
15+
efficient pointwise rerankers based on **LLaMA-3.2** and trained with
16+
**Binary Cross-Entropy** (BCE) or **RankNet** loss via knowledge
17+
distillation from a large teacher model.
18+
19+
The model scores query–document pairs using the prompt format:
20+
21+
.. code-block::
22+
23+
query: <query> [SEP] document: <document>
24+
25+
Multiple DeAR variants are supported (3B-CE, 3B-RankNet, 8B-CE, LoRA).
26+
27+
References:
28+
- **Abdallah et al. (2025)**: *DeAR: Dual-Stage Document Reranking
29+
with Reasoning Agents via LLM Distillation*.
30+
[Paper](https://arxiv.org/abs/2508.16998)
31+
32+
Attributes:
33+
method (str): The name of the reranking method.
34+
model_name (str): HuggingFace model identifier.
35+
device (str): Computation device (``"cuda"`` or ``"cpu"``).
36+
tokenizer (AutoTokenizer): Tokenizer for the DeAR model.
37+
model (AutoModelForSequenceClassification): The DeAR reranking model.
38+
batch_size (int): Batch size for inference.
39+
max_length (int): Maximum tokenisation length (default 228 per paper).
40+
41+
Example:
42+
```python
43+
from rankify.dataset.dataset import Document, Question, Answer, Context
44+
from rankify.models.reranking import Reranking
45+
46+
question = Question("When did Thomas Edison invent the light bulb?")
47+
answers = Answer(["1879"])
48+
contexts = [
49+
Context(text="Lightning strike at Seoul National University", id=1),
50+
Context(text="Thomas Edison invented the light bulb in 1879", id=2),
51+
Context(text="Coffee is good for diet", id=3),
52+
]
53+
document = Document(question=question, answers=answers, contexts=contexts)
54+
55+
model = Reranking(method='dear_reranker', model_name='dear-3b-reranker-ce-v1')
56+
model.rank([document])
57+
58+
for ctx in document.reorder_contexts:
59+
print(ctx.text)
60+
```
61+
"""
62+
63+
def __init__(
64+
self,
65+
method: str = None,
66+
model_name: str = None,
67+
api_key: str = None,
68+
**kwargs,
69+
):
70+
"""
71+
Initialises **DeARReranker**.
72+
73+
Args:
74+
method (str, optional): Reranking method name.
75+
model_name (str): HuggingFace model identifier
76+
(e.g. ``"abdoelsayed/dear-3b-reranker-ce-v1"``).
77+
api_key (str, optional): Unused; present for framework consistency.
78+
**kwargs:
79+
- device (str): ``"cuda"`` or ``"cpu"``. Default: auto-detect.
80+
- batch_size (int): Inference batch size. Default: ``32``.
81+
- max_length (int): Max tokenisation length. Default: ``228``.
82+
- dtype: Torch dtype. Default: ``bfloat16`` on CUDA, ``float32`` on CPU.
83+
"""
84+
self.method = method
85+
self.model_name = model_name
86+
87+
device_str = kwargs.get(
88+
"device", "cuda" if torch.cuda.is_available() else "cpu"
89+
)
90+
self.device = device_str
91+
self.batch_size = kwargs.get("batch_size", 32)
92+
# Paper trains at max_length=228; expose as a tunable kwarg
93+
self.max_length = kwargs.get("max_length", 228)
94+
95+
# Dtype: bfloat16 on GPU (matches paper), float32 on CPU
96+
if "dtype" in kwargs:
97+
dtype = kwargs["dtype"]
98+
elif device_str == "cuda":
99+
dtype = torch.bfloat16
100+
else:
101+
dtype = torch.float32
102+
103+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
104+
if self.tokenizer.pad_token is None:
105+
self.tokenizer.pad_token = self.tokenizer.eos_token
106+
107+
self.model = AutoModelForSequenceClassification.from_pretrained(
108+
model_name,
109+
torch_dtype=dtype,
110+
device_map="auto" if device_str == "cuda" else None,
111+
)
112+
if device_str != "cuda":
113+
self.model = self.model.to(device_str)
114+
self.model.eval()
115+
116+
# ------------------------------------------------------------------
117+
# Public interface
118+
# ------------------------------------------------------------------
119+
120+
@torch.inference_mode()
121+
def rank(self, documents: List[Document]) -> List[Document]:
122+
"""
123+
Reranks contexts within each document using **DeAR** relevance scores.
124+
125+
Args:
126+
documents (List[Document]): Documents whose contexts to rerank.
127+
128+
Returns:
129+
List[Document]: Documents with updated ``reorder_contexts``.
130+
"""
131+
for document in tqdm(documents, desc="Reranking Documents"):
132+
query = document.question.question
133+
contexts = copy.deepcopy(document.contexts)
134+
135+
query_texts = [f"query: {query}"] * len(contexts)
136+
doc_texts = [f"document: {ctx.text}" for ctx in contexts]
137+
138+
scores = self._score_batched(query_texts, doc_texts)
139+
140+
for ctx, score in zip(contexts, scores):
141+
ctx.score = score
142+
143+
document.reorder_contexts = sorted(
144+
contexts, key=lambda x: x.score, reverse=True
145+
)
146+
147+
return documents
148+
149+
# ------------------------------------------------------------------
150+
# Internal helpers
151+
# ------------------------------------------------------------------
152+
153+
def _score_batched(
154+
self,
155+
query_texts: List[str],
156+
doc_texts: List[str],
157+
) -> List[float]:
158+
"""
159+
Compute relevance scores for pre-formatted ``(query, document)`` pairs.
160+
161+
Args:
162+
query_texts: Already-formatted query strings (``"query: …"``).
163+
doc_texts: Already-formatted document strings (``"document: …"``).
164+
165+
Returns:
166+
List of float scores, one per pair.
167+
"""
168+
scores: List[float] = []
169+
for start in range(0, len(query_texts), self.batch_size):
170+
q_batch = query_texts[start : start + self.batch_size]
171+
d_batch = doc_texts[start : start + self.batch_size]
172+
173+
tokenized = self.tokenizer(
174+
q_batch,
175+
d_batch,
176+
return_tensors="pt",
177+
padding=True,
178+
truncation=True,
179+
max_length=self.max_length,
180+
)
181+
tokenized = {
182+
k: v.to(self.model.device) for k, v in tokenized.items()
183+
}
184+
185+
logits = self.model(**tokenized).logits # (batch, 1)
186+
batch_scores = logits.squeeze(-1).cpu().tolist()
187+
188+
if isinstance(batch_scores, float):
189+
scores.append(batch_scores)
190+
else:
191+
scores.extend(batch_scores)
192+
193+
return scores

0 commit comments

Comments
 (0)