Skip to content

Commit 34c1f9c

Browse files
committed
Refactored SemHash, moved more functions to utils
1 parent a8f1420 commit 34c1f9c

4 files changed

Lines changed: 119 additions & 109 deletions

File tree

semhash/datamodels.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
from collections.abc import Hashable, Sequence
66
from dataclasses import dataclass, field
77
from functools import cached_property
8-
from typing import Any, Generic, TypeAlias, TypeVar
8+
from typing import Any, Generic
99

1010
from frozendict import frozendict
1111

12-
from semhash.utils import to_frozendict
13-
14-
Record = TypeVar("Record", str, dict[str, Any])
15-
DuplicateList: TypeAlias = list[tuple[Record, float]]
12+
from semhash.utils import DuplicateList, Record, to_frozendict
1613

1714

1815
@dataclass

semhash/semhash.py

Lines changed: 17 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@
1414
from semhash.datamodels import DeduplicationResult, DuplicateRecord, FilterResult, Record
1515
from semhash.index import Index
1616
from semhash.records import add_scores_to_records, map_deduplication_result_to_strings
17-
from semhash.utils import Encoder, compute_candidate_limit, to_frozendict
17+
from semhash.utils import (
18+
Encoder,
19+
compute_candidate_limit,
20+
featurize,
21+
prepare_records,
22+
remove_exact_duplicates,
23+
to_frozendict,
24+
)
1825

1926

2027
class SemHash(Generic[Record]):
@@ -33,95 +40,6 @@ def __init__(self, index: Index, model: Encoder, columns: Sequence[str], was_str
3340
self._was_string = was_string
3441
self._ranking_cache: FilterResult | None = None
3542

36-
@staticmethod
37-
def _featurize(
38-
records: Sequence[dict[str, str]],
39-
columns: Sequence[str],
40-
model: Encoder,
41-
) -> np.ndarray:
42-
"""
43-
Featurize a list of records using the model.
44-
45-
:param records: A list of records.
46-
:param columns: Columns to featurize.
47-
:param model: An Encoder model.
48-
:return: The embeddings of the records.
49-
"""
50-
# Extract the embeddings for each column across all records
51-
embeddings_per_col = []
52-
for col in columns:
53-
col_texts = [r[col] for r in records]
54-
col_emb = model.encode(col_texts)
55-
embeddings_per_col.append(np.asarray(col_emb))
56-
57-
return np.concatenate(embeddings_per_col, axis=1)
58-
59-
@classmethod
60-
def _remove_exact_duplicates(
61-
cls,
62-
records: Sequence[dict[str, str]],
63-
columns: Sequence[str],
64-
reference_records: list[list[dict[str, str]]] | None = None,
65-
) -> tuple[list[dict[str, str]], list[tuple[dict[str, str], list[dict[str, str]]]]]:
66-
"""
67-
Remove exact duplicates based on the unpacked string representation of each record.
68-
69-
If reference_records is None, the function will only check for duplicates within the records list.
70-
71-
:param records: A list of records to check for exact duplicates.
72-
:param columns: Columns to unpack.
73-
:param reference_records: A list of records to compare against. These are already unpacked
74-
:return: A list of deduplicated records and a list of duplicates.
75-
"""
76-
deduplicated = []
77-
duplicates = []
78-
79-
column_set = set(columns)
80-
# Build a seen set from reference_records if provided
81-
seen: defaultdict[frozendict[str, str], list[dict[str, str]]] = defaultdict(list)
82-
if reference_records is not None:
83-
for record_set in reference_records:
84-
key = to_frozendict(record_set[0], column_set)
85-
seen[key] = list(record_set)
86-
in_one_set = reference_records is None
87-
88-
for record in records:
89-
frozen_record = frozendict({k: v for k, v in record.items() if k in column_set})
90-
if duplicated_records := seen.get(frozen_record):
91-
duplicates.append((record, duplicated_records))
92-
else:
93-
deduplicated.append(record)
94-
# Only add current documents to seen if no reference set is used
95-
if in_one_set:
96-
seen[frozen_record].append(record)
97-
98-
return deduplicated, duplicates
99-
100-
@staticmethod
101-
def _prepare_records(
102-
records: Sequence[Record], columns: Sequence[str] | None
103-
) -> tuple[list[dict[str, str]], Sequence[str], bool]:
104-
"""
105-
Validate and prepare records for processing.
106-
107-
:param records: A list of records (strings or dictionaries).
108-
:param columns: Columns to use if records are dictionaries.
109-
:return: Tuple of (dict_records, columns, was_string).
110-
:raises ValueError: If columns are not provided for dictionary records.
111-
"""
112-
if columns is None and isinstance(records[0], dict):
113-
raise ValueError("Columns must be specified when passing dictionaries.")
114-
115-
if isinstance(records[0], str):
116-
columns = ["text"]
117-
dict_records: list[dict[str, str]] = [{"text": str(record)} for record in records]
118-
was_string = True
119-
else:
120-
dict_records = list(records)
121-
was_string = False
122-
123-
return dict_records, columns, was_string
124-
12543
@classmethod
12644
def from_embeddings(
12745
cls,
@@ -152,10 +70,10 @@ def from_embeddings(
15270
raise ValueError(f"Number of embeddings ({len(embeddings)}) must match number of records ({len(records)})")
15371

15472
# Prepare and validate records
155-
dict_records, columns, was_string = cls._prepare_records(records, columns)
73+
dict_records, columns, was_string = prepare_records(records, columns)
15674

15775
# Remove exact duplicates
158-
deduplicated_records, exact_duplicates = cls._remove_exact_duplicates(dict_records, columns)
76+
deduplicated_records, exact_duplicates = remove_exact_duplicates(dict_records, columns)
15977

16078
# Build items list. Each item is a list of exact duplicates
16179
items: list[list[dict[str, str]]] = [[record] for record in deduplicated_records]
@@ -208,14 +126,14 @@ def from_records(
208126
:return: A SemHash instance with a fitted vicinity index.
209127
"""
210128
# Prepare and validate records
211-
dict_records, columns, was_string = cls._prepare_records(records, columns)
129+
dict_records, columns, was_string = prepare_records(records, columns)
212130

213131
# If no model is provided, load the default model
214132
if model is None:
215133
model = StaticModel.from_pretrained("minishlab/potion-base-8M")
216134

217135
# Remove exact duplicates
218-
deduplicated_records, duplicates = cls._remove_exact_duplicates(dict_records, columns)
136+
deduplicated_records, duplicates = remove_exact_duplicates(dict_records, columns)
219137

220138
col_set = set(columns)
221139
duplicate_map = defaultdict(list)
@@ -231,7 +149,7 @@ def from_records(
231149
items.append(i)
232150

233151
# Create embeddings for deduplicated records only
234-
embeddings = cls._featurize(deduplicated_records, columns, model)
152+
embeddings = featurize(deduplicated_records, columns, model)
235153

236154
# Build the Vicinity index
237155
backend = ann_backend if use_ann else Backend.BASIC
@@ -263,7 +181,7 @@ def deduplicate(
263181
dict_records = self._validate_if_strings(records)
264182

265183
# Remove exact duplicates before embedding
266-
dict_records, exact_duplicates = self._remove_exact_duplicates(
184+
dict_records, exact_duplicates = remove_exact_duplicates(
267185
records=dict_records, columns=self.columns, reference_records=self.index.items
268186
)
269187
duplicate_records = []
@@ -279,7 +197,7 @@ def deduplicate(
279197
)
280198

281199
# Compute embeddings for the new records
282-
embeddings = self._featurize(records=dict_records, columns=self.columns, model=self.model)
200+
embeddings = featurize(records=dict_records, columns=self.columns, model=self.model)
283201
# Query the fitted index
284202
results = self.index.query_threshold(embeddings, threshold=threshold)
285203

@@ -536,7 +454,7 @@ def _rank_by_average_similarity(
536454
:return: A FilterResult containing the ranking (records sorted and their average similarity scores).
537455
"""
538456
dict_records = self._validate_if_strings(records)
539-
embeddings = self._featurize(records=dict_records, columns=self.columns, model=self.model)
457+
embeddings = featurize(records=dict_records, columns=self.columns, model=self.model)
540458
results = self.index.query_top_k(embeddings, k=100, vectors_are_in_index=False)
541459

542460
# Compute the average similarity for each record.
@@ -600,7 +518,7 @@ def _diversify(
600518
if not candidates:
601519
return FilterResult(selected=[], filtered=[], scores_selected=[], scores_filtered=[])
602520

603-
embeddings = self._featurize(records=candidates, columns=self.columns, model=self.model)
521+
embeddings = featurize(records=candidates, columns=self.columns, model=self.model)
604522
result = diversify(
605523
embeddings=embeddings,
606524
scores=np.array(relevance),

semhash/utils.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
from collections import defaultdict
12
from collections.abc import Sequence
2-
from typing import Any, Protocol
3+
from typing import Any, Protocol, TypeAlias, TypeVar
34

45
import numpy as np
56
from frozendict import frozendict
67

8+
# Type definitions
9+
Record = TypeVar("Record", str, dict[str, Any])
10+
DuplicateList: TypeAlias = list[tuple[Record, float]]
11+
712

813
class Encoder(Protocol):
914
"""An encoder protocol for SemHash."""
@@ -54,3 +59,91 @@ def compute_candidate_limit(
5459
# 4) enforce upper bound (and never exceed the dataset)
5560
limit = min(limit, max_candidates, total)
5661
return limit
62+
63+
64+
def featurize(
65+
records: Sequence[dict[str, str]],
66+
columns: Sequence[str],
67+
model: Encoder,
68+
) -> np.ndarray:
69+
"""
70+
Featurize a list of records using the model.
71+
72+
:param records: A list of records.
73+
:param columns: Columns to featurize.
74+
:param model: An Encoder model.
75+
:return: The embeddings of the records.
76+
"""
77+
# Extract the embeddings for each column across all records
78+
embeddings_per_col = []
79+
for col in columns:
80+
col_texts = [r[col] for r in records]
81+
col_emb = model.encode(col_texts)
82+
embeddings_per_col.append(np.asarray(col_emb))
83+
84+
return np.concatenate(embeddings_per_col, axis=1)
85+
86+
87+
def remove_exact_duplicates(
88+
records: Sequence[dict[str, str]],
89+
columns: Sequence[str],
90+
reference_records: list[list[dict[str, str]]] | None = None,
91+
) -> tuple[list[dict[str, str]], list[tuple[dict[str, str], list[dict[str, str]]]]]:
92+
"""
93+
Remove exact duplicates based on the unpacked string representation of each record.
94+
95+
If reference_records is None, the function will only check for duplicates within the records list.
96+
97+
:param records: A list of records to check for exact duplicates.
98+
:param columns: Columns to unpack.
99+
:param reference_records: A list of records to compare against. These are already unpacked
100+
:return: A list of deduplicated records and a list of duplicates.
101+
"""
102+
deduplicated = []
103+
duplicates = []
104+
105+
column_set = set(columns)
106+
# Build a seen set from reference_records if provided
107+
seen: defaultdict[frozendict[str, str], list[dict[str, str]]] = defaultdict(list)
108+
if reference_records is not None:
109+
for record_set in reference_records:
110+
key = to_frozendict(record_set[0], column_set)
111+
seen[key] = list(record_set)
112+
in_one_set = reference_records is None
113+
114+
for record in records:
115+
frozen_record = frozendict({k: v for k, v in record.items() if k in column_set})
116+
if duplicated_records := seen.get(frozen_record):
117+
duplicates.append((record, duplicated_records))
118+
else:
119+
deduplicated.append(record)
120+
# Only add current documents to seen if no reference set is used
121+
if in_one_set:
122+
seen[frozen_record].append(record)
123+
124+
return deduplicated, duplicates
125+
126+
127+
def prepare_records(
128+
records: Sequence[Record], columns: Sequence[str] | None
129+
) -> tuple[list[dict[str, str]], Sequence[str], bool]:
130+
"""
131+
Validate and prepare records for processing.
132+
133+
:param records: A list of records (strings or dictionaries).
134+
:param columns: Columns to use if records are dictionaries.
135+
:return: Tuple of (dict_records, columns, was_string).
136+
:raises ValueError: If columns are not provided for dictionary records.
137+
"""
138+
if columns is None and isinstance(records[0], dict):
139+
raise ValueError("Columns must be specified when passing dictionaries.")
140+
141+
if isinstance(records[0], str):
142+
columns = ["text"]
143+
dict_records: list[dict[str, str]] = [{"text": str(record)} for record in records]
144+
was_string = True
145+
else:
146+
dict_records = list(records)
147+
was_string = False
148+
149+
return dict_records, columns, was_string

tests/test_semhash.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,22 +189,24 @@ def test_self_filter_outliers(use_ann: bool, model: Encoder, train_texts: list[s
189189
def test__diversify(monkeypatch: pytest.MonkeyPatch) -> None:
190190
"""Test the _diversify method."""
191191
# Create a dummy SemHash instance
192-
semhash = SemHash(index=None, model=None, columns=["text"], was_string=True) # type: ignore
192+
from semhash import semhash as semhash_module
193+
194+
semhash_instance = SemHash(index=None, model=None, columns=["text"], was_string=True) # type: ignore
193195
# Prepare a fake ranking with three records
194196
records = ["a", "b", "c"]
195197
scores = [3.0, 2.0, 1.0]
196198
ranking = FilterResult(selected=records, filtered=[], scores_selected=scores, scores_filtered=[])
197199
# Create dummy embeddings for the records
198200
embeddings = np.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]])
199201
# Monkeypatch featurize to return the dummy embeddings
200-
monkeypatch.setattr(semhash, "_featurize", lambda records, columns, model: embeddings)
202+
monkeypatch.setattr(semhash_module, "featurize", lambda records, columns, model: embeddings)
201203

202204
# Test diversity=0.0: pure relevance, should pick top 2 by score
203-
result_rel = semhash._diversify(ranking, candidate_limit=3, selection_size=2, diversity=0.0)
205+
result_rel = semhash_instance._diversify(ranking, candidate_limit=3, selection_size=2, diversity=0.0)
204206
assert result_rel.selected == ["a", "b"]
205207

206208
# Test diversity=1.0: pure diversity, should first pick 'a', then pick most dissimilar: 'c'
207-
result_div = semhash._diversify(ranking, candidate_limit=3, selection_size=2, diversity=1.0)
209+
result_div = semhash_instance._diversify(ranking, candidate_limit=3, selection_size=2, diversity=1.0)
208210
assert result_div.selected == ["a", "c"]
209211

210212

0 commit comments

Comments
 (0)