Skip to content

Commit 97f2596

Browse files
committed
Fixed from_embeddings bug
1 parent 4c58279 commit 97f2596

2 files changed

Lines changed: 85 additions & 2 deletions

File tree

semhash/semhash.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,19 @@ def from_embeddings(
143143
break
144144

145145
# Build index mapping for embeddings (accounting for removed exact duplicates)
146+
# We need to keep only the first occurrence of each unique record
147+
col_set = set(columns)
148+
seen_hashes = set()
146149
embedding_indices = []
150+
147151
for i, record in enumerate(dict_records):
148-
if record in deduplicated_records:
149-
embedding_indices.append(i)
152+
frozen = to_frozendict(record, col_set)
153+
# Only keep the first occurrence of each unique record
154+
if frozen not in seen_hashes:
155+
# Check if this record hash is in the deduplicated set
156+
if any(to_frozendict(dedup_rec, col_set) == frozen for dedup_rec in deduplicated_records):
157+
embedding_indices.append(i)
158+
seen_hashes.add(frozen)
150159

151160
# Select embeddings for non-exact-duplicate records
152161
deduplicated_embeddings = embeddings[embedding_indices]

tests/test_semhash.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,77 @@ def test_from_embeddings(model: Encoder, train_texts: list[str]) -> None:
233233

234234
assert len(result1.selected) == len(result2.selected)
235235
assert len(result1.filtered) == len(result2.filtered)
236+
237+
238+
def test_from_embeddings_with_exact_duplicates(model: Encoder) -> None:
239+
"""
240+
Test that from_embeddings correctly handles exact duplicates in input.
241+
242+
This is a regression test for Issue #1: the bug where duplicate records
243+
would cause duplicate embeddings to be kept in the index.
244+
"""
245+
# Create records with exact duplicates
246+
records = [
247+
"apple", # 0
248+
"banana", # 1
249+
"apple", # 2 - duplicate of 0
250+
"cherry", # 3
251+
"banana", # 4 - duplicate of 1
252+
"date", # 5
253+
]
254+
255+
# Generate embeddings for all records (including duplicates)
256+
embeddings = model.encode(records)
257+
258+
# Create SemHash from embeddings
259+
semhash = SemHash.from_embeddings(embeddings=embeddings, records=records, model=model)
260+
261+
# The index should only contain 4 unique records (apple, banana, cherry, date)
262+
assert len(semhash.index.vectors) == 4, f"Expected 4 unique vectors, got {len(semhash.index.vectors)}"
263+
assert len(semhash.index.items) == 4, f"Expected 4 items, got {len(semhash.index.items)}"
264+
265+
# Verify that duplicates are grouped correctly
266+
# Each item is a list of exact duplicates
267+
items_by_text = {}
268+
for item in semhash.index.items:
269+
text = item[0]["text"]
270+
items_by_text[text] = len(item)
271+
272+
# apple and banana should have 2 records each (original + duplicate)
273+
# cherry and date should have 1 record each
274+
assert items_by_text["apple"] == 2, "apple should have 2 records"
275+
assert items_by_text["banana"] == 2, "banana should have 2 records"
276+
assert items_by_text["cherry"] == 1, "cherry should have 1 record"
277+
assert items_by_text["date"] == 1, "date should have 1 record"
278+
279+
# Verify embeddings correspond to first occurrences
280+
# The vectors should match embeddings at indices [0, 1, 3, 5]
281+
# (order may vary in the index, so we can't do exact comparison)
282+
# but the count should be correct
283+
assert semhash.index.vectors.shape[0] == 4
284+
285+
286+
def test_from_embeddings_dict_records_with_duplicates(model: Encoder) -> None:
287+
"""Test that from_embeddings handles duplicates correctly with dictionary records."""
288+
records = [
289+
{"id": "1", "text": "apple"},
290+
{"id": "2", "text": "banana"},
291+
{"id": "3", "text": "apple"}, # Duplicate based on 'text' column
292+
{"id": "4", "text": "cherry"},
293+
]
294+
295+
# Generate embeddings
296+
texts = [r["text"] for r in records]
297+
embeddings = model.encode(texts)
298+
299+
# Create SemHash using only 'text' column for deduplication
300+
semhash = SemHash.from_embeddings(embeddings=embeddings, records=records, columns=["text"], model=model)
301+
302+
# Should have 3 unique 'text' values
303+
assert len(semhash.index.vectors) == 3
304+
assert len(semhash.index.items) == 3
305+
306+
# Find the item with "apple" text
307+
apple_items = [item for item in semhash.index.items if item[0]["text"] == "apple"]
308+
assert len(apple_items) == 1, "Should find exactly one item group for 'apple'"
309+
assert len(apple_items[0]) == 2, "The 'apple' item should contain 2 records"

0 commit comments

Comments
 (0)